Skip to content

Route gemma-4 sliding-window layers through FlashAttention-2#860

Open
danielhanchen wants to merge 1 commit into
mainfrom
gemma4-flash-sliding
Open

Route gemma-4 sliding-window layers through FlashAttention-2#860
danielhanchen wants to merge 1 commit into
mainfrom
gemma4-flash-sliding

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

Gemma-4 mixes 25 sliding-window layers (head_dim 256, window 1024) with 5 global layers (head_dim 512). FlashAttention-2 caps head_dim at 256, so the global layers disable it for the whole model and every layer falls back to SDPA, which spends O(seq^2) compute per sliding layer and dominates the step at long context.

What this does

A causal sliding window is exactly FA2's window_size argument, so intercept the registered sdpa attention function and, for a gemma-4 sliding layer whose head_dim is at most 256, whose dtype is bf16/fp16, and whose mask is the plain causal band (no token padding), call FA2 with window_size=(w-1, 0). Global head_dim 512 layers and padded or non-band masks defer to the original SDPA unchanged. FA2 is an opaque op so the compiled graph is unaffected.

On by default when flash-attn is importable; UNSLOTH_GEMMA4_FLASH_SLIDING=0 reverts to SDPA.

Testing

  • Result matches SDPA with the explicit band mask to bf16 rounding (forward rel 2.5e-3, backward 3.5e-3); unit test included.
  • Measured on a 26B-A4B LoRA SFT (4-bit, gradient checkpointing): 1.66x at 4k, 2.21x at 8k, 2.65x at 16k, 3.11x at 32k, memory unchanged. The speedup grows with sequence length as attention becomes a larger share of the step.

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Repo admins can enable using credits for code reviews in their settings.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a temporary patch to optimize Gemma-4's sliding-window layers by routing them through FlashAttention-2 when applicable, avoiding the O(S^2) SDPA fallback. The feedback highlights two critical issues: first, a memory address reuse bug in the caching mechanism of _mask_is_plain_band that relies on id(mask), which should be resolved by attaching the cached value directly to the tensor; second, a potential correctness issue where bidirectional attention could be incorrectly treated as causal when attention_mask is None and is_causal is False.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +54 to +83
def _mask_is_plain_band(mask, S, w):
"""True if `mask` is exactly the causal + sliding band with no token padding."""
if mask is None:
return True
if not torch.is_tensor(mask) or mask.dim() != 4:
return False
key = id(mask)
cached = _BAND_OK.get(key)
if cached is not None:
return cached
if len(_BAND_OK) > 64:
_BAND_OK.clear()
m = mask[0, 0]
if m.shape[-2] != S or m.shape[-1] != S:
_BAND_OK[key] = False
return False
allowed = m if m.dtype == torch.bool else (m > -1e4)
dev = m.device
cand = sorted({x for x in (0, 1, w - 1, w, w + 1, 2 * w, S // 2, S - 1, S - w, S - w - 1) if 0 <= x < S})
probes = torch.tensor(cand, device=dev)
rows = allowed[probes] # (P, S)
idx = torch.arange(S, device=dev)[None, :]
cnt = rows.sum(-1)
minidx = torch.where(rows, idx, torch.full_like(idx, S)).amin(-1)
maxidx = torch.where(rows, idx, torch.full_like(idx, -1)).amax(-1)
lo = torch.clamp(probes - w + 1, min=0)
hi = probes
ok = bool(((cnt == (hi - lo + 1)) & (minidx == lo) & (maxidx == hi)).all().item())
_BAND_OK[key] = ok
return ok

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Memory Address Reuse Bug with id(mask) Caching

Using id(mask) as a cache key in _BAND_OK is prone to correctness bugs due to Python's memory reuse. When a mask tensor is deallocated, its memory address (id) can be reassigned to a newly created mask tensor. If the new mask has different values (e.g., due to different padding or sequence length) but gets the same id, the cache will return a stale/incorrect result.

Instead of a global _BAND_OK dictionary, we can safely stash the result as a custom attribute directly on the mask tensor (e.g., mask._unsloth_plain_band = ok). This binds the cached value to the lifetime of that specific tensor object, completely avoiding id collisions and memory leaks. You can also safely delete the _BAND_OK dictionary on line 45.

We use a truthy getattr check (getattr(mask, "_unsloth_plain_band", None)) to preserve the falsy False value, avoiding hasattr which can lead to issues.

def _mask_is_plain_band(mask, S, w):
    """True if mask is exactly the causal + sliding band with no token padding."""
    if mask is None:
        return True
    if not torch.is_tensor(mask) or mask.dim() != 4:
        return False
    cached = getattr(mask, "_unsloth_plain_band", None)
    if cached is not None:
        return cached
    m = mask[0, 0]
    if m.shape[-2] != S or m.shape[-1] != S:
        try:
            mask._unsloth_plain_band = False
        except AttributeError:
            pass
        return False
    allowed = m if m.dtype == torch.bool else (m > -1e4)
    dev = m.device
    cand = sorted({x for x in (0, 1, w - 1, w, w + 1, 2 * w, S // 2, S - 1, S - w, S - w - 1) if 0 <= x < S})
    probes = torch.tensor(cand, device=dev)
    rows = allowed[probes]                                # (P, S)
    idx = torch.arange(S, device=dev)[None, :]
    cnt = rows.sum(-1)
    minidx = torch.where(rows, idx, torch.full_like(idx, S)).amin(-1)
    maxidx = torch.where(rows, idx, torch.full_like(idx, -1)).amax(-1)
    lo = torch.clamp(probes - w + 1, min=0)
    hi = probes
    ok = bool(((cnt == (hi - lo + 1)) & (minidx == lo) & (maxidx == hi)).all().item())
    try:
        mask._unsloth_plain_band = ok
    except AttributeError:
        pass
    return ok
References
  1. Use a truthy getattr check (e.g., getattr(cfg, "attr", None)) instead of hasattr when you need to preserve falsy values like None or False.

Comment on lines +93 to +96
w = kwargs.get("sliding_window", None) or getattr(module, "sliding_window", None)
Sq = query.shape[2]
Sk = key.shape[2]
if (w and Sq == Sk and Sq > w and _mask_is_plain_band(attention_mask, Sq, w)):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Causal Check when attention_mask is None

If attention_mask is None, _mask_is_plain_band returns True. However, if is_causal is False (meaning bidirectional attention), we would still route to FA2 with causal=True.

To prevent incorrect causal masking on bidirectional attention, we should explicitly check that is_causal is True when attention_mask is None.

        w = kwargs.get("sliding_window", None) or getattr(module, "sliding_window", None)
        Sq = query.shape[2]
        Sk = key.shape[2]
        if (w and Sq == Sk and Sq > w 
                and (attention_mask is not None or is_causal)
                and _mask_is_plain_band(attention_mask, Sq, w)):

Gemma-4 mixes 25 sliding-window layers (head_dim 256, window 1024) with 5 global
layers (head_dim 512). FlashAttention-2 caps head_dim at 256, so the global
layers disable it for the whole model and every layer falls back to SDPA, which
spends O(seq^2) compute per sliding layer and dominates the step at long context.

A causal sliding window is exactly FA2's window_size argument, so intercept the
registered sdpa attention function and, for a gemma-4 sliding layer whose
head_dim is 256, whose dtype is bf16/fp16, and whose mask is the plain causal
band (no token padding), call FA2 with window_size=(w-1, 0). Global head_dim 512
layers and padded or non-band masks defer to the original SDPA unchanged. FA2 is
an opaque op so the compiled graph is unaffected, and the result matches SDPA
with the explicit band mask to bf16 rounding (forward rel 2.5e-3, backward
3.5e-3).

Measured on a 26B-A4B LoRA SFT (4-bit, gradient checkpointing) at bf16:
1.66x at 4k, 2.21x at 8k, 2.65x at 16k, 3.11x at 32k, memory unchanged. The
speedup grows with sequence length as attention becomes a larger share of the
step. On by default; set UNSLOTH_GEMMA4_FLASH_SLIDING=0 to force SDPA.
@danielhanchen danielhanchen force-pushed the gemma4-flash-sliding branch from 8efc467 to 698c80b Compare July 3, 2026 19:05
dropout=dropout, scaling=scaling, is_causal=is_causal, **kwargs)


def patch_gemma4_flash_sliding():
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant