Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,12 @@ def _sanitize_parameters(
if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs
if stop_sequence is not None:
stop_sequence_ids = self.processor.tokenizer.encode(stop_sequence, add_special_tokens=False)
if len(stop_sequence_ids) > 1:
logger.warning_once(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
if generate_kwargs is None:
generate_kwargs = {}
if isinstance(stop_sequence, str):
stop_sequence = [stop_sequence]
generate_kwargs["stop_strings"] = stop_sequence
generate_kwargs["tokenizer"] = self.processor.tokenizer
if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs
if max_new_tokens is not None:
Expand Down
12 changes: 12 additions & 0 deletions tests/pipelines/test_pipelines_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import base64
import unittest
from types import SimpleNamespace

from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available
from transformers.pipelines import ImageTextToTextPipeline, pipeline
Expand Down Expand Up @@ -68,6 +69,17 @@ def run_pipeline_test(self, pipe, examples):
],
)

def test_stop_sequence_without_generate_kwargs(self):
pipe = object.__new__(ImageTextToTextPipeline)
tokenizer = object()
pipe.processor = SimpleNamespace(tokenizer=tokenizer)

_, forward_kwargs, _ = pipe._sanitize_parameters(stop_sequence=".", max_new_tokens=3)
self.assertEqual(
forward_kwargs["generate_kwargs"],
{"stop_strings": ["."], "tokenizer": tokenizer, "max_new_tokens": 3},
)

@require_torch
def test_small_model_pt_token_text_only(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
Expand Down
Loading