diff --git a/studio/backend/tests/test_vision_cache.py b/studio/backend/tests/test_vision_cache.py index 9e7bbdd1fb..97579b3b4d 100644 --- a/studio/backend/tests/test_vision_cache.py +++ b/studio/backend/tests/test_vision_cache.py @@ -117,6 +117,18 @@ def test_subprocess_called_once_with_cache(self, mock_needs_t5, mock_subprocess) mock_subprocess.assert_called_once() assert _vision_detection_cache[("unsloth/Qwen3.5-2B", None)] is True + @patch("utils.models.model_config._raw_config_has_vision_config", return_value = True) + @patch("utils.models.model_config._is_vision_model_subprocess", return_value = None) + @patch("utils.transformers_version.needs_transformers_5", return_value = True) + def test_subprocess_none_falls_back_to_raw_vision_config( + self, mock_needs_t5, mock_subprocess, mock_raw_config + ): + assert is_vision_model("unsloth/gemma-4-E4B-it") is True + assert is_vision_model("unsloth/gemma-4-E4B-it") is True + + mock_subprocess.assert_called_once() + mock_raw_config.assert_called_once_with("unsloth/gemma-4-E4B-it", hf_token = None) + # --------------------------------------------------------------------------- # Exception handling — cache the False fallback @@ -223,6 +235,20 @@ def test_vision_config_attr_detected_and_cached( assert is_vision_model("Qwen/Qwen2-VL-7B") is True mock_load_config.assert_called_once() + @patch("utils.transformers_version.needs_transformers_5", return_value = False) + @patch("utils.models.model_config.load_model_config") + def test_model_type_prefix_detected_and_cached( + self, mock_load_config, mock_needs_t5 + ): + cfg = MagicMock(spec = []) + cfg.model_type = "gemma4audio" + cfg.architectures = ["Gemma4AudioForCausalLM"] + mock_load_config.return_value = cfg + + assert is_vision_model("google/gemma-4-audio") is True + assert is_vision_model("google/gemma-4-audio") is True + mock_load_config.assert_called_once() + @patch("utils.transformers_version.needs_transformers_5", return_value = False) @patch("utils.models.model_config.load_model_config") def test_audio_model_excluded_and_cached(self, mock_load_config, mock_needs_t5): diff --git a/studio/backend/utils/models/model_config.py b/studio/backend/utils/models/model_config.py index 16f6d21edb..88e0535da3 100644 --- a/studio/backend/utils/models/model_config.py +++ b/studio/backend/utils/models/model_config.py @@ -496,6 +496,7 @@ def load_model_config( "internvl_chat", "cogvlm2", "minicpmv", + "gemma4", } # Pre-computed .venv_t5 paths and backend dir for subprocess version switching. @@ -503,6 +504,45 @@ def load_model_config( _VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5_550") _BACKEND_DIR = str(Path(__file__).resolve().parent.parent.parent) + +def _is_vlm(config) -> bool: + architectures = getattr(config, "architectures", None) or [] + model_type = getattr(config, "model_type", None) + return ( + any(x.endswith(_VLM_ARCH_SUFFIXES) for x in architectures) + or hasattr(config, "vision_config") + or hasattr(config, "img_processor") + or hasattr(config, "image_token_index") + or ( + model_type is not None + and any(model_type.startswith(vlm_type) for vlm_type in _VLM_MODEL_TYPES) + ) + ) + + +def _raw_config_has_vision_config( + model_name: str, hf_token: Optional[str] = None +) -> Optional[bool]: + try: + if is_local_path(model_name): + config_path = Path(normalize_path(model_name)).expanduser() / "config.json" + else: + from huggingface_hub import hf_hub_download + + config_path = Path( + hf_hub_download( + repo_id = model_name, + filename = "config.json", + token = hf_token, + ) + ) + config = json.loads(config_path.read_text()) + return "vision_config" in config and bool(config["vision_config"]) + except Exception as exc: + logger.warning("Could not read config.json for '%s': %s", model_name, exc) + return None + + # Inline script executed in a subprocess with transformers 5.x activated. # Receives model_name and token via argv, prints JSON result to stdout. _VISION_CHECK_SCRIPT = r""" @@ -521,30 +561,16 @@ def load_model_config( try: from transformers import AutoConfig + from utils.models.model_config import _is_vlm + kwargs = {"trust_remote_code": True} if token: kwargs["token"] = token config = AutoConfig.from_pretrained(model_name, **kwargs) - is_vlm = False - if hasattr(config, "architectures"): - is_vlm = any( - x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) - for x in config.architectures - ) - if not is_vlm and hasattr(config, "vision_config"): - is_vlm = True - if not is_vlm and hasattr(config, "img_processor"): - is_vlm = True - if not is_vlm and hasattr(config, "image_token_index"): - is_vlm = True - if not is_vlm and hasattr(config, "model_type"): - vlm_types = {"phi3_v","llava","llava_next","llava_onevision", - "internvl_chat","cogvlm2","minicpmv"} - if config.model_type in vlm_types: - is_vlm = True - - model_type = getattr(config, "model_type", "unknown") + is_vlm = _is_vlm(config) + + model_type = getattr(config, "model_type", None) archs = getattr(config, "architectures", []) print(json.dumps({"is_vision": is_vlm, "model_type": model_type, "architectures": archs})) @@ -719,7 +745,10 @@ def _is_vision_model_uncached( "Model '%s' needs transformers 5.x -- checking vision via subprocess", model_name, ) - return _is_vision_model_subprocess(model_name, hf_token = hf_token) + result = _is_vision_model_subprocess(model_name, hf_token = hf_token) + if result is not None: + return result + return _raw_config_has_vision_config(model_name, hf_token = hf_token) try: config = load_model_config(model_name, use_auth = True, token = hf_token) @@ -731,38 +760,10 @@ def _is_vision_model_uncached( if model_type in _audio_only_model_types: return False - # Check 1: Architecture class name patterns - if hasattr(config, "architectures"): - is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures) - if is_vlm: - logger.info( - f"Model {model_name} detected as VLM: architecture {config.architectures}" - ) - return True - - # Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.) - if hasattr(config, "vision_config"): - logger.info(f"Model {model_name} detected as VLM: has vision_config") + if _is_vlm(config): + logger.info(f"Model {model_name} detected as VLM") return True - # Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config) - if hasattr(config, "img_processor"): - logger.info(f"Model {model_name} detected as VLM: has img_processor") - return True - - # Check 4: Has image_token_index (common in VLMs for image placeholder tokens) - if hasattr(config, "image_token_index"): - logger.info(f"Model {model_name} detected as VLM: has image_token_index") - return True - - # Check 5: Known VLM model_type values that may not match above checks - if hasattr(config, "model_type"): - if config.model_type in _VLM_MODEL_TYPES: - logger.info( - f"Model {model_name} detected as VLM: model_type={config.model_type}" - ) - return True - return False except Exception as e: