Last active
January 7, 2024 08:24
-
-
Save bathtimefish/53df12ea64e758d9cb3b4de11007c946 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "4e7a48be-d8f7-4443-b9c6-c1d728fbc8e1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from langchain import HuggingFacePipeline, PromptTemplate, LLMChain\n", | |
"from langchain.prompts import (\n", | |
" ChatPromptTemplate,\n", | |
" SystemMessagePromptTemplate,\n", | |
" HumanMessagePromptTemplate,\n", | |
")\n", | |
"from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n", | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "50bcc067-a43a-4a6e-8a41-158c8e7efc10", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# 基本パラメータ\n", | |
"model_name = \"elyza/ELYZA-japanese-Llama-2-7b-instruct\"\n", | |
"task = \"text-generation\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "7c0b2818-3f10-40ac-99d1-40b4e0a7fe9d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"device(type='cuda')" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# GPUの確認\n", | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"device" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "568f42e6-485c-4d16-8946-b5eeed8f2d57", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "7420d1901cd6496db92701d74f7dac62", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# モデルのダウンロード\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | |
"model = AutoModelForCausalLM.from_pretrained(model_name).to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "672151dc-5780-4175-8635-b190fa8e9fb3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# LLMs: langchainで上記モデルを利用する\n", | |
"pipe = pipeline(\n", | |
" task, \n", | |
" model=model,\n", | |
" tokenizer=tokenizer,\n", | |
" device=0, # GPUを使うことを指定 (cuda:0と同義)\n", | |
" framework='pt', # モデルをPyTorchで読み込むことを指定\n", | |
" max_new_tokens=1024,\n", | |
" #temperature=0.1,\n", | |
")\n", | |
"\n", | |
"llm = HuggingFacePipeline(pipeline=pipe)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "4a638e89-5526-47e4-b5eb-c2c5c6a4dad6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Prompts: プロンプトを作成\n", | |
"template = \"<s>[INST] <<SYS>>あなたはユーザの質問に回答する優秀なアシスタントです。以下の質問に可能な限り丁寧に回答してください。 <</SYS>>\\n\\n{question}[/INST]\"\n", | |
"prompt = PromptTemplate.from_template(template)\n", | |
"#sys_prompt = SystemMessagePromptTemplate.from_template(\"あなたはユーザの質問に回答する優秀なアシスタントです。以下の質問に可能な限り丁寧に回答してください。\")\n", | |
"#hum_prompt = HumanMessagePromptTemplate.from_template(\"{question}\")\n", | |
"#prompt = ChatPromptTemplate.from_messages([sys_prompt, hum_prompt])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "65b6708e-4633-413f-96b2-fea3a7579937", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Chains: llmを利用可能な状態にする\n", | |
"llm_chain = LLMChain(prompt=prompt, llm=llm, verbose=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "b2af34d3-768a-46aa-87ba-1dd33dba1bb2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"question = \"カレーライスとは何ですか?\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "58db6470-64fc-4ad2-b1f3-e2c4a91fc017", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"\n", | |
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", | |
"Prompt after formatting:\n", | |
"\u001b[32;1m\u001b[1;3m<s>[INST] <<SYS>>あなたはユーザの質問に回答する優秀なアシスタントです。以下の質問に可能な限り丁寧に回答してください。 <</SYS>>\n", | |
"\n", | |
"カレーライスとは何ですか?[/INST]\u001b[0m\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/user/.pyenv/versions/3.10.12/lib/python3.10/site-packages/transformers/generation/utils.py:1411: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation )\n", | |
" warnings.warn(\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"\u001b[1m> Finished chain.\u001b[0m\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"' カレーライスとは、ご飯の上にカレーが乗っている料理のことです。一般的には、米を炊いた後、その米をカレーのルーで炒めたり、浸けたりして作ります。また、具材としてチキンやエビ、肉などが入ったり、サラダやヨーグルトが添えられることもあります。'" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"llm_chain.run({\"question\":question})" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment