Skip to content

Instantly share code, notes, and snippets.

@righthandabacus
Last active April 14, 2025 15:27
Show Gist options
  • Save righthandabacus/921be256ee8b95ccb046f327d5036168 to your computer and use it in GitHub Desktop.
Save righthandabacus/921be256ee8b95ccb046f327d5036168 to your computer and use it in GitHub Desktop.
Calculating Chemotaxis Index from Image of a Petri Dish
#!/usr/bin/env python
from math import pi, sqrt, acos, sin
import cv2
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage as ndi
import skimage.measure as measure
from sklearn.linear_model import RANSACRegressor
from sklearn.linear_model import LinearRegression
# Constants:
# the plate should be no less than 70% of the smaller dimension of the image
MIN_RADIUS_FACTOR = 0.7
# parameters to blur the image before Canny edge detection
GAUSSIAN_BLUR_SIGMA = 2
GAUSSIAN_BLUR_KERNEL = 5
# parameters to Canny edge detection, between 0 and 255, usually low is 1/3 of high
CANNY_LOW = 50
CANNY_HIGH = 150
# kernel size for closing and dilation
DILATION_KERNEL = 2
# effective radius as percentage of the plate radius to remove the edges of the plate
EFFECTIVE_RADIUS = 0.97
# "radius divisor" in pixel_counts()
RADIUS_DIVISOR = 5
# Threshold for object sizes in pixel counts, as percentage of the determined plate radius
SMALL_OBJECT_THRESHOLD = 0.1
LARGE_OBJECT_THRESHOLD = 2.6
print(f"{cv2.__version__ = }")
def circle_overlap_area(c1: tuple[float, float, float], c2: tuple[float, float, float]) -> float:
"""Calculate the area of overlap between two circles.
Args:
c1, c2 : tuples of (x, y, radius)
Two circles defined by their center coordinates and radius
Returns:
Area of overlap between the two circles
"""
x1, y1, r1 = c1
x2, y2, r2 = c2
# distance between centers
d_sq = (x2 - x1)**2 + (y2 - y1)**2
d = sqrt(d_sq)
if d >= r1 + r2:
return 0.0 # no overlap
if d <= abs(r1 - r2):
return pi * min(r1, r2)**2 # one circle is inside another
# calculate half-angle of sectors using cosine rule
r1_sq = r1**2
r2_sq = r2**2
alpha = acos((d_sq + r1_sq - r2_sq) / (2 * d * r1))
beta = acos((d_sq + r2_sq - r1_sq) / (2 * d * r2))
# area of sector = R^2 * half_angle
area1 = r1_sq * alpha
area2 = r2_sq * alpha
# area of the triangle
area3 = 0.5 * r1_sq * sin(2 * alpha)
area4 = 0.5 * r2_sq * sin(2 * beta)
# Total overlap area
return area1 + area2 - area3 - area4
def circle_overlap_matrix(circles: list[tuple[float, float, float]]) -> np.ndarray:
"""Create a matrix of overlap areas between all pairs of circles.
Args:
circles : list of tuples
List of (x, y, radius) tuples representing circles
Returns:
N x N matrix where N is the number of circles
Matrix[i,j] contains the overlap area between circle i and circle j
"""
n = len(circles)
overlap_matrix = np.zeros((n, n))
for i in range(n):
for j in range(i, n): # Only calculate upper triangle
overlap = circle_overlap_area(circles[i], circles[j])
overlap_matrix[i, j] = overlap
overlap_matrix[j, i] = overlap # Matrix is symmetric
return overlap_matrix
def find_circles(circles) -> list[bool]:
"""Find the right circles with majority vote"""
overlap_matrix = circle_overlap_matrix(circles)
row_medians = np.median(overlap_matrix, axis=1)
threshold = max(row_medians) - np.median(row_medians)
X = np.arange(len(row_medians)).reshape(-1, 1)
y = row_medians
# Create and fit the RANSAC regressor: Expact >60% of circles are similar
ransac = RANSACRegressor(
estimator=LinearRegression(),
min_samples=0.6,
residual_threshold=threshold,
random_state=42
)
ransac.fit(X, y)
# Get inlier mask
inlier_mask = ransac.inlier_mask_
return inlier_mask
def process_image(imagepath: str):
"""Read an image using OpenCV and preprocess it for chemotaxis analysis"""
# Read the image as RGB
if isinstance(imagepath, str):
img_bgr = cv2.imread(imagepath)
else:
img_bgr = cv2.imdecode(np.frombuffer(imagepath, np.uint8), cv2.IMREAD_COLOR)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
print(f"Image size: {img_rgb.shape[1]}x{img_rgb.shape[0]}")
# Grayscale -> Gaussian blur -> Canny edge detection
img_gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
img_blur = cv2.GaussianBlur(img_gray, (GAUSSIAN_BLUR_KERNEL,GAUSSIAN_BLUR_KERNEL), GAUSSIAN_BLUR_SIGMA)
img_edges = cv2.Canny(img_blur, CANNY_LOW, CANNY_HIGH)
# From the Canny edges, find circles using Hough transform (gradient method)
max_radius = min(img_gray.shape[:2]) // 2
min_radius = int(MIN_RADIUS_FACTOR * max_radius)
print(f"Radius range in Hough transform: {min_radius} to {max_radius}; with minDist {max(5, int(max_radius*0.02))}")
circles = cv2.HoughCircles(
img_edges, cv2.HOUGH_GRADIENT, dp=1, minDist=max(5, int(max_radius*0.02)),
param1=28, param2=25,
minRadius=min_radius, maxRadius=max_radius)
assert circles is not None, "No circles are found"
circles = circles[0, :, :3] # each row is [x, y, radius]
good = (
(circles[:, 0] - circles[:, 2] >= 0) & # x - r >= 0
(circles[:, 0] + circles[:, 2] <= img_edges.shape[1]) & # x + r <= width
(circles[:, 1] - circles[:, 2] >= 0) & # y - r >= 0
(circles[:, 1] + circles[:, 2] <= img_edges.shape[0]) # y + r <= height
)
circles = circles[good]
assert len(circles), "No valid circles are found"
# Find the plate using RANSAC
inlier = find_circles(circles)
selected_circle = circles[inlier].mean(axis=0)
# Crop the plate
x, y, radius = selected_circle
print(f"Determined the plate centered at ({x:.1f}, {y:.1f}) with radius {radius:.1f} pixels")
img_crop = img_gray[int(y-radius):int(y+radius), int(x-radius):int(x+radius)]
# Apply Canny filter to highlight the subjects
img_blur = cv2.GaussianBlur(img_crop, (GAUSSIAN_BLUR_KERNEL,GAUSSIAN_BLUR_KERNEL), GAUSSIAN_BLUR_SIGMA)
img_subject = cv2.Canny(img_blur, CANNY_LOW, CANNY_HIGH)
# Set up mask to crop the image of highlighted subject
mask = np.zeros(img_gray.shape[:2], dtype=np.uint8)
cv2.circle(mask, (int(x), int(y)), int(EFFECTIVE_RADIUS * radius), (255), cv2.FILLED)
mask = mask[int(y-radius):int(y+radius), int(x-radius):int(x+radius)]
img_subject = cv2.bitwise_and(img_subject, img_subject, mask=mask)
# Apply closing (dilation followed by erosion), then a dilation
kernel = np.ones((DILATION_KERNEL, DILATION_KERNEL), np.uint8)
img_dilated = cv2.morphologyEx(img_subject, cv2.MORPH_CLOSE, kernel, iterations=1)
img_dilated = cv2.dilate(img_dilated, kernel, iterations=1)
# Fill holes in the image: Flood fill point (0,0) with color 255, which marks the background
# then invert the image to find the subjects, then merge with the original using bitwise_or
img_filled = cv2.floodFill(img_dilated.copy(), None, (0,0), 255)[1]
img_filled = cv2.bitwise_not(img_filled)
img_filled = cv2.bitwise_or(img_dilated, img_filled)
# Count the number of connected regions
img_labelled, n_features = ndi.label(img_filled) # each connected element is labelled with a unique integer
sizes = np.bincount(img_labelled.ravel()) # count the number of pixels for each label
# Measure labelled image regions
# the returned list of RegionProperties objects contains various attributes. See:
# https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.regionprops
reg_props = measure.regionprops(img_labelled)
# build a lookup table to classify the objects too large or too small
# keep only objects labelled as 4 (large but not too large) or 5 (small but above threshold)
small = SMALL_OBJECT_THRESHOLD * radius
large = LARGE_OBJECT_THRESHOLD * radius
lut = np.zeros(len(reg_props)+1, dtype='int32')
lut[lut == 0] = 4 # value for default = 4
lut[sizes < small*5] = 5 # smaller than 5x small threshold is marked as "5"
lut[sizes < small] = 1 # object too small labels as "1"
lut[sizes > large] = 2 # object too large labels as "2"
lut[0] = 0 # background labels as "0"
img_filter = lut[img_labelled]
img_filter = ((img_filter == 4) | (img_filter == 5)).astype(np.uint8)
# keep a copy of what pixels are removed
img_removed = cv2.bitwise_and(img_filled, img_filled, mask=(1-img_filter))
# Return all the images and useful data
return {
"img_rgb": img_rgb, # original image
"img_gray": img_gray, # grayscale image
"img_edges": img_edges, # edges: showing the plate and the subject
"img_crop": img_crop, # cropped image zoomed into the plate
"img_subject": img_subject, # subject of the image with the plate edge removed
"img_dilated": img_dilated, # subject edges after dilation
"img_filled": img_filled, # subjects filled
"img_labelled": img_labelled, # integer labels of pixels on the subjects
"img_filter": img_filter, # filter mask (0 and 1) of showing subjects with objects too large or too small removed
"img_removed": img_removed, # copy of what pixels are removed due to object too large or too small
"circles": circles, # circles found by Hough transform
"selected_circle": selected_circle, # selected circle from RANSAC and average
}
def chemotaxis_index(img_filter, radius=None, radius_divisor=5):
"""Calculate the chemotaxis index of an image"""
# Draw mask isolating the center part of the plate
if radius is None:
radius = img_filter.shape[0]/2
radius = int(radius)
x = img_filter.shape[0]//2
y = img_filter.shape[1]//2
mask = np.zeros(img_filter.shape, dtype=np.uint8)
cv2.circle(mask, (x, y), int(radius/radius_divisor), (1), cv2.FILLED)
# Keep a visualization of the plate
img_q = cv2.cvtColor(255*cv2.bitwise_and(img_filter, img_filter, mask=mask), cv2.COLOR_GRAY2RGB)
img_n = cv2.cvtColor(255*cv2.bitwise_and(img_filter, img_filter, mask=(1-mask)), cv2.COLOR_GRAY2RGB)
for img in [img_q, img_n]:
cv2.circle(img, (x, y), int(radius/radius_divisor), (255,0,0), 2)
cv2.circle(img, (x, y), radius, (255,0,0), 2)
cv2.line(img, (0, y), (2*x, y), (255,0,0), 2)
cv2.line(img, (x, 0), (x, 2*y), (255,0,0), 2)
# count pixels
n, q = img_filter.copy(), img_filter.copy()
q[mask == 1] = 0 # copy of the image with center region removed
n[mask == 0] = 0 # copy of the image with outer region removed
tl = sum(q[0:radius,0:radius].flatten()) # pixel count at top left quantrant
tr = sum(q[0:radius,radius:].flatten()) # pixel count at top right quantrant
bl = sum(q[radius:,0:radius].flatten()) # pixel count at bottom left quantrant
br = sum(q[radius:,radius:].flatten()) # pixel count at bottom right quantrant
n = sum(n.flatten()) # total number of highlighted pixels in outer region
total_q = sum(q.flatten()) # total number of highlighted pixels in center region
total = sum(img_filter.flatten()) # total number of highlighted pixels
ci_val = ((tl + br) - (tr + bl)) / total # chemotaxis index
return {
"img_q": img_q,
"img_n": img_n,
"Q1 (top left)": tl,
"Q2 (top right)": tr,
"Q3 (bottom left)": bl,
"Q4 (bottom right)": br,
"N (outer region)": n,
"Total (center region)": total_q,
"Total (whole plate)": total,
"CI = ((Q1 + Q4) - (Q2 + Q3)) \u00f7 Total": ci_val
}
def main(imagepath: str):
processed = process_image(imagepath)
result = chemotaxis_index(processed["img_filter"])
for k, v in result.items():
if not k.startswith("img_"):
print(f"{k} = {v}")
images = {k: v for k, v in processed.items() if k.startswith("img_")}
images.update({k: v for k, v in result.items() if k.startswith("img_")})
fig, axs = plt.subplots(4, 3, figsize=(15, 20))
for i, (img_name, img) in enumerate(images.items()):
axs[i//3, i%3].imshow(img)
axs[i//3, i%3].set_title(img_name)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main("/path/to/your_image.png")
@righthandabacus
Copy link
Author

This is a rewrite from https://github.com/AndersenLab/chemotaxis-cli using OpenCV and adapted to Python 3 syntax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment