Skip to content

Instantly share code, notes, and snippets.

@ddh0
Last active June 4, 2025 23:11
Show Gist options
  • Save ddh0/b25e3a3c9d9171d98571cbdadf7c72fb to your computer and use it in GitHub Desktop.
Save ddh0/b25e3a3c9d9171d98571cbdadf7c72fb to your computer and use it in GitHub Desktop.
Low-level libllama test script
import os
import sys
import ctypes
import numpy as np
from easy_llama import libllama as lib
# -------------------------------------------------------------------------------------------- #
# set the shared library path
LIBLLAMA = '/Users/dylan/Documents/AI/llama.cpp/build/bin/libllama.dylib'
# set the GGUF file path
PATH_MODEL = '/Users/dylan/Documents/AI/models/Fireball-Meta-Llama-3.1-8B-Instruct-Agent-0.003-128K.Q5_K_M.gguf'
# enable or disable pure CPU computation
# (this also overrides n_gpu_layers and offload_kqv)
FORCE_CPU_ONLY = True
# configure model params
USE_MMAP = True
USE_MLOCK = False
N_GPU_LAYERS = 0
# configure context params
N_CTX = 4096
N_BATCH = 2048
N_UBATCH = 512
N_THREADS = 4
N_THREADS_BATCH = 8
TYPE_K = lib.GGMLType.GGML_TYPE_F16
TYPE_V = lib.GGMLType.GGML_TYPE_F16
OFFLOAD_KQV = True
FLASH_ATTN = False
# set the input text
INPUT_TXT = """<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is the capital of France?<|im_end|>
<|im_start|>assistant
"""
# -------------------------------------------------------------------------------------------- #
# misc.
MAX_TOKEN_LENGTH = 256
# load libllama
os.environ['LIBLLAMA'] = LIBLLAMA
lib.llama_backend_init()
# create model params
model_params = lib.llama_model_default_params()
if FORCE_CPU_ONLY:
devices = []
model_params.devices = (ctypes.c_void_p * (len(devices) + 1))(*devices, None)
model_params.use_mmap = USE_MMAP
model_params.use_mlock = USE_MLOCK
if not FORCE_CPU_ONLY:
model_params.n_gpu_layers = N_GPU_LAYERS
# load the model
llama_model = lib.llama_model_load_from_file(PATH_MODEL, model_params)
llama_vocab = lib.llama_model_get_vocab(llama_model)
# create context params
context_params = lib.llama_context_default_params()
context_params.n_ctx = N_CTX
context_params.n_batch = N_BATCH
context_params.n_ubatch = N_UBATCH
context_params.n_threads = N_THREADS
context_params.n_threads_batch = N_THREADS_BATCH
context_params.type_k = TYPE_K
context_params.type_v = TYPE_V
if not FORCE_CPU_ONLY:
context_params.offload_kqv = OFFLOAD_KQV
context_params.flash_attn = FLASH_ATTN
# load the context
llama_context = lib.llama_init_from_model(llama_model, context_params)
# tokenize the input text
text_bytes = INPUT_TXT.encode()
n_tokens_max = N_CTX
tokens_buf = (ctypes.c_int32 * n_tokens_max)()
n_tokens = lib.llama_tokenize(vocab=llama_vocab, text=text_bytes, text_len=len(text_bytes),
tokens=tokens_buf, n_tokens_max=n_tokens_max, add_special=False,
parse_special=False)
if n_tokens < 0:
raise ValueError(f'n_tokens value {-n_tokens} exceeds n_tokens_max value {n_tokens_max}')
tokens = list(tokens_buf[:n_tokens])
print(f'tokens: {tokens}')
# create a batch with the input tokens
llama_batch = lib.llama_batch_init(n_tokens=n_tokens, embd=0, n_seq_max=1)
llama_batch.n_tokens = n_tokens
for i in range(n_tokens):
llama_batch.token[i] = tokens[i]
llama_batch.pos[i] = i
llama_batch.seq_id[i][0] = 0
llama_batch.n_seq_id[i] = 1
llama_batch.logits[i] = True
# decode the batch
status_code = lib.llama_decode(llama_context, llama_batch)
if status_code != 0:
raise RuntimeError(f'llama_decode failed with status code {status_code}')
# get the resulting logits from the decoded batch
ctypes_logits = lib.llama_get_logits(llama_context)
n_vocab = lib.llama_vocab_n_tokens(llama_vocab)
logits = np.ctypeslib.as_array(ctypes_logits, shape=(n_tokens, n_vocab))
print(f'logits shape: {np.shape(logits)} (dtype {logits.dtype})')
# find the most likely next token
tok_id = np.argmax(logits[n_tokens - 1])
# get the string representation of that token
detok_buffer = ctypes.create_string_buffer(MAX_TOKEN_LENGTH)
n_bytes = lib.llama_token_to_piece(llama_vocab, tok_id, detok_buffer, MAX_TOKEN_LENGTH, 0, True)
if n_bytes > MAX_TOKEN_LENGTH:
raise ValueError(f"token_to_piece: the token with ID {tok_id} requires a buffer of size "
f"{n_bytes}, but the maximum buffer size is {MAX_TOKEN_LENGTH}")
tok_bytes = detok_buffer.raw[:n_bytes]
print(f'most likely next token: {tok_id} ({tok_bytes})')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment