Skip to content

Instantly share code, notes, and snippets.

@sxjscience
Last active June 18, 2021 09:04
Show Gist options
  • Save sxjscience/9499a1be01288049b8192a55bb51d2b7 to your computer and use it in GitHub Desktop.
Save sxjscience/9499a1be01288049b8192a55bb51d2b7 to your computer and use it in GitHub Desktop.
import time
import torch
from transformers import GPTNeoForCausalLM, AutoConfig, GPT2Tokenizer
import torch
import hashlib
import transformers
import argparse
import collections
import os
import logging
import requests
from tqdm import tqdm
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if (logger.hasHandlers()):
logger.handlers.clear()
console = logging.StreamHandler()
logger.addHandler(console)
parser = argparse.ArgumentParser(description='Generate sample from GPT-6B model.')
parser.add_argument('--max_length', type=int, default='512')
parser.add_argument('--top_p', type=float, default=0.9)
parser.add_argument('--context', type=str, default="Why AutoGluon is great?")
args = parser.parse_args()
def check_sha1(filename, sha1_hash):
"""Check whether the sha1 hash of the file content matches the expected hash.
Parameters
----------
filename : str
Path to the file.
sha1_hash : str
Expected sha1 hash in hexadecimal digits.
Returns
-------
bool
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, 'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
return sha1.hexdigest() == sha1_hash
def download(url, path=None, overwrite=False, sha1_hash=None):
"""Download files from a given URL.
Parameters
----------
url : str
URL where file is located
path : str, optional
Destination path to store downloaded file. By default stores to the
current directory with same name as in url.
overwrite : bool, optional
Whether to overwrite destination file if one already exists at this location.
sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits (will ignore existing file when hash is specified
but doesn't match).
Returns
-------
str
The file path of the downloaded file.
"""
if path is None:
fname = os.path.join(url.split('/')[-2],url.split('/')[-1])
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split('/')[-1])
else:
fname = path
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
logger.info('Downloading %s from %s...'%(fname, url))
r = requests.get(url, stream=True)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url)
total_length = r.headers.get('content-length')
with open(fname, 'wb') as f:
if total_length is None: # no content length header
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
else:
total_length = int(total_length)
for chunk in tqdm(r.iter_content(chunk_size=1024),
total=int(total_length / 1024. + 0.5),
unit='KB', unit_scale=False, dynamic_ncols=True):
f.write(chunk)
if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning('File {} is downloaded but the content hash does not match. ' \
'The repo may be outdated or download may be incomplete. ' \
'If the "repo_url" is overridden, consider switching to ' \
'the default repo.'.format(fname))
return fname
def main():
urls = [('https://zhisu-nlp.s3.us-west-2.amazonaws.com/gpt-j-hf/config.json',
'a0af27bcff3c0fa17ec9718ffb6060b8db5e54e4'),
('https://zhisu-nlp.s3.us-west-2.amazonaws.com/gpt-j-hf/pytorch_model.bin', None)]
for (url, sha1_hash) in urls:
download(url, sha1_hash=sha1_hash)
print("download finished", flush=True)
config = './gpt-j-hf/config.json'
print("Load the GPT-6B model", flush=True)
model = GPTNeoForCausalLM.from_pretrained("./gpt-j-hf")
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model.eval()
model.half().cuda() # This should take about 12GB of Graphics RAM, if you have a larger than 16GB gpu you don't need the half()
print('Loaded!')
input_text = args.context
input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda()
output = model.generate(
input_ids,
do_sample=True,
max_length=args.max_length,
top_p=args.top_p,
top_k=0,
temperature=1.0,
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment