Last active
May 31, 2023 21:28
-
-
Save naturalett/1fb33e337d3b664b0f8431613b2d5dea to your computer and use it in GitHub Desktop.
Iris Classification
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from airflow import DAG | |
from airflow.operators.python_operator import PythonOperator | |
from datetime import datetime | |
import mysql.connector | |
from sklearn.datasets import load_iris | |
from sklearn.model_selection import train_test_split | |
from sklearn.tree import DecisionTreeClassifier | |
import joblib | |
from datetime import timedelta | |
import os | |
import numpy as np | |
import random | |
# Set the Airflow base directory | |
DAGS_DIR = os.environ.get('AIRFLOW__CORE__DAGS_FOLDER', '/opt/airflow/dags') | |
def train_and_export_model(): | |
# Load the dataset | |
iris = load_iris() | |
X = iris.data | |
y = iris.target | |
# Split the data into training and testing sets | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
# Train the model | |
model = DecisionTreeClassifier() | |
print(f"model: {model}") | |
model.fit(X_train, y_train) | |
# Generate a unique folder path with timestamp | |
unique_folder = datetime.now().strftime('%Y%m%d%H%M%S') | |
folder_path = os.path.join(DAGS_DIR, unique_folder) | |
# Create the folder if it doesn't exist | |
os.makedirs(folder_path, exist_ok=True) | |
# Generate a unique path for the .pkl file | |
model_filename = 'iris_model.pkl' | |
model_path = os.path.join(folder_path, model_filename) | |
# Export the model | |
joblib.dump(model, model_path) | |
# Store the model path in MySQL | |
connection = mysql.connector.connect( | |
host='mysql.default.svc.cluster.local', | |
user='root', | |
password='password', | |
database='my_database' | |
) | |
cursor = connection.cursor() | |
try: | |
# Alter the existing table to add the model_path column | |
alter_table_query = ''' | |
ALTER TABLE models | |
ADD COLUMN model_path VARCHAR(255) | |
''' | |
cursor.execute(alter_table_query) | |
except mysql.connector.Error as err: | |
# Handle the error if the column already exists | |
if err.errno == 1060: | |
print("Column 'model_path' already exists in table 'models'") | |
else: | |
print("Error:", err) | |
# Create a table to store the model paths | |
create_table_query = ''' | |
CREATE TABLE IF NOT EXISTS models ( | |
id INT AUTO_INCREMENT PRIMARY KEY, | |
model_path VARCHAR(255) | |
) | |
''' | |
cursor.execute(create_table_query) | |
# Insert the model path into the table | |
insert_query = ''' | |
INSERT INTO models (model_path) VALUES (%s) | |
''' | |
cursor.execute(insert_query, (model_path,)) | |
connection.commit() | |
# Close the connection | |
cursor.close() | |
connection.close() | |
def load_model_from_database(): | |
connection = mysql.connector.connect( | |
host='mysql.default.svc.cluster.local', | |
user='root', | |
password='password', | |
database='my_database' | |
) | |
cursor = connection.cursor() | |
select_query = ''' | |
SELECT model_path FROM models ORDER BY id DESC LIMIT 1 | |
''' | |
cursor.execute(select_query) | |
result = cursor.fetchone() | |
if result is not None: | |
model_path = result[0] | |
# Load the model from the stored path | |
loaded_model = joblib.load(os.path.join(DAGS_DIR, model_path)) | |
# Use the loaded model for inference or further processing | |
species_names = ['setosa', 'versicolor', 'virginica'] | |
# Generate a random input for each species | |
for species in species_names: | |
sepal_length = random.uniform(4.0, 8.0) | |
sepal_width = random.uniform(2.0, 4.5) | |
petal_length = random.uniform(1.0, 7.0) | |
petal_width = random.uniform(0.1, 2.5) | |
X_new = np.array([[sepal_length, sepal_width, petal_length, petal_width]]) # Input for prediction | |
prediction = loaded_model.predict(X_new) | |
print("Species:", species) | |
print("Prediction:", prediction) | |
# Generate a link to view images of the predicted species | |
species_name = species.lower() | |
image_link = f"https://en.wikipedia.org/wiki/Iris_{species_name}" | |
print("Image Link:", image_link) | |
print() | |
# Remove the model file | |
os.remove(os.path.join(DAGS_DIR, model_path)) | |
cursor.close() | |
connection.close() | |
default_args = { | |
'start_date': datetime(2023, 5, 1), | |
'retries': 3, | |
'retry_delay': timedelta(minutes=5) | |
} | |
dag = DAG( | |
'iris_classification_try_me_again', | |
default_args=default_args, | |
description="An ETL with lineage", | |
tags=['workshop', 'ETL'], | |
schedule_interval=None | |
) | |
with dag: | |
train_export_operator = PythonOperator( | |
task_id='train_and_export', | |
python_callable=train_and_export_model, | |
outlets={ | |
"tables": ["mysql.my_database.my_database.models"] | |
}, | |
inlets={ | |
"tables": ["mysql.my_database.my_database.models"] | |
} | |
) | |
load_model_operator = PythonOperator( | |
task_id='load_model', | |
python_callable=load_model_from_database, | |
) | |
train_export_operator >> load_model_operator # Define the task dependency |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment