Created
May 7, 2025 08:23
-
-
Save mzbac/06c932b0884771fa59e786b348ddbb5e to your computer and use it in GitHub Desktop.
MLX LLM embedding
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mlx.core as mx | |
import numpy as np | |
from transformers import PreTrainedTokenizer, AutoModel, AutoTokenizer | |
import torch | |
import torch.nn.functional as F | |
from torch import Tensor | |
from typing import List, Dict, Any | |
from mlx_lm.utils import load | |
def tokenize_texts( | |
tokenizer: PreTrainedTokenizer, | |
sentences: List[str], | |
max_length: int | |
) -> Dict[str, mx.array]: | |
if not sentences: | |
return { | |
"input_ids": mx.zeros((0, max_length), dtype=mx.int32), | |
"attention_mask": mx.zeros((0, max_length), dtype=mx.int32), | |
} | |
batch_mx = tokenizer( | |
sentences, | |
max_length=max_length, | |
padding=True, | |
truncation=True, | |
return_tensors="mlx" | |
) | |
return batch_mx | |
def encode_batch( | |
model: Any, | |
batch_mx: Dict[str, mx.array] | |
) -> mx.array: | |
model_output = model.model(batch_mx["input_ids"]) | |
return model_output | |
def pool_last_token_simple( | |
last_hidden_state: mx.array, | |
attention_mask: mx.array | |
) -> mx.array: | |
sequence_lengths = mx.sum(attention_mask, axis=1) - 1 | |
batch_size = last_hidden_state.shape[0] | |
last_token_indices = mx.maximum(sequence_lengths, 0) | |
pooled = last_hidden_state[mx.arange(batch_size), last_token_indices] | |
return pooled | |
def normalize_embeddings( | |
embeddings: mx.array | |
) -> mx.array: | |
norm = mx.linalg.norm(embeddings, ord=2, axis=-1, keepdims=True) | |
normalized = embeddings / mx.maximum(norm, 1e-9) | |
return normalized | |
def hf_last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: | |
sequence_lengths = attention_mask.sum(dim=1) - 1 | |
batch_size = last_hidden_states.shape[0] | |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] | |
def compare_mlx_hf_embeddings( | |
mlx_model: Any, | |
hf_model: torch.nn.Module, | |
tokenizer: PreTrainedTokenizer, | |
sentences: List[str], | |
max_length: int, | |
rtol: float = 1e-4, | |
atol: float = 1e-5, | |
device: str = 'cpu' | |
) -> bool: | |
print("\n--- Starting Comparison ---") | |
print(f"Using sentences: {sentences}") | |
print(f"Max length: {max_length}") | |
print(f"PyTorch Device: {device}") | |
print(f"Tolerances: rtol={rtol}, atol={atol}") | |
print("\nRunning MLX Implementation...") | |
try: | |
batch_mx = tokenize_texts(tokenizer, sentences, max_length) | |
print(f"MLX Tokenized input_ids shape: {batch_mx['input_ids'].shape}") | |
mlx_hidden = encode_batch(mlx_model, batch_mx) | |
mlx_hidden = mlx_hidden.astype(mx.float32) | |
print(f"MLX Hidden state shape: {mlx_hidden.shape}") | |
mlx_pooled = pool_last_token_simple(mlx_hidden, batch_mx['attention_mask']) | |
print(f"MLX Pooled shape: {mlx_pooled.shape}") | |
mlx_normalized = normalize_embeddings(mlx_pooled) | |
print(f"MLX Normalized shape: {mlx_normalized.shape}") | |
print(f"MLX Normalized dtype: {mlx_normalized.dtype}") | |
mx.eval(mlx_normalized) | |
mlx_result_np = np.array(mlx_normalized, copy=True) | |
print("MLX Implementation finished.") | |
print(f"MLX Result NumPy shape: {mlx_result_np.shape}, dtype: {mlx_result_np.dtype}") | |
except Exception as e: | |
print(f"Error during MLX execution: {e}") | |
return False | |
print("\nRunning PyTorch (Hugging Face) Reference...") | |
try: | |
hf_model.to(device) | |
hf_model.eval() | |
batch_pt = tokenizer( | |
sentences, | |
max_length=max_length, | |
padding=True, | |
truncation=True, | |
return_tensors="pt" | |
).to(device) | |
print(f"HF Tokenized input_ids shape: {batch_pt['input_ids'].shape}") | |
with torch.no_grad(): | |
outputs = hf_model(**batch_pt) | |
hf_hidden = outputs.last_hidden_state | |
hf_hidden = hf_hidden.to(torch.float32) | |
print(f"HF Hidden state shape: {hf_hidden.shape}") | |
hf_pooled = hf_last_token_pool(hf_hidden, batch_pt['attention_mask']) | |
print(f"HF Pooled shape: {hf_pooled.shape}") | |
hf_normalized = F.normalize(hf_pooled, p=2, dim=1) | |
print(f"HF Normalized shape: {hf_normalized.shape}") | |
print(f"HF Normalized dtype: {hf_normalized.dtype}") | |
hf_result_np = hf_normalized.cpu().numpy() | |
print("PyTorch (HF) Implementation finished.") | |
except Exception as e: | |
print(f"Error during PyTorch execution: {e}") | |
return False | |
print("\nComparing Results...") | |
passed = True | |
if mlx_result_np.shape != hf_result_np.shape: | |
print(f"❌ FAILED: Shape mismatch!") | |
print(f" MLX Shape: {mlx_result_np.shape}") | |
print(f" HF Shape: {hf_result_np.shape}") | |
passed = False | |
else: | |
print(f"✅ Shapes Match: {mlx_result_np.shape}") | |
if passed: | |
if np.allclose(mlx_result_np, hf_result_np, rtol=rtol, atol=atol): | |
print(f"✅ PASSED: Numerical values are close within tolerance (rtol={rtol}, atol={atol}).") | |
else: | |
print(f"❌ FAILED: Numerical values differ significantly!") | |
diff = np.abs(mlx_result_np - hf_result_np) | |
print(f" Max absolute difference: {np.max(diff)}") | |
print(f" Mean absolute difference: {np.mean(diff)}") | |
passed = False | |
print("\n--- Comparison Finished ---") | |
return passed | |
if __name__ == '__main__': | |
MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct" | |
MAX_LEN_TEST = 128 | |
print(f"Loading mlx model '{MODEL_NAME}' ...") | |
mlx_model, _ = load(MODEL_NAME) | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
except Exception as e: | |
print(f"Failed to load tokenizer '{MODEL_NAME}': {e}") | |
exit() | |
print(f"Loading HF PyTorch Model '{MODEL_NAME}'...") | |
try: | |
hf_model = AutoModel.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16) | |
print(hf_model) | |
except Exception as e: | |
print(f"Failed to load HF PyTorch model '{MODEL_NAME}': {e}") | |
exit() | |
test_sentences = [ | |
"This is a test sentence.", | |
"Let's compare MLX and PyTorch.", | |
"Short one.", | |
"A significantly longer sentence to test padding and truncation mechanisms effectively." | |
] | |
test_passed = compare_mlx_hf_embeddings( | |
mlx_model=mlx_model, | |
hf_model=hf_model, | |
tokenizer=tokenizer, | |
sentences=test_sentences, | |
max_length=MAX_LEN_TEST, | |
device="mps", | |
rtol=1e-5, | |
atol=1e-5 | |
) | |
print(f"\nOverall Test Result: {'PASSED' if test_passed else 'FAILED'}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment