Last active
February 19, 2023 09:59
-
-
Save thuwarakeshm/c47d05526c545d329700e649f9ff2384 to your computer and use it in GitHub Desktop.
deploy ml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
# create and train a keras neural network | |
classifier = tf.keras.models.Sequential([ | |
tf.keras.layers.Dense(units=1, input_shape=[1]), | |
tf.keras.layers.Dense(units=28, activation='relu'), | |
tf.keras.layers.Dense(units=1) | |
]) | |
classifier.compile(optimizer='sgd', loss='mean_squared_error') | |
classifier.fit(x=[-1, 0, 1], y=[-3, -1, 1], epochs=5) | |
# Convert the model to a Tensorflow Lite object | |
converter = tf.lite.TFLiteConverter.from_keras_model(classifier) | |
tfl_classifier = converter.convert() | |
# Save the model as a .tflite file | |
with open('classifier.tflite', 'wb') as f: | |
f.write(tfl_classifier) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from sklearn.ensemble import RandomForestClassifier | |
df = pd.read_csv('titanic.csv') | |
x = df[df.columns.difference(['Survived']) | |
y = df['Survived'] | |
classifier = RandomForestClassifier() | |
classifier.fit(x, y) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.externals import joblib | |
joblib.dump(classifier, 'classifier.pkl') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
from datetime import timedelta, datetime | |
import pandas as pd | |
from prefect import task, Flow | |
from prefect.schedules import IntervalSchedule | |
@task(max_retries=3, retry_delay=timedelta(5)) | |
def predict(input_data_path:str): | |
""" | |
This task load the saved model, input data and returns prediction. | |
If failed this task will retry 3 times at 5 min interval and fail permenantly. | |
""" | |
classifier = joblib.load('classifier.pkl') | |
df = pd.read_csv(input_data_path) | |
prediction = classifier.predict(df) | |
return jsonify({'prediction': list(prediction)}) | |
@task(max_retries=3, retry_delay=timedelta(5)) | |
def save_prediction(data, output_data_path:str): | |
""" | |
This task will save the prediction to an output file. | |
If failed, this task will retry for 3 times and fail permenantly. | |
""" | |
with open(output_data_path, 'w') as f: | |
f.write(data) | |
# Create a schedule object. | |
# This object starts 5 seconds from the time of script execution and repeat once a week. | |
schedule = IntervalSchedule( | |
start_date=datetime.utcnow() + timedelta(seconds=5), | |
interval=timedelta(weeks=1), | |
) | |
# Attach the schedule object and orchastrate the workflow. | |
with Flow("predictions", schedule=schedule) as flow: | |
prediction = predict("./input_data.csv") | |
save_prediction(prediction. "./output_data.csv") | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask | |
app = Flask(__name__) | |
@app.route('/predict', methods=['POST']) | |
def predict(): | |
json_ = request.json | |
query_df = pd.DataFrame(json_) | |
query = pd.get_dummies(query_df) | |
classifier = joblib.load('classifier.pkl') | |
prediction = classifier.predict(query) | |
return jsonify({'prediction': list(prediction)}) | |
if __name__ == '__main__': | |
app.run(port=8080) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, thunk for posting this example, I think here is some err brackets not closing
https://gist.github.com/thuwarakeshm/c47d05526c545d329700e649f9ff2384/1d18360f32e4e28a14a82c811ad7e33edb0c7cd0#file-load-py