Created
August 18, 2020 05:04
-
-
Save visionNoob/830729f4269bd060dc6bb7901c301d1a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import collections | |
from functools import reduce, partial | |
from operator import getitem | |
from logger import setup_logging | |
from omegaconf import OmegaConf | |
class ConfigParser: | |
def __init__(self, config, modification=None): | |
""" | |
class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving | |
and logging module. | |
:param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. | |
:param modification: Dict keychain:value, specifying position values to be replaced from config dict. | |
""" | |
# load config file and apply modification | |
self._config = _update_config(config, modification) | |
@classmethod | |
def from_args(cls, args, options=''): | |
""" | |
Initialize this class from some cli arguments. Used in train, test. | |
""" | |
for opt in options: | |
args.add_argument(*opt.flags, default=None, type=opt.type) | |
if not isinstance(args, tuple): | |
args = args.parse_args() | |
config = OmegaConf.load("./test_config.yaml") | |
modification = {opt.target: getattr( | |
args, _get_opt_name(opt.flags)) for opt in options} | |
return cls(config, modification=modification) | |
def init_obj(self, name, module, *args, **kwargs): | |
""" | |
Finds a function handle with the name given as 'type' in config, and returns the | |
instance initialized with corresponding arguments given. | |
`object = config.init_obj('name', module, a, b=1)` | |
is equivalent to | |
`object = module.name(a, b=1)` | |
""" | |
module_name = self[name]['type'] | |
module_args = dict(self[name]['args']) | |
assert all([k not in module_args for k in kwargs] | |
), 'Overwriting kwargs given in config file is not allowed' | |
module_args.update(kwargs) | |
return getattr(module, module_name)(*args, **module_args) | |
def init_ftn(self, name, module, *args, **kwargs): | |
""" | |
Finds a function handle with the name given as 'type' in config, and returns the | |
function with given arguments fixed with functools.partial. | |
`function = config.init_ftn('name', module, a, b=1)` | |
is equivalent to | |
`function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. | |
""" | |
module_name = self[name]['type'] | |
module_args = dict(self[name]['args']) | |
assert all([k not in module_args for k in kwargs] | |
), 'Overwriting kwargs given in config file is not allowed' | |
module_args.update(kwargs) | |
return partial(getattr(module, module_name), *args, **module_args) | |
def __getitem__(self, name): | |
"""Access items like ordinary dict.""" | |
return self.config[name] | |
# setting read-only attributes | |
@property | |
def config(self): | |
return self._config | |
# helper functions to update config dict with custom cli options | |
def _update_config(config, modification): | |
if modification is None: | |
return config | |
for k, v in modification.items(): | |
if v is not None: | |
_set_by_path(config, k, v) | |
return config | |
def _get_opt_name(flags): | |
for flg in flags: | |
if flg.startswith('--'): | |
return flg.replace('--', '') | |
return flags[0].replace('--', '') | |
def _set_by_path(tree, keys, value): | |
"""Set a value in a nested object in tree by sequence of keys.""" | |
keys = keys.split(';') | |
_get_by_path(tree, keys[:-1])[keys[-1]] = value | |
def _get_by_path(tree, keys): | |
"""Access a nested object in tree by sequence of keys.""" | |
return reduce(getitem, keys, tree) | |
def main(config): | |
print(config['data_loader']) | |
if __name__ == '__main__': | |
args = argparse.ArgumentParser(description='Model Train Phase') | |
# custom cli options to modify configuation from default values fiben in json file. | |
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') | |
options = [ | |
CustomArgs(['--bs', '--batch_size'], type=int, | |
target='data_loader;args;batch_size') | |
] | |
config = ConfigParser.from_args(args, options) | |
main(config) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment