Created
October 1, 2019 16:06
-
-
Save zero323/ee36bce57ddeac82322e3ab4ef547611 to your computer and use it in GitHub Desktop.
Generating ML setters and getters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.ml.param import Param, Params, TypeConverters | |
from pyspark import keyword_only | |
class FooBar(Params): | |
foo = Param( | |
Params._dummy(), | |
"foo", "Just foo", | |
typeConverter=TypeConverters.toInt) | |
bar = Param( | |
Params._dummy(), | |
"bar", "Just bar", | |
typeConverter=TypeConverters.toString) | |
@keyword_only | |
def __init__(self, foo=42, bar=None): | |
super(FooBar, self).__init__() | |
self._setDefault(foo=42, bar="") | |
kwargs = self._input_kwargs | |
self.setParams(**kwargs) | |
@keyword_only | |
def setParams(self, foo=None, bar=None): | |
kwargs = self._input_kwargs | |
return self._set(**kwargs) | |
def getFoo(self): | |
"Sets the value of :py:attr:`foo` or its default value." | |
return self.getOrDefault(self.foo) | |
def getBar(self): | |
"Sets the value of :py:attr:`bar`" | |
return self.getOrDefault(self.bar) | |
def setBar(self, value): | |
"Sets the value of :py:attr:`bar` or its default value." | |
return self._set(bar=value) | |
foobar = FooBar() | |
foobar.setBar("bar") | |
foobar.getFoo() | |
foobar.getBar() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.ml.param import Param, Params, TypeConverters | |
from pyspark import keyword_only | |
def _make_getter(param): | |
name = param.name if isinstance(param, Param) else param | |
def get(self): | |
return self.getOrDefault(name) | |
get.__doc__ = "Gets the value of :py:attr:`{}` or its default value.".format(name) | |
return get | |
def _make_setter(param): | |
name = param.name if isinstance(param, Param) else param | |
def set(self, value): | |
return self._set(**{name: value}) | |
set.__doc__ = "Sets the value of :py:attr:`{}`".format(name) | |
return set | |
class FooBar(Params): | |
foo = Param( | |
Params._dummy(), | |
"foo", "Just foo", | |
typeConverter=TypeConverters.toInt) | |
bar = Param( | |
Params._dummy(), | |
"bar", "Just bar", | |
typeConverter=TypeConverters.toString) | |
@keyword_only | |
def __init__(self, foo=42, bar=None): | |
super(FooBar, self).__init__() | |
self._setDefault(foo=42, bar="") | |
kwargs = self._input_kwargs | |
self.setParams(**kwargs) | |
@keyword_only | |
def setParams(self, foo=None, bar=None): | |
kwargs = self._input_kwargs | |
return self._set(**kwargs) | |
setBar = _make_setter(bar) | |
getBar = _make_getter(bar) | |
getFoo = _make_getter(foo) | |
foobar = FooBar() | |
foobar.setBar("bar") | |
foobar.getFoo() | |
foobar.getBar() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def makeMLAccessors(setters=False): | |
def _(*params): | |
def augment(cls): | |
for param in params: | |
cap_name = param.title() | |
if setters: | |
setattr( | |
cls, | |
"set{}".format(cap_name), | |
_make_setter(param)) | |
else: | |
setattr( | |
cls, | |
"get{}".format(cap_name), | |
_make_getter(param)) | |
return cls | |
def _(x): | |
if isinstance(x, type): | |
# It seems like we've reached annotated type | |
return augment(x) | |
else: | |
# It seems like we're in the middle of decorator stack | |
def _(cls): | |
augment(x(cls)) | |
return _ | |
return _ | |
return _ | |
withMLGetters = makeMLAccessors(setters=False) | |
withMLSetters = makeMLAccessors(setters=True) | |
@withMLSetters("bar") | |
@withMLGetters("foo", "bar") | |
class FooBar(Params): | |
foo = Param( | |
Params._dummy(), | |
"foo", "Just foo", | |
typeConverter=TypeConverters.toInt) | |
bar = Param( | |
Params._dummy(), | |
"bar", "Just bar", | |
typeConverter=TypeConverters.toString) | |
@keyword_only | |
def __init__(self, foo=42, bar=None): | |
super(FooBar, self).__init__() | |
self._setDefault(foo=42, bar="") | |
kwargs = self._input_kwargs | |
self.setParams(**kwargs) | |
@keyword_only | |
def setParams(self, foo=None, bar=None): | |
kwargs = self._input_kwargs | |
return self._set(**kwargs) | |
foobar = FooBar() | |
foobar.setBar("bar") | |
foobar.getFoo() | |
foobar.getBar() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def MLSetters(*params): | |
return type( | |
"With".join( | |
["Params"] + ["HasSet{}".format(param.title()) for param in params] | |
), | |
(Params, ), | |
{"set{}".format(param.title()): _make_setter(param) for param in params} | |
) | |
def MLGetters(*params): | |
return type( | |
"With".join( | |
["Params"] + ["HasGet{}".format(param.title()) for param in params] | |
), | |
(Params, ), | |
{"get{}".format(param.title()): _make_getter(param) for param in params} | |
) | |
class FooBar(MLSetters("bar"), MLGetters("foo", "bar")): | |
foo = Param( | |
Params._dummy(), | |
"foo", "Just foo", | |
typeConverter=TypeConverters.toInt) | |
bar = Param( | |
Params._dummy(), | |
"bar", "Just bar", | |
typeConverter=TypeConverters.toString) | |
@keyword_only | |
def __init__(self, foo=42, bar=None): | |
super(FooBar, self).__init__() | |
self._setDefault(foo=42, bar="") | |
kwargs = self._input_kwargs | |
self.setParams(**kwargs) | |
@keyword_only | |
def setParams(self, foo=None, bar=None): | |
kwargs = self._input_kwargs | |
return self._set(**kwargs) | |
foobar = FooBar() | |
foobar.setBar("bar") | |
foobar.getFoo() | |
foobar.getBar() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Right now
pyspark.ml
wrappers require significant amount of boilerplate code to provide getters and setters, as well as initializers (__init__
) andsetParams
methods. This is partially addressed by a rich set of shared mixins (mimicking Scala counterpart), which are partially code generated. This approach, however, is somewhat limited, especially in case of one-off Params. Its benefits are further limited by ongoing work (https://issues.apache.org/jira/browse/SPARK-29093), therefore it might be worthwhile to explore other options.At the moment as simple Params definition with two Param fields (foo, with getter, and bar with setter and getter) requires something like this. get* methods are virtually identical for all Params
with varying
paramName
, whileset*
show some variation, although most of the time a basic pattern is used:Because of that, one could simply generate these methods using simple helpers,
This is still quite verbose, nonetheless it can significantly reduce the amount of code required to provide more complex wrappers.
It can further simplified by providing simple class decorators, resulting in something around these lines
While concise, such approach has serious disadvantages, as it is completely opaque for static analysis tools.
One possible compromise is to use dynamically generated base classes, resulting in something like this:
It is still rather opaque, but clearly indicates any changes in the public API (something that I personally find quite important). Dynamically generated classes are not for the faint of heart, and in the past I argued against using these, but there are still widely used (with namedtuples and data classes being the most prominent examples).
It is clear that similar, though more complex, methods can be used to generate both initalizers and
setParams
, but I consider this a separate issue.Each of the proposed solutions (and likely any similar approach) has some serious caveats and I am not completely convinced that the pros outweigh the cons.