Skip to content

[xLSTM] Fix train mode crashes with return_last_states=False and use_cache#47046

Open
Heatdh wants to merge 1 commit into
huggingface:mainfrom
Heatdh:fix-xlstm-return-last-states
Open

[xLSTM] Fix train mode crashes with return_last_states=False and use_cache#47046
Heatdh wants to merge 1 commit into
huggingface:mainfrom
Heatdh:fix-xlstm-return-last-states

Conversation

@Heatdh

@Heatdh Heatdh commented Jul 3, 2026

Copy link
Copy Markdown

CI

What does this PR do?

Fixes #47013

In train mode with return_last_states=False, xLSTMBackend.forward returned a bare tensor while its only caller, xLSTMLayer.forward, always unpacks h, state = .... The forward pass crashed for every batch size (batch size 2 mis-unpacked the hidden states along the batch dimension and tripped the shape assertion; every other batch size failed to unpack).

This is an alternative to #47026 (thanks @lcheng321 for the first patch!). While writing regression tests for the tuple normalization, two follow-on crashes surfaced on the same path, which this PR also fixes:

  1. train_with_padding + default use_cache=True: the padded kernels cannot produce last states, so there is nothing to fill the cache with — with the tuple normalization alone, xLSTMModel.forward fails at the cache copy with TypeError: 'NoneType' object is not subscriptable. Now use_cache falls back to False with a warning_once, and the cache-copy loop guards against rnn_state=None.
  2. train mode + default use_cache=True + backward: xLSTMModel.forward fed the cache's state tensors into the autograd graph as initial states and then mutated them in place via copy_, so loss.backward() failed with RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. An all-zero initial cache state is equivalent to no state, so it is no longer passed into the layers, and states are detached before being stored in the cache.

Changes

  • xLSTMBackend.forward (native path) always returns an (h, last_states) tuple, matching the inference branch. The chunkwise train kernels compute the last states as a byproduct either way, so they are always requested; train_with_padding returns (h, None) since padding pollutes the last state. Return type annotation and docstring updated accordingly.
  • xLSTMModel.forward disables use_cache under train_with_padding (with a warning) and skips the initial all-zero cache state so the in-place cache update cannot break the backward pass.
  • xLSTMCache.__init__/reset initialize rnn_state_initial = True — this attribute was previously assigned in xLSTMModel.forward but never initialized, so it did not exist on fresh caches.
  • Regression tests: train-mode forward + backward with return_last_states=False for batch sizes 2 and 3, and a train_with_padding forward with default settings.

Like #47026, this keeps the fix local to the xLSTM files; routing the last states through the output recorder (as suggested in #47013) can be a follow-up.

Testing

pytest tests/models/xlstm/test_modeling_xlstm.py -k "train_mode_without_return_last_states or train_with_padding_mode_forward"
# on main: 2 failed
# with this PR: 2 passed

pytest tests/models/xlstm/test_modeling_xlstm.py
# 90 passed, 177 skipped, 737 subtests passed
# (test_can_load_with_device_context_manager and test_can_load_with_global_device_set fail
#  identically on unpatched main in my environment — unrelated to this change)

Before submitting

Who can review?

@vasqu @ArthurZucker

…cache

- xLSTMBackend.forward (native path) now always returns an
  (h, last_states) tuple, matching the inference branch; the chunkwise
  train kernels compute the last states either way, so they are always
  requested. train_with_padding returns (h, None) since padding pollutes
  the last state.
- xLSTMModel.forward disables use_cache under train_with_padding (with a
  warning) and skips the initial all-zero cache state so the in-place
  cache update cannot break the backward pass; stored states are
  detached.
- xLSTMCache initializes rnn_state_initial (previously only ever
  assigned in xLSTMModel.forward, never initialized).
- Add regression tests.

Fixes huggingface#47013
@github-actions

github-actions Bot commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: xlstm

@Heatdh

Heatdh commented Jul 3, 2026

Copy link
Copy Markdown
Author

Thanks @lcheng321 for the initial efforts, i rerun my training and i discovered some issues along the way with the patch. @vasqu I added 2 additional tests and inspected manually the regression results and it is stable. The manual code is auto refactored to match the implementation guidelines

@github-actions

github-actions Bot commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

CI recap

Dashboard: View test results in Grafana
Latest run: 28674037288
Result: success | Grafana metrics are not available yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[xLSTM] Crash in train mode with return_last_states=False: backend returns bare tensor but xLSTMLayer always unpacks (h, state)

1 participant