Last active
April 14, 2025 15:27
-
-
Save righthandabacus/921be256ee8b95ccb046f327d5036168 to your computer and use it in GitHub Desktop.
Calculating Chemotaxis Index from Image of a Petri Dish
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from math import pi, sqrt, acos, sin | |
import cv2 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import scipy.ndimage as ndi | |
import skimage.measure as measure | |
from sklearn.linear_model import RANSACRegressor | |
from sklearn.linear_model import LinearRegression | |
# Constants: | |
# the plate should be no less than 70% of the smaller dimension of the image | |
MIN_RADIUS_FACTOR = 0.7 | |
# parameters to blur the image before Canny edge detection | |
GAUSSIAN_BLUR_SIGMA = 2 | |
GAUSSIAN_BLUR_KERNEL = 5 | |
# parameters to Canny edge detection, between 0 and 255, usually low is 1/3 of high | |
CANNY_LOW = 50 | |
CANNY_HIGH = 150 | |
# kernel size for closing and dilation | |
DILATION_KERNEL = 2 | |
# effective radius as percentage of the plate radius to remove the edges of the plate | |
EFFECTIVE_RADIUS = 0.97 | |
# "radius divisor" in pixel_counts() | |
RADIUS_DIVISOR = 5 | |
# Threshold for object sizes in pixel counts, as percentage of the determined plate radius | |
SMALL_OBJECT_THRESHOLD = 0.1 | |
LARGE_OBJECT_THRESHOLD = 2.6 | |
print(f"{cv2.__version__ = }") | |
def circle_overlap_area(c1: tuple[float, float, float], c2: tuple[float, float, float]) -> float: | |
"""Calculate the area of overlap between two circles. | |
Args: | |
c1, c2 : tuples of (x, y, radius) | |
Two circles defined by their center coordinates and radius | |
Returns: | |
Area of overlap between the two circles | |
""" | |
x1, y1, r1 = c1 | |
x2, y2, r2 = c2 | |
# distance between centers | |
d_sq = (x2 - x1)**2 + (y2 - y1)**2 | |
d = sqrt(d_sq) | |
if d >= r1 + r2: | |
return 0.0 # no overlap | |
if d <= abs(r1 - r2): | |
return pi * min(r1, r2)**2 # one circle is inside another | |
# calculate half-angle of sectors using cosine rule | |
r1_sq = r1**2 | |
r2_sq = r2**2 | |
alpha = acos((d_sq + r1_sq - r2_sq) / (2 * d * r1)) | |
beta = acos((d_sq + r2_sq - r1_sq) / (2 * d * r2)) | |
# area of sector = R^2 * half_angle | |
area1 = r1_sq * alpha | |
area2 = r2_sq * alpha | |
# area of the triangle | |
area3 = 0.5 * r1_sq * sin(2 * alpha) | |
area4 = 0.5 * r2_sq * sin(2 * beta) | |
# Total overlap area | |
return area1 + area2 - area3 - area4 | |
def circle_overlap_matrix(circles: list[tuple[float, float, float]]) -> np.ndarray: | |
"""Create a matrix of overlap areas between all pairs of circles. | |
Args: | |
circles : list of tuples | |
List of (x, y, radius) tuples representing circles | |
Returns: | |
N x N matrix where N is the number of circles | |
Matrix[i,j] contains the overlap area between circle i and circle j | |
""" | |
n = len(circles) | |
overlap_matrix = np.zeros((n, n)) | |
for i in range(n): | |
for j in range(i, n): # Only calculate upper triangle | |
overlap = circle_overlap_area(circles[i], circles[j]) | |
overlap_matrix[i, j] = overlap | |
overlap_matrix[j, i] = overlap # Matrix is symmetric | |
return overlap_matrix | |
def find_circles(circles) -> list[bool]: | |
"""Find the right circles with majority vote""" | |
overlap_matrix = circle_overlap_matrix(circles) | |
row_medians = np.median(overlap_matrix, axis=1) | |
threshold = max(row_medians) - np.median(row_medians) | |
X = np.arange(len(row_medians)).reshape(-1, 1) | |
y = row_medians | |
# Create and fit the RANSAC regressor: Expact >60% of circles are similar | |
ransac = RANSACRegressor( | |
estimator=LinearRegression(), | |
min_samples=0.6, | |
residual_threshold=threshold, | |
random_state=42 | |
) | |
ransac.fit(X, y) | |
# Get inlier mask | |
inlier_mask = ransac.inlier_mask_ | |
return inlier_mask | |
def process_image(imagepath: str): | |
"""Read an image using OpenCV and preprocess it for chemotaxis analysis""" | |
# Read the image as RGB | |
if isinstance(imagepath, str): | |
img_bgr = cv2.imread(imagepath) | |
else: | |
img_bgr = cv2.imdecode(np.frombuffer(imagepath, np.uint8), cv2.IMREAD_COLOR) | |
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
print(f"Image size: {img_rgb.shape[1]}x{img_rgb.shape[0]}") | |
# Grayscale -> Gaussian blur -> Canny edge detection | |
img_gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY) | |
img_blur = cv2.GaussianBlur(img_gray, (GAUSSIAN_BLUR_KERNEL,GAUSSIAN_BLUR_KERNEL), GAUSSIAN_BLUR_SIGMA) | |
img_edges = cv2.Canny(img_blur, CANNY_LOW, CANNY_HIGH) | |
# From the Canny edges, find circles using Hough transform (gradient method) | |
max_radius = min(img_gray.shape[:2]) // 2 | |
min_radius = int(MIN_RADIUS_FACTOR * max_radius) | |
print(f"Radius range in Hough transform: {min_radius} to {max_radius}; with minDist {max(5, int(max_radius*0.02))}") | |
circles = cv2.HoughCircles( | |
img_edges, cv2.HOUGH_GRADIENT, dp=1, minDist=max(5, int(max_radius*0.02)), | |
param1=28, param2=25, | |
minRadius=min_radius, maxRadius=max_radius) | |
assert circles is not None, "No circles are found" | |
circles = circles[0, :, :3] # each row is [x, y, radius] | |
good = ( | |
(circles[:, 0] - circles[:, 2] >= 0) & # x - r >= 0 | |
(circles[:, 0] + circles[:, 2] <= img_edges.shape[1]) & # x + r <= width | |
(circles[:, 1] - circles[:, 2] >= 0) & # y - r >= 0 | |
(circles[:, 1] + circles[:, 2] <= img_edges.shape[0]) # y + r <= height | |
) | |
circles = circles[good] | |
assert len(circles), "No valid circles are found" | |
# Find the plate using RANSAC | |
inlier = find_circles(circles) | |
selected_circle = circles[inlier].mean(axis=0) | |
# Crop the plate | |
x, y, radius = selected_circle | |
print(f"Determined the plate centered at ({x:.1f}, {y:.1f}) with radius {radius:.1f} pixels") | |
img_crop = img_gray[int(y-radius):int(y+radius), int(x-radius):int(x+radius)] | |
# Apply Canny filter to highlight the subjects | |
img_blur = cv2.GaussianBlur(img_crop, (GAUSSIAN_BLUR_KERNEL,GAUSSIAN_BLUR_KERNEL), GAUSSIAN_BLUR_SIGMA) | |
img_subject = cv2.Canny(img_blur, CANNY_LOW, CANNY_HIGH) | |
# Set up mask to crop the image of highlighted subject | |
mask = np.zeros(img_gray.shape[:2], dtype=np.uint8) | |
cv2.circle(mask, (int(x), int(y)), int(EFFECTIVE_RADIUS * radius), (255), cv2.FILLED) | |
mask = mask[int(y-radius):int(y+radius), int(x-radius):int(x+radius)] | |
img_subject = cv2.bitwise_and(img_subject, img_subject, mask=mask) | |
# Apply closing (dilation followed by erosion), then a dilation | |
kernel = np.ones((DILATION_KERNEL, DILATION_KERNEL), np.uint8) | |
img_dilated = cv2.morphologyEx(img_subject, cv2.MORPH_CLOSE, kernel, iterations=1) | |
img_dilated = cv2.dilate(img_dilated, kernel, iterations=1) | |
# Fill holes in the image: Flood fill point (0,0) with color 255, which marks the background | |
# then invert the image to find the subjects, then merge with the original using bitwise_or | |
img_filled = cv2.floodFill(img_dilated.copy(), None, (0,0), 255)[1] | |
img_filled = cv2.bitwise_not(img_filled) | |
img_filled = cv2.bitwise_or(img_dilated, img_filled) | |
# Count the number of connected regions | |
img_labelled, n_features = ndi.label(img_filled) # each connected element is labelled with a unique integer | |
sizes = np.bincount(img_labelled.ravel()) # count the number of pixels for each label | |
# Measure labelled image regions | |
# the returned list of RegionProperties objects contains various attributes. See: | |
# https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.regionprops | |
reg_props = measure.regionprops(img_labelled) | |
# build a lookup table to classify the objects too large or too small | |
# keep only objects labelled as 4 (large but not too large) or 5 (small but above threshold) | |
small = SMALL_OBJECT_THRESHOLD * radius | |
large = LARGE_OBJECT_THRESHOLD * radius | |
lut = np.zeros(len(reg_props)+1, dtype='int32') | |
lut[lut == 0] = 4 # value for default = 4 | |
lut[sizes < small*5] = 5 # smaller than 5x small threshold is marked as "5" | |
lut[sizes < small] = 1 # object too small labels as "1" | |
lut[sizes > large] = 2 # object too large labels as "2" | |
lut[0] = 0 # background labels as "0" | |
img_filter = lut[img_labelled] | |
img_filter = ((img_filter == 4) | (img_filter == 5)).astype(np.uint8) | |
# keep a copy of what pixels are removed | |
img_removed = cv2.bitwise_and(img_filled, img_filled, mask=(1-img_filter)) | |
# Return all the images and useful data | |
return { | |
"img_rgb": img_rgb, # original image | |
"img_gray": img_gray, # grayscale image | |
"img_edges": img_edges, # edges: showing the plate and the subject | |
"img_crop": img_crop, # cropped image zoomed into the plate | |
"img_subject": img_subject, # subject of the image with the plate edge removed | |
"img_dilated": img_dilated, # subject edges after dilation | |
"img_filled": img_filled, # subjects filled | |
"img_labelled": img_labelled, # integer labels of pixels on the subjects | |
"img_filter": img_filter, # filter mask (0 and 1) of showing subjects with objects too large or too small removed | |
"img_removed": img_removed, # copy of what pixels are removed due to object too large or too small | |
"circles": circles, # circles found by Hough transform | |
"selected_circle": selected_circle, # selected circle from RANSAC and average | |
} | |
def chemotaxis_index(img_filter, radius=None, radius_divisor=5): | |
"""Calculate the chemotaxis index of an image""" | |
# Draw mask isolating the center part of the plate | |
if radius is None: | |
radius = img_filter.shape[0]/2 | |
radius = int(radius) | |
x = img_filter.shape[0]//2 | |
y = img_filter.shape[1]//2 | |
mask = np.zeros(img_filter.shape, dtype=np.uint8) | |
cv2.circle(mask, (x, y), int(radius/radius_divisor), (1), cv2.FILLED) | |
# Keep a visualization of the plate | |
img_q = cv2.cvtColor(255*cv2.bitwise_and(img_filter, img_filter, mask=mask), cv2.COLOR_GRAY2RGB) | |
img_n = cv2.cvtColor(255*cv2.bitwise_and(img_filter, img_filter, mask=(1-mask)), cv2.COLOR_GRAY2RGB) | |
for img in [img_q, img_n]: | |
cv2.circle(img, (x, y), int(radius/radius_divisor), (255,0,0), 2) | |
cv2.circle(img, (x, y), radius, (255,0,0), 2) | |
cv2.line(img, (0, y), (2*x, y), (255,0,0), 2) | |
cv2.line(img, (x, 0), (x, 2*y), (255,0,0), 2) | |
# count pixels | |
n, q = img_filter.copy(), img_filter.copy() | |
q[mask == 1] = 0 # copy of the image with center region removed | |
n[mask == 0] = 0 # copy of the image with outer region removed | |
tl = sum(q[0:radius,0:radius].flatten()) # pixel count at top left quantrant | |
tr = sum(q[0:radius,radius:].flatten()) # pixel count at top right quantrant | |
bl = sum(q[radius:,0:radius].flatten()) # pixel count at bottom left quantrant | |
br = sum(q[radius:,radius:].flatten()) # pixel count at bottom right quantrant | |
n = sum(n.flatten()) # total number of highlighted pixels in outer region | |
total_q = sum(q.flatten()) # total number of highlighted pixels in center region | |
total = sum(img_filter.flatten()) # total number of highlighted pixels | |
ci_val = ((tl + br) - (tr + bl)) / total # chemotaxis index | |
return { | |
"img_q": img_q, | |
"img_n": img_n, | |
"Q1 (top left)": tl, | |
"Q2 (top right)": tr, | |
"Q3 (bottom left)": bl, | |
"Q4 (bottom right)": br, | |
"N (outer region)": n, | |
"Total (center region)": total_q, | |
"Total (whole plate)": total, | |
"CI = ((Q1 + Q4) - (Q2 + Q3)) \u00f7 Total": ci_val | |
} | |
def main(imagepath: str): | |
processed = process_image(imagepath) | |
result = chemotaxis_index(processed["img_filter"]) | |
for k, v in result.items(): | |
if not k.startswith("img_"): | |
print(f"{k} = {v}") | |
images = {k: v for k, v in processed.items() if k.startswith("img_")} | |
images.update({k: v for k, v in result.items() if k.startswith("img_")}) | |
fig, axs = plt.subplots(4, 3, figsize=(15, 20)) | |
for i, (img_name, img) in enumerate(images.items()): | |
axs[i//3, i%3].imshow(img) | |
axs[i//3, i%3].set_title(img_name) | |
plt.tight_layout() | |
plt.show() | |
if __name__ == "__main__": | |
main("/path/to/your_image.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is a rewrite from https://github.com/AndersenLab/chemotaxis-cli using OpenCV and adapted to Python 3 syntax