Skip to content

Instantly share code, notes, and snippets.

@sllynn
Created September 20, 2019 07:36
Show Gist options
  • Save sllynn/36441b1cbb3258f2e619ed0f896b5a97 to your computer and use it in GitHub Desktop.
Save sllynn/36441b1cbb3258f2e619ed0f896b5a97 to your computer and use it in GitHub Desktop.
Custom mlflow pyfunc wrapper for Keras models
import mlflow.pyfunc
import mlflow.keras
class KerasWrapper(mlflow.pyfunc.PythonModel):
def __init__(self, keras_model_name):
self.keras_model_name = keras_model_name
def load_context(self, context):
self.keras_model = mlflow.keras.load_model(model_uri=context.artifacts[self.keras_model_name], compile=False)
def predict(self, context, input_data):
import pandas as pd
import numpy as np
input_data_padded = np.stack(input_data["0"].apply(np.stack, axis=0), axis=0)
scores = self.keras_model.predict(input_data_padded)
predicted_class = np.apply_along_axis(np.argmax, 1, scores)
scores.sort(axis=1)
marginal_confidence = np.apply_along_axis(lambda x: x[-1] - x[-2], 1, scores)
predicted = pd.DataFrame(
data=dict(
predicted_class=predicted_class.astype(dtype="float32"),
marginal_confidence=marginal_confidence)
)
return predicted
@strangan
Copy link

Hi did you write this wrapper to get over the "cannot pickle "weakref" object error when trying to log a keras model using mlflow? I used your Keras wrapper class, but i'm still running into this error when I use mlflow.pyfunc.log_model(path, python_model=wrappedModel).

Here's a link to a question i posted on stackoverflow describing my problem.

https://stackoverflow.com/questions/70385171/mlflow-on-databricks-cannot-log-a-keras-model-as-a-mlflow-pyfunc-model-get-ty

@sllynn
Copy link
Author

sllynn commented Dec 17, 2021

I wrote this to postprocess the model's class probabilities into a predicted class and confidence.

The error you're getting is because your model isn't serialisable. That's something to do with the way keras is implemented (I don't fully understand the details) but I'm afraid it will prevent you using keras and mlflow.pyfunc together until this issue gets resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment