-
Notifications
You must be signed in to change notification settings - Fork 289
dataset_utils: add mask_out_tokens to train_on_responses_only (fixes unslothai/unsloth#6695) #852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -323,12 +323,21 @@ def train_on_responses_only( | |
| return_function = False, # Useful for iterating over lists | ||
| num_proc = None, | ||
| last_response_only = False, # Train only on the last assistant turn | ||
| mask_out_tokens = None, # e.g. ["</think>"] - also mask these inside kept responses | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When using the MLX API ( Useful? React with 👍 / 👎.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 76587db — the MLX wrapper signature now accepts mask_out_tokens and forwards it to the HF implementation, mirroring the rest of the API surface. |
||
| ): | ||
| """Train only on responses by masking instruction labels to -100. | ||
|
|
||
| With last_response_only=True, only the final assistant turn is unmasked; | ||
| earlier assistant turns stay at -100 (never written, never copied from | ||
| old_labels). | ||
|
|
||
| mask_out_tokens re-masks the given token strings to -100 even inside kept | ||
| response spans - e.g. mask_out_tokens=["</think>"] reproduces the Nemotron | ||
| Ultra recipe of never training on the thinking closer. Each entry is matched | ||
| as its tokenized id sequence (a leading-space variant is matched too, for | ||
| SentencePiece-style tokenizers). Atomic added tokens such as "</think>" | ||
| always match exactly; multi-token strings match only where the in-context | ||
| tokenization equals the standalone one. | ||
| """ | ||
| # All Unsloth Zoo code licensed under LGPLv3 | ||
| if tokenizer is None and trainer is not None: | ||
|
|
@@ -375,6 +384,18 @@ def train_on_responses_only( | |
| torch_Tensor = torch.Tensor | ||
| torch_int64 = torch.int64 | ||
|
|
||
| # Precompute id sequences for mask_out_tokens (see docstring). Done once here so | ||
| # the per-example closure below only does integer comparisons. | ||
| mask_out_sequences = [] | ||
| if mask_out_tokens: | ||
| if isinstance(mask_out_tokens, str): mask_out_tokens = [mask_out_tokens] | ||
| for token_string in mask_out_tokens: | ||
| for candidate in dict.fromkeys((token_string, " " + token_string)): | ||
| ids = tokenizer(candidate, add_special_tokens = False).input_ids | ||
| if ids and ids not in mask_out_sequences: | ||
| mask_out_sequences.append(ids) | ||
| pass | ||
|
Comment on lines
+389
to
+397
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To ensure robust defensive programming and prevent cryptic runtime errors (such as mask_out_sequences = []
if mask_out_tokens:
if isinstance(mask_out_tokens, str):
mask_out_tokens = [mask_out_tokens]
elif not isinstance(mask_out_tokens, (list, tuple)):
raise TypeError("Unsloth: mask_out_tokens must be a string, list, or tuple of strings.")
for token_string in mask_out_tokens:
if not isinstance(token_string, str):
raise TypeError(f"Unsloth: mask_out_tokens elements must be strings, but got {type(token_string).__name__}")
for candidate in dict.fromkeys((token_string, " " + token_string)):
ids = tokenizer(candidate, add_special_tokens = False).input_ids
if ids and ids not in mask_out_sequences:
mask_out_sequences.append(ids)
pass |
||
|
|
||
| def _train_on_responses_only(examples): | ||
| input_ids_ = examples["input_ids"] | ||
| use_tensors = False | ||
|
|
@@ -468,6 +489,19 @@ def _train_on_responses_only(examples): | |
| else: | ||
| labels[assistant_k : user_j] = old_labels[assistant_k : user_j] | ||
|
|
||
| # Re-mask requested token sequences (e.g. "</think>") wherever they occur; | ||
| # positions outside kept spans are already -100, so re-masking is harmless. | ||
| for seq in mask_out_sequences: | ||
| seq_len, first = len(seq), seq[0] | ||
| i, limit = 0, n - len(seq) | ||
| while i <= limit: | ||
| if input_ids[i] == first and input_ids[i : i + seq_len] == seq: | ||
| labels[i : i + seq_len] = [-100] * seq_len | ||
| i += seq_len | ||
| else: | ||
| i += 1 | ||
| pass | ||
|
|
||
| all_labels.append(labels) | ||
| pass | ||
| return { "labels" : torch.tensor(all_labels, dtype = torch.int64) if use_tensors else all_labels } | ||
|
|
@@ -664,6 +698,7 @@ def _is_vision_collator(collator): | |
| tokenizer = coll_proc, | ||
| return_function = True, | ||
| last_response_only = last_response_only, | ||
| mask_out_tokens = mask_out_tokens, | ||
| **parts, | ||
| ) | ||
| print(f"Unsloth: Enabled response-only masking on your {type(data_collator).__name__} (image handling kept intact).") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When callers use this new kwarg with a vision trainer,
train_on_responses_onlytakes the later_is_vision_collatorbranch and assignsdata_collator.train_on_responses_onlyfrom a recursive call at lines 695-702, but that call omitsmask_out_tokens, so the inner closure builds an emptymask_out_sequences. This means VLM response-only training still includes</think>or any requested token in the loss despite acceptingmask_out_tokens=[...]; pass the kwarg through the collator masking setup as well.Useful? React with 👍 / 👎.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 76587db — the vision-collator recursive call now passes mask_out_tokens through, so the VLM path applies the same re-masking instead of silently ignoring the kwarg. Text-path label output is regression-identical.