Created
May 24, 2019 22:02
-
-
Save hunan-rostomyan/874014f2dedc45590eb3d1e8c3cea193 to your computer and use it in GitHub Desktop.
GPT-2 345M download_model fork
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import requests | |
from tqdm import tqdm | |
if len(sys.argv) != 3: | |
print('You must enter the model name as a parameter, e.g.: download_model.py 117M') | |
sys.exit(1) | |
model = sys.argv[1] | |
directory = sys.argv[2] | |
subdir = os.path.join(directory, model) | |
if not os.path.exists(subdir): | |
os.makedirs(subdir) | |
subdir = subdir.replace('\\','/') # needed for Windows | |
for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']: | |
r = requests.get("https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True) | |
with open(os.path.join(subdir, filename), 'wb') as f: | |
file_size = int(r.headers["content-length"]) | |
chunk_size = 1000 | |
with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: | |
# 1k for chunk_size, since Ethernet packet size is around 1500 bytes | |
for chunk in r.iter_content(chunk_size=chunk_size): | |
f.write(chunk) | |
pbar.update(chunk_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment