Last active
August 22, 2021 14:21
-
-
Save hanjinliu/b2b84ec6774b79d2027fa8a5ed945411 to your computer and use it in GitHub Desktop.
Make image analysis protocols in napari from a function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import inspect | |
from typing import Callable | |
import napari | |
from magicgui.widgets import Label, Table, create_widget, Container | |
import numpy as np | |
from skimage.measure import regionprops_table | |
# GUI class can generate image analysis protocols using decorator @gui.bind_protocol | |
class GUI: | |
def __init__(self, viewer:"napari.Viewer"): | |
self.viewer = viewer | |
self.proceed = False | |
self._yielded_func = None | |
def bind_protocol(self, func=None, key1:str="F1", key2:str="F2", | |
allowed_dims:int|tuple[int, ...]=(1, 2, 3), exit_with_error:bool=False): | |
""" | |
Make protocol from a generator of functions. | |
Parameters | |
---------- | |
func : function generator | |
Protocol function. This function must accept ``func(self)`` and yield functions that accept | |
``f(self, **kwargs)`` . Docstring of the yielded functions will be displayed on the top of the | |
parameter container as a tooltip. Therefore it would be very useful if you write procedure of | |
the protocol as docstrings. | |
key1 : str, default is "F1" | |
First key binding. When this key is pushed ``self.proceed`` will be False. | |
key2 : str, default is "F2" | |
Second key binding. When this key is pushed ``self.proceed`` will be True. | |
allowed_dims : int or tuple of int, default is (1, 2, 3) | |
Function will not be called if the number of displayed dimensions does not match it. | |
exit_with_error :bool default is False | |
If True, protocol will quit whenever exception is raised and key binding will be released. If | |
False, protocol continues from the same step. | |
""" | |
allowed_dims = (allowed_dims,) if isinstance(allowed_dims, int) else tuple(allowed_dims) | |
def wrapper(protocol): | |
if not callable(protocol): | |
raise TypeError("func must be callable.") | |
gen = protocol(self) # prepare generator from protocol function | |
# initialize | |
self.proceed = False | |
self._yielded_func = next(gen) | |
self._add_parameter_container(self._yielded_func) | |
def _exit(viewer:"napari.Viewer"): | |
# delete keymap | |
viewer.keymap.pop(key1) | |
# delete widget | |
viewer.window.remove_dock_widget(viewer.window._dock_widgets["Parameter Container"]) | |
return None | |
@self.viewer.bind_key(key1, overwrite=True) | |
def _1(viewer:"napari.Viewer"): | |
self.proceed = False | |
return _base(viewer) | |
@self.viewer.bind_key(key2, overwrite=True) | |
def _2(viewer:"napari.Viewer"): | |
self.proceed = True | |
return _base(viewer) | |
def _base(viewer:"napari.Viewer"): | |
if not viewer.dims.ndisplay in allowed_dims: | |
return None | |
try: | |
# call the current function | |
self._yielded_func(self, **self.params) | |
except Exception: | |
exit_with_error and _exit(viewer) | |
raise | |
else: | |
try: | |
# get next function, update container if needed | |
next_func = next(gen) | |
if next_func != self._yielded_func: | |
# This avoid container renewing | |
self._yielded_func = next_func | |
self._add_parameter_container(self._yielded_func) | |
except StopIteration: | |
_exit(viewer) | |
# update all the layers | |
for layer in viewer.layers: | |
layer.refresh() | |
return None | |
return protocol | |
return wrapper if func is None else wrapper(func) | |
def _add_parameter_container(self, f:Callable): | |
widget_name = "Parameter Container" | |
params = inspect.signature(f).parameters | |
if not f.__doc__ and len(params) == 1: | |
return None | |
if widget_name in self.viewer.window._dock_widgets: | |
# clear all the widgets if container already exists | |
self._container.clear() | |
while self._container.native.layout().count() > 0: | |
self._container.native.layout().takeAt(0) | |
else: | |
# make new container | |
self._container = Container(name=widget_name) | |
wid = self.viewer.window.add_dock_widget(self._container, area="right", name=widget_name) | |
wid.resize(140, 100) | |
wid.setFloating(True) | |
if f.__doc__: | |
self._container.append(Label(value=f.__doc__)) | |
for i, (name, param) in enumerate(params.items()): | |
# make a container widget | |
if i == 0: | |
continue | |
value = None if param.default is inspect._empty else param.default | |
widget = create_widget(value=value, annotation=param.annotation, | |
name=name, param_kind=param.kind) | |
self._container.append(widget) | |
self.viewer.window._dock_widgets[widget_name].show() | |
return None | |
@property | |
def params(self) -> dict: | |
""" | |
Get parameter values from the container | |
""" | |
if hasattr(self, "_container"): | |
kwargs = {wid.name: wid.value for wid in self._container if not isinstance(wid, Label)} | |
else: | |
kwargs = {} | |
return kwargs | |
# Defining a class is not a must, but it will be easier to write protocols. | |
# Here we define Measure class for running regionprops around manually picked points. | |
class Measure: | |
def __init__(self): | |
self.image_layer = None | |
self.labels_layer = None | |
self.points_layer = None | |
def select_molecules(self, gui:GUI): | |
""" | |
Add markers with "F1". | |
Go to next step with "F2". | |
""" | |
if gui.proceed: | |
return | |
pos = gui.viewer.cursor.position | |
if self.points_layer is None: | |
self.points_layer = gui.viewer.add_points(pos, | |
face_color=[0,0,0,0], | |
edge_color=[0,1,0,1], | |
) | |
else: | |
self.points_layer.add(pos) | |
def select_image(self, gui:GUI): | |
""" | |
Select target image and push "F1". | |
""" | |
selected = list(gui.viewer.layers.selection)[0] | |
if not isinstance(selected, napari.layers.Image): | |
raise TypeError("Selected layer is not an image.") | |
self.image_layer = selected | |
labels = np.zeros(self.image_layer.data.shape, dtype=np.uint32) | |
self.labels_layer = gui.viewer.add_labels(labels, opacity=0.5) | |
def label(self, gui:GUI, radius=3): | |
""" | |
Set proper radius to label around markers. | |
Push "F1" to preview. | |
Push "F2" to apply. | |
""" | |
coords = self.points_layer.data.astype(np.int32) | |
lbl = self.labels_layer.data | |
for i, crds in enumerate(coords): | |
y0 = max(crds[0]-radius, 0) | |
y1 = min(crds[0]+radius, lbl.shape[0]) | |
x0 = max(crds[1]-radius, 0) | |
x1 = min(crds[1]+radius, lbl.shape[1]) | |
lbl[y0:y1, x0:x1] = i+1 | |
def measure(self, gui:GUI): | |
d = regionprops_table(self.labels_layer.data, self.image_layer.data, properties=("mean_intensity", "area")) | |
table = Table(d) | |
gui.viewer.window.add_dock_widget(table.native, name="Measurement", area="right") | |
if __name__ == "__main__": | |
viewer = napari.Viewer() | |
gui = GUI(viewer) | |
# Function "func" will be converted to image analysis protocol in the viewer. | |
@gui.bind_protocol | |
def func(gui): | |
measure = Measure() | |
gui.proceed = False | |
while not gui.proceed: | |
yield measure.select_molecules | |
yield measure.select_image | |
gui.proceed = False | |
while not gui.proceed: | |
yield measure.label | |
measure.measure(gui) | |
napari.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment