Last active
January 28, 2023 02:30
-
-
Save yudhiesh/6c344f4b5b94b6d353011395534c05fb to your computer and use it in GitHub Desktop.
Train a Deep Learning model using Optuna for HPO with Distributed HPO in Metaflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import os | |
import uuid | |
from functools import ( | |
partial, | |
) | |
from typing import ( | |
List, | |
Optional, | |
) | |
import optuna | |
import pandas as pd | |
import torch | |
from metaflow import ( | |
FlowSpec, | |
Parameter, | |
kubernetes, | |
step, | |
) | |
from model import ( | |
LitMNIST, | |
) | |
from pytorch_lightning import ( | |
Trainer, | |
) | |
from pytorch_lightning.callbacks import ( | |
TQDMProgressBar, | |
) | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s | [%(levelname)s] | %(name)s | %(message)s', | |
) | |
logger = logging.getLogger(__name__) | |
def construct_study_storage_url(db_username_key: str = 'username', db_password_key: str = 'password', | |
db_host_key: str = 'host', db_port_key: str = 'port', db_name_key: str = 'db_name', | |
default_port: str = '5432'): | |
""" | |
Construct the storage URL for the Optuna study. | |
Construct the storage URL for the Optuna study. | |
:param: db_username_key: The secret key for the database user name. | |
:param: db_password_key: The secret key for the database password. | |
:param: db_host_key: The secret key for the data host. | |
:param: db_port_key: The secret key for the database port. | |
:param: db_name_key: The secret key for the database name. | |
:param: default_port: The default port to use if the port is not a key in the secret. | |
""" | |
return 'postgresql://{0}:{1}@{2}:{3}/{4}'.format( | |
os.getenv('{}'.format(db_username_key)), | |
os.getenv('{}'.format(db_password_key)), | |
os.getenv('{}'.format(db_host_key)), | |
os.getenv('{}'.format(db_port_key), default_port), | |
os.getenv('{}'.format(db_name_key)), | |
) | |
def construct_study_name(study_name_prefix: str): | |
"""Construct the name of the Optuna study.""" | |
study_uuid = str(uuid.uuid4()) | |
return '{}{}'.format(study_name_prefix, study_uuid) | |
def calculate_n_trials_by_task(n_trials: int, n_tasks: int) -> List[int]: | |
""" | |
Calculate the number of trials to perform in each of the Metaflow tasks. | |
Calculate the number of Optuna trials to perform in each Metaflow task in order to perform a total of n_trials | |
trials across the n_tasks tasks. | |
:param: n_trials: The number of Optuna trials. | |
:param: n_tasks: The number of Metaflow tasks. | |
""" | |
n_trials_by_task_lower = n_trials // n_tasks | |
n_trials_by_task_upper = n_trials // n_tasks + 1 | |
n_trials_by_task = n_tasks * [n_trials_by_task_lower] | |
n_trials_by_task[:n_trials % n_tasks] = n_trials % n_tasks * [n_trials_by_task_upper] | |
return n_trials_by_task | |
def objective(trial: optuna.Trial, data_path: str, available_gpus: int, max_epochs: int = 3, | |
progress_bar_refresh_rate: Optional[int] = None): | |
""" | |
Calculate the value of the Optuna objective function for the given Optuna trial. | |
This function calculates the loss on the test set for the given Optuna trial. | |
param: trial: The instance of the Optuna trial for which the model will be trained and evaluated. | |
param: data_path: The path to the data set for the model. | |
param: available_gpus: Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node. | |
param: max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). If both | |
max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. To enable infinite training, set | |
``max_epochs = -1``. | |
param: progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. | |
Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means a suitable | |
value will be chosen based on the environment (terminal, Google COLAB, etc.). | |
""" | |
learning_rate = trial.suggest_float('learning_rate', 0.0001, 0.1) | |
model = LitMNIST( | |
data_path, | |
learning_rate=learning_rate, | |
) | |
trainer = Trainer( | |
gpus=available_gpus, | |
max_epochs=max_epochs, | |
callbacks=[TQDMProgressBar(refresh_rate=progress_bar_refresh_rate)], | |
) | |
trainer.fit(model) | |
test_results = trainer.test(model) | |
return test_results[0]['val_loss'] | |
class HyperparameterOptimisationFlow(FlowSpec): | |
"""Flow class for performing hyperparameter optimisation of the model defined in model.py on MNist data set.""" | |
study_name_prefix: str = Parameter( | |
'study_name_prefix', | |
help='The prefix of the name of the study', | |
default='mnist_hyperparameter_study', | |
type=str, | |
) | |
n_trials: int = Parameter( | |
'n_trials', | |
help='An upper limit on the number of Optuna trials to perform.', | |
type=int, | |
required=True, | |
) | |
n_tasks: int = Parameter( | |
'n_tasks', | |
help='The number of Metaflow tasks over which to split the hyperparameter optimisation.', | |
type=int, | |
required=True, | |
) | |
def __init__(self, use_cli=True): | |
"""Initialise the flow class to perform model hyperparameter optimisation.""" | |
super(HyperparameterOptimisationFlow, self).__init__(use_cli=use_cli) | |
self.PATH_DATASETS = None | |
self.AVAIL_GPUS = None | |
self.BATCH_SIZE = None | |
self.study_name: Optional[str] = None | |
self.n_trials_by_task: Optional[List[int]] = None | |
self.trials_dataframe: Optional[pd.DataFrame] = None | |
@step | |
def start(self): | |
"""Initialise constants to be used during model training.""" | |
self.PATH_DATASETS = os.environ.get('PATH_DATASETS', '.') | |
self.AVAIL_GPUS = min(1, torch.cuda.device_count()) | |
self.BATCH_SIZE = 256 if self.AVAIL_GPUS else 64 | |
self.next(self.download_dataset) | |
@step | |
def download_dataset(self): | |
"""Download MNist data set.""" | |
LitMNIST(self.PATH_DATASETS).prepare_data() | |
self.next(self.initialise_study) | |
@kubernetes(secrets=['optunasecret']) | |
@step | |
def initialise_study(self): | |
"""Initialise the Optuna study with the specified storage.""" | |
self.study_name = construct_study_name(self.study_name_prefix) | |
logger.info('Creating study: %s', self.study_name) | |
optuna.create_study( | |
direction='maximize', | |
study_name=self.study_name, | |
storage=construct_study_storage_url(), | |
) | |
self.n_trials_by_task = calculate_n_trials_by_task(self.n_trials, self.n_tasks) | |
logger.info('Specifying number of trials per task: %s', self.n_trials_by_task) | |
self.next(self.optimise_study, foreach='n_trials_by_task') | |
@kubernetes(secrets=['optunasecret']) | |
@step | |
def optimise_study(self): | |
"""Perform the optimisation of the Optuna study.""" | |
partial_objective = partial( | |
objective, | |
data_path=self.PATH_DATASETS, | |
available_gpus=self.AVAIL_GPUS, | |
progress_bar_refresh_rate=20, | |
) | |
study = optuna.load_study( | |
study_name=self.study_name, | |
storage=construct_study_storage_url(), | |
) | |
study.optimize( | |
partial_objective, | |
n_trials=self.input, | |
) | |
self.next(self.summarise_study_results) | |
@kubernetes(secrets=['optunasecret']) | |
@step | |
def summarise_study_results(self, inputs): | |
""" | |
Summarise the results of the study, including saving the trial dataframe to the data artifacts. | |
This function summarises the results of the Optuna study. | |
param: inputs: The inputs from the preceding foreach tasks. | |
""" | |
self.merge_artifacts(inputs) | |
study = optuna.load_study( | |
study_name=self.study_name, | |
storage=construct_study_storage_url(), | |
) | |
logger.info('Number of finished trials: %s', len(study.trials)) | |
logger.info('Best trial: %s', study.best_trial) | |
self.trials_dataframe = study.trials_dataframe().to_json() | |
self.next(self.end) | |
@step | |
def end(self): | |
"""Perform some final logging.""" | |
logger.info('Dataset Paths: %s', self.PATH_DATASETS) | |
logger.info('Available GPUs: %s', self.AVAIL_GPUS) | |
logger.info('Batch Size: %s', self.BATCH_SIZE) | |
if __name__ == '__main__': | |
HyperparameterOptimisationFlow() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment