feat(mlx): load GPTQ/AWQ pre-quantized checkpoints on Apple Silicon#848
feat(mlx): load GPTQ/AWQ pre-quantized checkpoints on Apple Silicon#848BardiaKoopah wants to merge 1 commit into
Conversation
mlx-lm cannot load AutoGPTQ/AutoAWQ packed weights (qweight/qzeros/g_idx), so pre-quantized GPTQ and AWQ repos previously failed or misrouted on the MLX path. Add runtime dequantization: detect quant_method gptq/awq, dequantize the packed weights to a temporary dense fp16 checkpoint, and load that through the existing pipeline, which re-quantizes to MLX affine 4-bit for the LoRA base (mirrors the bnb NF4->fp16->MLX-4bit flow). Dequant math is pure MLX array ops (bit-unpack + group scale/zero), verified bit-exact against AutoGPTQ/AutoAWQ conventions: - GPTQ: 8-per-int32 nibble unpack, +1 zero-point, g_idx group mapping (including desc_act/act-order permutation and non-128 group sizes). - AWQ (gemm): reverse-order column interleave, direct zero-point; reconstructs the AWQ smoothed weight the forward pass expects. Unsupported inputs now fail loud instead of silently misrouting: 3-bit GPTQ, VLM GPTQ/AWQ, and known-but-unhandled packed formats (compressed-tensors, aqlm, quip, eetq, hqq, vptq, fp_quant) each raise a clear error. The original repo id is preserved for _hf_repo/_src_path and the temp checkpoint is cleaned up after the weights are materialized. Verified end-to-end (load, generate, LoRA SFT) on GPTQ (desc_act True/False, gs 32/128) and AWQ (gs128) across Qwen2 and Llama, plus a 3B AWQ model; MLX baseline suite unaffected.
There was a problem hiding this comment.
Code Review
This pull request adds support for loading GPTQ and AWQ pre-quantized Hugging Face checkpoints on Apple Silicon by dequantizing them to a dense fp16 format in a temporary directory before loading. The review feedback suggests using tempfile.TemporaryDirectory instead of tempfile.mkdtemp to prevent resource leaks in case of exceptions, along with a more concise implementation for reinterpreting signed integers as unsigned using MLX's built-in casting.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _materialize_dequantized_hf_checkpoint(local_path, config_data, method, quant_config): | ||
| """Dequantize a GPTQ/AWQ checkpoint to a temporary dense fp16 checkpoint. | ||
|
|
||
| Returns (temp_dir, new_config_data) where new_config_data has the HF | ||
| quantization metadata stripped so the downstream MLX load treats it as an | ||
| ordinary dense model (and may re-quantize it to MLX affine for LoRA). | ||
| """ | ||
| import glob | ||
| import shutil | ||
| import mlx.core as mx | ||
|
|
||
| bits = int(quant_config.get("bits", 4) or 4) | ||
| group_size = int(quant_config.get("group_size", 128) or 128) | ||
| if bits != 4: | ||
| raise NotImplementedError( | ||
| f"Unsloth: {method.upper()} runtime dequant on MLX currently supports " | ||
| f"4-bit checkpoints only (got bits={bits})." | ||
| ) | ||
|
|
||
| shard_paths = sorted(glob.glob(os.path.join(local_path, "*.safetensors"))) | ||
| if not shard_paths: | ||
| raise FileNotFoundError( | ||
| f"Unsloth: no .safetensors weights found in '{local_path}' for " | ||
| f"{method.upper()} dequantization." | ||
| ) | ||
| weights = {} | ||
| for shard in shard_paths: | ||
| weights.update(mx.load(shard)) | ||
|
|
||
| quant_modules = sorted( | ||
| {k[: -len(".qweight")] for k in weights if k.endswith(".qweight")} | ||
| ) | ||
| if not quant_modules: | ||
| raise ValueError( | ||
| f"Unsloth: '{local_path}' is declared {method.upper()} but no packed " | ||
| "'.qweight' tensors were found." | ||
| ) | ||
|
|
||
| new_weights = {} | ||
| quant_related = set() | ||
| for name in quant_modules: | ||
| qweight = weights[name + ".qweight"] | ||
| qzeros = weights[name + ".qzeros"] | ||
| scales = weights[name + ".scales"] | ||
| if method == "gptq": | ||
| g_idx = weights[name + ".g_idx"] | ||
| dense = _gptq_dequantize_weight(qweight, qzeros, scales, g_idx, bits=bits) | ||
| quant_related.update( | ||
| name + suffix | ||
| for suffix in (".qweight", ".qzeros", ".scales", ".g_idx") | ||
| ) | ||
| else: | ||
| dense = _awq_dequantize_weight(qweight, qzeros, scales, group_size, bits=bits) | ||
| quant_related.update( | ||
| name + suffix for suffix in (".qweight", ".qzeros", ".scales") | ||
| ) | ||
| new_weights[name + ".weight"] = dense.astype(mx.float16) | ||
|
|
||
| for key, tensor in weights.items(): | ||
| if key in quant_related: | ||
| continue | ||
| # GPTQ QuantLinear allocates a zero bias even for architectures that | ||
| # have no bias (e.g. Llama); drop those so the dense checkpoint matches | ||
| # the target module tree. Real (non-zero) biases (e.g. Qwen2 q/k/v) are | ||
| # preserved. | ||
| if key.endswith(".bias") and bool(mx.all(tensor == 0).item()): | ||
| continue | ||
| new_weights[key] = tensor | ||
|
|
||
| mx.eval(list(new_weights.values())) | ||
|
|
||
| temp_dir = tempfile.mkdtemp(prefix="unsloth_mlx_dequant_") | ||
| for filename in os.listdir(local_path): | ||
| src = os.path.join(local_path, filename) | ||
| if not os.path.isfile(src): | ||
| continue | ||
| if filename.endswith(".safetensors") or filename.endswith(".safetensors.index.json"): | ||
| continue | ||
| shutil.copy(src, os.path.join(temp_dir, filename)) | ||
|
|
||
| new_config_data = dict(config_data) | ||
| new_config_data.pop("quantization_config", None) | ||
| new_config_data.pop("quantization", None) | ||
| with open(os.path.join(temp_dir, "config.json"), "w") as f: | ||
| json.dump(new_config_data, f, indent=2) | ||
|
|
||
| mx.save_safetensors(os.path.join(temp_dir, "model.safetensors"), new_weights) | ||
| return temp_dir, new_config_data |
There was a problem hiding this comment.
Using tempfile.mkdtemp can lead to persistent resource leaks of several gigabytes on disk if an exception is raised during the loading process (e.g., due to network issues, missing files, or shape mismatches). To ensure robust cleanup, use tempfile.TemporaryDirectory which automatically deletes the directory and its contents when garbage collected or when .cleanup() is called.
def _materialize_dequantized_hf_checkpoint(local_path, config_data, method, quant_config):
"""Dequantize a GPTQ/AWQ checkpoint to a temporary dense fp16 checkpoint.
Returns (temp_dir_obj, temp_dir, new_config_data) where new_config_data has the HF
quantization metadata stripped so the downstream MLX load treats it as an
ordinary dense model (and may re-quantize it to MLX affine for LoRA).
"""
import glob
import shutil
import mlx.core as mx
bits = int(quant_config.get("bits", 4) or 4)
group_size = int(quant_config.get("group_size", 128) or 128)
if bits != 4:
raise NotImplementedError(
f"Unsloth: {method.upper()} runtime dequant on MLX currently supports "
f"4-bit checkpoints only (got bits={bits})."
)
shard_paths = sorted(glob.glob(os.path.join(local_path, "*.safetensors")))
if not shard_paths:
raise FileNotFoundError(
f"Unsloth: no .safetensors weights found in '{local_path}' for "
f"{method.upper()} dequantization."
)
weights = {}
for shard in shard_paths:
weights.update(mx.load(shard))
quant_modules = sorted(
{k[: -len(".qweight")] for k in weights if k.endswith(".qweight")}
)
if not quant_modules:
raise ValueError(
f"Unsloth: '{local_path}' is declared {method.upper()} but no packed "
"".qweight" tensors were found."
)
new_weights = {}
quant_related = set()
for name in quant_modules:
qweight = weights[name + ".qweight"]
qzeros = weights[name + ".qzeros"]
scales = weights[name + ".scales"]
if method == "gptq":
g_idx = weights[name + ".g_idx"]
dense = _gptq_dequantize_weight(qweight, qzeros, scales, g_idx, bits=bits)
quant_related.update(
name + suffix
for suffix in (".qweight", ".qzeros", ".scales", ".g_idx")
)
else:
dense = _awq_dequantize_weight(qweight, qzeros, scales, group_size, bits=bits)
quant_related.update(
name + suffix for suffix in (".qweight", ".qzeros", ".scales")
)
new_weights[name + ".weight"] = dense.astype(mx.float16)
for key, tensor in weights.items():
if key in quant_related:
continue
# GPTQ QuantLinear allocates a zero bias even for architectures that
# have no bias (e.g. Llama); drop those so the dense checkpoint matches
# the target module tree. Real (non-zero) biases (e.g. Qwen2 q/k/v) are
# preserved.
if key.endswith(".bias") and bool(mx.all(tensor == 0).item()):
continue
new_weights[key] = tensor
mx.eval(list(new_weights.values()))
temp_dir_obj = tempfile.TemporaryDirectory(prefix="unsloth_mlx_dequant_")
temp_dir = temp_dir_obj.name
try:
for filename in os.listdir(local_path):
src = os.path.join(local_path, filename)
if not os.path.isfile(src):
continue
if filename.endswith(".safetensors") or filename.endswith(".safetensors.index.json"):
continue
shutil.copy(src, os.path.join(temp_dir, filename))
new_config_data = dict(config_data)
new_config_data.pop("quantization_config", None)
new_config_data.pop("quantization", None)
with open(os.path.join(temp_dir, "config.json"), "w") as f:
json.dump(new_config_data, f, indent=2)
mx.save_safetensors(os.path.join(temp_dir, "model.safetensors"), new_weights)
return temp_dir_obj, temp_dir, new_config_data
except Exception:
temp_dir_obj.cleanup()
raise| dequant_temp_dir = None | ||
| hf_prequant_method, hf_prequant_config = _detect_hf_prequant_method(config_data) | ||
| if hf_prequant_method is not None: | ||
| if local_path is None: | ||
| raise FileNotFoundError( | ||
| f"Unsloth: could not resolve local files for " | ||
| f"{hf_prequant_method.upper()} model '{model_name}'." | ||
| ) | ||
| if _is_vlm(config_data): | ||
| raise NotImplementedError( | ||
| f"Unsloth: {hf_prequant_method.upper()} runtime dequant is not " | ||
| "yet supported for vision models on MLX. Load an unquantized " | ||
| "VLM base for LoRA instead." | ||
| ) | ||
| print( | ||
| f"Unsloth: Detected {hf_prequant_method.upper()} pre-quantized " | ||
| f"checkpoint '{model_name}'; dequantizing to fp16 for MLX " | ||
| "(LoRA base will be re-quantized to MLX affine)..." | ||
| ) | ||
| dequant_dir, config_data = _materialize_dequantized_hf_checkpoint( | ||
| original_local_path, config_data, hf_prequant_method, hf_prequant_config, | ||
| ) | ||
| local_path = dequant_dir | ||
| model_name = dequant_dir | ||
| dequant_temp_dir = dequant_dir |
There was a problem hiding this comment.
Update the dequantization setup to accept and track the tempfile.TemporaryDirectory object returned by the updated _materialize_dequantized_hf_checkpoint function.
dequant_temp_dir_obj = None
hf_prequant_method, hf_prequant_config = _detect_hf_prequant_method(config_data)
if hf_prequant_method is not None:
if local_path is None:
raise FileNotFoundError(
f"Unsloth: could not resolve local files for "
f"{hf_prequant_method.upper()} model '{model_name}'."
)
if _is_vlm(config_data):
raise NotImplementedError(
f"Unsloth: {hf_prequant_method.upper()} runtime dequant is not "
"yet supported for vision models on MLX. Load an unquantized "
"VLM base for LoRA instead."
)
print(
f"Unsloth: Detected {hf_prequant_method.upper()} pre-quantized "
f"checkpoint '{model_name}'; dequantizing to fp16 for MLX "
"(LoRA base will be re-quantized to MLX affine)..."
)
dequant_dir_obj, dequant_dir, config_data = _materialize_dequantized_hf_checkpoint(
original_local_path, config_data, hf_prequant_method, hf_prequant_config,
)
local_path = dequant_dir
model_name = dequant_dir
dequant_temp_dir_obj = dequant_dir_obj| if dequant_temp_dir is not None: | ||
| # The dequantized weights are now materialized in memory; the | ||
| # temporary fp16 checkpoint on disk is no longer referenced. | ||
| import mlx.core as mx | ||
| import shutil | ||
|
|
||
| mx.eval(model.parameters()) | ||
| shutil.rmtree(dequant_temp_dir, ignore_errors=True) | ||
| return model, tokenizer |
There was a problem hiding this comment.
Clean up the temporary directory using the TemporaryDirectory.cleanup() method inside a try...finally block to ensure robust cleanup even if mx.eval fails, and gracefully ignore any errors during cleanup.
if dequant_temp_dir_obj is not None:
# The dequantized weights are now materialized in memory; the
# temporary fp16 checkpoint on disk is no longer referenced.
import mlx.core as mx
try:
mx.eval(model.parameters())
finally:
try:
dequant_temp_dir_obj.cleanup()
except Exception:
pass| def _mlx_reinterpret_uint32(arr): | ||
| # Reinterpret an int32 bit pattern as unsigned (widen to int64 so 4-bit | ||
| # nibble shifts on the top byte don't sign-extend). | ||
| import mlx.core as mx | ||
|
|
||
| a = arr.astype(mx.int64) | ||
| return mx.where(a < 0, a + (1 << 32), a) |
There was a problem hiding this comment.
Instead of manually casting to int64 and using mx.where to handle negative values, you can cast directly to mx.uint32 first and then to mx.int64. This is more concise, avoids the overhead of conditional operations, and leverages MLX's built-in modulo arithmetic for signed-to-unsigned conversion.
| def _mlx_reinterpret_uint32(arr): | |
| # Reinterpret an int32 bit pattern as unsigned (widen to int64 so 4-bit | |
| # nibble shifts on the top byte don't sign-extend). | |
| import mlx.core as mx | |
| a = arr.astype(mx.int64) | |
| return mx.where(a < 0, a + (1 << 32), a) | |
| def _mlx_reinterpret_uint32(arr): | |
| # Reinterpret an int32 bit pattern as unsigned (widen to int64 so 4-bit | |
| # nibble shifts on the top byte don't sign-extend). | |
| import mlx.core as mx | |
| return arr.astype(mx.uint32).astype(mx.int64) |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 437bf1a167
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| for suffix in (".qweight", ".qzeros", ".scales", ".g_idx") | ||
| ) | ||
| else: | ||
| dense = _awq_dequantize_weight(qweight, qzeros, scales, group_size, bits=bits) |
There was a problem hiding this comment.
Reject non-GEMM AWQ checkpoints before dequantizing
The new AWQ branch always calls _awq_dequantize_weight, whose unpacker assumes the AutoAWQ GEMM layout (qweight as [in, out//8] and qzeros as [groups, out//8]). AWQ checkpoints can declare other versions such as GEMV/GEMVFast/Marlin with different tensor layouts or no qzeros, so those repos will either fail during unpack/broadcasting or materialize weights with the wrong shape/orientation instead of a clear unsupported-format error. Check quantization_config.version/zero_point before this call and only route GEMM checkpoints into this dequantizer.
Useful? React with 👍 / 👎.
|
|
||
| w = _unpack(qw).astype(mx.float32) # [in, out] | ||
| z = _unpack(qz).astype(mx.float32) # [groups, out] | ||
| g = (mx.arange(w.shape[0]) // group_size).astype(mx.int32) # [in] |
There was a problem hiding this comment.
Normalize AWQ group_size=-1 before indexing scales
When an AWQ config uses group_size=-1, AutoAWQ treats that as a single group spanning the full input dimension. This code passes -1 through, so mx.arange(w.shape[0]) // group_size produces 0, -1, -2, ... and scales[g]/z[g] will index invalid or unintended groups after the first input row, breaking otherwise valid per-channel AWQ checkpoints. Normalize -1 to the input width before computing g.
Useful? React with 👍 / 👎.
| scales = weights[name + ".scales"] | ||
| if method == "gptq": | ||
| g_idx = weights[name + ".g_idx"] | ||
| dense = _gptq_dequantize_weight(qweight, qzeros, scales, g_idx, bits=bits) |
There was a problem hiding this comment.
Reject Marlin GPTQ before default dequantization
When a GPTQ repo advertises a Marlin checkpoint (format/checkpoint_format set to marlin), those weights are repacked for Marlin rather than the default AutoGPTQ layout assumed by _gptq_dequantize_weight. This branch still treats every quant_method='gptq' repo as ordinary AutoGPTQ and indexes qweight/qzeros/g_idx, so Marlin repos can either fail with raw missing-tensor/shape errors or materialize dense weights from misinterpreted packed data. Gate on the GPTQ checkpoint format and reject non-default layouts before this call.
Useful? React with 👍 / 👎.
| quant_config = config_data.get("quantization_config", None) | ||
| if not isinstance(quant_config, dict): | ||
| return None, None |
There was a problem hiding this comment.
Read quant_config.json-only GPTQ/AWQ metadata
This detection only looks at config.json's quantization_config, but valid AutoGPTQ/AutoAWQ repos can keep the quantization metadata in a separate quant_config.json while the checkpoint still contains packed qweight/qzeros tensors. In that case this returns None, skips the new dequantization path, and lets mlx-lm try to load packed tensors as a dense model. Read quant_config.json (including legacy w_bit/q_group_size/version fields) before concluding the repo is not pre-quantized.
Useful? React with 👍 / 👎.
| if key.endswith(".bias") and bool(mx.all(tensor == 0).item()): | ||
| continue |
There was a problem hiding this comment.
Preserve real zero-valued biases
For GPTQ/AWQ checkpoints with a bias-enabled layer whose saved bias happens to be all zeros, this drops the tensor solely because of its value. That turns a valid zero bias into a missing key (or leaves the dense layer with whatever default bias the constructor used), while the same model with a nonzero bias is preserved. Restrict this workaround to synthetic QuantLinear biases for modules whose target dense architecture has no bias, rather than dropping every zero .bias in the checkpoint.
Useful? React with 👍 / 👎.
feat(mlx): load & finetune GPTQ/AWQ pre-quantized checkpoints on Apple Silicon
Summary
Can now load and LoRA-finetune GPTQ and AWQ pre-quantized HuggingFace checkpoints directly on the MLX / Apple-Silicon path. A large fraction of popular models on the Hub ship only as GPTQ or AWQ; until now none of them worked on MLX.
Worse than "not supported," the old behavior was a silent mis-route:
from_pretrainedon a GPTQ repo would print "already quantized — using existing compatible MLX quantization" and then fail deep inside mlx-lm with a confusingFound ... g_idx in weights ... not currently supportederror (AWQ failed similarly). The loader treated the HFquantization_config(which carries bits/group_size) as if it were MLX-native quant metadata, because the only field-level guard understood bitsandbytes.How it works
mlx-lm cannot ingest AutoGPTQ/AutoAWQ packed tensors (
qweight/qzeros/g_idx), so we dequantize them ourselves before the model is built:quantization_config.quant_method = gptq / awq.This mirrors the accepted bitsandbytes runtime-dequant architecture (bnb NF4 → fp16 → MLX 4-bit) — same detect → materialize → recursive-load → cleanup shape, so it reuses the whole load + quant + LoRA path rather than forking it. The original repo id is preserved for
_hf_repo/_src_path, and the temp checkpoint is removed once the weights are materialized in memory.Net change:
unsloth_zoo/mlx/loader.pyonly, +255 / −3 (the 3 deletions reroute metadata to the original repo).What's supported
desc_act/ act-order (real permutedg_idx); non-128 group sizes (verified gs=32 and gs=128)[0,4,1,5,2,6,3,7]; direct zero-point. Reconstructs the AWQ smoothed weight (per-channel scales are folded into the checkpoint), which is exactly what the forward pass consumesAfter dequant the base is re-quantized to MLX affine 4-bit and trains via the normal LoRA path (
load_in_16bit=True/full_finetuning=Truekeep it dense fp16 instead).Dequant correctness
The dequant math was verified bit-exact against the AutoGPTQ/AutoAWQ conventions before any integration:
g_idx/interleave/zero-point convention produces garbage, not fluent output). This caught a real subtlety: AWQ's dequantized weights intentionally differ from the original unquantized weights (the smoothing transform), so weight-vs-original comparison is invalid for AWQ and generation is the correct signal.Tested end-to-end (load → generate → LoRA SFT)
LoRA SFT produces finite, decreasing loss (e.g. GPTQ desc_act=True: 3.65→2.76; AWQ: 4.04→2.58). Non-quantized and mlx-native prequantized loads are unaffected; the MLX baseline test suite still passes (90).
Fails loud instead of mis-routing
Every unsupported input now raises a clear error up front rather than silently mis-routing:
NotImplementedError(vision not yet supported).NotImplementedError(4-bit only).NotImplementedErrornaming the method (bitsandbytes deliberately excluded; it has its own path).Known limitations