Created
May 5, 2025 06:57
-
-
Save N8python/6f70660db9255492dafb29c2dbef6e89 to your computer and use it in GitHub Desktop.
A shamless ripoff of https://github.com/ml-explore/mlx-lm/pull/148 with grad accum.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import copy | |
import glob | |
import shutil | |
import time | |
import types | |
from pathlib import Path | |
import mlx.core as mx | |
import mlx.nn as nn | |
import mlx.optimizers as optimizers | |
import numpy as np | |
from mlx.utils import tree_flatten, tree_map | |
from mlx_lm.tokenizer_utils import TokenizerWrapper | |
from mlx_lm.tuner.datasets import load_dataset | |
from mlx_lm.tuner.trainer import iterate_batches | |
from mlx_lm.tuner.utils import print_trainable_parameters | |
from mlx_lm.utils import ( | |
create_model_card, | |
fetch_from_hub, | |
get_model_path, | |
quantize_model, | |
save_config, | |
save_weights, | |
) | |
def dwq_quantize( | |
model, | |
q_model, | |
opt, | |
data, | |
batch_size: int = 2, | |
max_seq_length: int = 2048, | |
temperature: float = 0.5, | |
dtype: mx.Dtype = mx.bfloat16, | |
grad_accum_steps: int = 1, | |
): | |
"""Distillation‑aware weight quantization with gradient accumulation. | |
After every *optimizer* step (i.e. after `grad_accum_steps` micro‑batches), | |
the function now logs the **mean KL‑div loss across those micro‑batches** | |
instead of spamming once per micro‑batch. | |
""" | |
assert grad_accum_steps > 0, "grad_accum_steps must be >= 1" | |
group = mx.distributed.init() | |
world_size = group.size() | |
rank = group.rank() | |
# ──────────────────── Trainable parameters ──────────────────────────────── | |
def unfreeze(_, m): | |
if hasattr(m, "bits") and hasattr(m, "group_size"): | |
m.unfreeze(keys=["scales", "biases"], recurse=False) | |
q_model.apply_to_modules(unfreeze) | |
print_trainable_parameters(q_model) | |
# ─────────────────────── Helper functions ──────────────────────────────── | |
def log_norm(x): | |
x = x * (1 / temperature) | |
return x - mx.logsumexp(x, axis=-1, keepdims=True) | |
def loss_fn(params, x, targets, lengths): | |
q_model.update(tree_map(lambda t: t.astype(dtype), params)) | |
logits = q_model(x).astype(mx.float32) | |
losses = nn.losses.kl_div_loss(log_norm(logits), targets, reduction="none") | |
mask = mx.arange(targets.shape[1]) < lengths[:, 1:] | |
ntoks = mask.sum() | |
loss = (mask * losses).sum() / ntoks | |
return loss, ntoks | |
def calc_grads(inputs, targets, lengths, params): | |
(loss, ntoks), grads = mx.value_and_grad(loss_fn)( | |
params, inputs, targets, lengths | |
) | |
grads = nn.average_gradients(grads) | |
return loss, ntoks, grads | |
# Keep params in fp32 while learning. | |
params = tree_map(lambda x: x.astype(mx.float32), q_model.trainable_parameters()) | |
# Accumulation buffers. | |
grad_accum = tree_map(lambda p: mx.zeros_like(p), params) | |
accum_counter = 0 | |
# Logging helpers. | |
accum_loss_sum = 0.0 | |
accum_loss_count = 0 | |
global_step = 0 # Counts *optimizer* steps. | |
tokens = 0 | |
tic = time.time() | |
for it, (batch, lengths) in enumerate( | |
iterate_batches(data, batch_size, max_seq_length) | |
): | |
# ───────────── Teacher forward pass (no grad) ──────────────────────── | |
targets = log_norm(model(batch).astype(mx.float32)) | |
mx.eval(targets) | |
mx.clear_cache() | |
# ───────────── Student forward/backward ────────────────────────────── | |
loss, ntoks, grads = calc_grads(batch, targets, lengths, params) | |
mx.eval(loss, grads) | |
mx.clear_cache() | |
# Distributed reduction for consistent loss/ntoks across devices. | |
loss_red = mx.distributed.all_sum(loss, stream=mx.cpu).item() / world_size | |
ntoks_red = mx.distributed.all_sum(ntoks, stream=mx.cpu).item() | |
tokens += ntoks_red | |
accum_loss_sum += loss_red | |
accum_loss_count += 1 | |
# ───────────── Gradient accumulation ──────────────────────────────── | |
grad_accum = tree_map(lambda a, b: a + b, grad_accum, grads) | |
accum_counter += 1 | |
step_now = accum_counter == grad_accum_steps | |
last_batch = it == len(data) - 1 | |
if step_now or last_batch: | |
scale = accum_counter # May be < grad_accum_steps on tail batch. | |
avg_grads = tree_map(lambda g: g / scale, grad_accum) | |
params = opt.apply_gradients(avg_grads, params) | |
# ──────────────── Logging (once per *optimizer* step) ────────── | |
avg_step_loss = accum_loss_sum / accum_loss_count | |
global_step += 1 | |
toks_per_sec = tokens / (time.time() - tic) | |
if rank == 0: | |
print( | |
f"step={global_step}, avg_loss={avg_step_loss:.3f}, " | |
f"tokens={tokens}, toks_per_sec={toks_per_sec:.3f}", | |
flush=True, | |
) | |
# Reset accumulators for next step. | |
grad_accum = tree_map(lambda p: mx.zeros_like(p), grad_accum) | |
accum_counter = 0 | |
accum_loss_sum = 0.0 | |
accum_loss_count = 0 | |
# Push learned params back into student model (cast to final dtype). | |
q_model.update(tree_map(lambda x: x.astype(dtype), params)) | |
def save_model( | |
model: nn.Module, | |
tokenizer: TokenizerWrapper, | |
config, | |
model_path: Path, | |
mlx_path: str, | |
hf_path: str, | |
): | |
weights = dict(tree_flatten(model.parameters())) | |
mlx_path = Path(mlx_path) | |
save_weights(mlx_path, weights, donate_weights=True) | |
for file in glob.glob(str(model_path / "*.py")): | |
shutil.copy(file, mlx_path) | |
tokenizer.save_pretrained(mlx_path) | |
save_config(config, config_path=mlx_path / "config.json") | |
create_model_card(mlx_path, hf_path) | |
def load_data(tokenizer, data_path: str, num_samples: int): | |
args = types.SimpleNamespace( | |
hf_dataset={ | |
"path": data_path, | |
"train_split": f"train[:{num_samples}]", | |
"valid_split": "train[:1]", | |
}, | |
train=True, | |
test=False, | |
) | |
dataset = load_dataset(args, tokenizer)[0] | |
return [dataset.process(d) for d in dataset] | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model", "-m", default="Qwen/Qwen3-1.7B") | |
parser.add_argument("--quantized-model", default=None) | |
parser.add_argument("--mlx-path", default="mlx_model", help="Path to save the quantized model.") | |
parser.add_argument("--bits", type=int, default=4, help="Bits per weight for quantization.") | |
parser.add_argument("--group-size", type=int, default=64, help="Group size for quantization.") | |
parser.add_argument("--num-samples", type=int, default=1024, help="Number of samples to use for training.") | |
parser.add_argument("--max-seq-length", type=int, default=2048) | |
parser.add_argument("--seed", type=int, default=123) | |
parser.add_argument("--learning-rate", type=float, default=1e-5) | |
parser.add_argument("--batch-size", type=int, default=8) | |
parser.add_argument("--data-path", type=str, default="allenai/tulu-3-sft-mixture", help="HuggingFace dataset path.") | |
parser.add_argument("--temperature", type=float, default=0.5, help="Temperature scaling for the loss.") | |
parser.add_argument("--grad-accum-steps", type=int, default=1, help="Micro‑batches per optimizer step.") | |
args = parser.parse_args() | |
group = mx.distributed.init() | |
num_samples = args.num_samples | |
if num_samples % group.size() > 0: | |
num_samples += group.size() - num_samples % group.size() | |
np.random.seed(args.seed) | |
mx.random.seed(args.seed) | |
model_path = get_model_path(args.model, revision=None) | |
model, config, tokenizer = fetch_from_hub(model_path, lazy=True) | |
calibration_data = load_data(tokenizer, args.data_path, args.num_samples) | |
if args.quantized_model is not None: | |
q_model_path = get_model_path(args.quantized_model, revision=None) | |
q_model, config, _ = fetch_from_hub(q_model_path, lazy=True) | |
else: | |
q_model = copy.deepcopy(model) | |
_, config = quantize_model(q_model, config, q_group_size=args.group_size, q_bits=args.bits) | |
opt = optimizers.Adam(learning_rate=args.learning_rate, bias_correction=True) | |
dwq_quantize( | |
model, | |
q_model, | |
opt, | |
calibration_data, | |
batch_size=args.batch_size, | |
max_seq_length=args.max_seq_length, | |
temperature=args.temperature, | |
grad_accum_steps=args.grad_accum_steps, | |
) | |
save_model(q_model, tokenizer, config, model_path, args.mlx_path, args.model) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment