Route gemma-4 sliding-window layers through FlashAttention-2#860
Route gemma-4 sliding-window layers through FlashAttention-2#860danielhanchen wants to merge 1 commit into
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Code Review
This pull request introduces a temporary patch to optimize Gemma-4's sliding-window layers by routing them through FlashAttention-2 when applicable, avoiding the O(S^2) SDPA fallback. The feedback highlights two critical issues: first, a memory address reuse bug in the caching mechanism of _mask_is_plain_band that relies on id(mask), which should be resolved by attaching the cached value directly to the tensor; second, a potential correctness issue where bidirectional attention could be incorrectly treated as causal when attention_mask is None and is_causal is False.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _mask_is_plain_band(mask, S, w): | ||
| """True if `mask` is exactly the causal + sliding band with no token padding.""" | ||
| if mask is None: | ||
| return True | ||
| if not torch.is_tensor(mask) or mask.dim() != 4: | ||
| return False | ||
| key = id(mask) | ||
| cached = _BAND_OK.get(key) | ||
| if cached is not None: | ||
| return cached | ||
| if len(_BAND_OK) > 64: | ||
| _BAND_OK.clear() | ||
| m = mask[0, 0] | ||
| if m.shape[-2] != S or m.shape[-1] != S: | ||
| _BAND_OK[key] = False | ||
| return False | ||
| allowed = m if m.dtype == torch.bool else (m > -1e4) | ||
| dev = m.device | ||
| cand = sorted({x for x in (0, 1, w - 1, w, w + 1, 2 * w, S // 2, S - 1, S - w, S - w - 1) if 0 <= x < S}) | ||
| probes = torch.tensor(cand, device=dev) | ||
| rows = allowed[probes] # (P, S) | ||
| idx = torch.arange(S, device=dev)[None, :] | ||
| cnt = rows.sum(-1) | ||
| minidx = torch.where(rows, idx, torch.full_like(idx, S)).amin(-1) | ||
| maxidx = torch.where(rows, idx, torch.full_like(idx, -1)).amax(-1) | ||
| lo = torch.clamp(probes - w + 1, min=0) | ||
| hi = probes | ||
| ok = bool(((cnt == (hi - lo + 1)) & (minidx == lo) & (maxidx == hi)).all().item()) | ||
| _BAND_OK[key] = ok | ||
| return ok |
There was a problem hiding this comment.
Memory Address Reuse Bug with id(mask) Caching
Using id(mask) as a cache key in _BAND_OK is prone to correctness bugs due to Python's memory reuse. When a mask tensor is deallocated, its memory address (id) can be reassigned to a newly created mask tensor. If the new mask has different values (e.g., due to different padding or sequence length) but gets the same id, the cache will return a stale/incorrect result.
Instead of a global _BAND_OK dictionary, we can safely stash the result as a custom attribute directly on the mask tensor (e.g., mask._unsloth_plain_band = ok). This binds the cached value to the lifetime of that specific tensor object, completely avoiding id collisions and memory leaks. You can also safely delete the _BAND_OK dictionary on line 45.
We use a truthy getattr check (getattr(mask, "_unsloth_plain_band", None)) to preserve the falsy False value, avoiding hasattr which can lead to issues.
def _mask_is_plain_band(mask, S, w):
"""True if mask is exactly the causal + sliding band with no token padding."""
if mask is None:
return True
if not torch.is_tensor(mask) or mask.dim() != 4:
return False
cached = getattr(mask, "_unsloth_plain_band", None)
if cached is not None:
return cached
m = mask[0, 0]
if m.shape[-2] != S or m.shape[-1] != S:
try:
mask._unsloth_plain_band = False
except AttributeError:
pass
return False
allowed = m if m.dtype == torch.bool else (m > -1e4)
dev = m.device
cand = sorted({x for x in (0, 1, w - 1, w, w + 1, 2 * w, S // 2, S - 1, S - w, S - w - 1) if 0 <= x < S})
probes = torch.tensor(cand, device=dev)
rows = allowed[probes] # (P, S)
idx = torch.arange(S, device=dev)[None, :]
cnt = rows.sum(-1)
minidx = torch.where(rows, idx, torch.full_like(idx, S)).amin(-1)
maxidx = torch.where(rows, idx, torch.full_like(idx, -1)).amax(-1)
lo = torch.clamp(probes - w + 1, min=0)
hi = probes
ok = bool(((cnt == (hi - lo + 1)) & (minidx == lo) & (maxidx == hi)).all().item())
try:
mask._unsloth_plain_band = ok
except AttributeError:
pass
return okReferences
- Use a truthy
getattrcheck (e.g.,getattr(cfg, "attr", None)) instead ofhasattrwhen you need to preserve falsy values likeNoneorFalse.
| w = kwargs.get("sliding_window", None) or getattr(module, "sliding_window", None) | ||
| Sq = query.shape[2] | ||
| Sk = key.shape[2] | ||
| if (w and Sq == Sk and Sq > w and _mask_is_plain_band(attention_mask, Sq, w)): |
There was a problem hiding this comment.
Causal Check when attention_mask is None
If attention_mask is None, _mask_is_plain_band returns True. However, if is_causal is False (meaning bidirectional attention), we would still route to FA2 with causal=True.
To prevent incorrect causal masking on bidirectional attention, we should explicitly check that is_causal is True when attention_mask is None.
w = kwargs.get("sliding_window", None) or getattr(module, "sliding_window", None)
Sq = query.shape[2]
Sk = key.shape[2]
if (w and Sq == Sk and Sq > w
and (attention_mask is not None or is_causal)
and _mask_is_plain_band(attention_mask, Sq, w)):Gemma-4 mixes 25 sliding-window layers (head_dim 256, window 1024) with 5 global layers (head_dim 512). FlashAttention-2 caps head_dim at 256, so the global layers disable it for the whole model and every layer falls back to SDPA, which spends O(seq^2) compute per sliding layer and dominates the step at long context. A causal sliding window is exactly FA2's window_size argument, so intercept the registered sdpa attention function and, for a gemma-4 sliding layer whose head_dim is 256, whose dtype is bf16/fp16, and whose mask is the plain causal band (no token padding), call FA2 with window_size=(w-1, 0). Global head_dim 512 layers and padded or non-band masks defer to the original SDPA unchanged. FA2 is an opaque op so the compiled graph is unaffected, and the result matches SDPA with the explicit band mask to bf16 rounding (forward rel 2.5e-3, backward 3.5e-3). Measured on a 26B-A4B LoRA SFT (4-bit, gradient checkpointing) at bf16: 1.66x at 4k, 2.21x at 8k, 2.65x at 16k, 3.11x at 32k, memory unchanged. The speedup grows with sequence length as attention becomes a larger share of the step. On by default; set UNSLOTH_GEMMA4_FLASH_SLIDING=0 to force SDPA.
8efc467 to
698c80b
Compare
| dropout=dropout, scaling=scaling, is_causal=is_causal, **kwargs) | ||
|
|
||
|
|
||
| def patch_gemma4_flash_sliding(): |
Summary
Gemma-4 mixes 25 sliding-window layers (head_dim 256, window 1024) with 5 global layers (head_dim 512). FlashAttention-2 caps head_dim at 256, so the global layers disable it for the whole model and every layer falls back to SDPA, which spends O(seq^2) compute per sliding layer and dominates the step at long context.
What this does
A causal sliding window is exactly FA2's
window_sizeargument, so intercept the registered sdpa attention function and, for a gemma-4 sliding layer whose head_dim is at most 256, whose dtype is bf16/fp16, and whose mask is the plain causal band (no token padding), call FA2 withwindow_size=(w-1, 0). Global head_dim 512 layers and padded or non-band masks defer to the original SDPA unchanged. FA2 is an opaque op so the compiled graph is unaffected.On by default when flash-attn is importable;
UNSLOTH_GEMMA4_FLASH_SLIDING=0reverts to SDPA.Testing