Created
August 5, 2019 18:03
-
-
Save fallenflint/31543a6ab9b283bf1ae52e76ae88d973 to your computer and use it in GitHub Desktop.
Example application which leads all over training process via Python CLI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| import os | |
| import sys | |
| import time | |
| from auger.api.dataset import DataSet | |
| from auger.api.model import Model | |
| from auger.api.project import Project | |
| from auger.api.experiment import Experiment, AugerExperimentSessionApi | |
| from auger.api.utils.context import Context | |
| from auger.cli.utils.config import AugerConfig | |
| PREDICTION_SOURCE = 'files/iris_data_test.csv' | |
| PREDICTION_TARGET = 'files/irist_data_set_predict.csv' | |
| DATASET_NAME = 'iris.csv' | |
| log = logging.getLogger('auger.ai') | |
| class ExampleApp(): | |
| def __init__(self, ctx): | |
| self.ctx = ctx | |
| def print_(*args, **kwargs): | |
| print(*args, **kwargs) | |
| self.ctx.log = print_ | |
| self.dataset = None | |
| self.model_id = None | |
| self.project = Project( | |
| self.ctx, self.ctx.get_config('auger').get('project', None)) | |
| self.project.create() | |
| def _get_datasets(self): | |
| """return list of existing datasets in the Aguer Cloud""" | |
| dataset_list = [] | |
| for dataset in iter(DataSet(self.ctx, self.project).list()): | |
| dataset_list.append(dataset['name']) | |
| return dataset_list | |
| def _start_experiment(self, experiment_name): | |
| experiment = Experiment(self.ctx, self.dataset, experiment_name) | |
| experiment_name, session_id = \ | |
| experiment.start() | |
| AugerConfig(self.ctx).set_experiment(experiment_name, session_id) | |
| return experiment, session_id | |
| def prepare_dataset(self): | |
| """check whether dataset selected, if not, select or create one""" | |
| self.ctx.log("Checking dataset...") | |
| selected = self.ctx.get_config('config').get('dataset', None) | |
| if selected is None: | |
| if DATASET_NAME in self._get_datasets(): | |
| # try to select existing | |
| AugerConfig(self.ctx).set_data_set( | |
| DATASET_NAME, '').set_experiment(None) | |
| self.dataset = DataSet(self.ctx, self.project, DATASET_NAME) | |
| else: | |
| # or create new | |
| self.ctx.log("No dataset found, creating the first one...") | |
| source = self.ctx.get_config('config').get('source', None) | |
| dataset = DataSet(self.ctx, self.project).create(source) | |
| AugerConfig(self.ctx).set_data_set( | |
| dataset.name, source).set_experiment(None) | |
| self.dataset = DataSet(self.ctx, self.project, dataset.name) | |
| else: | |
| self.dataset = DataSet(self.ctx, self.project, selected) | |
| self.ctx.log("Currently selected: %s" % selected) | |
| def run_experiment(self): | |
| experiment_name = self.ctx.get_config('auger').get( | |
| 'experiment/name', None) | |
| self.experiment, run_id = self._start_experiment(experiment_name) | |
| self.ctx.log("waiting for experiment %s to finish" % experiment_name) | |
| self.experiment.wait(run_id) | |
| leaderboard, status = self.experiment.leaderboard(run_id) | |
| self.model_id = leaderboard[0]['model id'] | |
| def deploy(self): | |
| Model(self.ctx, self.project).deploy(self.model_id, locally=True) | |
| def predict(self): | |
| if os.path.exists(PREDICTION_TARGET): | |
| self.ctx.log( | |
| "Prediction already exists." | |
| " If you want to re-run predict, just delete prediction file: " % | |
| PREDICTION_TARGET) | |
| else: | |
| Model(self.ctx, self.project).predict( | |
| PREDICTION_SOURCE, self.model_id, locally=True) | |
| def cleanup(self): | |
| self.project.delete() | |
| def main(): | |
| context = Context() | |
| try: | |
| app = ExampleApp(context) | |
| app.prepare_dataset() | |
| time.sleep(15) # wait for redis is ready | |
| app.run_experiment() | |
| app.deploy() | |
| app.predict() | |
| app.cleanup() | |
| except Exception as e: | |
| import traceback; traceback.print_exc(); | |
| context.log( | |
| "Example application execution has failed with error: '%s'" % | |
| str(e)) | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment