Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Created February 9, 2025 13:13
Show Gist options
  • Save kohya-ss/bdc5cdaf6ff4566e42a50391b3cf61e5 to your computer and use it in GitHub Desktop.
Save kohya-ss/bdc5cdaf6ff4566e42a50391b3cf61e5 to your computer and use it in GitHub Desktop.
ノベルゲーム風簡易CUI:会話とプロンプト生成を別のchatにして一貫性を向上させる
# ライセンスは Apache License 2.0 です。
# お勧めのオプションは:
# --flash_attn --q8_kv_cache --n_gpu_layers <GPUのVRAMに応じて> --n_ctx <VRAMとモデルのcontext lengthに応じて>
import argparse
import random
import msvcrt # Windowsの場合
import toml
from typing import List, Optional, Union, Iterator
from datetime import datetime
import importlib
import os
import llama_cpp
from llama_cpp.llama_chat_format import _convert_completion_to_chat, register_chat_completion_handler
import llama_cpp.llama_types as llama_types
from llama_cpp.llama import LogitsProcessorList, LlamaGrammar
from llama_cpp import Llama, llama_chat_format
MESSAGE_HISTORY_SIZE = 20 # コンテキスト長に余裕があれば長くするとよい
# デフォルト設定
DEFAULT_CHARACTER_NAME = "美咲"
DEFAULT_USER_NAME = "彰"
DEFAULT_CHARACTER_SYSTEM_PROMPT = """あなたは高校二年生の女子生徒、美咲です。
ある日、学校の帰り道、突然の雨に見舞われ、ひとつの古びた洋館に避難することになりました。
その洋館には、あなたのクラスメイトである男子生徒、彰(ユーザー)がいました。
彼は、あなたに「ここには、幽霊が出ると言われている」と教えてくれました。
あなたは、彰と一緒に、その洋館の中を探索することになります。
ユーザーの言葉に対応して、美咲のセリフだけを出力してください。"""
DEFAULT_PROMPTER_SYSTEM_PROMPT = """あなたは有能な脚本家です。
女子高生・美咲と男子高生・彰(ユーザー)が、古びた洋館で出会い、幽霊の噂を聞いて探索するストーリーを考えてください。
会話の履歴から、適切な次のセリフを考えてください。"""
DEFAULT_PROMPTER_USER_PROMPT = """以下が直近の会話です。
---
$conversation
---
行動「$policy」を表す適切なユーザーの次のセリフを考えてください。セリフだけを出力してください。"""
DEFAULT_POLICIES = [
"美咲と会話する",
"消極的に行動する",
"積極的に行動する",
"あたりを調べる",
"核心に迫る",
"もしかしてあれは……?",
"逃げ出す",
"なんかいい感じに振る舞う",
]
llama_state_chat = None
llama_state_prompter = None
debug_mode = False
def log(log_file, message):
if log_file:
with open(log_file, "a", encoding="utf-8") as f:
f.write(message + "\n")
def debug(log_msg):
# TODO 標準のloggingモジュールを使う
if debug_mode:
print(log_msg)
# the latest llama.cpp seems to have "command-r" handler, but we keep this until llama-cpp-python is updated
# we can also use the chat template from GGUF
@register_chat_completion_handler("command-r")
def command_r_chat_handler(
llama: Llama,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
max_tokens: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
seed: Optional[int] = None,
**kwargs, # type: ignore
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
# bos_token = "<BOS_TOKEN>"
start_turn_token = "<|START_OF_TURN_TOKEN|>"
end_turn_token = "<|END_OF_TURN_TOKEN|>"
user_token = "<|USER_TOKEN|>"
chatbot_token = "<|CHATBOT_TOKEN|>"
system_token = "<|SYSTEM_TOKEN|>"
prompt = "" # bos_token # suppress warning
if len(messages) > 0 and messages[0]["role"] == "system":
prompt += start_turn_token + system_token + messages[0]["content"] + end_turn_token
messages = messages[1:]
for message in messages:
if message["role"] == "user":
prompt += start_turn_token + user_token + message["content"] + end_turn_token
elif message["role"] == "assistant":
prompt += start_turn_token + chatbot_token + message["content"] + end_turn_token
prompt += start_turn_token + chatbot_token
# prompt += start_turn_token + chatbot_token + "<think>" # temp: add <think> to test thinking mode
# if debug_flag:
# print(f"Prompt: {prompt}")
stop_tokens = [end_turn_token] # , bos_token]
return _convert_completion_to_chat(
llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop_tokens,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
seed=seed,
),
stream=stream,
)
def generate(llama, handler, generation_params, messages, max_tokens):
debug(f"messages: {messages}")
seed = random.randint(0, 2**31 - 1)
response = handler(llama=llama, messages=messages, stream=False, max_tokens=max_tokens, seed=seed, **generation_params)
debug(f"response: {response}")
content = response["choices"][0]["message"]["content"]
return content
def generate_chat(
llama: Llama,
handler,
tts,
tts_char_name,
tts_call_letters,
generation_params,
character_system_prompt,
conversation,
max_tokens,
):
# モデルの状態を復元
global llama_state_chat
if llama_state_chat is not None:
llama.load_state(llama_state_chat)
if tts_call_letters:
tts_call_letters += "\n"
# モデルが想定するmessagesの形式に変換
messages = []
if character_system_prompt:
messages.append({"role": "system", "content": character_system_prompt})
for role, message in conversation:
messages.append({"role": role, "content": message})
# メッセージのトークン数を制限
n_ctx = llama.n_ctx()
total_tokens = 0
token_counts = []
for message in messages:
message_bytes = message["content"].encode("utf-8")
tokens = len(llama.tokenize(message_bytes, add_bos=False)) + 3 # add start/end
total_tokens += tokens
token_counts.append(tokens)
if total_tokens > n_ctx - max_tokens:
print(f"Total tokens: {total_tokens}, n_ctx: {n_ctx}, max_tokens: {max_tokens}")
index = 1 if character_system_prompt else 0 # システムプロンプトがある場合は除外
tokens_removed = 0
while total_tokens > n_ctx - max_tokens:
if len(token_counts) == 0:
break
total_tokens -= token_counts[index]
tokens_removed += token_counts[index]
token_counts.pop(index)
messages.pop(index)
# tokens and messages are removed, so we don't need to update index
print(f"Removed {tokens_removed} tokens")
# # モデルによる応答生成
# response = generate(llama, handler, generation_params, messages)
# # モデルが"「」"で囲まれたセリフを生成することがあるので、"」"以降を削除し、さらに"「"を削除
# response = response.split("」", 1)[0].replace("「", "")
# streamで生成し、都度ttsを呼び出す
chat_completion_chunks = handler(llama=llama, messages=messages, stream=True, max_tokens=max_tokens, **generation_params)
response = ""
response_sent_to_tts = ""
i = 0
for chunk in chat_completion_chunks:
debug(chunk)
i += 1
if not debug_mode and i % 20 == 0:
print(".", end="", flush=True)
if "choices" in chunk and len(chunk["choices"]) > 0:
if "delta" in chunk["choices"][0]:
if "content" in chunk["choices"][0]["delta"]:
content = chunk["choices"][0]["delta"]["content"]
response += content
# TTSを呼び出す区切り文字が含まれている場合、TTSを呼び出す
if tts and any([letter in content for letter in tts_call_letters]):
text_to_tts = response[len(response_sent_to_tts) :] # 前回からの差分
print(f"\nTTS: {text_to_tts}")
response_sent_to_tts = response
tts(tts_char_name, text_to_tts)
# TTSへ未送信のテキストがある場合、最後にTTSを呼び出す
if tts and response != response_sent_to_tts:
text_to_tts = response[len(response_sent_to_tts) :]
response_sent_to_tts = response
tts(tts_char_name, text_to_tts)
# モデルの状態を保存
llama_state_chat = llama.save_state()
# モデルが"「」"で囲まれたセリフを生成することがあるので、"」"以降を削除し、さらに"「"を削除
response = response.strip()
if response.startswith("「") and response.endswith("」"):
response = response[1:]
response = response[:-1]
return response
def generate_user_response(
llama, handler, generation_params, prompter_system_prompt, prompter_user_prompt, conversation_str, policy, max_tokens
):
# モデルの状態を復元
global llama_state_prompter
if llama_state_prompter is not None:
llama.load_state(llama_state_prompter)
# プロンプトの作成
prompt = prompter_user_prompt.replace("$conversation", conversation_str).replace("$policy", policy)
# モデルが想定するmessagesの形式に変換
messages = []
messages.append({"role": "system", "content": prompter_system_prompt})
messages.append({"role": "user", "content": prompt})
# モデルによる応答生成
response = generate(llama, handler, generation_params, messages, max_tokens) # , first_turn=first_turn)
# モデルが"「」"で囲まれたセリフを生成することがあるので、"」"以降を削除し、さらに"「"を削除
response = response.split("」", 1)[0].replace("「", "")
# 同様に `"` も削除
response = response.replace('"', "")
# モデルの状態を保存
llama_state_prompter = llama.save_state()
return response
def get_key():
"""エンターキーを押さずに入力を取得する"""
try:
return msvcrt.getch().decode()
except UnicodeDecodeError:
return ""
def display_policies(policies):
"""ポリシーの選択肢を表示する"""
print("\n=== 次の行動を選んでください ===")
for i, policy in enumerate(policies, 1):
c = str(i) if i < 10 else chr(ord("a") + i - 10)
print(f"{c}. {policy}, ", end="")
print("\nx. 自由入力, y. 直前の選択をやり直す, z. 終了")
def format_conversation(conversation, user_name, character_name):
"""会話履歴をフォーマットする"""
if not conversation:
return "(会話はまだありません)"
if len(conversation) > MESSAGE_HISTORY_SIZE:
# 残りがN件以上になるように、N/2件単位で先頭から削除:モデルのキャッシュを有効に使うため
while len(conversation) - MESSAGE_HISTORY_SIZE // 2 >= MESSAGE_HISTORY_SIZE:
conversation = conversation[MESSAGE_HISTORY_SIZE // 2 :]
result = ""
for role, message in conversation:
name = user_name if role == "user" else character_name
result += f"{name}: {message}\n"
return result
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=None, help="toml file path")
parser.add_argument("-m", "--model", type=str, default=None, help="Model file path")
parser.add_argument("-ngl", "--n_gpu_layers", type=int, default=0, help="Number of GPU layers")
parser.add_argument("-c", "--n_ctx", type=int, default=4096, help="Context length")
parser.add_argument(
"-ch",
"--chat_handler",
type=str,
default=None,
help="Chat handler, e.g. command-r, mistral-instruct, alpaca, llama-3 etc. default: None (use template from GGUF)",
)
parser.add_argument("--max_tokens", type=int, default=512, help="Max tokens for each prompt")
parser.add_argument(
"-ts", "--tensor_split", type=str, default=None, help="Tensor split, float values separated by comma for each gpu"
)
parser.add_argument("--disable_mmap", action="store_true", help="Disable mmap")
parser.add_argument("--q8_kv_cache", action="store_true", help="Use quantized kv cache (Q8)")
parser.add_argument("--q4_kv_cache", action="store_true", help="Use quantized kv cache (Q4)")
parser.add_argument("--flash_attn", action="store_true", help="Use flash attention")
parser.add_argument("--debug", action="store_true", help="Debug mode")
parser.add_argument("--tts_module", type=str, default=None, help="TTS module. must have `tts(tts_char_name, message)` function")
parser.add_argument(
"--tts_call_letters", type=str, default=None, help="letters to call TTS in middle of the message, like `。?」♪`"
)
parser.add_argument("--tts_char_name", type=str, default=None, help="Character name for TTS (TTS model selection)")
parser.add_argument("--output_dir", type=str, default=None, help="Output directory of logs")
args = parser.parse_args()
if args.debug:
global debug_mode
debug_mode = True
# load .toml file
if args.config is not None:
config = toml.load(args.config)
character_name = config.get("character_name", DEFAULT_CHARACTER_NAME)
user_name = config.get("user_name", DEFAULT_USER_NAME)
character_system_prompt = config.get("character_system_prompt", DEFAULT_CHARACTER_SYSTEM_PROMPT)
prompter_system_prompt = config.get("prompter_system_prompt", DEFAULT_PROMPTER_SYSTEM_PROMPT)
prompter_user_prompt = config.get("prompter_user_prompt", DEFAULT_PROMPTER_USER_PROMPT)
policies = config.get("policies", DEFAULT_POLICIES)
character_generation_params = config.get("character_generation_params", {})
prompter_generation_params = config.get("prompter_generation_params", character_generation_params)
else:
character_name = DEFAULT_CHARACTER_NAME
user_name = DEFAULT_USER_NAME
character_system_prompt = DEFAULT_CHARACTER_SYSTEM_PROMPT
prompter_system_prompt = DEFAULT_PROMPTER_SYSTEM_PROMPT
prompter_user_prompt = DEFAULT_PROMPTER_USER_PROMPT
policies = DEFAULT_POLICIES
character_generation_params = {}
prompter_generation_params = character_generation_params
print(f"character generation params: {character_generation_params}")
print(f"prompter generation params: {prompter_generation_params}")
# initialize Llama
print(f"Initializing Llama. Model ID: {args.model}, N_GPU_LAYERS: {args.n_gpu_layers}, N_CTX: {args.n_ctx}")
tensor_split = None if args.tensor_split is None else [float(x) for x in args.tensor_split.split(",")]
llama = Llama(
model_path=args.model,
n_gpu_layers=args.n_gpu_layers,
tensor_split=tensor_split,
n_ctx=args.n_ctx,
use_mmap=not args.disable_mmap,
type_k=llama_cpp.GGML_TYPE_Q8_0 if args.q8_kv_cache else (llama_cpp.GGML_TYPE_Q4_0 if args.q4_kv_cache else None),
type_v=llama_cpp.GGML_TYPE_Q8_0 if args.q8_kv_cache else (llama_cpp.GGML_TYPE_Q4_0 if args.q4_kv_cache else None),
flash_attn=args.flash_attn,
)
if args.chat_handler is not None:
handler = llama_chat_format.get_chat_completion_handler(args.chat_handler)
else:
handler = llama._chat_handlers[llama.chat_format]
# TTS
if args.tts_module:
print(f"Loading TTS module: {args.tts_module}")
tts_module = importlib.import_module(args.tts_module)
tts = tts_module.tts
tts_call_letters = args.tts_call_letters
tts_char_name = args.tts_char_name
else:
tts = tts_call_letters = tts_char_name = None
# ログファイル
log_file = None
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
log_file = f"{args.output_dir}/chat_log_{timestamp}.txt"
print(f"Log file: {log_file}")
# 会話履歴
conversation = []
print("\n=== ゲームを開始します ===")
print("\n")
while True:
# 1. ポリシーの選択
display_policies(policies)
choice = get_key()
print(f"選択: {choice}")
if choice == "z":
print("\nゲームを終了します。")
break
# ユーザーの発言を取得
if choice == "y":
if len(conversation) >= 2:
conversation = conversation[:-2]
print(
f"\n直前の選択をやり直しました。({character_name}の最後の台詞: {conversation[-1][1] if len(conversation)>0 else ''})"
)
continue
else:
print("\n直前の選択がありません。")
continue
if choice == "x":
print("\n自由入力モード(入力後、Enterを押してください):")
user_message = input("> ")
elif choice.isdigit() or (choice.isalpha() and ord("a") <= ord(choice) <= ord("z")):
if choice.isalpha():
choice = ord(choice) - ord("a") + 9
else:
choice = int(choice) - 1
policy = policies[choice]
# 2. プロンプターによるセリフ生成
conversation_str = format_conversation(conversation, user_name, character_name)
user_message = generate_user_response(
llama,
handler,
prompter_generation_params,
prompter_system_prompt,
prompter_user_prompt,
conversation_str,
policy,
args.max_tokens,
)
else:
print("\n無効な選択です。")
continue
# 3. ユーザーのセリフを表示
print(f"\nあなた: {user_message}")
log(log_file, f"{user_name}: {user_message}")
conversation.append(("user", user_message))
# 4. キャラクターの応答を生成
character_response = generate_chat(
llama,
handler,
tts,
tts_char_name,
tts_call_letters,
character_generation_params,
character_system_prompt,
conversation,
args.max_tokens,
)
print(f"{character_name}: {character_response}")
log(log_file, f"{character_name}: {character_response}")
conversation.append(("assistant", character_response))
print("\n" + "=" * 50)
if __name__ == "__main__":
main()
character_name = "美咲"
user_name = ""
character_system_prompt = """あなたは高校二年生の女子生徒、美咲です。
ある日、学校の帰り道、突然の雨に見舞われ、ひとつの古びた洋館に避難することになりました。
その洋館には、あなたのクラスメイトである男子生徒、彰(ユーザー)がいました。
彼は、あなたに「ここには、幽霊が出ると言われている」と教えてくれました。
あなたは、彰と一緒に、その洋館の中を探索することになります。
ユーザーの言葉に対応して、美咲のセリフだけを出力してください。"""
prompter_system_prompt = """あなたは有能な脚本家です。
女子高生・美咲と男子高生・彰(ユーザー)が、古びた洋館で出会い、幽霊の噂を聞いて探索するストーリーを考えてください。
会話の履歴から、適切な次のセリフを考えてください。"""
prompter_user_prompt = """以下が直近の会話です。
---
$conversation
---
行動「$policy」を表す適切なユーザーの次のセリフを考えてください。セリフだけを出力してください。"""
policies = [
"美咲と会話する",
"消極的に行動する",
"積極的に行動する",
"あたりを調べる",
"核心に迫る",
"もしかしてあれは……?",
"逃げ出す",
"なんかいい感じに振る舞う",
]
character_generation_params = {"temperature"=0.5}
prompter_generation_params = {"temperature"=0.5}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment