Created
July 11, 2023 17:54
-
-
Save epicfilemcnulty/af56dd310166b5892d9cfcbfe1b53207 to your computer and use it in GitHub Desktop.
simple HF tranformers inference (HTTP API wrapped)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import transformers | |
import transformers.models.llama.modeling_llama | |
def enable_ntk_rope_scaling(alpha=4): | |
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ | |
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None): | |
max_position_embeddings = 2048*alpha | |
a = alpha | |
base = base * a ** (dim / (dim-2)) | |
old_init(self, dim, max_position_embeddings, base, device) | |
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import time | |
import torch | |
import uuid | |
import os | |
from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline, BitsAndBytesConfig | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
import transformers | |
from bottle import Bottle, run, route, request | |
from utils.ntk_rope_scale import enable_ntk_rope_scaling | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-m', '--model', required=False, type=str, default='/storage/models/LLaMA/FP16/Wizard-Vicuna-Uncensored-13B', help="Grasping Model") | |
parser.add_argument('-a', '--model_name', required=False, type=str, default="WizVicUncen13.NF4", help="Model's alias") | |
parser.add_argument('-4', '--four_bit', required=False, type=bool, default=True, help="Load in 4 bit") | |
parser.add_argument('-A', '--alpha', default=1, required=False, type=int, help="NTK Scaled RoPE's alpha") | |
parser.add_argument('-c', '--context', default=2048, required=False, type=int, help="Context length") | |
parser.add_argument('--port', default=8013, required=False, type=int, help="Port to listen on") | |
parser.add_argument('--ip', default='127.0.0.1', required=False, type=str, help="IP to listen on") | |
args = parser.parse_args() | |
if args.alpha != 1: | |
enable_ntk_rope_scaling(args.alpha) | |
app = Bottle() | |
class StoppingCriteriaSub(StoppingCriteria): | |
def __init__(self, stops = [], encounters=1): | |
super().__init__() | |
self.stops = [stop.to("cuda") for stop in stops] | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
for stop in self.stops: | |
if torch.all((stop == input_ids[0][-len(stop):])).item(): | |
return True | |
return False | |
def load_model(): | |
model_id = args.model | |
tokenizer = LlamaTokenizer.from_pretrained(model_id) | |
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) | |
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' | |
n_gpus = torch.cuda.device_count() | |
max_memory = {i: max_memory for i in range(n_gpus)} | |
if args.four_bit: | |
nf4_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
model = LlamaForCausalLM.from_pretrained(model_id, device_map='auto', quantization_config=nf4_config, max_memory=max_memory) | |
else: | |
model = LlamaForCausalLM.from_pretrained(model_id, device_map='auto', load_in_8bit=True, max_memory=max_memory) | |
print(model.generation_config) | |
return model, tokenizer | |
llm, tokenizer = load_model() | |
conversations = {} | |
def full_conversation(idx): | |
chat = '' | |
for message in conversations[idx]['messages']: | |
if message['role'] == 'system': | |
chat += message['content'] + '\n\n' | |
if message['role'] == 'user': | |
chat += conversations[idx]['prefix'] + ' ' + message['content'] + '\n' | |
if message['role'] == 'assistant': | |
chat += conversations[idx]['suffix'] + ' ' + message['content'] + '\n' | |
if conversations[idx]['messages'][-1]['role'] == 'user': | |
chat += conversations[idx]['suffix'] | |
return chat | |
@app.route('/prompt', method='PUT') | |
def set_prompt(): | |
data = request.json | |
conversation_uuid = data.get('uuid', str(uuid.uuid4())) | |
prompt = data.get('prompt', '') | |
messages = data.get('messages', [{'role':'system', 'content':prompt}]) | |
prefix = data.get('prefix', 'USER:') | |
suffix = data.get('suffix', 'ASSISTANT:') | |
conversations[conversation_uuid] = { | |
"messages": messages, | |
"prefix": prefix, | |
"suffix": suffix | |
} | |
return {"message": "Prompt set", "uuid": conversation_uuid} | |
@app.route('/chat', method='POST') | |
def chat(): | |
data = request.json | |
conversation_uuid = data['uuid'] | |
if conversation_uuid not in conversations: | |
return {"uuid":conversation_uuid, "message": "not found"} | |
temperature = data.get('temperature', 0.7) | |
max_new_tokens = data.get('max_length', 512) | |
query = data.get('query') | |
conversations[conversation_uuid]['messages'].append({'role':'user','content':query}) | |
full_ctx = full_conversation(conversation_uuid) | |
stop_words = [conversations[conversation_uuid]['prefix']] | |
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words] | |
start_time = time.time_ns() | |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) | |
input_ids = tokenizer(full_ctx, return_tensors="pt").input_ids.to('cuda') | |
outputs = llm.generate( | |
input_ids, | |
do_sample=False, | |
num_beams=1, | |
stopping_criteria=stopping_criteria, | |
max_new_tokens = max_new_tokens, | |
temperature = temperature, | |
num_return_sequences=1, | |
remove_invalid_values=True, | |
) | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
answer = answer.replace(full_ctx,"") | |
conversations[conversation_uuid]['messages'].append({'role':'assistant','content':answer}) | |
new_tokens = len(outputs[0]) - len(input_ids[0]) | |
end_time = time.time_ns() | |
secs = (end_time - start_time) / 1e9 | |
return { | |
"text": answer, | |
"ctx": len(outputs[0]), | |
"tokens": new_tokens, | |
"rate": new_tokens / secs, | |
"model": args.model_name, | |
} | |
run(app, host=args.ip, port=args.port) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment