-
Notifications
You must be signed in to change notification settings - Fork 290
feat(mlx): load GPTQ/AWQ pre-quantized checkpoints on Apple Silicon #848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2701,6 +2701,195 @@ def _bnb_nested_absmax(absmax): | |
| return dequantized.reshape(original_shape).astype(original_dtype) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # GPTQ / AWQ pre-quantized HF checkpoint support. | ||
| # | ||
| # mlx-lm cannot load AutoGPTQ/AutoAWQ packed weights (qweight/qzeros/g_idx), | ||
| # so we dequantize them to a dense fp16 checkpoint on Apple Silicon and load | ||
| # that instead. The standard runtime-quant path then re-quantizes the dense | ||
| # weights to MLX affine for the LoRA base (mirrors the bnb NF4->fp16->MLX-4bit | ||
| # flow). Dequant math is pure MLX array ops (bit-unpack + group scale/zero), | ||
| # verified bit-exact against AutoGPTQ/AutoAWQ conventions. | ||
| # --------------------------------------------------------------------------- | ||
| _HF_RUNTIME_DEQUANT_METHODS = frozenset({"gptq", "awq"}) | ||
| # Known HF packed quantization methods that mlx-lm cannot load and that we do | ||
| # not dequantize here. These must fail loud with a clear message rather than | ||
| # fall through to the generic MLX-compatibility check (which misreports them as | ||
| # a bits/group_size mismatch). bitsandbytes is intentionally excluded — it is | ||
| # handled by a separate runtime-dequant path. | ||
| _HF_UNSUPPORTED_PACKED_METHODS = frozenset({ | ||
| "compressed-tensors", "compressed_tensors", "aqlm", | ||
| "quip", "quip_sharp", "eetq", "hqq", "vptq", "fp_quant", | ||
| }) | ||
| _AWQ_REVERSE_ORDER = (0, 4, 1, 5, 2, 6, 3, 7) | ||
|
|
||
|
|
||
| def _mlx_reinterpret_uint32(arr): | ||
| # Reinterpret an int32 bit pattern as unsigned (widen to int64 so 4-bit | ||
| # nibble shifts on the top byte don't sign-extend). | ||
| import mlx.core as mx | ||
|
|
||
| a = arr.astype(mx.int64) | ||
| return mx.where(a < 0, a + (1 << 32), a) | ||
|
|
||
|
|
||
| def _gptq_dequantize_weight(qweight, qzeros, scales, g_idx, bits=4): | ||
| """AutoGPTQ 4-bit -> dense [out, in] weight (HF nn.Linear orientation). | ||
|
|
||
| qweight [in//8, out] packs 8 input rows per int32; qzeros [groups, out//8] | ||
| packs 8 output cols per int32; g_idx [in] maps each input row to its group | ||
| (a real permutation when desc_act/act-order is enabled). Zeros use the | ||
| AutoGPTQ ``stored + 1`` convention (symmetric models store a constant). | ||
| """ | ||
| import mlx.core as mx | ||
|
|
||
| qw = _mlx_reinterpret_uint32(qweight) # [in//8, out] | ||
| qz = _mlx_reinterpret_uint32(qzeros) # [groups, out//8] | ||
| scales = scales.astype(mx.float32) # [groups, out] | ||
| shifts = mx.arange(0, 32, bits).astype(mx.int64) # [8] | ||
| w = (mx.right_shift(qw[:, None, :], shifts[None, :, None]) & 0xF) | ||
| w = w.reshape(-1, qw.shape[1]).astype(mx.float32) # [in, out] | ||
| z = (mx.right_shift(qz[:, :, None], shifts[None, None, :]) & 0xF) | ||
| z = z.reshape(qz.shape[0], -1).astype(mx.float32) + 1.0 # [groups, out] | ||
| g = g_idx.astype(mx.int32) # [in] | ||
| eff = (w - z[g]) * scales[g] # [in, out] | ||
| return mx.transpose(eff) # [out, in] | ||
|
|
||
|
|
||
| def _awq_dequantize_weight(qweight, qzeros, scales, group_size, bits=4): | ||
| """AutoAWQ GEMM 4-bit -> dense [out, in] weight (HF nn.Linear orientation). | ||
|
|
||
| qweight [in, out//8] and qzeros [groups, out//8] pack 8 output cols per | ||
| int32 with the AWQ interleave; ``_AWQ_REVERSE_ORDER`` restores natural | ||
| column order to align with the (non-interleaved) scales. Note the | ||
| reconstructed weight is the AWQ *smoothed* weight (per-channel scales are | ||
| folded into the checkpoint), which is exactly what the forward pass needs. | ||
| """ | ||
| import mlx.core as mx | ||
|
|
||
| qw = _mlx_reinterpret_uint32(qweight) # [in, out//8] | ||
| qz = _mlx_reinterpret_uint32(qzeros) # [groups, out//8] | ||
| scales = scales.astype(mx.float32) # [groups, out] | ||
| shifts = mx.arange(0, 32, bits).astype(mx.int64) | ||
| pack = 32 // bits # 8 | ||
| reorder = mx.array( | ||
| [b * pack + o for b in range(scales.shape[1] // pack) for o in _AWQ_REVERSE_ORDER] | ||
| ) | ||
|
|
||
| def _unpack(x): | ||
| y = (mx.right_shift(x[:, :, None], shifts[None, None, :]) & 0xF).reshape(x.shape[0], -1) | ||
| return y[:, reorder] | ||
|
|
||
| w = _unpack(qw).astype(mx.float32) # [in, out] | ||
| z = _unpack(qz).astype(mx.float32) # [groups, out] | ||
| g = (mx.arange(w.shape[0]) // group_size).astype(mx.int32) # [in] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When an AWQ config uses Useful? React with 👍 / 👎. |
||
| eff = (w - z[g]) * scales[g] # [in, out] | ||
| return mx.transpose(eff) # [out, in] | ||
|
|
||
|
|
||
| def _detect_hf_prequant_method(config_data): | ||
| """Return (method, quant_config_dict) for a GPTQ/AWQ repo, else (None, None).""" | ||
| if not isinstance(config_data, dict): | ||
| return None, None | ||
| quant_config = config_data.get("quantization_config", None) | ||
| if not isinstance(quant_config, dict): | ||
| return None, None | ||
|
Comment on lines
+2794
to
+2796
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This detection only looks at Useful? React with 👍 / 👎. |
||
| method = str(quant_config.get("quant_method", "")).lower() | ||
| if method in _HF_RUNTIME_DEQUANT_METHODS: | ||
| return method, quant_config | ||
| return None, None | ||
|
|
||
|
|
||
| def _materialize_dequantized_hf_checkpoint(local_path, config_data, method, quant_config): | ||
| """Dequantize a GPTQ/AWQ checkpoint to a temporary dense fp16 checkpoint. | ||
|
|
||
| Returns (temp_dir, new_config_data) where new_config_data has the HF | ||
| quantization metadata stripped so the downstream MLX load treats it as an | ||
| ordinary dense model (and may re-quantize it to MLX affine for LoRA). | ||
| """ | ||
| import glob | ||
| import shutil | ||
| import mlx.core as mx | ||
|
|
||
| bits = int(quant_config.get("bits", 4) or 4) | ||
| group_size = int(quant_config.get("group_size", 128) or 128) | ||
| if bits != 4: | ||
| raise NotImplementedError( | ||
| f"Unsloth: {method.upper()} runtime dequant on MLX currently supports " | ||
| f"4-bit checkpoints only (got bits={bits})." | ||
| ) | ||
|
|
||
| shard_paths = sorted(glob.glob(os.path.join(local_path, "*.safetensors"))) | ||
| if not shard_paths: | ||
| raise FileNotFoundError( | ||
| f"Unsloth: no .safetensors weights found in '{local_path}' for " | ||
| f"{method.upper()} dequantization." | ||
| ) | ||
| weights = {} | ||
| for shard in shard_paths: | ||
| weights.update(mx.load(shard)) | ||
|
|
||
| quant_modules = sorted( | ||
| {k[: -len(".qweight")] for k in weights if k.endswith(".qweight")} | ||
| ) | ||
| if not quant_modules: | ||
| raise ValueError( | ||
| f"Unsloth: '{local_path}' is declared {method.upper()} but no packed " | ||
| "'.qweight' tensors were found." | ||
| ) | ||
|
|
||
| new_weights = {} | ||
| quant_related = set() | ||
| for name in quant_modules: | ||
| qweight = weights[name + ".qweight"] | ||
| qzeros = weights[name + ".qzeros"] | ||
| scales = weights[name + ".scales"] | ||
| if method == "gptq": | ||
| g_idx = weights[name + ".g_idx"] | ||
| dense = _gptq_dequantize_weight(qweight, qzeros, scales, g_idx, bits=bits) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When a GPTQ repo advertises a Marlin checkpoint ( Useful? React with 👍 / 👎. |
||
| quant_related.update( | ||
| name + suffix | ||
| for suffix in (".qweight", ".qzeros", ".scales", ".g_idx") | ||
| ) | ||
| else: | ||
| dense = _awq_dequantize_weight(qweight, qzeros, scales, group_size, bits=bits) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The new AWQ branch always calls Useful? React with 👍 / 👎. |
||
| quant_related.update( | ||
| name + suffix for suffix in (".qweight", ".qzeros", ".scales") | ||
| ) | ||
| new_weights[name + ".weight"] = dense.astype(mx.float16) | ||
|
|
||
| for key, tensor in weights.items(): | ||
| if key in quant_related: | ||
| continue | ||
| # GPTQ QuantLinear allocates a zero bias even for architectures that | ||
| # have no bias (e.g. Llama); drop those so the dense checkpoint matches | ||
| # the target module tree. Real (non-zero) biases (e.g. Qwen2 q/k/v) are | ||
| # preserved. | ||
| if key.endswith(".bias") and bool(mx.all(tensor == 0).item()): | ||
| continue | ||
|
Comment on lines
+2868
to
+2869
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For GPTQ/AWQ checkpoints with a bias-enabled layer whose saved bias happens to be all zeros, this drops the tensor solely because of its value. That turns a valid zero bias into a missing key (or leaves the dense layer with whatever default bias the constructor used), while the same model with a nonzero bias is preserved. Restrict this workaround to synthetic QuantLinear biases for modules whose target dense architecture has no bias, rather than dropping every zero Useful? React with 👍 / 👎. |
||
| new_weights[key] = tensor | ||
|
|
||
| mx.eval(list(new_weights.values())) | ||
|
|
||
| temp_dir = tempfile.mkdtemp(prefix="unsloth_mlx_dequant_") | ||
| for filename in os.listdir(local_path): | ||
| src = os.path.join(local_path, filename) | ||
| if not os.path.isfile(src): | ||
| continue | ||
| if filename.endswith(".safetensors") or filename.endswith(".safetensors.index.json"): | ||
| continue | ||
| shutil.copy(src, os.path.join(temp_dir, filename)) | ||
|
|
||
| new_config_data = dict(config_data) | ||
| new_config_data.pop("quantization_config", None) | ||
| new_config_data.pop("quantization", None) | ||
| with open(os.path.join(temp_dir, "config.json"), "w") as f: | ||
| json.dump(new_config_data, f, indent=2) | ||
|
|
||
| mx.save_safetensors(os.path.join(temp_dir, "model.safetensors"), new_weights) | ||
| return temp_dir, new_config_data | ||
|
Comment on lines
+2803
to
+2890
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using def _materialize_dequantized_hf_checkpoint(local_path, config_data, method, quant_config):
"""Dequantize a GPTQ/AWQ checkpoint to a temporary dense fp16 checkpoint.
Returns (temp_dir_obj, temp_dir, new_config_data) where new_config_data has the HF
quantization metadata stripped so the downstream MLX load treats it as an
ordinary dense model (and may re-quantize it to MLX affine for LoRA).
"""
import glob
import shutil
import mlx.core as mx
bits = int(quant_config.get("bits", 4) or 4)
group_size = int(quant_config.get("group_size", 128) or 128)
if bits != 4:
raise NotImplementedError(
f"Unsloth: {method.upper()} runtime dequant on MLX currently supports "
f"4-bit checkpoints only (got bits={bits})."
)
shard_paths = sorted(glob.glob(os.path.join(local_path, "*.safetensors")))
if not shard_paths:
raise FileNotFoundError(
f"Unsloth: no .safetensors weights found in '{local_path}' for "
f"{method.upper()} dequantization."
)
weights = {}
for shard in shard_paths:
weights.update(mx.load(shard))
quant_modules = sorted(
{k[: -len(".qweight")] for k in weights if k.endswith(".qweight")}
)
if not quant_modules:
raise ValueError(
f"Unsloth: '{local_path}' is declared {method.upper()} but no packed "
"".qweight" tensors were found."
)
new_weights = {}
quant_related = set()
for name in quant_modules:
qweight = weights[name + ".qweight"]
qzeros = weights[name + ".qzeros"]
scales = weights[name + ".scales"]
if method == "gptq":
g_idx = weights[name + ".g_idx"]
dense = _gptq_dequantize_weight(qweight, qzeros, scales, g_idx, bits=bits)
quant_related.update(
name + suffix
for suffix in (".qweight", ".qzeros", ".scales", ".g_idx")
)
else:
dense = _awq_dequantize_weight(qweight, qzeros, scales, group_size, bits=bits)
quant_related.update(
name + suffix for suffix in (".qweight", ".qzeros", ".scales")
)
new_weights[name + ".weight"] = dense.astype(mx.float16)
for key, tensor in weights.items():
if key in quant_related:
continue
# GPTQ QuantLinear allocates a zero bias even for architectures that
# have no bias (e.g. Llama); drop those so the dense checkpoint matches
# the target module tree. Real (non-zero) biases (e.g. Qwen2 q/k/v) are
# preserved.
if key.endswith(".bias") and bool(mx.all(tensor == 0).item()):
continue
new_weights[key] = tensor
mx.eval(list(new_weights.values()))
temp_dir_obj = tempfile.TemporaryDirectory(prefix="unsloth_mlx_dequant_")
temp_dir = temp_dir_obj.name
try:
for filename in os.listdir(local_path):
src = os.path.join(local_path, filename)
if not os.path.isfile(src):
continue
if filename.endswith(".safetensors") or filename.endswith(".safetensors.index.json"):
continue
shutil.copy(src, os.path.join(temp_dir, filename))
new_config_data = dict(config_data)
new_config_data.pop("quantization_config", None)
new_config_data.pop("quantization", None)
with open(os.path.join(temp_dir, "config.json"), "w") as f:
json.dump(new_config_data, f, indent=2)
mx.save_safetensors(os.path.join(temp_dir, "model.safetensors"), new_weights)
return temp_dir_obj, temp_dir, new_config_data
except Exception:
temp_dir_obj.cleanup()
raise |
||
|
|
||
|
|
||
| def _apply_dense_nf4_quantization(model, config, spec: _MLXQuantizationSpec, predicate): | ||
| import mlx.core as mx | ||
|
|
||
|
|
@@ -3635,6 +3824,60 @@ def from_pretrained( | |
| config_data, | ||
| ) | ||
|
|
||
| # Preserve the caller-facing identity: GPTQ/AWQ dequant reroutes the | ||
| # load through a temp dir, but metadata (_hf_repo/_src_path) must keep | ||
| # pointing at the original repo so save/reload resolve correctly. | ||
| original_model_name = model_name | ||
| original_local_path = local_path | ||
|
|
||
| # GPTQ/AWQ pre-quantized checkpoints: mlx-lm can't load their packed | ||
| # weights. Dequantize to a temporary dense fp16 checkpoint and load | ||
| # that; the runtime-quant path below then re-quantizes to MLX affine | ||
| # for the LoRA base (bnb NF4->fp16->MLX-4bit style flow). | ||
| dequant_temp_dir = None | ||
| hf_prequant_method, hf_prequant_config = _detect_hf_prequant_method(config_data) | ||
| if hf_prequant_method is not None: | ||
| if local_path is None: | ||
| raise FileNotFoundError( | ||
| f"Unsloth: could not resolve local files for " | ||
| f"{hf_prequant_method.upper()} model '{model_name}'." | ||
| ) | ||
| if _is_vlm(config_data): | ||
| raise NotImplementedError( | ||
| f"Unsloth: {hf_prequant_method.upper()} runtime dequant is not " | ||
| "yet supported for vision models on MLX. Load an unquantized " | ||
| "VLM base for LoRA instead." | ||
| ) | ||
| print( | ||
| f"Unsloth: Detected {hf_prequant_method.upper()} pre-quantized " | ||
| f"checkpoint '{model_name}'; dequantizing to fp16 for MLX " | ||
| "(LoRA base will be re-quantized to MLX affine)..." | ||
| ) | ||
| dequant_dir, config_data = _materialize_dequantized_hf_checkpoint( | ||
| original_local_path, config_data, hf_prequant_method, hf_prequant_config, | ||
| ) | ||
| local_path = dequant_dir | ||
| model_name = dequant_dir | ||
| dequant_temp_dir = dequant_dir | ||
|
Comment on lines
+3837
to
+3861
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the dequantization setup to accept and track the dequant_temp_dir_obj = None
hf_prequant_method, hf_prequant_config = _detect_hf_prequant_method(config_data)
if hf_prequant_method is not None:
if local_path is None:
raise FileNotFoundError(
f"Unsloth: could not resolve local files for "
f"{hf_prequant_method.upper()} model '{model_name}'."
)
if _is_vlm(config_data):
raise NotImplementedError(
f"Unsloth: {hf_prequant_method.upper()} runtime dequant is not "
"yet supported for vision models on MLX. Load an unquantized "
"VLM base for LoRA instead."
)
print(
f"Unsloth: Detected {hf_prequant_method.upper()} pre-quantized "
f"checkpoint '{model_name}'; dequantizing to fp16 for MLX "
"(LoRA base will be re-quantized to MLX affine)..."
)
dequant_dir_obj, dequant_dir, config_data = _materialize_dequantized_hf_checkpoint(
original_local_path, config_data, hf_prequant_method, hf_prequant_config,
)
local_path = dequant_dir
model_name = dequant_dir
dequant_temp_dir_obj = dequant_dir_obj |
||
| else: | ||
| # A recognized-but-unsupported packed quant format must fail loud | ||
| # with a clear message instead of misrouting into the generic | ||
| # MLX-compatibility check. | ||
| _other_quant = ( | ||
| config_data.get("quantization_config") | ||
| if isinstance(config_data, dict) else None | ||
| ) | ||
| if isinstance(_other_quant, dict): | ||
| _other_method = str(_other_quant.get("quant_method", "")).lower() | ||
| if _other_method in _HF_UNSUPPORTED_PACKED_METHODS: | ||
| raise NotImplementedError( | ||
| f"Unsloth: '{model_name}' uses '{_other_method}' " | ||
| "quantization, which is not supported on the MLX path. " | ||
| "Supported pre-quantized formats are GPTQ and AWQ " | ||
| "(dequantized to MLX affine for LoRA); otherwise load an " | ||
| "unquantized or MLX-quantized checkpoint." | ||
| ) | ||
|
|
||
| # Reject full_finetuning on a pre-quantized repo: int4/int8 weights | ||
| # aren't trainable (our CCE backward zeros the quantized weight grad), | ||
| # so full FT would silently update only LayerNorms/biases. | ||
|
|
@@ -4218,10 +4461,10 @@ def from_pretrained( | |
| model._is_vlm_model = False | ||
|
|
||
| model._config = config | ||
| model._hf_repo = model_name | ||
| model._src_path = local_path | ||
| model._hf_repo = original_model_name | ||
| model._src_path = original_local_path | ||
| model._unsloth_base_revision = revision | ||
| model._unsloth_base_commit_hash = _infer_snapshot_commit(local_path) | ||
| model._unsloth_base_commit_hash = _infer_snapshot_commit(original_local_path) | ||
| model.max_seq_length = max_seq_length | ||
| model._unsloth_patch_mode = patch_mode | ||
| model._unsloth_full_finetuning = bool(full_finetuning) | ||
|
|
@@ -4232,6 +4475,15 @@ def from_pretrained( | |
| _patch_mixed_precision_set_dtype(model) | ||
|
|
||
| _patch_mlx_saving(model, tokenizer) | ||
|
|
||
| if dequant_temp_dir is not None: | ||
| # The dequantized weights are now materialized in memory; the | ||
| # temporary fp16 checkpoint on disk is no longer referenced. | ||
| import mlx.core as mx | ||
| import shutil | ||
|
|
||
| mx.eval(model.parameters()) | ||
| shutil.rmtree(dequant_temp_dir, ignore_errors=True) | ||
| return model, tokenizer | ||
|
Comment on lines
+4479
to
4487
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clean up the temporary directory using the if dequant_temp_dir_obj is not None:
# The dequantized weights are now materialized in memory; the
# temporary fp16 checkpoint on disk is no longer referenced.
import mlx.core as mx
try:
mx.eval(model.parameters())
finally:
try:
dequant_temp_dir_obj.cleanup()
except Exception:
pass |
||
|
|
||
| @staticmethod | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of manually casting to
int64and usingmx.whereto handle negative values, you can cast directly tomx.uint32first and then tomx.int64. This is more concise, avoids the overhead of conditional operations, and leverages MLX's built-in modulo arithmetic for signed-to-unsigned conversion.