Created
July 14, 2020 11:09
-
-
Save sansagara/e589a8b7cc2cc035c3842d83635fcfd1 to your computer and use it in GitHub Desktop.
MLflow Log Model decorator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import logging | |
import subprocess | |
import mlflow # type: ignore | |
import mlflow.sklearn # type: ignore | |
logger = logging.getLogger(__name__) | |
def log_skl_model(func): | |
""" | |
Log a scikit-learn model statistics to MLflow. | |
The decorated function must have a `model_reporting` kwargfrom parameters.yml. | |
The decorated function must return a tuple consisting on the `sklearn.pipeline.Pipeline`object and | |
the statistics to log. | |
Example: :: | |
from util.mlflow import log_skl_model | |
@log_skl_model | |
def my_model_run(arg1, arg2, arg3, model_reporting=model_reporting) | |
... | |
... | |
return skl_pipeline, {statistic1: value1, statistic2: value2, ...} | |
""" | |
@functools.wraps(func) | |
def wrapper_mlflow(*args, **kwargs): | |
# Get URI and Experiment for MLflow | |
model_reporting = kwargs.get("model_reporting") | |
if not model_reporting: | |
raise ValueError( | |
"You must set `model_reporting` kwarg to use the log_model util. " | |
"Pass from parameters.yml" | |
) | |
uri = model_reporting.get("tracking_uri") | |
experiment = kwargs.get("experiment") or model_reporting.get( | |
"default_experiment" | |
) | |
if not uri and experiment: | |
raise ValueError( | |
"The `model_reporting` must be a dict containing uri and default_experiment keys" | |
) | |
logger.info( | |
"Logging model performance using MLflow. URI: {} Experiment: {}".format( | |
uri, experiment | |
) | |
) | |
# Debug | |
args_repr = [repr(a) for a in args] | |
kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()] | |
signature = ", ".join(args_repr + kwargs_repr) | |
logger.debug(f"Calling {func.__name__}({signature})") | |
# Run decorated function and get model/statistics tuple | |
skl_pipeline, statistics = func(*args, **kwargs) | |
# Start the MLflow context | |
mlflow.set_tracking_uri(uri) | |
experiment_id = mlflow.create_experiment(experiment) | |
commit_hash = ( | |
subprocess.check_output(["git", "log", "-1", "--pretty=format:%h"]) or None | |
) | |
commit_msg = ( | |
subprocess.check_output(["git", "log", "-1", "--pretty=format:%B"]) or None | |
) | |
mlflow.set_tags( | |
{ | |
"mlflow.runName": commit_hash, | |
"git.hash": commit_hash, | |
"git.msg": commit_msg, | |
"calling_funtion.name": func.__name__, | |
"calling_funtion.args": args_repr, | |
"calling_funtion.kwargs": kwargs_repr, | |
} | |
) | |
mlflow.start_run(experiment_id=experiment_id) | |
# Log parameters | |
for parameter_name, value in skl_pipeline.get_params().items(): | |
try: | |
value = float(value or 0) | |
mlflow.log_param(parameter_name, value) | |
except (ValueError, TypeError): | |
continue | |
# Log statistics | |
for key in statistics: | |
mlflow.log_metric(key, statistics[key]) | |
# Log model | |
mlflow.sklearn.log_model(skl_pipeline.steps[-1][1], "model") | |
mlflow.end_run() | |
return skl_pipeline | |
return wrapper_mlflow |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Test the model_tracking utility built around mlflow.""" | |
# pylint: disable=too-few-public-methods | |
# pylint: skip-file | |
# flake8: noqa | |
import pytest | |
from sklearn.datasets import fetch_20newsgroups | |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer | |
from sklearn.metrics import f1_score | |
from sklearn.model_selection import cross_val_score | |
from sklearn.pipeline import Pipeline | |
from sklearn.svm import LinearSVC | |
from project_utsugi.nodes.common.utils.model_tracking import log_skl_model | |
def test_log_skl_model(dummy_string, mock_model_reporting): | |
@log_skl_model | |
def test_skl_pipeline(dummy_arg, model_reporting): | |
_, _ = dummy_arg, model_reporting | |
cats = ["alt.atheism", "sci.space"] | |
newsgroups_train = fetch_20newsgroups(subset="train", categories=cats) | |
newsgroups_test = fetch_20newsgroups(subset="test", categories=cats) | |
x_train, x_test = newsgroups_train.data, newsgroups_test.data | |
y_train, y_test = newsgroups_train.target, newsgroups_test.target | |
pipeline = Pipeline( | |
[ | |
("vect", CountVectorizer()), | |
("tfidf", TfidfTransformer()), | |
("clf", LinearSVC()), | |
] | |
) | |
# now train and predict test instances | |
pipeline.fit(x_train, y_train) | |
y_pred = pipeline.predict(x_test) | |
# get scores | |
cross = cross_val_score(pipeline, x_train, y_train, cv=3, scoring="f1_micro") | |
f1 = f1_score(y_test, y_pred, average="micro") | |
return pipeline, {"mean_cross_score": cross.mean(), "f1_score": f1} | |
try: | |
test_skl_pipeline(dummy_string, model_reporting=mock_model_reporting) | |
except Exception as e: | |
raise pytest.fail("log_skl_model decorator raised {0}".format(e)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment