Created
April 4, 2026 04:41
-
-
Save alexfanqi/ee45c7638ade877236dbd25cdff8138c to your computer and use it in GitHub Desktop.
omegaconf file override and config defaults
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from pathlib import Path | |
| from typing import Any, TypeVar | |
| from omegaconf import OmegaConf | |
| T = TypeVar("T") | |
| def load_config_with_case( | |
| Param: type[T], | |
| use_config_file: bool = True, | |
| search_path: Path = Path("."), | |
| argv_arglist: list[str] | None = None, | |
| return_cli: bool = False, | |
| make_concrete: bool = True, | |
| case_name_field: str = "case_name", | |
| ) -> T | tuple[T, Any]: | |
| """Load structured config from defaults, file, and CLI arguments. | |
| Merges configuration in order: dataclass defaults -> config file - | |
| > CLI arguments. Config file path is read from ``config_path`` in CLI | |
| arguments. | |
| This function supports the case_name "hack": if ``case_name`` is specified | |
| via CLI, it reloads the default configuration for that case before applying | |
| other overrides. | |
| Args: | |
| Param: Dataclass type defining the configuration schema. Must have a | |
| ``default_from_case()`` method and ``case_name`` field. | |
| use_config_file: Whether to load from file specified by | |
| ``config_path`` CLI argument. | |
| search_path: Directory to search for relative config paths. | |
| argv_arglist: Optional CLI arguments list (defaults to sys.argv). | |
| return_cli: If True, also return raw CLI config. | |
| make_concrete: If True, convert to Python objects. If False, | |
| return OmegaConf containers. | |
| case_name_field: Name of the field that triggers case reload. | |
| Returns: | |
| Loaded configuration. Tuple of (config, cli_config) if | |
| ``return_cli=True``. | |
| Example: | |
| >>> from dataclasses import dataclass | |
| >>> @dataclass | |
| ... class MyConfig: | |
| ... case_name: str = "default_case" | |
| ... value: float = 1.0 | |
| ... def default_from_case(self): | |
| ... # Load case-specific defaults | |
| ... return self | |
| >>> cfg = load_config_with_case(MyConfig) | |
| """ | |
| # Initialize with default case | |
| param_instance = Param() | |
| param_instance.default_from_case() | |
| defaults = OmegaConf.structured(param_instance) | |
| # Parse CLI | |
| if argv_arglist is None: | |
| cli_cfg_ = cli_cfg = OmegaConf.from_cli() | |
| else: | |
| cli_cfg_ = cli_cfg = OmegaConf.from_cli(argv_arglist) | |
| # Load config file if specified | |
| if use_config_file and "config_path" in cli_cfg: | |
| assert "config_path" not in Param.__dataclass_fields__ | |
| config_path = Path(cli_cfg.config_path) | |
| if not config_path.is_file(): | |
| config_path = search_path / config_path | |
| if not config_path.is_file(): | |
| raise FileNotFoundError(f"Config file not found: {config_path}") | |
| cfg_cli_file = OmegaConf.load(cli_cfg.config_path) | |
| if return_cli: | |
| cli_cfg_ = cli_cfg.copy() | |
| cli_cfg.pop("config_path") | |
| else: | |
| cfg_cli_file = OmegaConf.create() | |
| # Hack: reload case defaults if case_name specified via CLI | |
| if hasattr(cli_cfg, case_name_field): | |
| setattr(param_instance, case_name_field, getattr(cli_cfg, case_name_field)) | |
| param_instance.default_from_case() | |
| defaults = OmegaConf.structured(param_instance) | |
| # Merge: defaults <- config file <- CLI overrides | |
| cfg = OmegaConf.unsafe_merge(defaults, cfg_cli_file, cli_cfg) | |
| if make_concrete: | |
| cfg = OmegaConf.to_object(cfg) | |
| if return_cli: | |
| return cfg, cli_cfg_ | |
| return cfg |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment