Exploring Image Processing and 3D Visualization of JWST Images in Python

JWST Featured Image

Introduction

In this article, we will explore a code snippet that performs image processing and visualization tasks using various libraries in Python. The code aims to explain the steps involved in analyzing an image, applying filters, detecting local maxima, and plotting the detected features in a three-dimensional space. We will break down the code and provide a detailed explanation of each component to help you understand its functionality and use cases.

Importing the Required Libraries

The code begins by importing the necessary libraries that are required for image processing, visualization, and numerical computations. These libraries include numpy, scipy, imageio, scikit-image, matplotlib, skimage.draw, torch, and torchvision.

import subprocess
import sys
import importlib
import imageio.v3 as iio
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from mpl_toolkits.mplot3d import Axes3D
from scipy.ndimage import gaussian_filter
from scipy.ndimage import median_filter
from scipy.ndimage import maximum_filter
from scipy.ndimage import label
from skimage import morphology as morph
from matplotlib.patches import Circle

# Define a list of libraries to install
libraries = ["numpy", "scipy", "imageio", "scikit-image", "matplotlib", "skimage.draw", "torch", "torchvision"]

# Loop through each library in the list
for library in libraries:
    # Try to import the library
    try:
        importlib.import_module(library)
        print(f"{library} is already installed.")
    # If the import fails, install the library using pip
    except ImportError:
        print(f"{library} is not installed. Installing now...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", library])

In the code snippet above, we iterate through the libraries list and attempt to import each library. If a library is already installed, a message stating that it is already installed is displayed. If a library is not installed, it is installed using the pip command.

Loading and Displaying the Image

The next section of the code deals with loading and displaying the original image. The image is loaded using the imageio.imread() function, and the resulting image object is stored in the image variable.

image_path = ""
image = iio.imread(image_path)

# Display the original image
plt.imshow(image, cmap='gray')
plt.title('Original Image')
plt.show()

The loaded image is then displayed using the plt.imshow() function, which visualizes the image with a grayscale colormap. The plt.title() function sets the title of the displayed image to “Original Image.” Finally, the plt.show() function is called to render the image.

Applying Filters for Image Enhancement

After displaying the original image, the code proceeds to apply filters to enhance the image. Two filters, namely the Gaussian filter and the median filter, are utilized for this purpose.

Gaussian Filter

The Gaussian filter is employed to smooth the image and reduce noise. The filter is applied using the gaussian_filter() function from the scipy.ndimage module. The resulting filtered image is stored back in the image variable.

image = gaussian_filter(image, sigma=2)  # Apply Gaussian filter to smooth the image
# Display the image after Gaussian filter
plt.imshow(image, cmap='gray')
plt.title('Image after Gaussian filter')
plt.show()

The filtered image is then displayed using the same plt.imshow() function, and the title is set to “.

Median Filter

Following the Gaussian filter, the code applies a median filter to further enhance the image. The median filter is effective in reducing impulse noise and preserving edges in the image. Similar to the Gaussian filter, the median_filter() function from the scipy.ndimage module is used to apply the filter.

image = median_filter(image, size=3)  # Apply median filter to reduce noise
# Display the image after median filter
plt.imshow(image, cmap='gray')
plt.title('Image after Median filter')
plt.show()

The resulting image after the median filter is displayed using plt.imshow(), and the title is set to “Image after Median filter.”

Detecting Local Maxima

In this section, the code identifies local maxima in the filtered image. Local maxima are points in the image where the pixel value is higher than its surrounding pixels. This step is crucial for identifying features in the image, such as stars or other objects of interest.

The maximum_filter() function from the scipy.ndimage module is used to detect the local maxima. This function applies a maximum filter to the image, replacing each pixel’s value with the maximum value within its neighborhood. By comparing the filtered image with the maximum filtered image, local maxima are identified as the pixels that have the same values in both images.

# Detect local maxima
local_max = image == maximum_filter(image, size=3)

# Display the local maxima
plt.imshow(local_max, cmap='gray')
plt.title('Local Maxima')
plt.show()

The resulting local maxima image is displayed using plt.imshow(), with the title set to "Local Maxima."

Labeling Connected Components

