Last active
December 17, 2019 06:10
-
-
Save Multihuntr/55df3a9543e0d6f29af6942dab112b39 to your computer and use it in GitHub Desktop.
Decorator for lazy-loaded derived values
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def derive(_from, _coll='_derived', name=None): | |
''' | |
Creates a decorator that caches derived values. | |
Utilises a property on the object to keep a collection of derived properties, | |
but requires the calling obj to manage that collection (clearing, instantiation, etc). | |
Args: | |
_from (str): Property this property is derived from | |
_coll (str, optional): Collection name of derived property names | |
name (str, optional): Overwrite the property used to cache | |
Returns: | |
function(function(self)): a decorator | |
''' | |
def wrapped(fnc): | |
nonlocal name | |
if name is None: | |
name = fnc.__name__ | |
def wrapper(self): | |
coll = getattr(self, _coll) | |
p = getattr(self, _from, None) | |
if p is None: | |
return None | |
if name not in coll or coll[name] is None: | |
coll[name] = fnc(self, p) | |
return coll[name] | |
return wrapper | |
return wrapped |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from derive import derive | |
class Evaluation: | |
''' Accumulator for distances between targets and outputs ''' | |
def __init__(self, name=None): | |
self.name = name | |
self.dist_list = [] | |
self._derived = {} | |
def reset_derived(self): | |
self._derived = {} | |
@property | |
@derive(_from='dist_list') | |
def dist(self, dist_list): | |
return np.concatenate(dist_list) | |
@property | |
@derive(_from='dist') | |
def mean(self, dist): | |
return dist.mean() | |
def update(self, target, output): | |
''' Given a batch of examples and outputs update metrics ''' | |
# target shaped [..., 2] | |
# output shaped [..., 2] | |
dist = np.linalg.norm((output - target), axis =-1) | |
self.dist_list.append(dist) | |
self.reset_derived() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is very similar in concept to using something like:
But, it exposes a simpler interface for clearing all derived values when the cache becomes invalidated. And automatically handles the case that the value to derive from is None.
Also,
functools.lru_cache
requires a bit of finagling to make work as you expect on class methods (https://stackoverflow.com/questions/33672412/python-functools-lru-cache-with-class-methods-release-object).