Created
February 3, 2025 20:46
-
-
Save matham/2a499bbba251117287857da0aa6c5aeb to your computer and use it in GitHub Desktop.
Export results for teaball experiments - sniffing, occupancy etc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass, field | |
from typing import Callable, Union | |
from pathlib import Path | |
import csv | |
import math | |
import numpy as np | |
from functools import partial | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
from scipy.ndimage import gaussian_filter | |
import json | |
def save_or_show(save_fig_root: None | Path = None, save_fig_prefix: str = "", width_inch: int = 8, height_inch: int = 6): | |
if save_fig_root: | |
save_fig_root.mkdir(parents=True, exist_ok=True) | |
fig = plt.gcf() | |
fig.set_size_inches(width_inch, height_inch) | |
fig.tight_layout() | |
fig.savefig( | |
save_fig_root / f"{save_fig_prefix}.png", bbox_inches='tight', | |
dpi=300 | |
) | |
plt.close() | |
else: | |
plt.tight_layout() | |
plt.show() | |
@dataclass | |
class Experiment: | |
date: int | |
subject: int | |
condition: str | |
cell_label: str | |
sex: str | |
pre_start: float | |
pre_end: float | |
trial_start: float | |
trial_end: float | |
post_start: float | |
post_end: float | |
pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
pre_pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
trial_pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
post_pos_data: "SubjectPos" = field(default=None, init=False, repr=False) | |
box_center: tuple[int, int] = field(default=None, init=False) | |
box_size: tuple[int, int] = field(default=None, init=False) | |
image_size: tuple[int, int] = field(default=None, init=False) | |
image_offset: tuple[int, int] = field(default=None, init=False) | |
enlarged_image_size: tuple[int, int] = field(default=None, init=False, repr=False) | |
largest_box_size: tuple[int, int] = field(default=None, init=False, repr=False) | |
def pos_path_filename(self, data_root: Path) -> Path: | |
return data_root / f"{self.date}_{self.subject}_top_tracked.csv" | |
def box_metadata_filename(self, data_root: Path) -> Path: | |
return data_root / f"{self.date}_{self.subject}_top_000000000.json" | |
@classmethod | |
def parse_experiment_spec_csv(cls, filename: Path) -> list["Experiment"]: | |
experiments = [] | |
with open(filename, "r") as fh: | |
reader = csv.reader(fh) | |
header = next(reader) | |
for row in reader: | |
date = int(row[0]) | |
subject = int(row[1]) | |
condition = row[2] | |
label = row[3] | |
sex = row[4] | |
times = [] | |
for t in row[5:]: | |
min, sec = map(int, t.split(":")) | |
times.append(min * 60 + sec) | |
experiment = cls( | |
date=date, subject=subject, condition=condition, cell_label=label, sex=sex, | |
pre_start=times[0], pre_end=times[1], | |
trial_start=times[2], trial_end=times[3], | |
post_start=times[4], post_end=times[5], | |
) | |
experiments.append(experiment) | |
return experiments | |
def parse_box_metadata(self, data_root: Path, box_names: tuple[str] = ("farside", "nearside")): | |
filename = self.box_metadata_filename(data_root) | |
with open(filename, "r") as fh: | |
data = json.load(fh) | |
shape = None | |
for s in data["shapes"]: | |
for name in box_names: | |
if s["label"] == name: | |
shape = s | |
if shape is None: | |
raise ValueError(f"Cannot find {box_names} in the json file. {filename}") | |
points = np.array(shape["points"]) | |
min_x, min_y = np.min(points, axis=0) | |
max_x, max_y = np.max(points, axis=0) | |
w = max_x - min_x | |
h = max_y - min_y | |
self.box_center = int(min_x + w / 2), int(min_y + h / 2) | |
self.box_size = int(w), int(h) | |
self.image_size = int(data["imageWidth"]), int(data["imageHeight"]) | |
@classmethod | |
def enlarge_canvas(cls, experiments: list["Experiment"]) -> tuple[int, int]: | |
box_centers = np.array([e.box_center for e in experiments]) | |
box_sizes = np.array([e.box_size for e in experiments]) | |
image_sizes = np.array([e.image_size for e in experiments]) | |
max_image = np.max(image_sizes, axis=0) | |
image_centers = np.floor(image_sizes / 2) | |
adjusted_image_centers = np.floor(image_centers + (max_image[None, :] - image_sizes) / 2) | |
adjusted_image_offsets = adjusted_image_centers - image_centers | |
adjusted_box_centers = box_centers + adjusted_image_offsets | |
mean_adjusted_box_centers = np.floor(np.mean(adjusted_box_centers, axis=0)) | |
aligned_box_offsets = mean_adjusted_box_centers - adjusted_box_centers | |
total_image_offset = adjusted_image_offsets + aligned_box_offsets | |
min_offset = np.min(total_image_offset, axis=0) | |
max_offset = np.max(total_image_offset, axis=0) | |
final_image_size = max_image + max_offset - min_offset | |
max_box_size = np.max(box_sizes, axis=0) | |
for i, experiment in enumerate(experiments): | |
experiment.image_offset = tuple(map(int, total_image_offset[i, :])) | |
experiment.enlarged_image_size = tuple(map(int, final_image_size)) | |
experiment.largest_box_size = tuple(map(int, max_box_size)) | |
return tuple(map(int, final_image_size)) | |
def position_to_side( | |
self, x: int, y: int, split_horizontally: bool = True, | |
sides_label: tuple[str, ...] = ("Near-side", "Far-side"), | |
): | |
cx, cy = self.box_center | |
if split_horizontally: | |
i = 0 if x < cx else 1 | |
else: | |
i = 0 if y < cy else 1 | |
return sides_label[i] | |
def position_to_grid_index(self, x, y, grid_width: int, grid_height: int) -> tuple[float, float, float]: | |
if grid_width == self.image_size[0] and grid_height == self.image_size[1]: | |
return x + self.image_offset[0], y + self.image_offset[1], 1 | |
cx, cy = self.box_center | |
bw, bh = self.box_size | |
iw, ih = self.image_size | |
prop_x = 1 | |
prop_y = 1 | |
grid_item_width = bw / grid_width | |
grid_item_height = bh / grid_height | |
left_x_prop = 1 + (cx - bw / 2) / grid_item_width | |
right_x_prop = 1 + (iw - (cx + bw / 2)) / grid_item_width | |
bottom_y_prop = 1 + (cy - bh / 2) / grid_item_height | |
top_y_prop = 1 + (ih - (cy + bh / 2)) / grid_item_height | |
if x < cx - bw / 2: | |
x = 0 | |
prop_x = left_x_prop | |
elif x >= cx + bw / 2: | |
x = grid_width - 1 | |
prop_x = right_x_prop | |
else: | |
left = cx - bw / 2 | |
box_prop = (x - left) / bw | |
x = min(math.floor(box_prop * grid_width), grid_width - 1) | |
if x == 0: | |
prop_x = left_x_prop | |
elif x == grid_width - 1: | |
prop_x = right_x_prop | |
if y <= cy - bh / 2: | |
y = 0 | |
prop_y = bottom_y_prop | |
elif y >= cy + bh / 2: | |
y = grid_height - 1 | |
prop_y = top_y_prop | |
else: | |
bottom = cy - bh / 2 | |
box_prop = (y - bottom) / bh | |
y = min(math.floor(box_prop * grid_height), grid_height - 1) | |
if y == 0: | |
prop_y = bottom_y_prop | |
elif y == grid_height - 1: | |
prop_y = top_y_prop | |
return x, y, 1 / (prop_x * prop_y) | |
def read_pos_track( | |
self, data_root: Path, hab_duration: float, trial_duration: float, post_duration: float | |
) -> None: | |
self.pos_data = SubjectPos.parse_csv_track(self.pos_path_filename(data_root)) | |
self.pre_pos_data = self.pos_data.extract_range( | |
self.pre_start, min(self.pre_end, self.pre_start + hab_duration) | |
) | |
self.trial_pos_data = self.pos_data.extract_range( | |
self.trial_start, min(self.trial_end, self.trial_start + trial_duration) | |
) | |
self.post_pos_data = self.pos_data.extract_range( | |
self.post_start, min(self.post_end, self.post_start + post_duration) | |
) | |
@classmethod | |
def _get_data_items(cls, obj: Union["Experiment", None] = None) -> list[tuple[Union["SubjectPos", str], str]]: | |
if obj is None: | |
res = [ | |
("pre_pos_data", "Habituation"), | |
("trial_pos_data", "Trial"), | |
("post_pos_data", "Post Trial"), | |
] | |
else: | |
res = [ | |
(obj.pre_pos_data, "Habituation"), | |
(obj.trial_pos_data, "Trial"), | |
(obj.post_pos_data, "Post Trial"), | |
] | |
return res | |
@classmethod | |
def _iter_periods_and_groups(cls, experiments, periods, axs, filter_args: list[dict] | None): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
it = iter(axs.flatten()) | |
for i, filter_group in enumerate(filter_args or [None, ]): | |
if n_groups > 1: | |
experiments_ = cls.filter(experiments, **filter_group) | |
else: | |
experiments_ = experiments | |
for j, (data_name, t) in enumerate(periods): | |
ax = next(it) | |
yield experiments_, i, filter_group, j, data_name, t, ax | |
def plot_occupancy( | |
self, grid_width: int, grid_height: int, | |
gaussian_sigma: float = 0, intensity_limit: float = 0, frame_normalize: bool = True, | |
scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True) | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
data.plot_occupancy( | |
grid_width, grid_height, pos_to_index=self.position_to_grid_index, fig=fig, ax=ax, | |
gaussian_sigma=gaussian_sigma, | |
intensity_limit=intensity_limit, frame_normalize=frame_normalize, scale_to_one=scale_to_one, | |
color_bar=not i, | |
x_label="X (pixels)", | |
y_label="Y (pixels)" if not i else "", | |
title="", | |
) | |
ax.set_title(f"$\\bf{{{title}}}$") | |
fig.suptitle(f"#{self.subject} occupancy density") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_occupancy( | |
cls, experiments: list["Experiment"], grid_width: int, grid_height: int, | |
gaussian_sigma: float = 0, intensity_limit: float = 0, title: str = "Subjects occupancy density", | |
frame_normalize: bool = True, experiment_normalize: bool = True, scale_to_one: bool = True, | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
filter_args: list[dict] | None = None, | |
): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True) | |
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
experiments, periods, axs, filter_args): | |
occupancy = np.zeros((grid_width, grid_height)) | |
for experiment in experiments_: | |
getattr(experiment, data_name).calculate_occupancy( | |
occupancy, experiment.position_to_grid_index, frame_normalize, | |
) | |
if experiment_normalize: | |
occupancy /= len(experiments_) | |
group = "" | |
if n_groups > 1: | |
group = f"$\\bf{{{filter_group['condition']}}}$\n\n" | |
SubjectPos.plot_occupancy_data( | |
occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one, | |
color_bar=not i and not j, | |
title="", | |
x_label="X (pixels)" if i == n_groups - 1 else "", | |
y_label=f"{group}Y (pixels)" if not j else "", | |
) | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_motion(self, y_limit: float | None = None, save_fig_root: None | Path = None, save_fig_prefix: str = ""): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
data.plot_motion( | |
fig, ax, **{"y_label": ""} if i else {}, | |
) | |
if y_limit is not None: | |
ax.set_ylim(0, y_limit) | |
ax.set_title(f"$\\bf{{{title}}}$") | |
fig.suptitle(f"#{self.subject} motion index") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_motion( | |
cls, experiments: list["Experiment"], y_limit: float | None = None, title: str = "Subjects motion index", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
filter_args: list[dict] | None = None, | |
): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False) | |
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
experiments, periods, axs, filter_args): | |
for experiment in experiments_: | |
group = "" | |
if n_groups > 1: | |
group = f"$\\bf{{{filter_group['condition']}}}$\n\n" | |
kwargs = {"y_label": ""} | |
if i != n_groups - 1: | |
kwargs["x_label"] = "" | |
if not j: | |
kwargs["y_label"] = f"{group}Motion index" | |
getattr(experiment, data_name).plot_motion(fig, ax, **kwargs) | |
if y_limit is not None: | |
ax.set_ylim(0, y_limit) | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_motion_histogram(self, n_bins=100, save_fig_root: None | Path = None, save_fig_prefix: str = ""): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True) | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
data = data.motion_index | |
data = data[data >= 0] | |
ax.hist(data, bins=n_bins, density=True, range=(0, 3)) | |
ax.set_xlabel("Motion index") | |
if not i: | |
ax.set_ylabel("Density") | |
ax.set_title(f"$\\bf{{{title}}}$") | |
fig.suptitle(f"#{self.subject} motion index density") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_motion_histogram( | |
cls, experiments: list["Experiment"], n_bins=100, title: str = "Subjects motion index", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
filter_args: list[dict] | None = None, | |
): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True) | |
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
experiments, periods, axs, filter_args): | |
items = [] | |
for experiment in experiments_: | |
data = getattr(experiment, data_name).motion_index | |
items.append(data[data >= 0]) | |
ax.hist(np.concatenate(items), bins=n_bins, density=True, range=(0, 3)) | |
group = "" | |
if n_groups > 1: | |
group = f"$\\bf{{{filter_group['condition']}}}$\n\n" | |
if i == n_groups - 1: | |
ax.set_xlabel("Motion index") | |
if not j: | |
ax.set_ylabel(f"{group}Density") | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_side_of_box( | |
self, split_horizontally: bool = True, | |
sides_label: tuple[str, ...] = ("Near-side", "Far-side"), save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
f = partial( | |
self.position_to_side, | |
split_horizontally=split_horizontally, sides_label=sides_label, | |
) | |
for (data, title), ax in zip(self._get_data_items(self), axs.flatten()): | |
data.plot_side_of_box(f, sides_label, fig, ax) | |
ax.set_title(title) | |
fig.suptitle(f"Subject {self.subject}") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_side_of_box( | |
cls, experiments: list["Experiment"], split_horizontally: bool = True, | |
sides_label: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
for (data_name, t), ax in zip(cls._get_data_items(), axs.flatten()): | |
for experiment in experiments: | |
f = partial( | |
experiment.position_to_side, | |
split_horizontally=split_horizontally, sides_label=sides_label, | |
) | |
getattr(experiment, data_name).plot_side_of_box(f, sides_label, fig, ax) | |
ax.set_title(t) | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
def _descriptor_to_point(self, point: tuple[str, str]): | |
cx, cy = self.box_center | |
bw, bh = self.box_size | |
match point[0]: | |
case "left": | |
x = cx - bw / 2 | |
case "right": | |
x = cx + bw / 2 | |
case _: | |
raise ValueError(f"Can't recognize {point[0]}") | |
match point[1]: | |
case "bottom": | |
y = cy + bh / 2 | |
case "top": | |
y = cy - bh / 2 | |
case _: | |
raise ValueError(f"Can't recognize {point[1]}") | |
return x, y | |
def plot_distance_from_point( | |
self, teaball_corner: tuple[str, str], save_fig_root: None | Path = None, | |
save_fig_prefix: str = "", | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
point_xy = self._descriptor_to_point(teaball_corner) | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
data.plot_distance_from_point( | |
point_xy, fig, ax, **{"y_label": ""} if i else {}, | |
) | |
ax.set_title(f"$\\bf{{{title}}}$") | |
fig.suptitle(f"#{self.subject} distance from teaball corner") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_distance_from_point( | |
cls, experiments: list["Experiment"], teaball_corner: tuple[str, str], | |
title: str = "Subjects distance from teaball corner", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
filter_args: list[dict] | None = None, | |
): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False) | |
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
experiments, periods, axs, filter_args): | |
for experiment in experiments_: | |
group = "" | |
if n_groups > 1: | |
group = f"$\\bf{{{filter_group['condition']}}}$\n\n" | |
kwargs = {"y_label": ""} | |
if i != n_groups - 1: | |
kwargs["x_label"] = "" | |
if not j: | |
kwargs["y_label"] = f"{group}Distance" | |
point_xy = experiment._descriptor_to_point(teaball_corner) | |
getattr(experiment, data_name).plot_distance_from_point(point_xy, fig, ax, **kwargs) | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def _get_side_percent(cls, sides_index: list[np.ndarray], sorted_sides: tuple[str, ...]): | |
counts = np.array([ | |
[np.sum(arr == i) for i in range(len(sorted_sides))] | |
for arr in sides_index | |
]) | |
percents = counts / np.sum(counts, axis=1, keepdims=True) * 100 | |
mean_prop = np.mean(percents, axis=0).squeeze() | |
return mean_prop, [prop.squeeze() for prop in percents] | |
@classmethod | |
def _plot_percents( | |
cls, sides_index: list[np.ndarray], sorted_sides: tuple[str, ...], fig: plt.Figure, ax: plt.Axes, | |
x_label: str = "Teaball side", y_label: str = "% time spent", | |
): | |
mean_prop, percents = cls._get_side_percent(sides_index, sorted_sides) | |
ax.bar(np.arange(len(sorted_sides)), mean_prop, tick_label=sorted_sides) | |
if len(sides_index) > 1: | |
for prop in percents: | |
ax.plot(np.arange(len(sorted_sides)), prop.squeeze(), "k.") | |
if x_label: | |
ax.set_xlabel(x_label) | |
if y_label: | |
ax.set_ylabel(y_label) | |
def plot_side_of_box_percent( | |
self, split_horizontally: bool = True, sides_label: tuple[str, ...] = ("Near-side", "Far-side"), | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
): | |
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False) | |
f = partial( | |
self.position_to_side, | |
split_horizontally=split_horizontally, sides_label=sides_label, | |
) | |
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())): | |
_, sides_index = data.transform_to_side_of_box(f, sides_label) | |
self._plot_percents( | |
[sides_index], sides_label, fig, ax, **{"y_label": ""} if i else {}, | |
) | |
ax.set_title(f"$\\bf{{{title}}}$") | |
fig.suptitle(f"#{self.subject} % time spent in side of box") | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_side_of_box_percent( | |
cls, experiments: list["Experiment"], split_horizontally: bool = True, | |
sides_label: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
filter_args: list[dict] | None = None, | |
): | |
n_groups = 1 if filter_args is None else len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False) | |
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups( | |
experiments, periods, axs, filter_args): | |
sides_index_all = [] | |
for experiment in experiments_: | |
f = partial( | |
experiment.position_to_side, | |
split_horizontally=split_horizontally, sides_label=sides_label, | |
) | |
_, sides_index = getattr(experiment, data_name).transform_to_side_of_box(f, sides_label) | |
sides_index_all.append(sides_index) | |
group = "" | |
if n_groups > 1: | |
group = f"$\\bf{{{filter_group['condition']}}}$\n\n" | |
kwargs = {"y_label": ""} | |
if i != n_groups - 1: | |
kwargs["x_label"] = "" | |
if not j: | |
kwargs["y_label"] = f"{group}% time spent" | |
cls._plot_percents(sides_index_all, sides_label, fig, ax, **kwargs) | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_merged_by_period_side_of_box_percent( | |
cls, experiments: list["Experiment"], filter_args: list[dict], split_horizontally: bool = True, | |
sides_label: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
): | |
n_groups = len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(1, n_periods, sharey=True, sharex=True) | |
axes = list(axs.flatten()) | |
bar_width = 1 / (n_groups + 1) | |
n_sides = len(sides_label) | |
for i, filter_group in enumerate(filter_args): | |
experiments_ = cls.filter(experiments, **filter_group) | |
for j, (data_name, t) in enumerate(periods): | |
sides_index_all = [] | |
for experiment in experiments_: | |
f = partial( | |
experiment.position_to_side, | |
split_horizontally=split_horizontally, sides_label=sides_label, | |
) | |
_, sides_index = getattr(experiment, data_name).transform_to_side_of_box(f, sides_label) | |
sides_index_all.append(sides_index) | |
ax = axes[j] | |
mean_prop, percents = cls._get_side_percent(sides_index_all, sides_label) | |
ax.bar(i * bar_width + np.arange(n_sides), mean_prop, bar_width, label=filter_group["condition"]) | |
for prop in percents: | |
ax.plot(i * bar_width + np.arange(n_sides), prop, "k.") | |
ax.set_xlabel("Teaball side") | |
if not j: | |
ax.set_ylabel("% time spent") | |
if not i: | |
ax.set_title(f"$\\bf{{{t}}}$") | |
for ax in axes: | |
ax.set_xticks(np.arange(n_sides) + (1 - bar_width) / 2 - bar_width / 2, sides_label) | |
ax.set_xlim(-bar_width, n_sides + bar_width) | |
handles, labels = axes[-1].get_legend_handles_labels() | |
fig.legend(handles, labels, ncols=n_groups, bbox_to_anchor=(0, 0), loc=2) | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix) | |
@classmethod | |
def plot_multi_experiment_merged_by_group_side_of_box_percent( | |
cls, experiments: list["Experiment"], filter_args: list[dict], split_horizontally: bool = True, | |
sides_label: tuple[str, ...] = ("Teaball-side", "other-side"), title: str = "Subjects motion", | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", show_titles: bool = True, | |
show_legend: bool = True, only_side: str | None = None, show_xlabel: bool = True, | |
): | |
n_groups = len(filter_args) | |
periods = cls._get_data_items() | |
n_periods = len(periods) | |
fig, axs = plt.subplots(n_groups, 1, sharey=True, sharex=True) | |
axes = list(axs.flatten()) | |
bar_width = 1 / (n_periods + 1) | |
n_sides = len(sides_label) | |
sides_s = slice(0, n_sides) | |
n_sides_used = n_sides | |
sides_used = sides_label | |
if only_side: | |
i = sides_label.index(only_side) | |
sides_s = slice(i, i + 1) | |
n_sides_used = 1 | |
sides_used = [only_side] | |
for i, filter_group in enumerate(filter_args): | |
experiments_ = cls.filter(experiments, **filter_group) | |
for j, (data_name, t) in enumerate(periods): | |
sides_index_all = [] | |
for experiment in experiments_: | |
f = partial( | |
experiment.position_to_side, | |
split_horizontally=split_horizontally, sides_label=sides_label, | |
) | |
_, sides_index = getattr(experiment, data_name).transform_to_side_of_box(f, sides_label) | |
sides_index_all.append(sides_index) | |
ax = axes[i] | |
mean_prop, percents = cls._get_side_percent(sides_index_all, sides_label) | |
ax.bar(j * bar_width + np.arange(n_sides_used), mean_prop[sides_s], bar_width, label=t) | |
for prop in percents: | |
ax.plot(j * bar_width + np.arange(n_sides_used), prop[sides_s], "k.") | |
if i == n_groups - 1 and show_xlabel: | |
ax.set_xlabel("Teaball side") | |
ax.set_ylabel("% time spent") | |
if not j and show_titles: | |
ax.set_title(f"$\\bf{{{filter_group['condition']}}}$") | |
for ax in axes: | |
ax.set_xticks(np.arange(n_sides_used) + (1 - bar_width) / 2 - bar_width / 2, sides_used) | |
ax.set_xlim(-bar_width, n_sides_used - bar_width) | |
if show_legend: | |
handles, labels = axes[-1].get_legend_handles_labels() | |
fig.legend(handles, labels, ncols=1, bbox_to_anchor=(0, 0), loc=2) | |
fig.suptitle(title) | |
save_or_show(save_fig_root, save_fig_prefix, width_inch=3) | |
@classmethod | |
def filter( | |
cls, experiments: list["Experiment"], date: int | None = None, subject: int | None = None, | |
condition: str | None = None, cell_label: str | None = None, sex: str | None = None | |
): | |
if date is not None: | |
experiments = [t for t in experiments if t.date == date] | |
if subject is not None: | |
experiments = [t for t in experiments if t.subject == subject] | |
if condition is not None: | |
experiments = [t for t in experiments if t.condition == condition] | |
if cell_label is not None: | |
experiments = [t for t in experiments if t.cell_label == cell_label] | |
if sex is not None: | |
experiments = [t for t in experiments if t.sex == sex] | |
return experiments | |
@dataclass | |
class SubjectPos: | |
times: np.ndarray | |
track: np.ndarray | |
motion_index: np.ndarray | |
@property | |
def min_x(self): | |
return np.min(self.track[:, 0]) | |
@property | |
def min_y(self): | |
return np.min(self.track[:, 1]) | |
@property | |
def max_x(self): | |
return np.max(self.track[:, 0]) | |
@property | |
def max_y(self): | |
return np.max(self.track[:, 1]) | |
@classmethod | |
def parse_csv_track(cls, filename: Path, subject_name: str = "mouse") -> "SubjectPos": | |
track = [] | |
times = [] | |
index = [] | |
with open(filename, "r") as fh: | |
reader = csv.reader(fh) | |
header = next(reader) | |
for row in reader: | |
frame, instance, cx, cy, index_val, t = row | |
if instance != subject_name: | |
continue | |
track.append((int(float(cx)), int(float(cy)))) | |
index.append(float(index_val)) | |
th, tm, ts = map(float, t.split(":")) | |
times.append(th * 60 ** 2 + tm * 60 + ts) | |
return SubjectPos(times=np.array(times), track=np.array(track), motion_index=np.array(index)) | |
def extract_range(self, t_start: float | None = None, t_end: float | None = None): | |
if t_start is None: | |
t_start = self.times[0] | |
if t_end is None: | |
t_end = self.times[-1] + 1 | |
i_s = np.sum(self.times < t_start) | |
i_e = np.sum(self.times <= t_end) | |
return SubjectPos(times=self.times[i_s:i_e], track=self.track[i_s:i_e], motion_index=self.motion_index[i_s:i_e]) | |
def calculate_occupancy( | |
self, occupancy: np.ndarray, pos_to_index: Callable | None = None, frame_normalize: bool = True, | |
) -> None: | |
n = self.track.shape[0] | |
grid_width, grid_height = occupancy.shape | |
frame_proportion = 1 | |
if frame_normalize: | |
frame_proportion = 1 / n | |
for i in range(n): | |
if self.track[i, 0] < 0 or self.track[i, 1] < 0: | |
continue | |
if pos_to_index is None: | |
x, y = self.track[i, :] | |
area_prop = 1 | |
else: | |
x, y, area_prop = pos_to_index(*self.track[i, :], grid_width, grid_height) | |
x = int(min(x, grid_width)) | |
y = int(min(y, grid_height)) | |
occupancy[x, y] += frame_proportion * area_prop | |
@classmethod | |
def plot_occupancy_data( | |
cls, occupancy: np.ndarray, fig: plt.Figure, ax: plt.Axes, | |
gaussian_sigma: float = 0, intensity_limit: float = 0, scale_to_one: bool = True, | |
x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy", | |
color_bar: bool = True, | |
): | |
if gaussian_sigma: | |
occupancy = gaussian_filter(occupancy, gaussian_sigma) | |
if scale_to_one: | |
occupancy /= occupancy.max() | |
im = ax.imshow( | |
occupancy.T, aspect="auto", origin="upper", cmap="viridis", interpolation="sinc", | |
interpolation_stage="data", vmax=intensity_limit or None | |
) | |
if x_label: | |
ax.set_xlabel(x_label) | |
if y_label: | |
ax.set_ylabel(y_label) | |
if title: | |
ax.set_title(title) | |
if color_bar: | |
divider = make_axes_locatable(ax) | |
cax = divider.append_axes('right', size='5%', pad=0.05) | |
fig.colorbar(im, cax=cax, orientation='vertical') | |
def plot_occupancy( | |
self, grid_width: int, grid_height: int, | |
pos_to_index: Callable | None = None, fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
gaussian_sigma: float = 0, intensity_limit: int = 0, frame_normalize: bool = True, | |
scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy", | |
color_bar: bool = True, | |
): | |
occupancy = np.zeros((grid_width, grid_height)) | |
show_plot = ax is None | |
if ax is None: | |
assert fig is None | |
fig, ax = plt.subplots() | |
else: | |
assert fig is not None | |
self.calculate_occupancy(occupancy, pos_to_index, frame_normalize) | |
self.plot_occupancy_data( | |
occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one, x_label, y_label, title, color_bar, | |
) | |
if show_plot: | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_motion( | |
self, fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
x_label: str = "Time (min)", y_label: str = "Motion index", | |
): | |
show_plot = ax is None | |
if ax is None: | |
assert fig is None | |
fig, ax = plt.subplots() | |
else: | |
assert fig is not None | |
valid = self.motion_index >= 0 | |
times = self.times[valid] | |
ax.plot(times / 60 - times[0] / 60, self.motion_index[valid], ",") | |
if x_label: | |
ax.set_xlabel(x_label) | |
if y_label: | |
ax.set_ylabel(y_label) | |
if show_plot: | |
save_or_show(save_fig_root, save_fig_prefix) | |
def transform_to_side_of_box( | |
self, position_to_side: Callable, sides_label: tuple[str, ...], | |
) -> tuple[np.ndarray, np.ndarray]: | |
valid = np.logical_and(self.track[:, 0] >= 0, self.track[:, 1] >= 0) | |
pos_sides_name = [position_to_side(*p) for p in self.track[valid, :]] | |
sides = {name: i for i, name in enumerate(sides_label)} | |
sides_index = np.array([sides[n] for n in pos_sides_name]) | |
times = self.times[valid] | |
return times, sides_index | |
def plot_side_of_box( | |
self, position_to_side: Callable, sides_label: tuple[str, ...], | |
fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
): | |
show_plot = ax is None | |
if ax is None: | |
assert fig is None | |
fig, ax = plt.subplots() | |
else: | |
assert fig is not None | |
times, sides_index = self.transform_to_side_of_box(position_to_side, sides_label) | |
ax.plot(times / 60 - times[0] / 60, sides_index, "*", alpha=.3, ms=5) | |
ax.set_yticks(np.arange(len(sides_label)), sides_label) | |
ax.set_xlabel("Time (min)") | |
ax.set_ylabel("Side of box relative to teaball") | |
if show_plot: | |
save_or_show(save_fig_root, save_fig_prefix) | |
def plot_distance_from_point( | |
self, point_xy: tuple[int, int], fig: plt.Figure | None = None, ax: plt.Axes | None = None, | |
save_fig_root: None | Path = None, save_fig_prefix: str = "", | |
x_label: str = "Time (min)", y_label: str = "Distance", | |
): | |
show_plot = ax is None | |
if ax is None: | |
assert fig is None | |
fig, ax = plt.subplots() | |
else: | |
assert fig is not None | |
valid = np.logical_and(self.track[:, 0] >= 0, self.track[:, 1] >= 0) | |
point = np.array(point_xy)[None, :] | |
distance = np.sqrt(np.sum(np.square(self.track[valid, :] - point), axis=1)) | |
times = self.times[valid] | |
ax.plot(times / 60 - times[0] / 60, distance, ",") | |
if x_label: | |
ax.set_xlabel(x_label) | |
if y_label: | |
ax.set_ylabel(y_label) | |
if show_plot: | |
save_or_show(save_fig_root, save_fig_prefix) | |
if __name__ == "__main__": | |
csv_experiment_times = Path(r'C:\Users\Matthew Einhorn\Downloads\tmt24') / "Batch3 data analysis - timestamps.csv" | |
tracking_data_root = Path(r"C:\Users\Matthew Einhorn\Downloads\tmt24\tracking") | |
json_data_root = Path(r"C:\Users\Matthew Einhorn\Downloads\tmt24\json") | |
figure_root = Path(r"C:\Users\Matthew Einhorn\Downloads\tmt24\figures") | |
conditions = "Blank", "TMT", "2MBA", "IAMM" | |
experiments = Experiment.parse_experiment_spec_csv(csv_experiment_times) | |
for experiment in experiments: | |
experiment.parse_box_metadata(json_data_root) | |
experiment.read_pos_track(tracking_data_root, hab_duration=5 * 60, trial_duration=2 * 60, post_duration=2 * 60) | |
grid_width, grid_height = Experiment.enlarge_canvas(experiments) | |
for experiment in experiments: | |
experiment.plot_occupancy( | |
grid_width, grid_height, scale_to_one=False, intensity_limit=1e-5, | |
save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}", | |
save_fig_prefix=f"intensity_limit100K_{experiment.subject}_{experiment.date}", | |
) | |
# experiment.plot_occupancy( | |
# grid_width, grid_height, intensity_limit=1e-6, | |
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}", | |
# save_fig_prefix=f"intensity_limit1M_{experiment.subject}_{experiment.date}", | |
# ) | |
# experiment.plot_occupancy( | |
# grid_width, grid_height, gaussian_sigma=5, | |
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}", | |
# save_fig_prefix=f"occupancy_sigma5_{experiment.subject}_{experiment.date}", | |
# ) | |
# experiment.plot_occupancy( | |
# grid_width, grid_height, gaussian_sigma=1, | |
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}", | |
# save_fig_prefix=f"occupancy_sigma1_{experiment.subject}_{experiment.date}", | |
# ) | |
# experiment.plot_motion_histogram( | |
# n_bins=10, | |
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}", | |
# save_fig_prefix=f"motion_histogram_{experiment.subject}_{experiment.date}", | |
# ) | |
# experiment.plot_motion( | |
# y_limit=3, | |
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}", | |
# save_fig_prefix=f"motion_index_{experiment.subject}_{experiment.date}", | |
# ) | |
# experiment.plot_distance_from_point( | |
# ("left", "bottom"), | |
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}", | |
# save_fig_prefix=f"distance_{experiment.subject}_{experiment.date}", | |
# ) | |
# experiment.plot_side_of_box_percent( | |
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}", | |
# save_fig_prefix=f"side_percent_{experiment.subject}_{experiment.date}", | |
# ) | |
# | |
for condition in ("TMT", "Blank", "2MBA", "IAMM"): | |
experiments_ = Experiment.filter(experiments, condition=condition) | |
Experiment.plot_multi_experiment_occupancy( | |
experiments_, grid_width, grid_height, title=f"{condition} occupancy density", scale_to_one=False, | |
intensity_limit=1e-5, | |
save_fig_root=figure_root / "grouped" / "occupancy", | |
save_fig_prefix=f"occupancy_{condition}_intensity_limit100K", | |
) | |
# Experiment.plot_multi_experiment_occupancy( | |
# experiments_, grid_width, grid_height, title=f"{condition} occupancy density", intensity_limit=1e-6, | |
# save_fig_root=figure_root / "grouped" / "occupancy", | |
# save_fig_prefix=f"occupancy_{condition}_intensity_limit1M", | |
# ) | |
# Experiment.plot_multi_experiment_occupancy( | |
# experiments_, grid_width, grid_height, title=f"{condition} occupancy", gaussian_sigma=5, | |
# save_fig_root=figure_root / "grouped" / "occupancy", save_fig_prefix=f"occupancy_{condition}_sigma5", | |
# ) | |
# Experiment.plot_multi_experiment_occupancy( | |
# experiments_, grid_width, grid_height, title=f"{condition} occupancy", gaussian_sigma=1, | |
# save_fig_root=figure_root / "grouped" / "occupancy", save_fig_prefix=f"occupancy_{condition}_sigma1", | |
# ) | |
# Experiment.plot_multi_experiment_motion_histogram( | |
# experiments_, title=f"{condition} motion index density", n_bins=10, | |
# save_fig_root=figure_root / "grouped" / "motion_histogram", save_fig_prefix=f"motion_histogram_{condition}", | |
# ) | |
# Experiment.plot_multi_experiment_motion( | |
# experiments_, y_limit=3, title=f"{condition} motion index", | |
# save_fig_root=figure_root / "grouped" / "motion_index", save_fig_prefix=f"motion_index_{condition}", | |
# ) | |
# Experiment.plot_multi_experiment_distance_from_point( | |
# experiments_, ("left", "bottom"), title=f"{condition} distance from teaball corner", | |
# save_fig_root=figure_root / "grouped" / "distance", save_fig_prefix=f"distance_{condition}", | |
# ) | |
# Experiment.plot_multi_experiment_side_of_box_percent( | |
# experiments_, title=f"{condition} side of box percent", | |
# save_fig_root=figure_root / "grouped" / "side_percent", save_fig_prefix=f"side_percent_{condition}", | |
# ) | |
Experiment.plot_multi_experiment_occupancy( | |
experiments, grid_width, grid_height, title=f"Occupancy density", scale_to_one=False, intensity_limit=1e-5, | |
filter_args=[{"condition": c} for c in conditions], | |
save_fig_root=figure_root / "grouped_single_figure", | |
save_fig_prefix=f"occupancy_intensity_limit100K", | |
) | |
# Experiment.plot_multi_experiment_occupancy( | |
# experiments, grid_width, grid_height, title=f"Occupancy density", intensity_limit=1e-6, | |
# filter_args=[{"condition": c} for c in conditions], | |
# save_fig_root=figure_root / "grouped_single_figure", | |
# save_fig_prefix=f"occupancy_intensity_limit1M", | |
# ) | |
# Experiment.plot_multi_experiment_occupancy( | |
# experiments, grid_width, grid_height, title=f"Occupancy density", gaussian_sigma=5, | |
# filter_args=[{"condition": c} for c in conditions], | |
# save_fig_root=figure_root / "grouped_single_figure", | |
# save_fig_prefix=f"occupancy_sigma5", | |
# ) | |
# Experiment.plot_multi_experiment_occupancy( | |
# experiments, grid_width, grid_height, title=f"Occupancy density", gaussian_sigma=1, | |
# filter_args=[{"condition": c} for c in conditions], | |
# save_fig_root=figure_root / "grouped_single_figure", | |
# save_fig_prefix=f"occupancy_sigma1", | |
# ) | |
# Experiment.plot_multi_experiment_motion_histogram( | |
# experiments, title=f"Motion index density", n_bins=10, | |
# filter_args=[{"condition": c} for c in conditions], | |
# save_fig_root=figure_root / "grouped_single_figure", | |
# save_fig_prefix=f"motion_histogram", | |
# ) | |
# Experiment.plot_multi_experiment_motion( | |
# experiments, y_limit=3, title=f"Motion index", | |
# filter_args=[{"condition": c} for c in conditions], | |
# save_fig_root=figure_root / "grouped_single_figure", | |
# save_fig_prefix=f"motion_index", | |
# ) | |
# Experiment.plot_multi_experiment_distance_from_point( | |
# experiments, ("left", "bottom"), title=f"Distance from teaball corner", | |
# filter_args=[{"condition": c} for c in conditions], | |
# save_fig_root=figure_root / "grouped_single_figure", | |
# save_fig_prefix=f"distance", | |
# ) | |
# Experiment.plot_multi_experiment_side_of_box_percent( | |
# experiments, title=f"Side of box percent", | |
# filter_args=[{"condition": c} for c in conditions], | |
# save_fig_root=figure_root / "grouped_single_figure", | |
# save_fig_prefix=f"side_percent", | |
# ) | |
# Experiment.plot_multi_experiment_merged_by_period_side_of_box_percent( | |
# experiments, title=f"Side of box percent", | |
# filter_args=[{"condition": c} for c in conditions], | |
# save_fig_root=figure_root / "grouped_single_figure", | |
# save_fig_prefix=f"side_percent_merged_by_period", | |
# ) | |
# Experiment.plot_multi_experiment_merged_by_group_side_of_box_percent( | |
# experiments, title=f"% spent on teaball side", show_titles=False, show_legend=False, | |
# filter_args=[{"condition": c} for c in conditions], only_side="Teaball-side", show_xlabel=False, | |
# save_fig_root=figure_root / "grouped_single_figure", | |
# save_fig_prefix=f"side_percent_merged_by_group", | |
# ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment