[xLSTM] Fix train mode crashes with return_last_states=False and use_cache#47046
Open
Heatdh wants to merge 1 commit into
Open
[xLSTM] Fix train mode crashes with return_last_states=False and use_cache#47046Heatdh wants to merge 1 commit into
Heatdh wants to merge 1 commit into
Conversation
…cache - xLSTMBackend.forward (native path) now always returns an (h, last_states) tuple, matching the inference branch; the chunkwise train kernels compute the last states either way, so they are always requested. train_with_padding returns (h, None) since padding pollutes the last state. - xLSTMModel.forward disables use_cache under train_with_padding (with a warning) and skips the initial all-zero cache state so the in-place cache update cannot break the backward pass; stored states are detached. - xLSTMCache initializes rnn_state_initial (previously only ever assigned in xLSTMModel.forward, never initialized). - Add regression tests. Fixes huggingface#47013
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: xlstm |
Author
|
Thanks @lcheng321 for the initial efforts, i rerun my training and i discovered some issues along the way with the patch. @vasqu I added 2 additional tests and inspected manually the regression results and it is stable. The manual code is auto refactored to match the implementation guidelines |
Contributor
CI recapDashboard: View test results in Grafana |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #47013
In train mode with
return_last_states=False,xLSTMBackend.forwardreturned a bare tensor while its only caller,xLSTMLayer.forward, always unpacksh, state = .... The forward pass crashed for every batch size (batch size 2 mis-unpacked the hidden states along the batch dimension and tripped the shape assertion; every other batch size failed to unpack).This is an alternative to #47026 (thanks @lcheng321 for the first patch!). While writing regression tests for the tuple normalization, two follow-on crashes surfaced on the same path, which this PR also fixes:
train_with_padding+ defaultuse_cache=True: the padded kernels cannot produce last states, so there is nothing to fill the cache with — with the tuple normalization alone,xLSTMModel.forwardfails at the cache copy withTypeError: 'NoneType' object is not subscriptable. Nowuse_cachefalls back toFalsewith awarning_once, and the cache-copy loop guards againstrnn_state=None.use_cache=True+ backward:xLSTMModel.forwardfed the cache's state tensors into the autograd graph as initial states and then mutated them in place viacopy_, soloss.backward()failed withRuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. An all-zero initial cache state is equivalent to no state, so it is no longer passed into the layers, and states are detached before being stored in the cache.Changes
xLSTMBackend.forward(native path) always returns an(h, last_states)tuple, matching the inference branch. The chunkwise train kernels compute the last states as a byproduct either way, so they are always requested;train_with_paddingreturns(h, None)since padding pollutes the last state. Return type annotation and docstring updated accordingly.xLSTMModel.forwarddisablesuse_cacheundertrain_with_padding(with a warning) and skips the initial all-zero cache state so the in-place cache update cannot break the backward pass.xLSTMCache.__init__/resetinitializernn_state_initial = True— this attribute was previously assigned inxLSTMModel.forwardbut never initialized, so it did not exist on fresh caches.return_last_states=Falsefor batch sizes 2 and 3, and atrain_with_paddingforward with default settings.Like #47026, this keeps the fix local to the xLSTM files; routing the last states through the output recorder (as suggested in #47013) can be a follow-up.
Testing
Before submitting
Who can review?
@vasqu @ArthurZucker