After identifying the local maxima, the code proceeds to label the connected components in the image. Connected components refer to groups of adjacent pixels that share the same value. In this context, labeling the connected components allows us to differentiate between individual features, such as separate stars.

The label() function from the scipy.ndimage module is used to perform the labeling. It assigns a unique label to each connected component and returns an array with the same shape as the input image, where each pixel is labeled with an integer corresponding to its component.

pythonCopy code# Label connected components
labels, num_labels = label(local_max)

# Display the labeled components
plt.imshow(labels, cmap='nipy_spectral')
plt.title('Connected Components')
plt.show()

The labeled components image is displayed using plt.imshow(), with the colormap set to ‘nipy_spectral’, which enhances the visual distinction between different labels. The title is set to “Connected Components.”

Plotting the Detected Features in 3D

Finally, the code visualizes the detected features in a three-dimensional (3D) space. Each feature is represented as a point in the 3D plot, where the X, Y, and Z coordinates correspond to the pixel positions and the intensity values of the features.

The skimage.draw.circle_perimeter() function is used to draw circles around the detected features in the original image. Then, the X, Y, and Z coordinates of the features are extracted and plotted in a 3D plot using plt.scatter(). The resulting 3D plot provides a spatial representation of the detected features.

pythonCopy code# Get the coordinates and intensity values of the detected features
coords = np.array(np.nonzero(labels)).T
intensity_values = image[labels > 0]

# Plot the detected features in 3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(coords[:, 1], coords[:, 0], intensity_values, c=intensity_values, cmap='jet')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Intensity')
ax.set_title('Detected Features')
plt.show()

The resulting 3D plot is displayed using plt.show(), with the X, Y, and Z axes labeled accordingly. The title of the plot is set to “Detected Features.”

That concludes the explanation of the code for detecting features in an image using Python. By applying filters, identifying local maxima, labeling connected components, and visualizing the features in a 3D plot, this code provides a comprehensive approach to feature detection in images.

Feature Extraction with Convolutional Neural Networks (CNN)

In addition to traditional image processing techniques, convolutional neural networks (CNNs) have gained popularity for feature extraction in image analysis. CNNs are deep learning models specifically designed for image-related tasks and have shown remarkable performance in various computer vision applications.

To perform feature extraction using CNNs, the code utilizes a pre-trained CNN model and extracts features from an input image. Popular deep learning libraries, such as TensorFlow or PyTorch, provide pre-trained CNN models that can be easily loaded and used for feature extraction.

Let’s assume we are using the PyTorch library, which provides a wide range of pre-trained CNN models. Here’s an example of how to perform feature extraction using a pre-trained CNN model:

# Load pre-trained CNN model
model = models.resnet18(pretrained=True)

# Set the model to evaluation mode
model.eval()

# Load and preprocess the input image

image = Image.open(image_path).convert('RGB')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)

# Pass the input through the model to extract features
with torch.no_grad():
    features = model(input_batch)

# Print the extracted features
print(features)

In this example, we use the ResNet-18 model, a popular CNN architecture, which has been pre-trained on a large dataset like ImageNet. The model is loaded using models.resnet18(pretrained=True). We then set the model to evaluation mode using model.eval().

Next, we load and preprocess the input image. The image is opened using PIL (Image.open(image_path)), converted to the RGB color space, and preprocessed using a series of transformations (transforms.Compose). These transformations resize the image to 256×256 pixels, perform a center crop of size 224×224 pixels (as required by the ResNet-18 model), convert the image to a tensor, and normalize it with the mean and standard deviation values specific to the pre-trained model.

The preprocessed image tensor is then passed through the model using model(input_batch) to extract the features. Since we are only interested in feature extraction, we use torch.no_grad() to disable gradient computation, as we don’t need to update the model’s parameters.

Finally, the extracted features are printed (print(features)). The specific format and shape of the features depend on the chosen pre-trained model.

This example demonstrates how to extract features from an image using a pre-trained CNN model. These extracted features can be further utilized for various tasks, such as classification, object detection, or even generating image captions using techniques like transfer learning or fine-tuning.

Remember to adjust the code based on the specific pre-trained model you choose and the requirements of your application.

%d bloggers like this: