Last active
August 30, 2024 13:45
-
-
Save uogbuji/78ddddf9bf8e72dd44c5626321db681c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Grr. Can't really use jupyter-as-script convention: https://twitter.com/uogbuji/status/1829325187965170154 | |
# %% | |
''' | |
test/llm_struct_probe.py | |
''' | |
import asyncio | |
class llm_struct_probe: | |
def __init__(self, model_path): | |
self.model_path = model_path | |
self.model = Model() | |
self.model.load(model_path) | |
self.model_type = self.model.model.model_type | |
async def test(self): | |
async for chunk in self.test1(): print(chunk, end='') | |
async def test1(self): | |
sysprompt = ('You are a helpful assistant with access to a set of tool which you may ' | |
"invoke to help respond to the user's request.\n" | |
"You may also choose not to use any of the tools, if you're sure they're not " | |
'useful for this response. In that case, you can fill out the\n' | |
'`toolio_none` pattern for your response\n' | |
'\n' | |
'\n' | |
'Tool name: today_kfabe\n' | |
' Get the current date\n' | |
'Invocation schema:\n' | |
'{"type": "object", "properties": {"name": {"type": "const", "const": ' | |
'"today_kfabe"}, "arguments": {"type": "object", "properties": {}, ' | |
'"required": []}}, "required": ["name", "arguments"]}\n' | |
'\n' | |
'Tool name: toolio_none\n' | |
' Call this tool to indicate that no other provided tool is useful for ' | |
'responding to the user\n' | |
'Invocation schema:\n' | |
'{"type": "object", "properties": {"name": {"type": "const", "const": ' | |
'"toolio_none"}, "arguments": {"type": "object", "properties": {"response": ' | |
'{"type": "string", "description": "Your normal response to the user"}}}}, ' | |
'"required": ["name", "arguments"]}\n' | |
'Your answer is a JSON array with one or more tool invocations according to ' | |
'the appropriate schema(s),\n' | |
'or it follows the `toolio_none` pattern, as appropriate to respond to the ' | |
"user's prompt below.\n") | |
prompt = 'Write me a haiku about AI' | |
tool_schemas = [{'properties': {'arguments': {'properties': {}, | |
'required': [], | |
'type': 'object'}, | |
'name': {'const': 'today_kfabe', | |
'type': 'const'}}, | |
'required': ['name', 'arguments'], | |
'type': 'object'}, | |
{'properties': {'arguments': {'properties': {'response': {'description': 'Your ' | |
'normal ' | |
'response ' | |
'to ' | |
'the ' | |
'user', | |
'type': 'string'}}, | |
'type': 'object'}, | |
'name': {'const': 'toolio_none', | |
'type': 'const'}}, | |
'required': ['name', 'arguments'], | |
'type': 'object'}] | |
full_schema = {'type': 'array', 'items': {'anyOf': tool_schemas}} | |
messages = [ {'role': 'system', 'content': sysprompt}, {'role': 'user', 'content': prompt} ] | |
responder = ToolCallResponder(self.model_path, self.model_type) | |
prompt_tokens = None | |
for result in self.model.completion(messages, full_schema, max_tokens=1024, temp=0.0, cache_prompt=False): | |
if result['op'] == 'evaluatedPrompt': | |
prompt_tokens = result['token_count'] | |
elif result['op'] == 'generatedTokens': | |
message = responder.generated_tokens(result['text']) | |
if message: | |
yield message | |
elif result['op'] == 'stop': | |
completion_tokens = result['token_count'] | |
yield responder.generation_stopped( | |
result['reason'], prompt_tokens, completion_tokens | |
) | |
else: | |
raise RuntimeError(f'Unknown result operation {result["op"]}') | |
# Below is just most of https://github.com/otriscon/llm-structured-output/blob/main/src/examples/llm_schema.py | |
""" | |
Example of JSON schema decoding with MLX. | |
""" | |
import argparse | |
import json | |
import time | |
from math import inf | |
from operator import itemgetter | |
from typing import Iterable, Optional, Union | |
import mlx.core as mx | |
import mlx.nn as nn | |
from mlx_lm.utils import load | |
from llm_structured_output import JsonSchemaAcceptorDriver | |
from llm_structured_output.util.bitmap import ( | |
bias_logits, | |
count_set_bits, | |
enumerate_set_bits, | |
) | |
from llm_structured_output.util.output import info, bold, bolddim, debug | |
from llm_structured_output.util.tokenization import HuggingfaceTokenizerHelper | |
class RejectedCompletion(Exception): | |
""" | |
It's rare, but sometimes we reach a state where it's not possible to | |
advance the acceptor. For example, when closing a JSON string we get | |
a higher probability for slanted quotes than straight ones and select | |
the wrong token. At that point, the LLM will continue generating with | |
the prior that the string is closed, but our acceptor will remain in | |
the string-accepting state. This can indicate an issue with the | |
tokenizer vocabulary passed to the acceptor, or a bug in the code | |
used to decode tokens from the LLM. If none of these apply, check that | |
the LLM is actually able to generate JSON, although most are. | |
""" | |
class Model: | |
def __init__(self): | |
mx.random.seed(0) | |
self.model = None | |
self.tokenizer = None | |
self.vocabulary = None | |
self.eos_id = None | |
self.json_schema_acceptor_driver_factory = None | |
self._cached_prompt = None | |
self._cached_cache = None | |
def load(self, model_path: str): | |
""" | |
Load locally or download from Huggingface hub. | |
""" | |
self.model, tokenizer = load(model_path) | |
self.tokenizer = HuggingfaceTokenizerHelper(tokenizer) | |
self.vocabulary, self.eos_id = self.tokenizer.extract_vocabulary() | |
self.json_schema_acceptor_driver_factory = ( | |
JsonSchemaAcceptorDriver.driver_factory_for_model( | |
self.vocabulary, self.eos_id | |
) | |
) | |
def get_driver_for_json_schema(self, schema, encapsulated: bool = False): | |
return self.json_schema_acceptor_driver_factory( | |
schema, is_encapsulated_json=encapsulated | |
) | |
def _evaluate_prompt( | |
self, prompt: list[int], prior_prompt: list[int] = None, prior_cache=None | |
): | |
if prior_prompt: | |
i = 0 | |
for i, t in enumerate(prior_prompt): | |
# We need to leave at least one token to evaluate because we don't | |
# save the past logits. | |
if i >= len(prompt) - 1 or prompt[i] != t: | |
break | |
cache = prior_cache | |
for layer_cache in cache: | |
layer_cache.reuse(len(prompt), i) | |
tokens = prompt[i:] | |
else: | |
cache = ReusableKVCache.for_model(self.model) | |
tokens = prompt | |
logits = self.model(mx.array(tokens)[None], cache) | |
return logits, cache | |
def _decode(self, tokens): | |
return self.tokenizer.no_strip_decode(tokens) | |
def _debug_top_tokens(self, logits, count=10): | |
token_logits = sorted( | |
enumerate(logits.tolist()), key=itemgetter(1), reverse=True | |
) | |
top_tokens = [ | |
(self._decode([t]), p) for t, p in token_logits[:count] if p != -inf | |
] | |
debug("TOP TOKENS:", top_tokens) | |
def _sample(self, logits, temp: float = 0): | |
if temp == 0: | |
result = mx.argmax(logits, axis=-1) | |
else: | |
result = mx.random.categorical(logits * (1 / temp)) | |
return result.item() | |
def _sample_with_bias( | |
self, logits, temp: float = 0, token_acceptor=None, lazy_bias: bool = True | |
): | |
if token_acceptor is None: | |
return self._sample(logits, temp) | |
if lazy_bias: | |
token = self._sample(logits, temp) | |
try: | |
token_acceptor.advance_token(token) | |
return token | |
except JsonSchemaAcceptorDriver.TokenRejected: | |
pass | |
accepted_token_bitmap = token_acceptor.select_valid_tokens() | |
if not accepted_token_bitmap: | |
debug(token_acceptor.cursors) | |
self._debug_top_tokens(logits) | |
raise RejectedCompletion() | |
token = self._sample(bias_logits(mx, logits, accepted_token_bitmap), temp) | |
token_acceptor.advance_token(token) | |
return token | |
def generate_without_schema(self, logits, cache, temp: Optional[float] = 0.0): | |
""" | |
For testing / comparison purposes. | |
""" | |
while True: | |
tokens = [self._sample(logits[0, -1, :], temp)] | |
yield tokens | |
if tokens[-1] == self.eos_id: | |
break | |
logits = self.model(mx.array(tokens)[None], cache) | |
def generate_with_schema( | |
self, logits, cache, token_acceptor, temp: Optional[float] = 0.0 | |
): | |
while True: | |
tokens = [self._sample_with_bias(logits[0, -1, :], temp, token_acceptor)] | |
yield tokens | |
if tokens[-1] == self.eos_id: | |
break | |
logits = self.model(mx.array(tokens)[None], cache) | |
def generate_with_preemptive_decoding( | |
self, | |
logits, | |
cache, | |
token_acceptor, | |
temp: Optional[float] = 0.0, | |
max_batch_size=5, | |
): | |
""" | |
Try to generate faster by precomputing two tokens at a time when possible. | |
If we know that the acceptor will only accept a small set of tokens after | |
the current one, we can evaluate a batch with one entry per possible | |
future token. Each entry in the batch contains the current token sampled, | |
which we have to evaluate anyway, and a second token corresponding to one | |
of the possible tokens that could be sampled from the output to the first | |
token. We get back logits for both tokens for each item in the batch: the | |
logits for the first token will be the same (as long as the model applies | |
a causal mask), and we can sample those logits to select from which of the | |
items in the batch we can select the second token. | |
In practice, this only seems to accelerate things for unquantized models. | |
""" | |
# Sample token from prompt evaluation | |
accepted_token_bitmap = token_acceptor.select_valid_tokens() | |
first_token_logits = bias_logits(mx, logits[0, -1, :], accepted_token_bitmap) | |
first_token = self._sample(first_token_logits, temp) | |
tokens = [first_token] | |
yield tokens | |
token_acceptor.advance_token(first_token) | |
accepted_token_bitmap = token_acceptor.select_valid_tokens() | |
while True: | |
last_token = tokens[-1] | |
if count_set_bits(accepted_token_bitmap) in range(1, max_batch_size + 1): | |
# If the number of possible follow-up tokens is small, submit for | |
# evaluation a batch of 2-token continuations. | |
batch = [] | |
for followup_token in enumerate_set_bits(accepted_token_bitmap): | |
batch.append([last_token, followup_token]) | |
# Re-shape the cache to match the input. | |
for layer_cache in cache: | |
layer_cache.keys = mx.concatenate([layer_cache.keys] * len(batch)) | |
layer_cache.values = mx.concatenate( | |
[layer_cache.values] * len(batch) | |
) | |
else: # Otherwise, submit the normal one-token continuation. | |
batch = [[last_token]] | |
logits = self.model(mx.array(batch), cache) | |
mx.eval(logits) | |
first_token_logits = bias_logits(mx, logits[0, 0, :], accepted_token_bitmap) | |
first_token = self._sample(first_token_logits, temp) | |
tokens = [first_token] | |
if first_token == self.eos_id: | |
yield tokens | |
break | |
token_acceptor.advance_token(first_token) | |
accepted_token_bitmap = token_acceptor.select_valid_tokens() | |
if not accepted_token_bitmap: | |
raise RejectedCompletion() | |
# If we had submitted 2-token continuations, we can decode a second token | |
if len(batch[0]) > 1: | |
index = next( # Find which of the second tokens was selected | |
i | |
for i, batch_item in enumerate(batch) | |
if batch_item[1] == first_token | |
) | |
second_token_logits = bias_logits( | |
mx, logits[index, 1, :], accepted_token_bitmap | |
) | |
second_token = self._sample(second_token_logits, temp) | |
tokens.append(second_token) | |
token_acceptor.advance_token(second_token) | |
accepted_token_bitmap = token_acceptor.select_valid_tokens() | |
# Select the accepted generation in the cache, restoring it to batch dimension 1. | |
for layer_cache in cache: | |
layer_cache.keys = layer_cache.keys.split([index, index + 1])[1] | |
layer_cache.values = layer_cache.values.split([index, index + 1])[1] | |
yield tokens | |
def _generate_tokens( | |
self, | |
generator: Iterable, | |
max_tokens: int = 1000, | |
) -> Iterable: | |
start_time = time.time_ns() | |
token_count = 0 | |
for tokens in generator: | |
token_count += len(tokens) | |
try: | |
eos_index = tokens.index(self.eos_id) | |
tokens = tokens[0:eos_index] | |
except ValueError: | |
eos_index = -1 | |
if tokens: | |
text = self._decode(tokens) | |
yield { | |
"op": "generatedTokens", | |
"text": text, | |
"token_count": len(tokens), | |
"time_ms": (time.time_ns() - start_time) / 1e6, | |
} | |
if eos_index >= 0: | |
yield {"op": "stop", "reason": "end"} | |
return | |
if token_count >= max_tokens: | |
yield {"op": "stop", "reason": "max_tokens"} | |
return | |
start_time = time.time_ns() | |
assert False | |
def completion( | |
self, | |
prompt: Union[str, Iterable[dict[str, str]]], | |
schema: dict, | |
encapsulated: bool = False, | |
max_tokens: int = 1000, | |
temp: float = 0.0, | |
seed: int = None, | |
preemptive_batch_size: int = 0, | |
cache_prompt: bool = False, | |
): | |
if seed is not None: | |
mx.random.seed(seed) | |
start_time = time.time_ns() | |
prompt_tokens = self.tokenizer.encode_prompt(prompt) | |
logits, cache = self._evaluate_prompt( | |
prompt_tokens, self._cached_prompt, self._cached_cache | |
) | |
if cache_prompt: | |
self._cached_prompt = prompt_tokens | |
self._cached_cache = cache | |
# Eager eval to more accurately reflect the prompt evaluation time. | |
mx.eval(logits) | |
prompt_time = time.time_ns() - start_time | |
yield { | |
"op": "evaluatedPrompt", | |
"prompt": prompt, | |
"token_count": len(prompt_tokens), | |
"time_ms": prompt_time / 1e6, | |
"prompt_tps": len(prompt_tokens) / (prompt_time / 1e9), | |
} | |
if schema: | |
token_acceptor = self.get_driver_for_json_schema(schema, encapsulated) | |
if preemptive_batch_size > 0: | |
generator = self.generate_with_preemptive_decoding( | |
logits, | |
cache, | |
token_acceptor, | |
temp, | |
max_batch_size=preemptive_batch_size, | |
) | |
else: | |
generator = self.generate_with_schema( | |
logits, cache, token_acceptor, temp | |
) | |
else: | |
generator = self.generate_without_schema(logits, cache, temp) | |
token_count = 0 | |
generation_time = 0 | |
for generation_result in self._generate_tokens(generator, max_tokens): | |
if generation_result["op"] == "generatedTokens": | |
token_count += generation_result["token_count"] | |
generation_time += generation_result["time_ms"] | |
elif generation_result["op"] == "stop": | |
generation_result["token_count"] = token_count | |
generation_result["time_ms"] = generation_time | |
# This is slightly incorrect, because the first token is generated | |
# from the prompt evaluation. | |
generation_result["generation_tps"] = token_count / ( | |
generation_time / 1e3 | |
) | |
yield generation_result | |
# Below is a big reduction from https://github.com/otriscon/llm-structured-output/blob/main/src/examples/server.py | |
class ChatCompletionResponder: | |
def __init__(self, model_name: str): | |
self.object_type = "chat.completion" | |
self.model_name = model_name | |
self.created = int(time.time()) | |
self.id = f"{id(self)}_{self.created}" | |
self.content = "" | |
def message_properties(self): | |
return { | |
"object": self.object_type, | |
"id": f"chatcmpl-{self.id}", | |
"created": self.created, | |
"model": self.model_name, | |
} | |
def translate_reason(self, reason): | |
""" | |
Translate our reason codes to OpenAI ones. | |
""" | |
if reason == "end": | |
return "stop" | |
if reason == "max_tokens": | |
return "length" | |
return f"error: {reason}" # Not a standard OpenAI API reason | |
def format_usage(self, prompt_tokens: int, completion_tokens: int): | |
return { | |
"usage": { | |
"completion_tokens": completion_tokens, | |
"prompt_tokens": prompt_tokens, | |
"total_tokens": completion_tokens + prompt_tokens, | |
}, | |
} | |
def generated_tokens( | |
self, | |
text: str, | |
): | |
self.content += text | |
return None | |
def generation_stopped( | |
self, | |
stop_reason: str, | |
prompt_tokens: int, | |
completion_tokens: int, | |
): | |
finish_reason = self.translate_reason(stop_reason) | |
message = {"role": "assistant", "content": self.content} | |
return { | |
"choices": [ | |
{"index": 0, "message": message, "finish_reason": finish_reason} | |
], | |
**self.format_usage(prompt_tokens, completion_tokens), | |
**self.message_properties(), | |
} | |
class ToolCallResponder(ChatCompletionResponder): | |
def __init__(self, model_name: str, functions: list[dict]): | |
super().__init__(model_name) | |
def translate_reason(self, reason): | |
if reason == "end": | |
return "tool_calls" | |
return super().translate_reason(reason) | |
def generation_stopped( | |
self, | |
stop_reason: str, | |
prompt_tokens: int, | |
completion_tokens: int, | |
): | |
finish_reason = self.translate_reason(stop_reason) | |
if finish_reason == "tool_calls": | |
tool_calls = json.loads(self.content) | |
if not isinstance(tool_calls, list): | |
# len(functions) == 1 was special cased | |
tool_calls = [tool_calls] | |
message = { | |
"role": "assistant", | |
"tool_calls": [ | |
{ | |
"id": f"call_{self.id}_{i}", | |
"type": "function", | |
"function": { | |
"name": function_call["name"], | |
"arguments": json.dumps(function_call["arguments"]), | |
}, | |
} | |
for i, function_call in enumerate(tool_calls) | |
], | |
} | |
elif finish_reason == "function_call": | |
function_call = json.loads(self.content) | |
message = { | |
"role": "assistant", | |
"function_call": { | |
"name": function_call["name"], | |
"arguments": json.dumps(function_call["arguments"]), | |
}, | |
} | |
else: | |
message = None | |
return { | |
"choices": [ | |
{"index": 0, "message": message, "finish_reason": finish_reason} | |
], | |
**self.format_usage(prompt_tokens, completion_tokens), | |
**self.message_properties(), | |
} | |
# Below from src/examples/reusable_kv_cache.py | |
from mlx_lm.models.base import KVCache | |
class ReusableKVCache(KVCache): | |
""" | |
Usability improvements over KVCache. | |
""" | |
@classmethod | |
def for_model(cls, model): | |
kv_heads = ( | |
[model.n_kv_heads] * len(model.layers) | |
if isinstance(model.n_kv_heads, int) | |
else model.n_kv_heads | |
) | |
return [cls(model.head_dim, n) for n in kv_heads] | |
def reuse(self, new_prompt_length, common_prefix_length): | |
""" | |
Reuse (part of) this cache for a new prompt that shares a prefix with it. | |
""" | |
if self.keys is None: | |
return | |
# Clip the cache to the common length. | |
self.offset = common_prefix_length | |
# Make sure the cache can fit the whole prompt. Because the offset is | |
# (very likely) not a multiple of the step size, update_and_fetch() | |
# won't resize the cache when evaluating the rest of the prompt as it | |
# would if it were an empty cache. | |
current_size = self.keys.shape[2] | |
if current_size < new_prompt_length: | |
n_steps = (self.step + new_prompt_length - 1) // self.step | |
k_add_shape = (1, self.n_kv_heads, n_steps * self.step - current_size, self.k_head_dim) | |
v_add_shape = (1, self.n_kv_heads, n_steps * self.step - current_size, self.v_head_dim) | |
k_zeros = mx.zeros(k_add_shape, self.keys.dtype) | |
v_zeros = mx.zeros(v_add_shape, self.values.dtype) | |
self.keys = mx.concatenate([self.keys, k_zeros], axis=2) | |
self.values = mx.concatenate([self.values, v_zeros], axis=2) | |
def update_and_fetch(self, keys, values): | |
""" | |
Override the base class method to allow the cache to be used with batches of | |
size greater than 1. | |
This is just a tiny change in the line that determines the shape. | |
""" | |
prev = self.offset | |
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: | |
n_steps = (self.step + keys.shape[2] - 1) // self.step | |
k_shape = (keys.shape[0], self.n_kv_heads, n_steps * self.step, self.k_head_dim) | |
v_shape = (keys.shape[0], self.n_kv_heads, n_steps * self.step, self.v_head_dim) | |
new_k = mx.zeros(k_shape, keys.dtype) | |
new_v = mx.zeros(v_shape, values.dtype) | |
if self.keys is not None: | |
if prev % self.step != 0: | |
self.keys = self.keys[..., :prev, :] | |
self.values = self.values[..., :prev, :] | |
self.keys = mx.concatenate([self.keys, new_k], axis=2) | |
self.values = mx.concatenate([self.values, new_v], axis=2) | |
else: | |
self.keys, self.values = new_k, new_v | |
self.offset += keys.shape[2] | |
self.keys[..., prev : self.offset, :] = keys | |
self.values[..., prev : self.offset, :] = values | |
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] | |
lsp = llm_struct_probe('mlx-community/Hermes-2-Theta-Llama-3-8B-4bit') | |
# lsp = llm_struct_probe('mlx-community/Mistral-Nemo-Instruct-2407-4bit') | |
asyncio.run(lsp.test()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment