Skip to content

Load prequantized bnb 4-bit MoE expert checkpoints under transformers v5#856

Open
danielhanchen wants to merge 1 commit into
mainfrom
bnb4bit-prequant-moe-loading
Open

Load prequantized bnb 4-bit MoE expert checkpoints under transformers v5#856
danielhanchen wants to merge 1 commit into
mainfrom
bnb4bit-prequant-moe-loading

Conversation

@danielhanchen

@danielhanchen danielhanchen commented Jul 3, 2026

Copy link
Copy Markdown
Member

Summary

Older unsloth-bnb-4bit MoE checkpoints store one quantized tensor per expert per projection (experts.<N>.gate_proj.weight plus absmax/quant_map/quant_state aux keys), and also quantize small non-Linear weights such as the MoE router. These checkpoints fail to load under transformers v5: the model's MergeModulelist converters match the per-expert keys first, byte-concatenate the packed uint8 as if it were bf16, and drop every aux key as unexpected, while the router lands as raw packed bytes on a plain nn.Parameter slot and F.linear fails with a reduction dim mismatch.

What this does

All three patches are active only when a Bnb4BitHfQuantizer is pre_quantized:

  1. patch_bnb4bit_model_conversion_mapping prepends quantized twins of the model's per-expert merge converters. Each twin collects the aux keys, rebuilds one QuantState per expert, concatenates the packed bytes in target element order (NF4 packing is elementwise row-major, so this is exact), and reassembles a single stacked Params4bit with a flat fp32 absmax. Bare weight patterns are anchored so the twins never shadow the model converters for bf16 checkpoints, and layers without aux keys (dynamic-quant skip layers) fall back to the model's own merge.
  2. patch_bnb4bit_quantizer_weight_conversions registers deserializers for fused expert param names (gate_up_proj/down_proj), which the stock quantizer only registers for params literally named weight.
  3. patch_bnb4bit_dequantize_plain_params dequantizes any Params4bit that landed on a module that is neither a bnb Linear4bit nor a v5 experts module (e.g. the router) back to a plain float Parameter after loading.

Testing

  • Dequantized expert weights verified bitwise identical to a reference load.
  • End-to-end 4-bit LoRA training validated on Qwen3-30B-A3B (finite losses, 4-bit resident memory).
  • 5 new CPU unit tests cover twin construction, pattern anchoring, the unquantized-layer fallback, and absmax denesting.
  • Validated across transformers 4.57.6, 5.5, and 5.13 (the last routes stacked experts through PEFT target_parameters, which needs the logical 3D shape now set on the reassembled param).

@danielhanchen danielhanchen requested a review from Datta0 as a code owner July 3, 2026 18:39
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Repo admins can enable using credits for code reviews in their settings.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for loading prequantized bitsandbytes 4-bit MoE checkpoints under transformers v5, including custom weight converters, per-expert stack deserialization, and a dequantization pass for plain parameter slots, along with corresponding unit tests. The feedback highlights critical issues where the newly created Params4bit parameters lack the _original_shape attribute, which will cause PEFT LoRA to crash during training or loading. Additionally, potential KeyError and IndexError vulnerabilities were identified in the per-expert deserialization logic when handling missing or mismatched source keys.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +750 to +806
counts = {len(input_dict[a]) for a in self.anchored_sources if a in input_dict}
if len(counts) != 1:
raise ValueError(
f"Unsloth: inconsistent per-expert tensor counts {counts} for {full_layer_name}"
)
num_experts = counts.pop()

device = input_dict[self.anchored_sources[0]][0].device
first_qs = None
src_shapes = []
per_src_absmax = [] # [src][expert] fp32 absmax
for base, anch in zip(self.base_sources, self.anchored_sources):
absmax_list = []
for e in range(num_experts):
qd = {}
for suf in _AUX_SUFFIXES:
vals = input_dict.get(base + suf)
if vals is not None:
qd["weight" + suf] = vals[e]
qs = QuantState.from_dict(qs_dict=qd, device=device)
if first_qs is None:
first_qs = qs
if e == 0:
src_shapes.append(tuple(qs.shape))
absmax_list.append(_quantstate_absmax_fp32(qs))
per_src_absmax.append(absmax_list)

out_dim = sum(s[0] for s in src_shapes)
in_dim = src_shapes[0][1]
if any(s[1] != in_dim for s in src_shapes):
raise ValueError(f"Unsloth: mismatched expert in_dims {src_shapes} for {full_layer_name}")

packed_rows, absmax_rows = [], []
for e in range(num_experts):
packed_rows.append(torch.cat(
[input_dict[a][e].reshape(-1) for a in self.anchored_sources]
))
absmax_rows.append(torch.cat([per_src_absmax[i][e] for i in range(len(self.base_sources))]))
data = torch.stack(packed_rows).unsqueeze(-1) # (E, bytes_per_expert, 1)
absmax = torch.cat(absmax_rows)

quant_state = QuantState(
absmax=absmax,
shape=torch.Size((num_experts, out_dim, in_dim)),
code=first_qs.code.to(device),
blocksize=first_qs.blocksize,
quant_type=first_qs.quant_type,
dtype=first_qs.dtype,
)
new_param = torch.Tensor._make_subclass(Params4bit, data.to(device))
new_param.requires_grad = False
new_param.quant_state = quant_state
new_param.blocksize = quant_state.blocksize
new_param.compress_statistics = False
new_param.quant_type = quant_state.quant_type
new_param.quant_storage = data.dtype
new_param.bnb_quantized = True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _PerExpertStackDeserialize.convert, there are several potential correctness and safety issues:

  1. Missing _original_shape: The newly created Params4bit parameter does not have the _original_shape attribute set. This will cause PEFT LoRA's ParamWrapper.get_param() to raise a ValueError and crash during training or loading.
  2. Potential KeyError: Accessing input_dict[a] for all a in self.anchored_sources assumes all sources are present. However, the check if a in input_dict at line 750 indicates some might be missing. If any are missing, it will crash with a KeyError.
  3. Potential IndexError: Accessing vals[e] assumes vals has at least num_experts elements, which might not be true if some expert aux keys are missing or mismatched.
  4. Potential KeyError/IndexError on device retrieval: Accessing input_dict[self.anchored_sources[0]][0] is unsafe if the first source is missing or empty.

We can resolve all of these issues by filtering self.base_sources and self.anchored_sources to only include present sources, performing safe bounds checks, and explicitly setting _original_shape on the returned parameter.

            present_sources = [
                (base, anch)
                for base, anch in zip(self.base_sources, self.anchored_sources)
                if anch in input_dict
            ]
            if not present_sources:
                raise ValueError(f"Unsloth: no expert weights found in input_dict for {full_layer_name}")

            counts = {len(input_dict[anch]) for _, anch in present_sources}
            if len(counts) != 1:
                raise ValueError(
                    f"Unsloth: inconsistent per-expert tensor counts {counts} for {full_layer_name}"
                )
            num_experts = counts.pop()

            device = input_dict[present_sources[0][1]][0].device
            first_qs = None
            src_shapes = []
            per_src_absmax = []  # [src][expert] fp32 absmax
            for base, anch in present_sources:
                absmax_list = []
                for e in range(num_experts):
                    qd = {}
                    for suf in _AUX_SUFFIXES:
                        vals = input_dict.get(base + suf)
                        if vals is not None and e < len(vals):
                            qd["weight" + suf] = vals[e]
                    qs = QuantState.from_dict(qs_dict=qd, device=device)
                    if first_qs is None:
                        first_qs = qs
                    if e == 0:
                        src_shapes.append(tuple(qs.shape))
                    absmax_list.append(_quantstate_absmax_fp32(qs))
                per_src_absmax.append(absmax_list)

            out_dim = sum(s[0] for s in src_shapes)
            in_dim = src_shapes[0][1]
            if any(s[1] != in_dim for s in src_shapes):
                raise ValueError(f"Unsloth: mismatched expert in_dims {src_shapes} for {full_layer_name}")

            packed_rows, absmax_rows = [], []
            for e in range(num_experts):
                packed_rows.append(torch.cat(
                    [input_dict[anch][e].reshape(-1) for _, anch in present_sources]
                ))
                absmax_rows.append(torch.cat([per_src_absmax[i][e] for i in range(len(present_sources))]))
            data = torch.stack(packed_rows).unsqueeze(-1)  # (E, bytes_per_expert, 1)
            absmax = torch.cat(absmax_rows)

            quant_state = QuantState(
                absmax=absmax,
                shape=torch.Size((num_experts, out_dim, in_dim)),
                code=first_qs.code.to(device),
                blocksize=first_qs.blocksize,
                quant_type=first_qs.quant_type,
                dtype=first_qs.dtype,
            )
            new_param = torch.Tensor._make_subclass(Params4bit, data.to(device))
            new_param.requires_grad = False
            new_param.quant_state = quant_state
            new_param._original_shape = quant_state.shape
            new_param.blocksize = quant_state.blocksize
            new_param.compress_statistics = False
            new_param.quant_type = quant_state.quant_type
            new_param.quant_storage = data.dtype
            new_param.bnb_quantized = True

Comment on lines +620 to +628
new_value = Params4bit.from_prequantized(
data=weight,
quantized_stats=input_dict,
requires_grad=False,
device=weight.device,
module=module,
)
module._is_hf_initialized = True
return {self.param_name: new_value}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _ExpertDeserialize.convert, the deserialized Params4bit parameter does not have the _original_shape attribute set. This custom attribute is required by PEFT LoRA to recover the logical 3D shape of the expert weights; without it, ParamWrapper.get_param() will raise a ValueError and crash. We should set _original_shape to new_value.quant_state.shape if quant_state is available.

Suggested change
new_value = Params4bit.from_prequantized(
data=weight,
quantized_stats=input_dict,
requires_grad=False,
device=weight.device,
module=module,
)
module._is_hf_initialized = True
return {self.param_name: new_value}
new_value = Params4bit.from_prequantized(
data=weight,
quantized_stats=input_dict,
requires_grad=False,
device=weight.device,
module=module,
)
if getattr(new_value, "quant_state", None) is not None:
new_value._original_shape = new_value.quant_state.shape
module._is_hf_initialized = True
return {self.param_name: new_value}


patched_get_weight_conversions._unsloth_moe_patched = True
patch_function(Bnb4BitHfQuantizer, "get_weight_conversions", patched_get_weight_conversions)
pass
conversion_mapping.get_model_conversion_mapping = patched_get_model_conversion_mapping
if getattr(modeling_utils, "get_model_conversion_mapping", None) is original:
modeling_utils.get_model_conversion_mapping = patched_get_model_conversion_mapping
pass
Bnb4BitHfQuantizer, "_process_model_after_weight_loading", patched_after_load,
match_level="relaxed",
)
pass
f"Unsloth: dequantized non-Linear 4-bit param {module_name}.{name} "
f"to {tuple(dequant.shape)} {dequant.dtype}"
)
pass
return converters


def patch_bnb4bit_quantizer_weight_conversions():
return twins


def patch_bnb4bit_model_conversion_mapping():
pass


def patch_bnb4bit_dequantize_plain_params():
twin = twins[0]

# Aux keys come first so they are collected alongside the weights.
n_base = len(conv.source_patterns)
Older unsloth-bnb-4bit MoE checkpoints store one quantized tensor per expert
per projection plus absmax/quant_map/quant_state aux keys, and also quantize
small non-Linear weights such as the MoE router. Under transformers v5 the
model's MergeModulelist converters match the per-expert keys first, byte
concatenate the packed uint8 as if it were bf16, and drop every aux key as
unexpected; the router lands as raw packed bytes on a plain nn.Parameter slot.

Fix in three parts, active only when a Bnb4BitHfQuantizer is pre_quantized:

1. patch_bnb4bit_model_conversion_mapping prepends quantized twins of the
   model's per-expert merge converters. Each twin collects the aux keys,
   rebuilds one QuantState per expert, concatenates the packed bytes in
   target element order (NF4 packing is elementwise row-major, so this is
   exact), and reassembles a single stacked Params4bit with a flat fp32
   absmax. Bare weight patterns are anchored so the twins never shadow the
   model converters for bf16 checkpoints, and layers without aux keys fall
   back to the model's own merge.

2. patch_bnb4bit_quantizer_weight_conversions registers deserializers for
   fused expert param names (gate_up_proj/down_proj), which the stock
   quantizer only registers for params literally named weight.

3. patch_bnb4bit_dequantize_plain_params dequantizes any Params4bit that
   landed on a module that is neither a bnb Linear4bit nor a v5 experts
   module (e.g. the router) back to a plain float Parameter after loading.

Verified bitwise identical dequantized expert weights against a reference
load and end-to-end 4-bit MoE training on Qwen3-30B-A3B.
@danielhanchen danielhanchen force-pushed the bnb4bit-prequant-moe-loading branch from 3b8819b to 3d70f00 Compare July 4, 2026 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant