Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions tests/test_dora_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""DoRA (use_dora=True) merge support in the safetensors merge path.

The dense merge must (1) no longer raise a key-mismatch on a DoRA adapter (the magnitude vector
is now captured), and (2) produce the same merged weight as PEFT's own DoRA merge. MoE-expert
DoRA is explicitly refused (fail loud) rather than silently dropping the magnitude.
"""
import copy

import pytest
import torch
import torch.nn as nn

from unsloth_zoo.saving_utils import create_lora_statistics, _merge_lora, LoraStats


class _Tiny(nn.Module):
def __init__(self, d_in=32, d_out=24):
super().__init__()
self.q_proj = nn.Linear(d_in, d_out, bias=False)

def forward(self, x):
return self.q_proj(x)


def _find_q_stats(lora_weights):
for v in lora_weights.values():
if v.lora_A is not None and v.lora_B is not None:
return v
return None


def test_dora_merge_matches_peft():
from peft import LoraConfig, get_peft_model

torch.manual_seed(0)
base = _Tiny().to(torch.float32)
W0 = base.q_proj.weight.detach().clone()

cfg = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj"], use_dora=True)
pm = get_peft_model(copy.deepcopy(base), cfg)

# Give the adapter a non-trivial delta and magnitude so DoRA actually rescales.
for n, p in pm.named_parameters():
if n.endswith("lora_B.default.weight"):
with torch.no_grad():
p.copy_(torch.randn_like(p) * 0.1)
if n.endswith("lora_magnitude_vector.default.weight"):
with torch.no_grad():
p.add_(torch.randn_like(p) * 0.1)

# Ground truth: PEFT's own DoRA merge.
merged_peft = copy.deepcopy(pm).merge_and_unload()
W_peft = None
for n, p in merged_peft.named_parameters():
if n.endswith("q_proj.weight"):
W_peft = p.detach().float().clone()
assert W_peft is not None

# Unsloth merge path: capture stats (must NOT raise on DoRA) then fold via _merge_lora.
result = create_lora_statistics(pm, merge_into_original=True, return_state_dict=True)
lora_weights = result[0] if isinstance(result, tuple) else result
stats = _find_q_stats(lora_weights)
assert stats is not None
assert stats.magnitude is not None, "DoRA magnitude was not captured"

W_uns = _merge_lora(W0.clone(), stats, "q_proj").cpu().float()

max_abs = (W_uns - W_peft).abs().max().item()
assert torch.allclose(W_uns, W_peft, atol=1e-4, rtol=1e-4), f"max abs diff {max_abs}"
# Sanity: DoRA actually changed the weight vs the plain base.
assert (W_uns - W0.float()).abs().max().item() > 1e-3


def test_plain_lora_unaffected():
"""A non-DoRA adapter has magnitude None and merges as W0 + alpha*BA."""
from peft import LoraConfig, get_peft_model

torch.manual_seed(1)
base = _Tiny().to(torch.float32)
W0 = base.q_proj.weight.detach().clone()
cfg = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj"], use_dora=False)
pm = get_peft_model(copy.deepcopy(base), cfg)
for n, p in pm.named_parameters():
if n.endswith("lora_B.default.weight"):
with torch.no_grad():
p.copy_(torch.randn_like(p) * 0.1)

result = create_lora_statistics(pm, merge_into_original=True, return_state_dict=True)
lora_weights = result[0] if isinstance(result, tuple) else result
stats = _find_q_stats(lora_weights)
assert stats is not None and stats.magnitude is None

W_uns = _merge_lora(W0.clone(), stats, "q_proj").cpu().float()
expected = W0.float() + stats.alpha * (stats.lora_B.float() @ stats.lora_A.float())
assert torch.allclose(W_uns, expected, atol=1e-5)


def test_dora_on_moe_expert_is_refused():
from unsloth_zoo.saving_utils import _merge_moe_fused_gate_up_expert

E, rank, H, I = 4, 4, 8, 6
gate_up_W = torch.randn(E, 2 * I, H)
A = torch.randn(E * rank, H)
B = torch.randn(2 * I, E * rank)
stats = LoraStats(None, A, B, 1.0, magnitude=torch.randn(2 * I))
with pytest.raises(RuntimeError, match="DoRA"):
_merge_moe_fused_gate_up_expert(gate_up_W, stats, torch.float32)
41 changes: 40 additions & 1 deletion unsloth_zoo/saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,19 @@ def _merge_lora(W, lora_stats, name, use_dequant_base = False):
W = W_new.addmm_(lora_B, lora_A, alpha=lora_stats.alpha)
else:
W = W.addmm_(lora_B, lora_A, alpha=lora_stats.alpha)
# DoRA: rescale the merged direction to the learned magnitude. With delta = alpha*(B@A),
# PEFT's DoRA merge is (m / ||W0 + delta||_row) * (W0 + delta), one L2 norm per output row
# over the input dim. W already holds W0 + delta here, so fold m onto it.
magnitude = getattr(lora_stats, "magnitude", None)
if magnitude is not None:
magnitude = magnitude.to(device, dtype = torch.float32, non_blocking = True).reshape(-1)
if magnitude.shape[0] != W.shape[0]:
raise ValueError(
f"Unsloth: DoRA magnitude for `{name}` has {magnitude.shape[0]} entries but the "
f"merged weight has {W.shape[0]} output rows."
)
weight_norm = torch.linalg.norm(W, dim = 1).clamp_min(1e-9)
W = (magnitude / weight_norm).unsqueeze(1) * W
if not torch.isfinite(torch.amax(W)).item():
raise ValueError('Unsloth: Merge failed as there are infinite elements in ' + name)
return W
Expand Down Expand Up @@ -332,6 +345,7 @@ class LoraStats:
lora_A : torch.Tensor
lora_B : torch.Tensor
alpha : float
magnitude : object = None # DoRA lora_magnitude_vector weight (None for plain LoRA)
pass


Expand All @@ -344,13 +358,15 @@ def assert_same_keys(model, new_state_dict):

def _should_ignore(key: str) -> bool:
# Ignore helper wrappers and raw LoRA adapter tensors; the merged
# state_dict intentionally omits lora_A / lora_B weights.
# state_dict intentionally omits lora_A / lora_B / DoRA magnitude weights
# (the magnitude is folded into the merged weight in _merge_lora).
return (
"modules_to_save" in key
or "original_module" in key
or ".lora_A" in key
or ".lora_B" in key
or ".lora_embedding" in key
or ".lora_magnitude_vector" in key
)

def _normalize(key: str) -> str:
Expand Down Expand Up @@ -440,6 +456,12 @@ def create_lora_statistics(model, merge_into_original = False, return_state_dict
lora_B_count += 1
expand_module_keys(name, module, remove_keys)

elif name.endswith(".lora_magnitude_vector.default"):
# DoRA magnitude vector m; folded onto the merged weight in _merge_lora. Register its
# key so the key-consistency check does not flag it (the merged model omits it).
lora_weights[name[:-len(".lora_magnitude_vector.default")]].magnitude = module.weight
expand_module_keys(name, module, remove_keys)

elif isinstance(module, Linear_LoRA_Layers):
lora_weights[name].alpha = _get_lora_scaling(module)
scaling_count += 1
Expand Down Expand Up @@ -1179,10 +1201,24 @@ def _detect_moe_lora_layout(lora_A, lora_B, num_experts, out_dim, in_dim, lora_m
return "unknown", r


def _refuse_dora_on_moe(lora_stats):
"""DoRA on MoE experts is not yet supported: the expert merge helpers fold only the LoRA
delta, not the DoRA magnitude (the dense path handles it in _merge_lora). Fail loud rather
than emit a checkpoint with the magnitude silently dropped."""
if getattr(lora_stats, "magnitude", None) is not None:
raise RuntimeError(
"Unsloth: DoRA (use_dora=True) merging is not yet supported for MoE expert layers. "
"Fine-tune only the non-expert (attention/MLP) layers with DoRA, or open an issue at "
"https://github.com/unslothai/unsloth/issues."
)
pass


def _merge_moe_gate_or_up_expert(W, lora_stats, expert_idx, num_experts, output_dtype, *, role):
"""Per-expert merge for gate_proj/up_proj (role='gate' -> first I, 'up' -> last I)."""
if lora_stats is None or lora_stats.lora_A is None or lora_stats.lora_B is None:
return W
_refuse_dora_on_moe(lora_stats)
_MOE_MERGE_STATE["attempted"] += 1
try:
num_experts = _resolve_num_experts_from_lora_stats(lora_stats, num_experts)
Expand Down Expand Up @@ -1266,6 +1302,7 @@ def _merge_moe_up_expert(up_W, lora_stats, expert_idx, num_experts, output_dtype
def _merge_moe_down_proj_expert(down_W, lora_stats, expert_idx, num_experts, output_dtype):
if lora_stats is None or lora_stats.lora_A is None or lora_stats.lora_B is None:
return down_W
_refuse_dora_on_moe(lora_stats)
_MOE_MERGE_STATE["attempted"] += 1
try:
num_experts = _resolve_num_experts_from_lora_stats(lora_stats, num_experts)
Expand Down Expand Up @@ -1650,6 +1687,7 @@ def _merge_moe_fused_gate_up_expert(gate_up_W, lora_stats, output_dtype, is_tran
- Standard (Gemma4): (E, 2*I, H) with lora_A (E*R, H), lora_B (2*I, E*R)
is_transposed: if provided, overrides dimension-based heuristic (needed when dims are equal).
"""
_refuse_dora_on_moe(lora_stats)
_MOE_MERGE_STATE["attempted"] += 1
try:
if lora_stats.lora_A is None or lora_stats.lora_B is None:
Expand Down Expand Up @@ -1733,6 +1771,7 @@ def _merge_moe_fused_down_proj_expert(down_W, lora_stats, output_dtype, is_trans
- Standard (Gemma4): (E, H, I) with lora_A (E*R, H), lora_B (I, E*R)
is_transposed: if provided, overrides dimension-based heuristic (needed when H==I).
"""
_refuse_dora_on_moe(lora_stats)
_MOE_MERGE_STATE["attempted"] += 1
try:
if lora_stats.lora_A is None or lora_stats.lora_B is None:
Expand Down
Loading