Created
August 6, 2019 17:43
-
-
Save dzil123/ce01158e30eca73bd0c82d2e487424be to your computer and use it in GitHub Desktop.
Reference the value of other arguments in your function definitions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import inspect | |
import itertools | |
import typing | |
class Arg(typing.NamedTuple): | |
name: str | |
def wrap(func): | |
sig = inspect.signature(func) | |
# ensure no Args point to another Arg | |
# eg no `def foo(a, b=Arg('a'), c=Arg('b')):` | |
for val in sig.parameters.values(): | |
if isinstance(val.default, Arg): | |
arg: str = val.default.name | |
try: | |
pointer: inspect.Parameter = sig.parameters[arg] | |
except KeyError: | |
raise TypeError( | |
"Arg '{}' points to nonexistent argument '{}'".format( | |
val.name, val.default.name | |
) | |
) | |
if isinstance(pointer.default, Arg): | |
raise TypeError( | |
"Arg '{}' points to another Arg '{}' = Arg('{}')".format( | |
val.name, pointer.name, pointer.default.name | |
) | |
) | |
@functools.wraps(func) | |
def wrapper(*args, **kwargs): | |
# ensure no Args passed in at runtime | |
# eg no `foo(Arg('b'))` | |
for val in itertools.chain(args, kwargs.values()): | |
if isinstance(val, Arg): | |
raise TypeError(f"Cannot pass in Arg instance {val}") | |
# Will raise error on incorrect *args, **kwargs | |
bound = sig.bind(*args, **kwargs) | |
bound.apply_defaults() | |
# Loop through all arguments, replacing Args with the correct value | |
for key, val in bound.arguments.items(): | |
if isinstance(val, Arg): | |
bound.arguments[key] = bound.arguments[val.name] | |
return func(*bound.args, **bound.kwargs) | |
return wrapper | |
@wrap | |
def foo(a, b, c=None, d=Arg("c"), e=Arg("b")): | |
return {"a": a, "b": b, "c": c, "d": d, "e": e} | |
def test(*args, **kargs): | |
print(f"Result of calling foo({args} {kargs}):") | |
print(foo(*args, **kargs)) | |
print() | |
test(1, 2) | |
test(5, 7, d=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment