From d579ce962a4b290be7124d149b9e260c6e80f90a Mon Sep 17 00:00:00 2001 From: Etherll Date: Sat, 9 May 2026 10:53:09 +0300 Subject: [PATCH 01/11] Document extractor refactor (rebased onto origin/main with conflicts resolved) --- .gitignore | 6 + .../other/deepseek-ai_DeepSeek-OCR.yaml | 22 + .../other/unsloth_PaddleOCR-VL.yaml | 6 + .../model_defaults/other/zai-org_GLM-OCR.yaml | 22 + studio/backend/core/chat/__init__.py | 61 + .../backend/core/chat/document_extractor.py | 1076 +++++++++++++++++ studio/backend/core/chat/vlm_capability.py | 211 ++++ studio/backend/core/export/export.py | 5 +- studio/backend/core/inference/__init__.py | 3 +- studio/backend/core/inference/llama_cpp.py | 61 +- studio/backend/core/inference/worker.py | 47 +- studio/backend/core/training/trainer.py | 12 +- studio/backend/models/inference.py | 161 +++ studio/backend/requirements/studio.txt | 10 + studio/backend/routes/inference.py | 1056 ++++++++++++++-- studio/backend/routes/models.py | 119 +- studio/backend/run.py | 10 +- .../backend/tests/test_anthropic_messages.py | 19 + .../tests/test_chat_document_extraction.py | 900 ++++++++++++++ .../tests/test_chat_document_routes.py | 895 ++++++++++++++ studio/backend/tests/test_inference_worker.py | 37 + ...models_get_model_config_case_resolution.py | 108 +- .../tests/test_openai_tool_passthrough.py | 67 + studio/backend/tests/test_vision_cache.py | 10 +- studio/backend/utils/models/model_config.py | 119 +- .../components/assistant-ui/attachment.tsx | 357 +++++- .../src/components/assistant-ui/thread.tsx | 113 +- studio/frontend/src/components/ui/tabs.tsx | 307 +++-- .../src/features/chat/api/chat-adapter.ts | 154 ++- .../src/features/chat/api/chat-api.ts | 245 +++- .../src/features/chat/chat-settings-sheet.tsx | 889 +++++++++++++- .../components/attachment-chip-primitives.tsx | 242 ++++ .../chat/components/doc-attachment-chip.tsx | 160 +++ .../components/document-preview-panel.tsx | 732 +++++++++++ .../chat/components/document-stack.tsx | 748 ++++++++++++ .../chat/hooks/use-chat-model-runtime.ts | 2 + .../chat/hooks/use-document-extraction.ts | 150 +++ studio/frontend/src/features/chat/index.ts | 17 + .../src/features/chat/runtime-provider.tsx | 503 +++++--- .../src/features/chat/shared-composer.tsx | 822 +++++++++++-- .../chat/stores/chat-runtime-store.ts | 327 +++-- studio/frontend/src/features/chat/types.ts | 132 ++ .../frontend/src/features/chat/types/api.ts | 9 +- .../src/features/chat/types/runtime.ts | 5 + .../chat/utils/document-extraction.ts | 461 +++++++ .../src/features/chat/utils/ocr-model-lock.ts | 240 ++++ .../chat/utils/ocr-model-orchestrator.ts | 901 ++++++++++++++ .../features/chat/utils/ocr-model-presets.ts | 121 ++ .../src/features/training/api/models-api.ts | 33 +- .../training/stores/training-config-store.ts | 13 +- 50 files changed, 11956 insertions(+), 770 deletions(-) create mode 100644 studio/backend/assets/configs/model_defaults/other/deepseek-ai_DeepSeek-OCR.yaml create mode 100644 studio/backend/assets/configs/model_defaults/other/zai-org_GLM-OCR.yaml create mode 100644 studio/backend/core/chat/__init__.py create mode 100644 studio/backend/core/chat/document_extractor.py create mode 100644 studio/backend/core/chat/vlm_capability.py create mode 100644 studio/backend/tests/test_chat_document_extraction.py create mode 100644 studio/backend/tests/test_chat_document_routes.py create mode 100644 studio/backend/tests/test_inference_worker.py create mode 100644 studio/frontend/src/features/chat/components/attachment-chip-primitives.tsx create mode 100644 studio/frontend/src/features/chat/components/doc-attachment-chip.tsx create mode 100644 studio/frontend/src/features/chat/components/document-preview-panel.tsx create mode 100644 studio/frontend/src/features/chat/components/document-stack.tsx create mode 100644 studio/frontend/src/features/chat/hooks/use-document-extraction.ts create mode 100644 studio/frontend/src/features/chat/utils/document-extraction.ts create mode 100644 studio/frontend/src/features/chat/utils/ocr-model-lock.ts create mode 100644 studio/frontend/src/features/chat/utils/ocr-model-orchestrator.ts create mode 100644 studio/frontend/src/features/chat/utils/ocr-model-presets.ts diff --git a/.gitignore b/.gitignore index ae6770bc07..b960de5787 100644 --- a/.gitignore +++ b/.gitignore @@ -229,3 +229,9 @@ server.pid *.log package-lock.json llama.cpp/ +/.omc +/studio/frontend/.omc +/.codex +/studio/.omc +/studio/backend/.omc +*.patch diff --git a/studio/backend/assets/configs/model_defaults/other/deepseek-ai_DeepSeek-OCR.yaml b/studio/backend/assets/configs/model_defaults/other/deepseek-ai_DeepSeek-OCR.yaml new file mode 100644 index 0000000000..b827a1f910 --- /dev/null +++ b/studio/backend/assets/configs/model_defaults/other/deepseek-ai_DeepSeek-OCR.yaml @@ -0,0 +1,22 @@ +# Model defaults for deepseek-ai/DeepSeek-OCR +# Custom-code OCR vision model. Used by Studio chat as a temporary OCR +# model swap during scanned-PDF extraction; never used for training. + +model: + identifier: deepseek-ai/DeepSeek-OCR + display_name: DeepSeek-OCR + is_vision: true + is_ocr: true + +training: + trust_remote_code: true + max_seq_length: 8192 + packing: false + +inference: + trust_remote_code: true + temperature: 0.0 + top_p: 1.0 + top_k: -1 + min_p: 0.0 + default_max_seq_length: 8192 diff --git a/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml b/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml index b7587bbd91..2a270ed282 100644 --- a/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml +++ b/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml @@ -3,6 +3,12 @@ # Also applies to: unsloth/PaddleOCR-VL # added inference parameters from unsloth notebook +model: + identifier: unsloth/PaddleOCR-VL + display_name: PaddleOCR-VL + is_vision: true + is_ocr: true + training: trust_remote_code: true max_seq_length: 2048 diff --git a/studio/backend/assets/configs/model_defaults/other/zai-org_GLM-OCR.yaml b/studio/backend/assets/configs/model_defaults/other/zai-org_GLM-OCR.yaml new file mode 100644 index 0000000000..2249aa4487 --- /dev/null +++ b/studio/backend/assets/configs/model_defaults/other/zai-org_GLM-OCR.yaml @@ -0,0 +1,22 @@ +# Model defaults for zai-org/GLM-OCR +# GLM family OCR vision model with model_type "glm_ocr". Used by Studio chat +# as a temporary OCR model swap during scanned-PDF extraction. + +model: + identifier: zai-org/GLM-OCR + display_name: GLM-OCR + is_vision: true + is_ocr: true + +training: + trust_remote_code: true + max_seq_length: 8192 + packing: false + +inference: + trust_remote_code: true + temperature: 0.0 + top_p: 1.0 + top_k: -1 + min_p: 0.0 + default_max_seq_length: 8192 diff --git a/studio/backend/core/chat/__init__.py b/studio/backend/core/chat/__init__.py new file mode 100644 index 0000000000..ba0d556b64 --- /dev/null +++ b/studio/backend/core/chat/__init__.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +""" +Chat-surface helpers that do not belong in ``core/inference`` (tightly +coupled to model backends) and explicitly not in ``core/data_recipe`` +(owns dataset pipelines). + +Exposes the document-extraction pipeline used when a user drops a +PDF / DOCX / HTML / MD / TXT file into the chat composer. PDF parsing +uses PyMuPDF4LLM, DOCX uses mammoth. PPTX is not supported here — +convert to PDF first. +""" + +from __future__ import annotations + +from .document_extractor import ( + DOCUMENT_EXTRACTION_AVAILABLE, + DEFAULT_DOCUMENT_VISUAL_PAYLOADS, + DocumentExtractionBusy, + DocumentExtractionCancelled, + DocumentExtractionEncrypted, + DocumentExtractionTimeout, + DocumentExtractionUnavailable, + ExtractedFigure, + ExtractResult, + MAX_DOCUMENT_VISUAL_PAYLOADS, + SUPPORTED_MIME_TYPES, + SUPPORTED_SUFFIXES, + _EXTRACT_SEMAPHORE, + document_parser_support, + document_parser_unavailable_reasons, + extract_document, +) +from .vlm_capability import ( + VlmCapability, + detect_loaded_vlm, + extract_self_base_url, +) + +__all__ = [ + "DOCUMENT_EXTRACTION_AVAILABLE", + "DEFAULT_DOCUMENT_VISUAL_PAYLOADS", + "DocumentExtractionBusy", + "DocumentExtractionCancelled", + "DocumentExtractionEncrypted", + "DocumentExtractionTimeout", + "DocumentExtractionUnavailable", + "ExtractedFigure", + "ExtractResult", + "MAX_DOCUMENT_VISUAL_PAYLOADS", + "SUPPORTED_MIME_TYPES", + "SUPPORTED_SUFFIXES", + "VlmCapability", + "_EXTRACT_SEMAPHORE", + "detect_loaded_vlm", + "document_parser_support", + "document_parser_unavailable_reasons", + "extract_document", + "extract_self_base_url", +] diff --git a/studio/backend/core/chat/document_extractor.py b/studio/backend/core/chat/document_extractor.py new file mode 100644 index 0000000000..50e1e46551 --- /dev/null +++ b/studio/backend/core/chat/document_extractor.py @@ -0,0 +1,1076 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +""" +Document extractor for the Chat composer. + +Given raw file bytes (PDF / DOCX / HTML / MD / TXT), produce Markdown +suitable to splice into an outgoing chat message. When a vision-capable +model is loaded, selected figures are captioned through our OpenAI-compatible +``/v1/chat/completions`` surface after conversion. + +This build uses **PyMuPDF4LLM** (via ``pymupdf4llm`` / ``pymupdf``) for PDF +parsing and **mammoth** for DOCX conversion. Plain-text and Markdown inputs +are decoded as UTF-8 with replacement; HTML inputs are converted to Markdown. + +Notes and limitations: + +* **OCR is disabled.** There is no local OCR pass in this build, so scanned + PDFs without a text layer will yield empty or near-empty Markdown. The + ``use_vlm_ocr`` flag is still accepted for API compatibility; when set it + renders bounded page images so a loaded vision model can describe them. +* **PPTX is not supported** in this build. ``SUPPORTED_SUFFIXES`` and + ``SUPPORTED_MIME_TYPES`` no longer advertise the PowerPoint types. +* Parser dependencies are checked per format so plain-text, Markdown, and HTML + still work when optional PDF or DOCX libraries are missing. +* If the loaded model is not vision-capable, image description is silently + skipped and ``figures`` comes back with captions set to ``None``; + ``describe_skipped_reason`` carries the diagnostic text. +""" + +from __future__ import annotations + +import asyncio +import base64 +import inspect +import io +import logging +import math +import multiprocessing +import os +import queue +import threading +import time +from dataclasses import dataclass, field, replace +from typing import Any, Awaitable, Callable, Literal, List, Optional + +from .vlm_capability import VlmCapability, detect_loaded_vlm + + +logger = logging.getLogger(__name__) + + +SUPPORTED_MIME_TYPES = frozenset( + { + "application/pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/json", + "application/x-ndjson", + "application/xml", + "application/yaml", + "application/javascript", + "text/html", + "text/markdown", + "text/plain", + "text/csv", + "text/css", + "text/javascript", + "text/xml", + "text/yaml", + } +) + +SUPPORTED_SUFFIXES = frozenset( + { + ".pdf", ".docx", ".html", ".htm", ".md", ".txt", + ".csv", ".json", ".jsonl", ".yaml", ".yml", + ".py", ".js", ".jsx", ".ts", ".tsx", ".go", ".rs", ".java", + ".c", ".cpp", ".h", ".hpp", ".cs", ".php", ".rb", ".swift", + ".kt", ".kts", ".scala", ".sh", ".bash", ".zsh", ".ps1", + ".sql", ".toml", ".ini", ".cfg", ".log", ".xml", ".css", ".scss", + } +) + + +_DESCRIBE_PROMPT = ( + "Describe this figure in <=60 words. Focus on factual content " + "(axes, labels, captions, visible text, main objects). Do not " + "speculate beyond what is visible." +) + + +DEFAULT_DOCUMENT_VISUAL_PAYLOADS = 3 +MAX_DOCUMENT_VISUAL_PAYLOADS = 10 +_MAX_ENCODED_VISUALS = DEFAULT_DOCUMENT_VISUAL_PAYLOADS +_EXTRACT_TIMEOUT_SECONDS = 120 +_VLM_CAPTION_TOTAL_TIMEOUT_SECONDS = 180 +_LOCAL_VLM_CAPTION_CONCURRENCY = 1 +_DEFAULT_VLM_CAPTION_CONCURRENCY = 3 +_EXTRACT_CONCURRENCY = max( + 1, int(os.environ.get("UNSLOTH_STUDIO_EXTRACT_CONCURRENCY", "2")) +) +_EXTRACT_SEMAPHORE = threading.BoundedSemaphore(_EXTRACT_CONCURRENCY) +_PAGE_RENDER_DPI = 150 +_MAX_PAGE_RENDER_PIXELS = 4_000_000 +_MIME_TO_SUFFIX = { + "application/pdf": ".pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/json": ".json", + "application/x-ndjson": ".jsonl", + "application/xml": ".xml", + "application/yaml": ".yaml", + "application/javascript": ".js", + "text/html": ".html", + "text/markdown": ".md", + "text/plain": ".txt", + "text/csv": ".csv", + "text/css": ".css", + "text/javascript": ".js", + "text/xml": ".xml", + "text/yaml": ".yaml", +} + +_PLAIN_TEXT_SUFFIXES = SUPPORTED_SUFFIXES - {".pdf", ".docx", ".html", ".htm"} + + +def _normalized_suffix(filename: str, content_type: str = "") -> str: + suffix = os.path.splitext(filename)[1].lower() + if suffix in SUPPORTED_SUFFIXES: + return suffix + mime = (content_type or "").split(";", 1)[0].strip().lower() + return _MIME_TO_SUFFIX.get(mime, suffix) + + +class DocumentExtractionUnavailable(RuntimeError): + """Document extraction backend is not installed or failed to import. + + The backend is PyMuPDF4LLM + mammoth for parsed document formats. + """ + + +class DocumentExtractionTimeout(RuntimeError): + """Raised when document parsing exceeds the 120-second worker limit.""" + + +class DocumentExtractionBusy(RuntimeError): + """Raised when the bounded document extraction worker pool is saturated.""" + + +class DocumentExtractionCancelled(RuntimeError): + """Raised when the caller cancels an in-flight extraction.""" + + +class DocumentExtractionEncrypted(RuntimeError): + """Raised when a PDF is encrypted and cannot be parsed without a password.""" + + +try: # pragma: no cover - presence depends on optional install + import pymupdf # type: ignore + import pymupdf4llm # type: ignore +except Exception as _pdf_extract_exc: # pragma: no cover + pymupdf = None # type: ignore[assignment] + pymupdf4llm = None # type: ignore[assignment] + _PDF_EXTRACTION_IMPORT_ERROR: Optional[BaseException] = _pdf_extract_exc +else: + _PDF_EXTRACTION_IMPORT_ERROR = None + +try: # pragma: no cover - presence depends on optional install + import mammoth # type: ignore +except Exception as _docx_extract_exc: # pragma: no cover + mammoth = None # type: ignore[assignment] + _DOCX_EXTRACTION_IMPORT_ERROR: Optional[BaseException] = _docx_extract_exc +else: + _DOCX_EXTRACTION_IMPORT_ERROR = None + +# The dispatcher can still extract plain text / code / data files when PDF or +# DOCX optional parsers are missing. Format-specific helpers raise +# DocumentExtractionUnavailable only when that format is actually requested. +DOCUMENT_EXTRACTION_AVAILABLE = True +_DOCUMENT_EXTRACTION_IMPORT_ERROR: Optional[BaseException] = ( + _PDF_EXTRACTION_IMPORT_ERROR or _DOCX_EXTRACTION_IMPORT_ERROR +) + + +def document_parser_support() -> dict[str, bool]: + return { + "pdf": _PDF_EXTRACTION_IMPORT_ERROR is None, + "docx": _DOCX_EXTRACTION_IMPORT_ERROR is None, + "html": True, + "text": True, + "data": True, + "code": True, + } + + +def document_parser_unavailable_reasons() -> dict[str, str]: + reasons: dict[str, str] = {} + if _PDF_EXTRACTION_IMPORT_ERROR is not None: + reasons["pdf"] = "PDF extraction requires pymupdf and pymupdf4llm." + if _DOCX_EXTRACTION_IMPORT_ERROR is not None: + reasons["docx"] = "DOCX extraction requires mammoth." + return reasons + + +@dataclass +class ExtractedFigure: + id: str + page: Optional[int] + caption: Optional[str] + error: Optional[str] = None + kind: Literal["figure", "page"] = "figure" + image_mime: Optional[str] = None + image_base64: Optional[str] = None + image_width: Optional[int] = None + image_height: Optional[int] = None + + +@dataclass +class ExtractResult: + markdown: str + figures: List[ExtractedFigure] = field(default_factory = list) + page_count: int = 0 + tokens_est: int = 0 + describe_skipped_reason: Optional[str] = None + vlm_source: Optional[str] = None + vlm_model: Optional[str] = None + image_input_available: bool = False + warnings: List[str] = field(default_factory = list) + + +ProgressCb = Callable[[dict], Awaitable[None]] + + +def _ensure_pdf_backend() -> None: + if pymupdf is None or pymupdf4llm is None: + if _PDF_EXTRACTION_IMPORT_ERROR is not None: + logger.debug( + "PDF extraction parser import failed: %s", + _PDF_EXTRACTION_IMPORT_ERROR, + ) + raise DocumentExtractionUnavailable( + "PDF extraction requires pymupdf and pymupdf4llm. Re-run Studio " + "setup to install the parser dependencies from " + "studio/backend/requirements/single-env/data-designer-deps.txt" + ) + + +def _ensure_docx_backend() -> None: + if mammoth is None: + if _DOCX_EXTRACTION_IMPORT_ERROR is not None: + logger.debug( + "DOCX extraction parser import failed: %s", + _DOCX_EXTRACTION_IMPORT_ERROR, + ) + raise DocumentExtractionUnavailable( + "DOCX extraction requires mammoth. Re-run Studio setup to install " + "the parser dependencies from " + "studio/backend/requirements/single-env/data-designer-deps.txt" + ) + + +def _estimate_tokens(text: str) -> int: + return max(0, len(text) // 4) + + +def _encode_pil_image_for_chat(image: Any) -> tuple[Optional[str], Optional[int], Optional[int], Optional[str]]: + if image is None: + return None, None, None, None + try: + from PIL import Image as PILImage + + img = image.copy() + img.thumbnail((1600, 1600)) + if img.mode in ("RGBA", "LA"): + background = PILImage.new("RGB", img.size, (255, 255, 255)) + alpha = img.getchannel("A") + background.paste(img.convert("RGB"), mask = alpha) + img = background + elif img.mode != "RGB": + img = img.convert("RGB") + + out = io.BytesIO() + img.save(out, format = "JPEG", quality = 88, optimize = True) + encoded = base64.b64encode(out.getvalue()).decode("ascii") + return encoded, img.width, img.height, "image/jpeg" + except (ImportError, AttributeError, ValueError, OSError) as exc: + logger.warning("Failed to encode extracted document image", exc_info=exc) + return None, None, None, None + + +async def _describe_image_via_vlm( + *, + image_base64: str, + image_mime: str, + endpoint_url: str, + model_name: str, + authorization_header: Optional[str], + timeout_seconds: float, +) -> tuple[Optional[str], Optional[str]]: + try: + import httpx + except Exception as exc: + return None, f"httpx unavailable: {exc}" + + headers = {"Content-Type": "application/json"} + if authorization_header: + headers["Authorization"] = authorization_header + + data_url = f"data:{image_mime};base64,{image_base64}" + payload = { + "model": model_name, + "stream": False, + "max_tokens": 512, + "temperature": 0.2, + "top_p": 0.9, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": _DESCRIBE_PROMPT}, + {"type": "image_url", "image_url": {"url": data_url}}, + ], + } + ], + } + try: + async with httpx.AsyncClient(timeout = timeout_seconds) as client: + response = await client.post( + endpoint_url.rstrip("/") + "/v1/chat/completions", + headers = headers, + json = payload, + ) + if response.status_code >= 400: + return None, ( + f"VLM caption request failed with HTTP " + f"{response.status_code}" + ) + body = response.json() + choice = (body.get("choices") or [{}])[0] + message = choice.get("message") or {} + finish_reason = choice.get("finish_reason") + + # Some chat templates (Gemma 3/3n via llama-server, Qwen3 always-think) + # route the entire visible reply into ``reasoning_content`` and leave + # ``content`` empty. The chat UI handles this in its streaming + # consumer (see ``llama_cpp._chat_completion``); mirror that fallback + # here so non-streaming callers see the same answer. + candidates: list[Any] = [ + message.get("content"), + message.get("reasoning_content"), + message.get("text"), + ] + # Some servers return content as a list of parts (OpenAI multimodal); + # join any text parts into one string before checking emptiness. + normalized: list[str] = [] + for raw in candidates: + if isinstance(raw, str): + if raw.strip(): + normalized.append(raw.strip()) + elif isinstance(raw, list): + parts = [ + part.get("text", "") + for part in raw + if isinstance(part, dict) + and isinstance(part.get("text"), str) + ] + joined = "".join(parts).strip() + if joined: + normalized.append(joined) + + if not normalized: + logger.warning( + "VLM caption empty: finish_reason=%r message_keys=%s", + finish_reason, + list(message.keys()), + ) + return None, ( + f"VLM caption empty (finish_reason={finish_reason!r})" + ) + # Prefer the first non-empty candidate + # (content > reasoning_content > text). + return normalized[0], None + except Exception as exc: + logger.debug("VLM caption request failed", exc_info = True) + return None, f"VLM caption request failed: {type(exc).__name__}" + + +def _build_extract_options( + *, + extract_images: bool, + use_vlm_ocr: bool, + max_visual_payloads: int, +) -> tuple[dict, list[str]]: + """Return ``(options, build_warnings)``. + + The options dict is a simple bag of flags consumed by the synchronous + extract dispatcher. There is no local OCR pass available in this build; + ``use_vlm_ocr=True`` is implemented as a bounded full-page visual + extraction fallback for VLM captioning. + """ + build_warnings: list[str] = [] + if use_vlm_ocr: + build_warnings.append( + "Full-page OCR was requested, but this build has no local OCR " + "engine; rendered page images will be sent to the loaded vision " + "model when image description is enabled." + ) + options = { + "extract_images": bool(extract_images), + "use_vlm_ocr": bool(use_vlm_ocr), + "max_visual_payloads": max(0, max_visual_payloads), + } + return options, build_warnings + + +def _pymupdf4llm_markdown_kwargs() -> dict[str, Any]: + """Return kwargs supported by the installed pymupdf4llm.to_markdown().""" + preferred = { + "write_images": False, + "show_progress": False, + "ignore_images": True, + "table_strategy": "lines_strict", + "use_ocr": False, + "force_ocr": False, + } + try: + signature = inspect.signature(pymupdf4llm.to_markdown) + except (TypeError, ValueError): + return { + key: value + for key, value in preferred.items() + if key not in {"use_ocr", "force_ocr"} + } + params = signature.parameters + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()): + return preferred + return {key: value for key, value in preferred.items() if key in params} + + +def _safe_page_pixmap(page: Any) -> Any: + rect = getattr(page, "rect", None) + width_pt = max(float(getattr(rect, "width", 0) or 0), 1.0) + height_pt = max(float(getattr(rect, "height", 0) or 0), 1.0) + scale = _PAGE_RENDER_DPI / 72.0 + projected_pixels = width_pt * scale * height_pt * scale + if projected_pixels > _MAX_PAGE_RENDER_PIXELS: + scale *= math.sqrt(_MAX_PAGE_RENDER_PIXELS / projected_pixels) + scale = max(scale, 0.05) + matrix = pymupdf.Matrix(scale, scale) # type: ignore[union-attr] + return page.get_pixmap(matrix = matrix, alpha = False) + + +def _append_page_image_figure( + doc: Any, + figures_out: list[ExtractedFigure], + *, + page_index: int, + max_figures: int, + encode_image: bool = True, +) -> bool: + if len(figures_out) >= max_figures: + return False + if not encode_image: + figures_out.append( + ExtractedFigure( + id = f"page-{page_index + 1}", + page = page_index + 1, + caption = None, + error = None, + kind = "page", + ) + ) + return True + try: + from PIL import Image as PILImage + + pix = _safe_page_pixmap(doc[page_index]) + png_bytes = pix.tobytes("png") + page_image = PILImage.open(io.BytesIO(png_bytes)) + image_base64, image_width, image_height, image_mime = ( + _encode_pil_image_for_chat(page_image) + ) + if not image_base64: + return False + figures_out.append( + ExtractedFigure( + id = f"page-{page_index + 1}", + page = page_index + 1, + caption = None, + error = None, + kind = "page", + image_mime = image_mime, + image_base64 = image_base64, + image_width = image_width, + image_height = image_height, + ) + ) + return True + except ( + ImportError, + MemoryError, + OverflowError, + ValueError, + OSError, + RuntimeError, + ) as exc: + logger.warning( + "Failed to render page %d preview for PDF", + page_index + 1, + exc_info = exc, + ) + return False + + +def _extract_pdf( + file_bytes: bytes, + max_figures: int, + use_vlm_ocr: bool, + max_visual_payloads: int, +) -> tuple[str, list[ExtractedFigure], int, int, int]: + """Extract Markdown + figures from a PDF via PyMuPDF4LLM. + + Returns ``(markdown, figures, page_count, truncated_count, seen)``. + """ + _ensure_pdf_backend() + assert pymupdf is not None and pymupdf4llm is not None # for type-checkers + + doc = pymupdf.open(stream = file_bytes, filetype = "pdf") + try: + if getattr(doc, "is_encrypted", False) or getattr(doc, "needs_pass", False): + raise DocumentExtractionEncrypted( + "Encrypted PDF; provide a password before extracting it." + ) + markdown = pymupdf4llm.to_markdown(doc, **_pymupdf4llm_markdown_kwargs()) + + figures_out: list[ExtractedFigure] = [] + encoded_visuals = 0 + seen = 0 + truncated_count = 0 + page_count = len(doc) + + if max_figures > 0 and page_count > 0: + if use_vlm_ocr: + for page_index in range(page_count): + if len(figures_out) >= max_figures: + truncated_count += page_count - page_index + break + if _append_page_image_figure( + doc, + figures_out, + page_index = page_index, + max_figures = max_figures, + encode_image = encoded_visuals < max_visual_payloads, + ): + if figures_out[-1].image_base64: + encoded_visuals += 1 + seen += 1 + elif _append_page_image_figure( + doc, + figures_out, + page_index = 0, + max_figures = max_figures, + encode_image = encoded_visuals < max_visual_payloads, + ): + if figures_out[-1].image_base64: + encoded_visuals += 1 + + if not use_vlm_ocr: + try: + from PIL import Image as PILImage + + for page_index in range(page_count): + page = doc[page_index] + try: + images = page.get_images(full = True) + except (ValueError, RuntimeError) as exc: + logger.debug( + "page.get_images failed on page %d", + page_index + 1, + exc_info = exc, + ) + continue + for img_info in images: + xref = img_info[0] if img_info else 0 + if not xref: + continue + try: + extracted = doc.extract_image(xref) + except (ValueError, RuntimeError) as exc: + logger.debug( + "doc.extract_image failed for xref %s", + xref, + exc_info = exc, + ) + continue + if not extracted: + continue + raw_bytes = extracted.get("image") + if not raw_bytes: + continue + try: + pil_img = PILImage.open(io.BytesIO(raw_bytes)) + pil_img.load() + except (OSError, ValueError) as exc: + logger.debug( + "PIL failed to decode extracted image xref %s", + xref, + exc_info = exc, + ) + continue + if pil_img.width < 50 or pil_img.height < 50: + continue + seen += 1 + if len(figures_out) >= max_figures: + truncated_count += 1 + continue + image_base64 = None + image_width = None + image_height = None + image_mime = None + if encoded_visuals < max_visual_payloads: + ( + image_base64, + image_width, + image_height, + image_mime, + ) = _encode_pil_image_for_chat(pil_img) + if image_base64: + encoded_visuals += 1 + figures_out.append( + ExtractedFigure( + id = f"fig-{len(figures_out)}", + page = page_index + 1, + caption = None, + error = None, + kind = "figure", + image_mime = image_mime, + image_base64 = image_base64, + image_width = image_width, + image_height = image_height, + ) + ) + except ImportError as exc: + logger.warning( + "Pillow is unavailable; skipping embedded-image extraction", + exc_info = exc, + ) + + return markdown, figures_out, page_count, truncated_count, seen + finally: + try: + doc.close() + except Exception: # pragma: no cover - defensive + logger.debug("pymupdf doc.close() raised", exc_info = True) + + +def _extract_docx( + file_bytes: bytes, +) -> tuple[str, list[ExtractedFigure], int, int, int]: + _ensure_docx_backend() + assert mammoth is not None # for type-checkers + stream = io.BytesIO(file_bytes) + result = mammoth.convert_to_markdown(stream) + markdown = result.value or "" + return markdown, [], 0, 0, 0 + + +def _extract_plaintext( + file_bytes: bytes, +) -> tuple[str, list[ExtractedFigure], int, int, int]: + text = file_bytes.decode("utf-8", errors = "replace") + return text, [], 0, 0, 0 + + +def _extract_html( + file_bytes: bytes, +) -> tuple[str, list[ExtractedFigure], int, int, int]: + html = file_bytes.decode("utf-8", errors = "replace") + try: + from core.inference._html_to_md import html_to_markdown + except Exception as exc: + logger.warning( + "HTML-to-Markdown converter unavailable; using raw HTML", + exc_info = exc, + ) + return html, [], 0, 0, 0 + return html_to_markdown(html), [], 0, 0, 0 + + +def _run_extract_sync( + file_bytes: bytes, + filename: str, + options: dict, + content_type: str = "", +) -> tuple[str, list[ExtractedFigure], int, int, int]: + """Synchronous dispatch by file suffix. + + Returns ``(markdown, figures, page_count, truncated_count, seen)``. + """ + suffix = _normalized_suffix(filename, content_type) + extract_images = bool(options.get("extract_images")) + use_vlm_ocr = bool(options.get("use_vlm_ocr")) + max_figures = int(options.get("max_figures", 0)) if extract_images else 0 + max_visual_payloads = int( + options.get("max_visual_payloads", DEFAULT_DOCUMENT_VISUAL_PAYLOADS) + ) + + if suffix == ".pdf": + return _extract_pdf(file_bytes, max_figures, use_vlm_ocr, max_visual_payloads) + if suffix == ".docx": + return _extract_docx(file_bytes) + if suffix in {".html", ".htm"}: + return _extract_html(file_bytes) + if suffix in _PLAIN_TEXT_SUFFIXES: + return _extract_plaintext(file_bytes) + raise ValueError(f"Unsupported file type: {filename}") + + +_RUN_EXTRACT_SYNC_ORIGINAL = _run_extract_sync + + +def _run_extract_worker( + result_queue: Any, + file_bytes: bytes, + filename: str, + options: dict, + content_type: str, +) -> None: + try: + result_queue.put( + ("ok", _run_extract_sync(file_bytes, filename, options, content_type)) + ) + except DocumentExtractionUnavailable as exc: + result_queue.put(("extraction_unavailable", str(exc))) + except DocumentExtractionEncrypted as exc: + result_queue.put(("encrypted", str(exc))) + except ValueError as exc: + result_queue.put(("value_error", str(exc))) + except BaseException as exc: + result_queue.put(("error", type(exc).__name__, str(exc))) + + +def _terminate_extract_process(proc: multiprocessing.Process) -> None: + if not proc.is_alive(): + return + proc.terminate() + proc.join(5) + if proc.is_alive() and hasattr(proc, "kill"): + proc.kill() + proc.join(2) + + +def _run_extract_process_sync( + file_bytes: bytes, + filename: str, + options: dict, + content_type: str, + timeout_seconds: int, + cancel_event: Optional[threading.Event] = None, +) -> tuple[str, list[ExtractedFigure], int, int, int]: + if cancel_event is not None and cancel_event.is_set(): + raise DocumentExtractionCancelled("document extraction was cancelled") + if not _EXTRACT_SEMAPHORE.acquire(blocking = False): + raise DocumentExtractionBusy("document extraction is busy") + + ctx = multiprocessing.get_context("spawn" if os.name == "nt" else "fork") + result_queue = ctx.Queue(maxsize = 1) + proc = ctx.Process( + target = _run_extract_worker, + args = (result_queue, file_bytes, filename, options, content_type), + daemon = True, + ) + try: + if cancel_event is not None and cancel_event.is_set(): + raise DocumentExtractionCancelled("document extraction was cancelled") + proc.start() + deadline = time.monotonic() + timeout_seconds + message = None + while message is None: + try: + message = result_queue.get(timeout = 0.1) + break + except queue.Empty: + if cancel_event is not None and cancel_event.is_set(): + _terminate_extract_process(proc) + raise DocumentExtractionCancelled( + "document extraction was cancelled" + ) + if not proc.is_alive(): + break + if time.monotonic() >= deadline: + _terminate_extract_process(proc) + raise DocumentExtractionTimeout( + "document parsing exceeded the 120-second worker limit" + ) + + proc.join(2) + if proc.is_alive(): + proc.terminate() + proc.join(2) + if message is None: + raise RuntimeError( + f"document extraction worker exited without a result " + f"(exitcode={proc.exitcode})" + ) + + kind = message[0] + if kind == "ok": + return message[1] + if kind == "extraction_unavailable": + raise DocumentExtractionUnavailable(message[1]) + if kind == "encrypted": + raise DocumentExtractionEncrypted(message[1]) + if kind == "value_error": + raise ValueError(message[1]) + if kind == "error": + raise RuntimeError(f"{message[1]}: {message[2]}") + raise RuntimeError(f"unexpected document worker result: {kind!r}") + finally: + try: + result_queue.close() + result_queue.join_thread() + except Exception: + pass + _EXTRACT_SEMAPHORE.release() + + +async def extract_document( + file_bytes: bytes, + filename: str, + *, + content_type: str = "", + describe_images: bool = True, + use_vlm_ocr: bool = False, + max_figures: int = 40, + max_visual_payloads: int = DEFAULT_DOCUMENT_VISUAL_PAYLOADS, + vlm_timeout_seconds: float = 60.0, + capability: Optional[VlmCapability] = None, + self_base_url: Optional[str] = None, + authorization_header: Optional[str] = None, + progress_cb: Optional[ProgressCb] = None, + cancel_event: Optional[threading.Event] = None, +) -> ExtractResult: + """Extract layout-aware Markdown plus figure metadata. + + When ``describe_images`` is True and the active model is + vision-capable, the selected visual references are captioned via the + OpenAI-compat ``/v1/chat/completions`` surface after extraction. + Otherwise figures come back with ``caption=None`` and + ``describe_skipped_reason`` carries the human-readable reason. + """ + async def _emit(**event: Any) -> None: + if cancel_event is not None and cancel_event.is_set(): + raise DocumentExtractionCancelled("document extraction was cancelled") + if progress_cb is not None: + try: + await progress_cb(event) + except Exception: + logger.debug("progress_cb raised; continuing", exc_info = True) + + max_figures = max(0, max_figures) + max_visual_payloads = max(0, min(max_visual_payloads, max_figures)) + cap = capability if capability is not None else detect_loaded_vlm(self_base_url) + image_input_available = bool(cap.is_vlm and cap.endpoint_url and cap.model_name) + describe_available = bool( + describe_images and cap.is_vlm and cap.endpoint_url and cap.model_name + ) + effective_describe = ( + describe_available and max_figures > 0 and max_visual_payloads > 0 + ) + extract_images = max_figures > 0 + + skipped_reason: Optional[str] = None + if describe_images and not effective_describe: + if describe_available and max_figures <= 0: + skipped_reason = "figure description disabled because max_figures is 0" + elif describe_available and max_visual_payloads <= 0: + skipped_reason = ( + "figure description disabled because max_visual_payloads is 0" + ) + else: + skipped_reason = cap.reason or "no_vlm" + + await _emit(stage = "parsing") + + options, build_warnings = _build_extract_options( + extract_images = extract_images, + use_vlm_ocr = use_vlm_ocr, + max_visual_payloads = max_visual_payloads, + ) + options["max_figures"] = max_figures + + try: + if _run_extract_sync is _RUN_EXTRACT_SYNC_ORIGINAL: + markdown, figures_out, page_count, truncated_count, seen = await asyncio.to_thread( + _run_extract_process_sync, + file_bytes, + filename, + options, + content_type, + _EXTRACT_TIMEOUT_SECONDS, + cancel_event, + ) + else: + # Tests monkeypatch _run_extract_sync directly; preserve that seam + # without forcing patched callables through multiprocessing spawn. + loop = asyncio.get_running_loop() + markdown, figures_out, page_count, truncated_count, seen = ( + await asyncio.wait_for( + loop.run_in_executor( + None, + _run_extract_sync, + file_bytes, + filename, + options, + content_type, + ), + timeout = _EXTRACT_TIMEOUT_SECONDS, + ) + ) + except asyncio.TimeoutError: + raise DocumentExtractionTimeout( + "document parsing exceeded the 120-second worker limit" + ) + except DocumentExtractionTimeout: + raise + except DocumentExtractionBusy: + raise + except DocumentExtractionCancelled: + raise + except DocumentExtractionEncrypted: + raise + except DocumentExtractionUnavailable: + raise + except ValueError: + # Unsupported file type — surface unchanged so the route can map to 415. + raise + except Exception as exc: + logger.exception("document extraction failed for %s", filename) + raise RuntimeError("document extraction failed") from exc + + caption_deadline_hit = False + if effective_describe: + caption_concurrency = ( + _LOCAL_VLM_CAPTION_CONCURRENCY + if cap.source in {"transformers", "unsloth"} + else _DEFAULT_VLM_CAPTION_CONCURRENCY + ) + sem = asyncio.Semaphore(caption_concurrency) + + async def _describe_one(index: int, figure: ExtractedFigure) -> None: + if figure.caption or not figure.image_base64 or not figure.image_mime: + return + if cancel_event is not None and cancel_event.is_set(): + raise DocumentExtractionCancelled("document extraction was cancelled") + async with sem: + if cancel_event is not None and cancel_event.is_set(): + raise DocumentExtractionCancelled( + "document extraction was cancelled" + ) + try: + caption, error = await _describe_image_via_vlm( + image_base64 = figure.image_base64, + image_mime = figure.image_mime, + endpoint_url = cap.endpoint_url or "", + model_name = cap.model_name or "", + authorization_header = authorization_header, + timeout_seconds = vlm_timeout_seconds, + ) + figures_out[index] = replace( + figure, + caption = caption, + error = error, + ) + except asyncio.TimeoutError as exc: + logger.warning( + "VLM describe timed out for figure %s", figure.id, exc_info=exc + ) + figures_out[index] = replace( + figure, + error = f"VLM describe timed out: {type(exc).__name__}", + ) + except Exception as exc: + logger.warning( + "VLM describe failed for figure %s", figure.id, exc_info=exc + ) + figures_out[index] = replace( + figure, + error = f"VLM describe failed: {type(exc).__name__}", + ) + + tasks = [ + _describe_one(index, fig) + for index, fig in enumerate(figures_out[:max_figures]) + if fig.image_base64 and fig.image_mime + ] + if tasks: + try: + caption_timeout_seconds = _VLM_CAPTION_TOTAL_TIMEOUT_SECONDS + if cap.source in {"transformers", "unsloth"}: + caption_timeout_seconds = max( + caption_timeout_seconds, + len(tasks) * vlm_timeout_seconds + 15, + ) + results = await asyncio.wait_for( + asyncio.gather(*tasks, return_exceptions=True), + timeout = caption_timeout_seconds, + ) + for result in results: + if isinstance( + result, + (DocumentExtractionCancelled, asyncio.CancelledError), + ): + raise result + except asyncio.TimeoutError: + caption_deadline_hit = True + for index, figure in enumerate(figures_out): + if figure.image_base64 and not figure.caption and not figure.error: + figures_out[index] = replace( + figure, + error = "VLM caption deadline exceeded", + ) + + warnings: List[str] = list(build_warnings) + if truncated_count > 0: + warnings.append( + f"Document has {seen} figures; showing the first {max_figures} " + f"({truncated_count} truncated)." + ) + visual_payload_count = sum(1 for figure in figures_out if figure.image_base64) + if ( + visual_payload_count >= max_visual_payloads + and len(figures_out) > visual_payload_count + ): + warnings.append( + f"Only the first {max_visual_payloads} visual payloads " + "were attached; remaining figure references are text-only." + ) + if effective_describe and figures_out and all(f.caption is None for f in figures_out): + error_samples: list[str] = [] + seen_errors: set[str] = set() + for figure in figures_out: + if not figure.error or figure.error in seen_errors: + continue + seen_errors.add(figure.error) + error_samples.append(f"{figure.id}: {figure.error}") + if len(error_samples) >= 3: + break + sample_suffix = ( + " Examples: " + "; ".join(error_samples) + "." + if error_samples + else "" + ) + warnings.append( + "Figure descriptions were requested but none were produced — " + "check that the loaded model accepts image inputs via /v1." + f"{sample_suffix}" + ) + if caption_deadline_hit: + warnings.append( + "Figure captioning reached the inline timeout; some image " + "descriptions were skipped." + ) + + await _emit(stage = "done") + + return ExtractResult( + markdown = markdown, + figures = figures_out, + page_count = page_count, + tokens_est = _estimate_tokens(markdown), + describe_skipped_reason = skipped_reason, + vlm_source = cap.source, + vlm_model = cap.model_name, + image_input_available = image_input_available, + warnings = warnings, + ) diff --git a/studio/backend/core/chat/vlm_capability.py b/studio/backend/core/chat/vlm_capability.py new file mode 100644 index 0000000000..2e98d3eefe --- /dev/null +++ b/studio/backend/core/chat/vlm_capability.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +""" +Runtime probe: is the currently loaded model vision-capable, and where +is its OpenAI-compatible endpoint? + +Unifies the three Studio inference backends (embedded llama-server for +GGUF, transformers, Unsloth/LoRA) behind a single ``VlmCapability`` +dataclass. Read-only — never loads or modifies models. + +Why this replaces the old ``VISION_ARCHITECTURES`` allow-list: +- Allow-lists silently exclude legitimately new vision architectures. +- Runtime probing matches the user's actual loaded model. +- The document extractor can caption selected visual references through + any loaded backend exposing ``/v1/chat/completions`` without + hard-coding architecture names. +""" + +from __future__ import annotations + +import logging +from dataclasses import asdict, dataclass +from typing import Any, Literal, Optional +from urllib.parse import urlparse + + +logger = logging.getLogger(__name__) + + +VlmSource = Literal["gguf", "transformers", "unsloth", "none"] + + +@dataclass(frozen = True) +class VlmCapability: + """Immutable snapshot of the loaded model's image-input capability.""" + + is_vlm: bool + endpoint_url: Optional[str] + model_name: Optional[str] + source: VlmSource + reason: Optional[str] = None + + @classmethod + def none(cls, reason: str = "no model loaded") -> "VlmCapability": + return cls( + is_vlm = False, + endpoint_url = None, + model_name = None, + source = "none", + reason = reason, + ) + + def to_dict(self) -> dict: + return asdict(self) + + +def _probe_gguf(llama: Any = None) -> Optional[VlmCapability]: + if llama is None: + try: + from core.inference.llama_cpp import get_llama_cpp_backend + except Exception: # pragma: no cover - older embedding paths + return None + + try: + llama = get_llama_cpp_backend() + except Exception: + return None + + if not getattr(llama, "is_loaded", False): + return None + + base_url = getattr(llama, "base_url", None) + model_id = getattr(llama, "model_identifier", None) + is_vision = bool(getattr(llama, "is_vision", False)) + + if not base_url or not model_id: + # Half-initialised llama-server state — fall through to the + # transformers probe instead of returning a misleading + # non-vision GGUF result that suppresses the fallback chain. + logger.debug( + "llama-server reports is_loaded=True but base_url / model id missing" + ) + return None + + return VlmCapability( + is_vlm = is_vision, + endpoint_url = base_url, + model_name = model_id, + source = "gguf", + reason = None if is_vision else "gguf: model loaded, is_vision=False (no mmproj clip)", + ) + + +def _probe_transformers(self_base_url: Optional[str]) -> Optional[VlmCapability]: + try: + from core.inference import get_inference_backend + except ModuleNotFoundError as exc: + if exc.name == "core.inference" or ( + exc.name and exc.name.startswith("core.inference.") + ): + return None + logger.exception("Failed to import transformers inference backend") + return None + except ImportError: + # A different ImportError variant (e.g. circular import). Treat as + # backend-unavailable. Anything else (NameError/AttributeError raised + # by core.inference.__init__) propagates so real bugs aren't masked + # as "no VLM loaded". + logger.exception("Failed to import transformers inference backend") + return None + + try: + ib = get_inference_backend() + except Exception: + return None + + name: Optional[str] = getattr(ib, "active_model_name", None) + if not name: + return None + + models: dict = getattr(ib, "models", {}) or {} + info: dict = models.get(name) or {} + is_vision = bool(info.get("is_vision", False)) + is_lora = bool(info.get("is_lora", False)) + source: VlmSource = "unsloth" if is_lora else "transformers" + + if not self_base_url: + return VlmCapability( + is_vlm = False, + endpoint_url = None, + model_name = name, + source = source, + reason = f"{source}: self_base_url=None (cannot self-loopback to /v1/chat/completions)", + ) + + return VlmCapability( + is_vlm = is_vision, + endpoint_url = self_base_url.rstrip("/"), + model_name = name, + source = source, + reason = None if is_vision else f"{source}: active model not marked is_vision", + ) + + +def detect_loaded_vlm( + self_base_url: Optional[str] = None, + *, + llama_backend: Any = None, +) -> VlmCapability: + """Identify the active model and whether it can describe images. + + ``self_base_url`` is only consulted when the active model is served + by the transformers / Unsloth backend; document image captioning must + loop back through our own ``/v1/chat/completions``. GGUF models return + llama-server's own URL and ignore this argument. + """ + gguf = _probe_gguf(llama_backend) + if gguf is not None: + return gguf + + tf = _probe_transformers(self_base_url) + if tf is not None: + return tf + + return VlmCapability.none() + + +def extract_self_base_url(request: Any) -> Optional[str]: + """Derive a trusted local base URL for the active Studio server. + + The request Host header is attacker-controlled in many deployments, + so the returned origin always uses ``127.0.0.1``. Only the server + port is discovered, preferring the port published by ``run.py`` and + then uvicorn's ASGI scope. ``request.base_url`` is a last-resort + fallback for tests and non-uvicorn embedding. + """ + port: Optional[int] = None + + try: + candidate = getattr(getattr(request, "app", None), "state", None) + candidate = getattr(candidate, "server_port", None) + if isinstance(candidate, int) and candidate > 0: + port = candidate + except Exception: + port = None + + if port is None: + try: + server = getattr(request, "scope", {}).get("server") + if ( + isinstance(server, tuple) + and len(server) >= 2 + and isinstance(server[1], int) + and server[1] > 0 + ): + port = server[1] + except Exception: + port = None + + if port is None: + try: + base = str(getattr(request, "base_url", "") or "") + if not base: + return None + parsed = urlparse(base) + port = parsed.port if parsed.port is not None else 8888 + except Exception: + return None + + return f"http://127.0.0.1:{int(port)}" diff --git a/studio/backend/core/export/export.py b/studio/backend/core/export/export.py index 4ab95d896f..22dea33988 100644 --- a/studio/backend/core/export/export.py +++ b/studio/backend/core/export/export.py @@ -182,7 +182,10 @@ def load_checkpoint( # Detect audio type and vision self._audio_type = detect_audio_type(model_id) - self.is_vision = not self._audio_type and is_vision_model(model_id) + self.is_vision = not self._audio_type and is_vision_model( + model_id, + trust_remote_code = trust_remote_code, + ) # Load model based on type if self._audio_type == "csm": diff --git a/studio/backend/core/inference/__init__.py b/studio/backend/core/inference/__init__.py index 35318f6357..12315b706a 100644 --- a/studio/backend/core/inference/__init__.py +++ b/studio/backend/core/inference/__init__.py @@ -10,7 +10,7 @@ """ from .orchestrator import InferenceOrchestrator, get_inference_backend -from .llama_cpp import LlamaCppBackend +from .llama_cpp import LlamaCppBackend, get_llama_cpp_backend # Expose InferenceOrchestrator as InferenceBackend for backward compat InferenceBackend = InferenceOrchestrator @@ -19,5 +19,6 @@ "InferenceBackend", "InferenceOrchestrator", "get_inference_backend", + "get_llama_cpp_backend", "LlamaCppBackend", ] diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index 38c0261f5a..cecc919f22 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -491,6 +491,10 @@ def is_active(self) -> bool: def base_url(self) -> str: return f"http://127.0.0.1:{self._port}" + @property + def api_key(self) -> Optional[str]: + return self._api_key + @property def model_identifier(self) -> Optional[str]: return self._model_identifier @@ -2943,6 +2947,9 @@ def _parse_tool_calls_from_text(content: str) -> list[dict]: def _build_openai_messages( messages: list[dict], image_b64: Optional[str] = None, + image_b64s: Optional[list[str]] = None, + image_mime: Optional[str] = None, + image_mimes: Optional[list[str]] = None, ) -> list[dict]: """ Build OpenAI-format messages, optionally injecting an image_url @@ -2950,8 +2957,20 @@ def _build_openai_messages( If no image is provided, returns messages as-is. """ - if not image_b64: + images = ( + image_b64s + if image_b64s is not None + else ([image_b64] if image_b64 else []) + ) + images = [image for image in images if image] + if not images: return messages + if image_b64s is not None: + mimes = image_mimes or ["image/png"] * len(images) + else: + mimes = [image_mime or "image/png"] + if len(mimes) < len(images): + mimes = [*mimes, *(["image/png"] * (len(images) - len(mimes)))] # Find the last user message and convert to multimodal content parts result = [msg.copy() for msg in messages] @@ -2962,14 +2981,18 @@ def _build_openai_messages( if last_user_idx is not None: text_content = result[last_user_idx].get("content", "") - result[last_user_idx]["content"] = [ - {"type": "text", "text": text_content}, + image_parts = [ { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{image_b64}", + "url": f"data:{mime if mime and '/' in mime else 'image/png'};base64,{image}", }, - }, + } + for image, mime in zip(images, mimes) + ] + result[last_user_idx]["content"] = [ + {"type": "text", "text": text_content}, + *image_parts, ] return result @@ -3101,6 +3124,9 @@ def generate_chat_completion( self, messages: list[dict], image_b64: Optional[str] = None, + image_b64s: Optional[list[str]] = None, + image_mime: Optional[str] = None, + image_mimes: Optional[list[str]] = None, temperature: float = 0.6, top_p: float = 0.95, top_k: int = 20, @@ -3125,7 +3151,13 @@ def generate_chat_completion( if not self.is_loaded: raise RuntimeError("llama-server is not loaded") - openai_messages = self._build_openai_messages(messages, image_b64) + openai_messages = self._build_openai_messages( + messages, + image_b64 = image_b64, + image_b64s = image_b64s, + image_mime = image_mime, + image_mimes = image_mimes, + ) payload = { "messages": openai_messages, @@ -4338,3 +4370,20 @@ def generate_audio_response( return LlamaCppBackend._codec_mgr.decode( audio_type, device, token_ids = token_ids, text = data.get("content", "") ) + + +_llama_cpp_backend: Optional[LlamaCppBackend] = None + + +def get_llama_cpp_backend() -> LlamaCppBackend: + """Return the process-wide GGUF llama-server backend. + + Keep the singleton in ``core.inference`` so core helpers such as + ``core.chat.detect_loaded_vlm`` do not need to import route modules. + The instance is lazy to avoid subprocess cleanup side effects for + callers that only import model helpers. + """ + global _llama_cpp_backend + if _llama_cpp_backend is None: + _llama_cpp_backend = LlamaCppBackend() + return _llama_cpp_backend diff --git a/studio/backend/core/inference/worker.py b/studio/backend/core/inference/worker.py index 085a1ab899..0f70ccad40 100644 --- a/studio/backend/core/inference/worker.py +++ b/studio/backend/core/inference/worker.py @@ -74,7 +74,28 @@ def _send_response(resp_queue: Any, response: dict) -> None: logger.error("Failed to send response: %s", exc) -def _build_model_config(config: dict): +def _resolve_trust_remote_code(config: dict) -> bool: + # Auto-enable trust_remote_code for NemotronH/Nano models only. + # NemotronH has config parsing bugs requiring trust_remote_code=True. + # Other transformers 5.x models are native and do NOT need it. + # NOTE: Must NOT match Llama-Nemotron (standard Llama architecture). + trust_remote_code = config.get("trust_remote_code", False) + if not trust_remote_code: + model_name = config["model_name"] + _mn_lower = model_name.lower() + _NEMOTRON_TRUST_SUBSTRINGS = ("nemotron_h", "nemotron-h", "nemotron-3-nano") + if any(sub in _mn_lower for sub in _NEMOTRON_TRUST_SUBSTRINGS) and ( + _mn_lower.startswith("unsloth/") or _mn_lower.startswith("nvidia/") + ): + trust_remote_code = True + logger.info( + "Auto-enabled trust_remote_code for Nemotron model: %s", + model_name, + ) + return bool(trust_remote_code) + + +def _build_model_config(config: dict, *, trust_remote_code: bool | None = None): """Build a ModelConfig from the config dict.""" from utils.models import ModelConfig @@ -82,11 +103,14 @@ def _build_model_config(config: dict): hf_token = config.get("hf_token") hf_token = hf_token if hf_token and hf_token.strip() else None gguf_variant = config.get("gguf_variant") + if trust_remote_code is None: + trust_remote_code = _resolve_trust_remote_code(config) mc = ModelConfig.from_identifier( model_id = model_name, hf_token = hf_token, gguf_variant = gguf_variant, + trust_remote_code = trust_remote_code, ) if not mc: raise ValueError(f"Invalid model identifier: {model_name}") @@ -247,7 +271,8 @@ def _beat(): def _handle_load(backend, config: dict, resp_queue: Any) -> None: """Handle a load command: load a model into the backend.""" try: - mc = _build_model_config(config) + trust_remote_code = _resolve_trust_remote_code(config) + mc = _build_model_config(config, trust_remote_code = trust_remote_code) hf_token = config.get("hf_token") hf_token = hf_token if hf_token and hf_token.strip() else None @@ -287,24 +312,6 @@ def _handle_load(backend, config: dict, resp_queue: Any) -> None: except Exception as e: logger.warning("Could not read adapter_config.json: %s", e) - # Auto-enable trust_remote_code for NemotronH/Nano models only. - # NemotronH has config parsing bugs requiring trust_remote_code=True. - # Other transformers 5.x models are native and do NOT need it. - # NOTE: Must NOT match Llama-Nemotron (standard Llama architecture). - _NEMOTRON_TRUST_SUBSTRINGS = ("nemotron_h", "nemotron-h", "nemotron-3-nano") - trust_remote_code = config.get("trust_remote_code", False) - if not trust_remote_code: - model_name = config["model_name"] - _mn_lower = model_name.lower() - if any(sub in _mn_lower for sub in _NEMOTRON_TRUST_SUBSTRINGS) and ( - _mn_lower.startswith("unsloth/") or _mn_lower.startswith("nvidia/") - ): - trust_remote_code = True - logger.info( - "Auto-enabled trust_remote_code for Nemotron model: %s", - model_name, - ) - # Send heartbeats every 30s so the orchestrator knows we're still alive # (download / weight loading can take a long time on slow connections) xet_disabled = os.environ.get("HF_HUB_DISABLE_XET") == "1" diff --git a/studio/backend/core/training/trainer.py b/studio/backend/core/training/trainer.py index a3f063694f..635d547606 100644 --- a/studio/backend/core/training/trainer.py +++ b/studio/backend/core/training/trainer.py @@ -199,7 +199,11 @@ def pre_detect_and_load_tokenizer( # --- Detect VLM --- vision = ( - is_vision_model(model_name, hf_token = hf_token) + is_vision_model( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) if not self.is_audio else False ) @@ -572,7 +576,11 @@ def load_model( # VLM: vision model with image dataset (mutually exclusive with audio paths) vision = ( - is_vision_model(model_name, hf_token = hf_token) + is_vision_model( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) if not self.is_audio else False ) diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index 7a4c7d0b3c..51eda76c45 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -108,6 +108,10 @@ class ValidateModelRequest(BaseModel): gguf_variant: Optional[str] = Field( None, description = "GGUF quantization variant (e.g. 'Q4_K_M')" ) + trust_remote_code: bool = Field( + False, + description = "Allow validation probes that require custom model code.", + ) class ValidateModelResponse(BaseModel): @@ -151,6 +155,14 @@ class GenerateRequest(BaseModel): image_base64: Optional[str] = Field( None, description = "Base64 encoded image for vision models" ) + session_id: Optional[str] = Field( + None, + description = "[x-unsloth] Session/thread ID for cancellation scoping.", + ) + cancel_id: Optional[str] = Field( + None, + description = "[x-unsloth] Per-request cancellation token matched by /inference/cancel.", + ) class LoadResponse(BaseModel): @@ -316,6 +328,10 @@ class InferenceStatusResponse(BaseModel): supports_tools: bool = Field( False, description = "Whether the active model supports tool calling" ) + cache_type_kv: Optional[str] = Field( + None, + description = "KV cache data type for K and V (e.g. 'f16', 'bf16', 'q8_0')", + ) context_length: Optional[int] = Field( None, description = "Context length of the active model" ) @@ -1079,3 +1095,148 @@ class AnthropicMessagesResponse(BaseModel): stop_reason: Optional[str] = None stop_sequence: Optional[str] = None usage: AnthropicUsage = Field(default_factory = AnthropicUsage) + + +# ---------------------------------------------------------------------- # +# Chat document extraction (parsed documents + optional VLM captions) # +# ---------------------------------------------------------------------- # + + +class ExtractedFigureModel(BaseModel): + """A single extracted visual reference, optionally described by a + locally-loaded vision model.""" + + id: str = Field(..., description = "Stable id (e.g. 'fig-0')") + page: Optional[int] = Field(None, description = "1-based page number, if known") + caption: Optional[str] = Field( + None, description = "Short VLM-generated caption, or null if skipped/failed" + ) + error: Optional[str] = Field( + None, description = "Reason the describe call failed, if any" + ) + kind: Literal["figure", "page"] = Field( + "figure", + description = "Whether this reference is a detected figure or page image", + ) + image_mime: Optional[str] = Field( + None, description = "MIME type for image_base64 when a visual payload is present" + ) + image_base64: Optional[str] = Field( + None, + description = ( + "Base64-encoded visual payload for this reference. The first visual " + "reference is sent to vision-capable chat models as [Image #1]." + ), + ) + image_width: Optional[int] = Field( + None, ge = 1, description = "Width of image_base64 after resize" + ) + image_height: Optional[int] = Field( + None, ge = 1, description = "Height of image_base64 after resize" + ) + + +class ExtractDocumentResponse(BaseModel): + """ + Returned synchronously from ``POST /chat/extract-document`` for + small docs, or as the final SSE event for larger ones. + """ + + schema_version: int = Field(1, description = "Document extraction payload schema version") + filename: str = Field(..., description = "Original filename uploaded") + markdown: str = Field( + ..., description = "Layout-aware Markdown extracted from the document" + ) + page_count: int = Field(0, ge = 0, description = "Number of pages in the source") + tokens_est: int = Field( + 0, ge = 0, description = "Rough char/4 token estimate for the markdown" + ) + truncated: bool = Field( + False, + description = "Whether markdown was clipped to the requested token budget", + ) + figures: List[ExtractedFigureModel] = Field( + default_factory = list, + description = "Figures discovered in the document (captions optional)", + ) + describe_skipped_reason: Optional[str] = Field( + None, + description = ( + "If image description was requested but skipped, the reason " + "(e.g. 'loaded GGUF is not vision-capable'). Mirrors the " + "``reason`` surfaced by /chat/document-support." + ), + ) + vlm_source: Optional[str] = Field( + None, + description = ( + "Which inference backend served the describe calls: 'gguf', " + "'transformers', 'unsloth', or 'none' when no VLM was used." + ), + ) + vlm_model: Optional[str] = Field( + None, + description = "Identifier of the VLM whose captions appear in this document", + ) + image_input_available: bool = Field( + False, + description = ( + "Whether the active model can receive an extracted visual payload " + "alongside the markdown." + ), + ) + warnings: List[str] = Field( + default_factory = list, + description = "Non-fatal warnings surfaced to the UI", + ) + + +class VlmCapabilityModel(BaseModel): + """Runtime probe result for the currently-loaded model.""" + + is_vlm: bool = Field(..., description = "Whether the active model accepts image inputs") + endpoint_url: Optional[str] = Field( + None, + description = "Root URL serving /v1/chat/completions for the active model", + ) + model_name: Optional[str] = Field( + None, description = "Identifier of the active model, if any is loaded" + ) + source: Literal["gguf", "transformers", "unsloth", "none"] = Field( + ..., description = "Which backend currently owns the active model" + ) + reason: Optional[str] = Field( + None, + description = "Populated when is_vlm is false; explains why the UI toggle is disabled", + ) + + +class DocumentSupportResponse(BaseModel): + """Returned by GET /chat/document-support. + + Drives the Chat settings-card toggles. ``max_visual_payloads`` is kept + for older clients as an informational hint, not a hard request cap. + """ + + schema_version: int = Field(1, description = "Document support payload schema version") + extraction_available: bool = Field( + ..., + description = ( + "Whether the document extraction backend successfully imported " + "on the server" + ), + ) + max_visual_payloads: int = Field( + ..., + ge = 0, + description = "Legacy visual-payload hint; not a hard request cap", + ) + format_support: Dict[str, bool] = Field( + default_factory = dict, + description = "Per-format parser availability for document extraction", + ) + unavailable_formats: Dict[str, str] = Field( + default_factory = dict, + description = "Per-format parser unavailability reasons", + ) + vlm: VlmCapabilityModel diff --git a/studio/backend/requirements/studio.txt b/studio/backend/requirements/studio.txt index 186ba82fe0..3b235e3281 100644 --- a/studio/backend/requirements/studio.txt +++ b/studio/backend/requirements/studio.txt @@ -15,3 +15,13 @@ huggingface-hub==0.36.2 structlog>=24.1.0 diceware ddgs +pypdf>=6.0.0,<7 +python-multipart>=0.0.26 +# Document extraction relies on pymupdf4llm 1.27+ (installed via +# data-designer-deps.txt), which pulls pymupdf-layout. The bundled ONNX +# models work fine on modern onnxruntime; we require >=1.19 because +# earlier wheels (e.g. 1.17.x) were built against NumPy 1.x and crash +# on import in venvs that have NumPy 2.x installed (pymupdf.layout -> +# onnxruntime -> numpy._multiarray_umath ABI mismatch). Verified +# end-to-end with onnxruntime 1.25.0 + numpy 2.4.x. +onnxruntime>=1.19 diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index 6b559b9c45..b24772cf6c 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -9,6 +9,7 @@ import sys import time import uuid +from contextlib import suppress from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import StreamingResponse, JSONResponse, Response @@ -118,6 +119,7 @@ def _friendly_error(exc: Exception) -> str: _DEFAULT_MAX_TOKENS_FLOOR, _DEFAULT_T_MAX_PREDICT_MS, detect_reasoning_flags, + get_llama_cpp_backend, ) from core.inference.llama_server_args import validate_extra_args from utils.models import ModelConfig @@ -140,6 +142,7 @@ def _friendly_error(exc: Exception) -> str: _DEFAULT_MAX_TOKENS_FLOOR, _DEFAULT_T_MAX_PREDICT_MS, detect_reasoning_flags, + get_llama_cpp_backend, ) from core.inference.llama_server_args import validate_extra_args from utils.models import ModelConfig @@ -194,7 +197,11 @@ def _friendly_error(exc: Exception) -> str: AnthropicResponseTextBlock, AnthropicResponseToolUseBlock, AnthropicUsage, + DocumentSupportResponse, + ExtractDocumentResponse, + ExtractedFigureModel, ) +from dataclasses import asdict as _asdict from core.inference.anthropic_compat import ( anthropic_messages_to_openai, anthropic_tools_to_openai, @@ -343,7 +350,6 @@ async def _await_cancel_then_close(cancel_event, resp) -> None: except asyncio.CancelledError: return - # Appended to tool-use nudge to discourage plan-without-action _TOOL_ACTION_NUDGE = ( " IMPORTANT: Always call tools directly -- never write code yourself." @@ -506,6 +512,8 @@ async def load_model( reasoning_style = llama_backend.reasoning_style, reasoning_always_on = llama_backend.reasoning_always_on, supports_preserve_thinking = llama_backend.supports_preserve_thinking, + supports_tools = llama_backend.supports_tools, + cache_type_kv = llama_backend.cache_type_kv, chat_template = llama_backend.chat_template, speculative_type = llama_backend.speculative_type, ) @@ -568,12 +576,32 @@ async def load_model( chat_template = _chat_template, ) + model_defaults = load_model_defaults(request.model_path) + defaults_require_trust_remote_code = bool( + model_defaults.get("model", {}).get("trust_remote_code", False) + or model_defaults.get("inference", {}).get("trust_remote_code", False) + ) + if defaults_require_trust_remote_code and not request.trust_remote_code: + display_name = ( + model_defaults.get("model", {}).get("display_name") + or request.model_path.split("/")[-1] + or request.model_path + ) + raise HTTPException( + status_code = 400, + detail = ( + f"Model '{display_name}' requires trust_remote_code to be enabled. " + "Please enable 'Trust remote code' in Chat Settings and try again." + ), + ) + # Create config using clean factory method # is_lora is auto-detected from adapter_config.json on disk/HF config = ModelConfig.from_identifier( model_id = model_identifier, hf_token = request.hf_token, gguf_variant = request.gguf_variant, + trust_remote_code = request.trust_remote_code, ) if not config: @@ -918,10 +946,40 @@ async def validate_model( model_identifier, model_log_label, native_grant_backed = ( _resolve_model_identifier_for_request(request, operation = "validate-model") ) + if not native_grant_backed: + model_defaults = load_model_defaults(request.model_path) + default_model_config = model_defaults.get("model", {}) + default_inference_config = model_defaults.get("inference", {}) + defaults_require_trust_remote_code = bool( + default_model_config.get("trust_remote_code", False) + or default_inference_config.get("trust_remote_code", False) + ) + if defaults_require_trust_remote_code and not request.trust_remote_code: + display_name = ( + default_model_config.get("display_name") + or request.model_path.split("/")[-1] + or request.model_path + ) + return ValidateModelResponse( + valid = True, + message = ( + "Model identifier is valid, but this model requires " + "trust_remote_code before probing or loading." + ), + identifier = request.model_path, + display_name = display_name, + is_gguf = False, + is_lora = False, + is_vision = bool(default_model_config.get("is_vision", False)), + requires_trust_remote_code = True, + ) + + config = ModelConfig.from_identifier( model_id = model_identifier, hf_token = request.hf_token, gguf_variant = request.gguf_variant, + trust_remote_code = request.trust_remote_code, ) if not config: @@ -1056,6 +1114,7 @@ async def cancel_inference( @router.post("/generate/stream") async def generate_stream( + fastapi_request: Request, request: GenerateRequest, current_subject: str = Depends(get_current_subject), ): @@ -1098,9 +1157,21 @@ async def generate_stream( status_code = 400, detail = f"Failed to decode image: {str(e)}" ) + cancel_event = threading.Event() + completion_id = f"legacy-{uuid.uuid4().hex[:12]}" + _tracker = _TrackedCancel( + cancel_event, + request.cancel_id, + request.session_id, + completion_id, + ) + _tracker.__enter__() + async def stream(): + _DONE = object() try: - for chunk in backend.generate_chat_response( + yield f"data: {json.dumps({'completion_id': completion_id})}\n\n" + gen = backend.generate_chat_response( messages = request.messages, system_prompt = request.system_prompt, image = image, @@ -1109,7 +1180,19 @@ async def stream(): top_k = request.top_k, max_new_tokens = request.max_new_tokens, repetition_penalty = request.repetition_penalty, - ): + cancel_event = cancel_event, + ) + while True: + if cancel_event.is_set(): + backend.reset_generation_state() + break + if await fastapi_request.is_disconnected(): + cancel_event.set() + backend.reset_generation_state() + return + chunk = await asyncio.to_thread(next, gen, _DONE) + if chunk is _DONE: + break yield f"data: {json.dumps({'content': chunk})}\n\n" yield "data: [DONE]\n\n" @@ -1117,6 +1200,9 @@ async def stream(): backend.reset_generation_state() logger.error(f"Error during generation: {e}", exc_info = True) yield f"data: {json.dumps({'error': _friendly_error(e)})}\n\n" + finally: + cancel_event.set() + _tracker.__exit__(None, None, None) return StreamingResponse( stream(), @@ -1409,9 +1495,123 @@ def _decode_audio_base64(b64: str) -> np.ndarray: return waveform.squeeze(0).numpy() +_OPENAI_CHAT_MAX_IMAGES = 256 +_OPENAI_CHAT_MAX_IMAGE_BYTES = 20 * 1024 * 1024 +_OPENAI_CHAT_MAX_IMAGE_PIXELS = 40_000_000 +_OPENAI_CHAT_MAX_IMAGE_BASE64_CHARS = ((_OPENAI_CHAT_MAX_IMAGE_BYTES + 2) // 3) * 4 + 1024 + + +def _convert_openai_image_b64_to_png_b64(image_b64: str) -> str: + if len(image_b64) > _OPENAI_CHAT_MAX_IMAGE_BASE64_CHARS: + raise HTTPException( + status_code = 413, + detail = "Image payload exceeds the 20 MB decoded-image limit.", + ) + + try: + import base64 as _b64 + from io import BytesIO as _BytesIO + from PIL import Image as _Image + + raw = _b64.b64decode(image_b64, validate = True) + if len(raw) > _OPENAI_CHAT_MAX_IMAGE_BYTES: + raise HTTPException( + status_code = 413, + detail = "Image payload exceeds the 20 MB decoded-image limit.", + ) + with _Image.open(_BytesIO(raw)) as img: + width, height = img.size + if width * height > _OPENAI_CHAT_MAX_IMAGE_PIXELS: + raise HTTPException( + status_code = 413, + detail = "Image dimensions exceed the 40 MP limit.", + ) + converted = img.convert("RGB") + buf = _BytesIO() + converted.save(buf, format = "PNG") + png = buf.getvalue() + if len(png) > _OPENAI_CHAT_MAX_IMAGE_BYTES: + raise HTTPException( + status_code = 413, + detail = "Converted image payload exceeds the 20 MB limit.", + ) + return _b64.b64encode(png).decode("ascii") + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code = 400, detail = f"Failed to process image: {e}" + ) from e + + +def _data_url_base64_payload(url: str) -> str: + try: + header, b64data = url.split(",", 1) + except ValueError as exc: + raise HTTPException( + status_code = 400, detail = "Image data URL is missing base64 payload." + ) from exc + if ";base64" not in header.lower(): + raise HTTPException( + status_code = 400, detail = "Image data URL must be base64 encoded." + ) + return b64data + + +def _normalize_openai_message_images( + openai_messages: list[dict], + *, + is_vision: bool, + not_vision_detail: str, +) -> bool: + """Apply image count/size/pixel guards and normalize data URLs to PNG.""" + has_image = False + image_count = 0 + + for msg in openai_messages: + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if not isinstance(part, dict) or part.get("type") != "image_url": + continue + + has_image = True + image_count += 1 + if image_count > _OPENAI_CHAT_MAX_IMAGES: + raise HTTPException( + status_code = 413, + detail = f"Too many images provided; maximum is {_OPENAI_CHAT_MAX_IMAGES}.", + ) + if not is_vision: + raise HTTPException(status_code = 400, detail = not_vision_detail) + + image_url = part.get("image_url") or {} + if not isinstance(image_url, dict): + raise HTTPException( + status_code = 400, detail = "Invalid image_url content part." + ) + url = image_url.get("url", "") + if not isinstance(url, str): + raise HTTPException( + status_code = 400, detail = "Invalid image_url URL." + ) + if not url.startswith("data:"): + # Remote URLs are counted but cannot be byte/pixel checked here. + continue + + b64data = _data_url_base64_payload(url) + png_b64 = _convert_openai_image_b64_to_png_b64(b64data) + normalized = dict(image_url) + normalized["url"] = f"data:image/png;base64,{png_b64}" + part["image_url"] = normalized + + return has_image + + def _extract_content_parts( messages: list, -) -> tuple[str, list[dict], "Optional[str]"]: +) -> tuple[str, list[dict], list[str]]: """ Parse OpenAI-format messages into components the inference backend expects. @@ -1421,11 +1621,11 @@ def _extract_content_parts( Returns: system_prompt: The system message text (empty string if none provided). chat_messages: Non-system messages with content flattened to strings. - image_base64: Base64 data of the *first* image found, or ``None``. + image_base64s: Base64 data for image parts, in request order. """ system_prompt = "" chat_messages: list[dict] = [] - first_image_b64: Optional[str] = None + image_b64s: list[str] = [] for msg in messages: # ── System messages → extract as system_prompt ──────── @@ -1449,11 +1649,12 @@ def _extract_content_parts( for part in msg.content: if part.type == "text": text_parts.append(part.text) - elif part.type == "image_url" and first_image_b64 is None: + elif part.type == "image_url": url = part.image_url.url if url.startswith("data:"): # data:image/png;base64, → extract - first_image_b64 = url.split(",", 1)[1] if "," in url else None + if "," in url: + image_b64s.append(url.split(",", 1)[1]) else: logger.warning( f"Remote image URLs not yet supported: {url[:80]}..." @@ -1461,7 +1662,7 @@ def _extract_content_parts( combined_text = "\n".join(text_parts) if text_parts else "" chat_messages.append({"role": msg.role, "content": combined_text}) - return system_prompt, chat_messages, first_image_b64 + return system_prompt, chat_messages, image_b64s @router.post("/chat/completions") @@ -1712,7 +1913,7 @@ async def audio_input_stream(): ) # ── Parse messages (handles multimodal content parts) ───── - system_prompt, chat_messages, extracted_image_b64 = _extract_content_parts( + system_prompt, chat_messages, extracted_image_b64s = _extract_content_parts( payload.messages ) @@ -1731,33 +1932,26 @@ async def audio_input_stream(): ) # Reject images if this GGUF model doesn't support vision - image_b64 = extracted_image_b64 or payload.image_base64 - if image_b64 and not llama_backend.is_vision: + image_b64s = list(extracted_image_b64s) + if payload.image_base64: + image_b64s.append(payload.image_base64) + if image_b64s and not llama_backend.is_vision: raise HTTPException( status_code = 400, detail = "Image provided but current GGUF model does not support vision.", ) + if len(image_b64s) > _OPENAI_CHAT_MAX_IMAGES: + raise HTTPException( + status_code = 413, + detail = f"Too many images provided; maximum is {_OPENAI_CHAT_MAX_IMAGES}.", + ) # Convert image to PNG for llama-server (stb_image has limited format support) - if image_b64: - try: - import base64 as _b64 - from io import BytesIO as _BytesIO - from PIL import Image as _Image - - raw = _b64.b64decode(image_b64) - # Normalize to RGB so PNG encoding succeeds regardless of - # source mode (RGBA, P, L, CMYK, I, F, ...). Previously - # we only converted RGBA, which left CMYK/I/F to raise at - # img.save(PNG). - img = _Image.open(_BytesIO(raw)).convert("RGB") - buf = _BytesIO() - img.save(buf, format = "PNG") - image_b64 = _b64.b64encode(buf.getvalue()).decode("ascii") - except Exception as e: - raise HTTPException( - status_code = 400, detail = f"Failed to process image: {e}" - ) + if image_b64s: + image_b64s = [ + _convert_openai_image_b64_to_png_b64(image_b64) + for image_b64 in image_b64s + ] # Build message list with system prompt prepended gguf_messages = [] @@ -1777,7 +1971,7 @@ async def audio_input_stream(): use_tools = ( _effective_enable_tools(payload) and llama_backend.supports_tools - and not image_b64 + and not image_b64s ) if use_tools: @@ -2045,7 +2239,7 @@ async def gguf_tool_stream(): def gguf_generate(): return llama_backend.generate_chat_completion( messages = gguf_messages, - image_b64 = image_b64, + image_b64s = image_b64s, temperature = payload.temperature, top_p = payload.top_p, top_k = payload.top_k, @@ -2214,7 +2408,11 @@ async def gguf_stream_chunks(): # ── Standard Unsloth path ───────────────────────────────── # Decode image (from content parts OR legacy field) - image_b64 = extracted_image_b64 or payload.image_base64 + image_b64 = ( + extracted_image_b64s[0] + if extracted_image_b64s + else payload.image_base64 + ) image = None if image_b64: @@ -2464,9 +2662,9 @@ async def serve_sandbox_file( # ── Path containment check ────────────────────────────────── home = os.path.expanduser("~") sandbox_root = os.path.realpath(os.path.join(home, "studio_sandbox")) - safe_session = os.path.basename(session_id.replace("..", "")) - if not safe_session: + if not _re.fullmatch(r"[A-Za-z0-9_-]+", session_id or ""): raise HTTPException(status_code = 404, detail = "Not found") + safe_session = session_id file_path = os.path.realpath( os.path.join(sandbox_root, safe_session, safe_filename) @@ -2555,7 +2753,9 @@ async def openai_completions( detail = "No GGUF model loaded. Load a GGUF model first.", ) - body = await request.json() + body = await _read_json_body_limited( + request, max_bytes = _OPENAI_PROXY_BODY_MAX_BYTES + ) target_url = f"{llama_backend.base_url}/v1/completions" is_stream = body.get("stream", False) @@ -2634,7 +2834,9 @@ async def openai_embeddings( detail = "No GGUF model loaded. Load a GGUF model first.", ) - body = await request.json() + body = await _read_json_body_limited( + request, max_bytes = _OPENAI_PROXY_BODY_MAX_BYTES + ) target_url = f"{llama_backend.base_url}/v1/embeddings" async with httpx.AsyncClient() as client: @@ -3395,45 +3597,11 @@ def _normalize_anthropic_openai_images( HTTPException(400) when images are present but the active model is not a vision model, or when an image cannot be decoded. """ - from PIL import Image - - has_image = False - for msg in openai_messages: - content = msg.get("content") - if not isinstance(content, list): - continue - for part in content: - if part.get("type") != "image_url": - continue - - has_image = True - if not is_vision: - raise HTTPException( - status_code = 400, - detail = "Image provided but current GGUF model does not support vision.", - ) - - url = (part.get("image_url") or {}).get("url", "") - if not url.startswith("data:"): - # Remote URLs are forwarded as-is; llama-server will - # fetch (or fail) per its own support matrix. - continue - - try: - _, b64data = url.split(",", 1) - raw = base64.b64decode(b64data) - img = Image.open(io.BytesIO(raw)).convert("RGB") - buf = io.BytesIO() - img.save(buf, format = "PNG") - png_b64 = base64.b64encode(buf.getvalue()).decode("ascii") - except Exception as e: - raise HTTPException( - status_code = 400, - detail = f"Failed to process image: {e}", - ) - part["image_url"] = {"url": f"data:image/png;base64,{png_b64}"} - - return has_image + return _normalize_openai_message_images( + openai_messages, + is_vision = is_vision, + not_vision_detail = "Image provided but current GGUF model does not support vision.", + ) @router.post("/messages") @@ -4190,7 +4358,7 @@ async def _anthropic_passthrough_non_streaming( # ===================================================================== -def _openai_messages_for_passthrough(payload) -> list[dict]: +def _openai_messages_for_passthrough(payload, *, is_vision: bool = True) -> list[dict]: """Build OpenAI-format message dicts for the /v1/chat/completions passthrough path. @@ -4198,7 +4366,7 @@ def _openai_messages_for_passthrough(payload) -> list[dict]: unset optional fields) so they are already in standard OpenAI format — including ``role="tool"`` tool-result messages and assistant messages that carry structured ``tool_calls``. Content-parts images already in - the message list are left untouched. + the message list are counted, bounded, and data URLs are normalized to PNG. When a client uses Studio's legacy ``image_base64`` top-level field, the image is re-encoded to PNG (llama-server's stb_image has limited format @@ -4208,41 +4376,29 @@ def _openai_messages_for_passthrough(payload) -> list[dict]: """ messages = [m.model_dump(exclude_none = True) for m in payload.messages] - if not payload.image_base64: - return messages + if payload.image_base64: + data_url = f"data:image/unknown;base64,{payload.image_base64}" + image_part = {"type": "image_url", "image_url": {"url": data_url}} - try: - import base64 as _b64 - from io import BytesIO as _BytesIO - from PIL import Image as _Image - - raw = _b64.b64decode(payload.image_base64) - img = _Image.open(_BytesIO(raw)).convert("RGB") - buf = _BytesIO() - img.save(buf, format = "PNG") - png_b64 = _b64.b64encode(buf.getvalue()).decode("ascii") - except Exception as e: - raise HTTPException( - status_code = 400, - detail = f"Failed to process image: {e}", - ) - - data_url = f"data:image/png;base64,{png_b64}" - image_part = {"type": "image_url", "image_url": {"url": data_url}} - - for msg in reversed(messages): - if msg.get("role") != "user": - continue - existing = msg.get("content") - if isinstance(existing, str): - msg["content"] = [{"type": "text", "text": existing}, image_part] - elif isinstance(existing, list): - existing.append(image_part) + for msg in reversed(messages): + if msg.get("role") != "user": + continue + existing = msg.get("content") + if isinstance(existing, str): + msg["content"] = [{"type": "text", "text": existing}, image_part] + elif isinstance(existing, list): + existing.append(image_part) + else: + msg["content"] = [image_part] + break else: - msg["content"] = [image_part] - break - else: - messages.append({"role": "user", "content": [image_part]}) + messages.append({"role": "user", "content": [image_part]}) + + _normalize_openai_message_images( + messages, + is_vision = is_vision, + not_vision_detail = "Image provided but current GGUF model does not support vision.", + ) return messages @@ -4261,14 +4417,16 @@ def _extract_response_format(payload): return rf if isinstance(rf, dict) else None -def _build_openai_passthrough_body(payload, backend_ctx = None) -> dict: +def _build_openai_passthrough_body( + payload, backend_ctx = None, *, is_vision: bool = True +) -> dict: """Assemble the llama-server request body from a ChatCompletionRequest. Only explicitly-known OpenAI / llama-server fields are forwarded so that Studio-specific extensions (``enable_tools``, ``enabled_tools``, ``session_id``, ...) never leak to the backend. """ - messages = _openai_messages_for_passthrough(payload) + messages = _openai_messages_for_passthrough(payload, is_vision = is_vision) tool_choice = payload.tool_choice if payload.tool_choice is not None else "auto" # When the caller asked for a specific reasoning mode, forward it to # llama-server via chat_template_kwargs so the Jinja template renders @@ -4313,7 +4471,9 @@ async def _openai_passthrough_stream( """ target_url = f"{llama_backend.base_url}/v1/chat/completions" body = _build_openai_passthrough_body( - payload, backend_ctx = llama_backend.context_length + payload, + backend_ctx = llama_backend.context_length, + is_vision = llama_backend.is_vision, ) _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) @@ -4471,7 +4631,9 @@ async def _openai_passthrough_non_streaming( """ target_url = f"{llama_backend.base_url}/v1/chat/completions" body = _build_openai_passthrough_body( - payload, backend_ctx = llama_backend.context_length + payload, + backend_ctx = llama_backend.context_length, + is_vision = llama_backend.is_vision, ) try: @@ -4533,3 +4695,665 @@ async def _openai_passthrough_non_streaming( # verbatim (matches the docstring). Status is guaranteed 200 by # the check above. return Response(content = resp.content, media_type = "application/json") + + +# ---------------------------------------------------------------------- # +# Chat document extraction (PyMuPDF4LLM + optional VLM image description)# +# ---------------------------------------------------------------------- # + +try: + from core.chat import ( + DOCUMENT_EXTRACTION_AVAILABLE as _DOCUMENT_EXTRACTION_AVAILABLE, + DEFAULT_DOCUMENT_VISUAL_PAYLOADS as _DEFAULT_DOCUMENT_VISUAL_PAYLOADS, + DocumentExtractionBusy as _DocumentExtractionBusy, + DocumentExtractionCancelled as _DocumentExtractionCancelled, + DocumentExtractionEncrypted as _DocumentExtractionEncrypted, + DocumentExtractionTimeout as _DocumentExtractionTimeout, + DocumentExtractionUnavailable as _DocumentExtractionUnavailable, + MAX_DOCUMENT_VISUAL_PAYLOADS as _MAX_DOCUMENT_VISUAL_PAYLOADS, + SUPPORTED_MIME_TYPES as _DOC_MIME_OK, + SUPPORTED_SUFFIXES as _DOC_SUFFIX_OK, + VlmCapability as _VlmCapability, + _EXTRACT_SEMAPHORE, + detect_loaded_vlm as _detect_loaded_vlm, + document_parser_support as _document_parser_support, + document_parser_unavailable_reasons as _document_parser_unavailable_reasons, + extract_document as _extract_document, + extract_self_base_url as _extract_self_base_url, + ) +except ImportError: # pragma: no cover - package always installed alongside + _DOCUMENT_EXTRACTION_AVAILABLE = False + _DEFAULT_DOCUMENT_VISUAL_PAYLOADS = 0 + _MAX_DOCUMENT_VISUAL_PAYLOADS = 0 + _DOC_MIME_OK = frozenset() + _DOC_SUFFIX_OK = frozenset() + _detect_loaded_vlm = None # type: ignore[assignment] + _extract_document = None # type: ignore[assignment] + _extract_self_base_url = None # type: ignore[assignment] + _document_parser_support = lambda: {} # type: ignore[assignment] + _document_parser_unavailable_reasons = lambda: {} # type: ignore[assignment] + _VlmCapability = None # type: ignore[assignment] + + class _DocumentExtractionUnavailable(RuntimeError): # type: ignore[no-redef] + pass + + class _DocumentExtractionTimeout(RuntimeError): # type: ignore[no-redef] + pass + + class _DocumentExtractionBusy(RuntimeError): # type: ignore[no-redef] + pass + + class _DocumentExtractionCancelled(RuntimeError): # type: ignore[no-redef] + pass + + class _DocumentExtractionEncrypted(RuntimeError): # type: ignore[no-redef] + pass + + _EXTRACT_SEMAPHORE = threading.BoundedSemaphore(1) + + +_EXTRACT_MAX_BYTES = 100 * 1024 * 1024 +_EXTRACT_MULTIPART_OVERHEAD_BYTES = 1024 * 1024 +_EXTRACT_READ_CHUNK_BYTES = 64 * 1024 +_EXTRACT_MAX_PAGES_INLINE = 200 +_EXTRACT_TOKEN_BUDGET_DEFAULT = 8000 +_EXTRACT_TOKEN_BUDGET_MIN = 0 + +_DOCX_MIME = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" +_HTML_MIME_TYPES = {"text/html"} +_DATA_MIME_TYPES = { + "application/json", + "application/x-ndjson", + "application/xml", + "application/yaml", + "text/csv", + "text/xml", + "text/yaml", +} +_CODE_MIME_TYPES = { + "application/javascript", + "text/css", + "text/javascript", +} +_DATA_SUFFIXES = {".csv", ".json", ".jsonl", ".yaml", ".yml", ".xml"} +_CODE_SUFFIXES = { + ".py", ".js", ".jsx", ".ts", ".tsx", ".go", ".rs", ".java", + ".c", ".cpp", ".h", ".hpp", ".cs", ".php", ".rb", ".swift", + ".kt", ".kts", ".scala", ".sh", ".bash", ".zsh", ".ps1", + ".sql", ".toml", ".ini", ".cfg", ".css", ".scss", +} + + +async def _wait_for_document_request_disconnect( + fastapi_request: Request, + cancel_event: threading.Event, +) -> bool: + while not cancel_event.is_set(): + if await fastapi_request.is_disconnected(): + cancel_event.set() + return True + await asyncio.sleep(0.2) + return False + + +def _extract_ext(filename: str) -> str: + return os.path.splitext(filename or "")[1].lower() + + +def _is_supported_upload(filename: str, content_type: str) -> bool: + if (content_type or "").split(";")[0].strip().lower() in _DOC_MIME_OK: + return True + return _extract_ext(filename) in _DOC_SUFFIX_OK + + +def _document_upload_format(filename: str, content_type: str) -> Optional[str]: + mime = (content_type or "").split(";")[0].strip().lower() + ext = _extract_ext(filename) + if mime == "application/pdf" or ext == ".pdf": + return "pdf" + if mime == _DOCX_MIME or ext == ".docx": + return "docx" + if mime in _HTML_MIME_TYPES or ext in {".html", ".htm"}: + return "html" + if mime in _DATA_MIME_TYPES or ext in _DATA_SUFFIXES: + return "data" + if mime in _CODE_MIME_TYPES or ext in _CODE_SUFFIXES: + return "code" + if mime.startswith("text/") or ext in {".md", ".txt", ".log"}: + return "text" + return None + + +def _raise_if_document_parser_unavailable( + filename: str, + content_type: str, +) -> None: + format_key = _document_upload_format(filename, content_type) + if format_key is None: + return + support = _document_parser_support() + if support.get(format_key, True): + return + reason = _document_parser_unavailable_reasons().get( + format_key, + f"{format_key.upper()} extraction is not available on this server.", + ) + raise HTTPException(status_code = 501, detail = reason) + + +def _document_caption_authorization_header( + capability: Any, + llama_backend: Any, + studio_authorization_header: Optional[str], +) -> Optional[str]: + if getattr(capability, "source", None) != "gguf": + return studio_authorization_header + api_key = getattr(llama_backend, "api_key", None) or getattr( + llama_backend, "_api_key", None + ) + return f"Bearer {api_key}" if api_key else None + + +_FORM_TRUE = {"1", "true", "yes", "on"} +_FORM_FALSE = {"0", "false", "no", "off"} + + +def _parse_bool_form(value: Any, *, default: bool, field: str = "value") -> bool: + if value is None: + return default + norm = str(value).strip().lower() + if not norm: + return default + if norm in _FORM_TRUE: + return True + if norm in _FORM_FALSE: + return False + raise HTTPException( + status_code = 400, + detail = f"Invalid boolean value for {field}: {value!r}", + ) + + +def _parse_int_form( + value: Any, + *, + default: int, + lo: int, + hi: Optional[int] = None, +) -> int: + try: + parsed = int(value) if value is not None else default + except (TypeError, ValueError): + parsed = default + parsed = max(lo, parsed) + return min(parsed, hi) if hi is not None else parsed + + +def _reject_oversized_content_length(request: Request) -> None: + raw = request.headers.get("content-length") + if raw is None: + return + try: + total = int(raw) + except ValueError: + raise HTTPException( + status_code = 400, + detail = "Invalid Content-Length header", + ) + max_request_bytes = _EXTRACT_MAX_BYTES + _EXTRACT_MULTIPART_OVERHEAD_BYTES + if total > max_request_bytes: + raise HTTPException( + status_code = 413, + detail = ( + f"Request exceeds the {_EXTRACT_MAX_BYTES // (1024*1024)} MB " + "file limit" + ), + ) + + +async def _iter_request_body_limited(request: Request, *, max_bytes: int): + total = 0 + async for chunk in request.stream(): + if not chunk: + continue + total += len(chunk) + if total > max_bytes: + raise HTTPException( + status_code = 413, + detail = ( + f"Request exceeds the {_EXTRACT_MAX_BYTES // (1024*1024)} MB " + "file limit" + ), + ) + yield chunk + + +async def _read_multipart_form_limited(request: Request, *, max_bytes: int): + from starlette.formparsers import MultiPartException, MultiPartParser + + try: + parser = MultiPartParser( + request.headers, + _iter_request_body_limited(request, max_bytes = max_bytes), + ) + return await parser.parse() + except HTTPException: + raise + except MultiPartException as exc: + raise HTTPException(status_code = 400, detail = exc.message) from exc + + +# Cap on /completions and /embeddings JSON bodies. The OpenAI-compatible +# payload should be small (a few prompts + sampling params); 10 MB is generous +# headroom while still protecting against unbounded buffering when a client +# sends a falsified Content-Length and streams a much larger body. +_OPENAI_PROXY_BODY_MAX_BYTES = 10 * 1024 * 1024 + + +async def _read_json_body_limited(request: Request, *, max_bytes: int) -> Any: + """Stream the request body, enforce a hard byte cap, then parse as JSON. + + Unlike trusting Content-Length, this aborts mid-stream once the cap is + exceeded so a spoofed header cannot force the server to buffer arbitrary + payloads before parsing. + """ + total = 0 + chunks: list[bytes] = [] + async for chunk in request.stream(): + if not chunk: + continue + total += len(chunk) + if total > max_bytes: + raise HTTPException( + status_code = 413, + detail = f"Request body exceeds the {max_bytes // (1024 * 1024)} MB limit", + ) + chunks.append(chunk) + raw = b"".join(chunks) + try: + return json.loads(raw) if raw else {} + except json.JSONDecodeError as exc: + raise HTTPException( + status_code = 400, detail = f"Invalid JSON body: {exc.msg}" + ) + + +async def _read_upload_limited(upload: Any, *, max_bytes: int) -> bytes: + buf = bytearray() + while True: + chunk = await upload.read(_EXTRACT_READ_CHUNK_BYTES) + if not chunk: + break + buf.extend(chunk) + if len(buf) > max_bytes: + raise HTTPException( + status_code = 413, + detail = f"File exceeds the {max_bytes // (1024*1024)} MB limit", + ) + return bytes(buf) + + +def _is_pdf_upload(filename: str, content_type: str) -> bool: + mime = (content_type or "").split(";")[0].strip().lower() + return mime == "application/pdf" or _extract_ext(filename) == ".pdf" + + +def _preflight_pdf_page_count( + file_bytes: bytes, + filename: str, + content_type: str, +) -> Optional[int]: + if not _is_pdf_upload(filename, content_type): + return None + + pypdf_error: Optional[BaseException] = None + try: + from pypdf import PdfReader + + reader = PdfReader(io.BytesIO(file_bytes), strict = False) + if getattr(reader, "is_encrypted", False): + raise HTTPException( + status_code = 422, + detail = "Encrypted PDFs are not supported for inline extraction", + ) + return len(reader.pages) + except HTTPException: + raise + except Exception as exc: + pypdf_error = exc + logger.warning( + "pypdf page-count preflight failed for %s; trying PyMuPDF fallback", + filename, + ) + + try: + import pymupdf as _pymupdf # type: ignore + + doc = _pymupdf.open(stream = file_bytes, filetype = "pdf") + try: + if getattr(doc, "is_encrypted", False) or getattr(doc, "needs_pass", False): + raise HTTPException( + status_code = 422, + detail = "Encrypted PDFs are not supported for inline extraction", + ) + return len(doc) + finally: + doc.close() + except HTTPException: + raise + except Exception as exc: + if pypdf_error is not None: + logger.warning( + "PyMuPDF page-count fallback also failed for %s: %s", + filename, + exc, + ) + else: + logger.exception("PDF page-count preflight failed for %s", filename) + raise HTTPException( + status_code = 400, + detail = "Unable to read PDF page count before extraction", + ) from exc + + +def _truncate_markdown_to_token_budget( + markdown: str, + *, + token_budget: int, + original_tokens_est: int, +) -> tuple[str, int, Optional[str]]: + char_budget = max(_EXTRACT_TOKEN_BUDGET_MIN, token_budget) * 4 + if len(markdown) <= char_budget: + return markdown, original_tokens_est, None + + clipped = markdown[:char_budget] + clipped = ( + _re.sub(r"\s+\S*$", "", clipped).rstrip() + or markdown[:char_budget].rstrip() + ) + clipped += f"\n\n[... truncated; original was ~{original_tokens_est} tokens ...]" + warning = ( + f"Extracted markdown was truncated to {token_budget} tokens " + f"(original was ~{original_tokens_est} tokens)." + ) + return clipped, max(0, len(clipped) // 4), warning + + +@studio_router.get("/chat/document-support", response_model = DocumentSupportResponse) +async def document_support_endpoint( + fastapi_request: Request, + current_subject: str = Depends(get_current_subject), +): + """Whether document extraction + per-figure captions are available. + + Polled by the frontend when the settings panel mounts and when the + loaded model changes. The response drives the "describe figures" + toggle: when ``vlm.is_vlm`` is false the UI disables the toggle and + surfaces ``vlm.reason`` as tooltip text. + """ + if _extract_document is None or _detect_loaded_vlm is None: + return DocumentSupportResponse( + extraction_available = False, + max_visual_payloads = 0, + format_support = {}, + unavailable_formats = {}, + vlm = { + "is_vlm": False, + "endpoint_url": None, + "model_name": None, + "source": "none", + "reason": "document extraction backend is not installed", + }, + ) + + self_base_url = ( + _extract_self_base_url(fastapi_request) if _extract_self_base_url else None + ) + try: + cap = _detect_loaded_vlm( + self_base_url, + llama_backend = get_llama_cpp_backend(), + ) + except Exception as exc: + logger.exception("Document support VLM probe failed") + if _VlmCapability is not None: + cap = _VlmCapability.none( + f"document support probe failed: {type(exc).__name__}" + ) + else: # pragma: no cover - only when core.chat import fallback is active + cap = None + return DocumentSupportResponse( + extraction_available = True, + max_visual_payloads = _MAX_DOCUMENT_VISUAL_PAYLOADS, + format_support = _document_parser_support(), + unavailable_formats = _document_parser_unavailable_reasons(), + vlm = cap.to_dict() + if cap is not None + else { + "is_vlm": False, + "endpoint_url": None, + "model_name": None, + "source": "none", + "reason": "document support probe failed", + }, + ) + + +@studio_router.post("/chat/extract-document", response_model = ExtractDocumentResponse) +async def extract_document_endpoint( + fastapi_request: Request, + current_subject: str = Depends(get_current_subject), +): + """Upload a PDF / DOCX / HTML / MD / text file and return + layout-aware Markdown plus optional figure captions + generated by the currently-loaded vision model. + + The response is inlined as JSON. Large documents (>200 pages) are + rejected with 413 until the background-job path lands. + """ + if _extract_document is None: + raise HTTPException( + status_code = 501, + detail = ( + "document extraction backend is not installed. Re-run Studio " + "setup to install the parser dependencies." + ), + ) + + _reject_oversized_content_length(fastapi_request) + + try: + try: + form = await _read_multipart_form_limited( + fastapi_request, + max_bytes = _EXTRACT_MAX_BYTES + _EXTRACT_MULTIPART_OVERHEAD_BYTES, + ) + except HTTPException: + raise + except Exception as exc: + logger.exception("Invalid multipart document extraction payload") + raise HTTPException( + status_code = 400, detail = "Invalid multipart payload" + ) + + upload = form.get("file") + if upload is None or not hasattr(upload, "read"): + raise HTTPException(status_code = 400, detail = "Missing 'file' field") + + filename = getattr(upload, "filename", None) or "upload" + content_type = getattr(upload, "content_type", "") or "" + if not _is_supported_upload(filename, content_type): + raise HTTPException( + status_code = 415, + detail = f"Unsupported file type: {filename} ({content_type})", + ) + _raise_if_document_parser_unavailable(filename, content_type) + + file_bytes = await _read_upload_limited(upload, max_bytes = _EXTRACT_MAX_BYTES) + if not file_bytes: + raise HTTPException(status_code = 400, detail = "Uploaded file is empty") + + preflight_page_count = _preflight_pdf_page_count(file_bytes, filename, content_type) + if ( + preflight_page_count is not None + and preflight_page_count > _EXTRACT_MAX_PAGES_INLINE + ): + raise HTTPException( + status_code = 413, + detail = ( + f"Document has {preflight_page_count} pages; inline extraction " + f"is capped at {_EXTRACT_MAX_PAGES_INLINE}. Split into smaller " + f"documents or reduce the page range." + ), + ) + + describe_images = _parse_bool_form( + form.get("describe_images"), default = False, field = "describe_images" + ) + use_vlm_ocr = _parse_bool_form( + form.get("use_vlm_ocr"), default = False, field = "use_vlm_ocr" + ) + max_figures = _parse_int_form( + form.get("max_figures"), + default = 40, + lo = 0, + ) + max_visual_payloads = _parse_int_form( + form.get("max_visual_payloads"), + default = _DEFAULT_DOCUMENT_VISUAL_PAYLOADS, + lo = 0, + ) + token_budget = _parse_int_form( + form.get("token_budget"), + default = _EXTRACT_TOKEN_BUDGET_DEFAULT, + lo = 0, + ) + + self_base_url = ( + _extract_self_base_url(fastapi_request) if _extract_self_base_url else None + ) + llama_backend = get_llama_cpp_backend() + capability = ( + _detect_loaded_vlm( + self_base_url, + llama_backend = llama_backend, + ) + if _detect_loaded_vlm else None + ) + caption_authorization_header = _document_caption_authorization_header( + capability, + llama_backend, + fastapi_request.headers.get("authorization"), + ) + + if await fastapi_request.is_disconnected(): + raise HTTPException(status_code = 499, detail = "Client closed request") + + cancel_event = threading.Event() + extraction_task = asyncio.create_task( + _extract_document( + file_bytes, + filename, + content_type = content_type, + describe_images = describe_images, + use_vlm_ocr = use_vlm_ocr, + max_figures = max_figures, + max_visual_payloads = max_visual_payloads, + capability = capability, + self_base_url = self_base_url, + authorization_header = caption_authorization_header, + cancel_event = cancel_event, + ) + ) + disconnect_task = asyncio.create_task( + _wait_for_document_request_disconnect(fastapi_request, cancel_event) + ) + try: + done, _pending = await asyncio.wait( + {extraction_task, disconnect_task}, + return_when = asyncio.FIRST_COMPLETED, + ) + if extraction_task in done: + result = await extraction_task + elif disconnect_task in done and disconnect_task.result(): + cancel_event.set() + with suppress( + _DocumentExtractionCancelled, + asyncio.CancelledError, + asyncio.TimeoutError, + ): + await asyncio.wait_for(asyncio.shield(extraction_task), timeout = 10) + if not extraction_task.done(): + extraction_task.cancel() + raise _DocumentExtractionCancelled( + "document extraction was cancelled" + ) + else: + result = await extraction_task + except _DocumentExtractionUnavailable as exc: + raise HTTPException(status_code = 501, detail = str(exc)) + except _DocumentExtractionTimeout: + raise HTTPException( + status_code = 504, + detail = "Document parsing timed out after 120s before image captioning", + ) + except _DocumentExtractionBusy: + raise HTTPException(status_code = 503, detail = "Document extraction is busy") + except _DocumentExtractionCancelled: + raise HTTPException(status_code = 499, detail = "Client closed request") + except _DocumentExtractionEncrypted as exc: + raise HTTPException(status_code = 422, detail = str(exc)) + except ValueError as exc: + detail = str(exc) + status_code = 415 if detail.lower().startswith("unsupported file type") else 400 + raise HTTPException(status_code = status_code, detail = detail) + except Exception as exc: + logger.exception("Document extraction failed for %s", filename) + raise HTTPException( + status_code = 500, detail = "Extraction failed" + ) + finally: + cancel_event.set() + disconnect_task.cancel() + with suppress(asyncio.CancelledError): + await disconnect_task + + if result.page_count > _EXTRACT_MAX_PAGES_INLINE: + raise HTTPException( + status_code = 413, + detail = ( + f"Document has {result.page_count} pages; inline extraction " + f"is capped at {_EXTRACT_MAX_PAGES_INLINE}. Split into smaller " + f"documents or reduce the page range." + ), + ) + + markdown, tokens_est, truncate_warning = _truncate_markdown_to_token_budget( + result.markdown, + token_budget = token_budget, + original_tokens_est = result.tokens_est, + ) + warnings = list(result.warnings) + if truncate_warning: + warnings.append(truncate_warning) + + return ExtractDocumentResponse( + filename = filename, + markdown = markdown, + page_count = result.page_count, + tokens_est = tokens_est, + truncated = truncate_warning is not None, + figures = [ + ExtractedFigureModel(**_asdict(f)) + for f in result.figures + ], + describe_skipped_reason = result.describe_skipped_reason, + vlm_source = result.vlm_source, + vlm_model = result.vlm_model, + image_input_available = getattr(result, "image_input_available", False), + warnings = warnings, + ) + finally: + # _EXTRACT_SEMAPHORE is owned solely by _run_extract_process_sync; the + # worker maps a busy semaphore to DocumentExtractionBusy → 503 above. + pass diff --git a/studio/backend/routes/models.py b/studio/backend/routes/models.py index d01e94b0c9..2980e1f1ff 100644 --- a/studio/backend/routes/models.py +++ b/studio/backend/routes/models.py @@ -13,6 +13,7 @@ import uuid from pathlib import Path from fastapi import APIRouter, Body, Depends, HTTPException, Query +from pydantic import BaseModel, Field from typing import List, Optional import structlog from loggers import get_logger @@ -123,6 +124,16 @@ def _is_valid_repo_id(repo_id: str) -> bool: logger = get_logger(__name__) +class ModelProbeRequest(BaseModel): + model_name: str = Field(..., description = "Model identifier or local path") + hf_token: Optional[str] = Field( + None, description = "HuggingFace token for gated/private models" + ) + trust_remote_code: bool = Field( + False, description = "Allow probes that require custom model code" + ) + + def derive_model_type( is_vision: bool, audio_type: Optional[str], is_embedding: bool = False ) -> ModelType: @@ -136,6 +147,38 @@ def derive_model_type( return "text" +def _defaults_vision_flags(config_dict: dict) -> tuple[bool, bool]: + model_config = config_dict.get("model", {}) if isinstance(config_dict, dict) else {} + inference_config = ( + config_dict.get("inference", {}) if isinstance(config_dict, dict) else {} + ) + yaml_is_vision = bool(model_config.get("is_vision", False)) + yaml_requires_trust_remote_code = bool( + model_config.get("trust_remote_code", False) + or inference_config.get("trust_remote_code", False) + ) + return yaml_is_vision, yaml_requires_trust_remote_code + + +def _detect_vision_for_config_endpoint( + model_name: str, + *, + hf_token: Optional[str] = None, + trust_remote_code: bool = False, + config_dict: Optional[dict] = None, +) -> bool: + defaults = config_dict if config_dict is not None else load_model_defaults(model_name) + yaml_is_vision, yaml_requires_trust_remote_code = _defaults_vision_flags(defaults) + if yaml_is_vision and yaml_requires_trust_remote_code: + return True + detected = is_vision_model( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) + return detected + + def _resolve_hf_cache_dir() -> Path: """Resolve local HF cache root used by hub downloads.""" try: @@ -1463,7 +1506,7 @@ async def list_models( loaded_models.append(model_info) # Include active GGUF model (loaded via llama-server) - from routes.inference import get_llama_cpp_backend + from core.inference.llama_cpp import get_llama_cpp_backend llama_backend = get_llama_cpp_backend() if llama_backend.is_loaded and llama_backend.model_identifier: @@ -1547,8 +1590,33 @@ def _get_model_size_bytes( @router.get("/config/{model_name:path}") async def get_model_config( model_name: str, - hf_token: Optional[str] = Query(None), + hf_token: Optional[str] = None, + trust_remote_code: bool = False, current_subject: str = Depends(get_current_subject), +): + return await _build_model_config_response( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) + + +@router.post("/config") +async def post_model_config( + request: ModelProbeRequest, + current_subject: str = Depends(get_current_subject), +): + return await _build_model_config_response( + request.model_name, + hf_token = request.hf_token, + trust_remote_code = request.trust_remote_code, + ) + + +async def _build_model_config_response( + model_name: str, + hf_token: Optional[str] = None, + trust_remote_code: bool = False, ): """ Get configuration for a specific model. @@ -1573,7 +1641,12 @@ async def get_model_config( config_dict = load_model_defaults(model_name) # Detect model capabilities (pass HF token for gated models) - is_vision = is_vision_model(model_name, hf_token = hf_token) + is_vision = _detect_vision_for_config_endpoint( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + config_dict = config_dict, + ) is_embedding = is_embedding_model(model_name, hf_token = hf_token) audio_type = detect_audio_type(model_name, hf_token = hf_token) @@ -1582,7 +1655,11 @@ async def get_model_config( base_model = None max_position_embeddings = None try: - model_config = ModelConfig.from_identifier(model_name) + model_config = ModelConfig.from_identifier( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) is_lora = model_config.is_lora base_model = model_config.base_model if is_lora else None max_position_embeddings = _get_max_position_embeddings(model_config) @@ -2053,7 +2130,33 @@ async def get_lora_base_model( @router.get("/check-vision/{model_name:path}", response_model = VisionCheckResponse) async def check_vision_model( model_name: str, + hf_token: Optional[str] = None, + trust_remote_code: bool = False, + current_subject: str = Depends(get_current_subject), +): + return await _check_vision_model_response( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) + + +@router.post("/check-vision", response_model = VisionCheckResponse) +async def post_check_vision_model( + request: ModelProbeRequest, current_subject: str = Depends(get_current_subject), +): + return await _check_vision_model_response( + request.model_name, + hf_token = request.hf_token, + trust_remote_code = request.trust_remote_code, + ) + + +async def _check_vision_model_response( + model_name: str, + hf_token: Optional[str] = None, + trust_remote_code: bool = False, ): """ Check if a model is a vision model. @@ -2062,7 +2165,11 @@ async def check_vision_model( """ try: logger.info(f"Checking if vision model: {model_name}") - is_vision = is_vision_model(model_name) + is_vision = _detect_vision_for_config_endpoint( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) logger.info(f"Vision check result for {model_name}: is_vision={is_vision}") return VisionCheckResponse( @@ -2587,7 +2694,7 @@ async def delete_cached_model( # Check if model is currently loaded try: - from routes.inference import get_llama_cpp_backend + from core.inference.llama_cpp import get_llama_cpp_backend llama_backend = get_llama_cpp_backend() if llama_backend.is_loaded and llama_backend.model_identifier: diff --git a/studio/backend/run.py b/studio/backend/run.py index 1dd1230a17..e18c763af7 100644 --- a/studio/backend/run.py +++ b/studio/backend/run.py @@ -244,11 +244,15 @@ def _graceful_shutdown(server = None): logger.warning("Error shutting down training subprocess: %s", e) # 5. Kill llama-server subprocess (if loaded) + # + # Read the module-level singleton directly so we don't instantiate a + # fresh backend during shutdown when none was ever loaded. try: - from routes.inference import _llama_cpp_backend + from core.inference import llama_cpp as _llama_cpp_mod - if _llama_cpp_backend is not None: - _llama_cpp_backend._kill_process() + backend = getattr(_llama_cpp_mod, "_llama_cpp_backend", None) + if backend is not None: + backend._kill_process() except Exception as e: logger.warning("Error shutting down llama-server: %s", e) diff --git a/studio/backend/tests/test_anthropic_messages.py b/studio/backend/tests/test_anthropic_messages.py index 0825ef9337..e723c0091b 100644 --- a/studio/backend/tests/test_anthropic_messages.py +++ b/studio/backend/tests/test_anthropic_messages.py @@ -34,6 +34,7 @@ AnthropicStreamEmitter, AnthropicPassthroughEmitter, ) +import routes.inference as route from routes.inference import _normalize_anthropic_openai_images from fastapi import HTTPException import base64 as _b64 @@ -1011,3 +1012,21 @@ def test_bad_base64_raises_400(self): with pytest.raises(HTTPException) as exc: _normalize_anthropic_openai_images(msgs, is_vision = True) assert exc.value.status_code == 400 + + def test_image_count_limit_applies(self, monkeypatch): + monkeypatch.setattr(route, "_OPENAI_CHAT_MAX_IMAGES", 1) + data_url = _jpeg_data_url() + msgs = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": data_url}}, + {"type": "image_url", "image_url": {"url": data_url}}, + ], + } + ] + + with pytest.raises(HTTPException) as exc: + _normalize_anthropic_openai_images(msgs, is_vision = True) + + assert exc.value.status_code == 413 diff --git a/studio/backend/tests/test_chat_document_extraction.py b/studio/backend/tests/test_chat_document_extraction.py new file mode 100644 index 0000000000..297a9ddebc --- /dev/null +++ b/studio/backend/tests/test_chat_document_extraction.py @@ -0,0 +1,900 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +""" +Tests for the chat document extractor + VLM capability probe. + +Probe tests run regardless of the extraction backend because they only +shape-check :mod:`core.chat.vlm_capability`. Backend-backed tests skip +cleanly when the optional deps (pymupdf / pymupdf4llm / mammoth) are +missing. +""" + +from __future__ import annotations + +import importlib.util +import sys +from types import ModuleType, SimpleNamespace +from typing import Any, Dict, Optional + +import pytest + +from core.chat.vlm_capability import ( + VlmCapability, + detect_loaded_vlm, + extract_self_base_url, +) + + +# ---------------------------------------------------------------------- # +# VlmCapability dataclass # +# ---------------------------------------------------------------------- # + + +def test_vlm_capability_none_factory_is_safe_default() -> None: + cap = VlmCapability.none() + assert cap.is_vlm is False + assert cap.endpoint_url is None + assert cap.model_name is None + assert cap.source == "none" + assert cap.reason # non-empty + + +def test_vlm_capability_to_dict_round_trips_fields() -> None: + cap = VlmCapability( + is_vlm = True, + endpoint_url = "http://127.0.0.1:8080", + model_name = "qwen2-vl", + source = "gguf", + reason = None, + ) + assert cap.to_dict() == { + "is_vlm": True, + "endpoint_url": "http://127.0.0.1:8080", + "model_name": "qwen2-vl", + "source": "gguf", + "reason": None, + } + + +# ---------------------------------------------------------------------- # +# detect_loaded_vlm() across backend shapes # +# ---------------------------------------------------------------------- # + + +class _FakeLlama: + def __init__( + self, + *, + loaded: bool, + vision: bool = False, + base_url: str = "http://127.0.0.1:8080", + model_id: str = "fake-gguf", + ) -> None: + self.is_loaded = loaded + self.is_vision = vision + self.base_url = base_url + self.model_identifier = model_id + + +class _FakeInferenceBackend: + def __init__( + self, + *, + active: Optional[str], + info: Optional[Dict[str, Any]] = None, + ) -> None: + self.active_model_name = active + self.models: Dict[str, Dict[str, Any]] = ( + {active: info or {}} if active else {} + ) + + +def _patch_probes( + monkeypatch: pytest.MonkeyPatch, + *, + llama: Optional[_FakeLlama], + inference: Optional[_FakeInferenceBackend], +) -> None: + from core.chat import vlm_capability as vc + + if llama is None: + monkeypatch.setattr(vc, "_probe_gguf", lambda _llama = None: None) + else: + def probe_gguf(llama_backend = None): + backend = llama_backend or llama + if not backend.is_loaded: + return None + is_vision = bool(backend.is_vision) + return VlmCapability( + is_vlm = is_vision, + endpoint_url = backend.base_url, + model_name = backend.model_identifier, + source = "gguf", + reason = None if is_vision else "loaded GGUF is not vision-capable", + ) + + monkeypatch.setattr(vc, "_probe_gguf", probe_gguf) + + if inference is None: + monkeypatch.setattr(vc, "_probe_transformers", lambda _u: None) + else: + def probe_tf(self_base_url): + name = inference.active_model_name + if not name: + return None + info = inference.models.get(name) or {} + is_vision = bool(info.get("is_vision", False)) + source = "unsloth" if info.get("is_lora") else "transformers" + if not self_base_url: + return VlmCapability( + is_vlm = False, + endpoint_url = None, + model_name = name, + source = source, + reason = "cannot self-loopback: request base URL unavailable", + ) + return VlmCapability( + is_vlm = is_vision, + endpoint_url = self_base_url.rstrip("/"), + model_name = name, + source = source, + reason = None if is_vision else "loaded model is not vision-capable", + ) + + monkeypatch.setattr(vc, "_probe_transformers", probe_tf) + + +def test_detect_returns_none_when_no_model_loaded( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_probes(monkeypatch, llama = None, inference = None) + cap = detect_loaded_vlm() + assert cap.source == "none" + assert cap.is_vlm is False + + +def test_detect_gguf_vision_returns_llama_endpoint( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llama = _FakeLlama(loaded = True, vision = True, base_url = "http://127.0.0.1:9999") + _patch_probes(monkeypatch, llama = llama, inference = None) + cap = detect_loaded_vlm("http://studio.local") + assert cap.source == "gguf" + assert cap.is_vlm is True + assert cap.endpoint_url == "http://127.0.0.1:9999" # GGUF ignores self_base_url + assert cap.reason is None + + +def test_detect_gguf_vision_accepts_injected_backend( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from core.chat import vlm_capability as vc + + llama = _FakeLlama(loaded = True, vision = True, base_url = "http://127.0.0.1:9999") + monkeypatch.setattr(vc, "_probe_transformers", lambda _u: None) + + cap = detect_loaded_vlm( + "http://127.0.0.1:8000", + llama_backend = llama, + ) + + assert cap.source == "gguf" + assert cap.is_vlm is True + assert cap.endpoint_url == "http://127.0.0.1:9999" + + +def test_detect_gguf_vision_uses_core_llama_accessor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The implicit GGUF fallback must use the core-owned singleton path.""" + from core.chat import vlm_capability as vc + from core.inference import llama_cpp + + llama = _FakeLlama(loaded = True, vision = True, base_url = "http://127.0.0.1:9999") + assert hasattr(llama_cpp, "get_llama_cpp_backend") + monkeypatch.setattr(llama_cpp, "_llama_cpp_backend", llama) + monkeypatch.setattr(vc, "_probe_transformers", lambda _u: None) + + cap = detect_loaded_vlm("http://127.0.0.1:8000") + + assert cap.source == "gguf" + assert cap.is_vlm is True + assert cap.endpoint_url == "http://127.0.0.1:9999" + + +def test_detect_gguf_non_vision_surfaces_reason( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llama = _FakeLlama(loaded = True, vision = False) + _patch_probes(monkeypatch, llama = llama, inference = None) + cap = detect_loaded_vlm() + assert cap.source == "gguf" + assert cap.is_vlm is False + assert cap.reason and "vision" in cap.reason.lower() + + +def test_detect_transformers_vision_uses_self_loopback( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ib = _FakeInferenceBackend( + active = "Qwen2-VL-7B", info = {"is_vision": True, "is_lora": False}, + ) + _patch_probes(monkeypatch, llama = None, inference = ib) + cap = detect_loaded_vlm("http://127.0.0.1:8000/") + assert cap.source == "transformers" + assert cap.is_vlm is True + assert cap.endpoint_url == "http://127.0.0.1:8000" + assert cap.model_name == "Qwen2-VL-7B" + + +def test_detect_unsloth_lora_vision_reports_unsloth_source( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ib = _FakeInferenceBackend( + active = "my-qwen-vl-lora", info = {"is_vision": True, "is_lora": True}, + ) + _patch_probes(monkeypatch, llama = None, inference = ib) + cap = detect_loaded_vlm("http://studio.local:8000") + assert cap.source == "unsloth" + assert cap.is_vlm is True + + +def test_detect_falls_through_when_gguf_is_loaded_but_endpoint_data_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A half-initialised llama-server (is_loaded=True but base_url/model + missing) must not suppress the transformers fallback path — otherwise + a misleading non-vision GGUF result hides an active transformers VLM. + """ + from core.chat import vlm_capability as vc + + fake_llama_cpp = ModuleType("core.inference.llama_cpp") + fake_llama_cpp.get_llama_cpp_backend = lambda: _FakeLlama( + loaded = True, base_url = "", model_id = "", + ) + fake_inference = ModuleType("core.inference") + fake_inference.__path__ = [] # type: ignore[attr-defined] + fake_inference.llama_cpp = fake_llama_cpp # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "core.inference", fake_inference) + monkeypatch.setitem(sys.modules, "core.inference.llama_cpp", fake_llama_cpp) + + ib = _FakeInferenceBackend( + active = "Qwen2-VL-7B", info = {"is_vision": True, "is_lora": False}, + ) + monkeypatch.setattr( + vc, + "_probe_transformers", + lambda self_base_url: VlmCapability( + is_vlm = True, + endpoint_url = self_base_url.rstrip("/") if self_base_url else None, + model_name = ib.active_model_name, + source = "transformers", + reason = None, + ), + ) + + cap = detect_loaded_vlm("http://127.0.0.1:8000") + assert cap.source == "transformers" + assert cap.is_vlm is True + + +def test_detect_transformers_without_self_url_reports_missing_loopback( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ib = _FakeInferenceBackend( + active = "Qwen2-VL-7B", info = {"is_vision": True, "is_lora": False}, + ) + _patch_probes(monkeypatch, llama = None, inference = ib) + cap = detect_loaded_vlm(None) + assert cap.is_vlm is False + assert cap.reason and "loopback" in cap.reason.lower() + + +# ---------------------------------------------------------------------- # +# extract_self_base_url — request base-URL extraction # +# ---------------------------------------------------------------------- # + + +class _FakeState: + def __init__(self, server_port: Optional[int] = None) -> None: + if server_port is not None: + self.server_port = server_port + + +class _FakeApp: + def __init__(self, server_port: Optional[int] = None) -> None: + self.state = _FakeState(server_port) + + +class _FakeRequest: + def __init__( + self, + base_url: str, + *, + server_port: Optional[int] = None, + scope_server: Optional[tuple[str, int]] = None, + ) -> None: + self.base_url = base_url + self.app = _FakeApp(server_port) + self.scope = {"server": scope_server} if scope_server else {} + + +def test_extract_self_base_url_strips_trailing_slash() -> None: + assert ( + extract_self_base_url(_FakeRequest("http://127.0.0.1:8000/")) + == "http://127.0.0.1:8000" + ) + + +def test_extract_self_base_url_prefers_trusted_server_port() -> None: + assert ( + extract_self_base_url( + _FakeRequest( + "http://attacker.invalid:9999/", + server_port = 7777, + scope_server = ("127.0.0.1", 6666), + ) + ) + == "http://127.0.0.1:7777" + ) + assert ( + extract_self_base_url( + _FakeRequest( + "http://attacker.invalid:9999/", + scope_server = ("127.0.0.1", 6666), + ) + ) + == "http://127.0.0.1:6666" + ) + + +def test_extract_self_base_url_ignores_host_header() -> None: + assert ( + extract_self_base_url(_FakeRequest("http://studio.local:8000/")) + == "http://127.0.0.1:8000" + ) + assert ( + extract_self_base_url(_FakeRequest("https://example.com:9443/")) + == "http://127.0.0.1:9443" + ) + + +def test_extract_self_base_url_none_when_empty() -> None: + assert extract_self_base_url(_FakeRequest("")) is None + + +def test_extract_self_base_url_none_on_missing_attribute() -> None: + assert extract_self_base_url(object()) is None + + +# ---------------------------------------------------------------------- # +# extract_document orchestration — backend-agnostic (monkey-patched) # +# ---------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_max_figures_zero_sets_describe_skipped_reason( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """max_figures=0 must skip description with a specific diagnostic even + when a VLM is available.""" + from core.chat import document_extractor as de + + def fake_extract(_fb, _fn, _opts, _ct = ""): + return "# Smoke\n", [], 1, 0, 0 + + monkeypatch.setattr(de, "DOCUMENT_EXTRACTION_AVAILABLE", True) + monkeypatch.setattr(de, "_run_extract_sync", fake_extract) + + result = await de.extract_document( + b"# Smoke\n", + "sample.md", + describe_images = True, + max_figures = 0, + capability = VlmCapability( + is_vlm = True, + endpoint_url = "http://127.0.0.1:8000", + model_name = "vlm", + source = "transformers", + ), + ) + + assert result.describe_skipped_reason == ( + "figure description disabled because max_figures is 0" + ) + assert result.markdown == "# Smoke\n" + assert result.figures == [] + + +@pytest.mark.asyncio +async def test_run_extract_sync_seam_receives_content_type( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The test seam path (monkeypatched _run_extract_sync) must be invoked + with the content_type so dispatch-by-content-type can be exercised in + tests, not only by filename suffix.""" + from core.chat import document_extractor as de + + received: dict[str, str] = {} + + def fake_extract(_fb, _fn, _opts, ct = ""): + received["content_type"] = ct + return "ok", [], 0, 0, 0 + + monkeypatch.setattr(de, "DOCUMENT_EXTRACTION_AVAILABLE", True) + monkeypatch.setattr(de, "_run_extract_sync", fake_extract) + + await de.extract_document( + b"hello", + "no-suffix-file", + content_type = "text/plain", + describe_images = False, + ) + assert received["content_type"] == "text/plain" + + +@pytest.mark.asyncio +async def test_describe_image_via_vlm_sends_auth_header_and_max_tokens( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from core.chat import document_extractor as de + + captured: dict[str, Any] = {} + + class FakeResponse: + status_code = 200 + + def json(self): + return {"choices": [{"message": {"content": "A chart."}}]} + + class FakeAsyncClient: + def __init__(self, *, timeout: float) -> None: + captured["timeout"] = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, *_args): + return None + + async def post(self, url, *, headers, json): + captured["url"] = url + captured["headers"] = headers + captured["json"] = json + return FakeResponse() + + fake_httpx = ModuleType("httpx") + fake_httpx.AsyncClient = FakeAsyncClient + monkeypatch.setitem(sys.modules, "httpx", fake_httpx) + + caption, error = await de._describe_image_via_vlm( + image_base64 = "abc", + image_mime = "image/jpeg", + endpoint_url = "http://127.0.0.1:8000", + model_name = "vlm", + authorization_header = "Bearer token", + timeout_seconds = 7, + ) + + assert caption == "A chart." + assert error is None + assert captured["url"] == "http://127.0.0.1:8000/v1/chat/completions" + assert captured["headers"]["Authorization"] == "Bearer token" + assert captured["json"]["max_tokens"] == 512 + assert "max_completion_tokens" not in captured["json"] + + +# ---------------------------------------------------------------------- # +# Backend dispatch — real _run_extract_sync (requires pymupdf/mammoth) # +# ---------------------------------------------------------------------- # + + +_BACKEND_INSTALLED = ( + importlib.util.find_spec("pymupdf") is not None + and importlib.util.find_spec("pymupdf4llm") is not None + and importlib.util.find_spec("mammoth") is not None +) + + +def test_run_extract_sync_rejects_pptx_with_value_error() -> None: + """PPTX was dropped in the PyMuPDF4LLM migration. _run_extract_sync + must raise ValueError so the route can map it to HTTP 415.""" + if not _BACKEND_INSTALLED: + pytest.skip("extraction backend not installed") + from core.chat import document_extractor as de + + with pytest.raises(ValueError): + de._run_extract_sync( + b"PK\x03\x04", + "deck.pptx", + {"max_figures": 0, "extract_images": False, "use_vlm_ocr": False}, + ) + + +def test_run_extract_sync_text_path_decodes_utf8() -> None: + """TXT / MD paths must not require PDF/DOCX parser dependencies.""" + from core.chat import document_extractor as de + + md, figs, pages, trunc, seen = de._run_extract_sync( + "# Héllo\n".encode("utf-8"), + "notes.md", + {"max_figures": 0, "extract_images": False, "use_vlm_ocr": False}, + ) + assert md == "# Héllo\n" + assert figs == [] + assert pages == 0 and trunc == 0 and seen == 0 + + +def test_run_extract_sync_html_converts_to_markdown_without_parser_deps() -> None: + """HTML must be cleaned before prompt injection and not depend on PDF/DOCX deps.""" + from core.chat import document_extractor as de + + md, figs, pages, trunc, seen = de._run_extract_sync( + b"

Title

Hello world

", + "page.html", + {"max_figures": 0, "extract_images": False, "use_vlm_ocr": False}, + ) + assert "# Title" in md + assert "**world**" in md + assert "\"\n" + " b\"

hello

\")\n" + "out, *_rest = mod._extract_html(dirty)\n" + "import json\n" + "print(json.dumps({'out': out}))\n" + ) + proc = _run_subprocess(body) + assert proc.returncode == 0, proc.stderr + + import json + + parsed = json.loads(proc.stdout.strip().splitlines()[-1]) + out = parsed["out"] + # Pre-fix this returns the raw HTML because the fallback branch + # in _extract_html swallows the ImportError. + assert "alert" not in out, ( + f" survived into the prompt; raw output:\n{out}" + ) + assert " Date: Mon, 25 May 2026 13:20:16 +0000 Subject: [PATCH 07/11] studio: accept null-password PDFs, harden extractor process lifecycle Real-world testing (Orimi test PDF, RFC 8259 PDF, "Attention Is All You Need", calibre demo DOCX) plus an additional review pass surfaced five follow-ups on top of the earlier singleton/semaphore/HTML fix: 1. Null-password PDFs were rejected as encrypted. The classic Orimi PDF, Acrobat distilled scans, and a long tail of PDFs in the wild carry a /Encrypt dict with an empty user password so the file opens without prompting. pypdf.PdfReader.is_encrypted and PyMuPDF's doc.is_encrypted both flag them, but the canonical "needs a password" signal is PyMuPDF's needs_pass. The preflight in routes.inference._preflight_pdf_page_count and the extractor in core.chat.document_extractor._extract_pdf now refuse only when needs_pass is True. pypdf's branch tries decrypt("") first and falls through to PyMuPDF on failure. 2. Worker put-result-then-die race. _run_extract_process_sync could observe proc.is_alive() == False after the worker had already queued a successful result, exit the loop with message=None, and surface a RuntimeError. Both the in-loop is_alive() branch and the post-join branch now perform a final result_queue.get_nowait() before declaring failure. 3. macOS multiprocessing start method. The ternary picked "fork" on macOS, which is unsafe with Quartz / PyObjC / PyMuPDF's CoreFoundation linkage. macOS now uses "spawn" like Windows; Linux keeps "fork" for the CoW pickling win. 4. NDJSON streaming InvalidStateError on shield-cancel race. The streaming NDJSON loop accepted extract_wait completion as a signal to call extraction_task.result(). When asyncio.shield's outer future was cancelled before the inner task finished, that raised InvalidStateError and surfaced as a generic HTTP 500. The branch now waits for extraction_task.done() and re-arms a fresh shielded future when only the outer wrapper completes. 5. PaddleOCR-VL nondeterministic inference defaults. Shipped temperature=1.5, min_p=0.1 -- causes hallucinated glyphs and reorderings on a closed-form transcription task. Aligned with the sibling OCR presets (DeepSeek-OCR, GLM-OCR) at temperature=0.0, top_p=1.0, top_k=-1, min_p=0.0. Regression test additions: - tests/studio/test_pseudo_encrypted_pdf.py mints a null-password PDF with PyMuPDF, asserts both the preflight and _extract_pdf accept it, and confirms a real password-required PDF still raises DocumentExtractionEncrypted. Also drops importlib.reload from test_extractor_semaphore_leak.py: the reload swapped _drain_future_exception out from under routes.inference, breaking an existing identity assertion. The new fixture snapshots and restores the semaphore counter instead. Local: studio backend suite + 4 regression files: 91/91 PR-relevant tests pass; the remaining 9 failures are the pre-existing gpu_selection / kv_cache_estimation / help_output tests unchanged. --- .../other/unsloth_PaddleOCR-VL.yaml | 9 +- .../backend/core/chat/document_extractor.py | 36 +++++- studio/backend/routes/inference.py | 46 ++++++- tests/studio/test_extractor_semaphore_leak.py | 15 ++- tests/studio/test_pseudo_encrypted_pdf.py | 112 ++++++++++++++++++ 5 files changed, 200 insertions(+), 18 deletions(-) create mode 100644 tests/studio/test_pseudo_encrypted_pdf.py diff --git a/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml b/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml index 2a270ed282..bffb79902c 100644 --- a/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml +++ b/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml @@ -56,6 +56,11 @@ logging: inference: trust_remote_code: true - temperature: 1.5 - min_p: 0.1 + # OCR is a closed-form transcription task; sibling OCR presets + # (DeepSeek-OCR, GLM-OCR) use deterministic decoding so the + # transcription is reproducible. Match that convention here. + temperature: 0.0 + min_p: 0.0 + top_p: 1.0 + top_k: -1 diff --git a/studio/backend/core/chat/document_extractor.py b/studio/backend/core/chat/document_extractor.py index 419621178e..915fc596c2 100644 --- a/studio/backend/core/chat/document_extractor.py +++ b/studio/backend/core/chat/document_extractor.py @@ -568,7 +568,12 @@ def _extract_pdf( doc = pymupdf.open(stream = file_bytes, filetype = "pdf") try: - if getattr(doc, "is_encrypted", False) or getattr(doc, "needs_pass", False): + # ``is_encrypted`` is True for any file with an /Encrypt dict + # (very common for Acrobat-distilled PDFs, scanner output, the + # classic Orimi test file). ``needs_pass`` is the real "user + # password required" signal. Refuse extraction only when an + # actual password is missing. + if getattr(doc, "needs_pass", False): raise DocumentExtractionEncrypted( "Encrypted PDF; provide a password before extracting it." ) @@ -843,9 +848,16 @@ def _run_extract_process_sync( result_queue = None proc = None try: - ctx = multiprocessing.get_context( - "spawn" if os.name == "nt" else "fork" - ) + # Prefer "fork" only on Linux. macOS defaults to "spawn" in + # modern Python because Objective-C runtimes (loaded by + # PyMuPDF/CoreFoundation/Quartz) crash under fork. Windows has + # never supported fork. + import sys as _sys + if os.name == "nt" or _sys.platform == "darwin": + mp_method = "spawn" + else: + mp_method = "fork" + ctx = multiprocessing.get_context(mp_method) result_queue = ctx.Queue(maxsize = 1) proc = ctx.Process( target = _run_extract_worker, @@ -874,6 +886,14 @@ def _run_extract_process_sync( "document extraction was cancelled" ) if not proc.is_alive(): + # The worker may have put its result and exited + # between the queue.get timeout and this is_alive + # check. Drain the queue once more before declaring + # failure so a successful extraction is not lost. + try: + message = result_queue.get_nowait() + except queue.Empty: + pass break if time.monotonic() >= deadline: _terminate_extract_process(proc) @@ -885,6 +905,14 @@ def _run_extract_process_sync( if proc.is_alive(): proc.terminate() proc.join(2) + if message is None: + # One more attempt after the join completes; covers the + # case where the worker exited cleanly with a result still + # queued. + try: + message = result_queue.get_nowait() + except queue.Empty: + pass if message is None: raise RuntimeError( f"document extraction worker exited without a result " diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index 3a9ec43d88..c15b18492d 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -5067,11 +5067,25 @@ def _preflight_pdf_page_count( from pypdf import PdfReader reader = PdfReader(io.BytesIO(file_bytes), strict = False) + # Many PDFs report ``is_encrypted=True`` even though they only use a + # null/empty user password and open fine (Acrobat-distilled docs, + # the classic Orimi test PDF, scanner output). Try the empty + # password before refusing; PyMuPDF's ``needs_pass`` is the real + # signal in the fallback branch below. if getattr(reader, "is_encrypted", False): - raise HTTPException( - status_code = 422, - detail = "Encrypted PDFs are not supported for inline extraction", - ) + try: + if reader.decrypt("") == 0: + raise HTTPException( + status_code = 422, + detail = "Encrypted PDFs are not supported for inline extraction", + ) + except HTTPException: + raise + except Exception: + # ``decrypt`` itself failed (corrupt /Encrypt dict, unknown + # algorithm). Fall through to the PyMuPDF fallback rather + # than declaring the file encrypted. + raise RuntimeError("pypdf decrypt probe failed") return len(reader.pages) except HTTPException: raise @@ -5087,7 +5101,12 @@ def _preflight_pdf_page_count( doc = _pymupdf.open(stream = file_bytes, filetype = "pdf") try: - if getattr(doc, "is_encrypted", False) or getattr(doc, "needs_pass", False): + # PyMuPDF's ``needs_pass`` is True only when an actual password + # is required. ``is_encrypted`` is True for any file with an + # /Encrypt dict, which includes the common null-password case + # that opens fine. Refuse only when a password is actually + # needed. + if getattr(doc, "needs_pass", False): raise HTTPException( status_code = 422, detail = "Encrypted PDFs are not supported for inline extraction", @@ -5488,7 +5507,12 @@ async def _ndjson_stream(): "document extraction was cancelled" ) - if extract_wait in done or extraction_task.done(): + # The shield-wrapper may complete (cancelled) before + # the underlying extraction_task is done; calling + # ``.result()`` in that window raises + # InvalidStateError. Wait for the real task before + # consuming its result. + if extraction_task.done(): # Drain any remaining progress events before result. while not progress_queue.empty(): try: @@ -5498,6 +5522,16 @@ async def _ndjson_stream(): yield json.dumps(event) + "\n" result = extraction_task.result() break + if extract_wait in done: + # Shield-wrapper finished but the real task is + # still running. Re-arm the wait on a fresh + # shielded future and loop. + extract_wait = asyncio.ensure_future( + asyncio.shield(extraction_task) + ) + extract_wait.add_done_callback( + _drain_doc_future_exception + ) if result.page_count > _EXTRACT_MAX_PAGES_INLINE: yield ( diff --git a/tests/studio/test_extractor_semaphore_leak.py b/tests/studio/test_extractor_semaphore_leak.py index 12ca3a1b35..4cd46e3b8d 100644 --- a/tests/studio/test_extractor_semaphore_leak.py +++ b/tests/studio/test_extractor_semaphore_leak.py @@ -22,7 +22,6 @@ from __future__ import annotations -import importlib import os import sys from pathlib import Path @@ -36,19 +35,23 @@ sys.path.insert(0, str(_BACKEND)) -# Force a small concurrency so the test is fast and obvious. -os.environ.setdefault("UNSLOTH_STUDIO_EXTRACT_CONCURRENCY", "2") # Don't park the test waiting for a slot to free. os.environ.setdefault("UNSLOTH_STUDIO_EXTRACT_QUEUE_WAIT", "0") @pytest.fixture def extractor(): - # Re-import each test so the env vars above take effect and the - # semaphore counter starts at the configured ceiling. + """Yield the document_extractor module. + + We avoid ``importlib.reload`` here because reloading swaps the + module-level ``_drain_future_exception`` function object out from + under ``routes.inference`` (which captured it at import time), + and other tests assert identity between the two references. + Instead we snapshot ``_EXTRACT_SEMAPHORE._value`` before each + test and assert restoration after; no reload required. + """ from core.chat import document_extractor as mod - importlib.reload(mod) yield mod diff --git a/tests/studio/test_pseudo_encrypted_pdf.py b/tests/studio/test_pseudo_encrypted_pdf.py new file mode 100644 index 0000000000..34b2455fd0 --- /dev/null +++ b/tests/studio/test_pseudo_encrypted_pdf.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 +""" +Tests that PDFs with a null/empty user password (very common; Acrobat +distillation often writes /Encrypt dicts with no password) are NOT +falsely rejected as "encrypted" by either the preflight or the +extractor. + +Failure mode the test pins: + The classic Orimi PDF Test File (and many scanner-output PDFs) + carry "Standard V2 R3 128-bit RC4" encryption with an empty user + password -- the file opens without prompting in any reader. + Pre-fix, both ``routes.inference._preflight_pdf_page_count`` and + ``core.chat.document_extractor._extract_pdf`` returned HTTP 422 + "Encrypted PDFs are not supported" because they checked + ``is_encrypted`` rather than ``needs_pass``. After the fix the + file is accepted and its text is extracted. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest + + +_BACKEND = Path(__file__).resolve().parents[2] / "studio" / "backend" +if str(_BACKEND) not in sys.path: + sys.path.insert(0, str(_BACKEND)) + + +def _make_pseudo_encrypted_pdf() -> bytes: + """Mint a tiny PDF with an empty user password (mirrors what + Orimi's test file and many distiller pipelines produce).""" + pymupdf = pytest.importorskip("pymupdf") + doc = pymupdf.open() + page = doc.new_page() + page.insert_text( + (72, 100), + "pseudo-encrypted PDF: null user password, opens without prompt", + fontsize=12, + ) + out = doc.tobytes( + encryption=pymupdf.PDF_ENCRYPT_AES_256, + owner_pw="owner-pw", + user_pw="", + ) + doc.close() + return out + + +def test_extract_pdf_accepts_null_password(monkeypatch): + """The extractor must not raise DocumentExtractionEncrypted for a + PDF whose user password is the empty string. PyMuPDF's + ``needs_pass`` is the canonical signal; ``is_encrypted`` is too + aggressive.""" + from core.chat import document_extractor as mod + + file_bytes = _make_pseudo_encrypted_pdf() + + md, figures, page_count, truncated, seen = mod._extract_pdf( + file_bytes, + max_figures=0, + use_vlm_ocr=False, + max_visual_payloads=0, + ) + + assert page_count == 1 + assert "pseudo-encrypted PDF" in md + assert figures == [] + + +def test_preflight_pdf_page_count_accepts_null_password(): + """The pre-extraction preflight at + ``routes.inference._preflight_pdf_page_count`` must accept + null-password PDFs.""" + from routes.inference import _preflight_pdf_page_count + + file_bytes = _make_pseudo_encrypted_pdf() + n = _preflight_pdf_page_count( + file_bytes, + filename="pseudo_encrypted.pdf", + content_type="application/pdf", + ) + assert n == 1 + + +def test_extract_pdf_still_rejects_password_required(monkeypatch): + """Sanity-check the other direction: a PDF that actually requires + a non-empty user password must still raise + DocumentExtractionEncrypted.""" + pymupdf = pytest.importorskip("pymupdf") + doc = pymupdf.open() + page = doc.new_page() + page.insert_text((72, 100), "this one needs a password", fontsize=12) + encrypted = doc.tobytes( + encryption=pymupdf.PDF_ENCRYPT_AES_256, + owner_pw="owner", + user_pw="real-password", + ) + doc.close() + + from core.chat import document_extractor as mod + + with pytest.raises(mod.DocumentExtractionEncrypted): + mod._extract_pdf( + encrypted, + max_figures=0, + use_vlm_ocr=False, + max_visual_payloads=0, + ) From 22a0b2406f72ed09cdc89bd0a6701f9f9d02b1fd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 25 May 2026 13:31:03 +0000 Subject: [PATCH 08/11] studio: cap /api/inference/cancel body size Mirrors the request-body hardening this PR already added to the sibling JSON inference endpoints (/v1/chat/completions at :1674, /v1/anthropic/messages at :2769, /v1/anthropic/messages_count at :2850). /api/inference/cancel still used await request.json() with no streaming cap, so an authenticated client could force the server to buffer arbitrarily large bodies and slip past the exact overflow hardening this PR added elsewhere. Switched to _read_json_body_limited(request, max_bytes=64 KiB). The real cancel payload is a small dict of identifiers (cancel_id, completion_id, session_id, message_id); 64 KiB is generous and matches the cap pattern used in the other authenticated route handlers. Stream-cancel registration timing (test_stream_cancel_registration_timing + test_cancel_atomicity + test_cancel_id_wiring) is unchanged. --- studio/backend/routes/inference.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index c15b18492d..0bd93b1c82 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -1086,10 +1086,15 @@ async def cancel_inference( A cancel_id arriving before its stream registers is stashed briefly and replayed on registration. Returns {"cancelled": N}. """ + # The cancel body is a tiny dict of identifiers; cap the read so an + # authenticated client cannot make this endpoint buffer megabytes + # the way the sibling JSON inference endpoints already prevent. try: - body = await request.json() + body = await _read_json_body_limited(request, max_bytes = 64 * 1024) if not isinstance(body, dict): body = {} + except HTTPException: + raise except Exception as e: logger.debug("Failed to parse cancel request body: %s", e) body = {} From 7b6d399819c3765ac1d27de58fa2f8cd48b19894 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Mon, 25 May 2026 11:48:21 +0000 Subject: [PATCH 09/11] ci: trim staging workflows and add PR-5351 cross-OS lanes Drops studio-frontend-ci.yml, studio-inference-smoke.yml, studio-tauri-smoke.yml, wheel-smoke.yml, release-desktop.yml and stale.yml from this staging branch so the matrix stays below the 5-concurrent-Windows-runner cap. Keeps studio-backend-ci.yml as the Ubuntu sanity baseline. Adds three lanes that re-run the PR-5351 backend tests plus the three regression tests added in the fix commit: - pr5351-ubuntu.yml: ubuntu-latest, Python 3.11. CUDA spoof in tests/conftest.py engages on CPU runners. - pr5351-macos.yml: macos-14 (arm64). Exercises the multiprocessing spawn start-method and the MLX branch in core.chat.vlm_capability. - pr5351-windows.yml: windows-latest. Validates spawn + path normalisation + Process-construction-under-pressure (exactly the EAGAIN class the semaphore-leak fix protects against). Each workflow gates on paths: studio/backend/**, tests/studio/**, tests/conftest.py and its own file so unrelated commits do not re-trigger. --- .github/workflows/pr5351-macos.yml | 60 + .github/workflows/pr5351-ubuntu.yml | 57 + .github/workflows/pr5351-windows.yml | 59 + .github/workflows/release-desktop.yml | 902 --------------- .github/workflows/stale.yml | 37 - .github/workflows/studio-frontend-ci.yml | 151 --- .github/workflows/studio-inference-smoke.yml | 1052 ------------------ .github/workflows/studio-tauri-smoke.yml | 128 --- .github/workflows/wheel-smoke.yml | 136 --- 9 files changed, 176 insertions(+), 2406 deletions(-) create mode 100644 .github/workflows/pr5351-macos.yml create mode 100644 .github/workflows/pr5351-ubuntu.yml create mode 100644 .github/workflows/pr5351-windows.yml delete mode 100644 .github/workflows/release-desktop.yml delete mode 100644 .github/workflows/stale.yml delete mode 100644 .github/workflows/studio-frontend-ci.yml delete mode 100644 .github/workflows/studio-inference-smoke.yml delete mode 100644 .github/workflows/studio-tauri-smoke.yml delete mode 100644 .github/workflows/wheel-smoke.yml diff --git a/.github/workflows/pr5351-macos.yml b/.github/workflows/pr5351-macos.yml new file mode 100644 index 0000000000..6bb149659b --- /dev/null +++ b/.github/workflows/pr5351-macos.yml @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 cross-OS validation: macOS lane. +# macos-14 (arm64). Validates the multiprocessing `spawn` path that +# differs from Linux's default `fork`, the MLX detection branch in +# core/chat/vlm_capability.py, and Safari/WebKit-relevant filesystem +# behaviour. CPU-only; CUDA spoof auto-engages via tests/conftest.py. + +name: PR-5351 macOS + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/**' + - 'tests/studio/**' + - 'tests/conftest.py' + - '.github/workflows/pr5351-macos.yml' + workflow_dispatch: + +concurrency: + group: pr5351-macos-${{ github.ref }} + cancel-in-progress: true + +jobs: + pytest: + runs-on: macos-14 + timeout-minutes: 25 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend test dependencies (CPU only) + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install \ + python-multipart aiofiles sqlalchemy cryptography \ + pyyaml jinja2 mammoth unpdf requests \ + 'numpy<3' pytest pytest-asyncio httpx + pip install --index-url https://download.pytorch.org/whl/cpu 'torch>=2.4,<2.11' + pip install 'transformers>=4.51,<5.5' + + - name: PR-5351 document tests (macOS spawn semantics) + working-directory: studio/backend + env: + # macOS's default start method is spawn; exercise the same + # config users see in production. + UNSLOTH_STUDIO_EXTRACT_CONCURRENCY: '2' + run: | + python -m pytest -q tests/test_chat_document_extraction.py tests/test_chat_document_routes.py tests/test_inference_worker.py tests/test_vision_cache.py tests/test_anthropic_messages.py tests/test_openai_tool_passthrough.py tests/test_models_get_model_config_case_resolution.py --tb=short + + - name: PR-5351 regression tests + cancel timing + run: | + python -m pytest -q tests/studio/test_extractor_semaphore_leak.py tests/studio/test_html_independent_of_inference.py tests/studio/test_gguf_singleton_shared.py tests/studio/test_stream_cancel_registration_timing.py --tb=short diff --git a/.github/workflows/pr5351-ubuntu.yml b/.github/workflows/pr5351-ubuntu.yml new file mode 100644 index 0000000000..d1dd6d8712 --- /dev/null +++ b/.github/workflows/pr5351-ubuntu.yml @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 cross-OS validation: Ubuntu lane. +# Runs the document-extraction tests, the cancellation-timing structural +# test, and the three regression tests added in the fix commit against +# Python 3.11 on ubuntu-latest. CPU-only; the existing tests/conftest.py +# auto-installs the CUDA spoof so unsloth/unsloth_zoo device probes +# return "cuda". + +name: PR-5351 Ubuntu + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/**' + - 'tests/studio/**' + - 'tests/conftest.py' + - '.github/workflows/pr5351-ubuntu.yml' + workflow_dispatch: + +concurrency: + group: pr5351-ubuntu-${{ github.ref }} + cancel-in-progress: true + +jobs: + pytest: + runs-on: ubuntu-latest + timeout-minutes: 20 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend test dependencies (CPU only) + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install \ + python-multipart aiofiles sqlalchemy cryptography \ + pyyaml jinja2 mammoth unpdf requests \ + 'numpy<3' pytest pytest-asyncio httpx + pip install --index-url https://download.pytorch.org/whl/cpu 'torch>=2.4,<2.11' + pip install 'transformers>=4.51,<5.5' + + - name: PR-5351 document tests + working-directory: studio/backend + run: | + python -m pytest -q tests/test_chat_document_extraction.py tests/test_chat_document_routes.py tests/test_inference_worker.py tests/test_vision_cache.py tests/test_anthropic_messages.py tests/test_openai_tool_passthrough.py tests/test_models_get_model_config_case_resolution.py --tb=short + + - name: PR-5351 regression tests + cancel timing + run: | + python -m pytest -q tests/studio/test_extractor_semaphore_leak.py tests/studio/test_html_independent_of_inference.py tests/studio/test_gguf_singleton_shared.py tests/studio/test_stream_cancel_registration_timing.py --tb=short diff --git a/.github/workflows/pr5351-windows.yml b/.github/workflows/pr5351-windows.yml new file mode 100644 index 0000000000..777e1c38ec --- /dev/null +++ b/.github/workflows/pr5351-windows.yml @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 cross-OS validation: Windows lane. +# windows-latest. Validates the multiprocessing `spawn` path +# (mandatory on Windows), path normalisation, and EAGAIN-style +# Process construction failures under load (the exact bug class the +# semaphore-leak fix protects against). CPU-only; CUDA spoof +# auto-engages via tests/conftest.py. + +name: PR-5351 Windows + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/**' + - 'tests/studio/**' + - 'tests/conftest.py' + - '.github/workflows/pr5351-windows.yml' + workflow_dispatch: + +concurrency: + group: pr5351-windows-${{ github.ref }} + cancel-in-progress: true + +jobs: + pytest: + runs-on: windows-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend test dependencies (CPU only) + shell: pwsh + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install python-multipart aiofiles sqlalchemy cryptography pyyaml jinja2 mammoth unpdf requests "numpy<3" pytest pytest-asyncio httpx + pip install --index-url https://download.pytorch.org/whl/cpu "torch>=2.4,<2.11" + pip install "transformers>=4.51,<5.5" + + - name: PR-5351 document tests (Windows spawn semantics) + working-directory: studio/backend + shell: pwsh + env: + UNSLOTH_STUDIO_EXTRACT_CONCURRENCY: '2' + run: | + python -m pytest -q tests/test_chat_document_extraction.py tests/test_chat_document_routes.py tests/test_inference_worker.py tests/test_vision_cache.py tests/test_anthropic_messages.py tests/test_openai_tool_passthrough.py tests/test_models_get_model_config_case_resolution.py --tb=short + + - name: PR-5351 regression tests + cancel timing + shell: pwsh + run: | + python -m pytest -q tests/studio/test_extractor_semaphore_leak.py tests/studio/test_html_independent_of_inference.py tests/studio/test_gguf_singleton_shared.py tests/studio/test_stream_cancel_registration_timing.py --tb=short diff --git a/.github/workflows/release-desktop.yml b/.github/workflows/release-desktop.yml deleted file mode 100644 index e747605322..0000000000 --- a/.github/workflows/release-desktop.yml +++ /dev/null @@ -1,902 +0,0 @@ -name: Release Desktop App - -on: - workflow_dispatch: - inputs: - studio_version: - description: 'Studio version tag to release (for example, v0.1.39-beta)' - type: string - required: true - pypi_version: - description: 'Exact PyPI unsloth version just published/stamped (for example, 2026.5.3); leave blank to use MIN_DESKTOP_BACKEND_VERSION' - type: string - required: false - draft: - description: 'Create as draft release; draft runs do not advance desktop-latest updater channel' - type: boolean - default: true - -permissions: - contents: read - -concurrency: - group: release-desktop-${{ github.repository }} - cancel-in-progress: false - -jobs: - prepare-version: - name: Prepare release versions - runs-on: ubuntu-latest - outputs: - studio_version: ${{ steps.prepare.outputs.studio_version }} - app_version: ${{ steps.prepare.outputs.app_version }} - desktop_release_tag: ${{ steps.prepare.outputs.desktop_release_tag }} - prerelease: ${{ steps.prepare.outputs.prerelease }} - pypi_version: ${{ steps.prepare.outputs.pypi_version }} - - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd - with: - persist-credentials: false - - - name: Validate release versions - id: prepare - shell: bash - env: - INPUT_STUDIO_VERSION: ${{ inputs.studio_version }} - INPUT_PYPI_VERSION: ${{ inputs.pypi_version }} - run: | - python3 <<'PY' - import os - import pathlib - import re - import sys - - studio_version = os.environ['INPUT_STUDIO_VERSION'].strip() - if not studio_version: - sys.exit('studio_version is required, for example v0.1.39-beta') - if re.fullmatch(r'v?20\d{2}\.\d+\.\d+(?:[-+][0-9A-Za-z.-]+)?', studio_version): - sys.exit(f'studio_version must be a Studio SemVer tag, not a date-style backend version: {studio_version}') - - semver_tag = re.compile( - r'^v(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)' - r'(?:-[0-9A-Za-z.][0-9A-Za-z.-]*)?$' - ) - if not semver_tag.fullmatch(studio_version): - sys.exit(f'studio_version must be a SemVer tag with leading v, for example v0.1.39-beta: {studio_version}') - - app_version = studio_version.removeprefix('v') - desktop_release_tag = f'desktop-v{app_version}' - prerelease = 'true' if '-' in app_version.split('+', 1)[0] else 'false' - - def parse_backend_version(version): - match = re.fullmatch( - r'(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)' - r'(?:([a-zA-Z]|\.dev|dev|\.rc|rc|\.post|post)(\d*))?' - r'(?:[-+]([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?', - version, - ) - if not match: - return None - major, minor, patch, suffix_name, suffix_number, suffix_text = match.groups() - if suffix_name: - normalized = suffix_name.lower().lstrip('.') - order = {'dev': 0, 'a': 1, 'b': 2, 'rc': 3, 'post': 5}.get(normalized) - if order is None: - return None - number = int(suffix_number or '0') - elif suffix_text: - order = 3 if version[version.find(suffix_text) - 1] == '-' else 4 - number = 0 - else: - order = 4 - number = 0 - return (int(major), int(minor), int(patch), order, number) - - preflight = pathlib.Path('studio/src-tauri/src/preflight/version.rs').read_text() - match = re.search(r'MIN_DESKTOP_BACKEND_VERSION:\s*&str\s*=\s*"([^"]+)"', preflight) - if not match: - sys.exit('Could not read MIN_DESKTOP_BACKEND_VERSION') - min_backend_version = match.group(1) - - input_pypi_version = os.environ.get('INPUT_PYPI_VERSION', '').strip() - parsed_min_backend = parse_backend_version(min_backend_version) - if parsed_min_backend is None: - sys.exit(f'MIN_DESKTOP_BACKEND_VERSION is not a supported backend package version: {min_backend_version}') - - pypi_version = input_pypi_version or min_backend_version - parsed_pypi = parse_backend_version(pypi_version) - if parsed_pypi is None: - sys.exit(f'pypi_version is not a supported backend package version: {pypi_version}') - if parsed_pypi < parsed_min_backend: - sys.exit( - f'pypi_version {pypi_version} is lower than desktop minimum ' - f'MIN_DESKTOP_BACKEND_VERSION {min_backend_version}' - ) - - if input_pypi_version: - print( - 'Using exact PyPI unsloth version from pypi_version input: ' - f'{pypi_version} (desktop minimum: {min_backend_version})' - ) - else: - print( - 'Using exact PyPI unsloth version from MIN_DESKTOP_BACKEND_VERSION: ' - f'{pypi_version}' - ) - - with open(os.environ['GITHUB_OUTPUT'], 'a', encoding='utf-8') as output: - print(f'studio_version={studio_version}', file=output) - print(f'app_version={app_version}', file=output) - print(f'desktop_release_tag={desktop_release_tag}', file=output) - print(f'prerelease={prerelease}', file=output) - print(f'pypi_version={pypi_version}', file=output) - PY - - - name: Verify PyPI package and Studio stamp - shell: bash - env: - STUDIO_VERSION: ${{ steps.prepare.outputs.studio_version }} - PYPI_VERSION: ${{ steps.prepare.outputs.pypi_version }} - run: | - set -euo pipefail - python3 <<'PY' - import json - import os - import pathlib - import sys - import time - import urllib.error - import urllib.request - - pypi_version = os.environ['PYPI_VERSION'] - dist_dir = pathlib.Path(os.environ['RUNNER_TEMP'], 'pypi-unsloth-dist') - dist_dir.mkdir(parents=True, exist_ok=True) - metadata_url = f'https://pypi.org/pypi/unsloth/{pypi_version}/json' - - last_error = None - for attempt in range(1, 6): - try: - with urllib.request.urlopen(metadata_url, timeout=30) as response: - metadata = json.load(response) - break - except Exception as exc: - last_error = exc - if attempt < 5: - time.sleep(10 * attempt) - else: - sys.exit(f'Publish unsloth=={pypi_version} to PyPI before the desktop release ({last_error})') - - files = metadata.get('urls') or [] - if not files: - sys.exit(f'PyPI returned no distribution files for unsloth=={pypi_version}') - - for file_info in files: - filename = file_info.get('filename') - url = file_info.get('url') - if not filename or '/' in filename or not url: - sys.exit(f'Unexpected PyPI file entry for unsloth=={pypi_version}: {file_info!r}') - target = dist_dir / filename - for attempt in range(1, 4): - try: - with urllib.request.urlopen(url, timeout=60) as response: - target.write_bytes(response.read()) - break - except Exception as exc: - last_error = exc - if attempt < 3: - time.sleep(5 * attempt) - else: - sys.exit(f'Could not download {filename} from PyPI ({last_error})') - PY - - if [ -f scripts/stamp_studio_release.py ]; then - mapfile -t dists < <(find "$RUNNER_TEMP/pypi-unsloth-dist" -type f \( -name '*.whl' -o -name '*.tar.gz' \) | sort) - if [ "${#dists[@]}" -eq 0 ]; then - echo "No PyPI wheel/sdist artifacts downloaded for unsloth==$PYPI_VERSION" >&2 - exit 1 - fi - python3 scripts/stamp_studio_release.py --verify-dist "$RUNNER_TEMP/pypi-unsloth-dist" --expected "$STUDIO_VERSION" - else - echo "scripts/stamp_studio_release.py not found; release-desktop requires #5308 to verify the PyPI Studio stamp." >&2 - exit 1 - fi - - - name: Guard public updater channel version - if: ${{ !inputs.draft }} - shell: bash - env: - GH_REPO: ${{ github.repository }} - GH_TOKEN: ${{ github.token }} - APP_VERSION: ${{ steps.prepare.outputs.app_version }} - run: | - set -euo pipefail - mkdir -p "$RUNNER_TEMP/desktop-current" - if ! gh release download desktop-latest --pattern latest.json --dir "$RUNNER_TEMP/desktop-current" --clobber 2>/dev/null; then - echo "No existing desktop-latest latest.json found; allowing first channel publish." - exit 0 - fi - python3 <<'PY' - import json - import os - import pathlib - import re - import sys - - def parse(value: str): - value = value.removeprefix('v') - match = re.fullmatch( - r'(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)' - r'(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?' - r'(?:\+[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?', - value, - ) - if not match: - sys.exit(f'desktop-latest latest.json has invalid version: {value}') - major, minor, patch, prerelease = match.groups() - return (int(major), int(minor), int(patch), prerelease) - - def numeric_tail(identifier: str) -> tuple[str, int] | None: - match = re.fullmatch(r'([A-Za-z-]+)(\d+)', identifier) - if not match: - return None - return (match.group(1).lower(), int(match.group(2))) - - def compare_identifier(left: str, right: str) -> int: - left_num = left.isdigit() - right_num = right.isdigit() - if left_num and right_num: - return (int(left) > int(right)) - (int(left) < int(right)) - if left_num: - return -1 - if right_num: - return 1 - - left_tail = numeric_tail(left) - right_tail = numeric_tail(right) - if left_tail and right_tail and left_tail[0] == right_tail[0]: - return (left_tail[1] > right_tail[1]) - (left_tail[1] < right_tail[1]) - - return (left > right) - (left < right) - - def compare_prerelease(left: str | None, right: str | None) -> int: - if left == right: - return 0 - if left is None: - return 1 - if right is None: - return -1 - left_parts = left.split('.') - right_parts = right.split('.') - for left_part, right_part in zip(left_parts, right_parts): - order = compare_identifier(left_part, right_part) - if order: - return order - return (len(left_parts) > len(right_parts)) - (len(left_parts) < len(right_parts)) - - def compare(left: str, right: str) -> int: - left_major, left_minor, left_patch, left_pre = parse(left) - right_major, right_minor, right_patch, right_pre = parse(right) - left_core = (left_major, left_minor, left_patch) - right_core = (right_major, right_minor, right_patch) - if left_core != right_core: - return (left_core > right_core) - (left_core < right_core) - return compare_prerelease(left_pre, right_pre) - - current_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-current', 'latest.json') - current = json.loads(current_path.read_text()).get('version') - next_version = os.environ['APP_VERSION'] - if not isinstance(current, str): - sys.exit('desktop-latest latest.json has missing version') - if compare(next_version, current) < 0: - sys.exit( - f'Refusing to publish {next_version}; desktop-latest currently points at newer version {current}.' - ) - PY - - build: - # TODO: split into a "build (no secrets)" + "publish (secrets)" job pair - # with actions/upload-artifact handoff so the matrix build cannot - # publish a Release on its own. The current matrix runs across - # Linux/macOS/Windows in a single job, so the split needs artefact - # collection across the OS matrix and is out of scope for this - # hardening pass. - permissions: - contents: write # tauri-apps/tauri-action creates / uploads a GitHub Release - strategy: - fail-fast: false - max-parallel: 1 - matrix: - include: - - platform: macos-latest - args: '--target aarch64-apple-darwin' - label: macOS (Apple Silicon) - # - platform: macos-latest - # args: '--target x86_64-apple-darwin' - # label: macOS (Intel) - - platform: ubuntu-22.04 - args: '' - label: Linux (x64) - - platform: windows-latest - args: '' - label: Windows (x64) - - name: Build ${{ matrix.label }} - needs: prepare-version - runs-on: ${{ matrix.platform }} - - env: - FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true - APP_VERSION: ${{ needs.prepare-version.outputs.app_version }} - STUDIO_VERSION: ${{ needs.prepare-version.outputs.studio_version }} - DESKTOP_RELEASE_TAG: ${{ needs.prepare-version.outputs.desktop_release_tag }} - DESKTOP_PRERELEASE: ${{ needs.prepare-version.outputs.prerelease }} - - steps: - # harden-runner in audit mode: surfaces every egress destination in - # the runner log so the allowlist for a future `egress-policy: block` - # promotion can be derived from observed traffic. Audit mode is - # cross-platform (Linux / macOS / Windows runners); blocking mode is - # currently Linux-only, so we deliberately stay in audit until the - # macOS + Windows codesign paths have been observed. - - name: Harden runner (audit) - uses: step-security/harden-runner@a5ad31d6a139d249332a2605b85202e8c0b78450 # v2.19.1 - with: - egress-policy: audit - - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd - with: - persist-credentials: false - - # ── Linux dependencies ── - - name: Install Linux dependencies - if: matrix.platform == 'ubuntu-22.04' - run: | - sudo apt-get update - sudo apt-get install -y libwebkit2gtk-4.1-dev libayatana-appindicator3-dev librsvg2-dev libxdo-dev libssl-dev patchelf - - # ── Node.js ── - - name: Setup Node.js - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e - with: - node-version: 24 - - - name: Install pinned Tauri CLI - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: npm install --save-dev --prefix studio @tauri-apps/cli@2.10.1 --no-fund --no-audit - - - name: Verify pinned Tauri CLI - shell: bash - run: | - out="$(npx --prefix studio tauri --version)" - echo "$out" - if [ "$out" != "tauri-cli 2.10.1" ]; then - echo "Expected tauri-cli 2.10.1, got $out" >&2 - exit 1 - fi - - - name: Verify desktop updater and Linux package config - shell: bash - run: | - node <<'JS' - const { readFileSync } = require('node:fs'); - - const expected = 'https://github.com/unslothai/unsloth/releases/download/desktop-latest/latest.json'; - const config = JSON.parse(readFileSync('studio/src-tauri/tauri.conf.json', 'utf8')); - const endpoints = config.plugins?.updater?.endpoints; - if (!Array.isArray(endpoints) || endpoints.length !== 1) { - throw new Error('Expected exactly one desktop updater endpoint'); - } - if (endpoints[0] !== expected) { - throw new Error('Desktop updater endpoint must be ' + expected + ', got ' + endpoints[0]); - } - if (endpoints.some((endpoint) => endpoint.includes('/releases/latest/'))) { - throw new Error('Desktop updater endpoint must not use repo-wide /releases/latest/'); - } - - const targets = config.bundle?.targets; - if (Array.isArray(targets) && targets.some((target) => String(target).toLowerCase() === 'rpm')) { - throw new Error('Desktop release must not target RPM packages'); - } - if (config.bundle?.linux?.rpm) { - throw new Error('bundle.linux.rpm must not be configured'); - } - - const workflow = readFileSync('.github/workflows/release-desktop.yml', 'utf8'); - const lines = workflow.split(/\r?\n/); - const releaseBodies = []; - for (let i = 0; i < lines.length; i += 1) { - const match = lines[i].match(/^(\s*)releaseBody:\s*\|\s*$/); - if (!match) continue; - const baseIndent = match[1].length; - const bodyLines = []; - i += 1; - for (; i < lines.length; i += 1) { - const line = lines[i]; - if (line.trim() === '') { - bodyLines.push(''); - continue; - } - const indent = line.match(/^\s*/)[0].length; - if (indent <= baseIndent) { - i -= 1; - break; - } - bodyLines.push(line.slice(baseIndent + 2)); - } - releaseBodies.push(bodyLines.join('\n')); - } - if (releaseBodies.length === 0) { - throw new Error('Expected at least one desktop release body'); - } - for (const body of releaseBodies) { - if (/\brpm\b|\.rpm/i.test(body)) { - throw new Error('Desktop release body must not advertise RPM packages'); - } - } - JS - - - name: Install frontend dependencies - working-directory: studio/frontend - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: npm install --no-fund --no-audit - - # ── Rust ── - - name: Install Rust stable - uses: dtolnay/rust-toolchain@29eef336d9b2848a0b548edc03f92a220660cdb8 # stable @ 2026-03-27 - with: - targets: ${{ matrix.platform == 'macos-latest' && 'aarch64-apple-darwin,x86_64-apple-darwin' || '' }} - - - name: Patch desktop app version - shell: bash - working-directory: studio/src-tauri - run: | - set -euo pipefail - if command -v python3 >/dev/null 2>&1; then - PYTHON=python3 - else - PYTHON=python - fi - "$PYTHON" <<'PY' - import os - import pathlib - import re - import sys - - app_version = os.environ['APP_VERSION'] - if not app_version: - sys.exit('APP_VERSION is required') - - cargo_toml = pathlib.Path('Cargo.toml') - lines = cargo_toml.read_text().splitlines(keepends=True) - in_package = False - patched = False - for index, line in enumerate(lines): - stripped = line.strip() - if stripped == '[package]': - in_package = True - continue - if stripped.startswith('[') and stripped.endswith(']'): - in_package = False - if in_package and re.fullmatch(r'version\s*=\s*"[^"]+"\s*', stripped): - lines[index] = f'version = "{app_version}"\n' - patched = True - break - if not patched: - sys.exit('Could not patch [package] version in Cargo.toml') - cargo_toml.write_text(''.join(lines)) - - cargo_lock = pathlib.Path('Cargo.lock') - lock_text = cargo_lock.read_text() - lock_text, count = re.subn( - r'(?m)(^\[\[package\]\]\nname = "unsloth-studio"\nversion = ")[^"]+(")', - lambda match: f'{match.group(1)}{app_version}{match.group(2)}', - lock_text, - ) - if count != 1: - sys.exit(f'Could not patch unsloth-studio version in Cargo.lock (matches={count})') - cargo_lock.write_text(lock_text) - PY - - cargo metadata --locked --no-deps --format-version 1 > "$RUNNER_TEMP/cargo-metadata.json" - "$PYTHON" <<'PY' - import json - import os - import pathlib - import sys - - app_version = os.environ['APP_VERSION'] - metadata = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'cargo-metadata.json').read_text()) - versions = [package['version'] for package in metadata.get('packages', []) if package.get('name') == 'unsloth-studio'] - if versions != [app_version]: - sys.exit(f'cargo metadata unsloth-studio version mismatch: expected {app_version}, got {versions}') - PY - - git diff -- Cargo.toml Cargo.lock - - - name: Rust cache - uses: swatinem/rust-cache@e18b497796c12c097a38f9edb9d0641fb99eee32 - with: - workspaces: 'studio/src-tauri -> target' - - # ── macOS: import signing certificate ── - - name: Import Apple certificate - if: matrix.platform == 'macos-latest' - env: - APPLE_CERTIFICATE: ${{ secrets.APPLE_CERTIFICATE }} - APPLE_CERTIFICATE_PASSWORD: ${{ secrets.APPLE_CERTIFICATE_PASSWORD }} - KEYCHAIN_PASSWORD: ${{ secrets.KEYCHAIN_PASSWORD }} - run: | - echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12 - security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain - security default-keychain -s build.keychain - security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain - security set-keychain-settings -t 3600 -u build.keychain - security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign - security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain - security find-identity -v -p codesigning build.keychain - rm -f certificate.p12 - - # ── Windows: install Azure Trusted Signing CLI ── - - name: Install trusted-signing-cli - if: matrix.platform == 'windows-latest' - run: | - cargo install trusted-signing-cli --version 0.10.0 --locked - echo "$env:USERPROFILE\.cargo\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - - # ── Windows: verify signing CLI is accessible ── - - name: Verify trusted-signing-cli - if: matrix.platform == 'windows-latest' - run: | - Write-Output "PATH: $env:PATH" - Get-Command trusted-signing-cli -ErrorAction SilentlyContinue || Write-Output "trusted-signing-cli NOT in PATH" - trusted-signing-cli --version || Write-Output "trusted-signing-cli failed to run" - - # ── Linux: build + sign + upload ── - - name: Build Linux app - if: matrix.platform == 'ubuntu-22.04' - uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }} - TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }} - with: - projectPath: studio - tauriScript: npx --prefix . tauri - tagName: ${{ needs.prepare-version.outputs.desktop_release_tag }} - releaseName: 'Unsloth Studio (Desktop) ${{ needs.prepare-version.outputs.studio_version }}' - releaseBody: | - Desktop app for Unsloth Studio. - - **macOS**: Download the Apple Silicon `.dmg`. - **Windows**: Download the `-setup.exe` installer. - **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal). - - > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package. - > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64` - > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually. - releaseDraft: ${{ inputs.draft }} - prerelease: ${{ needs.prepare-version.outputs.prerelease }} - args: -v ${{ matrix.args }} - - # ── macOS: build + sign + notarize + upload ── - - name: Build macOS app - if: matrix.platform == 'macos-latest' - uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }} - TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }} - APPLE_SIGNING_IDENTITY: ${{ secrets.APPLE_SIGNING_IDENTITY }} - APPLE_ID: ${{ secrets.APPLE_ID }} - APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }} - APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} - with: - projectPath: studio - tauriScript: npx --prefix . tauri - tagName: ${{ needs.prepare-version.outputs.desktop_release_tag }} - releaseName: 'Unsloth Studio (Desktop) ${{ needs.prepare-version.outputs.studio_version }}' - releaseBody: | - Desktop app for Unsloth Studio. - - **macOS**: Download the Apple Silicon `.dmg`. - **Windows**: Download the `-setup.exe` installer. - **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal). - - > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package. - > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64` - > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually. - releaseDraft: ${{ inputs.draft }} - prerelease: ${{ needs.prepare-version.outputs.prerelease }} - args: -v ${{ matrix.args }} - - # ── Windows: build + sign + upload ── - - name: Build Windows app - if: matrix.platform == 'windows-latest' - uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }} - TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }} - AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} - AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }} - AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} - AZURE_TRUSTED_SIGNING_ACCOUNT_NAME: ${{ secrets.AZURE_TRUSTED_SIGNING_ACCOUNT_NAME }} - AZURE_CERTIFICATE_PROFILE_NAME: ${{ secrets.AZURE_CERTIFICATE_PROFILE_NAME }} - with: - projectPath: studio - tauriScript: npx --prefix . tauri - tagName: ${{ needs.prepare-version.outputs.desktop_release_tag }} - releaseName: 'Unsloth Studio (Desktop) ${{ needs.prepare-version.outputs.studio_version }}' - releaseBody: | - Desktop app for Unsloth Studio. - - **macOS**: Download the Apple Silicon `.dmg`. - **Windows**: Download the `-setup.exe` installer. - **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal). - - > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package. - > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64` - > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually. - releaseDraft: ${{ inputs.draft }} - prerelease: ${{ needs.prepare-version.outputs.prerelease }} - args: -v ${{ matrix.args }} - - # Release process note: only non-draft workflow runs advance the public - # desktop-latest updater channel. Draft builds are for private review; if a - # draft is manually published later, this channel intentionally remains - # unchanged until a narrow manual channel-publish flow is added or a public - # desktop release is created by running this workflow with draft=false. - publish-updater-channel: - name: Publish desktop updater channel - needs: [prepare-version, build] - if: ${{ !inputs.draft }} - runs-on: ubuntu-latest - permissions: - contents: write - env: - GH_REPO: ${{ github.repository }} - APP_VERSION: ${{ needs.prepare-version.outputs.app_version }} - STUDIO_VERSION: ${{ needs.prepare-version.outputs.studio_version }} - DESKTOP_RELEASE_TAG: ${{ needs.prepare-version.outputs.desktop_release_tag }} - DESKTOP_PRERELEASE: ${{ needs.prepare-version.outputs.prerelease }} - - steps: - - name: Download versioned updater metadata - shell: bash - env: - GH_TOKEN: ${{ github.token }} - run: | - set -euo pipefail - mkdir -p "$RUNNER_TEMP/desktop-updater" - gh api "repos/${GITHUB_REPOSITORY}/releases/tags/${DESKTOP_RELEASE_TAG}" > "$RUNNER_TEMP/source-release.json" - python3 <<'PY' - import json - import os - import pathlib - import sys - - source = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'source-release.json').read_text()) - expected_tag = os.environ['DESKTOP_RELEASE_TAG'] - if source.get('tag_name') != expected_tag: - sys.exit(f'Expected source release {expected_tag}, got {source.get("tag_name")}') - if source.get('draft'): - sys.exit(f'Source desktop release {expected_tag} is draft; refusing to publish public updater channel') - PY - gh release download "$DESKTOP_RELEASE_TAG" --pattern latest.json --dir "$RUNNER_TEMP/desktop-updater" --clobber - test -s "$RUNNER_TEMP/desktop-updater/latest.json" - - - name: Validate versioned updater metadata - shell: bash - run: | - python3 <<'PY' - import json - import os - import pathlib - import re - import sys - - app_version = os.environ['APP_VERSION'] - release_tag = os.environ['DESKTOP_RELEASE_TAG'] - latest_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-updater', 'latest.json') - data = json.loads(latest_path.read_text()) - if not isinstance(data, dict): - sys.exit('latest.json must be a JSON object') - - version = data.get('version') - if not isinstance(version, str) or not version: - sys.exit('latest.json missing version') - if not re.fullmatch(r'v?\d+\.\d+\.\d+(?:[-+][0-9A-Za-z.-]+)?', version): - sys.exit(f'latest.json version is not SemVer-like: {version}') - if version.removeprefix('v') != app_version: - sys.exit(f'latest.json version {version} does not match desktop app version {app_version}') - - platforms = data.get('platforms') - if not isinstance(platforms, dict) or not platforms: - sys.exit('latest.json missing platforms') - - required_families = { - 'darwin-aarch64': False, - 'linux-x86_64': False, - 'windows-x86_64': False, - } - expected_prefix = f'https://github.com/unslothai/unsloth/releases/download/{release_tag}/' - forbidden_fragments = ('/releases/latest/', '/releases/download/desktop-latest/') - - for platform, entry in platforms.items(): - if not isinstance(entry, dict): - sys.exit(f'Platform {platform} must be an object') - url = entry.get('url') - signature = entry.get('signature') - if not isinstance(url, str) or not url.strip(): - sys.exit(f'Platform {platform} missing url') - if not isinstance(signature, str) or not signature.strip(): - sys.exit(f'Platform {platform} missing signature') - if any(fragment in url for fragment in forbidden_fragments): - sys.exit(f'Platform {platform} points at a moving updater channel: {url}') - if not url.startswith(expected_prefix): - sys.exit(f'Platform {platform} URL must point at {release_tag}: {url}') - for family in required_families: - if platform == family or platform.startswith(family + '-'): - required_families[family] = True - - missing = [family for family, found in required_families.items() if not found] - if missing: - sys.exit('latest.json missing required platform families: ' + ', '.join(missing)) - PY - - - name: Ensure desktop updater channel release - shell: bash - env: - GH_TOKEN: ${{ github.token }} - run: | - set -euo pipefail - channel_json="$RUNNER_TEMP/desktop-latest-release.json" - if ! gh api "repos/${GITHUB_REPOSITORY}/releases/tags/desktop-latest" > "$channel_json" 2>/dev/null; then - gh release create desktop-latest \ - --title "Unsloth Studio Desktop updater channel" \ - --notes "Machine-managed desktop updater channel; latest.json is replaced by release-desktop.yml." \ - --prerelease \ - --latest=false \ - --target "$GITHUB_SHA" - gh api "repos/${GITHUB_REPOSITORY}/releases/tags/desktop-latest" > "$channel_json" - fi - - python3 <<'PY' - import json - import os - import pathlib - import sys - - channel = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-latest-release.json').read_text()) - if channel.get('draft'): - sys.exit('desktop-latest release is draft; refusing to publish updater channel') - if channel.get('immutable'): - sys.exit('desktop-latest release is immutable; cannot replace latest.json') - if not channel.get('prerelease'): - sys.exit('desktop-latest release must be a prerelease so it cannot compete with repo-wide latest') - PY - - - name: Prevent updater channel downgrade - shell: bash - env: - GH_TOKEN: ${{ github.token }} - run: | - set -euo pipefail - mkdir -p "$RUNNER_TEMP/desktop-current" - if ! gh release download desktop-latest --pattern latest.json --dir "$RUNNER_TEMP/desktop-current" --clobber 2>/dev/null; then - echo "No existing desktop-latest latest.json found; allowing first channel publish." - exit 0 - fi - python3 <<'PY' - import json - import os - import pathlib - import re - import sys - - def parse(value: str): - value = value.removeprefix('v') - match = re.fullmatch( - r'(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)' - r'(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?' - r'(?:\+[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?', - value, - ) - if not match: - sys.exit(f'desktop-latest latest.json has invalid version: {value}') - major, minor, patch, prerelease = match.groups() - return (int(major), int(minor), int(patch), prerelease) - - def numeric_tail(identifier: str) -> tuple[str, int] | None: - match = re.fullmatch(r'([A-Za-z-]+)(\d+)', identifier) - if not match: - return None - return (match.group(1).lower(), int(match.group(2))) - - def compare_identifier(left: str, right: str) -> int: - left_num = left.isdigit() - right_num = right.isdigit() - if left_num and right_num: - return (int(left) > int(right)) - (int(left) < int(right)) - if left_num: - return -1 - if right_num: - return 1 - - left_tail = numeric_tail(left) - right_tail = numeric_tail(right) - if left_tail and right_tail and left_tail[0] == right_tail[0]: - return (left_tail[1] > right_tail[1]) - (left_tail[1] < right_tail[1]) - - return (left > right) - (left < right) - - def compare_prerelease(left: str | None, right: str | None) -> int: - if left == right: - return 0 - if left is None: - return 1 - if right is None: - return -1 - left_parts = left.split('.') - right_parts = right.split('.') - for left_part, right_part in zip(left_parts, right_parts): - order = compare_identifier(left_part, right_part) - if order: - return order - return (len(left_parts) > len(right_parts)) - (len(left_parts) < len(right_parts)) - - def compare(left: str, right: str) -> int: - left_major, left_minor, left_patch, left_pre = parse(left) - right_major, right_minor, right_patch, right_pre = parse(right) - left_core = (left_major, left_minor, left_patch) - right_core = (right_major, right_minor, right_patch) - if left_core != right_core: - return (left_core > right_core) - (left_core < right_core) - return compare_prerelease(left_pre, right_pre) - - current_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-current', 'latest.json') - next_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-updater', 'latest.json') - current = json.loads(current_path.read_text()).get('version') - next_version = json.loads(next_path.read_text()).get('version') - if not isinstance(current, str) or not isinstance(next_version, str): - sys.exit('Could not compare desktop-latest channel versions') - if compare(next_version, current) < 0: - sys.exit( - f'Refusing to move desktop-latest from {current} to older version {next_version}.' - ) - PY - - - name: Publish desktop updater channel metadata - shell: bash - env: - GH_TOKEN: ${{ github.token }} - run: | - set -euo pipefail - gh release upload desktop-latest "$RUNNER_TEMP/desktop-updater/latest.json" --clobber - gh api "repos/${GITHUB_REPOSITORY}/releases/tags/desktop-latest" > "$RUNNER_TEMP/desktop-latest-release.json" - python3 <<'PY' - import json - import os - import pathlib - import sys - - channel = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-latest-release.json').read_text()) - assets = [asset for asset in channel.get('assets', []) if asset.get('name') == 'latest.json'] - if len(assets) != 1: - sys.exit(f'Expected exactly one desktop-latest latest.json asset, found {len(assets)}') - expected_url = f'https://github.com/{os.environ["GITHUB_REPOSITORY"]}/releases/download/desktop-latest/latest.json' - actual_url = assets[0].get('browser_download_url') - if actual_url != expected_url: - sys.exit(f'desktop-latest latest.json URL mismatch: expected {expected_url}, got {actual_url}') - PY diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml deleted file mode 100644 index 1a4cf841d0..0000000000 --- a/.github/workflows/stale.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: 'Inactive Issue Pinger' - -on: - schedule: - - cron: '30 5 * * *' # Runs at 5:30 UTC every day - -jobs: - stale: - runs-on: ubuntu-latest - permissions: - issues: write - - steps: - - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0 - with: - # The message to post on stale issues. - # This message will ping the issue author. - # Note: The stale bot action does not currently support a direct placeholder for the last commenter. - # As a workaround, this message encourages any participant to reply. - stale-issue-message: > - Is this issue still important to you? - Apologies in advance we might have missed this issue as well. - For faster response times, please post on our Reddit server - https://www.reddit.com/r/unsloth or our Discord - https://discord.com/invite/unsloth - - # The number of days of inactivity before an issue is considered stale. - days-before-issue-stale: 9999 - - # Set to -1 to never close stale issues. - days-before-issue-close: -1 - - # A label to apply to stale issues. - stale-issue-label: 'inactive' - - # The number of operations to perform per run to avoid rate limiting. - operations-per-run: 500 - - enable-statistics: false diff --git a/.github/workflows/studio-frontend-ci.yml b/.github/workflows/studio-frontend-ci.yml deleted file mode 100644 index 1270a57ef6..0000000000 --- a/.github/workflows/studio-frontend-ci.yml +++ /dev/null @@ -1,151 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -# Frontend PR gate: lockfile freshness, typecheck, build, and a bundle grep -# that catches the 2026.5.1 chat-history regression at the JS level. -# -# biome runs as non-blocking for now: the codebase currently has accumulated -# ~470 errors and ~1650 warnings against the existing biome config. Surfacing -# the count in CI lets us drive it down without forcing a fleet-wide cleanup -# in the same PR. Drop `continue-on-error` once that number is zero. - -name: Frontend CI - -on: - pull_request: - paths: - - 'studio/frontend/**' - - 'scripts/check_frontend_dep_removal.py' - - 'tests/studio/test_frontend_dep_removal.py' - - '.github/workflows/studio-frontend-ci.yml' - push: - branches: [main, pip] - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - build: - name: Frontend build + bundle sanity - runs-on: ubuntu-latest - timeout-minutes: 10 - defaults: - run: - working-directory: studio/frontend - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - # FIXME: drop this step once @assistant-ui/* and assistant-stream - # leave 0.x -- on 1.x, caret ranges are conventional. Until then, - # every 0.minor on this surface is a SemVer-major (this is exactly - # how 2026.5.1 shipped a broken chat runtime: ^0.12.19 quietly - # resolved to 0.12.28). - - name: '@assistant-ui must be pinned exactly (no caret/tilde)' - working-directory: ${{ github.workspace }} - run: | - set -e - if grep -nE '"(@assistant-ui/[a-z-]+|assistant-stream)":[[:space:]]*"[\^~]' studio/frontend/package.json; then - echo "::error file=studio/frontend/package.json::These packages must be pinned to exact versions until they leave 0.x. Drop the leading ^ or ~." - exit 1 - fi - echo "All assistant-ui packages are pinned exactly." - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '22' - - # Run the structural lockfile scan BEFORE npm ci. A compromised - # tarball runs its `prepare` / `postinstall` during `npm ci`, - # so any catch has to fire upstream of that. The scanner is - # pure-Python read-only; safe to call ahead of every install. - - name: Lockfile supply-chain audit (pre-install scan) - working-directory: ${{ github.workspace }} - run: python3 scripts/lockfile_supply_chain_audit.py - - - name: Lockfile must agree with package.json (npm ci is strict) - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: npm ci --no-fund --no-audit - - - name: npm ci must not have modified the working tree - working-directory: ${{ github.workspace }} - run: | - if ! git diff --quiet -- studio/frontend; then - echo "::error::npm ci modified files; commit the updated lockfile" - git status -- studio/frontend - exit 1 - fi - - # Catch the common foot-gun: a dep dropped from package.json that is - # still imported somewhere. The script walks the lockfile dep graph - # from the new top-level deps and only counts top-level node_modules - # paths as valid resolution targets for bare src/ imports. - # - # actions/checkout uses fetch-depth: 1 by default, so the base branch - # is not available locally. Fetch the single base commit with an - # explicit refspec so origin/ is reliably created (a bare - # `git fetch origin ` only updates FETCH_HEAD in some configs). - - name: Dependency removal safety check - if: github.event_name == 'pull_request' - working-directory: ${{ github.workspace }} - run: | - git fetch --no-tags --depth=1 origin \ - "${{ github.base_ref }}:refs/remotes/origin/${{ github.base_ref }}" - python3 scripts/check_frontend_dep_removal.py \ - --base "origin/${{ github.base_ref }}" \ - --enumerate-dead - python3 tests/studio/test_frontend_dep_removal.py - - - name: Typecheck - run: npm run typecheck - - - name: Build - run: npm run build - - - name: Built bundle must not contain Studio's unstable_Provider call site - run: | - set -e - JS=$(ls dist/assets/index-*.js | head -1) - HITS=$(grep -c 'unstable_Provider:' "$JS" || echo 0) - echo "main bundle: $JS" - echo "unstable_Provider: hits=$HITS (assistant-ui internals contribute up to 3)" - if [ "$HITS" -gt 3 ]; then - echo "::error file=studio/frontend/src/features/chat/runtime-provider.tsx::Studio bundle still passes unstable_Provider through useRemoteThreadListRuntime; this is the 2026.5.1 chat-history regression. Pass adapters directly into useLocalRuntime instead." - exit 1 - fi - - - name: Bundle size budget (75 MB) - run: | - SIZE=$(du -sb dist | cut -f1) - BUDGET=$((75 * 1024 * 1024)) - echo "dist size: $SIZE bytes ($((SIZE/1024/1024)) MB), budget: $BUDGET bytes (75 MB)" - if [ "$SIZE" -gt "$BUDGET" ]; then - echo "::error::studio/frontend/dist/ exceeded the 75 MB budget. Drop dead deps (e.g. the unused next dep) or split chunks." - exit 1 - fi - - - name: Biome (non-blocking until accumulated drift is cleared) - continue-on-error: true - run: npm run biome:check - - - name: Upload built dist - # Always upload so a green run is reviewable too -- the dist - # output catches "tests passed but bundle changed unexpectedly" - # regressions that would be invisible if we only kept artifacts - # on failure. - if: always() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: studio-frontend-dist - path: studio/frontend/dist - retention-days: 3 diff --git a/.github/workflows/studio-inference-smoke.yml b/.github/workflows/studio-inference-smoke.yml deleted file mode 100644 index 6def56f769..0000000000 --- a/.github/workflows/studio-inference-smoke.yml +++ /dev/null @@ -1,1052 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -# Three end-to-end smoke jobs that boot a freshly-installed Studio and -# exercise the surfaces real users hit through the OpenAI / Anthropic -# SDKs and curl. Each job picks the smallest model that exercises the -# behaviour under test, primes HF_HOME via actions/cache, and shares -# the install.sh --local --no-torch bootstrap. -# -# 1. OpenAI, Anthropic API tests -# gemma-3-270m-it UD-Q4_K_XL (~254 MiB). -# Password rotation via /api/auth/change-password (old fails, -# new works), then OpenAI + Anthropic Python SDKs against /v1/* -# with temperature=0 and a fixed seed. Asserts the four-turn -# conversation is deterministic across two runs. -# -# 2. Tool calling Tests -# Qwen3.5-2B UD-IQ3_XXS (~890 MiB). OpenAI function calling, -# server-side tools (python, terminal, web_search) via -# enable_tools / enabled_tools, and enable_thinking on/off. -# -# 3. JSON, images -# gemma-4-E2B-it UD-IQ3_XXS (~2.4 GiB) + mmproj-F16 (~986 MiB). -# response_format JSON-schema decoding and OpenAI image_url -# (data URI) plus Anthropic source/base64 image inputs. -# -# All three jobs run in parallel. Total wall time is dominated by job 3 -# on a cold cache; warm cache cuts that to ~3 min. - -name: Studio GGUF CI - -on: - pull_request: - paths: - - 'studio/**' - - 'unsloth/**' - - 'unsloth_cli/**' - - 'install.sh' - - 'pyproject.toml' - - '.github/workflows/studio-inference-smoke.yml' - push: - branches: [main, pip] - # Manual trigger for pre-warming HF_HOME caches on main, or re-running - # against an arbitrary branch without pushing a no-op commit. - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - # ───────────────────────────────────────────────────────────────────── - # Job 1: OpenAI, Anthropic API tests - # ───────────────────────────────────────────────────────────────────── - openai-anthropic: - name: OpenAI, Anthropic API tests - runs-on: ubuntu-latest - timeout-minutes: 25 - env: - GGUF_REPO: unsloth/gemma-3-270m-it-GGUF - GGUF_VARIANT: UD-Q4_K_XL - GGUF_FILE: gemma-3-270m-it-UD-Q4_K_XL.gguf - STUDIO_PORT: '18888' - HF_HOME: ${{ github.workspace }}/hf-cache - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - - name: Linux deps for llama.cpp prebuilt - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends \ - libcurl4-openssl-dev libssl-dev jq - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '22' - - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: '3.12' - cache: 'pip' - - - name: Restore HF_HOME for ${{ env.GGUF_REPO }} - id: cache-hf - uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - continue-on-error: true - with: - path: hf-cache - key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-v1 - - - name: Prime HF_HOME with the GGUF - id: prime-hf - if: steps.cache-hf.outputs.cache-hit != 'true' || steps.cache-hf.outcome != 'success' - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - python -m pip install --upgrade huggingface_hub - mkdir -p hf-cache - bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$GGUF_FILE" - - - name: Save HF_HOME for ${{ env.GGUF_REPO }} - if: always() && steps.prime-hf.outcome == 'success' - uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - with: - path: hf-cache - key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-v1 - - - name: Install Studio (--local, --no-torch) - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - mkdir -p logs - set -o pipefail - bash install.sh --local --no-torch 2>&1 | tee logs/install.log - - - name: Install OpenAI + Anthropic Python SDKs - run: pip install 'openai>=1.50' 'anthropic>=0.40' - - - name: Reset auth + boot Studio (API-only) - run: | - unsloth studio reset-password - mkdir -p logs - UNSLOTH_API_ONLY=1 unsloth studio -H 127.0.0.1 -p "$STUDIO_PORT" \ - > logs/studio.log 2>&1 & - echo "STUDIO_PID=$!" >> "$GITHUB_ENV" - - - name: Wait for /api/health - run: | - for i in $(seq 1 180); do - if curl -fs "http://127.0.0.1:${STUDIO_PORT}/api/health" > /tmp/health.json; then - jq -e '.status == "healthy"' /tmp/health.json - exit 0 - fi - sleep 1 - done - echo "Studio did not become healthy in 180s" - tail -200 logs/studio.log - exit 1 - - - name: Password rotation (old must fail, new must work) - run: | - OLD=$(cat ~/.unsloth/studio/auth/.bootstrap_password) - NEW="CIRotated-$(python -c 'import secrets; print(secrets.token_urlsafe(12))')" - echo "::add-mask::$OLD" - echo "::add-mask::$NEW" - # 1. Login with the bootstrap password. - OLD_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}" | jq -r .access_token) - [ -n "$OLD_TOKEN" ] && [ "$OLD_TOKEN" != "null" ] || { echo "bootstrap login failed"; exit 1; } - # 2. Rotate to a fresh random password. - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/change-password" \ - -H "Authorization: Bearer $OLD_TOKEN" -H 'content-type: application/json' \ - -d "{\"current_password\":\"$OLD\",\"new_password\":\"$NEW\"}" > /dev/null - # 3. Old password must now be rejected (HTTP 401). - OLD_STATUS=$(curl -s -o /dev/null -w '%{http_code}' \ - -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}") - if [ "$OLD_STATUS" != "401" ]; then - echo "::error::Login with old password returned $OLD_STATUS, expected 401" - exit 1 - fi - # 4. New password must succeed; capture the JWT for downstream steps. - NEW_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$NEW\"}" | jq -r .access_token) - [ -n "$NEW_TOKEN" ] && [ "$NEW_TOKEN" != "null" ] || { echo "new login failed"; exit 1; } - echo "TOKEN=$NEW_TOKEN" >> "$GITHUB_ENV" - echo "password rotation OK (old=401, new=200)" - - - name: Load the GGUF (HF repo + variant, served from HF_HOME cache) - run: | - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/inference/load" \ - -H "Authorization: Bearer $TOKEN" -H 'content-type: application/json' \ - --max-time 600 \ - -d "{\"model_path\":\"$GGUF_REPO\",\"gguf_variant\":\"$GGUF_VARIANT\",\"is_lora\":false,\"max_seq_length\":2048}" \ - | jq '{status, display_name, is_gguf, context_length}' - - - name: Multi-turn determinism via OpenAI + Anthropic SDKs - env: - BASE_URL: http://127.0.0.1:18888 - run: | - python - <<'PY' - import json - import os - from openai import OpenAI - from anthropic import Anthropic - - BASE = os.environ["BASE_URL"] - KEY = os.environ["TOKEN"] # JWT also accepted as Bearer on /v1/* - SEED = 3407 - - # Four-turn conversation: the second and fourth turns can only be - # answered correctly if the model sees the prior turns, so this - # also exercises the conversation-history wiring. - PROMPTS = [ - "What is 1+1?", - "What did I ask before?", - "What is the capital of France?", - "Repeat the city name", - ] - - def run_openai(): - client = OpenAI(base_url = f"{BASE}/v1", api_key = KEY) - history, replies = [], [] - for prompt in PROMPTS: - history.append({"role": "user", "content": prompt}) - resp = client.chat.completions.create( - model = "default", - messages = history, - temperature = 0.0, - max_tokens = 80, - seed = SEED, - extra_body = {"enable_thinking": False}, - ) - text = resp.choices[0].message.content or "" - replies.append(text) - history.append({"role": "assistant", "content": text}) - return replies - - def run_anthropic(): - # Two SDK quirks vs. Studio: - # 1. base_url must NOT include /v1 -- the SDK appends - # /v1/messages itself; otherwise the request hits - # /v1/v1/messages and 405s. - # 2. The SDK sends `x-api-key` by default, but Studio's - # auth layer is HTTPBearer-only. Override via - # default_headers so Authorization: Bearer ... is - # sent instead. - client = Anthropic( - base_url = BASE, - api_key = "unused", - default_headers = {"Authorization": f"Bearer {KEY}"}, - ) - history, replies = [], [] - for prompt in PROMPTS: - history.append({"role": "user", "content": prompt}) - msg = client.messages.create( - model = "default", - max_tokens = 80, - messages = history, - temperature = 0.0, - extra_body = {"seed": SEED, "enable_thinking": False}, - ) - text = "".join(b.text for b in msg.content if getattr(b, "type", None) == "text") - replies.append(text) - history.append({"role": "assistant", "content": text}) - return replies - - for label, runner in (("openai", run_openai), ("anthropic", run_anthropic)): - first = runner() - second = runner() - determinism_failures = [] - for i, (a, b) in enumerate(zip(first, second), start = 1): - print(f"[{label} turn {i}] {a!r}") - # Both runs must be non-empty; small-quant drift - # across runs is WARN-only (grounding asserts below - # are the stronger signal). - assert a, f"{label}: empty turn {i} response in first run" - assert b, f"{label}: empty turn {i} response in second run" - if a.strip() != b.strip(): - determinism_failures.append( - f"turn {i}: run1={a!r} run2={b!r}" - ) - if determinism_failures: - print( - f"[{label}] WARN non-determinism at temperature=0.0 across " - f"{len(determinism_failures)} of {len(first)} turn(s); " - f"small-quant model drift, not a Studio regression. " - f"Details: " + " | ".join(determinism_failures) - ) - # Sanity: turn-2 reply should mention the earlier question, and - # turn-4 reply should mention Paris (model echoes the city it - # produced for turn 3). Lower-cased substring checks keep the - # assertion robust to formatting jitter. - joined = " ".join(first).lower() - assert "1" in first[0], f"{label}: turn-1 answer should contain '1', got {first[0]!r}" - assert "paris" in joined, f"{label}: expected 'paris' somewhere in the four-turn transcript: {first}" - status_word = "PASS" if not determinism_failures else "PASS (with drift)" - print(f"[{label}] {status_word} -- 4 turns, history grounded ('paris' present)") - PY - - - name: Stop Studio - if: always() - run: | - kill "${STUDIO_PID}" 2>/dev/null || true - sleep 2 - ss -tln | grep ":${STUDIO_PORT}" || true - - - name: Upload logs - # Always upload so green runs are still reviewable. - if: always() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: openai-anthropic-log - path: | - logs/studio.log - logs/install.log - retention-days: 7 - - # ───────────────────────────────────────────────────────────────────── - # Job 2: Tool calling Tests - # ───────────────────────────────────────────────────────────────────── - tool-calling: - name: Tool calling Tests - runs-on: ubuntu-latest - timeout-minutes: 25 - env: - # Tool calling is the highest-volume GGUF in this workflow - # (Qwen3.5-2B at IQ3_XXS = ~890 MiB). Caching HF_HOME would - # store xet chunks + blobs + snapshots = ~4 GiB compressed -- - # 4-5x file-size inflation, dominated by xet chunks. Use main's - # `--local-dir gguf-cache` pattern to cache the flat .gguf only. - # Studio's /api/inference/load accepts either a HF repo (which - # uses HF_HOME) or an absolute file path; passing the absolute - # path keeps the test off HF_HOME entirely so the cache size - # tracks the GGUF file 1:1. The OpenAI/Anth and JSON+images - # jobs still cover the gguf_variant resolution path. - GGUF_REPO: unsloth/Qwen3.5-2B-GGUF - GGUF_FILE: Qwen3.5-2B-UD-IQ3_XXS.gguf - STUDIO_PORT: '18889' - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - - name: Linux deps for llama.cpp prebuilt - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends \ - libcurl4-openssl-dev libssl-dev jq - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '22' - - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: '3.12' - cache: 'pip' - - - name: Restore GGUF model file - id: cache-gguf - uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - continue-on-error: true - with: - path: gguf-cache - key: ${{ runner.os }}-gguf-${{ env.GGUF_REPO }}-${{ env.GGUF_FILE }}-v1 - - - name: Download GGUF if cache miss - id: download-gguf - if: steps.cache-gguf.outputs.cache-hit != 'true' || steps.cache-gguf.outcome != 'success' - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - python -m pip install --upgrade huggingface_hub - mkdir -p gguf-cache - bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$GGUF_FILE" gguf-cache - - - name: Save GGUF model file - if: always() && steps.download-gguf.outcome == 'success' - uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - with: - path: gguf-cache - key: ${{ runner.os }}-gguf-${{ env.GGUF_REPO }}-${{ env.GGUF_FILE }}-v1 - - - name: Install Studio (--local, --no-torch) - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - mkdir -p logs - set -o pipefail - bash install.sh --local --no-torch 2>&1 | tee logs/install.log - - - name: Reset auth + boot Studio (API-only, default tool policy) - # We deliberately use the API-only mode rather than - # `unsloth studio run` because the latter calls - # `set_tool_policy(...)` with a resolved bool: on loopback the - # default resolves to True, which forces every request through - # the server-side agentic loop and breaks the standard - # function-calling test below. API-only mode leaves - # tool_policy=None so each request's `enable_tools` field is - # honoured. - run: | - unsloth studio reset-password - mkdir -p logs - UNSLOTH_API_ONLY=1 unsloth studio -H 127.0.0.1 -p "$STUDIO_PORT" \ - > logs/studio.log 2>&1 & - echo "STUDIO_PID=$!" >> "$GITHUB_ENV" - - - name: Wait for /api/health, log in, change password, load model - run: | - for i in $(seq 1 180); do - if curl -fs "http://127.0.0.1:${STUDIO_PORT}/api/health" > /tmp/health.json; then - jq -e '.status == "healthy"' /tmp/health.json && break - fi - sleep 1 - done - jq -e '.status == "healthy"' /tmp/health.json - OLD=$(cat ~/.unsloth/studio/auth/.bootstrap_password) - NEW="CITool-$(python -c 'import secrets; print(secrets.token_urlsafe(12))')" - echo "::add-mask::$OLD" - echo "::add-mask::$NEW" - OLD_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}" | jq -r .access_token) - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/change-password" \ - -H "Authorization: Bearer $OLD_TOKEN" -H 'content-type: application/json' \ - -d "{\"current_password\":\"$OLD\",\"new_password\":\"$NEW\"}" > /dev/null - TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$NEW\"}" | jq -r .access_token) - echo "API_KEY=$TOKEN" >> "$GITHUB_ENV" - GGUF_PATH="$GITHUB_WORKSPACE/gguf-cache/${GGUF_FILE}" - ls -lh "$GGUF_PATH" - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/inference/load" \ - -H "Authorization: Bearer $TOKEN" -H 'content-type: application/json' \ - --max-time 600 \ - -d "{\"model_path\":\"$GGUF_PATH\",\"is_lora\":false,\"max_seq_length\":2048}" \ - | jq '{status, display_name}' - - - name: Tool calling, server-side tools, thinking on/off - env: - BASE_URL: http://127.0.0.1:18889 - run: | - python - <<'PY' - import json - import os - import urllib.request - - BASE = os.environ["BASE_URL"] - KEY = os.environ["API_KEY"] - SEED = 3407 - - def post(path, body, *, timeout = 240): - """Plain JSON POST. For requests that don't go through - the server-side agentic loop, the response is one JSON - object.""" - data = json.dumps(body).encode() - req = urllib.request.Request( - f"{BASE}{path}", - data = data, - method = "POST", - headers = { - "Authorization": f"Bearer {KEY}", - "Content-Type": "application/json", - }, - ) - with urllib.request.urlopen(req, timeout = timeout) as resp: - return resp.status, json.loads(resp.read().decode()) - - def post_sse(path, body, *, timeout = 600): - """POST a streaming request and accumulate the assistant - text deltas. The server-side agentic loop ALWAYS returns - SSE regardless of the request's `stream` field, so any - call with enable_tools=true must use this helper. - - Returns (content, raw_payloads): - content -- concatenated assistant delta.content - raw_payloads -- list of every raw "data: ..." event - payload (JSON strings). Callers asserting - that a server-side tool actually ran (and - not just that the model emitted some - text) should grep raw_payloads for tool - invocation markers / tool output, since - `delta.content` alone is not evidence - that the tool path executed. - """ - body = {**body, "stream": True} - data = json.dumps(body).encode() - req = urllib.request.Request( - f"{BASE}{path}", - data = data, - method = "POST", - headers = { - "Authorization": f"Bearer {KEY}", - "Content-Type": "application/json", - }, - ) - parts = [] - events = [] - with urllib.request.urlopen(req, timeout = timeout) as resp: - for raw in resp: - line = raw.decode().strip() - if not line.startswith("data: "): - continue - payload = line[6:] - if payload == "[DONE]": - break - events.append(payload) - try: - chunk = json.loads(payload) - except json.JSONDecodeError: - continue - for choice in chunk.get("choices", []): - delta = choice.get("delta", {}) or {} - if delta.get("content"): - parts.append(delta["content"]) - return "".join(parts), events - - _STUDIO_TOOL_TYPES = { - "tool_start", "tool_end", "tool_use", "tool_result", - } - - def _tool_invoked(events): - """Structural check: True iff some SSE payload is a real - tool envelope (Studio tool_start/tool_end, Anthropic - tool_use/tool_result, OpenAI non-empty delta.tool_calls / - message.tool_calls / finish_reason='tool_calls' / - role:'tool' / function_call). tool_status is NOT - evidence: Studio emits empty tool_status events on - iteration boundaries even when no tool ran. - """ - for raw in events: - try: - ev = json.loads(raw) - except (json.JSONDecodeError, TypeError): - continue - if not isinstance(ev, dict): - continue - if ev.get("type") in _STUDIO_TOOL_TYPES: - return True - for choice in ev.get("choices", []) or []: - if not isinstance(choice, dict): - continue - if choice.get("finish_reason") == "tool_calls": - return True - for src_key in ("delta", "message"): - src = choice.get(src_key) or {} - if not isinstance(src, dict): - continue - tc = src.get("tool_calls") - if isinstance(tc, list) and tc: - return True - if src.get("function_call"): - return True - if src.get("role") == "tool": - return True - for item in ev.get("output", []) or []: - if isinstance(item, dict) and item.get("type") in { - "tool_call", "function_call", "tool_use", - }: - return True - content = ev.get("content") - if isinstance(content, list): - for blk in content: - if isinstance(blk, dict) and blk.get("type") in { - "tool_use", "tool_result", - }: - return True - return False - - def _tool_output_contains(events, *needles): - """True iff any tool_end.result / tool_result.content / - tool-role message content contains a needle. Inspects - the tool's own output, not the model's narration.""" - for raw in events: - try: - ev = json.loads(raw) - except (json.JSONDecodeError, TypeError): - continue - if not isinstance(ev, dict): - continue - if ev.get("type") == "tool_end": - result = ev.get("result") - if isinstance(result, str) and any(n in result for n in needles if n): - return True - if ev.get("type") == "tool_result": - content = ev.get("content") - if isinstance(content, str) and any(n in content for n in needles if n): - return True - if isinstance(content, list): - for blk in content: - if isinstance(blk, dict): - text = blk.get("text") or blk.get("content") - if isinstance(text, str) and any(n in text for n in needles if n): - return True - for choice in ev.get("choices", []) or []: - delta = (choice or {}).get("delta") or {} - msg = (choice or {}).get("message") or {} - for src in (delta, msg): - if src.get("role") == "tool": - content = src.get("content") or "" - if isinstance(content, str) and any(n in content for n in needles if n): - return True - return False - - # ── 1. Standard OpenAI function calling ────────────────────── - weather_tool = { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get current weather for a city.", - "parameters": { - "type": "object", - "properties": {"city": {"type": "string"}}, - "required": ["city"], - }, - }, - } - - status, data = post("/v1/chat/completions", { - "messages": [{"role": "user", "content": "What is the weather in Paris?"}], - "tools": [weather_tool], - "tool_choice": "required", - "stream": False, - "temperature": 0.0, - "seed": SEED, - "max_tokens": 120, - }) - assert status == 200, f"tool call status {status}: {data}" - choice = data["choices"][0] - assert choice["finish_reason"] == "tool_calls", f"finish_reason={choice['finish_reason']!r}" - tc = choice["message"]["tool_calls"][0] - assert tc["function"]["name"] == "get_weather" - args = json.loads(tc["function"]["arguments"]) - assert args.get("city"), f"missing city arg: {args}" - print(f"[tools] PASS function calling -> {tc['function']['name']}({args})") - - # T=0 = deterministic argmax in llama.cpp; T>0 lets seed - # rotation explore distinct trajectories on retry. - TOOL_PROBE_TEMP = 0.4 - - def _run_tool_probe(*, label, prompt, enabled, session, needles, - max_attempts = 4): - """Drive a server-side tool with retries. Hard FAIL if no - attempt has structural invocation evidence. WARN (not - FAIL) if invoked but no attempt produces the expected - literal in tool_end.result -- small-quant Qwen3.5-2B can - emit OpenAI tool_calls deltas without Studio's GGUF - agentic loop intercepting them, and that GGUF-vs-OpenAI - format mismatch is out of scope for #5642. - """ - attempts_log = [] - best = None - for attempt_i in range(max_attempts): - attempt_seed = SEED + attempt_i - content, events = post_sse("/v1/chat/completions", { - "messages": [{"role": "user", "content": prompt}], - "enable_tools": True, - "enabled_tools": enabled, - "session_id": f"{session}-att{attempt_i}", - "temperature": TOOL_PROBE_TEMP, - "seed": attempt_seed, - "max_tokens": 600, - }) - invoked = _tool_invoked(events) - produced = _tool_output_contains(events, *needles) - attempts_log.append({ - "attempt": attempt_i, "seed": attempt_seed, - "n_events": len(events), - "tool_invoked": invoked, "tool_output_contains": produced, - "content_len": len(content), - }) - if invoked and produced: - print(f"[tools] PASS {label} attempt {attempt_i}") - return content, events, attempts_log - if invoked and best is None: - best = (content, events) - print(f"[tools] retry {label} attempt {attempt_i}: invoked={invoked} output_ok={produced} events={len(events)}") - if best is not None: - print(f"[tools] WARN {label}: invoked but no tool_end.result match (small-quant flake). Attempts: {attempts_log}") - content, events = best - return content, events, attempts_log - raise AssertionError( - f"{label}: no structural tool-invocation evidence across " - f"{max_attempts} attempts. enable_tools may be silently " - f"ignored. Attempts: {attempts_log}" - ) - - # ── 2. Server-side python tool ─────────────────────────────── - content, events, _attempts = _run_tool_probe( - label = "python tool", - prompt = "What is 123 * 456? Use the python tool to compute it and tell me the number.", - enabled = ["python"], - session = "ci-tool-calling-py", - needles = ("56088", "56,088"), - ) - if "56088" in content or "56,088" in content: - print(f"[tools] python tool narration OK") - else: - print(f"[tools] python tool narration drifted -- content={content!r}") - - # ── 3. Server-side bash (terminal) tool ────────────────────── - content, events, _attempts = _run_tool_probe( - label = "bash/terminal tool", - prompt = "Use the terminal tool to run `echo hello-bash-tool` and tell me the exact output.", - enabled = ["terminal"], - session = "ci-tool-calling-bash", - needles = ("hello-bash-tool",), - ) - if "hello-bash-tool" in content: - print(f"[tools] bash/terminal narration OK") - else: - print(f"[tools] bash/terminal narration dropped literal -- content={content!r}") - - # ── 4. Server-side web_search tool ─────────────────────────── - # DuckDuckGo is flaky from CI runners and small Qwen3.5-2B - # may not actually search. Only assert that the SSE stream - # opens and yields any data; HTTP / parser failures already - # raise above. Tool-invocation strictness is relaxed here - # because (a) the search may legitimately return no results, - # and (b) DuckDuckGo upstream blocks GHA IP ranges often - # enough that requiring a tool_call marker would create - # red-herring failures from infra rather than from Studio. - try: - content, events = post_sse("/v1/chat/completions", { - "messages": [{"role": "user", "content": "Search the web for 'unsloth ai github' and summarise."}], - "enable_tools": True, - "enabled_tools": ["web_search"], - "session_id": "ci-tool-calling-web", - "temperature": 0.0, - "seed": SEED, - "max_tokens": 400, - }) - print( - f"[tools] PASS web_search stream ({len(content)} chars in content, " - f"{len(events)} raw events)" - ) - except Exception as exc: - print(f"[tools] WARN web_search probe failed (non-blocking): {exc}") - - # ── 5. Thinking on / off ───────────────────────────────────── - # Studio strips think blocks from message.content for tools-mode - # responses, so we toggle plain chat (no enable_tools) and look - # at the surfaced reasoning_content / message.thinking field. - def thinking_call(enable): - status, data = post("/v1/chat/completions", { - "messages": [{"role": "user", "content": "Briefly: is 17 prime?"}], - "stream": False, - "enable_thinking": enable, - "temperature": 0.0, - "seed": SEED, - "max_tokens": 300, - }) - assert status == 200 - msg = data["choices"][0]["message"] - # Studio surfaces thinking via reasoning_content (OpenAI - # extension). Fall back to inline markers for - # robustness across template versions. - raw = (msg.get("content") or "") + (msg.get("reasoning_content") or "") - return raw - - on_text = thinking_call(True) - off_text = thinking_call(False) - had_think_on = ("" in on_text) or len(on_text) > 80 - had_think_off = ("" in off_text) and len(off_text) > 0 - assert had_think_on, ( - f"enable_thinking=True produced no thinking signal: {on_text!r}" - ) - # Off-mode should not contain the literal marker. - assert "" not in off_text, ( - f"enable_thinking=False but still present: {off_text!r}" - ) - print(f"[tools] PASS thinking on/off (on={len(on_text)} chars, off={len(off_text)} chars)") - PY - - - name: Stop Studio - if: always() - run: | - kill "${STUDIO_PID}" 2>/dev/null || true - sleep 2 - ss -tln | grep ":${STUDIO_PORT}" || true - - - name: Upload logs - # Always upload so green runs are still reviewable. - if: always() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: tool-calling-log - path: | - logs/studio.log - logs/install.log - retention-days: 7 - - # ───────────────────────────────────────────────────────────────────── - # Job 3: JSON, images - # ───────────────────────────────────────────────────────────────────── - json-images: - name: JSON, images - runs-on: ubuntu-latest - timeout-minutes: 30 - env: - GGUF_REPO: unsloth/gemma-4-E2B-it-GGUF - GGUF_VARIANT: UD-IQ3_XXS - GGUF_FILE: gemma-4-E2B-it-UD-IQ3_XXS.gguf - MMPROJ_FILE: mmproj-F16.gguf - STUDIO_PORT: '18890' - HF_HOME: ${{ github.workspace }}/hf-cache - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - - name: Linux deps for llama.cpp prebuilt - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends \ - libcurl4-openssl-dev libssl-dev jq - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '22' - - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: '3.12' - cache: 'pip' - - - name: Restore HF_HOME for ${{ env.GGUF_REPO }} (model + mmproj) - id: cache-hf - uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - continue-on-error: true - with: - path: hf-cache - key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-${{ env.MMPROJ_FILE }}-v1 - - - name: Prime HF_HOME with the GGUF + mmproj - id: prime-hf - if: steps.cache-hf.outputs.cache-hit != 'true' || steps.cache-hf.outcome != 'success' - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - python -m pip install --upgrade huggingface_hub - mkdir -p hf-cache - bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$GGUF_FILE" - bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$MMPROJ_FILE" - - - name: Save HF_HOME for ${{ env.GGUF_REPO }} (model + mmproj) - if: always() && steps.prime-hf.outcome == 'success' - uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - with: - path: hf-cache - key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-${{ env.MMPROJ_FILE }}-v1 - - - name: Install Studio (--local, --no-torch) - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - mkdir -p logs - set -o pipefail - bash install.sh --local --no-torch 2>&1 | tee logs/install.log - - - name: Install OpenAI + Anthropic Python SDKs - run: pip install 'openai>=1.50' 'anthropic>=0.40' - - - name: Reset auth + boot Studio (API-only) - # See Job 2's comment: API-only mode keeps tool_policy=None so - # response_format requests aren't routed through the agentic - # tool loop. - run: | - unsloth studio reset-password - mkdir -p logs - UNSLOTH_API_ONLY=1 unsloth studio -H 127.0.0.1 -p "$STUDIO_PORT" \ - > logs/studio.log 2>&1 & - echo "STUDIO_PID=$!" >> "$GITHUB_ENV" - - - name: Wait for /api/health, log in, change password, load model - run: | - for i in $(seq 1 180); do - if curl -fs "http://127.0.0.1:${STUDIO_PORT}/api/health" > /tmp/health.json; then - jq -e '.status == "healthy"' /tmp/health.json && break - fi - sleep 1 - done - jq -e '.status == "healthy"' /tmp/health.json - OLD=$(cat ~/.unsloth/studio/auth/.bootstrap_password) - NEW="CIJson-$(python -c 'import secrets; print(secrets.token_urlsafe(12))')" - echo "::add-mask::$OLD" - echo "::add-mask::$NEW" - OLD_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}" | jq -r .access_token) - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/change-password" \ - -H "Authorization: Bearer $OLD_TOKEN" -H 'content-type: application/json' \ - -d "{\"current_password\":\"$OLD\",\"new_password\":\"$NEW\"}" > /dev/null - TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$NEW\"}" | jq -r .access_token) - echo "API_KEY=$TOKEN" >> "$GITHUB_ENV" - # Load the GGUF (mmproj is auto-detected via the HF repo - # lookup, the cached file is pulled out of HF_HOME). - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/inference/load" \ - -H "Authorization: Bearer $TOKEN" -H 'content-type: application/json' \ - --max-time 900 \ - -d "{\"model_path\":\"$GGUF_REPO\",\"gguf_variant\":\"$GGUF_VARIANT\",\"is_lora\":false,\"max_seq_length\":2048}" \ - | jq '{status, display_name, is_vision}' - - - name: JSON schema decoding + image input - env: - BASE_URL: http://127.0.0.1:18890 - run: | - python - <<'PY' - import base64 - import json - import os - import urllib.request - from openai import OpenAI - from anthropic import Anthropic - - BASE = os.environ["BASE_URL"] - KEY = os.environ["API_KEY"] - SEED = 3407 - - def post(path, body, *, timeout = 240): - req = urllib.request.Request( - f"{BASE}{path}", - data = json.dumps(body).encode(), - method = "POST", - headers = { - "Authorization": f"Bearer {KEY}", - "Content-Type": "application/json", - }, - ) - with urllib.request.urlopen(req, timeout = timeout) as resp: - return resp.status, json.loads(resp.read().decode()) - - # ── 1. response_format = json_object (JSON mode) ───────────── - # llama.cpp's HTTP server supports OpenAI-compatible JSON - # mode: `response_format: {"type": "json_object"}` constrains - # the model to emit syntactically-valid JSON. We use raw HTTP - # rather than the OpenAI SDK so that the field shape Studio - # forwards to llama-server is unambiguous (the SDK rewrites - # response_format depending on which variant it recognises). - # We deliberately do NOT pass a strict JSON schema -- on - # small Gemma-4 quants the GBNF-from-schema path occasionally - # produces empty output, and JSON mode is the surface we care - # about exposing through Studio. - status, data = post("/v1/chat/completions", { - "model": "default", - "messages": [ - {"role": "system", "content": 'Reply with a single JSON object of the form {"city": "...", "country": "..."}. Output ONLY the JSON, nothing else.'}, - {"role": "user", "content": "What is the capital of France?"}, - ], - "temperature": 0.0, - "max_tokens": 200, - "seed": SEED, - "stream": False, - "enable_thinking": False, - "response_format": {"type": "json_object"}, - }, timeout = 600) - assert status == 200, f"json status {status}: {data}" - content = (data["choices"][0]["message"].get("content") or "").strip() - # Some chat templates wrap JSON in ```json fences even in JSON - # mode -- strip those before parsing. - if content.startswith("```"): - content = content.split("```", 2)[1] - if content.startswith("json"): - content = content[4:] - content = content.strip("`\n ") - parsed = json.loads(content) - assert "paris" in str(parsed.get("city", "")).lower(), ( - f"city != Paris: {parsed}" - ) - print(f"[json] PASS json_object -> {parsed}") - - # ── 2. OpenAI image_url (data URI base64) ─────────────────── - # 64x64 solid-red PNG. stb_image (used by Studio's image - # normaliser at routes/inference.py:3410) rejects 4x4 or - # smaller PNGs as truncated, so we go up to 64x64 -- still - # tiny in token cost. The assertion is loose: any non-empty - # response from the vision path proves multimodal end-to-end - # wiring; small VL quants are weak at colour identification. - PNG_64X64_RED_B64 = ( - "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAIAAAAlC+aJAAAAYklEQVR4nO3PMQ0AIADAMEAI/k" - "UhBhEcDcmqYJtn7/GzpQNeNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA" - "1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaBdCJ0BmMJ25zMAAAAASUVORK5CYII=" - ) - data_uri = f"data:image/png;base64,{PNG_64X64_RED_B64}" - - client = OpenAI(base_url = f"{BASE}/v1", api_key = KEY) - openai_resp = client.chat.completions.create( - model = "default", - temperature = 0.0, - max_tokens = 80, - seed = SEED, - messages = [{ - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": data_uri}}, - {"type": "text", "text": "What colour dominates this image? Reply in one word."}, - ], - }], - ) - openai_text = (openai_resp.choices[0].message.content or "").lower() - print(f"[image/openai] reply: {openai_text!r}") - assert openai_text, "OpenAI image_url returned empty content" - # We do not strictly require 'red' -- some quants of small VL - # models are weak at colour names. Just require a non-empty - # answer; the vision path is the part under test. - print("[image/openai] PASS image_url accepted, non-empty response") - - # ── 3. Anthropic source/base64 image ──────────────────────── - # Two SDK quirks vs. Studio: base_url must NOT include /v1 - # (the SDK appends it itself; otherwise /v1/v1/messages -> 405), - # and Studio's auth is HTTPBearer-only so the SDK's default - # x-api-key header is ignored -- send Authorization: Bearer - # via default_headers. - anthropic = Anthropic( - base_url = BASE, - api_key = "unused", - default_headers = {"Authorization": f"Bearer {KEY}"}, - ) - a_msg = anthropic.messages.create( - model = "default", - max_tokens = 80, - temperature = 0.0, - extra_body = {"seed": SEED}, - messages = [{ - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": PNG_64X64_RED_B64, - }, - }, - {"type": "text", "text": "Describe this image briefly."}, - ], - }], - ) - a_text = "".join(b.text for b in a_msg.content if getattr(b, "type", None) == "text") - print(f"[image/anthropic] reply: {a_text!r}") - assert a_text, "Anthropic source/base64 returned empty content" - print("[image/anthropic] PASS source/base64 accepted, non-empty response") - PY - - - name: Stop Studio - if: always() - run: | - kill "${STUDIO_PID}" 2>/dev/null || true - sleep 2 - ss -tln | grep ":${STUDIO_PORT}" || true - - - name: Upload logs - # Always upload so green runs are still reviewable. - if: always() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: json-images-log - path: | - logs/studio.log - logs/install.log - retention-days: 7 diff --git a/.github/workflows/studio-tauri-smoke.yml b/.github/workflows/studio-tauri-smoke.yml deleted file mode 100644 index 1156c264ae..0000000000 --- a/.github/workflows/studio-tauri-smoke.yml +++ /dev/null @@ -1,128 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -# PR-time smoke for the Tauri desktop wrapper. Builds the frontend and the -# Tauri Linux debug binary, with no codesigning. Catches: -# - tauri.conf.json drift -# - src-tauri Cargo.toml or rust source breakage -# - Tauri CLI version drift (we pin 2.10.1, matching release-desktop.yml) -# - frontend output not picked up by Tauri's distDir -# -# Linux-only on a free `ubuntu-latest` runner. Mac and Windows desktop builds -# stay in release-desktop.yml (manual `workflow_dispatch`) because they need -# code-signing secrets and ~30 min of runner time each. - -name: Studio Tauri CI - -on: - pull_request: - paths: - - 'studio/frontend/**' - - 'studio/src-tauri/**' - # CLI rename / signature change can break Tauri's spawned - # `unsloth studio` -- include unsloth_cli in the trigger set. - - 'unsloth_cli/**' - - '.github/workflows/studio-tauri-smoke.yml' - push: - branches: [main, pip] - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - linux-debug-build: - name: Tauri Linux debug build (no codesign) - runs-on: ubuntu-22.04 - timeout-minutes: 25 - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - - name: Linux native deps for Tauri / WebKit2GTK - run: | - sudo apt-get update - sudo apt-get install -y \ - libwebkit2gtk-4.1-dev libayatana-appindicator3-dev \ - librsvg2-dev libxdo-dev libssl-dev patchelf - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '24' - - - uses: dtolnay/rust-toolchain@29eef336d9b2848a0b548edc03f92a220660cdb8 # stable @ 2026-03-27 - - - uses: swatinem/rust-cache@e18b497796c12c097a38f9edb9d0641fb99eee32 # v2.9.1 - with: - workspaces: studio/src-tauri -> target - - - name: Install pinned Tauri CLI (matches release-desktop.yml) - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: npm install --save-dev --prefix studio @tauri-apps/cli@2.10.1 --no-fund --no-audit - - - name: Verify pinned Tauri CLI version - run: | - out="$(npx --prefix studio tauri --version)" - echo "$out" - [ "$out" = "tauri-cli 2.10.1" ] || { echo "::error::expected tauri-cli 2.10.1, got $out"; exit 1; } - - - name: Lockfile supply-chain audit (pre-install scan) - run: python3 scripts/lockfile_supply_chain_audit.py - - - name: Frontend build (npm ci, vite) - working-directory: studio/frontend - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: | - npm ci --no-fund --no-audit - npm run build - test -f dist/index.html - - - name: Tauri debug build (Linux, no bundle, no codesign) - # `--debug` + `--no-bundle` keeps this lean: compiles the Rust crate, - # confirms the frontend dist is wired into Tauri, but skips the AppImage - # / .deb production. Code signing is irrelevant because we never produce - # a distributable artifact. - env: - TAURI_SIGNING_PRIVATE_KEY: '' - TAURI_SIGNING_PRIVATE_KEY_PASSWORD: '' - run: npx --prefix studio tauri build --debug --no-bundle - - - name: Inspect produced binary - run: | - BIN=$(find studio/src-tauri/target/debug -maxdepth 1 -type f -executable 2>/dev/null \ - | grep -Ev '\.(d|so|dylib|dll)$' \ - | grep -Ev '/(deps|build|examples)$' \ - | head -1) - echo "binary: $BIN" - if [ -z "$BIN" ]; then - echo "::error::Tauri debug binary not produced" - ls -la studio/src-tauri/target/debug/ || true - exit 1 - fi - file "$BIN" - du -h "$BIN" - - - name: Upload Tauri debug build - # Always upload so a green run leaves the binary inspectable too. - if: always() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: tauri-debug-build - path: | - studio/src-tauri/target/debug - studio/frontend/dist - retention-days: 3 diff --git a/.github/workflows/wheel-smoke.yml b/.github/workflows/wheel-smoke.yml deleted file mode 100644 index 3de3c33ca2..0000000000 --- a/.github/workflows/wheel-smoke.yml +++ /dev/null @@ -1,136 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -# Builds the PyPI wheel from the PR branch, then verifies the built wheel -# actually contains what we expect to ship and does NOT contain the broken -# Studio bundle that 2026.5.1 published. This is the single workflow that -# would have blocked the 2026.5.1 release before twine upload. -# -# Verified locally end-to-end against this branch: -# - python -m build produces unsloth--py3-none-any.whl in 13s -# - wheel content sanity passes: -# lockfile shipped, frontend dist shipped, -# no node_modules in wheel, no bun.lock in wheel, -# main bundle has unstable_Provider hits=1 (assistant-ui internals only). -# - Studio backend imports cleanly from the installed wheel with the -# lightweight dep set below. - -name: Wheel CI - -on: - pull_request: - paths: - - 'pyproject.toml' - - 'studio/**' - - 'unsloth/**' - - 'unsloth_cli/**' - - '.github/workflows/wheel-smoke.yml' - push: - branches: [main, pip] - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - wheel: - name: Wheel build + content sanity + import smoke - runs-on: ubuntu-latest - timeout-minutes: 15 - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '22' - - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: '3.12' - - - name: Lockfile supply-chain audit (pre-install scan) - run: python3 scripts/lockfile_supply_chain_audit.py - - - name: Build frontend - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: | - cd studio/frontend - npm ci --no-fund --no-audit - npm run build - - - name: Build wheel + sdist - run: | - python -m pip install --upgrade pip build - rm -rf dist build ./*.egg-info - python -m build - - - name: Wheel content sanity - run: | - python - <<'PY' - import zipfile, glob, sys - w = glob.glob("dist/unsloth-*.whl") - if not w: - print("FAIL: no wheel produced"); sys.exit(2) - w = w[0] - print(f"wheel: {w}") - with zipfile.ZipFile(w) as z: - n = z.namelist() - checks = { - "lockfile shipped": any(s.endswith("studio/frontend/package-lock.json") for s in n), - "frontend dist shipped": any(s.endswith("studio/frontend/dist/index.html") for s in n), - "no node_modules": not any("studio/frontend/node_modules/" in s for s in n), - "no bun.lock": not any(s.endswith("studio/frontend/bun.lock") for s in n), - } - js = [s for s in n - if "studio/frontend/dist/assets/" in s - and s.endswith(".js") - and "/index-" in s] - if not js: - print("FAIL: no main bundle index-*.js in wheel"); sys.exit(2) - data = z.read(js[0]).decode("utf-8", "replace") - hits = data.count("unstable_Provider:") - print(f"main bundle: {js[0]}") - print(f"unstable_Provider hits: {hits} (>=4 indicates 2026.5.1 regression)") - checks["bundle has no Studio unstable_Provider call site"] = (hits < 4) - - print() - for k, v in checks.items(): - print(f" [{'PASS' if v else 'FAIL'}] {k}") - sys.exit(0 if all(checks.values()) else 1) - PY - - - name: Studio backend import smoke - # Imports `studio.backend.main:app` from the freshly-installed wheel in - # a clean venv. This catches the class of bug that 2026.5.1 shipped with: - # frontend dist missing, package-lock.json missing, or the wheel's Python - # source tree broken in a way that surfaces only at app construction time. - run: | - python -m venv /tmp/v - /tmp/v/bin/pip install --upgrade pip - /tmp/v/bin/pip install -r studio/backend/requirements/studio.txt - /tmp/v/bin/pip install \ - python-multipart aiofiles sqlalchemy cryptography \ - pyyaml jinja2 mammoth unpdf requests \ - 'numpy<3' - /tmp/v/bin/pip install --no-deps dist/unsloth-*.whl - # Run from /tmp so Python imports the installed package, not the source tree. - cd /tmp - /tmp/v/bin/python -c "from studio.backend.main import app; print('Studio backend OK:', app.title)" - - - name: Upload wheel on failure - if: failure() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: unsloth-wheel - path: dist/ - retention-days: 7 From 19f1718c0a2e60324aedfe17f9879504cc11e67a Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 27 May 2026 05:09:05 +0000 Subject: [PATCH 10/11] ci: add PR-5351 CPU-inference cross-OS lanes Adds a CPU end-to-end smoke that exercises: - the PR's `_extract_pdf` against an in-process synthetic PDF - llama-cpp-python (CPU build) loading Qwen2.5-0.5B-Instruct GGUF - inference on the extracted markdown with a ground-truth question Runs on ubuntu-latest, macos-14, and windows-latest with no GPU. Disables Metal on macOS and native autodetect on Windows/Linux so the lanes stay strictly CPU. Path-filtered to studio/backend/core/chat/, the test itself, and each workflow file so unrelated commits don't re-trigger. --- .../workflows/pr5351-cpu-inference-macos.yml | 52 ++++++ .../workflows/pr5351-cpu-inference-ubuntu.yml | 53 ++++++ .../pr5351-cpu-inference-windows.yml | 49 ++++++ ...est_cpu_inference_on_extracted_document.py | 157 ++++++++++++++++++ 4 files changed, 311 insertions(+) create mode 100644 .github/workflows/pr5351-cpu-inference-macos.yml create mode 100644 .github/workflows/pr5351-cpu-inference-ubuntu.yml create mode 100644 .github/workflows/pr5351-cpu-inference-windows.yml create mode 100644 tests/studio/test_cpu_inference_on_extracted_document.py diff --git a/.github/workflows/pr5351-cpu-inference-macos.yml b/.github/workflows/pr5351-cpu-inference-macos.yml new file mode 100644 index 0000000000..df154f7354 --- /dev/null +++ b/.github/workflows/pr5351-cpu-inference-macos.yml @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 CPU-inference cross-OS lane: macOS (Apple Silicon). +# Same as the Ubuntu lane but on macos-14. llama-cpp-python builds +# with Metal autodetect disabled to stay on the CPU code path so the +# result mirrors a non-GPU Mac. + +name: PR-5351 CPU inference macOS + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/core/chat/**' + - 'tests/studio/test_cpu_inference_on_extracted_document.py' + - '.github/workflows/pr5351-cpu-inference-macos.yml' + workflow_dispatch: + +concurrency: + group: pr5351-cpu-inference-macos-${{ github.ref }} + cancel-in-progress: true + +jobs: + cpu-inference: + runs-on: macos-14 + timeout-minutes: 40 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend + llama-cpp-python (CPU build) + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install \ + python-multipart aiofiles sqlalchemy cryptography \ + pyyaml jinja2 mammoth pymupdf pymupdf4llm pytest pytest-asyncio \ + pytest-timeout huggingface_hub requests numpy + # Disable Metal so the lane stays CPU-only; mirrors a no-GPU Mac. + CMAKE_ARGS="-DGGML_METAL=OFF -DGGML_ACCELERATE=OFF -DGGML_NATIVE=OFF" \ + pip install --upgrade --quiet llama-cpp-python + + - name: CPU inference on extracted document + env: + PR5351_LLAMA_THREADS: '3' + run: | + python -m pytest -q tests/studio/test_cpu_inference_on_extracted_document.py -s --tb=short diff --git a/.github/workflows/pr5351-cpu-inference-ubuntu.yml b/.github/workflows/pr5351-cpu-inference-ubuntu.yml new file mode 100644 index 0000000000..4b0a441a12 --- /dev/null +++ b/.github/workflows/pr5351-cpu-inference-ubuntu.yml @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 CPU-inference cross-OS lane: Ubuntu. +# Builds llama-cpp-python from source for CPU, downloads a 0.5B GGUF +# from HF, extracts a synthetic PDF via the PR's document extractor, +# and asserts the model answers a ground-truth question. Proves +# end-to-end document-attach -> extract -> inference works on a CPU +# runner with no GPU. + +name: PR-5351 CPU inference Ubuntu + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/core/chat/**' + - 'tests/studio/test_cpu_inference_on_extracted_document.py' + - '.github/workflows/pr5351-cpu-inference-ubuntu.yml' + workflow_dispatch: + +concurrency: + group: pr5351-cpu-inference-ubuntu-${{ github.ref }} + cancel-in-progress: true + +jobs: + cpu-inference: + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend + llama-cpp-python (CPU build) + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install \ + python-multipart aiofiles sqlalchemy cryptography \ + pyyaml jinja2 mammoth pymupdf pymupdf4llm pytest pytest-asyncio \ + pytest-timeout huggingface_hub requests numpy + # CPU wheel ships pre-built on Linux; falls back to source if needed. + CMAKE_ARGS="-DGGML_NATIVE=OFF" pip install --upgrade --quiet llama-cpp-python + + - name: CPU inference on extracted document + env: + PR5351_LLAMA_THREADS: '4' + run: | + python -m pytest -q tests/studio/test_cpu_inference_on_extracted_document.py -s --tb=short diff --git a/.github/workflows/pr5351-cpu-inference-windows.yml b/.github/workflows/pr5351-cpu-inference-windows.yml new file mode 100644 index 0000000000..50972f17e7 --- /dev/null +++ b/.github/workflows/pr5351-cpu-inference-windows.yml @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 CPU-inference cross-OS lane: Windows. +# llama-cpp-python wheels exist for Windows; if pip falls back to +# source, MSVC is preinstalled on windows-latest. CPU-only. + +name: PR-5351 CPU inference Windows + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/core/chat/**' + - 'tests/studio/test_cpu_inference_on_extracted_document.py' + - '.github/workflows/pr5351-cpu-inference-windows.yml' + workflow_dispatch: + +concurrency: + group: pr5351-cpu-inference-windows-${{ github.ref }} + cancel-in-progress: true + +jobs: + cpu-inference: + runs-on: windows-latest + timeout-minutes: 40 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend + llama-cpp-python (CPU build) + shell: pwsh + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install python-multipart aiofiles sqlalchemy cryptography pyyaml jinja2 mammoth pymupdf pymupdf4llm pytest pytest-asyncio pytest-timeout huggingface_hub requests numpy + $env:CMAKE_ARGS = "-DGGML_NATIVE=OFF" + pip install --upgrade --quiet llama-cpp-python + + - name: CPU inference on extracted document + shell: pwsh + env: + PR5351_LLAMA_THREADS: '4' + run: | + python -m pytest -q tests/studio/test_cpu_inference_on_extracted_document.py -s --tb=short diff --git a/tests/studio/test_cpu_inference_on_extracted_document.py b/tests/studio/test_cpu_inference_on_extracted_document.py new file mode 100644 index 0000000000..9f2afeadb8 --- /dev/null +++ b/tests/studio/test_cpu_inference_on_extracted_document.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. + +"""PR-5351 cross-OS CPU-inference smoke test. + +End-to-end: extract a small public PDF locally (no network during +extraction), then feed the extracted markdown into a tiny GGUF via +llama-cpp-python on CPU and assert the model identifies the document. + +Runs on ubuntu-latest / macos-14 / windows-latest GitHub-Actions +runners. CPU-only; no real GPU is required because the test path +imports `_extract_pdf` directly and runs llama-cpp-python's CPU build. +""" + +from __future__ import annotations + +import importlib +import io +import os +import sys +import textwrap +from pathlib import Path + +import pytest + + +def _make_text_pdf(body: str) -> bytes: + """Build a tiny one-page PDF whose stream is the literal `body`. + + Avoids pulling a real LaTeX/wkhtmltopdf chain into CI -- the PR's + pymupdf-based extractor recovers the text via its standard pdfminer + fallback path even without a content-stream filter. + """ + pdf = io.BytesIO() + pdf.write(b"%PDF-1.4\n") + objects = [] + + def write(obj_bytes: bytes) -> int: + offset = pdf.tell() + objects.append(offset) + pdf.write(obj_bytes) + return len(objects) + + write(b"1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n") + write(b"2 0 obj\n<< /Type /Pages /Kids [3 0 R] /Count 1 >>\nendobj\n") + write( + b"3 0 obj\n<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] " + b"/Contents 4 0 R /Resources << /Font << /F1 5 0 R >> >> >>\nendobj\n" + ) + text_stream = ( + "BT\n/F1 12 Tf\n72 720 Td\n" + + "\n".join( + f"({line}) Tj T* " + for line in body.splitlines() + if line.strip() + ) + + "\nET\n" + ) + stream_bytes = text_stream.encode("latin-1", errors="replace") + write( + f"4 0 obj\n<< /Length {len(stream_bytes)} >>\nstream\n".encode("latin-1") + + stream_bytes + + b"\nendstream\nendobj\n" + ) + write(b"5 0 obj\n<< /Type /Font /Subtype /Type1 /BaseFont /Helvetica >>\nendobj\n") + xref_offset = pdf.tell() + pdf.write(f"xref\n0 {len(objects) + 1}\n0000000000 65535 f \n".encode()) + for off in objects: + pdf.write(f"{off:010d} 00000 n \n".encode()) + pdf.write( + f"trailer\n<< /Size {len(objects) + 1} /Root 1 0 R >>\n" + f"startxref\n{xref_offset}\n%%EOF\n".encode() + ) + return pdf.getvalue() + + +@pytest.fixture(scope="module") +def extractor(): + """Import the PR's `_extract_pdf` directly so this is a unit-level + test of the extractor + a CPU integration test of llama-cpp-python.""" + sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "studio" / "backend")) + mod = importlib.import_module("core.chat.document_extractor") + return mod._extract_pdf + + +@pytest.fixture(scope="module") +def llama(): + """Load a tiny GGUF on CPU. Skips if llama-cpp-python isn't installed.""" + pytest.importorskip("llama_cpp") + from huggingface_hub import hf_hub_download + from llama_cpp import Llama + + cache_dir = Path(os.environ.get("PR5351_GGUF_CACHE", str(Path.home() / ".cache" / "pr5351_gguf"))) + cache_dir.mkdir(parents=True, exist_ok=True) + # Tiny instruction-tuned model that fits 7 GB CPU runners. + repo = "unsloth/Qwen2.5-0.5B-Instruct-GGUF" + fname = "Qwen2.5-0.5B-Instruct-Q4_K_M.gguf" + path = hf_hub_download( + repo_id=repo, + filename=fname, + local_dir=str(cache_dir), + ) + return Llama( + model_path=path, + n_ctx=4096, + n_threads=int(os.environ.get("PR5351_LLAMA_THREADS", "2")), + verbose=False, + ) + + +@pytest.mark.timeout(900) +def test_cpu_inference_identifies_extracted_document(extractor, llama, tmp_path): + """Extract a synthetic PDF and have a 0.5B model identify it.""" + body = textwrap.dedent( + """ + RFC 8259 The JavaScript Object Notation (JSON) Data Interchange Format + Internet Engineering Task Force + Abstract: JSON is a lightweight, text-based, language-independent data + interchange format. It was derived from the JavaScript programming + language. JSON defines a small set of formatting rules for the + portable representation of structured data. + """ + ).strip() + pdf_bytes = _make_text_pdf(body) + + text, figures, *_ = extractor(pdf_bytes) + assert "JSON" in text or "Object Notation" in text, ( + f"Extractor lost the body text. Got: {text[:200]!r}" + ) + + prompt = textwrap.dedent( + f""" + You read attached documents and answer in 1-2 sentences. + + [DOCUMENT] + {text[:3000]} + [/DOCUMENT] + + Question: Which RFC number does this document define and what is JSON? + Answer: + """ + ).strip() + + out = llama( + prompt, + max_tokens=160, + temperature=0.2, + stop=["\n\n", "", "<|im_end|>"], + ) + answer = out["choices"][0]["text"].strip().lower() + print(f"\n[answer]\n{answer}\n") + + matched_keywords = [kw for kw in ("8259", "json", "object notation") if kw in answer] + assert len(matched_keywords) >= 2, ( + f"Answer missed too many keywords. Got: {answer!r}; " + f"matched: {matched_keywords}" + ) From 8efab55dd58f0b39868559e2cc98b4252877bd6f Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 27 May 2026 05:43:09 +0000 Subject: [PATCH 11/11] ci: fix PR-5351 CPU-inference test repo + extractor signature Two corrections after the first run: - Point at Qwen/Qwen2.5-0.5B-Instruct-GGUF (the canonical Qwen team's repo); the unsloth/* fork at that name does not exist and returned 401 on all three runners. - Pass the required `max_figures`, `use_vlm_ocr`, and `max_visual_payloads` kwargs to `_extract_pdf`. Verified locally on the merge tip: PYTHONPATH=studio/backend python -c '...' -> extracted 97 chars including the expected 'RFC 8259' / 'JSON' tokens. --- .../test_cpu_inference_on_extracted_document.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/studio/test_cpu_inference_on_extracted_document.py b/tests/studio/test_cpu_inference_on_extracted_document.py index 9f2afeadb8..aaf953cc63 100644 --- a/tests/studio/test_cpu_inference_on_extracted_document.py +++ b/tests/studio/test_cpu_inference_on_extracted_document.py @@ -93,8 +93,8 @@ def llama(): cache_dir = Path(os.environ.get("PR5351_GGUF_CACHE", str(Path.home() / ".cache" / "pr5351_gguf"))) cache_dir.mkdir(parents=True, exist_ok=True) # Tiny instruction-tuned model that fits 7 GB CPU runners. - repo = "unsloth/Qwen2.5-0.5B-Instruct-GGUF" - fname = "Qwen2.5-0.5B-Instruct-Q4_K_M.gguf" + repo = "Qwen/Qwen2.5-0.5B-Instruct-GGUF" + fname = "qwen2.5-0.5b-instruct-q4_k_m.gguf" path = hf_hub_download( repo_id=repo, filename=fname, @@ -123,7 +123,12 @@ def test_cpu_inference_identifies_extracted_document(extractor, llama, tmp_path) ).strip() pdf_bytes = _make_text_pdf(body) - text, figures, *_ = extractor(pdf_bytes) + text, figures, *_ = extractor( + pdf_bytes, + max_figures=0, + use_vlm_ocr=False, + max_visual_payloads=0, + ) assert "JSON" in text or "Object Notation" in text, ( f"Extractor lost the body text. Got: {text[:200]!r}" )