Last active
November 20, 2018 18:08
-
-
Save wolfv/73f56e4a9cac84eea6a796fde3213456 to your computer and use it in GitHub Desktop.
A Taste Of Julia / C++ in Python – simple Python multiple dispatch from type hints
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import * | |
import re | |
def to_regex(typevar, groups): | |
def to_matchgroup(arg, groups): | |
if type(arg) is TypeVar: | |
if arg in groups: | |
return "(?P={})".format(arg.__name__) | |
else: | |
groups |= {arg} | |
return "(?P<{}>.*?)".format(arg.__name__) | |
else: | |
return to_regex(arg, groups) | |
if typevar in {float, int, str}: | |
return typevar.__name__ | |
elif typevar.mro()[1] is Sequence: | |
return "(?:list|set|tuple)\[{}\]".format(to_matchgroup(typevar.__args__[0], groups)) | |
return ".*?" | |
def get_element_types(sequence): | |
return set(type(el) for el in sequence) | |
def to_callee(arg): | |
if type(arg) in [float, int]: | |
return type(arg).__name__ | |
elif type(arg) in [list, set, tuple]: | |
t = type(arg).__name__ + '[{}]' | |
eltypes = get_element_types(arg) | |
if len(eltypes) == 1: | |
return t.format(list(eltypes)[0].__name__) | |
else: | |
raise RuntimeError("Not implemented yet.") | |
else: | |
raise RuntimeError("Not implemented yet.") | |
def to_match_target(caller_signature): | |
return ", ".join([to_callee(el) for el in caller_signature]) | |
def to_regex_sig(caller_signature): | |
groups = set() | |
return ", ".join([to_regex(el, groups) for el in caller_signature]) | |
class overloaded(object): | |
fmap = {} | |
def __init__(self, f): | |
signature = tuple(x[1] for x in f.__annotations__.items()) | |
groups = set() | |
self.fmap[to_regex_sig(signature)] = f | |
def __call__(self, *args): | |
match_sig = to_match_target(args) | |
for key, func in self.fmap.items(): | |
print("Matching: {} against\n {}\n".format(match_sig, key)) | |
if (re.match(key, match_sig)): | |
print(" === MATCH ===\n\n") | |
return func(*args) | |
else: | |
raise RuntimeError("No overload found for ", match_sig) | |
@overloaded | |
def add(a: int, b: int): | |
return a + b + 100 | |
@overloaded | |
def add(a: float, b: float): | |
return a + b | |
T = TypeVar('T') | |
U = TypeVar('U') | |
@overloaded | |
def add(a: Sequence[T], b: float): | |
return [x + b for x in a] | |
@overloaded | |
def add(a: Sequence[T], b: Sequence[T]): | |
return [x + y for x, y in zip(a, b)] | |
@overloaded | |
def add(a: Sequence[T], b: Sequence[str]): | |
return [str(x) + y for x, y in zip(a, b)] | |
if __name__ == '__main__': | |
print(add(3, 5)) | |
print(add(4.5, 8.2)) | |
print(add([1, 2, 3], 5.0)) | |
print(add([1, 2, 3], [1, 2, 3])) | |
print(add([1, 2, 3], ["a", "b", "c"])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment