Skip to content

dataset_utils: add mask_thinking_tokens to mask the </think> token#833

Open
Sushankthatipally wants to merge 1 commit into
unslothai:mainfrom
Sushankthatipally:mask-thinking-tokens
Open

dataset_utils: add mask_thinking_tokens to mask the </think> token#833
Sushankthatipally wants to merge 1 commit into
unslothai:mainfrom
Sushankthatipally:mask-thinking-tokens

Conversation

@Sushankthatipally

Copy link
Copy Markdown

Adds mask_thinking_tokens, a convenience function in the same spirit as
train_on_responses_only, that masks the thinking closing token (</think>
by default) to -100 in the labels so it isn't trained on.

Requested in unslothai/unsloth#6695. Nemotron Ultra masks out </think>
during training so the model isn't trained on the reasoning terminator and
doesn't memorise a fixed reasoning length. This turns that into a one-liner,
composable after train_on_responses_only:

trainer = train_on_responses_only(trainer, instruction_part=..., response_part=...)
trainer = mask_thinking_tokens(trainer)

The function mirrors train_on_responses_only's interface (tokenizer,
return_function, num_proc, torch.Tensor/list labels, IterableDataset)
and handles </think> whether it tokenises to one id or several. It's purely
subtractive: only positions matching think_token are set to -100; existing
-100s and every other label are left untouched. Tokenised datasets that don't
yet carry labels fall back to deriving them from input_ids.

Added offline CPU-only unit tests in tests/test_mask_thinking_tokens.py
covering single/multi-token </think>, multiple occurrences, label
preservation, tensor and list inputs, and the no-labels fallback.

Happy to follow up with a companion re-export from unsloth.chat_templates
(like train_on_responses_only) once this lands, and to rename the function
if you'd prefer something else.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the mask_thinking_tokens utility function to mask reasoning terminator tokens (like </think>) to -100 in dataset labels, preventing models from memorizing fixed reasoning lengths. It also adds a comprehensive offline test suite to verify the masking logic. The review feedback highlights several robust improvements: adding a type check to prevent silent failures when a dataset is mistakenly passed instead of a trainer, using getattr instead of hasattr to correctly handle None values for the tokenizer, ensuring consistent tensor type preservation when only labels are provided as tensors, and adding corresponding unit tests for these edge cases.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +440 to +455
def mask_thinking_tokens(
trainer,
think_token = "</think>",
tokenizer = None, # Optional
return_function = False, # Useful for iterating over lists
num_proc = None,
):
"""Mask the thinking closing token (e.g. </think>) to -100 in the labels.

Inspired by Nemotron Ultra, which masks out </think> during training so the
model is not trained on the reasoning terminator and does not memorise a
fixed reasoning length. Use it like train_on_responses_only - apply it after
train_on_responses_only (or on any tokenized dataset whose labels the data
collator preserves). Only the label positions matching think_token are set
to -100; every other label is left untouched.
"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If a user mistakenly passes a Dataset or IterableDataset as the first argument instead of a Trainer, the function will silently do nothing and return the dataset unmodified (since datasets do not have train_dataset or eval_dataset attributes). Adding a check for the map attribute prevents this silent failure and guides the user to use return_function=True.

def mask_thinking_tokens(
    trainer,
    think_token     = "</think>",
    tokenizer       = None,  # Optional
    return_function = False, # Useful for iterating over lists
    num_proc        = None,
):
    """Mask the thinking closing token (e.g. </think>) to -100 in the labels.

    Inspired by Nemotron Ultra, which masks out </think> during training so the
    model is not trained on the reasoning terminator and does not memorise a
    fixed reasoning length. Use it like train_on_responses_only - apply it after
    train_on_responses_only (or on any tokenized dataset whose labels the data
    collator preserves). Only the label positions matching think_token are set
    to -100; every other label is left untouched.
    """
    if hasattr(trainer, "map"):
        raise TypeError(
            "Unsloth: mask_thinking_tokens expects a Trainer as the first argument. "
            "If you want to apply it directly to a dataset, use return_function=True."
        )

Comment thread unsloth_zoo/dataset_utils.py Outdated
Comment on lines +457 to +458
if tokenizer is None and trainer is not None:
tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using hasattr returns True even if processing_class is set to None, which would prevent falling back to trainer.tokenizer. Using a truthy getattr check instead ensures that if processing_class is None or missing, we correctly fall back to trainer.tokenizer. Additionally, if both trainer and tokenizer are None, raising a clear ValueError improves error clarity.

Suggested change
if tokenizer is None and trainer is not None:
tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer
if tokenizer is None and trainer is not None:
tokenizer = getattr(trainer, "processing_class", None) or getattr(trainer, "tokenizer", None)
if tokenizer is None:
raise ValueError("Unsloth: A tokenizer must be provided or accessible from the trainer.")
References
  1. Use a truthy getattr check (e.g., getattr(cfg, "attr", None)) instead of hasattr when you need to preserve falsy values like None or False. hasattr returns True even if the attribute is set to None or False, which can lead to unintended overwriting of default behaviors.

Comment on lines +474 to +482
use_tensors = False
if type(input_ids_) is torch_Tensor:
use_tensors = True
input_ids_ = input_ids_.tolist()
if "labels" in examples:
labels_ = examples["labels"]
if type(labels_) is torch_Tensor:
labels_ = labels_.tolist()
assert(len(labels_) == len(input_ids_))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If input_ids is a list of lists but labels is a torch.Tensor, use_tensors will remain False, and the returned dictionary will contain labels as a list of lists instead of preserving its original tensor type. Setting use_tensors = True if either input_ids or labels is a tensor ensures consistent type preservation.

Suggested change
use_tensors = False
if type(input_ids_) is torch_Tensor:
use_tensors = True
input_ids_ = input_ids_.tolist()
if "labels" in examples:
labels_ = examples["labels"]
if type(labels_) is torch_Tensor:
labels_ = labels_.tolist()
assert(len(labels_) == len(input_ids_))
use_tensors = False
if type(input_ids_) is torch_Tensor:
use_tensors = True
input_ids_ = input_ids_.tolist()
if "labels" in examples:
labels_ = examples["labels"]
if type(labels_) is torch_Tensor:
use_tensors = True
labels_ = labels_.tolist()
assert(len(labels_) == len(input_ids_))

Comment on lines +107 to +110
def test_empty_think_token_raises():
tok = FakeTokenizer("</think>", [])
with pytest.raises(ValueError):
mask_thinking_tokens(trainer = None, think_token = "</think>", tokenizer = tok, return_function = True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add unit tests to verify that passing a dataset as the trainer raises a TypeError, and that calling the function without a tokenizer or trainer raises a ValueError.

Suggested change
def test_empty_think_token_raises():
tok = FakeTokenizer("</think>", [])
with pytest.raises(ValueError):
mask_thinking_tokens(trainer = None, think_token = "</think>", tokenizer = tok, return_function = True)
def test_empty_think_token_raises():
tok = FakeTokenizer("</think>", [])
with pytest.raises(ValueError):
mask_thinking_tokens(trainer = None, think_token = "</think>", tokenizer = tok, return_function = True)
def test_dataset_as_trainer_raises():
class FakeDataset:
def map(self): pass
with pytest.raises(TypeError):
mask_thinking_tokens(trainer = FakeDataset(), tokenizer = FakeTokenizer("</think>", [42]))
def test_no_tokenizer_or_trainer_raises():
with pytest.raises(ValueError):
mask_thinking_tokens(trainer = None, tokenizer = None)

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 82a41f564f

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

n = len(input_ids)
j = 0
while j <= n - len_think:
if input_ids[j] == first_think and input_ids[j : j + len_think] == think_ids:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Handle context-merged think-token encodings

When the closing marker is adjacent to whitespace in the serialized chat (for example </think>\n), byte-level/BPE tokenizers can fold that boundary into different token IDs than the bare think_token used here; this file already works around that class of issue for chat parts via _find_common_token_ids. Because the matcher only searches for the isolated think_ids, those common samples leave the terminator label unmasked, silently defeating the new helper for affected tokenizers/templates. Please also match contextual encodings of the marker or mask before tokenization.

Useful? React with 👍 / 👎.

pass

if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None:
trainer.train_dataset = _apply(trainer.train_dataset)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Filter rows fully masked by think-token removal

When the assistant span contains no trainable token besides the thinking terminator (for example an empty response after </think> or a truncation that leaves only that marker), this mapping turns the whole row into -100 labels but then keeps it in train_dataset. train_on_responses_only explicitly filters such rows because an all-ignored batch can produce NaN loss, so this helper needs the same post-mask filtering or zero-loss guard after applying _apply.

Useful? React with 👍 / 👎.

trainer.eval_dataset = _apply(trainer.eval_dataset)
pass

return trainer

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve masked labels with the trainer collator

When this helper is used without first calling train_on_responses_only (especially the no-label fallback above), common SFTTrainer/DataCollatorForLanguageModeling setups recreate labels from input_ids during collation, so the mapped -100 positions are overwritten and </think> is still trained on. train_on_responses_only avoids this by swapping to DataCollatorForSeq2Seq before returning; this helper should do the same for non-packing trainers or otherwise fail fast when the active collator will not preserve labels.

Useful? React with 👍 / 👎.

Adds a convenience function, like train_on_responses_only, that sets the label of the thinking closing token (</think> by default) to -100 so the model is not trained on it. Follows Nemotron Ultra, which masks out </think> during training to avoid teaching the model a fixed reasoning length. Switches to DataCollatorForSeq2Seq (like train_on_responses_only) so masked labels survive collation, validates the tokenizer and trainer arguments, and includes offline unit tests for the masking logic and error paths.
@Sushankthatipally

Copy link
Copy Markdown
Author

Thanks for the reviews! Pushed an update addressing most of the points:

  • Trainer-vs-Dataset guard: added a TypeError if a Dataset is passed as the
    first arg (it would otherwise silently no-op), pointing to return_function=True.
  • Tokenizer resolution: switched to getattr(..., None) or getattr(..., None) so a
    None processing_class falls back to tokenizer, plus a clear ValueError when no
    tokenizer is available.
  • Tensor preservation: use_tensors is now set when labels is a tensor too.
  • Collator: switch to DataCollatorForSeq2Seq for non-packing trainers (same as
    train_on_responses_only) so the masked -100s aren't rebuilt from input_ids
    during collation.
  • Tests: added coverage for the TypeError, the ValueError, and tensor-label
    preservation (13 tests, CPU-only/offline).

Two I held off on, with reasoning:

  • Context-merged encodings: the reasoning models this targets (Qwen3,
    DeepSeek-R1, Nemotron) tokenize as an atomic special token that doesn't
    merge with adjacent whitespace, so the bare-id match is correct. Routing through
    _find_common_token_ids also pulls in its [0] fallback (which can resolve to
    token-id 0 / pad) for tokens that don't encode cleanly. Happy to add
    context-probing if you'd like to support non-special-token templates.
  • Filtering fully-masked rows: train_on_responses_only filters because it masks the
    entire instruction, which can zero out a row. This masks only the single
    terminator token, so a row with any other trainable label can't become fully
    -100. Glad to add the filter for symmetry if you'd prefer it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant