Skip to content

Instantly share code, notes, and snippets.

@N8python
Created April 18, 2025 16:07
Show Gist options
  • Save N8python/008bb71362672eb6f40fc48666e62464 to your computer and use it in GitHub Desktop.
Save N8python/008bb71362672eb6f40fc48666e62464 to your computer and use it in GitHub Desktop.
from mlx.utils import tree_flatten, tree_map
from mlx_lm import load, generate
import mlx.core as mx
from mlx_lm.utils import (
dequantize_model,
fetch_from_hub,
get_model_path,
quantize_model,
save_config,
save_weights,
upload_to_hub,
)
model, tokenizer = load("gemma-3-27b-it-qat-4bit")
weights = dict(tree_flatten(model.parameters()))
for k, v in weights.items():
if v.dtype == mx.float16:
weights[k] = v.astype(mx.bfloat16)
model.load_weights(list(weights.items()))
save_weights("gemma-3-27b-it-qat-4bit", weights, donate_weights=True)
@awni
Copy link

awni commented Apr 18, 2025

Nice! FYI you can replace lines 14-19 with:

model.set_dtype(mx.bfloat16)

@N8python
Copy link
Author

Woah. That's convenient. Is it new?

@awni
Copy link

awni commented Apr 18, 2025

Not so new :). It basically does more or less what you wrote.. it's just such a common use case that we added a method for it.

@N8python
Copy link
Author

Can't believe I didn't know lol.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment