Skip to content

Instantly share code, notes, and snippets.

@brandon-lockaby
Created June 10, 2024 02:56
Show Gist options
  • Save brandon-lockaby/76a2aae68b4ea427dfb9efb2e11eb89b to your computer and use it in GitHub Desktop.
Save brandon-lockaby/76a2aae68b4ea427dfb9efb2e11eb89b to your computer and use it in GitHub Desktop.
logit chart notebook
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install --upgrade --force-reinstall llama-cpp-python \\\n",
" --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu124"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from llama_cpp import Llama\n",
"\n",
"llm = Llama(\n",
" #model_path=\"/home/axyo/dev/LLM/models/Meta-Llama-3-8B-GGUF-v2/Meta-Llama-3-8B.Q5_0.gguf\",\n",
" model_path=\"/home/axyo/dev/LLM/models/Meta-Llama-3-8B-Instruct-GGUF-v2/Meta-Llama-3-8B-Instruct-v2.Q5_0.gguf\",\n",
" n_gpu_layers=-1,\n",
" seed=8,\n",
" n_ctx=4096,\n",
" logits_all=True,\n",
")\n",
"print(llm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"prompt = \"\"\"<|start_header_id|>user<|end_header_id|>\n",
"\n",
"What game console is Chrono Trigger for? Reply only that<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"\"\"\"\n",
"\n",
"output = llm(\n",
" prompt,\n",
" echo=False,\n",
" logprobs=100,\n",
" max_tokens=1,\n",
" stop=[\"\\n\"],\n",
" repeat_penalty=1.0, # disable penalties\n",
" top_k=1,\n",
")\n",
"text = output['choices'][0]['text']\n",
"print(f'text: \"{text}\"')\n",
"logprobs = output['choices'][0]['logprobs']\n",
"print(\"tokens: \" + str(logprobs['tokens']))\n",
"print(\"token_logprobs: \" + str(logprobs['token_logprobs']))\n",
"print(\"top_logprobs: \" + str(logprobs['top_logprobs']))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Extend each prediction\n",
"\n",
"top_logprobs = logprobs['top_logprobs'][-1]\n",
"new_logprobs = {}\n",
"idx = 0\n",
"for tok, logprob in top_logprobs.items():\n",
" #print(tok, logprob)\n",
" output = llm(\n",
" prompt + tok,\n",
" echo=False,\n",
" max_tokens=20,\n",
" repeat_penalty=1.0, # disable penalties\n",
" top_k=1,\n",
" stop=[\"\\n\"]\n",
" )\n",
" result = tok + output['choices'][0]['text'] + \"\"\n",
" #print(result, logprob)\n",
" new_logprobs[str(idx) + \". \" + result] = logprob\n",
" idx += 1\n",
"\n",
"print(new_logprobs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import re\n",
"\n",
"#top_logprobs = output['choices'][0]['logprobs']['top_logprobs'][0]\n",
"top_logprobs = new_logprobs\n",
"top_logprobs = dict(sorted(top_logprobs.items(), key=lambda item: item[1], reverse=True))\n",
"print(top_logprobs)\n",
"\n",
"\n",
"token_labels = list(top_logprobs.keys())\n",
"logprob_values = list(top_logprobs.values())\n",
"\n",
"#token_labels = [label.replace(\" \", \"_\") for label in token_labels]\n",
"\n",
"replacement_dict = {\n",
" \"\\n\": \"↩️\",\n",
" \"\\t\": \"➡️➡️➡️➡️\"\n",
"}\n",
"for src, dest in replacement_dict.items():\n",
" token_labels = [label.replace(src, dest) for label in token_labels]\n",
"\n",
"token_labels = [re.sub(r'^(\\d+\\.)', '', label) for label in token_labels]\n",
"\n",
"plt.figure(figsize=(5, 40))\n",
"plt.margins(y=0.0028)\n",
"plt.barh(range(len(logprob_values)), logprob_values, align='center', color='steelblue')\n",
"plt.yticks(range(len(token_labels)), token_labels)\n",
"plt.xlabel('logprob')\n",
"plt.title(f\"{prompt}…\")\n",
"plt.gca().invert_yaxis()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jupyter",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment