Created
April 14, 2024 08:46
-
-
Save devig/e529f1920dec0b61a5dc5ef47ecc3ae8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM | |
import time | |
def custom_chat_template(chat): | |
# This template concatenates role and content, each on new line | |
return "\n".join(f"{turn['role']}: {turn['content']}" for turn in chat) | |
def generate_response(system_prompt, user_prompt): | |
try: | |
print("Initializing the model and tokenizer...") | |
model_name = "macadeliccc/WestLake-7B-v2-laser-truthy-dpo" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Apply the custom chat template to the tokenizer | |
tokenizer.chat_template = custom_chat_template | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | |
# Explicitly set the device to CPU | |
device = torch.device("cpu") | |
print(f"Using device: {device}") | |
print("Setting up the text generation pipeline...") | |
text_gen_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=-1 # -1 indicates to use CPU | |
) | |
print("Preparing the chat template...") | |
chat = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
] | |
prompt = tokenizer.chat_template(chat) # Using the custom chat template | |
print(f"Generated prompt: {prompt[:50]}...") # Print first 50 characters of the prompt | |
print("Generating response...") | |
start_time = time.time() | |
outputs = text_gen_pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.1, top_k=50, top_p=0.95) | |
print(f"Response generated in {time.time() - start_time:.2f} seconds.") | |
return outputs[0]['generated_text'] | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") | |
return None | |
# Example usage | |
system_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
user_prompt = "Give me top 3 jokes" | |
response = generate_response(system_prompt, user_prompt) | |
print(response if response else "No response generated.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Last that I see in console it Generating response...
Then it just stacked on pipeline