Load prequantized bnb 4-bit MoE expert checkpoints under transformers v5#856
Load prequantized bnb 4-bit MoE expert checkpoints under transformers v5#856danielhanchen wants to merge 1 commit into
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Code Review
This pull request introduces support for loading prequantized bitsandbytes 4-bit MoE checkpoints under transformers v5, including custom weight converters, per-expert stack deserialization, and a dequantization pass for plain parameter slots, along with corresponding unit tests. The feedback highlights critical issues where the newly created Params4bit parameters lack the _original_shape attribute, which will cause PEFT LoRA to crash during training or loading. Additionally, potential KeyError and IndexError vulnerabilities were identified in the per-expert deserialization logic when handling missing or mismatched source keys.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| counts = {len(input_dict[a]) for a in self.anchored_sources if a in input_dict} | ||
| if len(counts) != 1: | ||
| raise ValueError( | ||
| f"Unsloth: inconsistent per-expert tensor counts {counts} for {full_layer_name}" | ||
| ) | ||
| num_experts = counts.pop() | ||
|
|
||
| device = input_dict[self.anchored_sources[0]][0].device | ||
| first_qs = None | ||
| src_shapes = [] | ||
| per_src_absmax = [] # [src][expert] fp32 absmax | ||
| for base, anch in zip(self.base_sources, self.anchored_sources): | ||
| absmax_list = [] | ||
| for e in range(num_experts): | ||
| qd = {} | ||
| for suf in _AUX_SUFFIXES: | ||
| vals = input_dict.get(base + suf) | ||
| if vals is not None: | ||
| qd["weight" + suf] = vals[e] | ||
| qs = QuantState.from_dict(qs_dict=qd, device=device) | ||
| if first_qs is None: | ||
| first_qs = qs | ||
| if e == 0: | ||
| src_shapes.append(tuple(qs.shape)) | ||
| absmax_list.append(_quantstate_absmax_fp32(qs)) | ||
| per_src_absmax.append(absmax_list) | ||
|
|
||
| out_dim = sum(s[0] for s in src_shapes) | ||
| in_dim = src_shapes[0][1] | ||
| if any(s[1] != in_dim for s in src_shapes): | ||
| raise ValueError(f"Unsloth: mismatched expert in_dims {src_shapes} for {full_layer_name}") | ||
|
|
||
| packed_rows, absmax_rows = [], [] | ||
| for e in range(num_experts): | ||
| packed_rows.append(torch.cat( | ||
| [input_dict[a][e].reshape(-1) for a in self.anchored_sources] | ||
| )) | ||
| absmax_rows.append(torch.cat([per_src_absmax[i][e] for i in range(len(self.base_sources))])) | ||
| data = torch.stack(packed_rows).unsqueeze(-1) # (E, bytes_per_expert, 1) | ||
| absmax = torch.cat(absmax_rows) | ||
|
|
||
| quant_state = QuantState( | ||
| absmax=absmax, | ||
| shape=torch.Size((num_experts, out_dim, in_dim)), | ||
| code=first_qs.code.to(device), | ||
| blocksize=first_qs.blocksize, | ||
| quant_type=first_qs.quant_type, | ||
| dtype=first_qs.dtype, | ||
| ) | ||
| new_param = torch.Tensor._make_subclass(Params4bit, data.to(device)) | ||
| new_param.requires_grad = False | ||
| new_param.quant_state = quant_state | ||
| new_param.blocksize = quant_state.blocksize | ||
| new_param.compress_statistics = False | ||
| new_param.quant_type = quant_state.quant_type | ||
| new_param.quant_storage = data.dtype | ||
| new_param.bnb_quantized = True |
There was a problem hiding this comment.
In _PerExpertStackDeserialize.convert, there are several potential correctness and safety issues:
- Missing
_original_shape: The newly createdParams4bitparameter does not have the_original_shapeattribute set. This will cause PEFT LoRA'sParamWrapper.get_param()to raise aValueErrorand crash during training or loading. - Potential
KeyError: Accessinginput_dict[a]for allainself.anchored_sourcesassumes all sources are present. However, the checkif a in input_dictat line 750 indicates some might be missing. If any are missing, it will crash with aKeyError. - Potential
IndexError: Accessingvals[e]assumesvalshas at leastnum_expertselements, which might not be true if some expert aux keys are missing or mismatched. - Potential
KeyError/IndexErroron device retrieval: Accessinginput_dict[self.anchored_sources[0]][0]is unsafe if the first source is missing or empty.
We can resolve all of these issues by filtering self.base_sources and self.anchored_sources to only include present sources, performing safe bounds checks, and explicitly setting _original_shape on the returned parameter.
present_sources = [
(base, anch)
for base, anch in zip(self.base_sources, self.anchored_sources)
if anch in input_dict
]
if not present_sources:
raise ValueError(f"Unsloth: no expert weights found in input_dict for {full_layer_name}")
counts = {len(input_dict[anch]) for _, anch in present_sources}
if len(counts) != 1:
raise ValueError(
f"Unsloth: inconsistent per-expert tensor counts {counts} for {full_layer_name}"
)
num_experts = counts.pop()
device = input_dict[present_sources[0][1]][0].device
first_qs = None
src_shapes = []
per_src_absmax = [] # [src][expert] fp32 absmax
for base, anch in present_sources:
absmax_list = []
for e in range(num_experts):
qd = {}
for suf in _AUX_SUFFIXES:
vals = input_dict.get(base + suf)
if vals is not None and e < len(vals):
qd["weight" + suf] = vals[e]
qs = QuantState.from_dict(qs_dict=qd, device=device)
if first_qs is None:
first_qs = qs
if e == 0:
src_shapes.append(tuple(qs.shape))
absmax_list.append(_quantstate_absmax_fp32(qs))
per_src_absmax.append(absmax_list)
out_dim = sum(s[0] for s in src_shapes)
in_dim = src_shapes[0][1]
if any(s[1] != in_dim for s in src_shapes):
raise ValueError(f"Unsloth: mismatched expert in_dims {src_shapes} for {full_layer_name}")
packed_rows, absmax_rows = [], []
for e in range(num_experts):
packed_rows.append(torch.cat(
[input_dict[anch][e].reshape(-1) for _, anch in present_sources]
))
absmax_rows.append(torch.cat([per_src_absmax[i][e] for i in range(len(present_sources))]))
data = torch.stack(packed_rows).unsqueeze(-1) # (E, bytes_per_expert, 1)
absmax = torch.cat(absmax_rows)
quant_state = QuantState(
absmax=absmax,
shape=torch.Size((num_experts, out_dim, in_dim)),
code=first_qs.code.to(device),
blocksize=first_qs.blocksize,
quant_type=first_qs.quant_type,
dtype=first_qs.dtype,
)
new_param = torch.Tensor._make_subclass(Params4bit, data.to(device))
new_param.requires_grad = False
new_param.quant_state = quant_state
new_param._original_shape = quant_state.shape
new_param.blocksize = quant_state.blocksize
new_param.compress_statistics = False
new_param.quant_type = quant_state.quant_type
new_param.quant_storage = data.dtype
new_param.bnb_quantized = True| new_value = Params4bit.from_prequantized( | ||
| data=weight, | ||
| quantized_stats=input_dict, | ||
| requires_grad=False, | ||
| device=weight.device, | ||
| module=module, | ||
| ) | ||
| module._is_hf_initialized = True | ||
| return {self.param_name: new_value} |
There was a problem hiding this comment.
In _ExpertDeserialize.convert, the deserialized Params4bit parameter does not have the _original_shape attribute set. This custom attribute is required by PEFT LoRA to recover the logical 3D shape of the expert weights; without it, ParamWrapper.get_param() will raise a ValueError and crash. We should set _original_shape to new_value.quant_state.shape if quant_state is available.
| new_value = Params4bit.from_prequantized( | |
| data=weight, | |
| quantized_stats=input_dict, | |
| requires_grad=False, | |
| device=weight.device, | |
| module=module, | |
| ) | |
| module._is_hf_initialized = True | |
| return {self.param_name: new_value} | |
| new_value = Params4bit.from_prequantized( | |
| data=weight, | |
| quantized_stats=input_dict, | |
| requires_grad=False, | |
| device=weight.device, | |
| module=module, | |
| ) | |
| if getattr(new_value, "quant_state", None) is not None: | |
| new_value._original_shape = new_value.quant_state.shape | |
| module._is_hf_initialized = True | |
| return {self.param_name: new_value} |
|
|
||
| patched_get_weight_conversions._unsloth_moe_patched = True | ||
| patch_function(Bnb4BitHfQuantizer, "get_weight_conversions", patched_get_weight_conversions) | ||
| pass |
| conversion_mapping.get_model_conversion_mapping = patched_get_model_conversion_mapping | ||
| if getattr(modeling_utils, "get_model_conversion_mapping", None) is original: | ||
| modeling_utils.get_model_conversion_mapping = patched_get_model_conversion_mapping | ||
| pass |
| Bnb4BitHfQuantizer, "_process_model_after_weight_loading", patched_after_load, | ||
| match_level="relaxed", | ||
| ) | ||
| pass |
| f"Unsloth: dequantized non-Linear 4-bit param {module_name}.{name} " | ||
| f"to {tuple(dequant.shape)} {dequant.dtype}" | ||
| ) | ||
| pass |
| return converters | ||
|
|
||
|
|
||
| def patch_bnb4bit_quantizer_weight_conversions(): |
| return twins | ||
|
|
||
|
|
||
| def patch_bnb4bit_model_conversion_mapping(): |
| pass | ||
|
|
||
|
|
||
| def patch_bnb4bit_dequantize_plain_params(): |
| twin = twins[0] | ||
|
|
||
| # Aux keys come first so they are collected alongside the weights. | ||
| n_base = len(conv.source_patterns) |
Older unsloth-bnb-4bit MoE checkpoints store one quantized tensor per expert per projection plus absmax/quant_map/quant_state aux keys, and also quantize small non-Linear weights such as the MoE router. Under transformers v5 the model's MergeModulelist converters match the per-expert keys first, byte concatenate the packed uint8 as if it were bf16, and drop every aux key as unexpected; the router lands as raw packed bytes on a plain nn.Parameter slot. Fix in three parts, active only when a Bnb4BitHfQuantizer is pre_quantized: 1. patch_bnb4bit_model_conversion_mapping prepends quantized twins of the model's per-expert merge converters. Each twin collects the aux keys, rebuilds one QuantState per expert, concatenates the packed bytes in target element order (NF4 packing is elementwise row-major, so this is exact), and reassembles a single stacked Params4bit with a flat fp32 absmax. Bare weight patterns are anchored so the twins never shadow the model converters for bf16 checkpoints, and layers without aux keys fall back to the model's own merge. 2. patch_bnb4bit_quantizer_weight_conversions registers deserializers for fused expert param names (gate_up_proj/down_proj), which the stock quantizer only registers for params literally named weight. 3. patch_bnb4bit_dequantize_plain_params dequantizes any Params4bit that landed on a module that is neither a bnb Linear4bit nor a v5 experts module (e.g. the router) back to a plain float Parameter after loading. Verified bitwise identical dequantized expert weights against a reference load and end-to-end 4-bit MoE training on Qwen3-30B-A3B.
3b8819b to
3d70f00
Compare
Summary
Older unsloth-bnb-4bit MoE checkpoints store one quantized tensor per expert per projection (
experts.<N>.gate_proj.weightplus absmax/quant_map/quant_state aux keys), and also quantize small non-Linear weights such as the MoE router. These checkpoints fail to load under transformers v5: the model's MergeModulelist converters match the per-expert keys first, byte-concatenate the packed uint8 as if it were bf16, and drop every aux key as unexpected, while the router lands as raw packed bytes on a plainnn.Parameterslot andF.linearfails with a reduction dim mismatch.What this does
All three patches are active only when a
Bnb4BitHfQuantizerispre_quantized:patch_bnb4bit_model_conversion_mappingprepends quantized twins of the model's per-expert merge converters. Each twin collects the aux keys, rebuilds oneQuantStateper expert, concatenates the packed bytes in target element order (NF4 packing is elementwise row-major, so this is exact), and reassembles a single stackedParams4bitwith a flat fp32 absmax. Bare weight patterns are anchored so the twins never shadow the model converters for bf16 checkpoints, and layers without aux keys (dynamic-quant skip layers) fall back to the model's own merge.patch_bnb4bit_quantizer_weight_conversionsregisters deserializers for fused expert param names (gate_up_proj/down_proj), which the stock quantizer only registers for params literally namedweight.patch_bnb4bit_dequantize_plain_paramsdequantizes anyParams4bitthat landed on a module that is neither a bnbLinear4bitnor a v5 experts module (e.g. the router) back to a plain float Parameter after loading.Testing