Skip to content

Instantly share code, notes, and snippets.

@awni
Last active July 8, 2025 15:12
Show Gist options
  • Save awni/33a5315e0a5b91ea2cd032af39a624d8 to your computer and use it in GitHub Desktop.
Save awni/33a5315e0a5b91ea2cd032af39a624d8 to your computer and use it in GitHub Desktop.
import argparse
import math
import mlx.core as mx
import mlx.nn as nn
from tqdm import tqdm
from mlx_lm.utils import load
from pathlib import Path
def eval_ppl(model, data, batch_size=32):
all_loss = 0.0
ntoks = 0
for s in range(0, len(data), batch_size):
batch = data[s:s+batch_size]
logits = model(batch[:, :-1]).astype(mx.float32)
losses = nn.losses.cross_entropy(logits, batch[:, 1:])
all_loss += losses.sum().item()
ntoks += losses.size
ppl = math.exp(all_loss / ntoks)
return ppl
def load_dataset(tokenizer, num_samples: int, sequence_length: int) -> mx.array:
save_dir = Path.home() / ".cache/mlx-lm/calibration_v5.txt"
if not save_dir.exists():
save_dir.parent.mkdir(parents=True, exist_ok=True)
url = "https://gist.githubusercontent.com/tristandruyen/9e207a95c7d75ddf37525d353e00659c/raw/571fda718462de863e5a0171078c175420c7649a/calibration_data_v5_rc.txt"
request.urlretrieve(url, save_dir)
with open(save_dir) as fid:
texts = fid.read()
tokens = tokenizer.encode(texts, return_tensors="mlx")[0]
# select random non-overlapping chunks
tokens = tokens[: (tokens.size // sequence_length) * sequence_length]
tokens = tokens.reshape(-1, sequence_length)
segments = mx.random.permutation(tokens.shape[0])[:num_samples]
return tokens[segments]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", "-m", default="Qwen/Qwen3-1.7B"
)
parser.add_argument("--num-samples", type=int, default=32)
parser.add_argument("--sequence-length", type=int, default=512)
parser.add_argument("--seed", type=int, default=123)
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load(args.model)
data = load_dataset(tokenizer, args.num_samples, args.sequence_length)
ppl = eval_ppl(model, data)
print(f"Original PPL: {ppl:.3f}")
import argparse
import math
import mlx.core as mx
import mlx.nn as nn
from tqdm import tqdm
from mlx_lm.utils import load
def eval_ppl(model, data, batch_size=4):
all_loss = 0.0
ntoks = 0
n_ctx = data.shape[1] // 2
for s in tqdm(range(0, len(data), batch_size), total=len(data) // batch_size):
batch = data[s:s+batch_size]
logits = model(batch[:, :-1]).astype(mx.float32)
logits = logits[:, n_ctx:]
targets = batch[:, n_ctx + 1:]
losses = nn.losses.cross_entropy(logits, targets)
all_loss += losses.sum().item()
ntoks += losses.size
ppl = math.exp(all_loss / ntoks)
return ppl
def load_wikitext(
tokenizer, num_samples: int = 256, sequence_length: int = 1024
) -> mx.array:
from datasets import load_dataset
dataset = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train")
texts = "\n\n".join(dataset["text"])
tokens = tokenizer.encode(texts, return_tensors="mlx")[0]
tokens = tokens[: (tokens.size // sequence_length) * sequence_length]
tokens = tokens.reshape(-1, sequence_length)
segments = mx.random.permutation(tokens.shape[0])[:num_samples]
return tokens[segments]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", "-m", default="Qwen/Qwen3-4B-base"
)
parser.add_argument("--num-samples", type=int, default=512)
parser.add_argument("--sequence-length", type=int, default=1024)
parser.add_argument("--seed", type=int, default=123)
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load(args.model)
data = load_wikitext(tokenizer, args.num_samples, args.sequence_length + 1)
ppl = eval_ppl(model, data)
print(f"PPL: {ppl:.3f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment