Last active
August 19, 2021 08:30
-
-
Save AntoineToubhans/40c436052e2c5437ae9cb50ad89d0e0a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import tensorflow as tf | |
# Warning: this is private internal dvc api, it may change with future version | |
import dvc.repo.get | |
ROOT_MODEL_CACHE_DIR = Path(".model_cache") | |
ROOT_MODEL_CACHE_DIR.mkdir(exist_ok=True) | |
@st.cache | |
def load_model(rev: str): | |
print(f"Loading model for revision {rev}") | |
# 1. Download model to model cache dir using `dvc get` | |
# See https://dvc.org/doc/command-reference/get | |
model_cache_dir = str(ROOT_MODEL_CACHE_DIR / rev) | |
# Try to load the model directly (if it is in cache dir) | |
try: | |
return tf.keras.models.load_model(model_cache_dir) | |
except OSError: | |
print(f"Could not find model {rev} in cache") | |
except Exception as e: | |
print(f"Could not load model {rev} from cache") | |
dvc.repo.get.get( | |
url=".", | |
path="data/train/model", | |
out=model_cache_dir, | |
rev=rev | |
) | |
print(f"Model downloaded to {model_cache_dir}") | |
# 2. Load the model with tf.keras.models.load_model | |
return tf.keras.models.load_model(model_cache_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment