Last active
March 16, 2021 14:59
-
-
Save oscar-defelice/4e8516451d3b27a807720b17f30973c1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from tensorflow.keras.models import load_model | |
from .services import | |
class PriceEstimator: | |
""" | |
PriceEstimator object to collect prediction methods to be accessed | |
by API services. | |
""" | |
def __init__(self, model): | |
self.model = model | |
@classmethod | |
def from_pretrained(cls, pretrained_model_path): | |
model = load_model(pretrained_model_path) | |
return cls(model) | |
def predict(self, X): | |
X = feature_normalisation(X) | |
preds = self.model.predict(X) | |
return [{f'Prediction price {i}': float(prediction)} for i, prediction in enumerate(preds)] | |
@staticmethod | |
def feature_normalisation(data): | |
""" | |
feature_normalisation function. | |
It takes data array and returns it with feature normalised. | |
Arguments: | |
data np.array of shape (n_training_example, n_features) | |
Returns: | |
data_normalised np.array of shape (n_training_example, n_features) | |
""" | |
data_normalised = data | |
mean = data_normalised.mean(axis=0) | |
data_normalised -= mean | |
std = data_normalised.std(axis=0) | |
data_normalised /= std | |
assert data_normalised.shape == data.shape, "Data leaking!" | |
return data_normalised |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment