Created
February 4, 2023 00:52
-
-
Save tuulos/8e22d5e64bc93fb143d774a146f9170b to your computer and use it in GitHub Desktop.
Train a model with a config file using Metaflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from metaflow import FlowSpec, step, IncludeFile | |
def dataset_wine(): | |
from sklearn import datasets | |
return datasets.load_wine(return_X_y=True) | |
def model_knn(train_data, train_labels): | |
from sklearn.neighbors import KNeighborsClassifier | |
model = KNeighborsClassifier() | |
model.fit(train_data, train_labels) | |
return model | |
def model_svm(train_data, train_labels): | |
from sklearn import svm | |
model = svm.SVC(kernel='poly') | |
model.fit(train_data, train_labels) | |
return model | |
MODELS = {'knn': model_knn, | |
'svm': model_svm} | |
DATASETS = {'wine': dataset_wine} | |
class TrainWithConfigFlow(FlowSpec): | |
config_file = IncludeFile('config', default='config.json') | |
@step | |
def start(self): | |
import json | |
self.config = json.loads(self.config_file) | |
self.next(self.load_data) | |
@step | |
def load_data(self): | |
from sklearn.model_selection import train_test_split | |
print('Loading dataset', self.config['dataset']) | |
X, y = DATASETS[self.config['dataset']]() | |
self.train_data,\ | |
self.test_data,\ | |
self.train_labels,\ | |
self.test_labels = train_test_split(X, y, test_size=0.2, random_state=0) | |
self.next(self.train) | |
@step | |
def train(self): | |
print("Training model", self.config['model']) | |
self.model = MODELS[self.config['model']](self.train_data, | |
self.train_labels) | |
self.next(self.end) | |
@step | |
def end(self): | |
self.score = self.model.score(self.test_data, self.test_labels) | |
print('Eval score', self.score) | |
if __name__ == '__main__': | |
TrainWithConfigFlow() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment