Skip to content

Instantly share code, notes, and snippets.

@matham
Created February 3, 2025 20:48
Show Gist options
  • Save matham/02290662c9e9cf1a85d511c075d002e8 to your computer and use it in GitHub Desktop.
Save matham/02290662c9e9cf1a85d511c075d002e8 to your computer and use it in GitHub Desktop.
Analyzes lightsheet optics for lenses and RI
import math
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import scipy.special
from scipy.stats import linregress
from scipy import integrate
import scipy.optimize as optimize
from scipy.optimize import curve_fit, minimize, Bounds
import numpy as np
from functools import partial
from multiprocessing import Pool, cpu_count
from pathlib import Path
import multiprocessing.pool as mpp
from functools import partial
from tqdm import tqdm
try:
import sage.all as sage
except ImportError:
pass
NUM_THREADS = max(cpu_count() - 4, 0)
def istarmap(self, func, iterable, chunksize=1):
"""starmap-version of imap
"""
self._check_running()
if chunksize < 1:
raise ValueError(
"Chunksize must be 1+, not {0:n}".format(
chunksize))
task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
result = mpp.IMapIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
mpp.starmapstar,
task_batches),
result._set_length
))
return (item for chunk in result for item in chunk)
mpp.Pool.istarmap = istarmap
def save_or_show(save_fig_root: None | Path = None, save_fig_prefix: str = ""):
if save_fig_root:
save_fig_root.mkdir(parents=True, exist_ok=True)
fig = plt.gcf()
fig.set_size_inches(16, 12)
fig.tight_layout()
# if fig.legends:
# legend = fig.legends[-1]
# fig_size = fig.get_size_inches()[0] * fig.dpi
# fig.subplots_adjust(right=1 - legend.get_window_extent().width / fig_size)
fig.savefig(
save_fig_root / f"{save_fig_prefix}.png", bbox_inches='tight',
dpi=300
)
plt.close()
else:
plt.tight_layout()
# fig = plt.gcf()
# if fig.legends:
# legend = fig.legends[-1]
# fig_size = fig.get_size_inches()[0] * fig.dpi
# fig.subplots_adjust(right=1 - legend.get_window_extent().width / fig_size)
plt.show()
def reduce_na_via_aperture(na: float, aperture_remaining_fraction, medium_n=1) -> float:
alpha = math.asin(na / medium_n)
alpha_new = math.atan(aperture_remaining_fraction * math.tan(alpha))
new_an = medium_n * math.sin(alpha_new)
return new_an
def B_n(order, gamma):
val = math.tan(gamma / 2) ** (order - 1)
val /= 2 * (order - 1) * math.sqrt(order + 1)
val *= 1 - (order - 1) / (order + 3) * math.tan(gamma / 2) ** 4
return val
def zernike_An0_coef(order, alpha, beta):
return B_n(order, alpha) - B_n(order, beta)
def zernike_poly(p, order):
match order:
case 0:
return 1
case 2:
return math.sqrt(3) * (2 * p ** 2 - 1)
case 4:
return math.sqrt(5) * (6 * p ** 4 - 6 * p ** 2 + 1)
case 6:
return math.sqrt(7) * (20 * p ** 6 - 30 * p ** 4 + 12 * p ** 2 - 1)
case 8:
return math.sqrt(9) * (70 * p ** 8 - 140 * p ** 6 + 90 * p ** 4 - 20 * p ** 2 + 1)
def zernike_term(p, alpha, beta, order):
val = zernike_An0_coef(order, alpha, beta) * zernike_poly(p, order)
return val
def aberration_An0_coef(objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, order, wavelength=None):
"""The decomposed coefficient An0 * k * d
"""
air_angle = math.asin(objective_na / air_n)
glass_angle = math.asin(objective_na / glass_n)
liquid_angle = math.asin(objective_na / liquid_n)
val = glass_d * zernike_An0_coef(order, air_angle, glass_angle)
val += liquid_d * zernike_An0_coef(order, air_angle, liquid_angle)
val *= objective_na
if wavelength is not None:
val *= 2 * math.pi
val /= wavelength
return val
def aberration_An0_coefs(objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength=None):
vals = []
for order in (0, 2, 4, 6, 8):
val = aberration_An0_coef(
objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, order, wavelength
)
vals.append(val)
return vals
def decomposed_aberration(p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, order, wavelength=None):
coef = aberration_An0_coef(
objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, order, wavelength
)
val = coef * zernike_poly(p, order)
return val
def decomposed_aberrations(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, orders=(0, 2, 4, 6, 8), wavelength=None
):
vals = []
for order in orders:
val = decomposed_aberration(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, order, wavelength
)
vals.append(val)
return vals
def plot_decomposed_aberrations(
objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, save_fig_root: None | Path = None,
save_fig_prefix: str = "",
):
n = 10000
f, axs = plt.subplots(2, 3, sharex=True)
for ax, order in zip(axs.flatten(), (0, 2, 4, 6, 8, None)):
x = []
y = []
for i in range(n):
if order is None:
abb = sum(decomposed_aberrations(
i / n, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, (0, 2, 4, 6, 8), wavelength
))
else:
abb = decomposed_aberration(
i / n, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, order, wavelength
)
x.append(i / n)
y.append(abb)
x = np.array(x)
y = np.array(y)
x = np.concatenate((-x[1:][::-1], x), axis=None)
y = np.concatenate((y[1:][::-1], y), axis=None)
ax.plot(x, y)
ax.set_xlabel("Normalized radius")
ax.set_ylabel("Phase aberration K * ψ(radius)")
if order is None:
ax.set_title(f"Decomposed aberration sum due to orders 0-8")
else:
ax.set_title(f"Decomposed aberration due to order {order}")
save_or_show(save_fig_root, save_fig_prefix)
def plot_decomposed_coef_na_vs_depth(air_n, glass_n, liquid_n, glass_d, wavelength, order=4):
n = 1000
na = np.linspace(.01, .5, n)
depth = np.linspace(.001, .03, n)
img = np.zeros((n, n))
pb = tqdm(total=n * n)
for i, d in enumerate(depth):
for j, n in enumerate(na):
img[i, j] = aberration_An0_coef(
n, air_n, glass_n, liquid_n, glass_d, d, order, wavelength
)
pb.update()
pb.close()
img[img >= 2 * math.pi] = 2 * math.pi
na_tick = na[1] - na[0]
depth_tick = depth[1] - depth[0]
im = plt.imshow(
img, aspect="auto", origin="lower",
extent=(na[0], na[-1] + na_tick, depth[0] * 1e3, (depth[-1] + depth_tick) * 1e3),
cmap="nipy_spectral"
)
plt.xlabel("NA")
plt.ylabel("Oil depth (mm)")
plt.title(f"Decomposed aberration coefficient K * A{order}0")
plt.colorbar(im)
plt.tight_layout()
plt.show()
def plot_decomposed_coef_na_at_depth(depth, air_n, glass_n, liquid_n, glass_d, wavelength):
n = 100000
na = np.linspace(.01, .5, n)
f, axs = plt.subplots(2, 3, sharex=True)
for ax, order in zip(axs.flatten(), (0, 2, 4, 6, 8)):
vals = np.zeros(n)
for i, na_ in tqdm(enumerate(na), total=n):
vals[i] = aberration_An0_coef(
na_, air_n, glass_n, liquid_n, glass_d, depth, order, wavelength
)
ax.plot(na, vals)
ax.set_xlabel("NA")
ax.set_ylabel("Decomposed aberration coefficient")
ax.set_title(f"K * A{order}0")
plt.suptitle(f"Oil depth = {depth * 1000} mm")
plt.show()
def plot_decomposed_coef_depth_at_na(na, air_n, glass_n, liquid_n, glass_d, wavelength):
n = 100000
depth = np.linspace(0, .02, n)
f, axs = plt.subplots(2, 3, sharex=True)
for ax, order in zip(axs.flatten(), (0, 2, 4, 6, 8)):
vals = np.zeros(n)
for i, d in tqdm(enumerate(depth), total=n):
vals[i] = aberration_An0_coef(
na, air_n, glass_n, liquid_n, glass_d, d, order, wavelength
)
ax.plot(depth * 1000, vals)
ax.set_xlabel("Oil depth (mm)")
ax.set_ylabel("Decomposed aberration coefficient")
ax.set_title(f"K * A{order}0")
plt.suptitle(f"NA = {na}")
plt.show()
def _total_phase_aberration_cosec(p, alpha, beta):
return math.sqrt(1 / math.sin(beta) ** 2 - p ** 2) - math.sqrt(1 / math.sin(alpha) ** 2 - p ** 2)
def _total_phase_aberration_cosec_sym(p, alpha, beta):
return (
sage.sqrt(sage.Integer(1) / sage.sin(beta) ** sage.Integer(2) - p ** sage.Integer(2)) -
sage.sqrt(sage.Integer(1) / sage.sin(alpha) ** sage.Integer(2) - p ** sage.Integer(2))
)
def total_phase_aberration(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength=None, correction_func=None
):
"""Computes the total phase aberration for the given radial distance p, for the given setup. If wavelength is not
None, we multiply by k. If correction_func is provided, we subtract the given aberration (before multiplying
by k, if multiplying).
"""
air_angle = math.asin(objective_na / air_n)
glass_angle = math.asin(objective_na / glass_n)
liquid_angle = math.asin(objective_na / liquid_n)
f_p = glass_d * _total_phase_aberration_cosec(p, air_angle, glass_angle)
f_p += liquid_d * _total_phase_aberration_cosec(p, air_angle, liquid_angle)
f_p *= objective_na
if wavelength is not None:
f_p *= 2 * math.pi
f_p /= wavelength
if correction_func is not None:
f_p += correction_func(
p=p, objective_na=objective_na, air_n=air_n, glass_n=glass_n, liquid_n=liquid_n, glass_d=glass_d,
liquid_d=liquid_d, wavelength=wavelength,
)
return f_p
def total_phase_aberration_sym(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength=None, correction_func=None
):
air_angle = math.asin(objective_na / air_n)
glass_angle = math.asin(objective_na / glass_n)
liquid_angle = math.asin(objective_na / liquid_n)
f_p = glass_d * _total_phase_aberration_cosec_sym(p, air_angle, glass_angle)
f_p = f_p + liquid_d * _total_phase_aberration_cosec_sym(p, air_angle, liquid_angle)
f_p = f_p * objective_na
if correction_func is not None:
f_p = f_p + correction_func(
p=p, objective_na=objective_na, air_n=air_n, glass_n=glass_n, liquid_n=liquid_n, glass_d=glass_d,
liquid_d=liquid_d
)
if wavelength is not None:
f_p = f_p * 2 * math.pi
f_p = f_p / wavelength
return f_p
def decomposed_phase_correction_func(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, corrected_orders
):
air_angle = math.asin(objective_na / air_n)
glass_angle = math.asin(objective_na / glass_n)
liquid_angle = math.asin(objective_na / liquid_n)
f_p = -glass_d * zernike_term(p, air_angle, glass_angle, corrected_orders[0])
f_p -= liquid_d * zernike_term(p, air_angle, liquid_angle, corrected_orders[0])
for order in corrected_orders[1:]:
f_p -= glass_d * zernike_term(p, air_angle, glass_angle, order)
f_p -= liquid_d * zernike_term(p, air_angle, liquid_angle, order)
f_p *= objective_na
f_p *= 2 * math.pi
f_p /= wavelength
return f_p
def plot_total_phase_aberrations_vs_rad(
objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func=None
):
n = 10000
radii = np.linspace(0, 1, n)
vals = np.zeros((len(radii) * 2 - 1, 2))
for i, p in tqdm(enumerate(radii), total=len(radii)):
vals[i + n - 1, 0] = total_phase_aberration(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength
)
if correction_func is not None:
vals[i + n - 1, 1] = correction_func(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength
)
vals[:n - 1, :] = vals[n:, :][::-1, :]
x = np.concatenate((-radii[1:][::-1], radii), axis=None)
fig, (ax, ax2, ax3) = plt.subplots(1, 3, sharex=True)
ax.plot(x, vals[:, 0])
ax.set_xlabel("Normalized radius")
ax.set_ylabel(f"Total phase aberration K * ψ(radius)")
if correction_func is not None:
ax2.plot(x, vals[:, 1])
ax2.set_ylabel(f"Correction function K * ψ(radius)")
ax3.plot(x, vals[:, 0] + vals[:, 1])
ax3.set_ylabel(f"Corrected K * ψ(radius)")
plt.show()
def _intensity_radial_point_spread_func(
p, v, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, exp_f
):
term = total_phase_aberration(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func
)
exp = exp_f(term)
val = exp * scipy.special.jv(0, v * p) * p
return val
def _intensity_radial_point_spread_func_sym(
p, v, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, exp_f
):
term = total_phase_aberration_sym(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func
)
exp = exp_f(term)
integrand = exp * sage.bessel_J(0, v * p) * p
return integrand
def intensity_radial_point_spread(
v, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func=None, symbolic=False
):
if symbolic:
p = sage.var("p", domain="real")
integrand = _intensity_radial_point_spread_func_sym(
p, v, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, sage.cos
)
val = sage.numerical_integral(integrand, sage.Integer(0), sage.Integer(1), max_points=1000)
val = complex(val[0], val[1])
p = sage.var("p", domain="real")
integrand = _intensity_radial_point_spread_func_sym(
p, v, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, sage.sin
)
val2 = sage.numerical_integral(integrand, sage.Integer(0), sage.Integer(1), max_points=1000)
val2 = complex(val2[0], val2[1])
else:
# noinspection PyTupleAssignmentBalance
val, err = integrate.quad(
_intensity_radial_point_spread_func, 0, 1,
args=(
v, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, math.cos
), limit=1000,
)
# noinspection PyTupleAssignmentBalance
val2, err = integrate.quad(
_intensity_radial_point_spread_func, 0, 1,
args=(
v, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, math.sin
), limit=1000,
)
return abs(val + val2 * 1j)
def plot_psf(
objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func=None, symbolic=False
):
"""Point spread function. Visualized radially.
"""
n = 1000
max_r_over_l = 50
v_factor = 2 * math.pi * objective_na
radii = np.linspace(0, max_r_over_l * v_factor, n)
inputs = []
results = []
for v in radii:
inputs.append(
(v, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, None, symbolic)
)
inputs.append(
(v, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, symbolic)
)
pb = tqdm(total=len(inputs))
with Pool(NUM_THREADS) as pool:
for res in pool.istarmap(intensity_radial_point_spread, inputs):
results.append(res)
pb.update()
pb.close()
vals = np.array(results).reshape((n, 2))
vals = np.concatenate((vals[1:, :][::-1, :], vals), axis=0)
x = np.concatenate((-radii[1:][::-1], radii), axis=None)
x /= v_factor
fig, (ax, ax2) = plt.subplots(1, 2, sharex=True)
ax.plot(x, vals[:, 0])
ax.set_xlabel("r/λ distance along radius")
ax.set_ylabel("PSF intensity")
ax.set_title("Raw PSF radially")
if correction_func is not None:
ax2.plot(x, vals[:, 1])
ax2.set_xlabel("r/λ distance along radius")
ax2.set_title("Corrected PSF radially")
plt.show()
def _planar_object_at_depth_intensity_func(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, exp_f
):
abb = total_phase_aberration(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func
)
term = 2 * abb
val = exp_f(term) * p
return val
def _planar_object_at_depth_intensity_func_sym(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, correction_func, wavelength, exp_f
):
abb = total_phase_aberration_sym(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func
)
term = 2 * abb
exp = exp_f(term)
integrand = exp * p
return integrand
def planar_object_at_depth_intensity(
objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func=None, symbolic=False
):
if symbolic:
p = sage.var("p")
integrand = _planar_object_at_depth_intensity_func_sym(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, sage.cos
)
val = sage.numerical_integral(integrand, sage.Integer(0), sage.Integer(1), max_points=1000)
val = complex(val[0], val[1])
p = sage.var("p")
integrand = _planar_object_at_depth_intensity_func_sym(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, sage.sin
)
val2 = sage.numerical_integral(integrand, sage.Integer(0), sage.Integer(1), max_points=1000)
val2 = complex(val2[0], val2[1])
else:
# noinspection PyTupleAssignmentBalance
val, err = integrate.quad(
_planar_object_at_depth_intensity_func, 0, 1,
args=(
objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, math.cos
), limit=1000,
)
# noinspection PyTupleAssignmentBalance
val2, err2 = integrate.quad(
_planar_object_at_depth_intensity_func, 0, 1,
args=(
objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, math.sin
), limit=1000,
)
return abs(val + val2 * 1j) ** 2
def plot_planar_object_at_depth_intensity(
objective_na, air_n, glass_n, liquid_n, glass_d, wavelength, center_z_over_l,
side_z_over_l, correction_func=None, symbolic=False
):
n = 1000
liquid_depths = np.linspace(
(center_z_over_l - side_z_over_l) * wavelength, (center_z_over_l + side_z_over_l) * wavelength, n
)
inputs = []
results = []
for liquid_d in liquid_depths:
inputs.append(
(objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, None, symbolic)
)
inputs.append(
(objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, symbolic)
)
pb = tqdm(total=len(inputs))
with Pool(NUM_THREADS) as pool:
for res in pool.istarmap(planar_object_at_depth_intensity, inputs):
results.append(res)
pb.update()
pb.close()
vals = np.array(results).reshape((n, 2))
liquid_depths /= wavelength
fig, (ax, ax2) = plt.subplots(1, 2, sharex=True)
ax.plot(liquid_depths, vals[:, 0])
ax.set_xlabel("Depth (z/λ)")
ax.set_ylabel("Intensity")
ax.set_title("Raw PSF axially")
if correction_func is not None:
ax2.plot(liquid_depths, vals[:, 1])
ax2.set_xlabel("Depth (z/λ)")
ax2.set_ylabel("Intensity")
ax2.set_title("Corrected PSF axially")
plt.show()
def _radial_z_point_spread_func(
p, r, z, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, exp_f
):
term = total_phase_aberration(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func
)
liquid_angle = math.asin(objective_na / liquid_n)
u = 8 * math.pi * liquid_n * z / wavelength * math.sin(liquid_angle / 2) ** 2
term2 = u * p ** 2 / 2
exp = exp_f(term + term2)
v = 2 * math.pi / wavelength * objective_na * r
val = exp * scipy.special.jv(0, p * v) * p
return val
def _radial_z_point_spread_func_sym(
p, r, z, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, exp_f
):
term = total_phase_aberration_sym(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func
)
liquid_angle = math.asin(objective_na / liquid_n)
u = 8 * math.pi * liquid_n * z / wavelength * math.sin(liquid_angle / 2) ** 2
term2 = u * p ** 2 / 2
exp = exp_f(term + term2)
v = 2 * math.pi / wavelength * objective_na * r
val = exp * sage.bessel_J(0, p * v) * p
return val
def radial_z_point_spread(
r, z, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, objective_wd, wavelength,
correction_func=None, symbolic=False
):
if symbolic:
p = sage.var("p")
integrand = _radial_z_point_spread_func_sym(
p, r, z, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, sage.cos
)
val = sage.numerical_integral(integrand, sage.Integer(0), sage.Integer(1), max_points=1000)
val = complex(val[0], val[1])
p = sage.var("p")
integrand = _radial_z_point_spread_func_sym(
p, r, z, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, sage.sin
)
val2 = sage.numerical_integral(integrand, sage.Integer(0), sage.Integer(1), max_points=1000)
val2 = complex(val2[0], val2[1])
else:
# noinspection PyTupleAssignmentBalance
val, err = integrate.quad(
_radial_z_point_spread_func, 0, 1,
args=(
r, z, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, math.cos
), limit=1000,
)
# noinspection PyTupleAssignmentBalance
val2, err = integrate.quad(
_radial_z_point_spread_func, 0, 1,
args=(
r, z, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, math.sin
), limit=1000,
)
integral = val + val2 * 1j
res = abs(2 * air_n / (wavelength * objective_wd) * integral) ** 4
res *= (math.pi * objective_wd ** 2 * objective_na ** 2 / (air_n ** 2 - objective_na ** 2)) ** 4
return res
def plot_radial_z_psf(
objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, objective_wd, wavelength, center_z_over_l,
side_z_over_l, max_r_over_l, correction_func=None, max_val: float | None = None,
max_corrected_val: float | None = None, save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
n_z = 10000
n_r = 1000
liquid_depths = np.linspace(
(center_z_over_l - side_z_over_l) * wavelength, (center_z_over_l + side_z_over_l) * wavelength, n_z
)
radii = np.linspace(0, max_r_over_l * wavelength, n_r)
vals = np.zeros((n_z, n_r, 2))
inputs = []
indices = []
results = []
for i, z in enumerate(liquid_depths):
for j, r in enumerate(radii):
inputs.append((
r, z, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, objective_wd, wavelength,
None
))
indices.append((i, j, 0))
inputs.append((
r, z, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, objective_wd, wavelength,
correction_func
))
indices.append((i, j, 1))
pb = tqdm(total=n_z * n_r * 2)
with Pool(NUM_THREADS) as pool:
for res in pool.istarmap(radial_z_point_spread, inputs):
results.append(res)
pb.update()
pb.close()
for res, (i, j, k) in zip(results, indices):
vals[i, j, k] = res
vals = np.concatenate((vals[:, 1:, :][:, ::-1, :], vals), axis=1)
radii = np.concatenate((-radii[1:][::-1], radii), axis=None) / wavelength
liquid_depths /= wavelength
if max_val is not None:
vals[:, :, 0][vals[:, :, 0] >= max_val] = max_val
vals[:, :, 1][vals[:, :, 1] >= max_corrected_val] = max_corrected_val
radii_tick = radii[1] - radii[0]
depth_tick = liquid_depths[1] - liquid_depths[0]
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
ax, ax2, ax3, ax4 = axs.flatten()
extent = radii[0], radii[-1] + radii_tick, liquid_depths[0], liquid_depths[-1] + depth_tick
# noinspection PyTypeChecker
im = ax.imshow(vals[:, :, 0], aspect="auto", origin="lower", extent=extent, cmap="nipy_spectral")
im2 = ax2.imshow(np.log10(vals[:, :, 0]), aspect="auto", origin="lower", extent=extent, cmap="nipy_spectral")
ax.set_xlabel("Radial distance/λ of point from axis")
secx = ax.secondary_xaxis('top', functions=(lambda x: x * wavelength * 1e3, lambda x: x / (wavelength * 1e3)))
secx.set_xlabel("Radial distance of point from axis")
ax.set_ylabel("Z-depth/λ of point from focal plane")
secy = ax.secondary_yaxis('right', functions=(lambda x: x * wavelength * 1e3, lambda x: x / (wavelength * 1e3)))
secy.set_ylabel("Z-depth (mm) of point from focal plane")
ax.set_title("Raw PSF")
ax2.set_title("Log10 intensity")
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.75)
fig.colorbar(im, cax=cax, orientation='vertical')
divider = make_axes_locatable(ax2)
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im2, cax=cax, orientation='vertical')
if correction_func is not None:
im = ax3.imshow(vals[:, :, 1], aspect="auto", origin="lower", extent=extent, cmap="nipy_spectral")
im2 = ax4.imshow(np.log10(vals[:, :, 1]), aspect="auto", origin="lower", extent=extent, cmap="nipy_spectral")
ax3.set_title("Corrected PSF")
divider = make_axes_locatable(ax3)
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical')
divider = make_axes_locatable(ax4)
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im2, cax=cax, orientation='vertical')
save_or_show(save_fig_root, save_fig_prefix)
return vals
def plot_center_r_z_psf(
objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, objective_wd, wavelength, center_z_over_l,
side_z_over_l, r=0, correction_func=None, save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
n = 10000
liquid_depths = np.linspace(
(center_z_over_l - side_z_over_l) * wavelength, (center_z_over_l + side_z_over_l) * wavelength, n,
)
inputs = []
results = []
pb = tqdm(total=len(liquid_depths) * 2)
for z in liquid_depths:
inputs.append((
r, z, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, objective_wd, wavelength,
None
))
inputs.append((
r, z, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, objective_wd, wavelength,
correction_func
))
with Pool(NUM_THREADS) as pool:
for res in pool.istarmap(radial_z_point_spread, inputs):
results.append(res)
pb.update()
pb.close()
results = np.array(results).reshape((-1, 2))
fig, axs = plt.subplots(2, 1, sharex=True, sharey=False)
ax, ax2 = axs.flatten()
ax.plot(liquid_depths / wavelength, results[:, 0])
secax = ax.secondary_xaxis('top', functions=(lambda x: x * wavelength * 1e3, lambda x: x / (wavelength * 1e3)))
secax.set_xlabel("Z-depth (mm) of point from focal plane")
ax.set_ylabel("Intensity")
ax.set_title(f"Raw PSF at radius/λ = {r / wavelength}")
ax2.plot(liquid_depths / wavelength, results[:, 1])
ax2.set_xlabel("Z-depth/λ of point from focal plane")
ax2.set_ylabel("Intensity")
ax2.set_title(f"Corrected PSF at radius/λ = {r / wavelength}")
save_or_show(save_fig_root, save_fig_prefix)
return results
def phase_plate_surface_from_spec(r, plate_curvature_r, plate_k2, plate_k4):
c = 1 / plate_curvature_r
z = r ** 2 * c
z /= 1 + np.sqrt(1 - (r * c) ** 2)
z += plate_k2 * r ** 2
z += plate_k4 * r ** 4
return z
def phase_plate_surface_derivative_from_spec(r, plate_curvature_r, plate_k2, plate_k4):
c = 1 / plate_curvature_r
z = r * c
z /= np.sqrt(1 - (r * c) ** 2)
z += 2 * plate_k2 * r
z += 4 * plate_k4 * r ** 3
return z
def phase_plate_surface_wavefront_from_spec(r, plate_curvature_r, plate_k2, plate_k4, air_n, plate_n, wavelength):
n_a_minus_i_sq = air_n ** 2 - plate_n ** 2
sqrt_term = air_n ** 2 + n_a_minus_i_sq * phase_plate_surface_derivative_from_spec(
r, plate_curvature_r, plate_k2, plate_k4
) ** 2
z = n_a_minus_i_sq
z *= phase_plate_surface_from_spec(r, plate_curvature_r, plate_k2, plate_k4)
if np.any(sqrt_term < 0):
raise ValueError(
f"Got negative sqrt term ({sqrt_term}) for r = {r}, k2 = {phase_plate_k2_mm}, k4 = {phase_plate_k4_mm}"
)
z *= np.sqrt(sqrt_term)
z /= air_n ** 2 * (plate_n + np.sqrt(sqrt_term))
z *= 2 * np.pi / wavelength
return z
def phase_plate_correction_func(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, plate_r, plate_curvature_r,
plate_k2_mm, plate_k4_mm, plate_n
):
abb = phase_plate_surface_wavefront_from_spec(
p * plate_r * 1e3, plate_curvature_r * 1e3, plate_k2_mm, plate_k4_mm, air_n, plate_n, wavelength * 1e3
)
# abb2 = decomposed_phase_correction_func(
# p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, (0, 2)
# )
abb2 = 0
return abb + abb2
def plot_phase_plate(
plate_r, plate_curvature_r, plate_k2_mm, plate_k4_mm, plate_n, objective_na, air_n, glass_n, liquid_n,
glass_d, liquid_d, wavelength
):
n = 1000
radii = np.linspace(-plate_r, plate_r, n)
plate_surface_mm = np.array([
phase_plate_surface_from_spec(r * 1e3, plate_curvature_r * 1e3, plate_k2_mm, plate_k4_mm)
for r in radii
])
wavefront_waves = np.array([
phase_plate_surface_wavefront_from_spec(
r * 1e3, plate_curvature_r * 1e3, plate_k2_mm, plate_k4_mm, air_n, plate_n, wavelength * 1e3)
for r in radii
])
radii /= plate_r
abb_waves = np.array([sum(decomposed_aberrations(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, (4, 6, 8), wavelength))
for p in radii
])
f, axs = plt.subplots(2, 2, sharex=True)
ax, ax2, ax3, ax4 = axs.flatten()
ax.plot(radii, plate_surface_mm)
ax.set_xlabel("Radial distance (normalized)")
ax.set_ylabel("Asphere surface (mm)")
ax2.plot(radii, wavefront_waves)
ax2.set_xlabel("Radial distance (normalized)")
ax2.set_ylabel("Plate wavefront (waves)")
ax3.plot(radii, abb_waves)
ax3.set_xlabel("Radial distance (normalized)")
ax3.set_ylabel("Lens aberration (waves)")
ax4.plot(radii, abb_waves + wavefront_waves)
ax4.set_xlabel("Radial distance (normalized)")
ax4.set_ylabel("Remaining aberration after correction")
print(np.sum(np.abs(abb_waves + wavefront_waves)), np.sum(np.abs(abb_waves)), np.sum(np.abs(wavefront_waves)))
plt.show()
def func(x, radii, radii_norm, plate_curvature_r, plate_n, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d,
wavelength, ):
k2_mm, k4_mm = x
abb = np.array(decomposed_aberrations(
radii_norm, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, (4, 6, 8), wavelength
))
abb = np.sum(abb, axis=0)
plate_abb = phase_plate_surface_wavefront_from_spec(
radii * 1e3, plate_curvature_r * 1e3,
# plate_k2_mm, k4_mm,
# k4_mm, plate_k4_mm,
k2_mm, k4_mm,
air_n, plate_n, wavelength * 1e3
)
# return plate_abb
# plt.plot(np.square(abb - plate_abb))
# plt.show()
dist = np.max(np.abs(abb + plate_abb)) - np.min(np.abs(abb + plate_abb))
dist += .3 * np.max(np.abs(plate_abb))
# print(x, np.sum(np.abs(plate_abb)), np.sum(np.abs(abb)), np.sum(np.abs(abb + plate_abb)))
return dist
def compute_phase_plate_zernike_to_wavefront(
plate_r, plate_curvature_r, plate_k2_mm, plate_k4_mm, plate_n, objective_na, air_n, glass_n, liquid_n,
glass_d, liquid_d, wavelength
):
n = 1000
Ns = 10000
radii = np.linspace(-plate_r, plate_r, n)
# radii = np.array([-plate_r, 0, plate_r])
radii_norm = radii / plate_r
#
abb_waves = np.array([sum(decomposed_aberrations(
p, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, (4,), wavelength))
for p in radii / plate_r
])
# pb = tqdm(total=Ns * Ns)
# def func(x):
# k2_mm, k4_mm = x
# abb = np.array(decomposed_aberrations(
# radii_norm, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, (4,), wavelength
# ))
# abb = np.sum(abb, axis=0)
#
# plate_abb = phase_plate_surface_wavefront_from_spec(
# radii * 1e3, plate_curvature_r * 1e3,
# # plate_k2_mm, k4_mm,
# # k4_mm, plate_k4_mm,
# k2_mm, k4_mm,
# air_n, plate_n, wavelength * 1e3
# )
# # return plate_abb
#
# # plt.plot(np.square(abb - plate_abb))
# # plt.show()
# dist = np.sum(np.square(abb + plate_abb))
# # print(x, np.sum(np.abs(plate_abb)), np.sum(np.abs(abb)), np.sum(np.abs(abb + plate_abb)))
# pb.update()
# return dist
# assert plate_n > air_n
# k2_scalar = (air_n ** 2 - plate_n ** 2) * 2 * (plate_r * 1e3)
# k4_scalar = (air_n ** 2 - plate_n ** 2) * 4 * (plate_r * 1e3) ** 3
# constraint = optimize.LinearConstraint(np.array([k2_scalar, k4_scalar]), lb=-air_n ** 2)
res = optimize.brute(func, ((-2.45e-7, -2.45e-5), (3.9e-10, 3.9e-8)), Ns=Ns, workers=-1, args=(
radii, radii_norm, plate_curvature_r, plate_n, objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d,
wavelength))
# (k4_mm, ), *res = curve_fit(func, radii, abb_waves, [plate_k4_mm], bounds=(1e-11, 1e-6))
# (k4_mm,), *res = curve_fit(func, radii, abb_waves, [plate_k2_mm], bounds=(-1e-4, -1e-10))
# (k2_mm, k4_mm), *res = curve_fit(func, radii, abb_waves, [plate_k2_mm, plate_k4_mm], bounds=Bounds(np.array((-1e-3, 1e-13)), np.array((-1e-8, 1e-6))))
# print(k2_mm, k4_mm, res)
# print(res.x, res.success, res. message)
print(res)
def plot_phase_for_k(
plate_r, plate_curvature_r, plate_k2_mm, plate_k4_mm, plate_n, objective_na, air_n, glass_n, liquid_n,
glass_d, liquid_d, wavelength
):
n = 100
radii = np.linspace(-plate_r, plate_r, n)
total = 10
orders = 1
for o in range(orders):
for i in range(total):
vals = phase_plate_surface_wavefront_from_spec(
radii * 1e3, plate_curvature_r * 1e3, plate_k2_mm * i / (total * 10 ** o), plate_k4_mm, air_n, plate_n,
wavelength * 1e3
)
plt.plot(radii, vals, label=f"{i}E{o}")
plt.show()
def objective_resolution(objective_na, air_n, wavelength, fwhm=False):
m = 1
w_xy = 0
if 0.1 < objective_na < 0.7:
w_xy = 0.32 * wavelength / (math.sqrt(m) * objective_na)
elif objective_na > 0.7:
w_xy = 0.325 * wavelength / (math.sqrt(m) * objective_na ** 0.91)
w_z = 0.532 * wavelength / math.sqrt(m)
w_z /= air_n - math.sqrt(air_n ** 2 - objective_na ** 2)
# w_xy *= 2 * math.pi / wavelength
# w_z *= 2 * math.pi / wavelength
if fwhm:
w_xy *= 2 * math.sqrt(math.log(2))
w_z *= 2 * math.sqrt(math.log(2))
return w_xy, w_z
def circle_height_at_x(radius, x):
"""The distance between top to bottom of circle in y, at a distance x from center along X."""
theta = np.acos(x / radius)
y = radius * np.sin(theta)
return 2 * y
def circle_inner_outer_y_dist_at_x(outer_radius, inner_radius, inner_offset_x, n_samples: int = 10_000):
inner_x = np.linspace(0, inner_radius, n_samples)
outer_x = inner_x + inner_offset_x
inner_h = circle_height_at_x(inner_radius, inner_x)
outer_h = circle_height_at_x(outer_radius, outer_x)
dist = outer_h - inner_h
i = np.argmin(dist)
return dist[i], inner_x[i]
def plot_inner_outer_movable_dist(outer_radius, inner_radius, max_inner_offset_x, n_offsets: int = 100):
inner_offsets = np.linspace(0, max_inner_offset_x, n_offsets)
dist = []
d_at_min = []
for o in inner_offsets:
a, b = circle_inner_outer_y_dist_at_x(outer_radius, inner_radius, o)
dist.append(a)
d_at_min.append(b)
fig, ax = plt.subplots()
ax.plot(inner_offsets, dist, color="tab:red")
ax2 = ax.twinx()
ax2.plot(inner_offsets, d_at_min, color="tab:blue")
ax.set_ylabel("Max movable height", color="tab:red")
ax.tick_params(axis='y', labelcolor="tab:red")
ax.set_xlabel('Lens horizontal offset')
ax2.set_ylabel("Horizontal lens offset at min height", color="tab:blue")
ax2.tick_params(axis='y', labelcolor="tab:blue")
fig.tight_layout()
plt.show()
print(f"At lens offset of {inner_offsets[-1]}, movable height is {dist[-1]}")
def compute_oil_working_distance(
lens_r, liquid_n, liquid_d, original_wd, mag, n_pixels=2650, pixel_size_um=6.5
):
fov = pixel_size_um / mag * n_pixels / 1e6
fov_2 = fov / 2
theta_1 = math.atan((lens_r - fov_2) / original_wd)
theta_2 = math.asin(math.sin(theta_1 / liquid_n))
liquid_r = liquid_d * math.tan(theta_2)
air_r = lens_r - fov_2 - liquid_r
air_d = air_r / math.tan(theta_1)
return air_d, air_r
if __name__ == "__main__":
figure_root = Path(r"/home/matte/lightsheet_figures")
objective_na = 0.5
objective_wd = 0.02
objective_diameter = 0.047
tube_diameter = 0.03
air_n = 1
glass_n = 1.516
liquid_n = 1.517
glass_d = 0.001
liquid_d = 0.013
wavelength = 500e-9
corrected_objective_na = reduce_na_via_aperture(objective_na, tube_diameter / objective_diameter, medium_n=air_n)
phase_plate_n = 1.5168
phase_plate_t = .004
phase_plate_r = .05 / 2
phase_plate_curvature_r = float("inf")
phase_plate_k2_mm = -2.44876418e-06
phase_plate_k4_mm = 3.92193993e-09
# print(reduce_na_via_aperture(0.176, 35 / 50, medium_n=1))
correction_func = partial(decomposed_phase_correction_func, corrected_orders=(4, ))
# correction_func = partial(
# phase_plate_correction_func, plate_r=phase_plate_r, plate_curvature_r=phase_plate_curvature_r,
# plate_k2_mm=phase_plate_k2_mm, plate_k4_mm=phase_plate_k4_mm, plate_n=phase_plate_n,
# )
# correction_func = None
# print(aberration_An0_coef(objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, 4, wavelength))
# plot_decomposed_aberrations(objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength)
# plot_decomposed_coef_na_vs_depth(air_n, glass_n, liquid_n, glass_d, wavelength, 4)
# plot_decomposed_coef_na_at_depth(liquid_d, air_n, glass_n, liquid_n, glass_d, wavelength)
# plot_decomposed_coef_depth_at_na(objective_na, air_n, glass_n, liquid_n, glass_d, wavelength)
# plot_total_phase_aberrations_vs_rad(objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func)
# plot_psf(objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength, correction_func, symbolic=False)
# plot_planar_object_at_depth_intensity(objective_na, air_n, glass_n, liquid_n, glass_d, wavelength, -1800, 50, correction_func, symbolic=False)
# plot_radial_z_psf(objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, objective_wd, wavelength, -10700, 1250, 10, correction_func, symbolic=False)
# plot_center_r_z_psf(objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, objective_wd, wavelength, -10310, 1000, 0, correction_func, symbolic=False)
# plot_phase_plate(
# phase_plate_r, phase_plate_curvature_r, phase_plate_k2_mm, phase_plate_k4_mm, phase_plate_n, objective_na,
# air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength
# )
lens_2xc = {
"objective_na": 0.5,
"objective_wd": 0.02,
"objective_diameter": 0.047,
"center_z_over_l": -10725,
"side_z_over_l": 1000,
"prefix": "MV_PLAPO_2XC",
}
lens_1_6x = {
"objective_na": 0.176,
"objective_wd": 0.034,
"objective_diameter": 0.051,
"center_z_over_l": -10190,
"side_z_over_l": 500,
"prefix": "DFPLFL_1_6X",
}
lens_1_2x = {
"objective_na": 0.25,
"objective_wd": 0.065,
"objective_diameter": 0.051,
"center_z_over_l": -10190,
"side_z_over_l": 800,
"prefix": "DFPLAPO_1_2X",
}
for lens in (lens_2xc, lens_1_6x, lens_1_2x):
objective_wd = lens["objective_wd"]
center_z_over_l = lens["center_z_over_l"]
side_z_over_l = lens["side_z_over_l"]
prefix = lens["prefix"]
corrected_objective_na = reduce_na_via_aperture(
lens["objective_na"], tube_diameter / lens["objective_diameter"], medium_n=air_n
)
print(lens["objective_na"], corrected_objective_na, objective_resolution(corrected_objective_na, air_n, wavelength))
# plot_decomposed_aberrations(
# corrected_objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, wavelength,
# # save_fig_root=figure_root / "zernike_aberration", save_fig_prefix=f"{prefix}_zernike_aberration",
# )
# for orders, orders_name in (((4,), "4"), ((4, 6, 8), "4-8")):
# correction_func = partial(decomposed_phase_correction_func, corrected_orders=orders)
#
# vals = plot_center_r_z_psf(
# corrected_objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, objective_wd, wavelength,
# center_z_over_l, side_z_over_l, 0, correction_func,
# save_fig_root=figure_root / "center_r_psf",
# save_fig_prefix=f"{prefix}_center_r_psf_corrected_orders_{orders_name}",
# )
# plot_radial_z_psf(
# objective_na, air_n, glass_n, liquid_n, glass_d, liquid_d, objective_wd, wavelength, center_z_over_l,
# side_z_over_l, 10, correction_func, max_val=np.max(vals[:, :, 0]),
# max_corrected_val=np.max(vals[:, :, 1]),
# save_fig_root=figure_root / "psf",
# save_fig_prefix=f"{prefix}_psf_corrected_orders_{orders_name}",
# )
# plot_inner_outer_movable_dist(95.5 / 2, 67 / 2, 6)
print(compute_oil_working_distance(0.051 / 2, liquid_n, liquid_d, 0.034, 3.2))
print(compute_oil_working_distance(0.054 / 2, liquid_n, liquid_d, 0.065, 4))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment