diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index cf26c31f76e0..1b2e11fcf21e 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -160,13 +160,12 @@ def _sanitize_parameters( if generate_kwargs is not None: forward_kwargs["generate_kwargs"] = generate_kwargs if stop_sequence is not None: - stop_sequence_ids = self.processor.tokenizer.encode(stop_sequence, add_special_tokens=False) - if len(stop_sequence_ids) > 1: - logger.warning_once( - "Stopping on a multiple token sequence is not yet supported on transformers. The first token of" - " the stop sequence will be used as the stop sequence string in the interim." - ) - generate_kwargs["eos_token_id"] = stop_sequence_ids[0] + if generate_kwargs is None: + generate_kwargs = {} + if isinstance(stop_sequence, str): + stop_sequence = [stop_sequence] + generate_kwargs["stop_strings"] = stop_sequence + generate_kwargs["tokenizer"] = self.processor.tokenizer if generate_kwargs is not None: forward_kwargs["generate_kwargs"] = generate_kwargs if max_new_tokens is not None: diff --git a/tests/pipelines/test_pipelines_image_text_to_text.py b/tests/pipelines/test_pipelines_image_text_to_text.py index 894c7f4587f9..76199896524d 100644 --- a/tests/pipelines/test_pipelines_image_text_to_text.py +++ b/tests/pipelines/test_pipelines_image_text_to_text.py @@ -14,6 +14,7 @@ import base64 import unittest +from types import SimpleNamespace from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available from transformers.pipelines import ImageTextToTextPipeline, pipeline @@ -68,6 +69,17 @@ def run_pipeline_test(self, pipe, examples): ], ) + def test_stop_sequence_without_generate_kwargs(self): + pipe = object.__new__(ImageTextToTextPipeline) + tokenizer = object() + pipe.processor = SimpleNamespace(tokenizer=tokenizer) + + _, forward_kwargs, _ = pipe._sanitize_parameters(stop_sequence=".", max_new_tokens=3) + self.assertEqual( + forward_kwargs["generate_kwargs"], + {"stop_strings": ["."], "tokenizer": tokenizer, "max_new_tokens": 3}, + ) + @require_torch def test_small_model_pt_token_text_only(self): pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")