Skip to content

Instantly share code, notes, and snippets.

@matham
Created February 3, 2025 20:46
Show Gist options
  • Save matham/2a499bbba251117287857da0aa6c5aeb to your computer and use it in GitHub Desktop.
Save matham/2a499bbba251117287857da0aa6c5aeb to your computer and use it in GitHub Desktop.
Export results for teaball experiments - sniffing, occupancy etc
from dataclasses import dataclass, field
from typing import Callable, Union
from pathlib import Path
import csv
import math
import numpy as np
from functools import partial
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.ndimage import gaussian_filter
import json
def save_or_show(save_fig_root: None | Path = None, save_fig_prefix: str = "", width_inch: int = 8, height_inch: int = 6):
if save_fig_root:
save_fig_root.mkdir(parents=True, exist_ok=True)
fig = plt.gcf()
fig.set_size_inches(width_inch, height_inch)
fig.tight_layout()
fig.savefig(
save_fig_root / f"{save_fig_prefix}.png", bbox_inches='tight',
dpi=300
)
plt.close()
else:
plt.tight_layout()
plt.show()
@dataclass
class Experiment:
date: int
subject: int
condition: str
cell_label: str
sex: str
pre_start: float
pre_end: float
trial_start: float
trial_end: float
post_start: float
post_end: float
pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
pre_pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
trial_pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
post_pos_data: "SubjectPos" = field(default=None, init=False, repr=False)
box_center: tuple[int, int] = field(default=None, init=False)
box_size: tuple[int, int] = field(default=None, init=False)
image_size: tuple[int, int] = field(default=None, init=False)
image_offset: tuple[int, int] = field(default=None, init=False)
enlarged_image_size: tuple[int, int] = field(default=None, init=False, repr=False)
largest_box_size: tuple[int, int] = field(default=None, init=False, repr=False)
def pos_path_filename(self, data_root: Path) -> Path:
return data_root / f"{self.date}_{self.subject}_top_tracked.csv"
def box_metadata_filename(self, data_root: Path) -> Path:
return data_root / f"{self.date}_{self.subject}_top_000000000.json"
@classmethod
def parse_experiment_spec_csv(cls, filename: Path) -> list["Experiment"]:
experiments = []
with open(filename, "r") as fh:
reader = csv.reader(fh)
header = next(reader)
for row in reader:
date = int(row[0])
subject = int(row[1])
condition = row[2]
label = row[3]
sex = row[4]
times = []
for t in row[5:]:
min, sec = map(int, t.split(":"))
times.append(min * 60 + sec)
experiment = cls(
date=date, subject=subject, condition=condition, cell_label=label, sex=sex,
pre_start=times[0], pre_end=times[1],
trial_start=times[2], trial_end=times[3],
post_start=times[4], post_end=times[5],
)
experiments.append(experiment)
return experiments
def parse_box_metadata(self, data_root: Path, box_names: tuple[str] = ("farside", "nearside")):
filename = self.box_metadata_filename(data_root)
with open(filename, "r") as fh:
data = json.load(fh)
shape = None
for s in data["shapes"]:
for name in box_names:
if s["label"] == name:
shape = s
if shape is None:
raise ValueError(f"Cannot find {box_names} in the json file. {filename}")
points = np.array(shape["points"])
min_x, min_y = np.min(points, axis=0)
max_x, max_y = np.max(points, axis=0)
w = max_x - min_x
h = max_y - min_y
self.box_center = int(min_x + w / 2), int(min_y + h / 2)
self.box_size = int(w), int(h)
self.image_size = int(data["imageWidth"]), int(data["imageHeight"])
@classmethod
def enlarge_canvas(cls, experiments: list["Experiment"]) -> tuple[int, int]:
box_centers = np.array([e.box_center for e in experiments])
box_sizes = np.array([e.box_size for e in experiments])
image_sizes = np.array([e.image_size for e in experiments])
max_image = np.max(image_sizes, axis=0)
image_centers = np.floor(image_sizes / 2)
adjusted_image_centers = np.floor(image_centers + (max_image[None, :] - image_sizes) / 2)
adjusted_image_offsets = adjusted_image_centers - image_centers
adjusted_box_centers = box_centers + adjusted_image_offsets
mean_adjusted_box_centers = np.floor(np.mean(adjusted_box_centers, axis=0))
aligned_box_offsets = mean_adjusted_box_centers - adjusted_box_centers
total_image_offset = adjusted_image_offsets + aligned_box_offsets
min_offset = np.min(total_image_offset, axis=0)
max_offset = np.max(total_image_offset, axis=0)
final_image_size = max_image + max_offset - min_offset
max_box_size = np.max(box_sizes, axis=0)
for i, experiment in enumerate(experiments):
experiment.image_offset = tuple(map(int, total_image_offset[i, :]))
experiment.enlarged_image_size = tuple(map(int, final_image_size))
experiment.largest_box_size = tuple(map(int, max_box_size))
return tuple(map(int, final_image_size))
def position_to_side(
self, x: int, y: int, split_horizontally: bool = True,
sides_label: tuple[str, ...] = ("Near-side", "Far-side"),
):
cx, cy = self.box_center
if split_horizontally:
i = 0 if x < cx else 1
else:
i = 0 if y < cy else 1
return sides_label[i]
def position_to_grid_index(self, x, y, grid_width: int, grid_height: int) -> tuple[float, float, float]:
if grid_width == self.image_size[0] and grid_height == self.image_size[1]:
return x + self.image_offset[0], y + self.image_offset[1], 1
cx, cy = self.box_center
bw, bh = self.box_size
iw, ih = self.image_size
prop_x = 1
prop_y = 1
grid_item_width = bw / grid_width
grid_item_height = bh / grid_height
left_x_prop = 1 + (cx - bw / 2) / grid_item_width
right_x_prop = 1 + (iw - (cx + bw / 2)) / grid_item_width
bottom_y_prop = 1 + (cy - bh / 2) / grid_item_height
top_y_prop = 1 + (ih - (cy + bh / 2)) / grid_item_height
if x < cx - bw / 2:
x = 0
prop_x = left_x_prop
elif x >= cx + bw / 2:
x = grid_width - 1
prop_x = right_x_prop
else:
left = cx - bw / 2
box_prop = (x - left) / bw
x = min(math.floor(box_prop * grid_width), grid_width - 1)
if x == 0:
prop_x = left_x_prop
elif x == grid_width - 1:
prop_x = right_x_prop
if y <= cy - bh / 2:
y = 0
prop_y = bottom_y_prop
elif y >= cy + bh / 2:
y = grid_height - 1
prop_y = top_y_prop
else:
bottom = cy - bh / 2
box_prop = (y - bottom) / bh
y = min(math.floor(box_prop * grid_height), grid_height - 1)
if y == 0:
prop_y = bottom_y_prop
elif y == grid_height - 1:
prop_y = top_y_prop
return x, y, 1 / (prop_x * prop_y)
def read_pos_track(
self, data_root: Path, hab_duration: float, trial_duration: float, post_duration: float
) -> None:
self.pos_data = SubjectPos.parse_csv_track(self.pos_path_filename(data_root))
self.pre_pos_data = self.pos_data.extract_range(
self.pre_start, min(self.pre_end, self.pre_start + hab_duration)
)
self.trial_pos_data = self.pos_data.extract_range(
self.trial_start, min(self.trial_end, self.trial_start + trial_duration)
)
self.post_pos_data = self.pos_data.extract_range(
self.post_start, min(self.post_end, self.post_start + post_duration)
)
@classmethod
def _get_data_items(cls, obj: Union["Experiment", None] = None) -> list[tuple[Union["SubjectPos", str], str]]:
if obj is None:
res = [
("pre_pos_data", "Habituation"),
("trial_pos_data", "Trial"),
("post_pos_data", "Post Trial"),
]
else:
res = [
(obj.pre_pos_data, "Habituation"),
(obj.trial_pos_data, "Trial"),
(obj.post_pos_data, "Post Trial"),
]
return res
@classmethod
def _iter_periods_and_groups(cls, experiments, periods, axs, filter_args: list[dict] | None):
n_groups = 1 if filter_args is None else len(filter_args)
it = iter(axs.flatten())
for i, filter_group in enumerate(filter_args or [None, ]):
if n_groups > 1:
experiments_ = cls.filter(experiments, **filter_group)
else:
experiments_ = experiments
for j, (data_name, t) in enumerate(periods):
ax = next(it)
yield experiments_, i, filter_group, j, data_name, t, ax
def plot_occupancy(
self, grid_width: int, grid_height: int,
gaussian_sigma: float = 0, intensity_limit: float = 0, frame_normalize: bool = True,
scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data.plot_occupancy(
grid_width, grid_height, pos_to_index=self.position_to_grid_index, fig=fig, ax=ax,
gaussian_sigma=gaussian_sigma,
intensity_limit=intensity_limit, frame_normalize=frame_normalize, scale_to_one=scale_to_one,
color_bar=not i,
x_label="X (pixels)",
y_label="Y (pixels)" if not i else "",
title="",
)
ax.set_title(f"$\\bf{{{title}}}$")
fig.suptitle(f"#{self.subject} occupancy density")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_occupancy(
cls, experiments: list["Experiment"], grid_width: int, grid_height: int,
gaussian_sigma: float = 0, intensity_limit: float = 0, title: str = "Subjects occupancy density",
frame_normalize: bool = True, experiment_normalize: bool = True, scale_to_one: bool = True,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None,
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
occupancy = np.zeros((grid_width, grid_height))
for experiment in experiments_:
getattr(experiment, data_name).calculate_occupancy(
occupancy, experiment.position_to_grid_index, frame_normalize,
)
if experiment_normalize:
occupancy /= len(experiments_)
group = ""
if n_groups > 1:
group = f"$\\bf{{{filter_group['condition']}}}$\n\n"
SubjectPos.plot_occupancy_data(
occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one,
color_bar=not i and not j,
title="",
x_label="X (pixels)" if i == n_groups - 1 else "",
y_label=f"{group}Y (pixels)" if not j else "",
)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
def plot_motion(self, y_limit: float | None = None, save_fig_root: None | Path = None, save_fig_prefix: str = ""):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data.plot_motion(
fig, ax, **{"y_label": ""} if i else {},
)
if y_limit is not None:
ax.set_ylim(0, y_limit)
ax.set_title(f"$\\bf{{{title}}}$")
fig.suptitle(f"#{self.subject} motion index")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_motion(
cls, experiments: list["Experiment"], y_limit: float | None = None, title: str = "Subjects motion index",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None,
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
for experiment in experiments_:
group = ""
if n_groups > 1:
group = f"$\\bf{{{filter_group['condition']}}}$\n\n"
kwargs = {"y_label": ""}
if i != n_groups - 1:
kwargs["x_label"] = ""
if not j:
kwargs["y_label"] = f"{group}Motion index"
getattr(experiment, data_name).plot_motion(fig, ax, **kwargs)
if y_limit is not None:
ax.set_ylim(0, y_limit)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
def plot_motion_histogram(self, n_bins=100, save_fig_root: None | Path = None, save_fig_prefix: str = ""):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=True)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data = data.motion_index
data = data[data >= 0]
ax.hist(data, bins=n_bins, density=True, range=(0, 3))
ax.set_xlabel("Motion index")
if not i:
ax.set_ylabel("Density")
ax.set_title(f"$\\bf{{{title}}}$")
fig.suptitle(f"#{self.subject} motion index density")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_motion_histogram(
cls, experiments: list["Experiment"], n_bins=100, title: str = "Subjects motion index",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None,
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=True)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
items = []
for experiment in experiments_:
data = getattr(experiment, data_name).motion_index
items.append(data[data >= 0])
ax.hist(np.concatenate(items), bins=n_bins, density=True, range=(0, 3))
group = ""
if n_groups > 1:
group = f"$\\bf{{{filter_group['condition']}}}$\n\n"
if i == n_groups - 1:
ax.set_xlabel("Motion index")
if not j:
ax.set_ylabel(f"{group}Density")
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
def plot_side_of_box(
self, split_horizontally: bool = True,
sides_label: tuple[str, ...] = ("Near-side", "Far-side"), save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
f = partial(
self.position_to_side,
split_horizontally=split_horizontally, sides_label=sides_label,
)
for (data, title), ax in zip(self._get_data_items(self), axs.flatten()):
data.plot_side_of_box(f, sides_label, fig, ax)
ax.set_title(title)
fig.suptitle(f"Subject {self.subject}")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_side_of_box(
cls, experiments: list["Experiment"], split_horizontally: bool = True,
sides_label: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
for (data_name, t), ax in zip(cls._get_data_items(), axs.flatten()):
for experiment in experiments:
f = partial(
experiment.position_to_side,
split_horizontally=split_horizontally, sides_label=sides_label,
)
getattr(experiment, data_name).plot_side_of_box(f, sides_label, fig, ax)
ax.set_title(t)
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
def _descriptor_to_point(self, point: tuple[str, str]):
cx, cy = self.box_center
bw, bh = self.box_size
match point[0]:
case "left":
x = cx - bw / 2
case "right":
x = cx + bw / 2
case _:
raise ValueError(f"Can't recognize {point[0]}")
match point[1]:
case "bottom":
y = cy + bh / 2
case "top":
y = cy - bh / 2
case _:
raise ValueError(f"Can't recognize {point[1]}")
return x, y
def plot_distance_from_point(
self, teaball_corner: tuple[str, str], save_fig_root: None | Path = None,
save_fig_prefix: str = "",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
point_xy = self._descriptor_to_point(teaball_corner)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
data.plot_distance_from_point(
point_xy, fig, ax, **{"y_label": ""} if i else {},
)
ax.set_title(f"$\\bf{{{title}}}$")
fig.suptitle(f"#{self.subject} distance from teaball corner")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_distance_from_point(
cls, experiments: list["Experiment"], teaball_corner: tuple[str, str],
title: str = "Subjects distance from teaball corner",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None,
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
for experiment in experiments_:
group = ""
if n_groups > 1:
group = f"$\\bf{{{filter_group['condition']}}}$\n\n"
kwargs = {"y_label": ""}
if i != n_groups - 1:
kwargs["x_label"] = ""
if not j:
kwargs["y_label"] = f"{group}Distance"
point_xy = experiment._descriptor_to_point(teaball_corner)
getattr(experiment, data_name).plot_distance_from_point(point_xy, fig, ax, **kwargs)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def _get_side_percent(cls, sides_index: list[np.ndarray], sorted_sides: tuple[str, ...]):
counts = np.array([
[np.sum(arr == i) for i in range(len(sorted_sides))]
for arr in sides_index
])
percents = counts / np.sum(counts, axis=1, keepdims=True) * 100
mean_prop = np.mean(percents, axis=0).squeeze()
return mean_prop, [prop.squeeze() for prop in percents]
@classmethod
def _plot_percents(
cls, sides_index: list[np.ndarray], sorted_sides: tuple[str, ...], fig: plt.Figure, ax: plt.Axes,
x_label: str = "Teaball side", y_label: str = "% time spent",
):
mean_prop, percents = cls._get_side_percent(sides_index, sorted_sides)
ax.bar(np.arange(len(sorted_sides)), mean_prop, tick_label=sorted_sides)
if len(sides_index) > 1:
for prop in percents:
ax.plot(np.arange(len(sorted_sides)), prop.squeeze(), "k.")
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
def plot_side_of_box_percent(
self, split_horizontally: bool = True, sides_label: tuple[str, ...] = ("Near-side", "Far-side"),
save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
fig, axs = plt.subplots(1, 3, sharey=True, sharex=False)
f = partial(
self.position_to_side,
split_horizontally=split_horizontally, sides_label=sides_label,
)
for i, ((data, title), ax) in enumerate(zip(self._get_data_items(self), axs.flatten())):
_, sides_index = data.transform_to_side_of_box(f, sides_label)
self._plot_percents(
[sides_index], sides_label, fig, ax, **{"y_label": ""} if i else {},
)
ax.set_title(f"$\\bf{{{title}}}$")
fig.suptitle(f"#{self.subject} % time spent in side of box")
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_side_of_box_percent(
cls, experiments: list["Experiment"], split_horizontally: bool = True,
sides_label: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
filter_args: list[dict] | None = None,
):
n_groups = 1 if filter_args is None else len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, n_periods, sharey=True, sharex=False)
for experiments_, i, filter_group, j, data_name, t, ax in cls._iter_periods_and_groups(
experiments, periods, axs, filter_args):
sides_index_all = []
for experiment in experiments_:
f = partial(
experiment.position_to_side,
split_horizontally=split_horizontally, sides_label=sides_label,
)
_, sides_index = getattr(experiment, data_name).transform_to_side_of_box(f, sides_label)
sides_index_all.append(sides_index)
group = ""
if n_groups > 1:
group = f"$\\bf{{{filter_group['condition']}}}$\n\n"
kwargs = {"y_label": ""}
if i != n_groups - 1:
kwargs["x_label"] = ""
if not j:
kwargs["y_label"] = f"{group}% time spent"
cls._plot_percents(sides_index_all, sides_label, fig, ax, **kwargs)
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_merged_by_period_side_of_box_percent(
cls, experiments: list["Experiment"], filter_args: list[dict], split_horizontally: bool = True,
sides_label: tuple[str, ...] = ("Near-side", "Far-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
n_groups = len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(1, n_periods, sharey=True, sharex=True)
axes = list(axs.flatten())
bar_width = 1 / (n_groups + 1)
n_sides = len(sides_label)
for i, filter_group in enumerate(filter_args):
experiments_ = cls.filter(experiments, **filter_group)
for j, (data_name, t) in enumerate(periods):
sides_index_all = []
for experiment in experiments_:
f = partial(
experiment.position_to_side,
split_horizontally=split_horizontally, sides_label=sides_label,
)
_, sides_index = getattr(experiment, data_name).transform_to_side_of_box(f, sides_label)
sides_index_all.append(sides_index)
ax = axes[j]
mean_prop, percents = cls._get_side_percent(sides_index_all, sides_label)
ax.bar(i * bar_width + np.arange(n_sides), mean_prop, bar_width, label=filter_group["condition"])
for prop in percents:
ax.plot(i * bar_width + np.arange(n_sides), prop, "k.")
ax.set_xlabel("Teaball side")
if not j:
ax.set_ylabel("% time spent")
if not i:
ax.set_title(f"$\\bf{{{t}}}$")
for ax in axes:
ax.set_xticks(np.arange(n_sides) + (1 - bar_width) / 2 - bar_width / 2, sides_label)
ax.set_xlim(-bar_width, n_sides + bar_width)
handles, labels = axes[-1].get_legend_handles_labels()
fig.legend(handles, labels, ncols=n_groups, bbox_to_anchor=(0, 0), loc=2)
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix)
@classmethod
def plot_multi_experiment_merged_by_group_side_of_box_percent(
cls, experiments: list["Experiment"], filter_args: list[dict], split_horizontally: bool = True,
sides_label: tuple[str, ...] = ("Teaball-side", "other-side"), title: str = "Subjects motion",
save_fig_root: None | Path = None, save_fig_prefix: str = "", show_titles: bool = True,
show_legend: bool = True, only_side: str | None = None, show_xlabel: bool = True,
):
n_groups = len(filter_args)
periods = cls._get_data_items()
n_periods = len(periods)
fig, axs = plt.subplots(n_groups, 1, sharey=True, sharex=True)
axes = list(axs.flatten())
bar_width = 1 / (n_periods + 1)
n_sides = len(sides_label)
sides_s = slice(0, n_sides)
n_sides_used = n_sides
sides_used = sides_label
if only_side:
i = sides_label.index(only_side)
sides_s = slice(i, i + 1)
n_sides_used = 1
sides_used = [only_side]
for i, filter_group in enumerate(filter_args):
experiments_ = cls.filter(experiments, **filter_group)
for j, (data_name, t) in enumerate(periods):
sides_index_all = []
for experiment in experiments_:
f = partial(
experiment.position_to_side,
split_horizontally=split_horizontally, sides_label=sides_label,
)
_, sides_index = getattr(experiment, data_name).transform_to_side_of_box(f, sides_label)
sides_index_all.append(sides_index)
ax = axes[i]
mean_prop, percents = cls._get_side_percent(sides_index_all, sides_label)
ax.bar(j * bar_width + np.arange(n_sides_used), mean_prop[sides_s], bar_width, label=t)
for prop in percents:
ax.plot(j * bar_width + np.arange(n_sides_used), prop[sides_s], "k.")
if i == n_groups - 1 and show_xlabel:
ax.set_xlabel("Teaball side")
ax.set_ylabel("% time spent")
if not j and show_titles:
ax.set_title(f"$\\bf{{{filter_group['condition']}}}$")
for ax in axes:
ax.set_xticks(np.arange(n_sides_used) + (1 - bar_width) / 2 - bar_width / 2, sides_used)
ax.set_xlim(-bar_width, n_sides_used - bar_width)
if show_legend:
handles, labels = axes[-1].get_legend_handles_labels()
fig.legend(handles, labels, ncols=1, bbox_to_anchor=(0, 0), loc=2)
fig.suptitle(title)
save_or_show(save_fig_root, save_fig_prefix, width_inch=3)
@classmethod
def filter(
cls, experiments: list["Experiment"], date: int | None = None, subject: int | None = None,
condition: str | None = None, cell_label: str | None = None, sex: str | None = None
):
if date is not None:
experiments = [t for t in experiments if t.date == date]
if subject is not None:
experiments = [t for t in experiments if t.subject == subject]
if condition is not None:
experiments = [t for t in experiments if t.condition == condition]
if cell_label is not None:
experiments = [t for t in experiments if t.cell_label == cell_label]
if sex is not None:
experiments = [t for t in experiments if t.sex == sex]
return experiments
@dataclass
class SubjectPos:
times: np.ndarray
track: np.ndarray
motion_index: np.ndarray
@property
def min_x(self):
return np.min(self.track[:, 0])
@property
def min_y(self):
return np.min(self.track[:, 1])
@property
def max_x(self):
return np.max(self.track[:, 0])
@property
def max_y(self):
return np.max(self.track[:, 1])
@classmethod
def parse_csv_track(cls, filename: Path, subject_name: str = "mouse") -> "SubjectPos":
track = []
times = []
index = []
with open(filename, "r") as fh:
reader = csv.reader(fh)
header = next(reader)
for row in reader:
frame, instance, cx, cy, index_val, t = row
if instance != subject_name:
continue
track.append((int(float(cx)), int(float(cy))))
index.append(float(index_val))
th, tm, ts = map(float, t.split(":"))
times.append(th * 60 ** 2 + tm * 60 + ts)
return SubjectPos(times=np.array(times), track=np.array(track), motion_index=np.array(index))
def extract_range(self, t_start: float | None = None, t_end: float | None = None):
if t_start is None:
t_start = self.times[0]
if t_end is None:
t_end = self.times[-1] + 1
i_s = np.sum(self.times < t_start)
i_e = np.sum(self.times <= t_end)
return SubjectPos(times=self.times[i_s:i_e], track=self.track[i_s:i_e], motion_index=self.motion_index[i_s:i_e])
def calculate_occupancy(
self, occupancy: np.ndarray, pos_to_index: Callable | None = None, frame_normalize: bool = True,
) -> None:
n = self.track.shape[0]
grid_width, grid_height = occupancy.shape
frame_proportion = 1
if frame_normalize:
frame_proportion = 1 / n
for i in range(n):
if self.track[i, 0] < 0 or self.track[i, 1] < 0:
continue
if pos_to_index is None:
x, y = self.track[i, :]
area_prop = 1
else:
x, y, area_prop = pos_to_index(*self.track[i, :], grid_width, grid_height)
x = int(min(x, grid_width))
y = int(min(y, grid_height))
occupancy[x, y] += frame_proportion * area_prop
@classmethod
def plot_occupancy_data(
cls, occupancy: np.ndarray, fig: plt.Figure, ax: plt.Axes,
gaussian_sigma: float = 0, intensity_limit: float = 0, scale_to_one: bool = True,
x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy",
color_bar: bool = True,
):
if gaussian_sigma:
occupancy = gaussian_filter(occupancy, gaussian_sigma)
if scale_to_one:
occupancy /= occupancy.max()
im = ax.imshow(
occupancy.T, aspect="auto", origin="upper", cmap="viridis", interpolation="sinc",
interpolation_stage="data", vmax=intensity_limit or None
)
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if title:
ax.set_title(title)
if color_bar:
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im, cax=cax, orientation='vertical')
def plot_occupancy(
self, grid_width: int, grid_height: int,
pos_to_index: Callable | None = None, fig: plt.Figure | None = None, ax: plt.Axes | None = None,
gaussian_sigma: float = 0, intensity_limit: int = 0, frame_normalize: bool = True,
scale_to_one: bool = True, save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Box X", y_label: str = "Box Y", title: str = "Occupancy",
color_bar: bool = True,
):
occupancy = np.zeros((grid_width, grid_height))
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
self.calculate_occupancy(occupancy, pos_to_index, frame_normalize)
self.plot_occupancy_data(
occupancy, fig, ax, gaussian_sigma, intensity_limit, scale_to_one, x_label, y_label, title, color_bar,
)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def plot_motion(
self, fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Time (min)", y_label: str = "Motion index",
):
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
valid = self.motion_index >= 0
times = self.times[valid]
ax.plot(times / 60 - times[0] / 60, self.motion_index[valid], ",")
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def transform_to_side_of_box(
self, position_to_side: Callable, sides_label: tuple[str, ...],
) -> tuple[np.ndarray, np.ndarray]:
valid = np.logical_and(self.track[:, 0] >= 0, self.track[:, 1] >= 0)
pos_sides_name = [position_to_side(*p) for p in self.track[valid, :]]
sides = {name: i for i, name in enumerate(sides_label)}
sides_index = np.array([sides[n] for n in pos_sides_name])
times = self.times[valid]
return times, sides_index
def plot_side_of_box(
self, position_to_side: Callable, sides_label: tuple[str, ...],
fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
):
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
times, sides_index = self.transform_to_side_of_box(position_to_side, sides_label)
ax.plot(times / 60 - times[0] / 60, sides_index, "*", alpha=.3, ms=5)
ax.set_yticks(np.arange(len(sides_label)), sides_label)
ax.set_xlabel("Time (min)")
ax.set_ylabel("Side of box relative to teaball")
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
def plot_distance_from_point(
self, point_xy: tuple[int, int], fig: plt.Figure | None = None, ax: plt.Axes | None = None,
save_fig_root: None | Path = None, save_fig_prefix: str = "",
x_label: str = "Time (min)", y_label: str = "Distance",
):
show_plot = ax is None
if ax is None:
assert fig is None
fig, ax = plt.subplots()
else:
assert fig is not None
valid = np.logical_and(self.track[:, 0] >= 0, self.track[:, 1] >= 0)
point = np.array(point_xy)[None, :]
distance = np.sqrt(np.sum(np.square(self.track[valid, :] - point), axis=1))
times = self.times[valid]
ax.plot(times / 60 - times[0] / 60, distance, ",")
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
if show_plot:
save_or_show(save_fig_root, save_fig_prefix)
if __name__ == "__main__":
csv_experiment_times = Path(r'C:\Users\Matthew Einhorn\Downloads\tmt24') / "Batch3 data analysis - timestamps.csv"
tracking_data_root = Path(r"C:\Users\Matthew Einhorn\Downloads\tmt24\tracking")
json_data_root = Path(r"C:\Users\Matthew Einhorn\Downloads\tmt24\json")
figure_root = Path(r"C:\Users\Matthew Einhorn\Downloads\tmt24\figures")
conditions = "Blank", "TMT", "2MBA", "IAMM"
experiments = Experiment.parse_experiment_spec_csv(csv_experiment_times)
for experiment in experiments:
experiment.parse_box_metadata(json_data_root)
experiment.read_pos_track(tracking_data_root, hab_duration=5 * 60, trial_duration=2 * 60, post_duration=2 * 60)
grid_width, grid_height = Experiment.enlarge_canvas(experiments)
for experiment in experiments:
experiment.plot_occupancy(
grid_width, grid_height, scale_to_one=False, intensity_limit=1e-5,
save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}",
save_fig_prefix=f"intensity_limit100K_{experiment.subject}_{experiment.date}",
)
# experiment.plot_occupancy(
# grid_width, grid_height, intensity_limit=1e-6,
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}",
# save_fig_prefix=f"intensity_limit1M_{experiment.subject}_{experiment.date}",
# )
# experiment.plot_occupancy(
# grid_width, grid_height, gaussian_sigma=5,
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}",
# save_fig_prefix=f"occupancy_sigma5_{experiment.subject}_{experiment.date}",
# )
# experiment.plot_occupancy(
# grid_width, grid_height, gaussian_sigma=1,
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}",
# save_fig_prefix=f"occupancy_sigma1_{experiment.subject}_{experiment.date}",
# )
# experiment.plot_motion_histogram(
# n_bins=10,
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}",
# save_fig_prefix=f"motion_histogram_{experiment.subject}_{experiment.date}",
# )
# experiment.plot_motion(
# y_limit=3,
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}",
# save_fig_prefix=f"motion_index_{experiment.subject}_{experiment.date}",
# )
# experiment.plot_distance_from_point(
# ("left", "bottom"),
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}",
# save_fig_prefix=f"distance_{experiment.subject}_{experiment.date}",
# )
# experiment.plot_side_of_box_percent(
# save_fig_root=figure_root / f"{experiment.subject}_{experiment.date}",
# save_fig_prefix=f"side_percent_{experiment.subject}_{experiment.date}",
# )
#
for condition in ("TMT", "Blank", "2MBA", "IAMM"):
experiments_ = Experiment.filter(experiments, condition=condition)
Experiment.plot_multi_experiment_occupancy(
experiments_, grid_width, grid_height, title=f"{condition} occupancy density", scale_to_one=False,
intensity_limit=1e-5,
save_fig_root=figure_root / "grouped" / "occupancy",
save_fig_prefix=f"occupancy_{condition}_intensity_limit100K",
)
# Experiment.plot_multi_experiment_occupancy(
# experiments_, grid_width, grid_height, title=f"{condition} occupancy density", intensity_limit=1e-6,
# save_fig_root=figure_root / "grouped" / "occupancy",
# save_fig_prefix=f"occupancy_{condition}_intensity_limit1M",
# )
# Experiment.plot_multi_experiment_occupancy(
# experiments_, grid_width, grid_height, title=f"{condition} occupancy", gaussian_sigma=5,
# save_fig_root=figure_root / "grouped" / "occupancy", save_fig_prefix=f"occupancy_{condition}_sigma5",
# )
# Experiment.plot_multi_experiment_occupancy(
# experiments_, grid_width, grid_height, title=f"{condition} occupancy", gaussian_sigma=1,
# save_fig_root=figure_root / "grouped" / "occupancy", save_fig_prefix=f"occupancy_{condition}_sigma1",
# )
# Experiment.plot_multi_experiment_motion_histogram(
# experiments_, title=f"{condition} motion index density", n_bins=10,
# save_fig_root=figure_root / "grouped" / "motion_histogram", save_fig_prefix=f"motion_histogram_{condition}",
# )
# Experiment.plot_multi_experiment_motion(
# experiments_, y_limit=3, title=f"{condition} motion index",
# save_fig_root=figure_root / "grouped" / "motion_index", save_fig_prefix=f"motion_index_{condition}",
# )
# Experiment.plot_multi_experiment_distance_from_point(
# experiments_, ("left", "bottom"), title=f"{condition} distance from teaball corner",
# save_fig_root=figure_root / "grouped" / "distance", save_fig_prefix=f"distance_{condition}",
# )
# Experiment.plot_multi_experiment_side_of_box_percent(
# experiments_, title=f"{condition} side of box percent",
# save_fig_root=figure_root / "grouped" / "side_percent", save_fig_prefix=f"side_percent_{condition}",
# )
Experiment.plot_multi_experiment_occupancy(
experiments, grid_width, grid_height, title=f"Occupancy density", scale_to_one=False, intensity_limit=1e-5,
filter_args=[{"condition": c} for c in conditions],
save_fig_root=figure_root / "grouped_single_figure",
save_fig_prefix=f"occupancy_intensity_limit100K",
)
# Experiment.plot_multi_experiment_occupancy(
# experiments, grid_width, grid_height, title=f"Occupancy density", intensity_limit=1e-6,
# filter_args=[{"condition": c} for c in conditions],
# save_fig_root=figure_root / "grouped_single_figure",
# save_fig_prefix=f"occupancy_intensity_limit1M",
# )
# Experiment.plot_multi_experiment_occupancy(
# experiments, grid_width, grid_height, title=f"Occupancy density", gaussian_sigma=5,
# filter_args=[{"condition": c} for c in conditions],
# save_fig_root=figure_root / "grouped_single_figure",
# save_fig_prefix=f"occupancy_sigma5",
# )
# Experiment.plot_multi_experiment_occupancy(
# experiments, grid_width, grid_height, title=f"Occupancy density", gaussian_sigma=1,
# filter_args=[{"condition": c} for c in conditions],
# save_fig_root=figure_root / "grouped_single_figure",
# save_fig_prefix=f"occupancy_sigma1",
# )
# Experiment.plot_multi_experiment_motion_histogram(
# experiments, title=f"Motion index density", n_bins=10,
# filter_args=[{"condition": c} for c in conditions],
# save_fig_root=figure_root / "grouped_single_figure",
# save_fig_prefix=f"motion_histogram",
# )
# Experiment.plot_multi_experiment_motion(
# experiments, y_limit=3, title=f"Motion index",
# filter_args=[{"condition": c} for c in conditions],
# save_fig_root=figure_root / "grouped_single_figure",
# save_fig_prefix=f"motion_index",
# )
# Experiment.plot_multi_experiment_distance_from_point(
# experiments, ("left", "bottom"), title=f"Distance from teaball corner",
# filter_args=[{"condition": c} for c in conditions],
# save_fig_root=figure_root / "grouped_single_figure",
# save_fig_prefix=f"distance",
# )
# Experiment.plot_multi_experiment_side_of_box_percent(
# experiments, title=f"Side of box percent",
# filter_args=[{"condition": c} for c in conditions],
# save_fig_root=figure_root / "grouped_single_figure",
# save_fig_prefix=f"side_percent",
# )
# Experiment.plot_multi_experiment_merged_by_period_side_of_box_percent(
# experiments, title=f"Side of box percent",
# filter_args=[{"condition": c} for c in conditions],
# save_fig_root=figure_root / "grouped_single_figure",
# save_fig_prefix=f"side_percent_merged_by_period",
# )
# Experiment.plot_multi_experiment_merged_by_group_side_of_box_percent(
# experiments, title=f"% spent on teaball side", show_titles=False, show_legend=False,
# filter_args=[{"condition": c} for c in conditions], only_side="Teaball-side", show_xlabel=False,
# save_fig_root=figure_root / "grouped_single_figure",
# save_fig_prefix=f"side_percent_merged_by_group",
# )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment