Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 255 additions & 3 deletions unsloth_zoo/mlx/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2701,6 +2701,195 @@ def _bnb_nested_absmax(absmax):
return dequantized.reshape(original_shape).astype(original_dtype)


# ---------------------------------------------------------------------------
# GPTQ / AWQ pre-quantized HF checkpoint support.
#
# mlx-lm cannot load AutoGPTQ/AutoAWQ packed weights (qweight/qzeros/g_idx),
# so we dequantize them to a dense fp16 checkpoint on Apple Silicon and load
# that instead. The standard runtime-quant path then re-quantizes the dense
# weights to MLX affine for the LoRA base (mirrors the bnb NF4->fp16->MLX-4bit
# flow). Dequant math is pure MLX array ops (bit-unpack + group scale/zero),
# verified bit-exact against AutoGPTQ/AutoAWQ conventions.
# ---------------------------------------------------------------------------
_HF_RUNTIME_DEQUANT_METHODS = frozenset({"gptq", "awq"})
# Known HF packed quantization methods that mlx-lm cannot load and that we do
# not dequantize here. These must fail loud with a clear message rather than
# fall through to the generic MLX-compatibility check (which misreports them as
# a bits/group_size mismatch). bitsandbytes is intentionally excluded — it is
# handled by a separate runtime-dequant path.
_HF_UNSUPPORTED_PACKED_METHODS = frozenset({
"compressed-tensors", "compressed_tensors", "aqlm",
"quip", "quip_sharp", "eetq", "hqq", "vptq", "fp_quant",
})
_AWQ_REVERSE_ORDER = (0, 4, 1, 5, 2, 6, 3, 7)


def _mlx_reinterpret_uint32(arr):
# Reinterpret an int32 bit pattern as unsigned (widen to int64 so 4-bit
# nibble shifts on the top byte don't sign-extend).
import mlx.core as mx

a = arr.astype(mx.int64)
return mx.where(a < 0, a + (1 << 32), a)
Comment on lines +2727 to +2733

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of manually casting to int64 and using mx.where to handle negative values, you can cast directly to mx.uint32 first and then to mx.int64. This is more concise, avoids the overhead of conditional operations, and leverages MLX's built-in modulo arithmetic for signed-to-unsigned conversion.

Suggested change
def _mlx_reinterpret_uint32(arr):
# Reinterpret an int32 bit pattern as unsigned (widen to int64 so 4-bit
# nibble shifts on the top byte don't sign-extend).
import mlx.core as mx
a = arr.astype(mx.int64)
return mx.where(a < 0, a + (1 << 32), a)
def _mlx_reinterpret_uint32(arr):
# Reinterpret an int32 bit pattern as unsigned (widen to int64 so 4-bit
# nibble shifts on the top byte don't sign-extend).
import mlx.core as mx
return arr.astype(mx.uint32).astype(mx.int64)



def _gptq_dequantize_weight(qweight, qzeros, scales, g_idx, bits=4):
"""AutoGPTQ 4-bit -> dense [out, in] weight (HF nn.Linear orientation).

qweight [in//8, out] packs 8 input rows per int32; qzeros [groups, out//8]
packs 8 output cols per int32; g_idx [in] maps each input row to its group
(a real permutation when desc_act/act-order is enabled). Zeros use the
AutoGPTQ ``stored + 1`` convention (symmetric models store a constant).
"""
import mlx.core as mx

qw = _mlx_reinterpret_uint32(qweight) # [in//8, out]
qz = _mlx_reinterpret_uint32(qzeros) # [groups, out//8]
scales = scales.astype(mx.float32) # [groups, out]
shifts = mx.arange(0, 32, bits).astype(mx.int64) # [8]
w = (mx.right_shift(qw[:, None, :], shifts[None, :, None]) & 0xF)
w = w.reshape(-1, qw.shape[1]).astype(mx.float32) # [in, out]
z = (mx.right_shift(qz[:, :, None], shifts[None, None, :]) & 0xF)
z = z.reshape(qz.shape[0], -1).astype(mx.float32) + 1.0 # [groups, out]
g = g_idx.astype(mx.int32) # [in]
eff = (w - z[g]) * scales[g] # [in, out]
return mx.transpose(eff) # [out, in]


def _awq_dequantize_weight(qweight, qzeros, scales, group_size, bits=4):
"""AutoAWQ GEMM 4-bit -> dense [out, in] weight (HF nn.Linear orientation).

qweight [in, out//8] and qzeros [groups, out//8] pack 8 output cols per
int32 with the AWQ interleave; ``_AWQ_REVERSE_ORDER`` restores natural
column order to align with the (non-interleaved) scales. Note the
reconstructed weight is the AWQ *smoothed* weight (per-channel scales are
folded into the checkpoint), which is exactly what the forward pass needs.
"""
import mlx.core as mx

qw = _mlx_reinterpret_uint32(qweight) # [in, out//8]
qz = _mlx_reinterpret_uint32(qzeros) # [groups, out//8]
scales = scales.astype(mx.float32) # [groups, out]
shifts = mx.arange(0, 32, bits).astype(mx.int64)
pack = 32 // bits # 8
reorder = mx.array(
[b * pack + o for b in range(scales.shape[1] // pack) for o in _AWQ_REVERSE_ORDER]
)

def _unpack(x):
y = (mx.right_shift(x[:, :, None], shifts[None, None, :]) & 0xF).reshape(x.shape[0], -1)
return y[:, reorder]

w = _unpack(qw).astype(mx.float32) # [in, out]
z = _unpack(qz).astype(mx.float32) # [groups, out]
g = (mx.arange(w.shape[0]) // group_size).astype(mx.int32) # [in]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize AWQ group_size=-1 before indexing scales

When an AWQ config uses group_size=-1, AutoAWQ treats that as a single group spanning the full input dimension. This code passes -1 through, so mx.arange(w.shape[0]) // group_size produces 0, -1, -2, ... and scales[g]/z[g] will index invalid or unintended groups after the first input row, breaking otherwise valid per-channel AWQ checkpoints. Normalize -1 to the input width before computing g.

Useful? React with 👍 / 👎.

eff = (w - z[g]) * scales[g] # [in, out]
return mx.transpose(eff) # [out, in]


def _detect_hf_prequant_method(config_data):
"""Return (method, quant_config_dict) for a GPTQ/AWQ repo, else (None, None)."""
if not isinstance(config_data, dict):
return None, None
quant_config = config_data.get("quantization_config", None)
if not isinstance(quant_config, dict):
return None, None
Comment on lines +2794 to +2796

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Read quant_config.json-only GPTQ/AWQ metadata

This detection only looks at config.json's quantization_config, but valid AutoGPTQ/AutoAWQ repos can keep the quantization metadata in a separate quant_config.json while the checkpoint still contains packed qweight/qzeros tensors. In that case this returns None, skips the new dequantization path, and lets mlx-lm try to load packed tensors as a dense model. Read quant_config.json (including legacy w_bit/q_group_size/version fields) before concluding the repo is not pre-quantized.

Useful? React with 👍 / 👎.

method = str(quant_config.get("quant_method", "")).lower()
if method in _HF_RUNTIME_DEQUANT_METHODS:
return method, quant_config
return None, None


def _materialize_dequantized_hf_checkpoint(local_path, config_data, method, quant_config):
"""Dequantize a GPTQ/AWQ checkpoint to a temporary dense fp16 checkpoint.

Returns (temp_dir, new_config_data) where new_config_data has the HF
quantization metadata stripped so the downstream MLX load treats it as an
ordinary dense model (and may re-quantize it to MLX affine for LoRA).
"""
import glob
import shutil
import mlx.core as mx

bits = int(quant_config.get("bits", 4) or 4)
group_size = int(quant_config.get("group_size", 128) or 128)
if bits != 4:
raise NotImplementedError(
f"Unsloth: {method.upper()} runtime dequant on MLX currently supports "
f"4-bit checkpoints only (got bits={bits})."
)

shard_paths = sorted(glob.glob(os.path.join(local_path, "*.safetensors")))
if not shard_paths:
raise FileNotFoundError(
f"Unsloth: no .safetensors weights found in '{local_path}' for "
f"{method.upper()} dequantization."
)
weights = {}
for shard in shard_paths:
weights.update(mx.load(shard))

quant_modules = sorted(
{k[: -len(".qweight")] for k in weights if k.endswith(".qweight")}
)
if not quant_modules:
raise ValueError(
f"Unsloth: '{local_path}' is declared {method.upper()} but no packed "
"'.qweight' tensors were found."
)

new_weights = {}
quant_related = set()
for name in quant_modules:
qweight = weights[name + ".qweight"]
qzeros = weights[name + ".qzeros"]
scales = weights[name + ".scales"]
if method == "gptq":
g_idx = weights[name + ".g_idx"]
dense = _gptq_dequantize_weight(qweight, qzeros, scales, g_idx, bits=bits)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reject Marlin GPTQ before default dequantization

When a GPTQ repo advertises a Marlin checkpoint (format/checkpoint_format set to marlin), those weights are repacked for Marlin rather than the default AutoGPTQ layout assumed by _gptq_dequantize_weight. This branch still treats every quant_method='gptq' repo as ordinary AutoGPTQ and indexes qweight/qzeros/g_idx, so Marlin repos can either fail with raw missing-tensor/shape errors or materialize dense weights from misinterpreted packed data. Gate on the GPTQ checkpoint format and reject non-default layouts before this call.

Useful? React with 👍 / 👎.

quant_related.update(
name + suffix
for suffix in (".qweight", ".qzeros", ".scales", ".g_idx")
)
else:
dense = _awq_dequantize_weight(qweight, qzeros, scales, group_size, bits=bits)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reject non-GEMM AWQ checkpoints before dequantizing

The new AWQ branch always calls _awq_dequantize_weight, whose unpacker assumes the AutoAWQ GEMM layout (qweight as [in, out//8] and qzeros as [groups, out//8]). AWQ checkpoints can declare other versions such as GEMV/GEMVFast/Marlin with different tensor layouts or no qzeros, so those repos will either fail during unpack/broadcasting or materialize weights with the wrong shape/orientation instead of a clear unsupported-format error. Check quantization_config.version/zero_point before this call and only route GEMM checkpoints into this dequantizer.

Useful? React with 👍 / 👎.

quant_related.update(
name + suffix for suffix in (".qweight", ".qzeros", ".scales")
)
new_weights[name + ".weight"] = dense.astype(mx.float16)

for key, tensor in weights.items():
if key in quant_related:
continue
# GPTQ QuantLinear allocates a zero bias even for architectures that
# have no bias (e.g. Llama); drop those so the dense checkpoint matches
# the target module tree. Real (non-zero) biases (e.g. Qwen2 q/k/v) are
# preserved.
if key.endswith(".bias") and bool(mx.all(tensor == 0).item()):
continue
Comment on lines +2868 to +2869

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve real zero-valued biases

For GPTQ/AWQ checkpoints with a bias-enabled layer whose saved bias happens to be all zeros, this drops the tensor solely because of its value. That turns a valid zero bias into a missing key (or leaves the dense layer with whatever default bias the constructor used), while the same model with a nonzero bias is preserved. Restrict this workaround to synthetic QuantLinear biases for modules whose target dense architecture has no bias, rather than dropping every zero .bias in the checkpoint.

Useful? React with 👍 / 👎.

new_weights[key] = tensor

mx.eval(list(new_weights.values()))

temp_dir = tempfile.mkdtemp(prefix="unsloth_mlx_dequant_")
for filename in os.listdir(local_path):
src = os.path.join(local_path, filename)
if not os.path.isfile(src):
continue
if filename.endswith(".safetensors") or filename.endswith(".safetensors.index.json"):
continue
shutil.copy(src, os.path.join(temp_dir, filename))

new_config_data = dict(config_data)
new_config_data.pop("quantization_config", None)
new_config_data.pop("quantization", None)
with open(os.path.join(temp_dir, "config.json"), "w") as f:
json.dump(new_config_data, f, indent=2)

mx.save_safetensors(os.path.join(temp_dir, "model.safetensors"), new_weights)
return temp_dir, new_config_data
Comment on lines +2803 to +2890

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using tempfile.mkdtemp can lead to persistent resource leaks of several gigabytes on disk if an exception is raised during the loading process (e.g., due to network issues, missing files, or shape mismatches). To ensure robust cleanup, use tempfile.TemporaryDirectory which automatically deletes the directory and its contents when garbage collected or when .cleanup() is called.

def _materialize_dequantized_hf_checkpoint(local_path, config_data, method, quant_config):
    """Dequantize a GPTQ/AWQ checkpoint to a temporary dense fp16 checkpoint.

    Returns (temp_dir_obj, temp_dir, new_config_data) where new_config_data has the HF
    quantization metadata stripped so the downstream MLX load treats it as an
    ordinary dense model (and may re-quantize it to MLX affine for LoRA).
    """
    import glob
    import shutil
    import mlx.core as mx

    bits = int(quant_config.get("bits", 4) or 4)
    group_size = int(quant_config.get("group_size", 128) or 128)
    if bits != 4:
        raise NotImplementedError(
            f"Unsloth: {method.upper()} runtime dequant on MLX currently supports "
            f"4-bit checkpoints only (got bits={bits})."
        )

    shard_paths = sorted(glob.glob(os.path.join(local_path, "*.safetensors")))
    if not shard_paths:
        raise FileNotFoundError(
            f"Unsloth: no .safetensors weights found in '{local_path}' for "
            f"{method.upper()} dequantization."
        )
    weights = {}
    for shard in shard_paths:
        weights.update(mx.load(shard))

    quant_modules = sorted(
        {k[: -len(".qweight")] for k in weights if k.endswith(".qweight")}
    )
    if not quant_modules:
        raise ValueError(
            f"Unsloth: '{local_path}' is declared {method.upper()} but no packed "
            "".qweight" tensors were found."
        )

    new_weights = {}
    quant_related = set()
    for name in quant_modules:
        qweight = weights[name + ".qweight"]
        qzeros = weights[name + ".qzeros"]
        scales = weights[name + ".scales"]
        if method == "gptq":
            g_idx = weights[name + ".g_idx"]
            dense = _gptq_dequantize_weight(qweight, qzeros, scales, g_idx, bits=bits)
            quant_related.update(
                name + suffix
                for suffix in (".qweight", ".qzeros", ".scales", ".g_idx")
            )
        else:
            dense = _awq_dequantize_weight(qweight, qzeros, scales, group_size, bits=bits)
            quant_related.update(
                name + suffix for suffix in (".qweight", ".qzeros", ".scales")
            )
        new_weights[name + ".weight"] = dense.astype(mx.float16)

    for key, tensor in weights.items():
        if key in quant_related:
            continue
        # GPTQ QuantLinear allocates a zero bias even for architectures that
        # have no bias (e.g. Llama); drop those so the dense checkpoint matches
        # the target module tree. Real (non-zero) biases (e.g. Qwen2 q/k/v) are
        # preserved.
        if key.endswith(".bias") and bool(mx.all(tensor == 0).item()):
            continue
        new_weights[key] = tensor

    mx.eval(list(new_weights.values()))

    temp_dir_obj = tempfile.TemporaryDirectory(prefix="unsloth_mlx_dequant_")
    temp_dir = temp_dir_obj.name
    try:
        for filename in os.listdir(local_path):
            src = os.path.join(local_path, filename)
            if not os.path.isfile(src):
                continue
            if filename.endswith(".safetensors") or filename.endswith(".safetensors.index.json"):
                continue
            shutil.copy(src, os.path.join(temp_dir, filename))

        new_config_data = dict(config_data)
        new_config_data.pop("quantization_config", None)
        new_config_data.pop("quantization", None)
        with open(os.path.join(temp_dir, "config.json"), "w") as f:
            json.dump(new_config_data, f, indent=2)

        mx.save_safetensors(os.path.join(temp_dir, "model.safetensors"), new_weights)
        return temp_dir_obj, temp_dir, new_config_data
    except Exception:
        temp_dir_obj.cleanup()
        raise



def _apply_dense_nf4_quantization(model, config, spec: _MLXQuantizationSpec, predicate):
import mlx.core as mx

Expand Down Expand Up @@ -3635,6 +3824,60 @@ def from_pretrained(
config_data,
)

# Preserve the caller-facing identity: GPTQ/AWQ dequant reroutes the
# load through a temp dir, but metadata (_hf_repo/_src_path) must keep
# pointing at the original repo so save/reload resolve correctly.
original_model_name = model_name
original_local_path = local_path

# GPTQ/AWQ pre-quantized checkpoints: mlx-lm can't load their packed
# weights. Dequantize to a temporary dense fp16 checkpoint and load
# that; the runtime-quant path below then re-quantizes to MLX affine
# for the LoRA base (bnb NF4->fp16->MLX-4bit style flow).
dequant_temp_dir = None
hf_prequant_method, hf_prequant_config = _detect_hf_prequant_method(config_data)
if hf_prequant_method is not None:
if local_path is None:
raise FileNotFoundError(
f"Unsloth: could not resolve local files for "
f"{hf_prequant_method.upper()} model '{model_name}'."
)
if _is_vlm(config_data):
raise NotImplementedError(
f"Unsloth: {hf_prequant_method.upper()} runtime dequant is not "
"yet supported for vision models on MLX. Load an unquantized "
"VLM base for LoRA instead."
)
print(
f"Unsloth: Detected {hf_prequant_method.upper()} pre-quantized "
f"checkpoint '{model_name}'; dequantizing to fp16 for MLX "
"(LoRA base will be re-quantized to MLX affine)..."
)
dequant_dir, config_data = _materialize_dequantized_hf_checkpoint(
original_local_path, config_data, hf_prequant_method, hf_prequant_config,
)
local_path = dequant_dir
model_name = dequant_dir
dequant_temp_dir = dequant_dir
Comment on lines +3837 to +3861

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the dequantization setup to accept and track the tempfile.TemporaryDirectory object returned by the updated _materialize_dequantized_hf_checkpoint function.

        dequant_temp_dir_obj = None
        hf_prequant_method, hf_prequant_config = _detect_hf_prequant_method(config_data)
        if hf_prequant_method is not None:
            if local_path is None:
                raise FileNotFoundError(
                    f"Unsloth: could not resolve local files for "
                    f"{hf_prequant_method.upper()} model '{model_name}'."
                )
            if _is_vlm(config_data):
                raise NotImplementedError(
                    f"Unsloth: {hf_prequant_method.upper()} runtime dequant is not "
                    "yet supported for vision models on MLX. Load an unquantized "
                    "VLM base for LoRA instead."
                )
            print(
                f"Unsloth: Detected {hf_prequant_method.upper()} pre-quantized "
                f"checkpoint '{model_name}'; dequantizing to fp16 for MLX "
                "(LoRA base will be re-quantized to MLX affine)..."
            )
            dequant_dir_obj, dequant_dir, config_data = _materialize_dequantized_hf_checkpoint(
                original_local_path, config_data, hf_prequant_method, hf_prequant_config,
            )
            local_path = dequant_dir
            model_name = dequant_dir
            dequant_temp_dir_obj = dequant_dir_obj

else:
# A recognized-but-unsupported packed quant format must fail loud
# with a clear message instead of misrouting into the generic
# MLX-compatibility check.
_other_quant = (
config_data.get("quantization_config")
if isinstance(config_data, dict) else None
)
if isinstance(_other_quant, dict):
_other_method = str(_other_quant.get("quant_method", "")).lower()
if _other_method in _HF_UNSUPPORTED_PACKED_METHODS:
raise NotImplementedError(
f"Unsloth: '{model_name}' uses '{_other_method}' "
"quantization, which is not supported on the MLX path. "
"Supported pre-quantized formats are GPTQ and AWQ "
"(dequantized to MLX affine for LoRA); otherwise load an "
"unquantized or MLX-quantized checkpoint."
)

# Reject full_finetuning on a pre-quantized repo: int4/int8 weights
# aren't trainable (our CCE backward zeros the quantized weight grad),
# so full FT would silently update only LayerNorms/biases.
Expand Down Expand Up @@ -4218,10 +4461,10 @@ def from_pretrained(
model._is_vlm_model = False

model._config = config
model._hf_repo = model_name
model._src_path = local_path
model._hf_repo = original_model_name
model._src_path = original_local_path
model._unsloth_base_revision = revision
model._unsloth_base_commit_hash = _infer_snapshot_commit(local_path)
model._unsloth_base_commit_hash = _infer_snapshot_commit(original_local_path)
model.max_seq_length = max_seq_length
model._unsloth_patch_mode = patch_mode
model._unsloth_full_finetuning = bool(full_finetuning)
Expand All @@ -4232,6 +4475,15 @@ def from_pretrained(
_patch_mixed_precision_set_dtype(model)

_patch_mlx_saving(model, tokenizer)

if dequant_temp_dir is not None:
# The dequantized weights are now materialized in memory; the
# temporary fp16 checkpoint on disk is no longer referenced.
import mlx.core as mx
import shutil

mx.eval(model.parameters())
shutil.rmtree(dequant_temp_dir, ignore_errors=True)
return model, tokenizer
Comment on lines +4479 to 4487

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Clean up the temporary directory using the TemporaryDirectory.cleanup() method inside a try...finally block to ensure robust cleanup even if mx.eval fails, and gracefully ignore any errors during cleanup.

            if dequant_temp_dir_obj is not None:
                # The dequantized weights are now materialized in memory; the
                # temporary fp16 checkpoint on disk is no longer referenced.
                import mlx.core as mx

                try:
                    mx.eval(model.parameters())
                finally:
                    try:
                        dequant_temp_dir_obj.cleanup()
                    except Exception:
                        pass


@staticmethod
Expand Down
Loading