Last active
January 27, 2022 21:31
-
-
Save tacaswell/95177903175dbc28be5353b4a0e5118f to your computer and use it in GitHub Desktop.
datathoughts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib | |
import matplotlib.lines | |
from matplotlib.artist import allow_rasterization | |
import matplotlib.pyplot as plt | |
class MatplotlibException(Exception): | |
... | |
class InvalidDatasource(MatplotlibException, ValueError): | |
... | |
class SimpleSource: | |
def __init__(self, **kwargs): | |
self._data = {k: np.asanyarray(v) for k, v in kwargs.items()} | |
self.md = {k: {"ndim": v.ndim, "dtype": v.dtype} for k, v in self._data.items()} | |
def get(self, keys, ax=None, renderer=None): | |
return {k: self._data[k] for k in keys} | |
class DFSource: | |
def __init__(self, df, **kwargs): | |
self._remapping = kwargs | |
self._data = df | |
self.md = {k: {"ndim": 1, "dtype": df[v].dtype} for k, v in kwargs.items()} | |
def get(self, keys, ax, renderer): | |
return {k: self._data[self._mapping[k]] for k in keys} | |
class FuncSource1D: | |
def __init__(self, func): | |
self._func = func | |
self.md = {"x": {"ndim": 1, "dtype": float}, "y": {"ndim": 1, "dtype": float}} | |
def get(self, keys, ax, renderer): | |
assert set(keys) == set(self.md) | |
xlim = ax.get_xlim() | |
bbox = ax.get_window_extent(renderer) | |
xpixels = bbox.width | |
x = np.linspace(*xlim, xpixels) | |
return {"x": x, "y": self._func(x)} | |
class DSLine2D(matplotlib.lines.Line2D): | |
def __init__(self, DS, **kwargs): | |
if not all(k in DS.md for k in ("x", "y")): | |
raise InvalidDatasource | |
self._DS = DS | |
super().__init__([], [], **kwargs) | |
@allow_rasterization | |
def draw(self, renderer): | |
data = self._DS.get({"x", "y"}, self.axes, renderer) | |
super().set_data(data["x"], data["y"]) | |
return super().draw(renderer) | |
plt.close("all") | |
ax = plt.gca() | |
DS = SimpleSource(x=np.linspace(0, 10, 100), y=np.sin(np.linspace(0, 10, 100))) | |
DS2 = FuncSource1D(lambda x: np.cos(x) + 1) | |
dsl = DSLine2D(DS, color="red") | |
ax.add_artist(dsl) | |
dsl2 = DSLine2D(DS2, color="blue") | |
ax.add_artist(dsl2) |
I suspect we will want an easy way to wrap an xarray, but I am not sure that xarray can be the whole data model (as I am not sure how to fit things like the function source into it).
I could see expecting the source to provide a dict-of-arrays alike (ex pandas or xarray) back from the get
call (instead of the artists calling them n times)?
Not so much xarray as data model, more stealing ideas from their architecture (same really with dask on some of the functional ideas)
👍
for simple indexing of an array
from typing import TYPE_CHECKING, Set
from numbers import Integral
from matplotlib.axes import Axes
class ArraySource1D:
def __init__(self, array, scale=1) -> None:
self._arr = array
self._scale = scale
if hasattr(self._arr, "vindex"):
# account for zarr
self._indexer = self._arr.vindex
else:
self._indexer = self._arr
self.md = {"x": {"ndim": 1, "dtype": float}, "y": {"ndim": 1, "dtype": float}}
@property
def scale(self) -> int:
return self.scale
@scale.setter
def scale(self, value: int):
if not isinstance(value, Integral):
raise TypeError(f"scale must be integer values but is type {type(value)}")
self._scale = value
def get(self, keys: Set[str], ax: Axes, renderer):
xlim = ax.get_xlim()
xmin = np.max([int(xlim[0]), 0])
xmax = np.max([int(xlim[1]), 0])
x = np.arange(xmin, xmax, self._scale)
return {"x": x, "y": self._indexer[x]}
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
we should maybe talk to @shoyer about how xarray does this? They've apparently got a super clean model under the hood...
And I'm thinking for many users the datasource stuff will get hidden in @process_data?