Created
September 20, 2019 07:36
-
-
Save sllynn/36441b1cbb3258f2e619ed0f896b5a97 to your computer and use it in GitHub Desktop.
Custom mlflow pyfunc wrapper for Keras models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mlflow.pyfunc | |
import mlflow.keras | |
class KerasWrapper(mlflow.pyfunc.PythonModel): | |
def __init__(self, keras_model_name): | |
self.keras_model_name = keras_model_name | |
def load_context(self, context): | |
self.keras_model = mlflow.keras.load_model(model_uri=context.artifacts[self.keras_model_name], compile=False) | |
def predict(self, context, input_data): | |
import pandas as pd | |
import numpy as np | |
input_data_padded = np.stack(input_data["0"].apply(np.stack, axis=0), axis=0) | |
scores = self.keras_model.predict(input_data_padded) | |
predicted_class = np.apply_along_axis(np.argmax, 1, scores) | |
scores.sort(axis=1) | |
marginal_confidence = np.apply_along_axis(lambda x: x[-1] - x[-2], 1, scores) | |
predicted = pd.DataFrame( | |
data=dict( | |
predicted_class=predicted_class.astype(dtype="float32"), | |
marginal_confidence=marginal_confidence) | |
) | |
return predicted |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I wrote this to postprocess the model's class probabilities into a predicted class and confidence.
The error you're getting is because your model isn't serialisable. That's something to do with the way keras is implemented (I don't fully understand the details) but I'm afraid it will prevent you using keras and mlflow.pyfunc together until this issue gets resolved.