diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index 3a674cf51b6d..35232d6bbfc7 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -816,11 +816,25 @@ def forward( if return_last_states is None: return_last_states = self.config.return_last_states + # padding pollutes the last state, so it can never be returned if self.config.mode == "train_with_padding": if return_last_states: raise ValueError("return_last_states=True is not supported with train_with_padding mode.") + h = self._train_fn( + query=query, + key=key, + value=value, + igate=igate, + fgate=fgate, + c_initial=c_initial, + n_initial=n_initial, + m_initial=m_initial, + return_last_states=False, + ) + return h, None - return self._train_fn( + # last states are a free byproduct of the chunkwise recurrence, so always return them + h, last_states = self._train_fn( query=query, key=key, value=value, @@ -829,8 +843,9 @@ def forward( c_initial=c_initial, n_initial=n_initial, m_initial=m_initial, - return_last_states=return_last_states, + return_last_states=True, ) + return h, last_states elif "inference" in mode: # inference mode always returns the last states