Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions src/transformers/models/xlstm/modeling_xlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_xlstm_available
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_xlstm_available, logging
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from .configuration_xlstm import xLSTMConfig


logger = logging.get_logger(__name__)


if is_xlstm_available():
from xlstm.xlstm_large.model import RMSNorm as xLSTMRMSNorm
from xlstm.xlstm_large.model import mLSTMBlock, mLSTMStateType, soft_cap
Expand Down Expand Up @@ -786,7 +789,7 @@ def forward(
m_initial: torch.Tensor | None = None,
return_last_states: bool | None = None,
mode: Literal["train", "inference"] | None = None,
) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None]:
"""Forward pass of the mLSTM backend.

Depending on the configured mode, this method will call the appropriate kernel function.
Expand All @@ -805,9 +808,10 @@ def forward(
If None, the value from the config is used.

Returns:
hidden states of shape (batch_size, nh, sequence_length, dhhv)
hidden states and last states the last states are the cell state cstate (batch_size, nh, dhqk, dhhv),
the normalizer state nstate (batch_size, nh, dhqk), and the max state mstate (batch_size, nh, 1)
A tuple of the hidden states of shape (batch_size, nh, sequence_length, dhhv) and the last states
(None in train_with_padding mode). The last states are the cell state cstate
(batch_size, nh, dhqk, dhhv), the normalizer state nstate (batch_size, nh, dhqk), and the max state
mstate (batch_size, nh, 1).
"""
if mode is None:
mode = self.config.mode
Expand All @@ -819,8 +823,24 @@ def forward(
if self.config.mode == "train_with_padding":
if return_last_states:
raise ValueError("return_last_states=True is not supported with train_with_padding mode.")
# The padded chunkwise kernels cannot compute meaningful last states.
h = self._train_fn(
query=query,
key=key,
value=value,
igate=igate,
fgate=fgate,
c_initial=c_initial,
n_initial=n_initial,
m_initial=m_initial,
return_last_states=False,
)
return h, None

return self._train_fn(
# The last states are a cheap byproduct of the chunkwise recurrence. Always request them
# from the kernel so that this method returns an (h, last_states) tuple in every mode,
# matching what the callers unpack.
h, last_states = self._train_fn(
query=query,
key=key,
value=value,
Expand All @@ -829,8 +849,9 @@ def forward(
c_initial=c_initial,
n_initial=n_initial,
m_initial=m_initial,
return_last_states=return_last_states,
return_last_states=True,
)
return h, last_states

elif "inference" in mode:
# inference mode always returns the last states
Expand Down Expand Up @@ -1367,6 +1388,7 @@ def __init__(
)
for layer in range(config.num_hidden_layers)
}
self.rnn_state_initial = True

def reset(self):
self.rnn_state = {
Expand All @@ -1377,6 +1399,7 @@ def reset(self):
)
for layer in self.rnn_state
}
self.rnn_state_initial = True


@auto_docstring
Expand Down Expand Up @@ -1438,6 +1461,13 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)

if use_cache and "with_padding" in self.config.mode:
logger.warning_once(
"`use_cache=True` is not supported with `mode='train_with_padding'` as no last states can be "
"computed on padded sequences. Setting `use_cache=False`."
)
use_cache = False

if use_cache and cache_params is None:
cache_params = xLSTMCache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
Expand Down Expand Up @@ -1474,15 +1504,19 @@ def forward(
offset += self.config.max_inference_chunksize
hidden_states = final_state
else:
# An all-zero initial cache state is equivalent to no state. Skipping it keeps the cache
# tensors out of the autograd graph, where the in-place update below would otherwise
# break the backward pass in train mode.
cache_has_state = cache_params is not None and not cache_params.rnn_state_initial
for layer_idx, xlstm_block in enumerate(self.blocks):
hidden_states, rnn_state = xlstm_block(
hidden_states,
cache_params.rnn_state[layer_idx] if cache_params is not None else None,
cache_params.rnn_state[layer_idx] if cache_has_state else None,
)

if cache_params:
if cache_params and rnn_state is not None:
for state_idx in range(len(cache_params.rnn_state[layer_idx])):
local_rnn_state = rnn_state[state_idx]
local_rnn_state = rnn_state[state_idx].detach()
cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
cache_params.rnn_state_initial = False

Expand Down
36 changes: 36 additions & 0 deletions tests/models/xlstm/test_modeling_xlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,42 @@ def test_chunkwise_shape_calculation(self):
expected_shape = (batch_size, seq_length, config.hidden_size)
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)

def test_train_mode_without_return_last_states(self):
# Regression test for #47013: in train mode with `return_last_states=False`, the mLSTM backend
# returned a bare tensor that `xLSTMLayer.forward` then unpacked along the batch dimension.
config = self.model_tester.get_config()
config.mode = "train"
config.return_last_states = False

model = xLSTMModel(config)
model.to(torch_device)
model.train()

# batch_size=2 used to trip the internal shape check, every other batch size failed to unpack
seq_length = config.chunk_size
for batch_size in (2, 3):
input_ids = ids_tensor([batch_size, seq_length], config.vocab_size)
outputs = model(input_ids)
self.assertEqual(outputs.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))

outputs.last_hidden_state.sum().backward()

def test_train_with_padding_mode_forward(self):
# The padded chunkwise kernels cannot compute last states, so there is nothing to fill a cache
# with; the forward pass must fall back to `use_cache=False` instead of crashing.
config = self.model_tester.get_config()
config.mode = "train_with_padding"
config.return_last_states = False

model = xLSTMModel(config)
model.to(torch_device)
model.train()

input_ids = ids_tensor([2, config.chunk_size], config.vocab_size)
outputs = model(input_ids)
self.assertEqual(outputs.last_hidden_state.shape, (2, config.chunk_size, config.hidden_size))
self.assertIsNone(outputs.cache_params)

@unittest.skip("This model doesn't support beam search with cache, as the cache cannot be reordered")
def test_beam_search_generate(self):
pass
Expand Down