diff --git a/.github/workflows/pr5351-cpu-inference-macos.yml b/.github/workflows/pr5351-cpu-inference-macos.yml
new file mode 100644
index 0000000000..df154f7354
--- /dev/null
+++ b/.github/workflows/pr5351-cpu-inference-macos.yml
@@ -0,0 +1,52 @@
+# SPDX-License-Identifier: AGPL-3.0-only
+# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
+#
+# PR-5351 CPU-inference cross-OS lane: macOS (Apple Silicon).
+# Same as the Ubuntu lane but on macos-14. llama-cpp-python builds
+# with Metal autodetect disabled to stay on the CPU code path so the
+# result mirrors a non-GPU Mac.
+
+name: PR-5351 CPU inference macOS
+
+on:
+ push:
+ branches: [pr-5351-cross-os-validation]
+ paths:
+ - 'studio/backend/core/chat/**'
+ - 'tests/studio/test_cpu_inference_on_extracted_document.py'
+ - '.github/workflows/pr5351-cpu-inference-macos.yml'
+ workflow_dispatch:
+
+concurrency:
+ group: pr5351-cpu-inference-macos-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ cpu-inference:
+ runs-on: macos-14
+ timeout-minutes: 40
+ steps:
+ - uses: actions/checkout@v4
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.11'
+ cache: 'pip'
+
+ - name: Install backend + llama-cpp-python (CPU build)
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r studio/backend/requirements/studio.txt
+ pip install \
+ python-multipart aiofiles sqlalchemy cryptography \
+ pyyaml jinja2 mammoth pymupdf pymupdf4llm pytest pytest-asyncio \
+ pytest-timeout huggingface_hub requests numpy
+ # Disable Metal so the lane stays CPU-only; mirrors a no-GPU Mac.
+ CMAKE_ARGS="-DGGML_METAL=OFF -DGGML_ACCELERATE=OFF -DGGML_NATIVE=OFF" \
+ pip install --upgrade --quiet llama-cpp-python
+
+ - name: CPU inference on extracted document
+ env:
+ PR5351_LLAMA_THREADS: '3'
+ run: |
+ python -m pytest -q tests/studio/test_cpu_inference_on_extracted_document.py -s --tb=short
diff --git a/.github/workflows/pr5351-cpu-inference-ubuntu.yml b/.github/workflows/pr5351-cpu-inference-ubuntu.yml
new file mode 100644
index 0000000000..4b0a441a12
--- /dev/null
+++ b/.github/workflows/pr5351-cpu-inference-ubuntu.yml
@@ -0,0 +1,53 @@
+# SPDX-License-Identifier: AGPL-3.0-only
+# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
+#
+# PR-5351 CPU-inference cross-OS lane: Ubuntu.
+# Builds llama-cpp-python from source for CPU, downloads a 0.5B GGUF
+# from HF, extracts a synthetic PDF via the PR's document extractor,
+# and asserts the model answers a ground-truth question. Proves
+# end-to-end document-attach -> extract -> inference works on a CPU
+# runner with no GPU.
+
+name: PR-5351 CPU inference Ubuntu
+
+on:
+ push:
+ branches: [pr-5351-cross-os-validation]
+ paths:
+ - 'studio/backend/core/chat/**'
+ - 'tests/studio/test_cpu_inference_on_extracted_document.py'
+ - '.github/workflows/pr5351-cpu-inference-ubuntu.yml'
+ workflow_dispatch:
+
+concurrency:
+ group: pr5351-cpu-inference-ubuntu-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ cpu-inference:
+ runs-on: ubuntu-latest
+ timeout-minutes: 30
+ steps:
+ - uses: actions/checkout@v4
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.11'
+ cache: 'pip'
+
+ - name: Install backend + llama-cpp-python (CPU build)
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r studio/backend/requirements/studio.txt
+ pip install \
+ python-multipart aiofiles sqlalchemy cryptography \
+ pyyaml jinja2 mammoth pymupdf pymupdf4llm pytest pytest-asyncio \
+ pytest-timeout huggingface_hub requests numpy
+ # CPU wheel ships pre-built on Linux; falls back to source if needed.
+ CMAKE_ARGS="-DGGML_NATIVE=OFF" pip install --upgrade --quiet llama-cpp-python
+
+ - name: CPU inference on extracted document
+ env:
+ PR5351_LLAMA_THREADS: '4'
+ run: |
+ python -m pytest -q tests/studio/test_cpu_inference_on_extracted_document.py -s --tb=short
diff --git a/.github/workflows/pr5351-cpu-inference-windows.yml b/.github/workflows/pr5351-cpu-inference-windows.yml
new file mode 100644
index 0000000000..50972f17e7
--- /dev/null
+++ b/.github/workflows/pr5351-cpu-inference-windows.yml
@@ -0,0 +1,49 @@
+# SPDX-License-Identifier: AGPL-3.0-only
+# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
+#
+# PR-5351 CPU-inference cross-OS lane: Windows.
+# llama-cpp-python wheels exist for Windows; if pip falls back to
+# source, MSVC is preinstalled on windows-latest. CPU-only.
+
+name: PR-5351 CPU inference Windows
+
+on:
+ push:
+ branches: [pr-5351-cross-os-validation]
+ paths:
+ - 'studio/backend/core/chat/**'
+ - 'tests/studio/test_cpu_inference_on_extracted_document.py'
+ - '.github/workflows/pr5351-cpu-inference-windows.yml'
+ workflow_dispatch:
+
+concurrency:
+ group: pr5351-cpu-inference-windows-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ cpu-inference:
+ runs-on: windows-latest
+ timeout-minutes: 40
+ steps:
+ - uses: actions/checkout@v4
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.11'
+ cache: 'pip'
+
+ - name: Install backend + llama-cpp-python (CPU build)
+ shell: pwsh
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r studio/backend/requirements/studio.txt
+ pip install python-multipart aiofiles sqlalchemy cryptography pyyaml jinja2 mammoth pymupdf pymupdf4llm pytest pytest-asyncio pytest-timeout huggingface_hub requests numpy
+ $env:CMAKE_ARGS = "-DGGML_NATIVE=OFF"
+ pip install --upgrade --quiet llama-cpp-python
+
+ - name: CPU inference on extracted document
+ shell: pwsh
+ env:
+ PR5351_LLAMA_THREADS: '4'
+ run: |
+ python -m pytest -q tests/studio/test_cpu_inference_on_extracted_document.py -s --tb=short
diff --git a/.github/workflows/pr5351-macos.yml b/.github/workflows/pr5351-macos.yml
new file mode 100644
index 0000000000..6bb149659b
--- /dev/null
+++ b/.github/workflows/pr5351-macos.yml
@@ -0,0 +1,60 @@
+# SPDX-License-Identifier: AGPL-3.0-only
+# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
+#
+# PR-5351 cross-OS validation: macOS lane.
+# macos-14 (arm64). Validates the multiprocessing `spawn` path that
+# differs from Linux's default `fork`, the MLX detection branch in
+# core/chat/vlm_capability.py, and Safari/WebKit-relevant filesystem
+# behaviour. CPU-only; CUDA spoof auto-engages via tests/conftest.py.
+
+name: PR-5351 macOS
+
+on:
+ push:
+ branches: [pr-5351-cross-os-validation]
+ paths:
+ - 'studio/backend/**'
+ - 'tests/studio/**'
+ - 'tests/conftest.py'
+ - '.github/workflows/pr5351-macos.yml'
+ workflow_dispatch:
+
+concurrency:
+ group: pr5351-macos-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ pytest:
+ runs-on: macos-14
+ timeout-minutes: 25
+ steps:
+ - uses: actions/checkout@v4
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.11'
+ cache: 'pip'
+
+ - name: Install backend test dependencies (CPU only)
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r studio/backend/requirements/studio.txt
+ pip install \
+ python-multipart aiofiles sqlalchemy cryptography \
+ pyyaml jinja2 mammoth unpdf requests \
+ 'numpy<3' pytest pytest-asyncio httpx
+ pip install --index-url https://download.pytorch.org/whl/cpu 'torch>=2.4,<2.11'
+ pip install 'transformers>=4.51,<5.5'
+
+ - name: PR-5351 document tests (macOS spawn semantics)
+ working-directory: studio/backend
+ env:
+ # macOS's default start method is spawn; exercise the same
+ # config users see in production.
+ UNSLOTH_STUDIO_EXTRACT_CONCURRENCY: '2'
+ run: |
+ python -m pytest -q tests/test_chat_document_extraction.py tests/test_chat_document_routes.py tests/test_inference_worker.py tests/test_vision_cache.py tests/test_anthropic_messages.py tests/test_openai_tool_passthrough.py tests/test_models_get_model_config_case_resolution.py --tb=short
+
+ - name: PR-5351 regression tests + cancel timing
+ run: |
+ python -m pytest -q tests/studio/test_extractor_semaphore_leak.py tests/studio/test_html_independent_of_inference.py tests/studio/test_gguf_singleton_shared.py tests/studio/test_stream_cancel_registration_timing.py --tb=short
diff --git a/.github/workflows/pr5351-ubuntu.yml b/.github/workflows/pr5351-ubuntu.yml
new file mode 100644
index 0000000000..d1dd6d8712
--- /dev/null
+++ b/.github/workflows/pr5351-ubuntu.yml
@@ -0,0 +1,57 @@
+# SPDX-License-Identifier: AGPL-3.0-only
+# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
+#
+# PR-5351 cross-OS validation: Ubuntu lane.
+# Runs the document-extraction tests, the cancellation-timing structural
+# test, and the three regression tests added in the fix commit against
+# Python 3.11 on ubuntu-latest. CPU-only; the existing tests/conftest.py
+# auto-installs the CUDA spoof so unsloth/unsloth_zoo device probes
+# return "cuda".
+
+name: PR-5351 Ubuntu
+
+on:
+ push:
+ branches: [pr-5351-cross-os-validation]
+ paths:
+ - 'studio/backend/**'
+ - 'tests/studio/**'
+ - 'tests/conftest.py'
+ - '.github/workflows/pr5351-ubuntu.yml'
+ workflow_dispatch:
+
+concurrency:
+ group: pr5351-ubuntu-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ pytest:
+ runs-on: ubuntu-latest
+ timeout-minutes: 20
+ steps:
+ - uses: actions/checkout@v4
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.11'
+ cache: 'pip'
+
+ - name: Install backend test dependencies (CPU only)
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r studio/backend/requirements/studio.txt
+ pip install \
+ python-multipart aiofiles sqlalchemy cryptography \
+ pyyaml jinja2 mammoth unpdf requests \
+ 'numpy<3' pytest pytest-asyncio httpx
+ pip install --index-url https://download.pytorch.org/whl/cpu 'torch>=2.4,<2.11'
+ pip install 'transformers>=4.51,<5.5'
+
+ - name: PR-5351 document tests
+ working-directory: studio/backend
+ run: |
+ python -m pytest -q tests/test_chat_document_extraction.py tests/test_chat_document_routes.py tests/test_inference_worker.py tests/test_vision_cache.py tests/test_anthropic_messages.py tests/test_openai_tool_passthrough.py tests/test_models_get_model_config_case_resolution.py --tb=short
+
+ - name: PR-5351 regression tests + cancel timing
+ run: |
+ python -m pytest -q tests/studio/test_extractor_semaphore_leak.py tests/studio/test_html_independent_of_inference.py tests/studio/test_gguf_singleton_shared.py tests/studio/test_stream_cancel_registration_timing.py --tb=short
diff --git a/.github/workflows/pr5351-windows.yml b/.github/workflows/pr5351-windows.yml
new file mode 100644
index 0000000000..777e1c38ec
--- /dev/null
+++ b/.github/workflows/pr5351-windows.yml
@@ -0,0 +1,59 @@
+# SPDX-License-Identifier: AGPL-3.0-only
+# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
+#
+# PR-5351 cross-OS validation: Windows lane.
+# windows-latest. Validates the multiprocessing `spawn` path
+# (mandatory on Windows), path normalisation, and EAGAIN-style
+# Process construction failures under load (the exact bug class the
+# semaphore-leak fix protects against). CPU-only; CUDA spoof
+# auto-engages via tests/conftest.py.
+
+name: PR-5351 Windows
+
+on:
+ push:
+ branches: [pr-5351-cross-os-validation]
+ paths:
+ - 'studio/backend/**'
+ - 'tests/studio/**'
+ - 'tests/conftest.py'
+ - '.github/workflows/pr5351-windows.yml'
+ workflow_dispatch:
+
+concurrency:
+ group: pr5351-windows-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ pytest:
+ runs-on: windows-latest
+ timeout-minutes: 30
+ steps:
+ - uses: actions/checkout@v4
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: '3.11'
+ cache: 'pip'
+
+ - name: Install backend test dependencies (CPU only)
+ shell: pwsh
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r studio/backend/requirements/studio.txt
+ pip install python-multipart aiofiles sqlalchemy cryptography pyyaml jinja2 mammoth unpdf requests "numpy<3" pytest pytest-asyncio httpx
+ pip install --index-url https://download.pytorch.org/whl/cpu "torch>=2.4,<2.11"
+ pip install "transformers>=4.51,<5.5"
+
+ - name: PR-5351 document tests (Windows spawn semantics)
+ working-directory: studio/backend
+ shell: pwsh
+ env:
+ UNSLOTH_STUDIO_EXTRACT_CONCURRENCY: '2'
+ run: |
+ python -m pytest -q tests/test_chat_document_extraction.py tests/test_chat_document_routes.py tests/test_inference_worker.py tests/test_vision_cache.py tests/test_anthropic_messages.py tests/test_openai_tool_passthrough.py tests/test_models_get_model_config_case_resolution.py --tb=short
+
+ - name: PR-5351 regression tests + cancel timing
+ shell: pwsh
+ run: |
+ python -m pytest -q tests/studio/test_extractor_semaphore_leak.py tests/studio/test_html_independent_of_inference.py tests/studio/test_gguf_singleton_shared.py tests/studio/test_stream_cancel_registration_timing.py --tb=short
diff --git a/.github/workflows/release-desktop.yml b/.github/workflows/release-desktop.yml
deleted file mode 100644
index e747605322..0000000000
--- a/.github/workflows/release-desktop.yml
+++ /dev/null
@@ -1,902 +0,0 @@
-name: Release Desktop App
-
-on:
- workflow_dispatch:
- inputs:
- studio_version:
- description: 'Studio version tag to release (for example, v0.1.39-beta)'
- type: string
- required: true
- pypi_version:
- description: 'Exact PyPI unsloth version just published/stamped (for example, 2026.5.3); leave blank to use MIN_DESKTOP_BACKEND_VERSION'
- type: string
- required: false
- draft:
- description: 'Create as draft release; draft runs do not advance desktop-latest updater channel'
- type: boolean
- default: true
-
-permissions:
- contents: read
-
-concurrency:
- group: release-desktop-${{ github.repository }}
- cancel-in-progress: false
-
-jobs:
- prepare-version:
- name: Prepare release versions
- runs-on: ubuntu-latest
- outputs:
- studio_version: ${{ steps.prepare.outputs.studio_version }}
- app_version: ${{ steps.prepare.outputs.app_version }}
- desktop_release_tag: ${{ steps.prepare.outputs.desktop_release_tag }}
- prerelease: ${{ steps.prepare.outputs.prerelease }}
- pypi_version: ${{ steps.prepare.outputs.pypi_version }}
-
- steps:
- - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
- with:
- persist-credentials: false
-
- - name: Validate release versions
- id: prepare
- shell: bash
- env:
- INPUT_STUDIO_VERSION: ${{ inputs.studio_version }}
- INPUT_PYPI_VERSION: ${{ inputs.pypi_version }}
- run: |
- python3 <<'PY'
- import os
- import pathlib
- import re
- import sys
-
- studio_version = os.environ['INPUT_STUDIO_VERSION'].strip()
- if not studio_version:
- sys.exit('studio_version is required, for example v0.1.39-beta')
- if re.fullmatch(r'v?20\d{2}\.\d+\.\d+(?:[-+][0-9A-Za-z.-]+)?', studio_version):
- sys.exit(f'studio_version must be a Studio SemVer tag, not a date-style backend version: {studio_version}')
-
- semver_tag = re.compile(
- r'^v(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)'
- r'(?:-[0-9A-Za-z.][0-9A-Za-z.-]*)?$'
- )
- if not semver_tag.fullmatch(studio_version):
- sys.exit(f'studio_version must be a SemVer tag with leading v, for example v0.1.39-beta: {studio_version}')
-
- app_version = studio_version.removeprefix('v')
- desktop_release_tag = f'desktop-v{app_version}'
- prerelease = 'true' if '-' in app_version.split('+', 1)[0] else 'false'
-
- def parse_backend_version(version):
- match = re.fullmatch(
- r'(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)'
- r'(?:([a-zA-Z]|\.dev|dev|\.rc|rc|\.post|post)(\d*))?'
- r'(?:[-+]([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?',
- version,
- )
- if not match:
- return None
- major, minor, patch, suffix_name, suffix_number, suffix_text = match.groups()
- if suffix_name:
- normalized = suffix_name.lower().lstrip('.')
- order = {'dev': 0, 'a': 1, 'b': 2, 'rc': 3, 'post': 5}.get(normalized)
- if order is None:
- return None
- number = int(suffix_number or '0')
- elif suffix_text:
- order = 3 if version[version.find(suffix_text) - 1] == '-' else 4
- number = 0
- else:
- order = 4
- number = 0
- return (int(major), int(minor), int(patch), order, number)
-
- preflight = pathlib.Path('studio/src-tauri/src/preflight/version.rs').read_text()
- match = re.search(r'MIN_DESKTOP_BACKEND_VERSION:\s*&str\s*=\s*"([^"]+)"', preflight)
- if not match:
- sys.exit('Could not read MIN_DESKTOP_BACKEND_VERSION')
- min_backend_version = match.group(1)
-
- input_pypi_version = os.environ.get('INPUT_PYPI_VERSION', '').strip()
- parsed_min_backend = parse_backend_version(min_backend_version)
- if parsed_min_backend is None:
- sys.exit(f'MIN_DESKTOP_BACKEND_VERSION is not a supported backend package version: {min_backend_version}')
-
- pypi_version = input_pypi_version or min_backend_version
- parsed_pypi = parse_backend_version(pypi_version)
- if parsed_pypi is None:
- sys.exit(f'pypi_version is not a supported backend package version: {pypi_version}')
- if parsed_pypi < parsed_min_backend:
- sys.exit(
- f'pypi_version {pypi_version} is lower than desktop minimum '
- f'MIN_DESKTOP_BACKEND_VERSION {min_backend_version}'
- )
-
- if input_pypi_version:
- print(
- 'Using exact PyPI unsloth version from pypi_version input: '
- f'{pypi_version} (desktop minimum: {min_backend_version})'
- )
- else:
- print(
- 'Using exact PyPI unsloth version from MIN_DESKTOP_BACKEND_VERSION: '
- f'{pypi_version}'
- )
-
- with open(os.environ['GITHUB_OUTPUT'], 'a', encoding='utf-8') as output:
- print(f'studio_version={studio_version}', file=output)
- print(f'app_version={app_version}', file=output)
- print(f'desktop_release_tag={desktop_release_tag}', file=output)
- print(f'prerelease={prerelease}', file=output)
- print(f'pypi_version={pypi_version}', file=output)
- PY
-
- - name: Verify PyPI package and Studio stamp
- shell: bash
- env:
- STUDIO_VERSION: ${{ steps.prepare.outputs.studio_version }}
- PYPI_VERSION: ${{ steps.prepare.outputs.pypi_version }}
- run: |
- set -euo pipefail
- python3 <<'PY'
- import json
- import os
- import pathlib
- import sys
- import time
- import urllib.error
- import urllib.request
-
- pypi_version = os.environ['PYPI_VERSION']
- dist_dir = pathlib.Path(os.environ['RUNNER_TEMP'], 'pypi-unsloth-dist')
- dist_dir.mkdir(parents=True, exist_ok=True)
- metadata_url = f'https://pypi.org/pypi/unsloth/{pypi_version}/json'
-
- last_error = None
- for attempt in range(1, 6):
- try:
- with urllib.request.urlopen(metadata_url, timeout=30) as response:
- metadata = json.load(response)
- break
- except Exception as exc:
- last_error = exc
- if attempt < 5:
- time.sleep(10 * attempt)
- else:
- sys.exit(f'Publish unsloth=={pypi_version} to PyPI before the desktop release ({last_error})')
-
- files = metadata.get('urls') or []
- if not files:
- sys.exit(f'PyPI returned no distribution files for unsloth=={pypi_version}')
-
- for file_info in files:
- filename = file_info.get('filename')
- url = file_info.get('url')
- if not filename or '/' in filename or not url:
- sys.exit(f'Unexpected PyPI file entry for unsloth=={pypi_version}: {file_info!r}')
- target = dist_dir / filename
- for attempt in range(1, 4):
- try:
- with urllib.request.urlopen(url, timeout=60) as response:
- target.write_bytes(response.read())
- break
- except Exception as exc:
- last_error = exc
- if attempt < 3:
- time.sleep(5 * attempt)
- else:
- sys.exit(f'Could not download {filename} from PyPI ({last_error})')
- PY
-
- if [ -f scripts/stamp_studio_release.py ]; then
- mapfile -t dists < <(find "$RUNNER_TEMP/pypi-unsloth-dist" -type f \( -name '*.whl' -o -name '*.tar.gz' \) | sort)
- if [ "${#dists[@]}" -eq 0 ]; then
- echo "No PyPI wheel/sdist artifacts downloaded for unsloth==$PYPI_VERSION" >&2
- exit 1
- fi
- python3 scripts/stamp_studio_release.py --verify-dist "$RUNNER_TEMP/pypi-unsloth-dist" --expected "$STUDIO_VERSION"
- else
- echo "scripts/stamp_studio_release.py not found; release-desktop requires #5308 to verify the PyPI Studio stamp." >&2
- exit 1
- fi
-
- - name: Guard public updater channel version
- if: ${{ !inputs.draft }}
- shell: bash
- env:
- GH_REPO: ${{ github.repository }}
- GH_TOKEN: ${{ github.token }}
- APP_VERSION: ${{ steps.prepare.outputs.app_version }}
- run: |
- set -euo pipefail
- mkdir -p "$RUNNER_TEMP/desktop-current"
- if ! gh release download desktop-latest --pattern latest.json --dir "$RUNNER_TEMP/desktop-current" --clobber 2>/dev/null; then
- echo "No existing desktop-latest latest.json found; allowing first channel publish."
- exit 0
- fi
- python3 <<'PY'
- import json
- import os
- import pathlib
- import re
- import sys
-
- def parse(value: str):
- value = value.removeprefix('v')
- match = re.fullmatch(
- r'(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)'
- r'(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?'
- r'(?:\+[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?',
- value,
- )
- if not match:
- sys.exit(f'desktop-latest latest.json has invalid version: {value}')
- major, minor, patch, prerelease = match.groups()
- return (int(major), int(minor), int(patch), prerelease)
-
- def numeric_tail(identifier: str) -> tuple[str, int] | None:
- match = re.fullmatch(r'([A-Za-z-]+)(\d+)', identifier)
- if not match:
- return None
- return (match.group(1).lower(), int(match.group(2)))
-
- def compare_identifier(left: str, right: str) -> int:
- left_num = left.isdigit()
- right_num = right.isdigit()
- if left_num and right_num:
- return (int(left) > int(right)) - (int(left) < int(right))
- if left_num:
- return -1
- if right_num:
- return 1
-
- left_tail = numeric_tail(left)
- right_tail = numeric_tail(right)
- if left_tail and right_tail and left_tail[0] == right_tail[0]:
- return (left_tail[1] > right_tail[1]) - (left_tail[1] < right_tail[1])
-
- return (left > right) - (left < right)
-
- def compare_prerelease(left: str | None, right: str | None) -> int:
- if left == right:
- return 0
- if left is None:
- return 1
- if right is None:
- return -1
- left_parts = left.split('.')
- right_parts = right.split('.')
- for left_part, right_part in zip(left_parts, right_parts):
- order = compare_identifier(left_part, right_part)
- if order:
- return order
- return (len(left_parts) > len(right_parts)) - (len(left_parts) < len(right_parts))
-
- def compare(left: str, right: str) -> int:
- left_major, left_minor, left_patch, left_pre = parse(left)
- right_major, right_minor, right_patch, right_pre = parse(right)
- left_core = (left_major, left_minor, left_patch)
- right_core = (right_major, right_minor, right_patch)
- if left_core != right_core:
- return (left_core > right_core) - (left_core < right_core)
- return compare_prerelease(left_pre, right_pre)
-
- current_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-current', 'latest.json')
- current = json.loads(current_path.read_text()).get('version')
- next_version = os.environ['APP_VERSION']
- if not isinstance(current, str):
- sys.exit('desktop-latest latest.json has missing version')
- if compare(next_version, current) < 0:
- sys.exit(
- f'Refusing to publish {next_version}; desktop-latest currently points at newer version {current}.'
- )
- PY
-
- build:
- # TODO: split into a "build (no secrets)" + "publish (secrets)" job pair
- # with actions/upload-artifact handoff so the matrix build cannot
- # publish a Release on its own. The current matrix runs across
- # Linux/macOS/Windows in a single job, so the split needs artefact
- # collection across the OS matrix and is out of scope for this
- # hardening pass.
- permissions:
- contents: write # tauri-apps/tauri-action creates / uploads a GitHub Release
- strategy:
- fail-fast: false
- max-parallel: 1
- matrix:
- include:
- - platform: macos-latest
- args: '--target aarch64-apple-darwin'
- label: macOS (Apple Silicon)
- # - platform: macos-latest
- # args: '--target x86_64-apple-darwin'
- # label: macOS (Intel)
- - platform: ubuntu-22.04
- args: ''
- label: Linux (x64)
- - platform: windows-latest
- args: ''
- label: Windows (x64)
-
- name: Build ${{ matrix.label }}
- needs: prepare-version
- runs-on: ${{ matrix.platform }}
-
- env:
- FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
- APP_VERSION: ${{ needs.prepare-version.outputs.app_version }}
- STUDIO_VERSION: ${{ needs.prepare-version.outputs.studio_version }}
- DESKTOP_RELEASE_TAG: ${{ needs.prepare-version.outputs.desktop_release_tag }}
- DESKTOP_PRERELEASE: ${{ needs.prepare-version.outputs.prerelease }}
-
- steps:
- # harden-runner in audit mode: surfaces every egress destination in
- # the runner log so the allowlist for a future `egress-policy: block`
- # promotion can be derived from observed traffic. Audit mode is
- # cross-platform (Linux / macOS / Windows runners); blocking mode is
- # currently Linux-only, so we deliberately stay in audit until the
- # macOS + Windows codesign paths have been observed.
- - name: Harden runner (audit)
- uses: step-security/harden-runner@a5ad31d6a139d249332a2605b85202e8c0b78450 # v2.19.1
- with:
- egress-policy: audit
-
- - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
- with:
- persist-credentials: false
-
- # ── Linux dependencies ──
- - name: Install Linux dependencies
- if: matrix.platform == 'ubuntu-22.04'
- run: |
- sudo apt-get update
- sudo apt-get install -y libwebkit2gtk-4.1-dev libayatana-appindicator3-dev librsvg2-dev libxdo-dev libssl-dev patchelf
-
- # ── Node.js ──
- - name: Setup Node.js
- uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e
- with:
- node-version: 24
-
- - name: Install pinned Tauri CLI
- # Lifecycle scripts (esbuild native-binary postinstall, etc.) are
- # required for `vite build`. The pre-install lockfile structural
- # audit (lockfile_supply_chain_audit.py) is the practical defence
- # against the npm postinstall-dropper class -- it fires BEFORE any
- # tarball runs, on the injection pattern itself rather than an
- # advisory-DB lookup.
- run: npm install --save-dev --prefix studio @tauri-apps/cli@2.10.1 --no-fund --no-audit
-
- - name: Verify pinned Tauri CLI
- shell: bash
- run: |
- out="$(npx --prefix studio tauri --version)"
- echo "$out"
- if [ "$out" != "tauri-cli 2.10.1" ]; then
- echo "Expected tauri-cli 2.10.1, got $out" >&2
- exit 1
- fi
-
- - name: Verify desktop updater and Linux package config
- shell: bash
- run: |
- node <<'JS'
- const { readFileSync } = require('node:fs');
-
- const expected = 'https://github.com/unslothai/unsloth/releases/download/desktop-latest/latest.json';
- const config = JSON.parse(readFileSync('studio/src-tauri/tauri.conf.json', 'utf8'));
- const endpoints = config.plugins?.updater?.endpoints;
- if (!Array.isArray(endpoints) || endpoints.length !== 1) {
- throw new Error('Expected exactly one desktop updater endpoint');
- }
- if (endpoints[0] !== expected) {
- throw new Error('Desktop updater endpoint must be ' + expected + ', got ' + endpoints[0]);
- }
- if (endpoints.some((endpoint) => endpoint.includes('/releases/latest/'))) {
- throw new Error('Desktop updater endpoint must not use repo-wide /releases/latest/');
- }
-
- const targets = config.bundle?.targets;
- if (Array.isArray(targets) && targets.some((target) => String(target).toLowerCase() === 'rpm')) {
- throw new Error('Desktop release must not target RPM packages');
- }
- if (config.bundle?.linux?.rpm) {
- throw new Error('bundle.linux.rpm must not be configured');
- }
-
- const workflow = readFileSync('.github/workflows/release-desktop.yml', 'utf8');
- const lines = workflow.split(/\r?\n/);
- const releaseBodies = [];
- for (let i = 0; i < lines.length; i += 1) {
- const match = lines[i].match(/^(\s*)releaseBody:\s*\|\s*$/);
- if (!match) continue;
- const baseIndent = match[1].length;
- const bodyLines = [];
- i += 1;
- for (; i < lines.length; i += 1) {
- const line = lines[i];
- if (line.trim() === '') {
- bodyLines.push('');
- continue;
- }
- const indent = line.match(/^\s*/)[0].length;
- if (indent <= baseIndent) {
- i -= 1;
- break;
- }
- bodyLines.push(line.slice(baseIndent + 2));
- }
- releaseBodies.push(bodyLines.join('\n'));
- }
- if (releaseBodies.length === 0) {
- throw new Error('Expected at least one desktop release body');
- }
- for (const body of releaseBodies) {
- if (/\brpm\b|\.rpm/i.test(body)) {
- throw new Error('Desktop release body must not advertise RPM packages');
- }
- }
- JS
-
- - name: Install frontend dependencies
- working-directory: studio/frontend
- # Lifecycle scripts (esbuild native-binary postinstall, etc.) are
- # required for `vite build`. The pre-install lockfile structural
- # audit (lockfile_supply_chain_audit.py) is the practical defence
- # against the npm postinstall-dropper class -- it fires BEFORE any
- # tarball runs, on the injection pattern itself rather than an
- # advisory-DB lookup.
- run: npm install --no-fund --no-audit
-
- # ── Rust ──
- - name: Install Rust stable
- uses: dtolnay/rust-toolchain@29eef336d9b2848a0b548edc03f92a220660cdb8 # stable @ 2026-03-27
- with:
- targets: ${{ matrix.platform == 'macos-latest' && 'aarch64-apple-darwin,x86_64-apple-darwin' || '' }}
-
- - name: Patch desktop app version
- shell: bash
- working-directory: studio/src-tauri
- run: |
- set -euo pipefail
- if command -v python3 >/dev/null 2>&1; then
- PYTHON=python3
- else
- PYTHON=python
- fi
- "$PYTHON" <<'PY'
- import os
- import pathlib
- import re
- import sys
-
- app_version = os.environ['APP_VERSION']
- if not app_version:
- sys.exit('APP_VERSION is required')
-
- cargo_toml = pathlib.Path('Cargo.toml')
- lines = cargo_toml.read_text().splitlines(keepends=True)
- in_package = False
- patched = False
- for index, line in enumerate(lines):
- stripped = line.strip()
- if stripped == '[package]':
- in_package = True
- continue
- if stripped.startswith('[') and stripped.endswith(']'):
- in_package = False
- if in_package and re.fullmatch(r'version\s*=\s*"[^"]+"\s*', stripped):
- lines[index] = f'version = "{app_version}"\n'
- patched = True
- break
- if not patched:
- sys.exit('Could not patch [package] version in Cargo.toml')
- cargo_toml.write_text(''.join(lines))
-
- cargo_lock = pathlib.Path('Cargo.lock')
- lock_text = cargo_lock.read_text()
- lock_text, count = re.subn(
- r'(?m)(^\[\[package\]\]\nname = "unsloth-studio"\nversion = ")[^"]+(")',
- lambda match: f'{match.group(1)}{app_version}{match.group(2)}',
- lock_text,
- )
- if count != 1:
- sys.exit(f'Could not patch unsloth-studio version in Cargo.lock (matches={count})')
- cargo_lock.write_text(lock_text)
- PY
-
- cargo metadata --locked --no-deps --format-version 1 > "$RUNNER_TEMP/cargo-metadata.json"
- "$PYTHON" <<'PY'
- import json
- import os
- import pathlib
- import sys
-
- app_version = os.environ['APP_VERSION']
- metadata = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'cargo-metadata.json').read_text())
- versions = [package['version'] for package in metadata.get('packages', []) if package.get('name') == 'unsloth-studio']
- if versions != [app_version]:
- sys.exit(f'cargo metadata unsloth-studio version mismatch: expected {app_version}, got {versions}')
- PY
-
- git diff -- Cargo.toml Cargo.lock
-
- - name: Rust cache
- uses: swatinem/rust-cache@e18b497796c12c097a38f9edb9d0641fb99eee32
- with:
- workspaces: 'studio/src-tauri -> target'
-
- # ── macOS: import signing certificate ──
- - name: Import Apple certificate
- if: matrix.platform == 'macos-latest'
- env:
- APPLE_CERTIFICATE: ${{ secrets.APPLE_CERTIFICATE }}
- APPLE_CERTIFICATE_PASSWORD: ${{ secrets.APPLE_CERTIFICATE_PASSWORD }}
- KEYCHAIN_PASSWORD: ${{ secrets.KEYCHAIN_PASSWORD }}
- run: |
- echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12
- security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
- security default-keychain -s build.keychain
- security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
- security set-keychain-settings -t 3600 -u build.keychain
- security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
- security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain
- security find-identity -v -p codesigning build.keychain
- rm -f certificate.p12
-
- # ── Windows: install Azure Trusted Signing CLI ──
- - name: Install trusted-signing-cli
- if: matrix.platform == 'windows-latest'
- run: |
- cargo install trusted-signing-cli --version 0.10.0 --locked
- echo "$env:USERPROFILE\.cargo\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
-
- # ── Windows: verify signing CLI is accessible ──
- - name: Verify trusted-signing-cli
- if: matrix.platform == 'windows-latest'
- run: |
- Write-Output "PATH: $env:PATH"
- Get-Command trusted-signing-cli -ErrorAction SilentlyContinue || Write-Output "trusted-signing-cli NOT in PATH"
- trusted-signing-cli --version || Write-Output "trusted-signing-cli failed to run"
-
- # ── Linux: build + sign + upload ──
- - name: Build Linux app
- if: matrix.platform == 'ubuntu-22.04'
- uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5
- env:
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }}
- TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }}
- with:
- projectPath: studio
- tauriScript: npx --prefix . tauri
- tagName: ${{ needs.prepare-version.outputs.desktop_release_tag }}
- releaseName: 'Unsloth Studio (Desktop) ${{ needs.prepare-version.outputs.studio_version }}'
- releaseBody: |
- Desktop app for Unsloth Studio.
-
- **macOS**: Download the Apple Silicon `.dmg`.
- **Windows**: Download the `-setup.exe` installer.
- **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal).
-
- > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package.
- > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64`
- > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually.
- releaseDraft: ${{ inputs.draft }}
- prerelease: ${{ needs.prepare-version.outputs.prerelease }}
- args: -v ${{ matrix.args }}
-
- # ── macOS: build + sign + notarize + upload ──
- - name: Build macOS app
- if: matrix.platform == 'macos-latest'
- uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5
- env:
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }}
- TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }}
- APPLE_SIGNING_IDENTITY: ${{ secrets.APPLE_SIGNING_IDENTITY }}
- APPLE_ID: ${{ secrets.APPLE_ID }}
- APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
- APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
- with:
- projectPath: studio
- tauriScript: npx --prefix . tauri
- tagName: ${{ needs.prepare-version.outputs.desktop_release_tag }}
- releaseName: 'Unsloth Studio (Desktop) ${{ needs.prepare-version.outputs.studio_version }}'
- releaseBody: |
- Desktop app for Unsloth Studio.
-
- **macOS**: Download the Apple Silicon `.dmg`.
- **Windows**: Download the `-setup.exe` installer.
- **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal).
-
- > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package.
- > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64`
- > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually.
- releaseDraft: ${{ inputs.draft }}
- prerelease: ${{ needs.prepare-version.outputs.prerelease }}
- args: -v ${{ matrix.args }}
-
- # ── Windows: build + sign + upload ──
- - name: Build Windows app
- if: matrix.platform == 'windows-latest'
- uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5
- env:
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }}
- TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }}
- AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }}
- AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }}
- AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }}
- AZURE_TRUSTED_SIGNING_ACCOUNT_NAME: ${{ secrets.AZURE_TRUSTED_SIGNING_ACCOUNT_NAME }}
- AZURE_CERTIFICATE_PROFILE_NAME: ${{ secrets.AZURE_CERTIFICATE_PROFILE_NAME }}
- with:
- projectPath: studio
- tauriScript: npx --prefix . tauri
- tagName: ${{ needs.prepare-version.outputs.desktop_release_tag }}
- releaseName: 'Unsloth Studio (Desktop) ${{ needs.prepare-version.outputs.studio_version }}'
- releaseBody: |
- Desktop app for Unsloth Studio.
-
- **macOS**: Download the Apple Silicon `.dmg`.
- **Windows**: Download the `-setup.exe` installer.
- **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal).
-
- > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package.
- > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64`
- > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually.
- releaseDraft: ${{ inputs.draft }}
- prerelease: ${{ needs.prepare-version.outputs.prerelease }}
- args: -v ${{ matrix.args }}
-
- # Release process note: only non-draft workflow runs advance the public
- # desktop-latest updater channel. Draft builds are for private review; if a
- # draft is manually published later, this channel intentionally remains
- # unchanged until a narrow manual channel-publish flow is added or a public
- # desktop release is created by running this workflow with draft=false.
- publish-updater-channel:
- name: Publish desktop updater channel
- needs: [prepare-version, build]
- if: ${{ !inputs.draft }}
- runs-on: ubuntu-latest
- permissions:
- contents: write
- env:
- GH_REPO: ${{ github.repository }}
- APP_VERSION: ${{ needs.prepare-version.outputs.app_version }}
- STUDIO_VERSION: ${{ needs.prepare-version.outputs.studio_version }}
- DESKTOP_RELEASE_TAG: ${{ needs.prepare-version.outputs.desktop_release_tag }}
- DESKTOP_PRERELEASE: ${{ needs.prepare-version.outputs.prerelease }}
-
- steps:
- - name: Download versioned updater metadata
- shell: bash
- env:
- GH_TOKEN: ${{ github.token }}
- run: |
- set -euo pipefail
- mkdir -p "$RUNNER_TEMP/desktop-updater"
- gh api "repos/${GITHUB_REPOSITORY}/releases/tags/${DESKTOP_RELEASE_TAG}" > "$RUNNER_TEMP/source-release.json"
- python3 <<'PY'
- import json
- import os
- import pathlib
- import sys
-
- source = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'source-release.json').read_text())
- expected_tag = os.environ['DESKTOP_RELEASE_TAG']
- if source.get('tag_name') != expected_tag:
- sys.exit(f'Expected source release {expected_tag}, got {source.get("tag_name")}')
- if source.get('draft'):
- sys.exit(f'Source desktop release {expected_tag} is draft; refusing to publish public updater channel')
- PY
- gh release download "$DESKTOP_RELEASE_TAG" --pattern latest.json --dir "$RUNNER_TEMP/desktop-updater" --clobber
- test -s "$RUNNER_TEMP/desktop-updater/latest.json"
-
- - name: Validate versioned updater metadata
- shell: bash
- run: |
- python3 <<'PY'
- import json
- import os
- import pathlib
- import re
- import sys
-
- app_version = os.environ['APP_VERSION']
- release_tag = os.environ['DESKTOP_RELEASE_TAG']
- latest_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-updater', 'latest.json')
- data = json.loads(latest_path.read_text())
- if not isinstance(data, dict):
- sys.exit('latest.json must be a JSON object')
-
- version = data.get('version')
- if not isinstance(version, str) or not version:
- sys.exit('latest.json missing version')
- if not re.fullmatch(r'v?\d+\.\d+\.\d+(?:[-+][0-9A-Za-z.-]+)?', version):
- sys.exit(f'latest.json version is not SemVer-like: {version}')
- if version.removeprefix('v') != app_version:
- sys.exit(f'latest.json version {version} does not match desktop app version {app_version}')
-
- platforms = data.get('platforms')
- if not isinstance(platforms, dict) or not platforms:
- sys.exit('latest.json missing platforms')
-
- required_families = {
- 'darwin-aarch64': False,
- 'linux-x86_64': False,
- 'windows-x86_64': False,
- }
- expected_prefix = f'https://github.com/unslothai/unsloth/releases/download/{release_tag}/'
- forbidden_fragments = ('/releases/latest/', '/releases/download/desktop-latest/')
-
- for platform, entry in platforms.items():
- if not isinstance(entry, dict):
- sys.exit(f'Platform {platform} must be an object')
- url = entry.get('url')
- signature = entry.get('signature')
- if not isinstance(url, str) or not url.strip():
- sys.exit(f'Platform {platform} missing url')
- if not isinstance(signature, str) or not signature.strip():
- sys.exit(f'Platform {platform} missing signature')
- if any(fragment in url for fragment in forbidden_fragments):
- sys.exit(f'Platform {platform} points at a moving updater channel: {url}')
- if not url.startswith(expected_prefix):
- sys.exit(f'Platform {platform} URL must point at {release_tag}: {url}')
- for family in required_families:
- if platform == family or platform.startswith(family + '-'):
- required_families[family] = True
-
- missing = [family for family, found in required_families.items() if not found]
- if missing:
- sys.exit('latest.json missing required platform families: ' + ', '.join(missing))
- PY
-
- - name: Ensure desktop updater channel release
- shell: bash
- env:
- GH_TOKEN: ${{ github.token }}
- run: |
- set -euo pipefail
- channel_json="$RUNNER_TEMP/desktop-latest-release.json"
- if ! gh api "repos/${GITHUB_REPOSITORY}/releases/tags/desktop-latest" > "$channel_json" 2>/dev/null; then
- gh release create desktop-latest \
- --title "Unsloth Studio Desktop updater channel" \
- --notes "Machine-managed desktop updater channel; latest.json is replaced by release-desktop.yml." \
- --prerelease \
- --latest=false \
- --target "$GITHUB_SHA"
- gh api "repos/${GITHUB_REPOSITORY}/releases/tags/desktop-latest" > "$channel_json"
- fi
-
- python3 <<'PY'
- import json
- import os
- import pathlib
- import sys
-
- channel = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-latest-release.json').read_text())
- if channel.get('draft'):
- sys.exit('desktop-latest release is draft; refusing to publish updater channel')
- if channel.get('immutable'):
- sys.exit('desktop-latest release is immutable; cannot replace latest.json')
- if not channel.get('prerelease'):
- sys.exit('desktop-latest release must be a prerelease so it cannot compete with repo-wide latest')
- PY
-
- - name: Prevent updater channel downgrade
- shell: bash
- env:
- GH_TOKEN: ${{ github.token }}
- run: |
- set -euo pipefail
- mkdir -p "$RUNNER_TEMP/desktop-current"
- if ! gh release download desktop-latest --pattern latest.json --dir "$RUNNER_TEMP/desktop-current" --clobber 2>/dev/null; then
- echo "No existing desktop-latest latest.json found; allowing first channel publish."
- exit 0
- fi
- python3 <<'PY'
- import json
- import os
- import pathlib
- import re
- import sys
-
- def parse(value: str):
- value = value.removeprefix('v')
- match = re.fullmatch(
- r'(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)'
- r'(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?'
- r'(?:\+[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?',
- value,
- )
- if not match:
- sys.exit(f'desktop-latest latest.json has invalid version: {value}')
- major, minor, patch, prerelease = match.groups()
- return (int(major), int(minor), int(patch), prerelease)
-
- def numeric_tail(identifier: str) -> tuple[str, int] | None:
- match = re.fullmatch(r'([A-Za-z-]+)(\d+)', identifier)
- if not match:
- return None
- return (match.group(1).lower(), int(match.group(2)))
-
- def compare_identifier(left: str, right: str) -> int:
- left_num = left.isdigit()
- right_num = right.isdigit()
- if left_num and right_num:
- return (int(left) > int(right)) - (int(left) < int(right))
- if left_num:
- return -1
- if right_num:
- return 1
-
- left_tail = numeric_tail(left)
- right_tail = numeric_tail(right)
- if left_tail and right_tail and left_tail[0] == right_tail[0]:
- return (left_tail[1] > right_tail[1]) - (left_tail[1] < right_tail[1])
-
- return (left > right) - (left < right)
-
- def compare_prerelease(left: str | None, right: str | None) -> int:
- if left == right:
- return 0
- if left is None:
- return 1
- if right is None:
- return -1
- left_parts = left.split('.')
- right_parts = right.split('.')
- for left_part, right_part in zip(left_parts, right_parts):
- order = compare_identifier(left_part, right_part)
- if order:
- return order
- return (len(left_parts) > len(right_parts)) - (len(left_parts) < len(right_parts))
-
- def compare(left: str, right: str) -> int:
- left_major, left_minor, left_patch, left_pre = parse(left)
- right_major, right_minor, right_patch, right_pre = parse(right)
- left_core = (left_major, left_minor, left_patch)
- right_core = (right_major, right_minor, right_patch)
- if left_core != right_core:
- return (left_core > right_core) - (left_core < right_core)
- return compare_prerelease(left_pre, right_pre)
-
- current_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-current', 'latest.json')
- next_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-updater', 'latest.json')
- current = json.loads(current_path.read_text()).get('version')
- next_version = json.loads(next_path.read_text()).get('version')
- if not isinstance(current, str) or not isinstance(next_version, str):
- sys.exit('Could not compare desktop-latest channel versions')
- if compare(next_version, current) < 0:
- sys.exit(
- f'Refusing to move desktop-latest from {current} to older version {next_version}.'
- )
- PY
-
- - name: Publish desktop updater channel metadata
- shell: bash
- env:
- GH_TOKEN: ${{ github.token }}
- run: |
- set -euo pipefail
- gh release upload desktop-latest "$RUNNER_TEMP/desktop-updater/latest.json" --clobber
- gh api "repos/${GITHUB_REPOSITORY}/releases/tags/desktop-latest" > "$RUNNER_TEMP/desktop-latest-release.json"
- python3 <<'PY'
- import json
- import os
- import pathlib
- import sys
-
- channel = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-latest-release.json').read_text())
- assets = [asset for asset in channel.get('assets', []) if asset.get('name') == 'latest.json']
- if len(assets) != 1:
- sys.exit(f'Expected exactly one desktop-latest latest.json asset, found {len(assets)}')
- expected_url = f'https://github.com/{os.environ["GITHUB_REPOSITORY"]}/releases/download/desktop-latest/latest.json'
- actual_url = assets[0].get('browser_download_url')
- if actual_url != expected_url:
- sys.exit(f'desktop-latest latest.json URL mismatch: expected {expected_url}, got {actual_url}')
- PY
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
deleted file mode 100644
index 1a4cf841d0..0000000000
--- a/.github/workflows/stale.yml
+++ /dev/null
@@ -1,37 +0,0 @@
-name: 'Inactive Issue Pinger'
-
-on:
- schedule:
- - cron: '30 5 * * *' # Runs at 5:30 UTC every day
-
-jobs:
- stale:
- runs-on: ubuntu-latest
- permissions:
- issues: write
-
- steps:
- - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
- with:
- # The message to post on stale issues.
- # This message will ping the issue author.
- # Note: The stale bot action does not currently support a direct placeholder for the last commenter.
- # As a workaround, this message encourages any participant to reply.
- stale-issue-message: >
- Is this issue still important to you?
- Apologies in advance we might have missed this issue as well.
- For faster response times, please post on our Reddit server - https://www.reddit.com/r/unsloth or our Discord - https://discord.com/invite/unsloth
-
- # The number of days of inactivity before an issue is considered stale.
- days-before-issue-stale: 9999
-
- # Set to -1 to never close stale issues.
- days-before-issue-close: -1
-
- # A label to apply to stale issues.
- stale-issue-label: 'inactive'
-
- # The number of operations to perform per run to avoid rate limiting.
- operations-per-run: 500
-
- enable-statistics: false
diff --git a/.github/workflows/studio-frontend-ci.yml b/.github/workflows/studio-frontend-ci.yml
deleted file mode 100644
index 1270a57ef6..0000000000
--- a/.github/workflows/studio-frontend-ci.yml
+++ /dev/null
@@ -1,151 +0,0 @@
-# SPDX-License-Identifier: AGPL-3.0-only
-# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
-
-# Frontend PR gate: lockfile freshness, typecheck, build, and a bundle grep
-# that catches the 2026.5.1 chat-history regression at the JS level.
-#
-# biome runs as non-blocking for now: the codebase currently has accumulated
-# ~470 errors and ~1650 warnings against the existing biome config. Surfacing
-# the count in CI lets us drive it down without forcing a fleet-wide cleanup
-# in the same PR. Drop `continue-on-error` once that number is zero.
-
-name: Frontend CI
-
-on:
- pull_request:
- paths:
- - 'studio/frontend/**'
- - 'scripts/check_frontend_dep_removal.py'
- - 'tests/studio/test_frontend_dep_removal.py'
- - '.github/workflows/studio-frontend-ci.yml'
- push:
- branches: [main, pip]
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: true
-
-permissions:
- contents: read
-
-jobs:
- build:
- name: Frontend build + bundle sanity
- runs-on: ubuntu-latest
- timeout-minutes: 10
- defaults:
- run:
- working-directory: studio/frontend
- steps:
- - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- with:
- persist-credentials: false
-
- # FIXME: drop this step once @assistant-ui/* and assistant-stream
- # leave 0.x -- on 1.x, caret ranges are conventional. Until then,
- # every 0.minor on this surface is a SemVer-major (this is exactly
- # how 2026.5.1 shipped a broken chat runtime: ^0.12.19 quietly
- # resolved to 0.12.28).
- - name: '@assistant-ui must be pinned exactly (no caret/tilde)'
- working-directory: ${{ github.workspace }}
- run: |
- set -e
- if grep -nE '"(@assistant-ui/[a-z-]+|assistant-stream)":[[:space:]]*"[\^~]' studio/frontend/package.json; then
- echo "::error file=studio/frontend/package.json::These packages must be pinned to exact versions until they leave 0.x. Drop the leading ^ or ~."
- exit 1
- fi
- echo "All assistant-ui packages are pinned exactly."
-
- - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
- with:
- node-version: '22'
-
- # Run the structural lockfile scan BEFORE npm ci. A compromised
- # tarball runs its `prepare` / `postinstall` during `npm ci`,
- # so any catch has to fire upstream of that. The scanner is
- # pure-Python read-only; safe to call ahead of every install.
- - name: Lockfile supply-chain audit (pre-install scan)
- working-directory: ${{ github.workspace }}
- run: python3 scripts/lockfile_supply_chain_audit.py
-
- - name: Lockfile must agree with package.json (npm ci is strict)
- # Lifecycle scripts (esbuild native-binary postinstall, etc.) are
- # required for `vite build`. The pre-install lockfile structural
- # audit (lockfile_supply_chain_audit.py) is the practical defence
- # against the npm postinstall-dropper class -- it fires BEFORE any
- # tarball runs, on the injection pattern itself rather than an
- # advisory-DB lookup.
- run: npm ci --no-fund --no-audit
-
- - name: npm ci must not have modified the working tree
- working-directory: ${{ github.workspace }}
- run: |
- if ! git diff --quiet -- studio/frontend; then
- echo "::error::npm ci modified files; commit the updated lockfile"
- git status -- studio/frontend
- exit 1
- fi
-
- # Catch the common foot-gun: a dep dropped from package.json that is
- # still imported somewhere. The script walks the lockfile dep graph
- # from the new top-level deps and only counts top-level node_modules
- # paths as valid resolution targets for bare src/ imports.
- #
- # actions/checkout uses fetch-depth: 1 by default, so the base branch
- # is not available locally. Fetch the single base commit with an
- # explicit refspec so origin/ is reliably created (a bare
- # `git fetch origin [` only updates FETCH_HEAD in some configs).
- - name: Dependency removal safety check
- if: github.event_name == 'pull_request'
- working-directory: ${{ github.workspace }}
- run: |
- git fetch --no-tags --depth=1 origin \
- "${{ github.base_ref }}:refs/remotes/origin/${{ github.base_ref }}"
- python3 scripts/check_frontend_dep_removal.py \
- --base "origin/${{ github.base_ref }}" \
- --enumerate-dead
- python3 tests/studio/test_frontend_dep_removal.py
-
- - name: Typecheck
- run: npm run typecheck
-
- - name: Build
- run: npm run build
-
- - name: Built bundle must not contain Studio's unstable_Provider call site
- run: |
- set -e
- JS=$(ls dist/assets/index-*.js | head -1)
- HITS=$(grep -c 'unstable_Provider:' "$JS" || echo 0)
- echo "main bundle: $JS"
- echo "unstable_Provider: hits=$HITS (assistant-ui internals contribute up to 3)"
- if [ "$HITS" -gt 3 ]; then
- echo "::error file=studio/frontend/src/features/chat/runtime-provider.tsx::Studio bundle still passes unstable_Provider through useRemoteThreadListRuntime; this is the 2026.5.1 chat-history regression. Pass adapters directly into useLocalRuntime instead."
- exit 1
- fi
-
- - name: Bundle size budget (75 MB)
- run: |
- SIZE=$(du -sb dist | cut -f1)
- BUDGET=$((75 * 1024 * 1024))
- echo "dist size: $SIZE bytes ($((SIZE/1024/1024)) MB), budget: $BUDGET bytes (75 MB)"
- if [ "$SIZE" -gt "$BUDGET" ]; then
- echo "::error::studio/frontend/dist/ exceeded the 75 MB budget. Drop dead deps (e.g. the unused next dep) or split chunks."
- exit 1
- fi
-
- - name: Biome (non-blocking until accumulated drift is cleared)
- continue-on-error: true
- run: npm run biome:check
-
- - name: Upload built dist
- # Always upload so a green run is reviewable too -- the dist
- # output catches "tests passed but bundle changed unexpectedly"
- # regressions that would be invisible if we only kept artifacts
- # on failure.
- if: always()
- uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
- with:
- name: studio-frontend-dist
- path: studio/frontend/dist
- retention-days: 3
diff --git a/.github/workflows/studio-inference-smoke.yml b/.github/workflows/studio-inference-smoke.yml
deleted file mode 100644
index 6def56f769..0000000000
--- a/.github/workflows/studio-inference-smoke.yml
+++ /dev/null
@@ -1,1052 +0,0 @@
-# SPDX-License-Identifier: AGPL-3.0-only
-# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
-
-# Three end-to-end smoke jobs that boot a freshly-installed Studio and
-# exercise the surfaces real users hit through the OpenAI / Anthropic
-# SDKs and curl. Each job picks the smallest model that exercises the
-# behaviour under test, primes HF_HOME via actions/cache, and shares
-# the install.sh --local --no-torch bootstrap.
-#
-# 1. OpenAI, Anthropic API tests
-# gemma-3-270m-it UD-Q4_K_XL (~254 MiB).
-# Password rotation via /api/auth/change-password (old fails,
-# new works), then OpenAI + Anthropic Python SDKs against /v1/*
-# with temperature=0 and a fixed seed. Asserts the four-turn
-# conversation is deterministic across two runs.
-#
-# 2. Tool calling Tests
-# Qwen3.5-2B UD-IQ3_XXS (~890 MiB). OpenAI function calling,
-# server-side tools (python, terminal, web_search) via
-# enable_tools / enabled_tools, and enable_thinking on/off.
-#
-# 3. JSON, images
-# gemma-4-E2B-it UD-IQ3_XXS (~2.4 GiB) + mmproj-F16 (~986 MiB).
-# response_format JSON-schema decoding and OpenAI image_url
-# (data URI) plus Anthropic source/base64 image inputs.
-#
-# All three jobs run in parallel. Total wall time is dominated by job 3
-# on a cold cache; warm cache cuts that to ~3 min.
-
-name: Studio GGUF CI
-
-on:
- pull_request:
- paths:
- - 'studio/**'
- - 'unsloth/**'
- - 'unsloth_cli/**'
- - 'install.sh'
- - 'pyproject.toml'
- - '.github/workflows/studio-inference-smoke.yml'
- push:
- branches: [main, pip]
- # Manual trigger for pre-warming HF_HOME caches on main, or re-running
- # against an arbitrary branch without pushing a no-op commit.
- workflow_dispatch:
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: true
-
-permissions:
- contents: read
-
-jobs:
- # ─────────────────────────────────────────────────────────────────────
- # Job 1: OpenAI, Anthropic API tests
- # ─────────────────────────────────────────────────────────────────────
- openai-anthropic:
- name: OpenAI, Anthropic API tests
- runs-on: ubuntu-latest
- timeout-minutes: 25
- env:
- GGUF_REPO: unsloth/gemma-3-270m-it-GGUF
- GGUF_VARIANT: UD-Q4_K_XL
- GGUF_FILE: gemma-3-270m-it-UD-Q4_K_XL.gguf
- STUDIO_PORT: '18888'
- HF_HOME: ${{ github.workspace }}/hf-cache
- steps:
- - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- with:
- persist-credentials: false
-
- - name: Linux deps for llama.cpp prebuilt
- run: |
- sudo apt-get update
- sudo apt-get install -y --no-install-recommends \
- libcurl4-openssl-dev libssl-dev jq
-
- - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
- with:
- node-version: '22'
-
- - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
- with:
- python-version: '3.12'
- cache: 'pip'
-
- - name: Restore HF_HOME for ${{ env.GGUF_REPO }}
- id: cache-hf
- uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
- continue-on-error: true
- with:
- path: hf-cache
- key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-v1
-
- - name: Prime HF_HOME with the GGUF
- id: prime-hf
- if: steps.cache-hf.outputs.cache-hit != 'true' || steps.cache-hf.outcome != 'success'
- env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
- run: |
- python -m pip install --upgrade huggingface_hub
- mkdir -p hf-cache
- bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$GGUF_FILE"
-
- - name: Save HF_HOME for ${{ env.GGUF_REPO }}
- if: always() && steps.prime-hf.outcome == 'success'
- uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
- with:
- path: hf-cache
- key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-v1
-
- - name: Install Studio (--local, --no-torch)
- env:
- GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- run: |
- mkdir -p logs
- set -o pipefail
- bash install.sh --local --no-torch 2>&1 | tee logs/install.log
-
- - name: Install OpenAI + Anthropic Python SDKs
- run: pip install 'openai>=1.50' 'anthropic>=0.40'
-
- - name: Reset auth + boot Studio (API-only)
- run: |
- unsloth studio reset-password
- mkdir -p logs
- UNSLOTH_API_ONLY=1 unsloth studio -H 127.0.0.1 -p "$STUDIO_PORT" \
- > logs/studio.log 2>&1 &
- echo "STUDIO_PID=$!" >> "$GITHUB_ENV"
-
- - name: Wait for /api/health
- run: |
- for i in $(seq 1 180); do
- if curl -fs "http://127.0.0.1:${STUDIO_PORT}/api/health" > /tmp/health.json; then
- jq -e '.status == "healthy"' /tmp/health.json
- exit 0
- fi
- sleep 1
- done
- echo "Studio did not become healthy in 180s"
- tail -200 logs/studio.log
- exit 1
-
- - name: Password rotation (old must fail, new must work)
- run: |
- OLD=$(cat ~/.unsloth/studio/auth/.bootstrap_password)
- NEW="CIRotated-$(python -c 'import secrets; print(secrets.token_urlsafe(12))')"
- echo "::add-mask::$OLD"
- echo "::add-mask::$NEW"
- # 1. Login with the bootstrap password.
- OLD_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \
- -H 'content-type: application/json' \
- -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}" | jq -r .access_token)
- [ -n "$OLD_TOKEN" ] && [ "$OLD_TOKEN" != "null" ] || { echo "bootstrap login failed"; exit 1; }
- # 2. Rotate to a fresh random password.
- curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/change-password" \
- -H "Authorization: Bearer $OLD_TOKEN" -H 'content-type: application/json' \
- -d "{\"current_password\":\"$OLD\",\"new_password\":\"$NEW\"}" > /dev/null
- # 3. Old password must now be rejected (HTTP 401).
- OLD_STATUS=$(curl -s -o /dev/null -w '%{http_code}' \
- -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \
- -H 'content-type: application/json' \
- -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}")
- if [ "$OLD_STATUS" != "401" ]; then
- echo "::error::Login with old password returned $OLD_STATUS, expected 401"
- exit 1
- fi
- # 4. New password must succeed; capture the JWT for downstream steps.
- NEW_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \
- -H 'content-type: application/json' \
- -d "{\"username\":\"unsloth\",\"password\":\"$NEW\"}" | jq -r .access_token)
- [ -n "$NEW_TOKEN" ] && [ "$NEW_TOKEN" != "null" ] || { echo "new login failed"; exit 1; }
- echo "TOKEN=$NEW_TOKEN" >> "$GITHUB_ENV"
- echo "password rotation OK (old=401, new=200)"
-
- - name: Load the GGUF (HF repo + variant, served from HF_HOME cache)
- run: |
- curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/inference/load" \
- -H "Authorization: Bearer $TOKEN" -H 'content-type: application/json' \
- --max-time 600 \
- -d "{\"model_path\":\"$GGUF_REPO\",\"gguf_variant\":\"$GGUF_VARIANT\",\"is_lora\":false,\"max_seq_length\":2048}" \
- | jq '{status, display_name, is_gguf, context_length}'
-
- - name: Multi-turn determinism via OpenAI + Anthropic SDKs
- env:
- BASE_URL: http://127.0.0.1:18888
- run: |
- python - <<'PY'
- import json
- import os
- from openai import OpenAI
- from anthropic import Anthropic
-
- BASE = os.environ["BASE_URL"]
- KEY = os.environ["TOKEN"] # JWT also accepted as Bearer on /v1/*
- SEED = 3407
-
- # Four-turn conversation: the second and fourth turns can only be
- # answered correctly if the model sees the prior turns, so this
- # also exercises the conversation-history wiring.
- PROMPTS = [
- "What is 1+1?",
- "What did I ask before?",
- "What is the capital of France?",
- "Repeat the city name",
- ]
-
- def run_openai():
- client = OpenAI(base_url = f"{BASE}/v1", api_key = KEY)
- history, replies = [], []
- for prompt in PROMPTS:
- history.append({"role": "user", "content": prompt})
- resp = client.chat.completions.create(
- model = "default",
- messages = history,
- temperature = 0.0,
- max_tokens = 80,
- seed = SEED,
- extra_body = {"enable_thinking": False},
- )
- text = resp.choices[0].message.content or ""
- replies.append(text)
- history.append({"role": "assistant", "content": text})
- return replies
-
- def run_anthropic():
- # Two SDK quirks vs. Studio:
- # 1. base_url must NOT include /v1 -- the SDK appends
- # /v1/messages itself; otherwise the request hits
- # /v1/v1/messages and 405s.
- # 2. The SDK sends `x-api-key` by default, but Studio's
- # auth layer is HTTPBearer-only. Override via
- # default_headers so Authorization: Bearer ... is
- # sent instead.
- client = Anthropic(
- base_url = BASE,
- api_key = "unused",
- default_headers = {"Authorization": f"Bearer {KEY}"},
- )
- history, replies = [], []
- for prompt in PROMPTS:
- history.append({"role": "user", "content": prompt})
- msg = client.messages.create(
- model = "default",
- max_tokens = 80,
- messages = history,
- temperature = 0.0,
- extra_body = {"seed": SEED, "enable_thinking": False},
- )
- text = "".join(b.text for b in msg.content if getattr(b, "type", None) == "text")
- replies.append(text)
- history.append({"role": "assistant", "content": text})
- return replies
-
- for label, runner in (("openai", run_openai), ("anthropic", run_anthropic)):
- first = runner()
- second = runner()
- determinism_failures = []
- for i, (a, b) in enumerate(zip(first, second), start = 1):
- print(f"[{label} turn {i}] {a!r}")
- # Both runs must be non-empty; small-quant drift
- # across runs is WARN-only (grounding asserts below
- # are the stronger signal).
- assert a, f"{label}: empty turn {i} response in first run"
- assert b, f"{label}: empty turn {i} response in second run"
- if a.strip() != b.strip():
- determinism_failures.append(
- f"turn {i}: run1={a!r} run2={b!r}"
- )
- if determinism_failures:
- print(
- f"[{label}] WARN non-determinism at temperature=0.0 across "
- f"{len(determinism_failures)} of {len(first)} turn(s); "
- f"small-quant model drift, not a Studio regression. "
- f"Details: " + " | ".join(determinism_failures)
- )
- # Sanity: turn-2 reply should mention the earlier question, and
- # turn-4 reply should mention Paris (model echoes the city it
- # produced for turn 3). Lower-cased substring checks keep the
- # assertion robust to formatting jitter.
- joined = " ".join(first).lower()
- assert "1" in first[0], f"{label}: turn-1 answer should contain '1', got {first[0]!r}"
- assert "paris" in joined, f"{label}: expected 'paris' somewhere in the four-turn transcript: {first}"
- status_word = "PASS" if not determinism_failures else "PASS (with drift)"
- print(f"[{label}] {status_word} -- 4 turns, history grounded ('paris' present)")
- PY
-
- - name: Stop Studio
- if: always()
- run: |
- kill "${STUDIO_PID}" 2>/dev/null || true
- sleep 2
- ss -tln | grep ":${STUDIO_PORT}" || true
-
- - name: Upload logs
- # Always upload so green runs are still reviewable.
- if: always()
- uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
- with:
- name: openai-anthropic-log
- path: |
- logs/studio.log
- logs/install.log
- retention-days: 7
-
- # ─────────────────────────────────────────────────────────────────────
- # Job 2: Tool calling Tests
- # ─────────────────────────────────────────────────────────────────────
- tool-calling:
- name: Tool calling Tests
- runs-on: ubuntu-latest
- timeout-minutes: 25
- env:
- # Tool calling is the highest-volume GGUF in this workflow
- # (Qwen3.5-2B at IQ3_XXS = ~890 MiB). Caching HF_HOME would
- # store xet chunks + blobs + snapshots = ~4 GiB compressed --
- # 4-5x file-size inflation, dominated by xet chunks. Use main's
- # `--local-dir gguf-cache` pattern to cache the flat .gguf only.
- # Studio's /api/inference/load accepts either a HF repo (which
- # uses HF_HOME) or an absolute file path; passing the absolute
- # path keeps the test off HF_HOME entirely so the cache size
- # tracks the GGUF file 1:1. The OpenAI/Anth and JSON+images
- # jobs still cover the gguf_variant resolution path.
- GGUF_REPO: unsloth/Qwen3.5-2B-GGUF
- GGUF_FILE: Qwen3.5-2B-UD-IQ3_XXS.gguf
- STUDIO_PORT: '18889'
- steps:
- - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- with:
- persist-credentials: false
-
- - name: Linux deps for llama.cpp prebuilt
- run: |
- sudo apt-get update
- sudo apt-get install -y --no-install-recommends \
- libcurl4-openssl-dev libssl-dev jq
-
- - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
- with:
- node-version: '22'
-
- - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
- with:
- python-version: '3.12'
- cache: 'pip'
-
- - name: Restore GGUF model file
- id: cache-gguf
- uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
- continue-on-error: true
- with:
- path: gguf-cache
- key: ${{ runner.os }}-gguf-${{ env.GGUF_REPO }}-${{ env.GGUF_FILE }}-v1
-
- - name: Download GGUF if cache miss
- id: download-gguf
- if: steps.cache-gguf.outputs.cache-hit != 'true' || steps.cache-gguf.outcome != 'success'
- env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
- run: |
- python -m pip install --upgrade huggingface_hub
- mkdir -p gguf-cache
- bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$GGUF_FILE" gguf-cache
-
- - name: Save GGUF model file
- if: always() && steps.download-gguf.outcome == 'success'
- uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
- with:
- path: gguf-cache
- key: ${{ runner.os }}-gguf-${{ env.GGUF_REPO }}-${{ env.GGUF_FILE }}-v1
-
- - name: Install Studio (--local, --no-torch)
- env:
- GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- run: |
- mkdir -p logs
- set -o pipefail
- bash install.sh --local --no-torch 2>&1 | tee logs/install.log
-
- - name: Reset auth + boot Studio (API-only, default tool policy)
- # We deliberately use the API-only mode rather than
- # `unsloth studio run` because the latter calls
- # `set_tool_policy(...)` with a resolved bool: on loopback the
- # default resolves to True, which forces every request through
- # the server-side agentic loop and breaks the standard
- # function-calling test below. API-only mode leaves
- # tool_policy=None so each request's `enable_tools` field is
- # honoured.
- run: |
- unsloth studio reset-password
- mkdir -p logs
- UNSLOTH_API_ONLY=1 unsloth studio -H 127.0.0.1 -p "$STUDIO_PORT" \
- > logs/studio.log 2>&1 &
- echo "STUDIO_PID=$!" >> "$GITHUB_ENV"
-
- - name: Wait for /api/health, log in, change password, load model
- run: |
- for i in $(seq 1 180); do
- if curl -fs "http://127.0.0.1:${STUDIO_PORT}/api/health" > /tmp/health.json; then
- jq -e '.status == "healthy"' /tmp/health.json && break
- fi
- sleep 1
- done
- jq -e '.status == "healthy"' /tmp/health.json
- OLD=$(cat ~/.unsloth/studio/auth/.bootstrap_password)
- NEW="CITool-$(python -c 'import secrets; print(secrets.token_urlsafe(12))')"
- echo "::add-mask::$OLD"
- echo "::add-mask::$NEW"
- OLD_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \
- -H 'content-type: application/json' \
- -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}" | jq -r .access_token)
- curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/change-password" \
- -H "Authorization: Bearer $OLD_TOKEN" -H 'content-type: application/json' \
- -d "{\"current_password\":\"$OLD\",\"new_password\":\"$NEW\"}" > /dev/null
- TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \
- -H 'content-type: application/json' \
- -d "{\"username\":\"unsloth\",\"password\":\"$NEW\"}" | jq -r .access_token)
- echo "API_KEY=$TOKEN" >> "$GITHUB_ENV"
- GGUF_PATH="$GITHUB_WORKSPACE/gguf-cache/${GGUF_FILE}"
- ls -lh "$GGUF_PATH"
- curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/inference/load" \
- -H "Authorization: Bearer $TOKEN" -H 'content-type: application/json' \
- --max-time 600 \
- -d "{\"model_path\":\"$GGUF_PATH\",\"is_lora\":false,\"max_seq_length\":2048}" \
- | jq '{status, display_name}'
-
- - name: Tool calling, server-side tools, thinking on/off
- env:
- BASE_URL: http://127.0.0.1:18889
- run: |
- python - <<'PY'
- import json
- import os
- import urllib.request
-
- BASE = os.environ["BASE_URL"]
- KEY = os.environ["API_KEY"]
- SEED = 3407
-
- def post(path, body, *, timeout = 240):
- """Plain JSON POST. For requests that don't go through
- the server-side agentic loop, the response is one JSON
- object."""
- data = json.dumps(body).encode()
- req = urllib.request.Request(
- f"{BASE}{path}",
- data = data,
- method = "POST",
- headers = {
- "Authorization": f"Bearer {KEY}",
- "Content-Type": "application/json",
- },
- )
- with urllib.request.urlopen(req, timeout = timeout) as resp:
- return resp.status, json.loads(resp.read().decode())
-
- def post_sse(path, body, *, timeout = 600):
- """POST a streaming request and accumulate the assistant
- text deltas. The server-side agentic loop ALWAYS returns
- SSE regardless of the request's `stream` field, so any
- call with enable_tools=true must use this helper.
-
- Returns (content, raw_payloads):
- content -- concatenated assistant delta.content
- raw_payloads -- list of every raw "data: ..." event
- payload (JSON strings). Callers asserting
- that a server-side tool actually ran (and
- not just that the model emitted some
- text) should grep raw_payloads for tool
- invocation markers / tool output, since
- `delta.content` alone is not evidence
- that the tool path executed.
- """
- body = {**body, "stream": True}
- data = json.dumps(body).encode()
- req = urllib.request.Request(
- f"{BASE}{path}",
- data = data,
- method = "POST",
- headers = {
- "Authorization": f"Bearer {KEY}",
- "Content-Type": "application/json",
- },
- )
- parts = []
- events = []
- with urllib.request.urlopen(req, timeout = timeout) as resp:
- for raw in resp:
- line = raw.decode().strip()
- if not line.startswith("data: "):
- continue
- payload = line[6:]
- if payload == "[DONE]":
- break
- events.append(payload)
- try:
- chunk = json.loads(payload)
- except json.JSONDecodeError:
- continue
- for choice in chunk.get("choices", []):
- delta = choice.get("delta", {}) or {}
- if delta.get("content"):
- parts.append(delta["content"])
- return "".join(parts), events
-
- _STUDIO_TOOL_TYPES = {
- "tool_start", "tool_end", "tool_use", "tool_result",
- }
-
- def _tool_invoked(events):
- """Structural check: True iff some SSE payload is a real
- tool envelope (Studio tool_start/tool_end, Anthropic
- tool_use/tool_result, OpenAI non-empty delta.tool_calls /
- message.tool_calls / finish_reason='tool_calls' /
- role:'tool' / function_call). tool_status is NOT
- evidence: Studio emits empty tool_status events on
- iteration boundaries even when no tool ran.
- """
- for raw in events:
- try:
- ev = json.loads(raw)
- except (json.JSONDecodeError, TypeError):
- continue
- if not isinstance(ev, dict):
- continue
- if ev.get("type") in _STUDIO_TOOL_TYPES:
- return True
- for choice in ev.get("choices", []) or []:
- if not isinstance(choice, dict):
- continue
- if choice.get("finish_reason") == "tool_calls":
- return True
- for src_key in ("delta", "message"):
- src = choice.get(src_key) or {}
- if not isinstance(src, dict):
- continue
- tc = src.get("tool_calls")
- if isinstance(tc, list) and tc:
- return True
- if src.get("function_call"):
- return True
- if src.get("role") == "tool":
- return True
- for item in ev.get("output", []) or []:
- if isinstance(item, dict) and item.get("type") in {
- "tool_call", "function_call", "tool_use",
- }:
- return True
- content = ev.get("content")
- if isinstance(content, list):
- for blk in content:
- if isinstance(blk, dict) and blk.get("type") in {
- "tool_use", "tool_result",
- }:
- return True
- return False
-
- def _tool_output_contains(events, *needles):
- """True iff any tool_end.result / tool_result.content /
- tool-role message content contains a needle. Inspects
- the tool's own output, not the model's narration."""
- for raw in events:
- try:
- ev = json.loads(raw)
- except (json.JSONDecodeError, TypeError):
- continue
- if not isinstance(ev, dict):
- continue
- if ev.get("type") == "tool_end":
- result = ev.get("result")
- if isinstance(result, str) and any(n in result for n in needles if n):
- return True
- if ev.get("type") == "tool_result":
- content = ev.get("content")
- if isinstance(content, str) and any(n in content for n in needles if n):
- return True
- if isinstance(content, list):
- for blk in content:
- if isinstance(blk, dict):
- text = blk.get("text") or blk.get("content")
- if isinstance(text, str) and any(n in text for n in needles if n):
- return True
- for choice in ev.get("choices", []) or []:
- delta = (choice or {}).get("delta") or {}
- msg = (choice or {}).get("message") or {}
- for src in (delta, msg):
- if src.get("role") == "tool":
- content = src.get("content") or ""
- if isinstance(content, str) and any(n in content for n in needles if n):
- return True
- return False
-
- # ── 1. Standard OpenAI function calling ──────────────────────
- weather_tool = {
- "type": "function",
- "function": {
- "name": "get_weather",
- "description": "Get current weather for a city.",
- "parameters": {
- "type": "object",
- "properties": {"city": {"type": "string"}},
- "required": ["city"],
- },
- },
- }
-
- status, data = post("/v1/chat/completions", {
- "messages": [{"role": "user", "content": "What is the weather in Paris?"}],
- "tools": [weather_tool],
- "tool_choice": "required",
- "stream": False,
- "temperature": 0.0,
- "seed": SEED,
- "max_tokens": 120,
- })
- assert status == 200, f"tool call status {status}: {data}"
- choice = data["choices"][0]
- assert choice["finish_reason"] == "tool_calls", f"finish_reason={choice['finish_reason']!r}"
- tc = choice["message"]["tool_calls"][0]
- assert tc["function"]["name"] == "get_weather"
- args = json.loads(tc["function"]["arguments"])
- assert args.get("city"), f"missing city arg: {args}"
- print(f"[tools] PASS function calling -> {tc['function']['name']}({args})")
-
- # T=0 = deterministic argmax in llama.cpp; T>0 lets seed
- # rotation explore distinct trajectories on retry.
- TOOL_PROBE_TEMP = 0.4
-
- def _run_tool_probe(*, label, prompt, enabled, session, needles,
- max_attempts = 4):
- """Drive a server-side tool with retries. Hard FAIL if no
- attempt has structural invocation evidence. WARN (not
- FAIL) if invoked but no attempt produces the expected
- literal in tool_end.result -- small-quant Qwen3.5-2B can
- emit OpenAI tool_calls deltas without Studio's GGUF
- agentic loop intercepting them, and that GGUF-vs-OpenAI
- format mismatch is out of scope for #5642.
- """
- attempts_log = []
- best = None
- for attempt_i in range(max_attempts):
- attempt_seed = SEED + attempt_i
- content, events = post_sse("/v1/chat/completions", {
- "messages": [{"role": "user", "content": prompt}],
- "enable_tools": True,
- "enabled_tools": enabled,
- "session_id": f"{session}-att{attempt_i}",
- "temperature": TOOL_PROBE_TEMP,
- "seed": attempt_seed,
- "max_tokens": 600,
- })
- invoked = _tool_invoked(events)
- produced = _tool_output_contains(events, *needles)
- attempts_log.append({
- "attempt": attempt_i, "seed": attempt_seed,
- "n_events": len(events),
- "tool_invoked": invoked, "tool_output_contains": produced,
- "content_len": len(content),
- })
- if invoked and produced:
- print(f"[tools] PASS {label} attempt {attempt_i}")
- return content, events, attempts_log
- if invoked and best is None:
- best = (content, events)
- print(f"[tools] retry {label} attempt {attempt_i}: invoked={invoked} output_ok={produced} events={len(events)}")
- if best is not None:
- print(f"[tools] WARN {label}: invoked but no tool_end.result match (small-quant flake). Attempts: {attempts_log}")
- content, events = best
- return content, events, attempts_log
- raise AssertionError(
- f"{label}: no structural tool-invocation evidence across "
- f"{max_attempts} attempts. enable_tools may be silently "
- f"ignored. Attempts: {attempts_log}"
- )
-
- # ── 2. Server-side python tool ───────────────────────────────
- content, events, _attempts = _run_tool_probe(
- label = "python tool",
- prompt = "What is 123 * 456? Use the python tool to compute it and tell me the number.",
- enabled = ["python"],
- session = "ci-tool-calling-py",
- needles = ("56088", "56,088"),
- )
- if "56088" in content or "56,088" in content:
- print(f"[tools] python tool narration OK")
- else:
- print(f"[tools] python tool narration drifted -- content={content!r}")
-
- # ── 3. Server-side bash (terminal) tool ──────────────────────
- content, events, _attempts = _run_tool_probe(
- label = "bash/terminal tool",
- prompt = "Use the terminal tool to run `echo hello-bash-tool` and tell me the exact output.",
- enabled = ["terminal"],
- session = "ci-tool-calling-bash",
- needles = ("hello-bash-tool",),
- )
- if "hello-bash-tool" in content:
- print(f"[tools] bash/terminal narration OK")
- else:
- print(f"[tools] bash/terminal narration dropped literal -- content={content!r}")
-
- # ── 4. Server-side web_search tool ───────────────────────────
- # DuckDuckGo is flaky from CI runners and small Qwen3.5-2B
- # may not actually search. Only assert that the SSE stream
- # opens and yields any data; HTTP / parser failures already
- # raise above. Tool-invocation strictness is relaxed here
- # because (a) the search may legitimately return no results,
- # and (b) DuckDuckGo upstream blocks GHA IP ranges often
- # enough that requiring a tool_call marker would create
- # red-herring failures from infra rather than from Studio.
- try:
- content, events = post_sse("/v1/chat/completions", {
- "messages": [{"role": "user", "content": "Search the web for 'unsloth ai github' and summarise."}],
- "enable_tools": True,
- "enabled_tools": ["web_search"],
- "session_id": "ci-tool-calling-web",
- "temperature": 0.0,
- "seed": SEED,
- "max_tokens": 400,
- })
- print(
- f"[tools] PASS web_search stream ({len(content)} chars in content, "
- f"{len(events)} raw events)"
- )
- except Exception as exc:
- print(f"[tools] WARN web_search probe failed (non-blocking): {exc}")
-
- # ── 5. Thinking on / off ─────────────────────────────────────
- # Studio strips think blocks from message.content for tools-mode
- # responses, so we toggle plain chat (no enable_tools) and look
- # at the surfaced reasoning_content / message.thinking field.
- def thinking_call(enable):
- status, data = post("/v1/chat/completions", {
- "messages": [{"role": "user", "content": "Briefly: is 17 prime?"}],
- "stream": False,
- "enable_thinking": enable,
- "temperature": 0.0,
- "seed": SEED,
- "max_tokens": 300,
- })
- assert status == 200
- msg = data["choices"][0]["message"]
- # Studio surfaces thinking via reasoning_content (OpenAI
- # extension). Fall back to inline markers for
- # robustness across template versions.
- raw = (msg.get("content") or "") + (msg.get("reasoning_content") or "")
- return raw
-
- on_text = thinking_call(True)
- off_text = thinking_call(False)
- had_think_on = ("" in on_text) or len(on_text) > 80
- had_think_off = ("" in off_text) and len(off_text) > 0
- assert had_think_on, (
- f"enable_thinking=True produced no thinking signal: {on_text!r}"
- )
- # Off-mode should not contain the literal marker.
- assert "" not in off_text, (
- f"enable_thinking=False but still present: {off_text!r}"
- )
- print(f"[tools] PASS thinking on/off (on={len(on_text)} chars, off={len(off_text)} chars)")
- PY
-
- - name: Stop Studio
- if: always()
- run: |
- kill "${STUDIO_PID}" 2>/dev/null || true
- sleep 2
- ss -tln | grep ":${STUDIO_PORT}" || true
-
- - name: Upload logs
- # Always upload so green runs are still reviewable.
- if: always()
- uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
- with:
- name: tool-calling-log
- path: |
- logs/studio.log
- logs/install.log
- retention-days: 7
-
- # ─────────────────────────────────────────────────────────────────────
- # Job 3: JSON, images
- # ─────────────────────────────────────────────────────────────────────
- json-images:
- name: JSON, images
- runs-on: ubuntu-latest
- timeout-minutes: 30
- env:
- GGUF_REPO: unsloth/gemma-4-E2B-it-GGUF
- GGUF_VARIANT: UD-IQ3_XXS
- GGUF_FILE: gemma-4-E2B-it-UD-IQ3_XXS.gguf
- MMPROJ_FILE: mmproj-F16.gguf
- STUDIO_PORT: '18890'
- HF_HOME: ${{ github.workspace }}/hf-cache
- steps:
- - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- with:
- persist-credentials: false
-
- - name: Linux deps for llama.cpp prebuilt
- run: |
- sudo apt-get update
- sudo apt-get install -y --no-install-recommends \
- libcurl4-openssl-dev libssl-dev jq
-
- - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
- with:
- node-version: '22'
-
- - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
- with:
- python-version: '3.12'
- cache: 'pip'
-
- - name: Restore HF_HOME for ${{ env.GGUF_REPO }} (model + mmproj)
- id: cache-hf
- uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
- continue-on-error: true
- with:
- path: hf-cache
- key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-${{ env.MMPROJ_FILE }}-v1
-
- - name: Prime HF_HOME with the GGUF + mmproj
- id: prime-hf
- if: steps.cache-hf.outputs.cache-hit != 'true' || steps.cache-hf.outcome != 'success'
- env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
- run: |
- python -m pip install --upgrade huggingface_hub
- mkdir -p hf-cache
- bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$GGUF_FILE"
- bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$MMPROJ_FILE"
-
- - name: Save HF_HOME for ${{ env.GGUF_REPO }} (model + mmproj)
- if: always() && steps.prime-hf.outcome == 'success'
- uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
- with:
- path: hf-cache
- key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-${{ env.MMPROJ_FILE }}-v1
-
- - name: Install Studio (--local, --no-torch)
- env:
- GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- run: |
- mkdir -p logs
- set -o pipefail
- bash install.sh --local --no-torch 2>&1 | tee logs/install.log
-
- - name: Install OpenAI + Anthropic Python SDKs
- run: pip install 'openai>=1.50' 'anthropic>=0.40'
-
- - name: Reset auth + boot Studio (API-only)
- # See Job 2's comment: API-only mode keeps tool_policy=None so
- # response_format requests aren't routed through the agentic
- # tool loop.
- run: |
- unsloth studio reset-password
- mkdir -p logs
- UNSLOTH_API_ONLY=1 unsloth studio -H 127.0.0.1 -p "$STUDIO_PORT" \
- > logs/studio.log 2>&1 &
- echo "STUDIO_PID=$!" >> "$GITHUB_ENV"
-
- - name: Wait for /api/health, log in, change password, load model
- run: |
- for i in $(seq 1 180); do
- if curl -fs "http://127.0.0.1:${STUDIO_PORT}/api/health" > /tmp/health.json; then
- jq -e '.status == "healthy"' /tmp/health.json && break
- fi
- sleep 1
- done
- jq -e '.status == "healthy"' /tmp/health.json
- OLD=$(cat ~/.unsloth/studio/auth/.bootstrap_password)
- NEW="CIJson-$(python -c 'import secrets; print(secrets.token_urlsafe(12))')"
- echo "::add-mask::$OLD"
- echo "::add-mask::$NEW"
- OLD_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \
- -H 'content-type: application/json' \
- -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}" | jq -r .access_token)
- curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/change-password" \
- -H "Authorization: Bearer $OLD_TOKEN" -H 'content-type: application/json' \
- -d "{\"current_password\":\"$OLD\",\"new_password\":\"$NEW\"}" > /dev/null
- TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \
- -H 'content-type: application/json' \
- -d "{\"username\":\"unsloth\",\"password\":\"$NEW\"}" | jq -r .access_token)
- echo "API_KEY=$TOKEN" >> "$GITHUB_ENV"
- # Load the GGUF (mmproj is auto-detected via the HF repo
- # lookup, the cached file is pulled out of HF_HOME).
- curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/inference/load" \
- -H "Authorization: Bearer $TOKEN" -H 'content-type: application/json' \
- --max-time 900 \
- -d "{\"model_path\":\"$GGUF_REPO\",\"gguf_variant\":\"$GGUF_VARIANT\",\"is_lora\":false,\"max_seq_length\":2048}" \
- | jq '{status, display_name, is_vision}'
-
- - name: JSON schema decoding + image input
- env:
- BASE_URL: http://127.0.0.1:18890
- run: |
- python - <<'PY'
- import base64
- import json
- import os
- import urllib.request
- from openai import OpenAI
- from anthropic import Anthropic
-
- BASE = os.environ["BASE_URL"]
- KEY = os.environ["API_KEY"]
- SEED = 3407
-
- def post(path, body, *, timeout = 240):
- req = urllib.request.Request(
- f"{BASE}{path}",
- data = json.dumps(body).encode(),
- method = "POST",
- headers = {
- "Authorization": f"Bearer {KEY}",
- "Content-Type": "application/json",
- },
- )
- with urllib.request.urlopen(req, timeout = timeout) as resp:
- return resp.status, json.loads(resp.read().decode())
-
- # ── 1. response_format = json_object (JSON mode) ─────────────
- # llama.cpp's HTTP server supports OpenAI-compatible JSON
- # mode: `response_format: {"type": "json_object"}` constrains
- # the model to emit syntactically-valid JSON. We use raw HTTP
- # rather than the OpenAI SDK so that the field shape Studio
- # forwards to llama-server is unambiguous (the SDK rewrites
- # response_format depending on which variant it recognises).
- # We deliberately do NOT pass a strict JSON schema -- on
- # small Gemma-4 quants the GBNF-from-schema path occasionally
- # produces empty output, and JSON mode is the surface we care
- # about exposing through Studio.
- status, data = post("/v1/chat/completions", {
- "model": "default",
- "messages": [
- {"role": "system", "content": 'Reply with a single JSON object of the form {"city": "...", "country": "..."}. Output ONLY the JSON, nothing else.'},
- {"role": "user", "content": "What is the capital of France?"},
- ],
- "temperature": 0.0,
- "max_tokens": 200,
- "seed": SEED,
- "stream": False,
- "enable_thinking": False,
- "response_format": {"type": "json_object"},
- }, timeout = 600)
- assert status == 200, f"json status {status}: {data}"
- content = (data["choices"][0]["message"].get("content") or "").strip()
- # Some chat templates wrap JSON in ```json fences even in JSON
- # mode -- strip those before parsing.
- if content.startswith("```"):
- content = content.split("```", 2)[1]
- if content.startswith("json"):
- content = content[4:]
- content = content.strip("`\n ")
- parsed = json.loads(content)
- assert "paris" in str(parsed.get("city", "")).lower(), (
- f"city != Paris: {parsed}"
- )
- print(f"[json] PASS json_object -> {parsed}")
-
- # ── 2. OpenAI image_url (data URI base64) ───────────────────
- # 64x64 solid-red PNG. stb_image (used by Studio's image
- # normaliser at routes/inference.py:3410) rejects 4x4 or
- # smaller PNGs as truncated, so we go up to 64x64 -- still
- # tiny in token cost. The assertion is loose: any non-empty
- # response from the vision path proves multimodal end-to-end
- # wiring; small VL quants are weak at colour identification.
- PNG_64X64_RED_B64 = (
- "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAIAAAAlC+aJAAAAYklEQVR4nO3PMQ0AIADAMEAI/k"
- "UhBhEcDcmqYJtn7/GzpQNeNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA"
- "1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaBdCJ0BmMJ25zMAAAAASUVORK5CYII="
- )
- data_uri = f"data:image/png;base64,{PNG_64X64_RED_B64}"
-
- client = OpenAI(base_url = f"{BASE}/v1", api_key = KEY)
- openai_resp = client.chat.completions.create(
- model = "default",
- temperature = 0.0,
- max_tokens = 80,
- seed = SEED,
- messages = [{
- "role": "user",
- "content": [
- {"type": "image_url", "image_url": {"url": data_uri}},
- {"type": "text", "text": "What colour dominates this image? Reply in one word."},
- ],
- }],
- )
- openai_text = (openai_resp.choices[0].message.content or "").lower()
- print(f"[image/openai] reply: {openai_text!r}")
- assert openai_text, "OpenAI image_url returned empty content"
- # We do not strictly require 'red' -- some quants of small VL
- # models are weak at colour names. Just require a non-empty
- # answer; the vision path is the part under test.
- print("[image/openai] PASS image_url accepted, non-empty response")
-
- # ── 3. Anthropic source/base64 image ────────────────────────
- # Two SDK quirks vs. Studio: base_url must NOT include /v1
- # (the SDK appends it itself; otherwise /v1/v1/messages -> 405),
- # and Studio's auth is HTTPBearer-only so the SDK's default
- # x-api-key header is ignored -- send Authorization: Bearer
- # via default_headers.
- anthropic = Anthropic(
- base_url = BASE,
- api_key = "unused",
- default_headers = {"Authorization": f"Bearer {KEY}"},
- )
- a_msg = anthropic.messages.create(
- model = "default",
- max_tokens = 80,
- temperature = 0.0,
- extra_body = {"seed": SEED},
- messages = [{
- "role": "user",
- "content": [
- {
- "type": "image",
- "source": {
- "type": "base64",
- "media_type": "image/png",
- "data": PNG_64X64_RED_B64,
- },
- },
- {"type": "text", "text": "Describe this image briefly."},
- ],
- }],
- )
- a_text = "".join(b.text for b in a_msg.content if getattr(b, "type", None) == "text")
- print(f"[image/anthropic] reply: {a_text!r}")
- assert a_text, "Anthropic source/base64 returned empty content"
- print("[image/anthropic] PASS source/base64 accepted, non-empty response")
- PY
-
- - name: Stop Studio
- if: always()
- run: |
- kill "${STUDIO_PID}" 2>/dev/null || true
- sleep 2
- ss -tln | grep ":${STUDIO_PORT}" || true
-
- - name: Upload logs
- # Always upload so green runs are still reviewable.
- if: always()
- uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
- with:
- name: json-images-log
- path: |
- logs/studio.log
- logs/install.log
- retention-days: 7
diff --git a/.github/workflows/studio-tauri-smoke.yml b/.github/workflows/studio-tauri-smoke.yml
deleted file mode 100644
index 1156c264ae..0000000000
--- a/.github/workflows/studio-tauri-smoke.yml
+++ /dev/null
@@ -1,128 +0,0 @@
-# SPDX-License-Identifier: AGPL-3.0-only
-# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
-
-# PR-time smoke for the Tauri desktop wrapper. Builds the frontend and the
-# Tauri Linux debug binary, with no codesigning. Catches:
-# - tauri.conf.json drift
-# - src-tauri Cargo.toml or rust source breakage
-# - Tauri CLI version drift (we pin 2.10.1, matching release-desktop.yml)
-# - frontend output not picked up by Tauri's distDir
-#
-# Linux-only on a free `ubuntu-latest` runner. Mac and Windows desktop builds
-# stay in release-desktop.yml (manual `workflow_dispatch`) because they need
-# code-signing secrets and ~30 min of runner time each.
-
-name: Studio Tauri CI
-
-on:
- pull_request:
- paths:
- - 'studio/frontend/**'
- - 'studio/src-tauri/**'
- # CLI rename / signature change can break Tauri's spawned
- # `unsloth studio` -- include unsloth_cli in the trigger set.
- - 'unsloth_cli/**'
- - '.github/workflows/studio-tauri-smoke.yml'
- push:
- branches: [main, pip]
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: true
-
-permissions:
- contents: read
-
-jobs:
- linux-debug-build:
- name: Tauri Linux debug build (no codesign)
- runs-on: ubuntu-22.04
- timeout-minutes: 25
- steps:
- - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- with:
- persist-credentials: false
-
- - name: Linux native deps for Tauri / WebKit2GTK
- run: |
- sudo apt-get update
- sudo apt-get install -y \
- libwebkit2gtk-4.1-dev libayatana-appindicator3-dev \
- librsvg2-dev libxdo-dev libssl-dev patchelf
-
- - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
- with:
- node-version: '24'
-
- - uses: dtolnay/rust-toolchain@29eef336d9b2848a0b548edc03f92a220660cdb8 # stable @ 2026-03-27
-
- - uses: swatinem/rust-cache@e18b497796c12c097a38f9edb9d0641fb99eee32 # v2.9.1
- with:
- workspaces: studio/src-tauri -> target
-
- - name: Install pinned Tauri CLI (matches release-desktop.yml)
- # Lifecycle scripts (esbuild native-binary postinstall, etc.) are
- # required for `vite build`. The pre-install lockfile structural
- # audit (lockfile_supply_chain_audit.py) is the practical defence
- # against the npm postinstall-dropper class -- it fires BEFORE any
- # tarball runs, on the injection pattern itself rather than an
- # advisory-DB lookup.
- run: npm install --save-dev --prefix studio @tauri-apps/cli@2.10.1 --no-fund --no-audit
-
- - name: Verify pinned Tauri CLI version
- run: |
- out="$(npx --prefix studio tauri --version)"
- echo "$out"
- [ "$out" = "tauri-cli 2.10.1" ] || { echo "::error::expected tauri-cli 2.10.1, got $out"; exit 1; }
-
- - name: Lockfile supply-chain audit (pre-install scan)
- run: python3 scripts/lockfile_supply_chain_audit.py
-
- - name: Frontend build (npm ci, vite)
- working-directory: studio/frontend
- # Lifecycle scripts (esbuild native-binary postinstall, etc.) are
- # required for `vite build`. The pre-install lockfile structural
- # audit (lockfile_supply_chain_audit.py) is the practical defence
- # against the npm postinstall-dropper class -- it fires BEFORE any
- # tarball runs, on the injection pattern itself rather than an
- # advisory-DB lookup.
- run: |
- npm ci --no-fund --no-audit
- npm run build
- test -f dist/index.html
-
- - name: Tauri debug build (Linux, no bundle, no codesign)
- # `--debug` + `--no-bundle` keeps this lean: compiles the Rust crate,
- # confirms the frontend dist is wired into Tauri, but skips the AppImage
- # / .deb production. Code signing is irrelevant because we never produce
- # a distributable artifact.
- env:
- TAURI_SIGNING_PRIVATE_KEY: ''
- TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ''
- run: npx --prefix studio tauri build --debug --no-bundle
-
- - name: Inspect produced binary
- run: |
- BIN=$(find studio/src-tauri/target/debug -maxdepth 1 -type f -executable 2>/dev/null \
- | grep -Ev '\.(d|so|dylib|dll)$' \
- | grep -Ev '/(deps|build|examples)$' \
- | head -1)
- echo "binary: $BIN"
- if [ -z "$BIN" ]; then
- echo "::error::Tauri debug binary not produced"
- ls -la studio/src-tauri/target/debug/ || true
- exit 1
- fi
- file "$BIN"
- du -h "$BIN"
-
- - name: Upload Tauri debug build
- # Always upload so a green run leaves the binary inspectable too.
- if: always()
- uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
- with:
- name: tauri-debug-build
- path: |
- studio/src-tauri/target/debug
- studio/frontend/dist
- retention-days: 3
diff --git a/.github/workflows/wheel-smoke.yml b/.github/workflows/wheel-smoke.yml
deleted file mode 100644
index 3de3c33ca2..0000000000
--- a/.github/workflows/wheel-smoke.yml
+++ /dev/null
@@ -1,136 +0,0 @@
-# SPDX-License-Identifier: AGPL-3.0-only
-# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved.
-
-# Builds the PyPI wheel from the PR branch, then verifies the built wheel
-# actually contains what we expect to ship and does NOT contain the broken
-# Studio bundle that 2026.5.1 published. This is the single workflow that
-# would have blocked the 2026.5.1 release before twine upload.
-#
-# Verified locally end-to-end against this branch:
-# - python -m build produces unsloth--py3-none-any.whl in 13s
-# - wheel content sanity passes:
-# lockfile shipped, frontend dist shipped,
-# no node_modules in wheel, no bun.lock in wheel,
-# main bundle has unstable_Provider hits=1 (assistant-ui internals only).
-# - Studio backend imports cleanly from the installed wheel with the
-# lightweight dep set below.
-
-name: Wheel CI
-
-on:
- pull_request:
- paths:
- - 'pyproject.toml'
- - 'studio/**'
- - 'unsloth/**'
- - 'unsloth_cli/**'
- - '.github/workflows/wheel-smoke.yml'
- push:
- branches: [main, pip]
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: true
-
-permissions:
- contents: read
-
-jobs:
- wheel:
- name: Wheel build + content sanity + import smoke
- runs-on: ubuntu-latest
- timeout-minutes: 15
- steps:
- - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- with:
- persist-credentials: false
-
- - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
- with:
- node-version: '22'
-
- - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
- with:
- python-version: '3.12'
-
- - name: Lockfile supply-chain audit (pre-install scan)
- run: python3 scripts/lockfile_supply_chain_audit.py
-
- - name: Build frontend
- # Lifecycle scripts (esbuild native-binary postinstall, etc.) are
- # required for `vite build`. The pre-install lockfile structural
- # audit (lockfile_supply_chain_audit.py) is the practical defence
- # against the npm postinstall-dropper class -- it fires BEFORE any
- # tarball runs, on the injection pattern itself rather than an
- # advisory-DB lookup.
- run: |
- cd studio/frontend
- npm ci --no-fund --no-audit
- npm run build
-
- - name: Build wheel + sdist
- run: |
- python -m pip install --upgrade pip build
- rm -rf dist build ./*.egg-info
- python -m build
-
- - name: Wheel content sanity
- run: |
- python - <<'PY'
- import zipfile, glob, sys
- w = glob.glob("dist/unsloth-*.whl")
- if not w:
- print("FAIL: no wheel produced"); sys.exit(2)
- w = w[0]
- print(f"wheel: {w}")
- with zipfile.ZipFile(w) as z:
- n = z.namelist()
- checks = {
- "lockfile shipped": any(s.endswith("studio/frontend/package-lock.json") for s in n),
- "frontend dist shipped": any(s.endswith("studio/frontend/dist/index.html") for s in n),
- "no node_modules": not any("studio/frontend/node_modules/" in s for s in n),
- "no bun.lock": not any(s.endswith("studio/frontend/bun.lock") for s in n),
- }
- js = [s for s in n
- if "studio/frontend/dist/assets/" in s
- and s.endswith(".js")
- and "/index-" in s]
- if not js:
- print("FAIL: no main bundle index-*.js in wheel"); sys.exit(2)
- data = z.read(js[0]).decode("utf-8", "replace")
- hits = data.count("unstable_Provider:")
- print(f"main bundle: {js[0]}")
- print(f"unstable_Provider hits: {hits} (>=4 indicates 2026.5.1 regression)")
- checks["bundle has no Studio unstable_Provider call site"] = (hits < 4)
-
- print()
- for k, v in checks.items():
- print(f" [{'PASS' if v else 'FAIL'}] {k}")
- sys.exit(0 if all(checks.values()) else 1)
- PY
-
- - name: Studio backend import smoke
- # Imports `studio.backend.main:app` from the freshly-installed wheel in
- # a clean venv. This catches the class of bug that 2026.5.1 shipped with:
- # frontend dist missing, package-lock.json missing, or the wheel's Python
- # source tree broken in a way that surfaces only at app construction time.
- run: |
- python -m venv /tmp/v
- /tmp/v/bin/pip install --upgrade pip
- /tmp/v/bin/pip install -r studio/backend/requirements/studio.txt
- /tmp/v/bin/pip install \
- python-multipart aiofiles sqlalchemy cryptography \
- pyyaml jinja2 mammoth unpdf requests \
- 'numpy<3'
- /tmp/v/bin/pip install --no-deps dist/unsloth-*.whl
- # Run from /tmp so Python imports the installed package, not the source tree.
- cd /tmp
- /tmp/v/bin/python -c "from studio.backend.main import app; print('Studio backend OK:', app.title)"
-
- - name: Upload wheel on failure
- if: failure()
- uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
- with:
- name: unsloth-wheel
- path: dist/
- retention-days: 7
diff --git a/.gitignore b/.gitignore
index a839633790..da33583e29 100644
--- a/.gitignore
+++ b/.gitignore
@@ -235,3 +235,9 @@ package-lock.json
!studio/backend/core/data_recipe/oxc-validator/package-lock.json
!studio/package-lock.json
llama.cpp/
+/.omc
+/studio/frontend/.omc
+/.codex
+/studio/.omc
+/studio/backend/.omc
+*.patch
diff --git a/studio/backend/assets/configs/model_defaults/other/deepseek-ai_DeepSeek-OCR.yaml b/studio/backend/assets/configs/model_defaults/other/deepseek-ai_DeepSeek-OCR.yaml
new file mode 100644
index 0000000000..b827a1f910
--- /dev/null
+++ b/studio/backend/assets/configs/model_defaults/other/deepseek-ai_DeepSeek-OCR.yaml
@@ -0,0 +1,22 @@
+# Model defaults for deepseek-ai/DeepSeek-OCR
+# Custom-code OCR vision model. Used by Studio chat as a temporary OCR
+# model swap during scanned-PDF extraction; never used for training.
+
+model:
+ identifier: deepseek-ai/DeepSeek-OCR
+ display_name: DeepSeek-OCR
+ is_vision: true
+ is_ocr: true
+
+training:
+ trust_remote_code: true
+ max_seq_length: 8192
+ packing: false
+
+inference:
+ trust_remote_code: true
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ min_p: 0.0
+ default_max_seq_length: 8192
diff --git a/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml b/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml
index b7587bbd91..bffb79902c 100644
--- a/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml
+++ b/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml
@@ -3,6 +3,12 @@
# Also applies to: unsloth/PaddleOCR-VL
# added inference parameters from unsloth notebook
+model:
+ identifier: unsloth/PaddleOCR-VL
+ display_name: PaddleOCR-VL
+ is_vision: true
+ is_ocr: true
+
training:
trust_remote_code: true
max_seq_length: 2048
@@ -50,6 +56,11 @@ logging:
inference:
trust_remote_code: true
- temperature: 1.5
- min_p: 0.1
+ # OCR is a closed-form transcription task; sibling OCR presets
+ # (DeepSeek-OCR, GLM-OCR) use deterministic decoding so the
+ # transcription is reproducible. Match that convention here.
+ temperature: 0.0
+ min_p: 0.0
+ top_p: 1.0
+ top_k: -1
diff --git a/studio/backend/assets/configs/model_defaults/other/zai-org_GLM-OCR.yaml b/studio/backend/assets/configs/model_defaults/other/zai-org_GLM-OCR.yaml
new file mode 100644
index 0000000000..2249aa4487
--- /dev/null
+++ b/studio/backend/assets/configs/model_defaults/other/zai-org_GLM-OCR.yaml
@@ -0,0 +1,22 @@
+# Model defaults for zai-org/GLM-OCR
+# GLM family OCR vision model with model_type "glm_ocr". Used by Studio chat
+# as a temporary OCR model swap during scanned-PDF extraction.
+
+model:
+ identifier: zai-org/GLM-OCR
+ display_name: GLM-OCR
+ is_vision: true
+ is_ocr: true
+
+training:
+ trust_remote_code: true
+ max_seq_length: 8192
+ packing: false
+
+inference:
+ trust_remote_code: true
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ min_p: 0.0
+ default_max_seq_length: 8192
diff --git a/studio/backend/core/chat/__init__.py b/studio/backend/core/chat/__init__.py
new file mode 100644
index 0000000000..8ce71de2e8
--- /dev/null
+++ b/studio/backend/core/chat/__init__.py
@@ -0,0 +1,65 @@
+# SPDX-License-Identifier: AGPL-3.0-only
+# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
+
+"""
+Chat-surface helpers that do not belong in ``core/inference`` (tightly
+coupled to model backends) and explicitly not in ``core/data_recipe``
+(owns dataset pipelines).
+
+Exposes the document-extraction pipeline used when a user drops a
+PDF / DOCX / HTML / MD / TXT file into the chat composer. PDF parsing
+uses PyMuPDF4LLM, DOCX uses mammoth. PPTX is not supported here —
+convert to PDF first.
+"""
+
+from __future__ import annotations
+
+from .document_extractor import (
+ DOCUMENT_EXTRACTION_AVAILABLE,
+ DEFAULT_DOCUMENT_VISUAL_PAYLOADS,
+ DocumentExtractionBusy,
+ DocumentExtractionCancelled,
+ DocumentExtractionEncrypted,
+ DocumentExtractionTimeout,
+ DocumentExtractionUnavailable,
+ ExtractedFigure,
+ ExtractResult,
+ _EXTRACT_CONCURRENCY,
+ MAX_DOCUMENT_VISUAL_PAYLOADS,
+ SUPPORTED_MIME_TYPES,
+ SUPPORTED_SUFFIXES,
+ _EXTRACT_SEMAPHORE,
+ _drain_future_exception,
+ document_parser_support,
+ document_parser_unavailable_reasons,
+ extract_document,
+)
+from .vlm_capability import (
+ VlmCapability,
+ detect_loaded_vlm,
+ extract_self_base_url,
+)
+
+__all__ = [
+ "DOCUMENT_EXTRACTION_AVAILABLE",
+ "DEFAULT_DOCUMENT_VISUAL_PAYLOADS",
+ "DocumentExtractionBusy",
+ "DocumentExtractionCancelled",
+ "DocumentExtractionEncrypted",
+ "DocumentExtractionTimeout",
+ "DocumentExtractionUnavailable",
+ "ExtractedFigure",
+ "ExtractResult",
+ "_EXTRACT_CONCURRENCY",
+ "MAX_DOCUMENT_VISUAL_PAYLOADS",
+ "SUPPORTED_MIME_TYPES",
+ "SUPPORTED_SUFFIXES",
+ "VlmCapability",
+ "_EXTRACT_SEMAPHORE",
+ "_drain_future_exception",
+ "detect_loaded_vlm",
+ "document_parser_support",
+ "document_parser_unavailable_reasons",
+ "extract_document",
+ "extract_self_base_url",
+]
diff --git a/studio/backend/core/chat/document_extractor.py b/studio/backend/core/chat/document_extractor.py
new file mode 100644
index 0000000000..915fc596c2
--- /dev/null
+++ b/studio/backend/core/chat/document_extractor.py
@@ -0,0 +1,1243 @@
+# SPDX-License-Identifier: AGPL-3.0-only
+# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
+
+"""
+Document extractor for the Chat composer.
+
+Given raw file bytes (PDF / DOCX / HTML / MD / TXT), produce Markdown
+suitable to splice into an outgoing chat message. When a vision-capable
+model is loaded, selected figures are captioned through our OpenAI-compatible
+``/v1/chat/completions`` surface after conversion.
+
+This build uses **PyMuPDF4LLM** (via ``pymupdf4llm`` / ``pymupdf``) for PDF
+parsing and **mammoth** for DOCX conversion. Plain-text and Markdown inputs
+are decoded as UTF-8 with replacement; HTML inputs are converted to Markdown.
+
+Notes and limitations:
+
+* **OCR is disabled.** There is no local OCR pass in this build, so scanned
+ PDFs without a text layer will yield empty or near-empty Markdown. The
+ ``use_vlm_ocr`` flag is still accepted for API compatibility; when set it
+ renders bounded page images so a loaded vision model can describe them.
+* **PPTX is not supported** in this build. ``SUPPORTED_SUFFIXES`` and
+ ``SUPPORTED_MIME_TYPES`` no longer advertise the PowerPoint types.
+* Parser dependencies are checked per format so plain-text, Markdown, and HTML
+ still work when optional PDF or DOCX libraries are missing.
+* If the loaded model is not vision-capable, image description is silently
+ skipped and ``figures`` comes back with captions set to ``None``;
+ ``describe_skipped_reason`` carries the diagnostic text.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import base64
+import inspect
+import io
+import logging
+import math
+import multiprocessing
+import os
+import queue
+import threading
+import time
+from dataclasses import dataclass, field, replace
+from typing import Any, Awaitable, Callable, Literal, List, Optional
+
+from .vlm_capability import VlmCapability, detect_loaded_vlm
+
+
+logger = logging.getLogger(__name__)
+
+
+SUPPORTED_MIME_TYPES = frozenset(
+ {
+ "application/pdf",
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
+ "application/json",
+ "application/x-ndjson",
+ "application/xml",
+ "application/yaml",
+ "application/javascript",
+ "text/html",
+ "text/markdown",
+ "text/plain",
+ "text/csv",
+ "text/css",
+ "text/javascript",
+ "text/xml",
+ "text/yaml",
+ }
+)
+
+SUPPORTED_SUFFIXES = frozenset(
+ {
+ ".pdf",
+ ".docx",
+ ".html",
+ ".htm",
+ ".md",
+ ".txt",
+ ".csv",
+ ".json",
+ ".jsonl",
+ ".yaml",
+ ".yml",
+ ".py",
+ ".js",
+ ".jsx",
+ ".ts",
+ ".tsx",
+ ".go",
+ ".rs",
+ ".java",
+ ".c",
+ ".cpp",
+ ".h",
+ ".hpp",
+ ".cs",
+ ".php",
+ ".rb",
+ ".swift",
+ ".kt",
+ ".kts",
+ ".scala",
+ ".sh",
+ ".bash",
+ ".zsh",
+ ".ps1",
+ ".sql",
+ ".toml",
+ ".ini",
+ ".cfg",
+ ".log",
+ ".xml",
+ ".css",
+ ".scss",
+ }
+)
+
+
+_DESCRIBE_PROMPT = (
+ "Describe this figure in <=60 words. Focus on factual content "
+ "(axes, labels, captions, visible text, main objects). Do not "
+ "speculate beyond what is visible."
+)
+
+
+DEFAULT_DOCUMENT_VISUAL_PAYLOADS = 3
+MAX_DOCUMENT_VISUAL_PAYLOADS = 10
+_MAX_ENCODED_VISUALS = DEFAULT_DOCUMENT_VISUAL_PAYLOADS
+_EXTRACT_TIMEOUT_SECONDS = 120
+_VLM_CAPTION_TOTAL_TIMEOUT_SECONDS = 180
+_LOCAL_VLM_CAPTION_CONCURRENCY = 1
+_DEFAULT_VLM_CAPTION_CONCURRENCY = 3
+_EXTRACT_CONCURRENCY = max(
+ 1, int(os.environ.get("UNSLOTH_STUDIO_EXTRACT_CONCURRENCY", "2"))
+)
+_EXTRACT_SEMAPHORE = threading.BoundedSemaphore(_EXTRACT_CONCURRENCY)
+# Bounded queue wait: callers park here for a slot instead of failing fast
+# with 503 when the worker pool is saturated. Tuned so a fast burst (e.g.
+# multi-select 4 PDFs) drains naturally without surfacing busy errors,
+# while truly stuck workers still time out via _EXTRACT_TIMEOUT_SECONDS.
+_EXTRACT_QUEUE_WAIT_SECONDS = max(
+ 0.0,
+ float(os.environ.get("UNSLOTH_STUDIO_EXTRACT_QUEUE_WAIT", "60")),
+)
+_PAGE_RENDER_DPI = 150
+_MAX_PAGE_RENDER_PIXELS = 4_000_000
+_MIME_TO_SUFFIX = {
+ "application/pdf": ".pdf",
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
+ "application/json": ".json",
+ "application/x-ndjson": ".jsonl",
+ "application/xml": ".xml",
+ "application/yaml": ".yaml",
+ "application/javascript": ".js",
+ "text/html": ".html",
+ "text/markdown": ".md",
+ "text/plain": ".txt",
+ "text/csv": ".csv",
+ "text/css": ".css",
+ "text/javascript": ".js",
+ "text/xml": ".xml",
+ "text/yaml": ".yaml",
+}
+
+_PLAIN_TEXT_SUFFIXES = SUPPORTED_SUFFIXES - {".pdf", ".docx", ".html", ".htm"}
+
+
+def _normalized_suffix(filename: str, content_type: str = "") -> str:
+ suffix = os.path.splitext(filename)[1].lower()
+ if suffix in SUPPORTED_SUFFIXES:
+ return suffix
+ mime = (content_type or "").split(";", 1)[0].strip().lower()
+ return _MIME_TO_SUFFIX.get(mime, suffix)
+
+
+class DocumentExtractionUnavailable(RuntimeError):
+ """Document extraction backend is not installed or failed to import.
+
+ The backend is PyMuPDF4LLM + mammoth for parsed document formats.
+ """
+
+
+class DocumentExtractionTimeout(RuntimeError):
+ """Raised when document parsing exceeds the 120-second worker limit."""
+
+
+class DocumentExtractionBusy(RuntimeError):
+ """Raised when the bounded document extraction worker pool is saturated."""
+
+
+class DocumentExtractionCancelled(RuntimeError):
+ """Raised when the caller cancels an in-flight extraction."""
+
+
+class DocumentExtractionEncrypted(RuntimeError):
+ """Raised when a PDF is encrypted and cannot be parsed without a password."""
+
+
+try: # pragma: no cover - presence depends on optional install
+ import pymupdf # type: ignore
+ import pymupdf4llm # type: ignore
+except Exception as _pdf_extract_exc: # pragma: no cover
+ pymupdf = None # type: ignore[assignment]
+ pymupdf4llm = None # type: ignore[assignment]
+ _PDF_EXTRACTION_IMPORT_ERROR: Optional[BaseException] = _pdf_extract_exc
+else:
+ _PDF_EXTRACTION_IMPORT_ERROR = None
+
+try: # pragma: no cover - presence depends on optional install
+ import mammoth # type: ignore
+except Exception as _docx_extract_exc: # pragma: no cover
+ mammoth = None # type: ignore[assignment]
+ _DOCX_EXTRACTION_IMPORT_ERROR: Optional[BaseException] = _docx_extract_exc
+else:
+ _DOCX_EXTRACTION_IMPORT_ERROR = None
+
+# The dispatcher can still extract plain text / code / data files when PDF or
+# DOCX optional parsers are missing. Format-specific helpers raise
+# DocumentExtractionUnavailable only when that format is actually requested.
+DOCUMENT_EXTRACTION_AVAILABLE = True
+_DOCUMENT_EXTRACTION_IMPORT_ERROR: Optional[BaseException] = (
+ _PDF_EXTRACTION_IMPORT_ERROR or _DOCX_EXTRACTION_IMPORT_ERROR
+)
+
+
+def document_parser_support() -> dict[str, bool]:
+ return {
+ "pdf": _PDF_EXTRACTION_IMPORT_ERROR is None,
+ "docx": _DOCX_EXTRACTION_IMPORT_ERROR is None,
+ "html": True,
+ "text": True,
+ "data": True,
+ "code": True,
+ }
+
+
+def document_parser_unavailable_reasons() -> dict[str, str]:
+ reasons: dict[str, str] = {}
+ if _PDF_EXTRACTION_IMPORT_ERROR is not None:
+ reasons["pdf"] = "PDF extraction requires pymupdf and pymupdf4llm."
+ if _DOCX_EXTRACTION_IMPORT_ERROR is not None:
+ reasons["docx"] = "DOCX extraction requires mammoth."
+ return reasons
+
+
+@dataclass
+class ExtractedFigure:
+ id: str
+ page: Optional[int]
+ caption: Optional[str]
+ error: Optional[str] = None
+ kind: Literal["figure", "page"] = "figure"
+ image_mime: Optional[str] = None
+ image_base64: Optional[str] = None
+ image_width: Optional[int] = None
+ image_height: Optional[int] = None
+
+
+@dataclass
+class ExtractResult:
+ markdown: str
+ figures: List[ExtractedFigure] = field(default_factory = list)
+ page_count: int = 0
+ tokens_est: int = 0
+ describe_skipped_reason: Optional[str] = None
+ vlm_source: Optional[str] = None
+ vlm_model: Optional[str] = None
+ image_input_available: bool = False
+ warnings: List[str] = field(default_factory = list)
+
+
+ProgressCb = Callable[[dict], Awaitable[None]]
+
+
+def _ensure_pdf_backend() -> None:
+ if pymupdf is None or pymupdf4llm is None:
+ if _PDF_EXTRACTION_IMPORT_ERROR is not None:
+ logger.debug(
+ "PDF extraction parser import failed: %s",
+ _PDF_EXTRACTION_IMPORT_ERROR,
+ )
+ raise DocumentExtractionUnavailable(
+ "PDF extraction requires pymupdf and pymupdf4llm. Re-run Studio "
+ "setup to install the parser dependencies from "
+ "studio/backend/requirements/single-env/data-designer-deps.txt"
+ )
+
+
+def _ensure_docx_backend() -> None:
+ if mammoth is None:
+ if _DOCX_EXTRACTION_IMPORT_ERROR is not None:
+ logger.debug(
+ "DOCX extraction parser import failed: %s",
+ _DOCX_EXTRACTION_IMPORT_ERROR,
+ )
+ raise DocumentExtractionUnavailable(
+ "DOCX extraction requires mammoth. Re-run Studio setup to install "
+ "the parser dependencies from "
+ "studio/backend/requirements/single-env/data-designer-deps.txt"
+ )
+
+
+def _estimate_tokens(text: str) -> int:
+ return max(0, len(text) // 4)
+
+
+def _encode_pil_image_for_chat(
+ image: Any,
+) -> tuple[Optional[str], Optional[int], Optional[int], Optional[str]]:
+ if image is None:
+ return None, None, None, None
+ try:
+ from PIL import Image as PILImage
+
+ img = image.copy()
+ img.thumbnail((1600, 1600))
+ if img.mode in ("RGBA", "LA"):
+ background = PILImage.new("RGB", img.size, (255, 255, 255))
+ alpha = img.getchannel("A")
+ background.paste(img.convert("RGB"), mask = alpha)
+ img = background
+ elif img.mode != "RGB":
+ img = img.convert("RGB")
+
+ out = io.BytesIO()
+ img.save(out, format = "JPEG", quality = 88, optimize = True)
+ encoded = base64.b64encode(out.getvalue()).decode("ascii")
+ return encoded, img.width, img.height, "image/jpeg"
+ except (ImportError, AttributeError, ValueError, OSError) as exc:
+ logger.warning("Failed to encode extracted document image", exc_info = exc)
+ return None, None, None, None
+
+
+async def _describe_image_via_vlm(
+ *,
+ image_base64: str,
+ image_mime: str,
+ endpoint_url: str,
+ model_name: str,
+ authorization_header: Optional[str],
+ timeout_seconds: float,
+) -> tuple[Optional[str], Optional[str]]:
+ try:
+ import httpx
+ except Exception as exc:
+ return None, f"httpx unavailable: {exc}"
+
+ headers = {"Content-Type": "application/json"}
+ if authorization_header:
+ headers["Authorization"] = authorization_header
+
+ data_url = f"data:{image_mime};base64,{image_base64}"
+ payload = {
+ "model": model_name,
+ "stream": False,
+ "max_tokens": 512,
+ "temperature": 0.2,
+ "top_p": 0.9,
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": _DESCRIBE_PROMPT},
+ {"type": "image_url", "image_url": {"url": data_url}},
+ ],
+ }
+ ],
+ }
+ try:
+ async with httpx.AsyncClient(timeout = timeout_seconds) as client:
+ response = await client.post(
+ endpoint_url.rstrip("/") + "/v1/chat/completions",
+ headers = headers,
+ json = payload,
+ )
+ if response.status_code >= 400:
+ return None, (
+ f"VLM caption request failed with HTTP " f"{response.status_code}"
+ )
+ body = response.json()
+ choice = (body.get("choices") or [{}])[0]
+ message = choice.get("message") or {}
+ finish_reason = choice.get("finish_reason")
+
+ # Some chat templates (Gemma 3/3n via llama-server, Qwen3 always-think)
+ # route the entire visible reply into ``reasoning_content`` and leave
+ # ``content`` empty. The chat UI handles this in its streaming
+ # consumer (see ``llama_cpp._chat_completion``); mirror that fallback
+ # here so non-streaming callers see the same answer.
+ candidates: list[Any] = [
+ message.get("content"),
+ message.get("reasoning_content"),
+ message.get("text"),
+ ]
+ # Some servers return content as a list of parts (OpenAI multimodal);
+ # join any text parts into one string before checking emptiness.
+ normalized: list[str] = []
+ for raw in candidates:
+ if isinstance(raw, str):
+ if raw.strip():
+ normalized.append(raw.strip())
+ elif isinstance(raw, list):
+ parts = [
+ part.get("text", "")
+ for part in raw
+ if isinstance(part, dict) and isinstance(part.get("text"), str)
+ ]
+ joined = "".join(parts).strip()
+ if joined:
+ normalized.append(joined)
+
+ if not normalized:
+ logger.warning(
+ "VLM caption empty: finish_reason=%r message_keys=%s",
+ finish_reason,
+ list(message.keys()),
+ )
+ return None, (f"VLM caption empty (finish_reason={finish_reason!r})")
+ # Prefer the first non-empty candidate
+ # (content > reasoning_content > text).
+ return normalized[0], None
+ except Exception as exc:
+ logger.debug("VLM caption request failed", exc_info = True)
+ return None, f"VLM caption request failed: {type(exc).__name__}"
+
+
+def _build_extract_options(
+ *,
+ extract_images: bool,
+ use_vlm_ocr: bool,
+ max_visual_payloads: int,
+) -> tuple[dict, list[str]]:
+ """Return ``(options, build_warnings)``.
+
+ The options dict is a simple bag of flags consumed by the synchronous
+ extract dispatcher. There is no local OCR pass available in this build;
+ ``use_vlm_ocr=True`` is implemented as a bounded full-page visual
+ extraction fallback for VLM captioning.
+ """
+ build_warnings: list[str] = []
+ if use_vlm_ocr:
+ build_warnings.append(
+ "Full-page OCR was requested, but this build has no local OCR "
+ "engine; rendered page images will be sent to the loaded vision "
+ "model when image description is enabled."
+ )
+ options = {
+ "extract_images": bool(extract_images),
+ "use_vlm_ocr": bool(use_vlm_ocr),
+ "max_visual_payloads": max(0, max_visual_payloads),
+ }
+ return options, build_warnings
+
+
+def _pymupdf4llm_markdown_kwargs() -> dict[str, Any]:
+ """Return kwargs supported by the installed pymupdf4llm.to_markdown()."""
+ preferred = {
+ "write_images": False,
+ "show_progress": False,
+ "ignore_images": True,
+ "table_strategy": "lines_strict",
+ "use_ocr": False,
+ "force_ocr": False,
+ }
+ try:
+ signature = inspect.signature(pymupdf4llm.to_markdown)
+ except (TypeError, ValueError):
+ return {
+ key: value
+ for key, value in preferred.items()
+ if key not in {"use_ocr", "force_ocr"}
+ }
+ params = signature.parameters
+ if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()):
+ return preferred
+ return {key: value for key, value in preferred.items() if key in params}
+
+
+def _safe_page_pixmap(page: Any) -> Any:
+ rect = getattr(page, "rect", None)
+ width_pt = max(float(getattr(rect, "width", 0) or 0), 1.0)
+ height_pt = max(float(getattr(rect, "height", 0) or 0), 1.0)
+ scale = _PAGE_RENDER_DPI / 72.0
+ projected_pixels = width_pt * scale * height_pt * scale
+ if projected_pixels > _MAX_PAGE_RENDER_PIXELS:
+ scale *= math.sqrt(_MAX_PAGE_RENDER_PIXELS / projected_pixels)
+ scale = max(scale, 0.05)
+ matrix = pymupdf.Matrix(scale, scale) # type: ignore[union-attr]
+ return page.get_pixmap(matrix = matrix, alpha = False)
+
+
+def _append_page_image_figure(
+ doc: Any,
+ figures_out: list[ExtractedFigure],
+ *,
+ page_index: int,
+ max_figures: int,
+ encode_image: bool = True,
+) -> bool:
+ if len(figures_out) >= max_figures:
+ return False
+ if not encode_image:
+ figures_out.append(
+ ExtractedFigure(
+ id = f"page-{page_index + 1}",
+ page = page_index + 1,
+ caption = None,
+ error = None,
+ kind = "page",
+ )
+ )
+ return True
+ try:
+ from PIL import Image as PILImage
+
+ pix = _safe_page_pixmap(doc[page_index])
+ png_bytes = pix.tobytes("png")
+ page_image = PILImage.open(io.BytesIO(png_bytes))
+ image_base64, image_width, image_height, image_mime = (
+ _encode_pil_image_for_chat(page_image)
+ )
+ if not image_base64:
+ return False
+ figures_out.append(
+ ExtractedFigure(
+ id = f"page-{page_index + 1}",
+ page = page_index + 1,
+ caption = None,
+ error = None,
+ kind = "page",
+ image_mime = image_mime,
+ image_base64 = image_base64,
+ image_width = image_width,
+ image_height = image_height,
+ )
+ )
+ return True
+ except (
+ ImportError,
+ MemoryError,
+ OverflowError,
+ ValueError,
+ OSError,
+ RuntimeError,
+ ) as exc:
+ logger.warning(
+ "Failed to render page %d preview for PDF",
+ page_index + 1,
+ exc_info = exc,
+ )
+ return False
+
+
+def _extract_pdf(
+ file_bytes: bytes,
+ max_figures: int,
+ use_vlm_ocr: bool,
+ max_visual_payloads: int,
+) -> tuple[str, list[ExtractedFigure], int, int, int]:
+ """Extract Markdown + figures from a PDF via PyMuPDF4LLM.
+
+ Returns ``(markdown, figures, page_count, truncated_count, seen)``.
+ """
+ _ensure_pdf_backend()
+ assert pymupdf is not None and pymupdf4llm is not None # for type-checkers
+
+ doc = pymupdf.open(stream = file_bytes, filetype = "pdf")
+ try:
+ # ``is_encrypted`` is True for any file with an /Encrypt dict
+ # (very common for Acrobat-distilled PDFs, scanner output, the
+ # classic Orimi test file). ``needs_pass`` is the real "user
+ # password required" signal. Refuse extraction only when an
+ # actual password is missing.
+ if getattr(doc, "needs_pass", False):
+ raise DocumentExtractionEncrypted(
+ "Encrypted PDF; provide a password before extracting it."
+ )
+ markdown = pymupdf4llm.to_markdown(doc, **_pymupdf4llm_markdown_kwargs())
+
+ figures_out: list[ExtractedFigure] = []
+ encoded_visuals = 0
+ seen = 0
+ truncated_count = 0
+ page_count = len(doc)
+
+ if max_figures > 0 and page_count > 0:
+ if use_vlm_ocr:
+ for page_index in range(page_count):
+ if len(figures_out) >= max_figures:
+ truncated_count += page_count - page_index
+ break
+ if _append_page_image_figure(
+ doc,
+ figures_out,
+ page_index = page_index,
+ max_figures = max_figures,
+ encode_image = encoded_visuals < max_visual_payloads,
+ ):
+ if figures_out[-1].image_base64:
+ encoded_visuals += 1
+ seen += 1
+ elif _append_page_image_figure(
+ doc,
+ figures_out,
+ page_index = 0,
+ max_figures = max_figures,
+ encode_image = encoded_visuals < max_visual_payloads,
+ ):
+ if figures_out[-1].image_base64:
+ encoded_visuals += 1
+
+ if not use_vlm_ocr:
+ try:
+ from PIL import Image as PILImage
+
+ for page_index in range(page_count):
+ page = doc[page_index]
+ try:
+ images = page.get_images(full = True)
+ except (ValueError, RuntimeError) as exc:
+ logger.debug(
+ "page.get_images failed on page %d",
+ page_index + 1,
+ exc_info = exc,
+ )
+ continue
+ for img_info in images:
+ xref = img_info[0] if img_info else 0
+ if not xref:
+ continue
+ try:
+ extracted = doc.extract_image(xref)
+ except (ValueError, RuntimeError) as exc:
+ logger.debug(
+ "doc.extract_image failed for xref %s",
+ xref,
+ exc_info = exc,
+ )
+ continue
+ if not extracted:
+ continue
+ raw_bytes = extracted.get("image")
+ if not raw_bytes:
+ continue
+ try:
+ pil_img = PILImage.open(io.BytesIO(raw_bytes))
+ pil_img.load()
+ except (OSError, ValueError) as exc:
+ logger.debug(
+ "PIL failed to decode extracted image xref %s",
+ xref,
+ exc_info = exc,
+ )
+ continue
+ if pil_img.width < 50 or pil_img.height < 50:
+ continue
+ seen += 1
+ if len(figures_out) >= max_figures:
+ truncated_count += 1
+ continue
+ image_base64 = None
+ image_width = None
+ image_height = None
+ image_mime = None
+ if encoded_visuals < max_visual_payloads:
+ (
+ image_base64,
+ image_width,
+ image_height,
+ image_mime,
+ ) = _encode_pil_image_for_chat(pil_img)
+ if image_base64:
+ encoded_visuals += 1
+ figures_out.append(
+ ExtractedFigure(
+ id = f"fig-{len(figures_out)}",
+ page = page_index + 1,
+ caption = None,
+ error = None,
+ kind = "figure",
+ image_mime = image_mime,
+ image_base64 = image_base64,
+ image_width = image_width,
+ image_height = image_height,
+ )
+ )
+ except ImportError as exc:
+ logger.warning(
+ "Pillow is unavailable; skipping embedded-image extraction",
+ exc_info = exc,
+ )
+
+ return markdown, figures_out, page_count, truncated_count, seen
+ finally:
+ try:
+ doc.close()
+ except Exception: # pragma: no cover - defensive
+ logger.debug("pymupdf doc.close() raised", exc_info = True)
+
+
+def _extract_docx(
+ file_bytes: bytes,
+) -> tuple[str, list[ExtractedFigure], int, int, int]:
+ _ensure_docx_backend()
+ assert mammoth is not None # for type-checkers
+ stream = io.BytesIO(file_bytes)
+ result = mammoth.convert_to_markdown(stream)
+ markdown = result.value or ""
+ return markdown, [], 0, 0, 0
+
+
+def _extract_plaintext(
+ file_bytes: bytes,
+) -> tuple[str, list[ExtractedFigure], int, int, int]:
+ text = file_bytes.decode("utf-8", errors = "replace")
+ return text, [], 0, 0, 0
+
+
+def _extract_html(
+ file_bytes: bytes,
+) -> tuple[str, list[ExtractedFigure], int, int, int]:
+ html = file_bytes.decode("utf-8", errors = "replace")
+ try:
+ from core.inference._html_to_md import html_to_markdown
+ except Exception as exc:
+ logger.warning(
+ "HTML-to-Markdown converter unavailable; using raw HTML",
+ exc_info = exc,
+ )
+ return html, [], 0, 0, 0
+ return html_to_markdown(html), [], 0, 0, 0
+
+
+def _run_extract_sync(
+ file_bytes: bytes,
+ filename: str,
+ options: dict,
+ content_type: str = "",
+) -> tuple[str, list[ExtractedFigure], int, int, int]:
+ """Synchronous dispatch by file suffix.
+
+ Returns ``(markdown, figures, page_count, truncated_count, seen)``.
+ """
+ suffix = _normalized_suffix(filename, content_type)
+ extract_images = bool(options.get("extract_images"))
+ use_vlm_ocr = bool(options.get("use_vlm_ocr"))
+ max_figures = int(options.get("max_figures", 0)) if extract_images else 0
+ max_visual_payloads = int(
+ options.get("max_visual_payloads", DEFAULT_DOCUMENT_VISUAL_PAYLOADS)
+ )
+
+ if suffix == ".pdf":
+ return _extract_pdf(file_bytes, max_figures, use_vlm_ocr, max_visual_payloads)
+ if suffix == ".docx":
+ return _extract_docx(file_bytes)
+ if suffix in {".html", ".htm"}:
+ return _extract_html(file_bytes)
+ if suffix in _PLAIN_TEXT_SUFFIXES:
+ return _extract_plaintext(file_bytes)
+ raise ValueError(f"Unsupported file type: {filename}")
+
+
+_RUN_EXTRACT_SYNC_ORIGINAL = _run_extract_sync
+
+
+def _run_extract_worker(
+ result_queue: Any,
+ file_bytes: bytes,
+ filename: str,
+ options: dict,
+ content_type: str,
+) -> None:
+ try:
+ result_queue.put(
+ ("ok", _run_extract_sync(file_bytes, filename, options, content_type))
+ )
+ except DocumentExtractionUnavailable as exc:
+ result_queue.put(("extraction_unavailable", str(exc)))
+ except DocumentExtractionEncrypted as exc:
+ result_queue.put(("encrypted", str(exc)))
+ except ValueError as exc:
+ result_queue.put(("value_error", str(exc)))
+ except BaseException as exc:
+ result_queue.put(("error", type(exc).__name__, str(exc)))
+
+
+def _drain_future_exception(fut: Any) -> None:
+ """Retrieve a future's exception (if any) so asyncio's gc-time
+ "Future exception was never retrieved" warning stays quiet when the
+ awaiting task is cancelled mid-flight (e.g. client disconnect or
+ AbortController abort)."""
+ try:
+ if fut.cancelled():
+ return
+ fut.exception()
+ except BaseException:
+ # Never let a drain hook itself raise — best effort only.
+ pass
+
+
+def _terminate_extract_process(proc: multiprocessing.Process) -> None:
+ if not proc.is_alive():
+ return
+ proc.terminate()
+ proc.join(5)
+ if proc.is_alive() and hasattr(proc, "kill"):
+ proc.kill()
+ proc.join(2)
+
+
+def _run_extract_process_sync(
+ file_bytes: bytes,
+ filename: str,
+ options: dict,
+ content_type: str,
+ timeout_seconds: int,
+ cancel_event: Optional[threading.Event] = None,
+) -> tuple[str, list[ExtractedFigure], int, int, int]:
+ if cancel_event is not None and cancel_event.is_set():
+ raise DocumentExtractionCancelled("document extraction was cancelled")
+ # Park up to _EXTRACT_QUEUE_WAIT_SECONDS waiting for a slot, polling
+ # cancel_event so a client disconnect during the wait short-circuits
+ # cleanly instead of holding the request open.
+ deadline = time.monotonic() + _EXTRACT_QUEUE_WAIT_SECONDS
+ acquired = _EXTRACT_SEMAPHORE.acquire(blocking = False)
+ while True:
+ if acquired:
+ break
+ if cancel_event is not None and cancel_event.is_set():
+ raise DocumentExtractionCancelled("document extraction was cancelled")
+ remaining = deadline - time.monotonic()
+ if remaining <= 0:
+ break
+ wait = min(remaining, 0.5)
+ if _EXTRACT_SEMAPHORE.acquire(timeout = wait):
+ acquired = True
+ break
+ if not acquired:
+ raise DocumentExtractionBusy("document extraction is busy")
+
+ # Everything past the semaphore acquisition must live inside the
+ # try/finally so the slot is released even if multiprocessing
+ # context creation / Queue allocation / Process construction
+ # itself raises (e.g. OSError on fork-resource exhaustion, EAGAIN
+ # on Windows under load).
+ result_queue = None
+ proc = None
+ try:
+ # Prefer "fork" only on Linux. macOS defaults to "spawn" in
+ # modern Python because Objective-C runtimes (loaded by
+ # PyMuPDF/CoreFoundation/Quartz) crash under fork. Windows has
+ # never supported fork.
+ import sys as _sys
+ if os.name == "nt" or _sys.platform == "darwin":
+ mp_method = "spawn"
+ else:
+ mp_method = "fork"
+ ctx = multiprocessing.get_context(mp_method)
+ result_queue = ctx.Queue(maxsize = 1)
+ proc = ctx.Process(
+ target = _run_extract_worker,
+ args = (
+ result_queue,
+ file_bytes,
+ filename,
+ options,
+ content_type,
+ ),
+ daemon = True,
+ )
+ if cancel_event is not None and cancel_event.is_set():
+ raise DocumentExtractionCancelled("document extraction was cancelled")
+ proc.start()
+ deadline = time.monotonic() + timeout_seconds
+ message = None
+ while message is None:
+ try:
+ message = result_queue.get(timeout = 0.1)
+ break
+ except queue.Empty:
+ if cancel_event is not None and cancel_event.is_set():
+ _terminate_extract_process(proc)
+ raise DocumentExtractionCancelled(
+ "document extraction was cancelled"
+ )
+ if not proc.is_alive():
+ # The worker may have put its result and exited
+ # between the queue.get timeout and this is_alive
+ # check. Drain the queue once more before declaring
+ # failure so a successful extraction is not lost.
+ try:
+ message = result_queue.get_nowait()
+ except queue.Empty:
+ pass
+ break
+ if time.monotonic() >= deadline:
+ _terminate_extract_process(proc)
+ raise DocumentExtractionTimeout(
+ "document parsing exceeded the 120-second worker limit"
+ )
+
+ proc.join(2)
+ if proc.is_alive():
+ proc.terminate()
+ proc.join(2)
+ if message is None:
+ # One more attempt after the join completes; covers the
+ # case where the worker exited cleanly with a result still
+ # queued.
+ try:
+ message = result_queue.get_nowait()
+ except queue.Empty:
+ pass
+ if message is None:
+ raise RuntimeError(
+ f"document extraction worker exited without a result "
+ f"(exitcode={proc.exitcode})"
+ )
+
+ kind = message[0]
+ if kind == "ok":
+ return message[1]
+ if kind == "extraction_unavailable":
+ raise DocumentExtractionUnavailable(message[1])
+ if kind == "encrypted":
+ raise DocumentExtractionEncrypted(message[1])
+ if kind == "value_error":
+ raise ValueError(message[1])
+ if kind == "error":
+ raise RuntimeError(f"{message[1]}: {message[2]}")
+ raise RuntimeError(f"unexpected document worker result: {kind!r}")
+ finally:
+ if proc is not None:
+ try:
+ _terminate_extract_process(proc)
+ except Exception:
+ pass
+ if result_queue is not None:
+ try:
+ result_queue.close()
+ result_queue.join_thread()
+ except Exception:
+ pass
+ _EXTRACT_SEMAPHORE.release()
+
+
+async def extract_document(
+ file_bytes: bytes,
+ filename: str,
+ *,
+ content_type: str = "",
+ describe_images: bool = True,
+ use_vlm_ocr: bool = False,
+ max_figures: int = 40,
+ max_visual_payloads: int = DEFAULT_DOCUMENT_VISUAL_PAYLOADS,
+ vlm_timeout_seconds: float = 60.0,
+ capability: Optional[VlmCapability] = None,
+ self_base_url: Optional[str] = None,
+ authorization_header: Optional[str] = None,
+ progress_cb: Optional[ProgressCb] = None,
+ cancel_event: Optional[threading.Event] = None,
+) -> ExtractResult:
+ """Extract layout-aware Markdown plus figure metadata.
+
+ When ``describe_images`` is True and the active model is
+ vision-capable, the selected visual references are captioned via the
+ OpenAI-compat ``/v1/chat/completions`` surface after extraction.
+ Otherwise figures come back with ``caption=None`` and
+ ``describe_skipped_reason`` carries the human-readable reason.
+ """
+
+ async def _emit(**event: Any) -> None:
+ if cancel_event is not None and cancel_event.is_set():
+ raise DocumentExtractionCancelled("document extraction was cancelled")
+ if progress_cb is not None:
+ try:
+ await progress_cb(event)
+ except Exception:
+ logger.debug("progress_cb raised; continuing", exc_info = True)
+
+ max_figures = max(0, max_figures)
+ max_visual_payloads = max(0, min(max_visual_payloads, max_figures))
+ cap = capability if capability is not None else detect_loaded_vlm(self_base_url)
+ image_input_available = bool(cap.is_vlm and cap.endpoint_url and cap.model_name)
+ describe_available = bool(
+ describe_images and cap.is_vlm and cap.endpoint_url and cap.model_name
+ )
+ effective_describe = (
+ describe_available and max_figures > 0 and max_visual_payloads > 0
+ )
+ extract_images = max_figures > 0
+
+ skipped_reason: Optional[str] = None
+ if describe_images and not effective_describe:
+ if describe_available and max_figures <= 0:
+ skipped_reason = "figure description disabled because max_figures is 0"
+ elif describe_available and max_visual_payloads <= 0:
+ skipped_reason = (
+ "figure description disabled because max_visual_payloads is 0"
+ )
+ else:
+ skipped_reason = cap.reason or "no_vlm"
+
+ await _emit(stage = "parsing")
+
+ options, build_warnings = _build_extract_options(
+ extract_images = extract_images,
+ use_vlm_ocr = use_vlm_ocr,
+ max_visual_payloads = max_visual_payloads,
+ )
+ options["max_figures"] = max_figures
+
+ try:
+ if _run_extract_sync is _RUN_EXTRACT_SYNC_ORIGINAL:
+ # Drive run_in_executor directly (rather than asyncio.to_thread)
+ # so we can attach a done-callback that retrieves the future's
+ # exception even when the awaiting task is cancelled — silences
+ # "Future exception was never retrieved" noise on busy/cancel.
+ loop = asyncio.get_running_loop()
+ extract_future = loop.run_in_executor(
+ None,
+ _run_extract_process_sync,
+ file_bytes,
+ filename,
+ options,
+ content_type,
+ _EXTRACT_TIMEOUT_SECONDS,
+ cancel_event,
+ )
+ extract_future.add_done_callback(_drain_future_exception)
+ (
+ markdown,
+ figures_out,
+ page_count,
+ truncated_count,
+ seen,
+ ) = await extract_future
+ else:
+ # Tests monkeypatch _run_extract_sync directly; preserve that seam
+ # without forcing patched callables through multiprocessing spawn.
+ loop = asyncio.get_running_loop()
+ (
+ markdown,
+ figures_out,
+ page_count,
+ truncated_count,
+ seen,
+ ) = await asyncio.wait_for(
+ loop.run_in_executor(
+ None,
+ _run_extract_sync,
+ file_bytes,
+ filename,
+ options,
+ content_type,
+ ),
+ timeout = _EXTRACT_TIMEOUT_SECONDS,
+ )
+ except asyncio.TimeoutError:
+ raise DocumentExtractionTimeout(
+ "document parsing exceeded the 120-second worker limit"
+ )
+ except DocumentExtractionTimeout:
+ raise
+ except DocumentExtractionBusy:
+ raise
+ except DocumentExtractionCancelled:
+ raise
+ except DocumentExtractionEncrypted:
+ raise
+ except DocumentExtractionUnavailable:
+ raise
+ except ValueError:
+ # Unsupported file type — surface unchanged so the route can map to 415.
+ raise
+ except Exception as exc:
+ logger.exception("document extraction failed for %s", filename)
+ raise RuntimeError("document extraction failed") from exc
+
+ caption_deadline_hit = False
+ if effective_describe:
+ caption_concurrency = (
+ _LOCAL_VLM_CAPTION_CONCURRENCY
+ if cap.source in {"transformers", "unsloth"}
+ else _DEFAULT_VLM_CAPTION_CONCURRENCY
+ )
+ sem = asyncio.Semaphore(caption_concurrency)
+
+ captionable_total = sum(
+ 1
+ for fig in figures_out[:max_figures]
+ if fig.image_base64 and fig.image_mime
+ )
+ captioned_completed = 0
+ await _emit(
+ stage = "captioning",
+ current = 0,
+ total = captionable_total,
+ page = None,
+ total_pages = page_count,
+ )
+
+ async def _describe_one(index: int, figure: ExtractedFigure) -> None:
+ nonlocal captioned_completed
+ if figure.caption or not figure.image_base64 or not figure.image_mime:
+ return
+ if cancel_event is not None and cancel_event.is_set():
+ raise DocumentExtractionCancelled("document extraction was cancelled")
+ async with sem:
+ if cancel_event is not None and cancel_event.is_set():
+ raise DocumentExtractionCancelled(
+ "document extraction was cancelled"
+ )
+ try:
+ caption, error = await _describe_image_via_vlm(
+ image_base64 = figure.image_base64,
+ image_mime = figure.image_mime,
+ endpoint_url = cap.endpoint_url or "",
+ model_name = cap.model_name or "",
+ authorization_header = authorization_header,
+ timeout_seconds = vlm_timeout_seconds,
+ )
+ figures_out[index] = replace(
+ figure,
+ caption = caption,
+ error = error,
+ )
+ except asyncio.TimeoutError as exc:
+ logger.warning(
+ "VLM describe timed out for figure %s", figure.id, exc_info = exc
+ )
+ figures_out[index] = replace(
+ figure,
+ error = f"VLM describe timed out: {type(exc).__name__}",
+ )
+ except Exception as exc:
+ logger.warning(
+ "VLM describe failed for figure %s", figure.id, exc_info = exc
+ )
+ figures_out[index] = replace(
+ figure,
+ error = f"VLM describe failed: {type(exc).__name__}",
+ )
+ finally:
+ captioned_completed += 1
+ await _emit(
+ stage = "captioning",
+ current = captioned_completed,
+ total = captionable_total,
+ page = figure.page,
+ total_pages = page_count,
+ )
+
+ tasks = [
+ _describe_one(index, fig)
+ for index, fig in enumerate(figures_out[:max_figures])
+ if fig.image_base64 and fig.image_mime
+ ]
+ if tasks:
+ try:
+ caption_timeout_seconds = _VLM_CAPTION_TOTAL_TIMEOUT_SECONDS
+ if cap.source in {"transformers", "unsloth"}:
+ caption_timeout_seconds = max(
+ caption_timeout_seconds,
+ len(tasks) * vlm_timeout_seconds + 15,
+ )
+ results = await asyncio.wait_for(
+ asyncio.gather(*tasks, return_exceptions = True),
+ timeout = caption_timeout_seconds,
+ )
+ for result in results:
+ if isinstance(
+ result,
+ (DocumentExtractionCancelled, asyncio.CancelledError),
+ ):
+ raise result
+ except asyncio.TimeoutError:
+ caption_deadline_hit = True
+ for index, figure in enumerate(figures_out):
+ if figure.image_base64 and not figure.caption and not figure.error:
+ figures_out[index] = replace(
+ figure,
+ error = "VLM caption deadline exceeded",
+ )
+
+ warnings: List[str] = list(build_warnings)
+ if truncated_count > 0:
+ warnings.append(
+ f"Document has {seen} figures; showing the first {max_figures} "
+ f"({truncated_count} truncated)."
+ )
+ visual_payload_count = sum(1 for figure in figures_out if figure.image_base64)
+ if (
+ visual_payload_count >= max_visual_payloads
+ and len(figures_out) > visual_payload_count
+ ):
+ warnings.append(
+ f"Only the first {max_visual_payloads} visual payloads "
+ "were attached; remaining figure references are text-only."
+ )
+ if (
+ effective_describe
+ and figures_out
+ and all(f.caption is None for f in figures_out)
+ ):
+ error_samples: list[str] = []
+ seen_errors: set[str] = set()
+ for figure in figures_out:
+ if not figure.error or figure.error in seen_errors:
+ continue
+ seen_errors.add(figure.error)
+ error_samples.append(f"{figure.id}: {figure.error}")
+ if len(error_samples) >= 3:
+ break
+ sample_suffix = (
+ " Examples: " + "; ".join(error_samples) + "." if error_samples else ""
+ )
+ warnings.append(
+ "Figure descriptions were requested but none were produced — "
+ "check that the loaded model accepts image inputs via /v1."
+ f"{sample_suffix}"
+ )
+ if caption_deadline_hit:
+ warnings.append(
+ "Figure captioning reached the inline timeout; some image "
+ "descriptions were skipped."
+ )
+
+ await _emit(stage = "done")
+
+ return ExtractResult(
+ markdown = markdown,
+ figures = figures_out,
+ page_count = page_count,
+ tokens_est = _estimate_tokens(markdown),
+ describe_skipped_reason = skipped_reason,
+ vlm_source = cap.source,
+ vlm_model = cap.model_name,
+ image_input_available = image_input_available,
+ warnings = warnings,
+ )
diff --git a/studio/backend/core/chat/vlm_capability.py b/studio/backend/core/chat/vlm_capability.py
new file mode 100644
index 0000000000..f8992c6455
--- /dev/null
+++ b/studio/backend/core/chat/vlm_capability.py
@@ -0,0 +1,213 @@
+# SPDX-License-Identifier: AGPL-3.0-only
+# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
+
+"""
+Runtime probe: is the currently loaded model vision-capable, and where
+is its OpenAI-compatible endpoint?
+
+Unifies the three Studio inference backends (embedded llama-server for
+GGUF, transformers, Unsloth/LoRA) behind a single ``VlmCapability``
+dataclass. Read-only — never loads or modifies models.
+
+Why this replaces the old ``VISION_ARCHITECTURES`` allow-list:
+- Allow-lists silently exclude legitimately new vision architectures.
+- Runtime probing matches the user's actual loaded model.
+- The document extractor can caption selected visual references through
+ any loaded backend exposing ``/v1/chat/completions`` without
+ hard-coding architecture names.
+"""
+
+from __future__ import annotations
+
+import logging
+from dataclasses import asdict, dataclass
+from typing import Any, Literal, Optional
+from urllib.parse import urlparse
+
+
+logger = logging.getLogger(__name__)
+
+
+VlmSource = Literal["gguf", "transformers", "unsloth", "none"]
+
+
+@dataclass(frozen = True)
+class VlmCapability:
+ """Immutable snapshot of the loaded model's image-input capability."""
+
+ is_vlm: bool
+ endpoint_url: Optional[str]
+ model_name: Optional[str]
+ source: VlmSource
+ reason: Optional[str] = None
+
+ @classmethod
+ def none(cls, reason: str = "no model loaded") -> "VlmCapability":
+ return cls(
+ is_vlm = False,
+ endpoint_url = None,
+ model_name = None,
+ source = "none",
+ reason = reason,
+ )
+
+ def to_dict(self) -> dict:
+ return asdict(self)
+
+
+def _probe_gguf(llama: Any = None) -> Optional[VlmCapability]:
+ if llama is None:
+ try:
+ from core.inference.llama_cpp import get_llama_cpp_backend
+ except Exception: # pragma: no cover - older embedding paths
+ return None
+
+ try:
+ llama = get_llama_cpp_backend()
+ except Exception:
+ return None
+
+ if not getattr(llama, "is_loaded", False):
+ return None
+
+ base_url = getattr(llama, "base_url", None)
+ model_id = getattr(llama, "model_identifier", None)
+ is_vision = bool(getattr(llama, "is_vision", False))
+
+ if not base_url or not model_id:
+ # Half-initialised llama-server state — fall through to the
+ # transformers probe instead of returning a misleading
+ # non-vision GGUF result that suppresses the fallback chain.
+ logger.debug(
+ "llama-server reports is_loaded=True but base_url / model id missing"
+ )
+ return None
+
+ return VlmCapability(
+ is_vlm = is_vision,
+ endpoint_url = base_url,
+ model_name = model_id,
+ source = "gguf",
+ reason = None
+ if is_vision
+ else "gguf: model loaded, is_vision=False (no mmproj clip)",
+ )
+
+
+def _probe_transformers(self_base_url: Optional[str]) -> Optional[VlmCapability]:
+ try:
+ from core.inference import get_inference_backend
+ except ModuleNotFoundError as exc:
+ if exc.name == "core.inference" or (
+ exc.name and exc.name.startswith("core.inference.")
+ ):
+ return None
+ logger.exception("Failed to import transformers inference backend")
+ return None
+ except ImportError:
+ # A different ImportError variant (e.g. circular import). Treat as
+ # backend-unavailable. Anything else (NameError/AttributeError raised
+ # by core.inference.__init__) propagates so real bugs aren't masked
+ # as "no VLM loaded".
+ logger.exception("Failed to import transformers inference backend")
+ return None
+
+ try:
+ ib = get_inference_backend()
+ except Exception:
+ return None
+
+ name: Optional[str] = getattr(ib, "active_model_name", None)
+ if not name:
+ return None
+
+ models: dict = getattr(ib, "models", {}) or {}
+ info: dict = models.get(name) or {}
+ is_vision = bool(info.get("is_vision", False))
+ is_lora = bool(info.get("is_lora", False))
+ source: VlmSource = "unsloth" if is_lora else "transformers"
+
+ if not self_base_url:
+ return VlmCapability(
+ is_vlm = False,
+ endpoint_url = None,
+ model_name = name,
+ source = source,
+ reason = f"{source}: self_base_url=None (cannot self-loopback to /v1/chat/completions)",
+ )
+
+ return VlmCapability(
+ is_vlm = is_vision,
+ endpoint_url = self_base_url.rstrip("/"),
+ model_name = name,
+ source = source,
+ reason = None if is_vision else f"{source}: active model not marked is_vision",
+ )
+
+
+def detect_loaded_vlm(
+ self_base_url: Optional[str] = None,
+ *,
+ llama_backend: Any = None,
+) -> VlmCapability:
+ """Identify the active model and whether it can describe images.
+
+ ``self_base_url`` is only consulted when the active model is served
+ by the transformers / Unsloth backend; document image captioning must
+ loop back through our own ``/v1/chat/completions``. GGUF models return
+ llama-server's own URL and ignore this argument.
+ """
+ gguf = _probe_gguf(llama_backend)
+ if gguf is not None:
+ return gguf
+
+ tf = _probe_transformers(self_base_url)
+ if tf is not None:
+ return tf
+
+ return VlmCapability.none()
+
+
+def extract_self_base_url(request: Any) -> Optional[str]:
+ """Derive a trusted local base URL for the active Studio server.
+
+ The request Host header is attacker-controlled in many deployments,
+ so the returned origin always uses ``127.0.0.1``. Only the server
+ port is discovered, preferring the port published by ``run.py`` and
+ then uvicorn's ASGI scope. ``request.base_url`` is a last-resort
+ fallback for tests and non-uvicorn embedding.
+ """
+ port: Optional[int] = None
+
+ try:
+ candidate = getattr(getattr(request, "app", None), "state", None)
+ candidate = getattr(candidate, "server_port", None)
+ if isinstance(candidate, int) and candidate > 0:
+ port = candidate
+ except Exception:
+ port = None
+
+ if port is None:
+ try:
+ server = getattr(request, "scope", {}).get("server")
+ if (
+ isinstance(server, tuple)
+ and len(server) >= 2
+ and isinstance(server[1], int)
+ and server[1] > 0
+ ):
+ port = server[1]
+ except Exception:
+ port = None
+
+ if port is None:
+ try:
+ base = str(getattr(request, "base_url", "") or "")
+ if not base:
+ return None
+ parsed = urlparse(base)
+ port = parsed.port if parsed.port is not None else 8888
+ except Exception:
+ return None
+
+ return f"http://127.0.0.1:{int(port)}"
diff --git a/studio/backend/core/export/export.py b/studio/backend/core/export/export.py
index 7cabd382eb..1ad3a3607b 100644
--- a/studio/backend/core/export/export.py
+++ b/studio/backend/core/export/export.py
@@ -182,7 +182,10 @@ def load_checkpoint(
# Detect audio type and vision
self._audio_type = detect_audio_type(model_id)
- self.is_vision = not self._audio_type and is_vision_model(model_id)
+ self.is_vision = not self._audio_type and is_vision_model(
+ model_id,
+ trust_remote_code = trust_remote_code,
+ )
# Load model based on type
if self._audio_type == "csm":
diff --git a/studio/backend/core/inference/__init__.py b/studio/backend/core/inference/__init__.py
index 35318f6357..8c56a56564 100644
--- a/studio/backend/core/inference/__init__.py
+++ b/studio/backend/core/inference/__init__.py
@@ -7,17 +7,43 @@
The default get_inference_backend() returns an InferenceOrchestrator that
delegates to a subprocess. The original InferenceBackend runs inside
the subprocess and can be imported directly from .inference when needed.
-"""
-from .orchestrator import InferenceOrchestrator, get_inference_backend
-from .llama_cpp import LlamaCppBackend
+Symbols are exposed lazily through ``__getattr__`` (PEP 562) so that
+importing a stdlib-only helper from this package (e.g.
+``from core.inference._html_to_md import html_to_markdown``) does not
+eagerly pull in the orchestrator or the GGUF/llama-server backend.
+That matters for the document-extractor HTML path which must keep
+working in environments where the inference extras are unavailable or
+broken.
+"""
-# Expose InferenceOrchestrator as InferenceBackend for backward compat
-InferenceBackend = InferenceOrchestrator
+from typing import Any
__all__ = [
"InferenceBackend",
"InferenceOrchestrator",
"get_inference_backend",
+ "get_llama_cpp_backend",
"LlamaCppBackend",
]
+
+
+def __getattr__(name: str) -> Any:
+ if name in ("InferenceOrchestrator", "get_inference_backend", "InferenceBackend"):
+ from .orchestrator import InferenceOrchestrator, get_inference_backend
+
+ globals()["InferenceOrchestrator"] = InferenceOrchestrator
+ globals()["get_inference_backend"] = get_inference_backend
+ globals()["InferenceBackend"] = InferenceOrchestrator
+ return globals()[name]
+ if name in ("LlamaCppBackend", "get_llama_cpp_backend"):
+ from .llama_cpp import LlamaCppBackend, get_llama_cpp_backend
+
+ globals()["LlamaCppBackend"] = LlamaCppBackend
+ globals()["get_llama_cpp_backend"] = get_llama_cpp_backend
+ return globals()[name]
+ raise AttributeError(name)
+
+
+def __dir__() -> list[str]:
+ return sorted(set(globals()) | set(__all__))
diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py
index 76234386aa..b41ee999b8 100644
--- a/studio/backend/core/inference/llama_cpp.py
+++ b/studio/backend/core/inference/llama_cpp.py
@@ -711,6 +711,10 @@ def is_active(self) -> bool:
def base_url(self) -> str:
return f"http://127.0.0.1:{self._port}"
+ @property
+ def api_key(self) -> Optional[str]:
+ return self._api_key
+
@property
def model_identifier(self) -> Optional[str]:
return self._model_identifier
@@ -4077,6 +4081,9 @@ def _parse_tool_calls_from_text(content: str) -> list[dict]:
def _build_openai_messages(
messages: list[dict],
image_b64: Optional[str] = None,
+ image_b64s: Optional[list[str]] = None,
+ image_mime: Optional[str] = None,
+ image_mimes: Optional[list[str]] = None,
) -> list[dict]:
"""
Build OpenAI-format messages, optionally injecting an image_url
@@ -4084,8 +4091,18 @@ def _build_openai_messages(
If no image is provided, returns messages as-is.
"""
- if not image_b64:
+ images = (
+ image_b64s if image_b64s is not None else ([image_b64] if image_b64 else [])
+ )
+ images = [image for image in images if image]
+ if not images:
return messages
+ if image_b64s is not None:
+ mimes = image_mimes or ["image/png"] * len(images)
+ else:
+ mimes = [image_mime or "image/png"]
+ if len(mimes) < len(images):
+ mimes = [*mimes, *(["image/png"] * (len(images) - len(mimes)))]
# Find the last user message and convert to multimodal content parts
result = [msg.copy() for msg in messages]
@@ -4096,14 +4113,18 @@ def _build_openai_messages(
if last_user_idx is not None:
text_content = result[last_user_idx].get("content", "")
- result[last_user_idx]["content"] = [
- {"type": "text", "text": text_content},
+ image_parts = [
{
"type": "image_url",
"image_url": {
- "url": f"data:image/png;base64,{image_b64}",
+ "url": f"data:{mime if mime and '/' in mime else 'image/png'};base64,{image}",
},
- },
+ }
+ for image, mime in zip(images, mimes)
+ ]
+ result[last_user_idx]["content"] = [
+ {"type": "text", "text": text_content},
+ *image_parts,
]
return result
@@ -4235,6 +4256,9 @@ def generate_chat_completion(
self,
messages: list[dict],
image_b64: Optional[str] = None,
+ image_b64s: Optional[list[str]] = None,
+ image_mime: Optional[str] = None,
+ image_mimes: Optional[list[str]] = None,
temperature: float = 0.6,
top_p: float = 0.95,
top_k: int = 20,
@@ -4259,7 +4283,13 @@ def generate_chat_completion(
if not self.is_loaded:
raise RuntimeError("llama-server is not loaded")
- openai_messages = self._build_openai_messages(messages, image_b64)
+ openai_messages = self._build_openai_messages(
+ messages,
+ image_b64 = image_b64,
+ image_b64s = image_b64s,
+ image_mime = image_mime,
+ image_mimes = image_mimes,
+ )
payload = {
"messages": openai_messages,
@@ -5490,3 +5520,20 @@ def generate_audio_response(
return LlamaCppBackend._codec_mgr.decode(
audio_type, device, token_ids = token_ids, text = data.get("content", "")
)
+
+
+_llama_cpp_backend: Optional[LlamaCppBackend] = None
+
+
+def get_llama_cpp_backend() -> LlamaCppBackend:
+ """Return the process-wide GGUF llama-server backend.
+
+ Keep the singleton in ``core.inference`` so core helpers such as
+ ``core.chat.detect_loaded_vlm`` do not need to import route modules.
+ The instance is lazy to avoid subprocess cleanup side effects for
+ callers that only import model helpers.
+ """
+ global _llama_cpp_backend
+ if _llama_cpp_backend is None:
+ _llama_cpp_backend = LlamaCppBackend()
+ return _llama_cpp_backend
diff --git a/studio/backend/core/inference/worker.py b/studio/backend/core/inference/worker.py
index 20a7d2d16c..ba12157780 100644
--- a/studio/backend/core/inference/worker.py
+++ b/studio/backend/core/inference/worker.py
@@ -74,7 +74,28 @@ def _send_response(resp_queue: Any, response: dict) -> None:
logger.error("Failed to send response: %s", exc)
-def _build_model_config(config: dict):
+def _resolve_trust_remote_code(config: dict) -> bool:
+ # Auto-enable trust_remote_code for NemotronH/Nano models only.
+ # NemotronH has config parsing bugs requiring trust_remote_code=True.
+ # Other transformers 5.x models are native and do NOT need it.
+ # NOTE: Must NOT match Llama-Nemotron (standard Llama architecture).
+ trust_remote_code = config.get("trust_remote_code", False)
+ if not trust_remote_code:
+ model_name = config["model_name"]
+ _mn_lower = model_name.lower()
+ _NEMOTRON_TRUST_SUBSTRINGS = ("nemotron_h", "nemotron-h", "nemotron-3-nano")
+ if any(sub in _mn_lower for sub in _NEMOTRON_TRUST_SUBSTRINGS) and (
+ _mn_lower.startswith("unsloth/") or _mn_lower.startswith("nvidia/")
+ ):
+ trust_remote_code = True
+ logger.info(
+ "Auto-enabled trust_remote_code for Nemotron model: %s",
+ model_name,
+ )
+ return bool(trust_remote_code)
+
+
+def _build_model_config(config: dict, *, trust_remote_code: bool | None = None):
"""Build a ModelConfig from the config dict."""
from utils.models import ModelConfig
@@ -82,11 +103,14 @@ def _build_model_config(config: dict):
hf_token = config.get("hf_token")
hf_token = hf_token if hf_token and hf_token.strip() else None
gguf_variant = config.get("gguf_variant")
+ if trust_remote_code is None:
+ trust_remote_code = _resolve_trust_remote_code(config)
mc = ModelConfig.from_identifier(
model_id = model_name,
hf_token = hf_token,
gguf_variant = gguf_variant,
+ trust_remote_code = trust_remote_code,
)
if not mc:
raise ValueError(f"Invalid model identifier: {model_name}")
@@ -247,7 +271,8 @@ def _beat():
def _handle_load(backend, config: dict, resp_queue: Any) -> None:
"""Handle a load command: load a model into the backend."""
try:
- mc = _build_model_config(config)
+ trust_remote_code = _resolve_trust_remote_code(config)
+ mc = _build_model_config(config, trust_remote_code = trust_remote_code)
hf_token = config.get("hf_token")
hf_token = hf_token if hf_token and hf_token.strip() else None
@@ -287,24 +312,6 @@ def _handle_load(backend, config: dict, resp_queue: Any) -> None:
except Exception as e:
logger.warning("Could not read adapter_config.json: %s", e)
- # Auto-enable trust_remote_code for NemotronH/Nano models only.
- # NemotronH has config parsing bugs requiring trust_remote_code=True.
- # Other transformers 5.x models are native and do NOT need it.
- # NOTE: Must NOT match Llama-Nemotron (standard Llama architecture).
- _NEMOTRON_TRUST_SUBSTRINGS = ("nemotron_h", "nemotron-h", "nemotron-3-nano")
- trust_remote_code = config.get("trust_remote_code", False)
- if not trust_remote_code:
- model_name = config["model_name"]
- _mn_lower = model_name.lower()
- if any(sub in _mn_lower for sub in _NEMOTRON_TRUST_SUBSTRINGS) and (
- _mn_lower.startswith("unsloth/") or _mn_lower.startswith("nvidia/")
- ):
- trust_remote_code = True
- logger.info(
- "Auto-enabled trust_remote_code for Nemotron model: %s",
- model_name,
- )
-
# Send heartbeats every 30s so the orchestrator knows we're still alive
# (download / weight loading can take a long time on slow connections)
xet_disabled = os.environ.get("HF_HUB_DISABLE_XET") == "1"
diff --git a/studio/backend/core/training/trainer.py b/studio/backend/core/training/trainer.py
index b128fb5338..39b491b488 100644
--- a/studio/backend/core/training/trainer.py
+++ b/studio/backend/core/training/trainer.py
@@ -201,7 +201,11 @@ def pre_detect_and_load_tokenizer(
# --- Detect VLM ---
vision = (
- is_vision_model(model_name, hf_token = hf_token)
+ is_vision_model(
+ model_name,
+ hf_token = hf_token,
+ trust_remote_code = trust_remote_code,
+ )
if not self.is_audio
else False
)
@@ -574,7 +578,11 @@ def load_model(
# VLM: vision model with image dataset (mutually exclusive with audio paths)
vision = (
- is_vision_model(model_name, hf_token = hf_token)
+ is_vision_model(
+ model_name,
+ hf_token = hf_token,
+ trust_remote_code = trust_remote_code,
+ )
if not self.is_audio
else False
)
diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py
index 0af9425fdc..21949baf5b 100644
--- a/studio/backend/models/inference.py
+++ b/studio/backend/models/inference.py
@@ -129,6 +129,10 @@ class ValidateModelRequest(BaseModel):
gguf_variant: Optional[str] = Field(
None, description = "GGUF quantization variant (e.g. 'Q4_K_M')"
)
+ trust_remote_code: bool = Field(
+ False,
+ description = "Allow validation probes that require custom model code.",
+ )
class ValidateModelResponse(BaseModel):
@@ -172,6 +176,14 @@ class GenerateRequest(BaseModel):
image_base64: Optional[str] = Field(
None, description = "Base64 encoded image for vision models"
)
+ session_id: Optional[str] = Field(
+ None,
+ description = "[x-unsloth] Session/thread ID for cancellation scoping.",
+ )
+ cancel_id: Optional[str] = Field(
+ None,
+ description = "[x-unsloth] Per-request cancellation token matched by /inference/cancel.",
+ )
class LoadResponse(BaseModel):
@@ -353,6 +365,10 @@ class InferenceStatusResponse(BaseModel):
supports_tools: bool = Field(
False, description = "Whether the active model supports tool calling"
)
+ cache_type_kv: Optional[str] = Field(
+ None,
+ description = "KV cache data type for K and V (e.g. 'f16', 'bf16', 'q8_0')",
+ )
context_length: Optional[int] = Field(
None, description = "Context length of the active model"
)
@@ -1471,3 +1487,159 @@ class AnthropicMessagesResponse(BaseModel):
stop_reason: Optional[str] = None
stop_sequence: Optional[str] = None
usage: AnthropicUsage = Field(default_factory = AnthropicUsage)
+
+
+# ---------------------------------------------------------------------- #
+# Chat document extraction (parsed documents + optional VLM captions) #
+# ---------------------------------------------------------------------- #
+
+
+class ExtractedFigureModel(BaseModel):
+ """A single extracted visual reference, optionally described by a
+ locally-loaded vision model."""
+
+ id: str = Field(..., description = "Stable id (e.g. 'fig-0')")
+ page: Optional[int] = Field(None, description = "1-based page number, if known")
+ caption: Optional[str] = Field(
+ None, description = "Short VLM-generated caption, or null if skipped/failed"
+ )
+ error: Optional[str] = Field(
+ None, description = "Reason the describe call failed, if any"
+ )
+ kind: Literal["figure", "page"] = Field(
+ "figure",
+ description = "Whether this reference is a detected figure or page image",
+ )
+ image_mime: Optional[str] = Field(
+ None, description = "MIME type for image_base64 when a visual payload is present"
+ )
+ image_base64: Optional[str] = Field(
+ None,
+ description = (
+ "Base64-encoded visual payload for this reference. The first visual "
+ "reference is sent to vision-capable chat models as [Image #1]."
+ ),
+ )
+ image_width: Optional[int] = Field(
+ None, ge = 1, description = "Width of image_base64 after resize"
+ )
+ image_height: Optional[int] = Field(
+ None, ge = 1, description = "Height of image_base64 after resize"
+ )
+
+
+class ExtractDocumentResponse(BaseModel):
+ """
+ Returned synchronously from ``POST /chat/extract-document`` for
+ small docs, or as the final SSE event for larger ones.
+ """
+
+ schema_version: int = Field(
+ 1, description = "Document extraction payload schema version"
+ )
+ filename: str = Field(..., description = "Original filename uploaded")
+ markdown: str = Field(
+ ..., description = "Layout-aware Markdown extracted from the document"
+ )
+ page_count: int = Field(0, ge = 0, description = "Number of pages in the source")
+ tokens_est: int = Field(
+ 0, ge = 0, description = "Rough char/4 token estimate for the markdown"
+ )
+ truncated: bool = Field(
+ False,
+ description = "Whether markdown was clipped to the requested token budget",
+ )
+ figures: List[ExtractedFigureModel] = Field(
+ default_factory = list,
+ description = "Figures discovered in the document (captions optional)",
+ )
+ describe_skipped_reason: Optional[str] = Field(
+ None,
+ description = (
+ "If image description was requested but skipped, the reason "
+ "(e.g. 'loaded GGUF is not vision-capable'). Mirrors the "
+ "``reason`` surfaced by /chat/document-support."
+ ),
+ )
+ vlm_source: Optional[str] = Field(
+ None,
+ description = (
+ "Which inference backend served the describe calls: 'gguf', "
+ "'transformers', 'unsloth', or 'none' when no VLM was used."
+ ),
+ )
+ vlm_model: Optional[str] = Field(
+ None,
+ description = "Identifier of the VLM whose captions appear in this document",
+ )
+ image_input_available: bool = Field(
+ False,
+ description = (
+ "Whether the active model can receive an extracted visual payload "
+ "alongside the markdown."
+ ),
+ )
+ warnings: List[str] = Field(
+ default_factory = list,
+ description = "Non-fatal warnings surfaced to the UI",
+ )
+
+
+class VlmCapabilityModel(BaseModel):
+ """Runtime probe result for the currently-loaded model."""
+
+ is_vlm: bool = Field(
+ ..., description = "Whether the active model accepts image inputs"
+ )
+ endpoint_url: Optional[str] = Field(
+ None,
+ description = "Root URL serving /v1/chat/completions for the active model",
+ )
+ model_name: Optional[str] = Field(
+ None, description = "Identifier of the active model, if any is loaded"
+ )
+ source: Literal["gguf", "transformers", "unsloth", "none"] = Field(
+ ..., description = "Which backend currently owns the active model"
+ )
+ reason: Optional[str] = Field(
+ None,
+ description = "Populated when is_vlm is false; explains why the UI toggle is disabled",
+ )
+
+
+class DocumentSupportResponse(BaseModel):
+ """Returned by GET /chat/document-support.
+
+ Drives the Chat settings-card toggles. ``max_visual_payloads`` is kept
+ for older clients as an informational hint, not a hard request cap.
+ """
+
+ schema_version: int = Field(
+ 1, description = "Document support payload schema version"
+ )
+ extraction_available: bool = Field(
+ ...,
+ description = (
+ "Whether the document extraction backend successfully imported "
+ "on the server"
+ ),
+ )
+ max_visual_payloads: int = Field(
+ ...,
+ ge = 0,
+ description = "Legacy visual-payload hint; not a hard request cap",
+ )
+ max_extract_concurrency: int = Field(
+ 1,
+ ge = 1,
+ description = "Maximum server-side document extraction workers",
+ )
+ format_support: Dict[str, bool] = Field(
+ default_factory = dict,
+ description = "Per-format parser availability for document extraction",
+ )
+ unavailable_formats: Dict[str, str] = Field(
+ default_factory = dict,
+ description = "Per-format parser unavailability reasons",
+ )
+ vlm: VlmCapabilityModel
diff --git a/studio/backend/requirements/studio.txt b/studio/backend/requirements/studio.txt
index 96f8816b57..13c556b878 100644
--- a/studio/backend/requirements/studio.txt
+++ b/studio/backend/requirements/studio.txt
@@ -16,5 +16,15 @@ huggingface-hub==0.36.2
structlog>=24.1.0
diceware
ddgs
+pypdf>=6.0.0,<7
+python-multipart>=0.0.26
+# Document extraction relies on pymupdf4llm 1.27+ (installed via
+# data-designer-deps.txt), which pulls pymupdf-layout. The bundled ONNX
+# models work fine on modern onnxruntime; we require >=1.19 because
+# earlier wheels (e.g. 1.17.x) were built against NumPy 1.x and crash
+# on import in venvs that have NumPy 2.x installed (pymupdf.layout ->
+# onnxruntime -> numpy._multiarray_umath ABI mismatch). Verified
+# end-to-end with onnxruntime 1.25.0 + numpy 2.4.x.
+onnxruntime>=1.19
cryptography>=42.0.0
httpx>=0.27.0
diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py
index a156f2397c..6b7fec5bdb 100644
--- a/studio/backend/routes/inference.py
+++ b/studio/backend/routes/inference.py
@@ -9,9 +9,11 @@
import sys
import time
import uuid
+from contextlib import suppress
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import StreamingResponse, JSONResponse, Response
+from pydantic import ValidationError
from typing import Any, Optional, Union
import json
import httpx
@@ -124,6 +126,7 @@ def _friendly_error(exc: Exception) -> str:
_canonicalize_spec_mode,
_hf_offline_if_dns_dead,
detect_reasoning_flags,
+ get_llama_cpp_backend,
)
from core.inference.llama_server_args import (
strip_shadowing_flags,
@@ -151,6 +154,7 @@ def _friendly_error(exc: Exception) -> str:
_canonicalize_spec_mode,
_hf_offline_if_dns_dead,
detect_reasoning_flags,
+ get_llama_cpp_backend,
)
from core.inference.llama_server_args import (
strip_shadowing_flags,
@@ -210,10 +214,14 @@ def _friendly_error(exc: Exception) -> str:
AnthropicUsage,
CreateOpenAIContainerBody,
DeleteOpenAIContainerBody,
+ DocumentSupportResponse,
+ ExtractDocumentResponse,
+ ExtractedFigureModel,
ListOpenAIContainersResponse,
OpenAIContainerRequest,
OpenAIContainerSummary,
)
+from dataclasses import asdict as _asdict
from core.inference.anthropic_compat import (
anthropic_messages_to_openai,
anthropic_tools_to_openai,
@@ -558,12 +566,12 @@ def _resolve_model_identifier_for_request(
return str(grant.canonical_path), display_label, True
-# GGUF inference backend (llama-server)
-_llama_cpp_backend = LlamaCppBackend()
-
-
-def get_llama_cpp_backend() -> LlamaCppBackend:
- return _llama_cpp_backend
+# GGUF inference backend (llama-server) singleton lives in
+# ``core.inference.llama_cpp``. ``get_llama_cpp_backend`` is already
+# imported above and re-exported from this module so external callers
+# that do ``from routes.inference import get_llama_cpp_backend`` keep
+# resolving to the same process-wide instance that load/list/delete/
+# shutdown all consult.
@router.post("/load", response_model = LoadResponse)
@@ -661,6 +669,7 @@ async def load_model(
reasoning_always_on = llama_backend.reasoning_always_on,
supports_preserve_thinking = llama_backend.supports_preserve_thinking,
supports_tools = llama_backend.supports_tools,
+ cache_type_kv = llama_backend.cache_type_kv,
chat_template = llama_backend.chat_template,
speculative_type = llama_backend.requested_spec_mode,
spec_draft_n_max = llama_backend.spec_draft_n_max,
@@ -713,6 +722,26 @@ async def load_model(
chat_template = _chat_template,
)
+ model_defaults = load_model_defaults(request.model_path)
+ defaults_require_trust_remote_code = bool(
+ model_defaults.get("model", {}).get("trust_remote_code", False)
+ or model_defaults.get("inference", {}).get("trust_remote_code", False)
+ )
+ if defaults_require_trust_remote_code and not request.trust_remote_code:
+ display_name = (
+ model_defaults.get("model", {}).get("display_name")
+ or request.model_path.split("/")[-1]
+ or request.model_path
+ )
+ raise HTTPException(
+ status_code = 400,
+ detail = (
+ f"Model '{display_name}' requires trust_remote_code to be enabled. "
+ "Please enable 'Trust remote code' in Chat Settings and try again."
+ ),
+ )
+
+ # Create config using clean factory method.
# is_lora auto-detected from adapter_config.json on disk/HF.
# DNS-probe wrap so offline loads skip 30-60s of soft-failed
# network checks before the worker starts.
@@ -721,6 +750,7 @@ async def load_model(
model_id = model_identifier,
hf_token = request.hf_token,
gguf_variant = request.gguf_variant,
+ trust_remote_code = request.trust_remote_code,
)
if not config:
@@ -1122,10 +1152,39 @@ async def validate_model(
model_identifier, model_log_label, native_grant_backed = (
_resolve_model_identifier_for_request(request, operation = "validate-model")
)
+ if not native_grant_backed:
+ model_defaults = load_model_defaults(request.model_path)
+ default_model_config = model_defaults.get("model", {})
+ default_inference_config = model_defaults.get("inference", {})
+ defaults_require_trust_remote_code = bool(
+ default_model_config.get("trust_remote_code", False)
+ or default_inference_config.get("trust_remote_code", False)
+ )
+ if defaults_require_trust_remote_code and not request.trust_remote_code:
+ display_name = (
+ default_model_config.get("display_name")
+ or request.model_path.split("/")[-1]
+ or request.model_path
+ )
+ return ValidateModelResponse(
+ valid = True,
+ message = (
+ "Model identifier is valid, but this model requires "
+ "trust_remote_code before probing or loading."
+ ),
+ identifier = request.model_path,
+ display_name = display_name,
+ is_gguf = False,
+ is_lora = False,
+ is_vision = bool(default_model_config.get("is_vision", False)),
+ requires_trust_remote_code = True,
+ )
+
config = ModelConfig.from_identifier(
model_id = model_identifier,
hf_token = request.hf_token,
gguf_variant = request.gguf_variant,
+ trust_remote_code = request.trust_remote_code,
)
if not config:
@@ -1231,10 +1290,15 @@ async def cancel_inference(
A cancel_id arriving before its stream registers is stashed briefly
and replayed on registration. Returns {"cancelled": N}.
"""
+ # The cancel body is a tiny dict of identifiers; cap the read so an
+ # authenticated client cannot make this endpoint buffer megabytes
+ # the way the sibling JSON inference endpoints already prevent.
try:
- body = await request.json()
+ body = await _read_json_body_limited(request, max_bytes = 64 * 1024)
if not isinstance(body, dict):
body = {}
+ except HTTPException:
+ raise
except Exception as e:
logger.debug("Failed to parse cancel request body: %s", e)
body = {}
@@ -1260,6 +1324,7 @@ async def cancel_inference(
@router.post("/generate/stream")
async def generate_stream(
+ fastapi_request: Request,
request: GenerateRequest,
current_subject: str = Depends(get_current_subject),
):
@@ -1302,9 +1367,21 @@ async def generate_stream(
status_code = 400, detail = f"Failed to decode image: {str(e)}"
)
+ cancel_event = threading.Event()
+ completion_id = f"legacy-{uuid.uuid4().hex[:12]}"
+ _tracker = _TrackedCancel(
+ cancel_event,
+ request.cancel_id,
+ request.session_id,
+ completion_id,
+ )
+ _tracker.__enter__()
+
async def stream():
+ _DONE = object()
try:
- for chunk in backend.generate_chat_response(
+ yield f"data: {json.dumps({'completion_id': completion_id})}\n\n"
+ gen = backend.generate_chat_response(
messages = request.messages,
system_prompt = request.system_prompt,
image = image,
@@ -1313,7 +1390,19 @@ async def stream():
top_k = request.top_k,
max_new_tokens = request.max_new_tokens,
repetition_penalty = request.repetition_penalty,
- ):
+ cancel_event = cancel_event,
+ )
+ while True:
+ if cancel_event.is_set():
+ backend.reset_generation_state()
+ break
+ if await fastapi_request.is_disconnected():
+ cancel_event.set()
+ backend.reset_generation_state()
+ return
+ chunk = await asyncio.to_thread(next, gen, _DONE)
+ if chunk is _DONE:
+ break
yield f"data: {json.dumps({'content': chunk})}\n\n"
yield "data: [DONE]\n\n"
@@ -1321,6 +1410,9 @@ async def stream():
backend.reset_generation_state()
logger.error(f"Error during generation: {e}", exc_info = True)
yield f"data: {json.dumps({'error': _friendly_error(e)})}\n\n"
+ finally:
+ cancel_event.set()
+ _tracker.__exit__(None, None, None)
return StreamingResponse(
stream(),
@@ -1632,9 +1724,123 @@ def _decode_audio_base64(b64: str) -> np.ndarray:
return waveform.squeeze(0).numpy()
+_OPENAI_CHAT_MAX_IMAGES = 256
+_OPENAI_CHAT_MAX_IMAGE_BYTES = 20 * 1024 * 1024
+_OPENAI_CHAT_MAX_IMAGE_PIXELS = 40_000_000
+_OPENAI_CHAT_MAX_IMAGE_BASE64_CHARS = (
+ (_OPENAI_CHAT_MAX_IMAGE_BYTES + 2) // 3
+) * 4 + 1024
+
+
+def _convert_openai_image_b64_to_png_b64(image_b64: str) -> str:
+ if len(image_b64) > _OPENAI_CHAT_MAX_IMAGE_BASE64_CHARS:
+ raise HTTPException(
+ status_code = 413,
+ detail = "Image payload exceeds the 20 MB decoded-image limit.",
+ )
+
+ try:
+ import base64 as _b64
+ from io import BytesIO as _BytesIO
+ from PIL import Image as _Image
+
+ raw = _b64.b64decode(image_b64, validate = True)
+ if len(raw) > _OPENAI_CHAT_MAX_IMAGE_BYTES:
+ raise HTTPException(
+ status_code = 413,
+ detail = "Image payload exceeds the 20 MB decoded-image limit.",
+ )
+ with _Image.open(_BytesIO(raw)) as img:
+ width, height = img.size
+ if width * height > _OPENAI_CHAT_MAX_IMAGE_PIXELS:
+ raise HTTPException(
+ status_code = 413,
+ detail = "Image dimensions exceed the 40 MP limit.",
+ )
+ converted = img.convert("RGB")
+ buf = _BytesIO()
+ converted.save(buf, format = "PNG")
+ png = buf.getvalue()
+ if len(png) > _OPENAI_CHAT_MAX_IMAGE_BYTES:
+ raise HTTPException(
+ status_code = 413,
+ detail = "Converted image payload exceeds the 20 MB limit.",
+ )
+ return _b64.b64encode(png).decode("ascii")
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code = 400, detail = f"Failed to process image: {e}"
+ ) from e
+
+
+def _data_url_base64_payload(url: str) -> str:
+ try:
+ header, b64data = url.split(",", 1)
+ except ValueError as exc:
+ raise HTTPException(
+ status_code = 400, detail = "Image data URL is missing base64 payload."
+ ) from exc
+ if ";base64" not in header.lower():
+ raise HTTPException(
+ status_code = 400, detail = "Image data URL must be base64 encoded."
+ )
+ return b64data
+
+
+def _normalize_openai_message_images(
+ openai_messages: list[dict],
+ *,
+ is_vision: bool,
+ not_vision_detail: str,
+) -> bool:
+ """Apply image count/size/pixel guards and normalize data URLs to PNG."""
+ has_image = False
+ image_count = 0
+
+ for msg in openai_messages:
+ content = msg.get("content")
+ if not isinstance(content, list):
+ continue
+ for part in content:
+ if not isinstance(part, dict) or part.get("type") != "image_url":
+ continue
+
+ has_image = True
+ image_count += 1
+ if image_count > _OPENAI_CHAT_MAX_IMAGES:
+ raise HTTPException(
+ status_code = 413,
+ detail = f"Too many images provided; maximum is {_OPENAI_CHAT_MAX_IMAGES}.",
+ )
+ if not is_vision:
+ raise HTTPException(status_code = 400, detail = not_vision_detail)
+
+ image_url = part.get("image_url") or {}
+ if not isinstance(image_url, dict):
+ raise HTTPException(
+ status_code = 400, detail = "Invalid image_url content part."
+ )
+ url = image_url.get("url", "")
+ if not isinstance(url, str):
+ raise HTTPException(status_code = 400, detail = "Invalid image_url URL.")
+ if not url.startswith("data:"):
+ # Remote URLs are counted but cannot be byte/pixel checked here.
+ continue
+
+ b64data = _data_url_base64_payload(url)
+ png_b64 = _convert_openai_image_b64_to_png_b64(b64data)
+ normalized = dict(image_url)
+ normalized["url"] = f"data:image/png;base64,{png_b64}"
+ part["image_url"] = normalized
+
+ return has_image
+
+
def _extract_content_parts(
messages: list,
-) -> tuple[str, list[dict], "Optional[str]"]:
+) -> tuple[str, list[dict], list[str]]:
"""
Parse OpenAI-format messages into components the inference backend expects.
@@ -1644,11 +1850,11 @@ def _extract_content_parts(
Returns:
system_prompt: The system message text (empty string if none provided).
chat_messages: Non-system messages with content flattened to strings.
- image_base64: Base64 data of the *first* image found, or ``None``.
+ image_base64s: Base64 data for image parts, in request order.
"""
system_prompt = ""
chat_messages: list[dict] = []
- first_image_b64: Optional[str] = None
+ image_b64s: list[str] = []
for msg in messages:
# ── System messages → extract as system_prompt ────────
@@ -1672,11 +1878,12 @@ def _extract_content_parts(
for part in msg.content:
if part.type == "text":
text_parts.append(part.text)
- elif part.type == "image_url" and first_image_b64 is None:
+ elif part.type == "image_url":
url = part.image_url.url
if url.startswith("data:"):
# data:image/png;base64, → extract
- first_image_b64 = url.split(",", 1)[1] if "," in url else None
+ if "," in url:
+ image_b64s.append(url.split(",", 1)[1])
else:
logger.warning(
f"Remote image URLs not yet supported: {url[:80]}..."
@@ -1684,7 +1891,7 @@ def _extract_content_parts(
combined_text = "\n".join(text_parts) if text_parts else ""
chat_messages.append({"role": msg.role, "content": combined_text})
- return system_prompt, chat_messages, first_image_b64
+ return system_prompt, chat_messages, image_b64s
# ── External provider proxy ──────────────────────────────────────
@@ -2149,9 +2356,23 @@ async def delete_openai_container(
@router.post("/chat/completions")
async def openai_chat_completions(
- payload: ChatCompletionRequest,
request: Request,
current_subject: str = Depends(get_current_subject),
+):
+ body = await _read_json_body_limited(
+ request,
+ max_bytes = _OPENAI_CHAT_BODY_MAX_BYTES,
+ )
+ try:
+ payload = ChatCompletionRequest.model_validate(body)
+ except ValidationError as exc:
+ raise HTTPException(status_code = 422, detail = exc.errors()) from exc
+ return await _openai_chat_completions_impl(payload, request)
+
+
+async def _openai_chat_completions_impl(
+ payload: ChatCompletionRequest,
+ request: Request,
):
"""
OpenAI-compatible chat completions endpoint.
@@ -2406,7 +2627,7 @@ async def audio_input_stream():
)
# ── Parse messages (handles multimodal content parts) ─────
- system_prompt, chat_messages, extracted_image_b64 = _extract_content_parts(
+ system_prompt, chat_messages, extracted_image_b64s = _extract_content_parts(
payload.messages
)
@@ -2710,7 +2931,7 @@ async def gguf_tool_stream():
def gguf_generate():
return llama_backend.generate_chat_completion(
messages = gguf_messages,
- image_b64 = image_b64,
+ image_b64s = image_b64s,
temperature = payload.temperature,
top_p = payload.top_p,
top_k = payload.top_k,
@@ -2879,7 +3100,9 @@ async def gguf_stream_chunks():
# ── Standard Unsloth path ─────────────────────────────────
# Decode image (from content parts OR legacy field)
- image_b64 = extracted_image_b64 or payload.image_base64
+ image_b64 = (
+ extracted_image_b64s[0] if extracted_image_b64s else payload.image_base64
+ )
image = None
if image_b64:
@@ -3425,9 +3648,9 @@ async def serve_sandbox_file(
# ── Path containment check ──────────────────────────────────
home = os.path.expanduser("~")
sandbox_root = os.path.realpath(os.path.join(home, "studio_sandbox"))
- safe_session = os.path.basename(session_id.replace("..", ""))
- if not safe_session:
+ if not _re.fullmatch(r"[A-Za-z0-9_-]+", session_id or ""):
raise HTTPException(status_code = 404, detail = "Not found")
+ safe_session = session_id
file_path = os.path.realpath(
os.path.join(sandbox_root, safe_session, safe_filename)
@@ -3516,7 +3739,9 @@ async def openai_completions(
detail = "No GGUF model loaded. Load a GGUF model first.",
)
- body = await request.json()
+ body = await _read_json_body_limited(
+ request, max_bytes = _OPENAI_PROXY_BODY_MAX_BYTES
+ )
target_url = f"{llama_backend.base_url}/v1/completions"
is_stream = body.get("stream", False)
@@ -3595,7 +3820,9 @@ async def openai_embeddings(
detail = "No GGUF model loaded. Load a GGUF model first.",
)
- body = await request.json()
+ body = await _read_json_body_limited(
+ request, max_bytes = _OPENAI_PROXY_BODY_MAX_BYTES
+ )
target_url = f"{llama_backend.base_url}/v1/embeddings"
async with httpx.AsyncClient() as client:
@@ -3894,7 +4121,7 @@ async def _responses_non_streaming(
) -> JSONResponse:
"""Handle a non-streaming Responses API call."""
chat_req = _build_chat_request(payload, messages, stream = False)
- result = await openai_chat_completions(chat_req, request)
+ result = await _openai_chat_completions_impl(chat_req, request)
# openai_chat_completions returns a JSONResponse for non-streaming
if isinstance(result, JSONResponse):
@@ -4410,45 +4637,11 @@ def _normalize_anthropic_openai_images(
HTTPException(400) when images are present but the active model is
not a vision model, or when an image cannot be decoded.
"""
- from PIL import Image
-
- has_image = False
- for msg in openai_messages:
- content = msg.get("content")
- if not isinstance(content, list):
- continue
- for part in content:
- if part.get("type") != "image_url":
- continue
-
- has_image = True
- if not is_vision:
- raise HTTPException(
- status_code = 400,
- detail = "Image provided but current GGUF model does not support vision.",
- )
-
- url = (part.get("image_url") or {}).get("url", "")
- if not url.startswith("data:"):
- # Remote URLs are forwarded as-is; llama-server will
- # fetch (or fail) per its own support matrix.
- continue
-
- try:
- _, b64data = url.split(",", 1)
- raw = base64.b64decode(b64data)
- img = Image.open(io.BytesIO(raw)).convert("RGB")
- buf = io.BytesIO()
- img.save(buf, format = "PNG")
- png_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
- except Exception:
- raise HTTPException(
- status_code = 400,
- detail = "Failed to process image.",
- )
- part["image_url"] = {"url": f"data:image/png;base64,{png_b64}"}
-
- return has_image
+ return _normalize_openai_message_images(
+ openai_messages,
+ is_vision = is_vision,
+ not_vision_detail = "Image provided but current GGUF model does not support vision.",
+ )
@router.post("/messages")
@@ -5271,7 +5464,7 @@ def _drop_empty_assistant_sentinels(messages: list[dict]) -> list[dict]:
return out
-def _openai_messages_for_passthrough(payload) -> list[dict]:
+def _openai_messages_for_passthrough(payload, *, is_vision: bool = True) -> list[dict]:
"""Build OpenAI-format message dicts for the /v1/chat/completions
passthrough path.
@@ -5279,7 +5472,7 @@ def _openai_messages_for_passthrough(payload) -> list[dict]:
unset optional fields) so they are already in standard OpenAI format
— including ``role="tool"`` tool-result messages and assistant messages
that carry structured ``tool_calls``. Content-parts images already in
- the message list are left untouched.
+ the message list are counted, bounded, and data URLs are normalized to PNG.
When a client uses Studio's legacy ``image_base64`` top-level field, the
image is re-encoded to PNG (llama-server's stb_image has limited format
@@ -5291,41 +5484,29 @@ def _openai_messages_for_passthrough(payload) -> list[dict]:
[m.model_dump(exclude_none = True) for m in payload.messages]
)
- if not payload.image_base64:
- return messages
-
- try:
- import base64 as _b64
- from io import BytesIO as _BytesIO
- from PIL import Image as _Image
-
- raw = _b64.b64decode(payload.image_base64)
- img = _Image.open(_BytesIO(raw)).convert("RGB")
- buf = _BytesIO()
- img.save(buf, format = "PNG")
- png_b64 = _b64.b64encode(buf.getvalue()).decode("ascii")
- except Exception:
- raise HTTPException(
- status_code = 400,
- detail = "Failed to process image.",
- )
+ if payload.image_base64:
+ data_url = f"data:image/unknown;base64,{payload.image_base64}"
+ image_part = {"type": "image_url", "image_url": {"url": data_url}}
- data_url = f"data:image/png;base64,{png_b64}"
- image_part = {"type": "image_url", "image_url": {"url": data_url}}
-
- for msg in reversed(messages):
- if msg.get("role") != "user":
- continue
- existing = msg.get("content")
- if isinstance(existing, str):
- msg["content"] = [{"type": "text", "text": existing}, image_part]
- elif isinstance(existing, list):
- existing.append(image_part)
+ for msg in reversed(messages):
+ if msg.get("role") != "user":
+ continue
+ existing = msg.get("content")
+ if isinstance(existing, str):
+ msg["content"] = [{"type": "text", "text": existing}, image_part]
+ elif isinstance(existing, list):
+ existing.append(image_part)
+ else:
+ msg["content"] = [image_part]
+ break
else:
- msg["content"] = [image_part]
- break
- else:
- messages.append({"role": "user", "content": [image_part]})
+ messages.append({"role": "user", "content": [image_part]})
+
+ _normalize_openai_message_images(
+ messages,
+ is_vision = is_vision,
+ not_vision_detail = "Image provided but current GGUF model does not support vision.",
+ )
return messages
@@ -5385,14 +5566,16 @@ def _extract_response_format(payload):
return rf if isinstance(rf, dict) else None
-def _build_openai_passthrough_body(payload, backend_ctx = None) -> dict:
+def _build_openai_passthrough_body(
+ payload, backend_ctx = None, *, is_vision: bool = True
+) -> dict:
"""Assemble the llama-server request body from a ChatCompletionRequest.
Only explicitly-known OpenAI / llama-server fields are forwarded so that
Studio-specific extensions (``enable_tools``, ``enabled_tools``,
``session_id``, ...) never leak to the backend.
"""
- messages = _openai_messages_for_passthrough(payload)
+ messages = _openai_messages_for_passthrough(payload, is_vision = is_vision)
tool_choice = payload.tool_choice if payload.tool_choice is not None else "auto"
# When the caller asked for a specific reasoning mode, forward it to
# llama-server via chat_template_kwargs so the Jinja template renders
@@ -5437,7 +5620,9 @@ async def _openai_passthrough_stream(
"""
target_url = f"{llama_backend.base_url}/v1/chat/completions"
body = _build_openai_passthrough_body(
- payload, backend_ctx = llama_backend.context_length
+ payload,
+ backend_ctx = llama_backend.context_length,
+ is_vision = llama_backend.is_vision,
)
_cancel_keys = (payload.cancel_id, payload.session_id, completion_id)
@@ -5595,7 +5780,9 @@ async def _openai_passthrough_non_streaming(
"""
target_url = f"{llama_backend.base_url}/v1/chat/completions"
body = _build_openai_passthrough_body(
- payload, backend_ctx = llama_backend.context_length
+ payload,
+ backend_ctx = llama_backend.context_length,
+ is_vision = llama_backend.is_vision,
)
try:
@@ -5657,3 +5844,952 @@ async def _openai_passthrough_non_streaming(
# verbatim (matches the docstring). Status is guaranteed 200 by
# the check above.
return Response(content = resp.content, media_type = "application/json")
+
+
+# ---------------------------------------------------------------------- #
+# Chat document extraction (PyMuPDF4LLM + optional VLM image description)#
+# ---------------------------------------------------------------------- #
+
+try:
+ from core.chat import (
+ DOCUMENT_EXTRACTION_AVAILABLE as _DOCUMENT_EXTRACTION_AVAILABLE,
+ DEFAULT_DOCUMENT_VISUAL_PAYLOADS as _DEFAULT_DOCUMENT_VISUAL_PAYLOADS,
+ DocumentExtractionBusy as _DocumentExtractionBusy,
+ DocumentExtractionCancelled as _DocumentExtractionCancelled,
+ DocumentExtractionEncrypted as _DocumentExtractionEncrypted,
+ DocumentExtractionTimeout as _DocumentExtractionTimeout,
+ DocumentExtractionUnavailable as _DocumentExtractionUnavailable,
+ _EXTRACT_CONCURRENCY as _DOCUMENT_EXTRACT_CONCURRENCY,
+ MAX_DOCUMENT_VISUAL_PAYLOADS as _MAX_DOCUMENT_VISUAL_PAYLOADS,
+ SUPPORTED_MIME_TYPES as _DOC_MIME_OK,
+ SUPPORTED_SUFFIXES as _DOC_SUFFIX_OK,
+ VlmCapability as _VlmCapability,
+ _EXTRACT_SEMAPHORE,
+ _drain_future_exception as _drain_doc_future_exception,
+ detect_loaded_vlm as _detect_loaded_vlm,
+ document_parser_support as _document_parser_support,
+ document_parser_unavailable_reasons as _document_parser_unavailable_reasons,
+ extract_document as _extract_document,
+ extract_self_base_url as _extract_self_base_url,
+ )
+except ImportError: # pragma: no cover - package always installed alongside
+ _DOCUMENT_EXTRACTION_AVAILABLE = False
+ _DEFAULT_DOCUMENT_VISUAL_PAYLOADS = 0
+ _DOCUMENT_EXTRACT_CONCURRENCY = 1
+ _MAX_DOCUMENT_VISUAL_PAYLOADS = 0
+ _DOC_MIME_OK = frozenset()
+ _DOC_SUFFIX_OK = frozenset()
+ _detect_loaded_vlm = None # type: ignore[assignment]
+ _extract_document = None # type: ignore[assignment]
+ _extract_self_base_url = None # type: ignore[assignment]
+ _document_parser_support = lambda: {} # type: ignore[assignment]
+ _document_parser_unavailable_reasons = lambda: {} # type: ignore[assignment]
+ _VlmCapability = None # type: ignore[assignment]
+ _drain_doc_future_exception = lambda _f: None # type: ignore[assignment]
+
+ class _DocumentExtractionUnavailable(RuntimeError): # type: ignore[no-redef]
+ pass
+
+ class _DocumentExtractionTimeout(RuntimeError): # type: ignore[no-redef]
+ pass
+
+ class _DocumentExtractionBusy(RuntimeError): # type: ignore[no-redef]
+ pass
+
+ class _DocumentExtractionCancelled(RuntimeError): # type: ignore[no-redef]
+ pass
+
+ class _DocumentExtractionEncrypted(RuntimeError): # type: ignore[no-redef]
+ pass
+
+ _EXTRACT_SEMAPHORE = threading.BoundedSemaphore(1)
+
+
+_EXTRACT_MAX_BYTES = 100 * 1024 * 1024
+_EXTRACT_MULTIPART_OVERHEAD_BYTES = 1024 * 1024
+_EXTRACT_READ_CHUNK_BYTES = 64 * 1024
+_EXTRACT_MAX_PAGES_INLINE = 200
+_EXTRACT_TOKEN_BUDGET_DEFAULT = 8000
+_EXTRACT_TOKEN_BUDGET_MIN = 0
+
+_DOCX_MIME = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
+_HTML_MIME_TYPES = {"text/html"}
+_DATA_MIME_TYPES = {
+ "application/json",
+ "application/x-ndjson",
+ "application/xml",
+ "application/yaml",
+ "text/csv",
+ "text/xml",
+ "text/yaml",
+}
+_CODE_MIME_TYPES = {
+ "application/javascript",
+ "text/css",
+ "text/javascript",
+}
+_DATA_SUFFIXES = {".csv", ".json", ".jsonl", ".yaml", ".yml", ".xml"}
+_CODE_SUFFIXES = {
+ ".py",
+ ".js",
+ ".jsx",
+ ".ts",
+ ".tsx",
+ ".go",
+ ".rs",
+ ".java",
+ ".c",
+ ".cpp",
+ ".h",
+ ".hpp",
+ ".cs",
+ ".php",
+ ".rb",
+ ".swift",
+ ".kt",
+ ".kts",
+ ".scala",
+ ".sh",
+ ".bash",
+ ".zsh",
+ ".ps1",
+ ".sql",
+ ".toml",
+ ".ini",
+ ".cfg",
+ ".css",
+ ".scss",
+}
+
+
+async def _wait_for_document_request_disconnect(
+ fastapi_request: Request,
+ cancel_event: threading.Event,
+) -> bool:
+ while not cancel_event.is_set():
+ if await fastapi_request.is_disconnected():
+ cancel_event.set()
+ return True
+ await asyncio.sleep(0.2)
+ return False
+
+
+def _extract_ext(filename: str) -> str:
+ return os.path.splitext(filename or "")[1].lower()
+
+
+def _is_supported_upload(filename: str, content_type: str) -> bool:
+ if (content_type or "").split(";")[0].strip().lower() in _DOC_MIME_OK:
+ return True
+ return _extract_ext(filename) in _DOC_SUFFIX_OK
+
+
+def _document_upload_format(filename: str, content_type: str) -> Optional[str]:
+ mime = (content_type or "").split(";")[0].strip().lower()
+ ext = _extract_ext(filename)
+ if mime == "application/pdf" or ext == ".pdf":
+ return "pdf"
+ if mime == _DOCX_MIME or ext == ".docx":
+ return "docx"
+ if mime in _HTML_MIME_TYPES or ext in {".html", ".htm"}:
+ return "html"
+ if mime in _DATA_MIME_TYPES or ext in _DATA_SUFFIXES:
+ return "data"
+ if mime in _CODE_MIME_TYPES or ext in _CODE_SUFFIXES:
+ return "code"
+ if mime.startswith("text/") or ext in {".md", ".txt", ".log"}:
+ return "text"
+ return None
+
+
+def _raise_if_document_parser_unavailable(
+ filename: str,
+ content_type: str,
+) -> None:
+ format_key = _document_upload_format(filename, content_type)
+ if format_key is None:
+ return
+ support = _document_parser_support()
+ if support.get(format_key, True):
+ return
+ reason = _document_parser_unavailable_reasons().get(
+ format_key,
+ f"{format_key.upper()} extraction is not available on this server.",
+ )
+ raise HTTPException(status_code = 501, detail = reason)
+
+
+def _document_caption_authorization_header(
+ capability: Any,
+ llama_backend: Any,
+ studio_authorization_header: Optional[str],
+) -> Optional[str]:
+ if getattr(capability, "source", None) != "gguf":
+ return studio_authorization_header
+ api_key = getattr(llama_backend, "api_key", None) or getattr(
+ llama_backend, "_api_key", None
+ )
+ return f"Bearer {api_key}" if api_key else None
+
+
+_FORM_TRUE = {"1", "true", "yes", "on"}
+_FORM_FALSE = {"0", "false", "no", "off"}
+
+
+def _parse_bool_form(value: Any, *, default: bool, field: str = "value") -> bool:
+ if value is None:
+ return default
+ norm = str(value).strip().lower()
+ if not norm:
+ return default
+ if norm in _FORM_TRUE:
+ return True
+ if norm in _FORM_FALSE:
+ return False
+ raise HTTPException(
+ status_code = 400,
+ detail = f"Invalid boolean value for {field}: {value!r}",
+ )
+
+
+def _parse_int_form(
+ value: Any,
+ *,
+ default: int,
+ lo: int,
+ hi: Optional[int] = None,
+) -> int:
+ try:
+ parsed = int(value) if value is not None else default
+ except (TypeError, ValueError):
+ parsed = default
+ parsed = max(lo, parsed)
+ return min(parsed, hi) if hi is not None else parsed
+
+
+def _reject_oversized_content_length(request: Request) -> None:
+ raw = request.headers.get("content-length")
+ if raw is None:
+ return
+ try:
+ total = int(raw)
+ except ValueError:
+ raise HTTPException(
+ status_code = 400,
+ detail = "Invalid Content-Length header",
+ )
+ max_request_bytes = _EXTRACT_MAX_BYTES + _EXTRACT_MULTIPART_OVERHEAD_BYTES
+ if total > max_request_bytes:
+ raise HTTPException(
+ status_code = 413,
+ detail = (
+ f"Request exceeds the {_EXTRACT_MAX_BYTES // (1024*1024)} MB "
+ "file limit"
+ ),
+ )
+
+
+async def _iter_request_body_limited(request: Request, *, max_bytes: int):
+ total = 0
+ async for chunk in request.stream():
+ if not chunk:
+ continue
+ total += len(chunk)
+ if total > max_bytes:
+ raise HTTPException(
+ status_code = 413,
+ detail = (
+ f"Request exceeds the {_EXTRACT_MAX_BYTES // (1024*1024)} MB "
+ "file limit"
+ ),
+ )
+ yield chunk
+
+
+async def _read_multipart_form_limited(request: Request, *, max_bytes: int):
+ from starlette.formparsers import MultiPartException, MultiPartParser
+
+ try:
+ parser = MultiPartParser(
+ request.headers,
+ _iter_request_body_limited(request, max_bytes = max_bytes),
+ )
+ return await parser.parse()
+ except HTTPException:
+ raise
+ except MultiPartException as exc:
+ raise HTTPException(status_code = 400, detail = exc.message) from exc
+
+
+# Cap on /completions and /embeddings JSON bodies. Those proxy payloads should
+# be small (a few prompts + sampling params); 10 MB is generous headroom while
+# still protecting against unbounded buffering when a client sends a falsified
+# Content-Length and streams a much larger body.
+_OPENAI_PROXY_BODY_MAX_BYTES = 10 * 1024 * 1024
+# Chat-completions also carries multimodal data URLs. Keep it bounded, but
+# large enough that document extraction's visual-payload budget reaches the
+# existing per-image guards instead of being rejected by the JSON body reader
+# first.
+_OPENAI_CHAT_BODY_IMAGE_SLOTS = max(
+ 1,
+ min(
+ _OPENAI_CHAT_MAX_IMAGES,
+ _MAX_DOCUMENT_VISUAL_PAYLOADS or _DEFAULT_DOCUMENT_VISUAL_PAYLOADS or 1,
+ ),
+)
+_OPENAI_CHAT_BODY_MAX_BYTES = max(
+ 32 * 1024 * 1024,
+ (_OPENAI_CHAT_MAX_IMAGE_BASE64_CHARS * _OPENAI_CHAT_BODY_IMAGE_SLOTS)
+ + (2 * 1024 * 1024),
+)
+
+
+async def _read_json_body_limited(request: Request, *, max_bytes: int) -> Any:
+ """Stream the request body, enforce a hard byte cap, then parse as JSON.
+
+ Unlike trusting Content-Length, this aborts mid-stream once the cap is
+ exceeded so a spoofed header cannot force the server to buffer arbitrary
+ payloads before parsing.
+ """
+ total = 0
+ chunks: list[bytes] = []
+ async for chunk in request.stream():
+ if not chunk:
+ continue
+ total += len(chunk)
+ if total > max_bytes:
+ raise HTTPException(
+ status_code = 413,
+ detail = f"Request body exceeds the {max_bytes // (1024 * 1024)} MB limit",
+ )
+ chunks.append(chunk)
+ raw = b"".join(chunks)
+ try:
+ return json.loads(raw) if raw else {}
+ except json.JSONDecodeError as exc:
+ raise HTTPException(status_code = 400, detail = f"Invalid JSON body: {exc.msg}")
+
+
+async def _read_upload_limited(upload: Any, *, max_bytes: int) -> bytes:
+ buf = bytearray()
+ while True:
+ chunk = await upload.read(_EXTRACT_READ_CHUNK_BYTES)
+ if not chunk:
+ break
+ buf.extend(chunk)
+ if len(buf) > max_bytes:
+ raise HTTPException(
+ status_code = 413,
+ detail = f"File exceeds the {max_bytes // (1024*1024)} MB limit",
+ )
+ return bytes(buf)
+
+
+def _is_pdf_upload(filename: str, content_type: str) -> bool:
+ mime = (content_type or "").split(";")[0].strip().lower()
+ return mime == "application/pdf" or _extract_ext(filename) == ".pdf"
+
+
+def _preflight_pdf_page_count(
+ file_bytes: bytes,
+ filename: str,
+ content_type: str,
+) -> Optional[int]:
+ if not _is_pdf_upload(filename, content_type):
+ return None
+
+ pypdf_error: Optional[BaseException] = None
+ try:
+ from pypdf import PdfReader
+
+ reader = PdfReader(io.BytesIO(file_bytes), strict = False)
+ # Many PDFs report ``is_encrypted=True`` even though they only use a
+ # null/empty user password and open fine (Acrobat-distilled docs,
+ # the classic Orimi test PDF, scanner output). Try the empty
+ # password before refusing; PyMuPDF's ``needs_pass`` is the real
+ # signal in the fallback branch below.
+ if getattr(reader, "is_encrypted", False):
+ try:
+ if reader.decrypt("") == 0:
+ raise HTTPException(
+ status_code = 422,
+ detail = "Encrypted PDFs are not supported for inline extraction",
+ )
+ except HTTPException:
+ raise
+ except Exception:
+ # ``decrypt`` itself failed (corrupt /Encrypt dict, unknown
+ # algorithm). Fall through to the PyMuPDF fallback rather
+ # than declaring the file encrypted.
+ raise RuntimeError("pypdf decrypt probe failed")
+ return len(reader.pages)
+ except HTTPException:
+ raise
+ except Exception as exc:
+ pypdf_error = exc
+ logger.warning(
+ "pypdf page-count preflight failed for %s; trying PyMuPDF fallback",
+ filename,
+ )
+
+ try:
+ import pymupdf as _pymupdf # type: ignore
+
+ doc = _pymupdf.open(stream = file_bytes, filetype = "pdf")
+ try:
+ # PyMuPDF's ``needs_pass`` is True only when an actual password
+ # is required. ``is_encrypted`` is True for any file with an
+ # /Encrypt dict, which includes the common null-password case
+ # that opens fine. Refuse only when a password is actually
+ # needed.
+ if getattr(doc, "needs_pass", False):
+ raise HTTPException(
+ status_code = 422,
+ detail = "Encrypted PDFs are not supported for inline extraction",
+ )
+ return len(doc)
+ finally:
+ doc.close()
+ except HTTPException:
+ raise
+ except Exception as exc:
+ if pypdf_error is not None:
+ logger.warning(
+ "PyMuPDF page-count fallback also failed for %s: %s",
+ filename,
+ exc,
+ )
+ else:
+ logger.exception("PDF page-count preflight failed for %s", filename)
+ raise HTTPException(
+ status_code = 400,
+ detail = "Unable to read PDF page count before extraction",
+ ) from exc
+
+
+def _truncate_markdown_to_token_budget(
+ markdown: str,
+ *,
+ token_budget: int,
+ original_tokens_est: int,
+) -> tuple[str, int, Optional[str]]:
+ char_budget = max(_EXTRACT_TOKEN_BUDGET_MIN, token_budget) * 4
+ if len(markdown) <= char_budget:
+ return markdown, original_tokens_est, None
+
+ clipped = markdown[:char_budget]
+ clipped = (
+ _re.sub(r"\s+\S*$", "", clipped).rstrip() or markdown[:char_budget].rstrip()
+ )
+ clipped += f"\n\n[... truncated; original was ~{original_tokens_est} tokens ...]"
+ warning = (
+ f"Extracted markdown was truncated to {token_budget} tokens "
+ f"(original was ~{original_tokens_est} tokens)."
+ )
+ return clipped, max(0, len(clipped) // 4), warning
+
+
+@studio_router.get("/chat/document-support", response_model = DocumentSupportResponse)
+async def document_support_endpoint(
+ fastapi_request: Request,
+ current_subject: str = Depends(get_current_subject),
+):
+ """Whether document extraction + per-figure captions are available.
+
+ Polled by the frontend when the settings panel mounts and when the
+ loaded model changes. The response drives the "describe figures"
+ toggle: when ``vlm.is_vlm`` is false the UI disables the toggle and
+ surfaces ``vlm.reason`` as tooltip text.
+ """
+ if _extract_document is None or _detect_loaded_vlm is None:
+ return DocumentSupportResponse(
+ extraction_available = False,
+ max_visual_payloads = 0,
+ max_extract_concurrency = 1,
+ format_support = {},
+ unavailable_formats = {},
+ vlm = {
+ "is_vlm": False,
+ "endpoint_url": None,
+ "model_name": None,
+ "source": "none",
+ "reason": "document extraction backend is not installed",
+ },
+ )
+
+ self_base_url = (
+ _extract_self_base_url(fastapi_request) if _extract_self_base_url else None
+ )
+ try:
+ cap = _detect_loaded_vlm(
+ self_base_url,
+ llama_backend = get_llama_cpp_backend(),
+ )
+ except Exception as exc:
+ logger.exception("Document support VLM probe failed")
+ if _VlmCapability is not None:
+ cap = _VlmCapability.none(
+ f"document support probe failed: {type(exc).__name__}"
+ )
+ else: # pragma: no cover - only when core.chat import fallback is active
+ cap = None
+ return DocumentSupportResponse(
+ extraction_available = True,
+ max_visual_payloads = _MAX_DOCUMENT_VISUAL_PAYLOADS,
+ max_extract_concurrency = _DOCUMENT_EXTRACT_CONCURRENCY,
+ format_support = _document_parser_support(),
+ unavailable_formats = _document_parser_unavailable_reasons(),
+ vlm = cap.to_dict()
+ if cap is not None
+ else {
+ "is_vlm": False,
+ "endpoint_url": None,
+ "model_name": None,
+ "source": "none",
+ "reason": "document support probe failed",
+ },
+ )
+
+
+@studio_router.post("/chat/extract-document")
+async def extract_document_endpoint(
+ fastapi_request: Request,
+ current_subject: str = Depends(get_current_subject),
+):
+ """Upload a PDF / DOCX / HTML / MD / text file and stream
+ progress events plus a final layout-aware Markdown payload.
+
+ Response is NDJSON (one JSON object per line). Validation errors
+ raised before streaming begins return as standard HTTP 4xx/5xx.
+ Once the stream starts, the final line is `{"stage":"result", ...}`
+ or `{"stage":"error", ...}`. Large documents (>200 pages) are
+ rejected with 413 until the background-job path lands.
+ """
+ if _extract_document is None:
+ raise HTTPException(
+ status_code = 501,
+ detail = (
+ "document extraction backend is not installed. Re-run Studio "
+ "setup to install the parser dependencies."
+ ),
+ )
+
+ _reject_oversized_content_length(fastapi_request)
+
+ try:
+ try:
+ form = await _read_multipart_form_limited(
+ fastapi_request,
+ max_bytes = _EXTRACT_MAX_BYTES + _EXTRACT_MULTIPART_OVERHEAD_BYTES,
+ )
+ except HTTPException:
+ raise
+ except Exception as exc:
+ logger.exception("Invalid multipart document extraction payload")
+ raise HTTPException(status_code = 400, detail = "Invalid multipart payload")
+
+ upload = form.get("file")
+ if upload is None or not hasattr(upload, "read"):
+ raise HTTPException(status_code = 400, detail = "Missing 'file' field")
+
+ filename = getattr(upload, "filename", None) or "upload"
+ content_type = getattr(upload, "content_type", "") or ""
+ if not _is_supported_upload(filename, content_type):
+ raise HTTPException(
+ status_code = 415,
+ detail = f"Unsupported file type: {filename} ({content_type})",
+ )
+ _raise_if_document_parser_unavailable(filename, content_type)
+
+ file_bytes = await _read_upload_limited(upload, max_bytes = _EXTRACT_MAX_BYTES)
+ if not file_bytes:
+ raise HTTPException(status_code = 400, detail = "Uploaded file is empty")
+
+ preflight_page_count = _preflight_pdf_page_count(
+ file_bytes, filename, content_type
+ )
+ if (
+ preflight_page_count is not None
+ and preflight_page_count > _EXTRACT_MAX_PAGES_INLINE
+ ):
+ raise HTTPException(
+ status_code = 413,
+ detail = (
+ f"Document has {preflight_page_count} pages; inline extraction "
+ f"is capped at {_EXTRACT_MAX_PAGES_INLINE}. Split into smaller "
+ f"documents or reduce the page range."
+ ),
+ )
+
+ describe_images = _parse_bool_form(
+ form.get("describe_images"), default = False, field = "describe_images"
+ )
+ use_vlm_ocr = _parse_bool_form(
+ form.get("use_vlm_ocr"), default = False, field = "use_vlm_ocr"
+ )
+ max_figures = _parse_int_form(
+ form.get("max_figures"),
+ default = 40,
+ lo = 0,
+ )
+ max_visual_payloads = _parse_int_form(
+ form.get("max_visual_payloads"),
+ default = _DEFAULT_DOCUMENT_VISUAL_PAYLOADS,
+ lo = 0,
+ )
+ token_budget = _parse_int_form(
+ form.get("token_budget"),
+ default = _EXTRACT_TOKEN_BUDGET_DEFAULT,
+ lo = 0,
+ )
+
+ self_base_url = (
+ _extract_self_base_url(fastapi_request) if _extract_self_base_url else None
+ )
+ llama_backend = get_llama_cpp_backend()
+ capability = (
+ _detect_loaded_vlm(
+ self_base_url,
+ llama_backend = llama_backend,
+ )
+ if _detect_loaded_vlm
+ else None
+ )
+ caption_authorization_header = _document_caption_authorization_header(
+ capability,
+ llama_backend,
+ fastapi_request.headers.get("authorization"),
+ )
+
+ if await fastapi_request.is_disconnected():
+ raise HTTPException(status_code = 499, detail = "Client closed request")
+
+ accept_header = (fastapi_request.headers.get("accept", "") or "").lower()
+ wants_stream = "application/x-ndjson" in accept_header
+
+ def _build_response_payload(result: Any) -> ExtractDocumentResponse:
+ markdown_, tokens_est_, truncate_warning_ = (
+ _truncate_markdown_to_token_budget(
+ result.markdown,
+ token_budget = token_budget,
+ original_tokens_est = result.tokens_est,
+ )
+ )
+ warnings_ = list(result.warnings)
+ if truncate_warning_:
+ warnings_.append(truncate_warning_)
+ return ExtractDocumentResponse(
+ filename = filename,
+ markdown = markdown_,
+ page_count = result.page_count,
+ tokens_est = tokens_est_,
+ truncated = truncate_warning_ is not None,
+ figures = [ExtractedFigureModel(**_asdict(f)) for f in result.figures],
+ describe_skipped_reason = result.describe_skipped_reason,
+ vlm_source = result.vlm_source,
+ vlm_model = result.vlm_model,
+ image_input_available = getattr(result, "image_input_available", False),
+ warnings = warnings_,
+ )
+
+ if not wants_stream:
+ # ---- Legacy JSON path (no progress events) -----------------
+ cancel_event = threading.Event()
+ extraction_task = asyncio.create_task(
+ _extract_document(
+ file_bytes,
+ filename,
+ content_type = content_type,
+ describe_images = describe_images,
+ use_vlm_ocr = use_vlm_ocr,
+ max_figures = max_figures,
+ max_visual_payloads = max_visual_payloads,
+ capability = capability,
+ self_base_url = self_base_url,
+ authorization_header = caption_authorization_header,
+ cancel_event = cancel_event,
+ )
+ )
+ disconnect_task = asyncio.create_task(
+ _wait_for_document_request_disconnect(fastapi_request, cancel_event)
+ )
+ try:
+ done, _pending = await asyncio.wait(
+ {extraction_task, disconnect_task},
+ return_when = asyncio.FIRST_COMPLETED,
+ )
+ if extraction_task in done:
+ result = await extraction_task
+ elif disconnect_task in done and disconnect_task.result():
+ cancel_event.set()
+ with suppress(
+ _DocumentExtractionCancelled,
+ asyncio.CancelledError,
+ asyncio.TimeoutError,
+ ):
+ await asyncio.wait_for(
+ asyncio.shield(extraction_task), timeout = 10
+ )
+ if not extraction_task.done():
+ extraction_task.cancel()
+ raise _DocumentExtractionCancelled(
+ "document extraction was cancelled"
+ )
+ else:
+ result = await extraction_task
+ except _DocumentExtractionUnavailable as exc:
+ raise HTTPException(status_code = 501, detail = str(exc))
+ except _DocumentExtractionTimeout:
+ raise HTTPException(
+ status_code = 504,
+ detail = "Document parsing timed out after 120s before image captioning",
+ )
+ except _DocumentExtractionBusy:
+ raise HTTPException(
+ status_code = 503, detail = "Document extraction is busy"
+ )
+ except _DocumentExtractionCancelled:
+ raise HTTPException(status_code = 499, detail = "Client closed request")
+ except _DocumentExtractionEncrypted as exc:
+ raise HTTPException(status_code = 422, detail = str(exc))
+ except ValueError as exc:
+ detail = str(exc)
+ status_code = (
+ 415 if detail.lower().startswith("unsupported file type") else 400
+ )
+ raise HTTPException(status_code = status_code, detail = detail)
+ except Exception:
+ logger.exception("Document extraction failed for %s", filename)
+ raise HTTPException(status_code = 500, detail = "Extraction failed")
+ finally:
+ cancel_event.set()
+ disconnect_task.cancel()
+ with suppress(asyncio.CancelledError):
+ await disconnect_task
+
+ if result.page_count > _EXTRACT_MAX_PAGES_INLINE:
+ raise HTTPException(
+ status_code = 413,
+ detail = (
+ f"Document has {result.page_count} pages; inline extraction "
+ f"is capped at {_EXTRACT_MAX_PAGES_INLINE}. Split into smaller "
+ f"documents or reduce the page range."
+ ),
+ )
+ return _build_response_payload(result)
+
+ # ---- Streaming NDJSON path (Accept: application/x-ndjson) ------
+ progress_queue: asyncio.Queue = asyncio.Queue()
+
+ async def _progress_cb(event: dict) -> None:
+ await progress_queue.put(dict(event))
+
+ async def _ndjson_stream():
+ cancel_event = threading.Event()
+ extraction_task = asyncio.create_task(
+ _extract_document(
+ file_bytes,
+ filename,
+ content_type = content_type,
+ describe_images = describe_images,
+ use_vlm_ocr = use_vlm_ocr,
+ max_figures = max_figures,
+ max_visual_payloads = max_visual_payloads,
+ capability = capability,
+ self_base_url = self_base_url,
+ authorization_header = caption_authorization_header,
+ cancel_event = cancel_event,
+ progress_cb = _progress_cb,
+ )
+ )
+ # Always drain the task's exception so a busy/cancel race
+ # doesn't leave an orphan "Future exception was never retrieved"
+ # in the logs when the body iterator exits early.
+ extraction_task.add_done_callback(_drain_doc_future_exception)
+ disconnect_task = asyncio.create_task(
+ _wait_for_document_request_disconnect(fastapi_request, cancel_event)
+ )
+ try:
+ extract_wait = asyncio.ensure_future(asyncio.shield(extraction_task))
+ extract_wait.add_done_callback(_drain_doc_future_exception)
+ while True:
+ queue_get = asyncio.ensure_future(progress_queue.get())
+ queue_get.add_done_callback(_drain_doc_future_exception)
+ done, _pending = await asyncio.wait(
+ {queue_get, extract_wait, disconnect_task},
+ return_when = asyncio.FIRST_COMPLETED,
+ )
+ if queue_get in done:
+ event = queue_get.result()
+ yield json.dumps(event) + "\n"
+ else:
+ queue_get.cancel()
+ with suppress(asyncio.CancelledError):
+ await queue_get
+
+ if disconnect_task in done and disconnect_task.result():
+ cancel_event.set()
+ with suppress(
+ _DocumentExtractionCancelled,
+ asyncio.CancelledError,
+ asyncio.TimeoutError,
+ ):
+ await asyncio.wait_for(
+ asyncio.shield(extraction_task), timeout = 10
+ )
+ if not extraction_task.done():
+ extraction_task.cancel()
+ raise _DocumentExtractionCancelled(
+ "document extraction was cancelled"
+ )
+
+ # The shield-wrapper may complete (cancelled) before
+ # the underlying extraction_task is done; calling
+ # ``.result()`` in that window raises
+ # InvalidStateError. Wait for the real task before
+ # consuming its result.
+ if extraction_task.done():
+ # Drain any remaining progress events before result.
+ while not progress_queue.empty():
+ try:
+ event = progress_queue.get_nowait()
+ except asyncio.QueueEmpty:
+ break
+ yield json.dumps(event) + "\n"
+ result = extraction_task.result()
+ break
+ if extract_wait in done:
+ # Shield-wrapper finished but the real task is
+ # still running. Re-arm the wait on a fresh
+ # shielded future and loop.
+ extract_wait = asyncio.ensure_future(
+ asyncio.shield(extraction_task)
+ )
+ extract_wait.add_done_callback(
+ _drain_doc_future_exception
+ )
+
+ if result.page_count > _EXTRACT_MAX_PAGES_INLINE:
+ yield (
+ json.dumps(
+ {
+ "stage": "error",
+ "status_code": 413,
+ "detail": (
+ f"Document has {result.page_count} pages; inline extraction "
+ f"is capped at {_EXTRACT_MAX_PAGES_INLINE}. Split into smaller "
+ f"documents or reduce the page range."
+ ),
+ }
+ )
+ + "\n"
+ )
+ return
+
+ response = _build_response_payload(result)
+ yield (
+ json.dumps(
+ {
+ "stage": "result",
+ "data": response.model_dump(mode = "json"),
+ }
+ )
+ + "\n"
+ )
+ except _DocumentExtractionUnavailable as exc:
+ yield (
+ json.dumps(
+ {
+ "stage": "error",
+ "status_code": 501,
+ "detail": str(exc),
+ }
+ )
+ + "\n"
+ )
+ except _DocumentExtractionTimeout:
+ yield (
+ json.dumps(
+ {
+ "stage": "error",
+ "status_code": 504,
+ "detail": "Document parsing timed out after 120s before image captioning",
+ }
+ )
+ + "\n"
+ )
+ except _DocumentExtractionBusy:
+ yield (
+ json.dumps(
+ {
+ "stage": "error",
+ "status_code": 503,
+ "detail": "Document extraction is busy",
+ }
+ )
+ + "\n"
+ )
+ except _DocumentExtractionCancelled:
+ yield (
+ json.dumps(
+ {
+ "stage": "error",
+ "status_code": 499,
+ "detail": "Client closed request",
+ }
+ )
+ + "\n"
+ )
+ except _DocumentExtractionEncrypted as exc:
+ yield (
+ json.dumps(
+ {
+ "stage": "error",
+ "status_code": 422,
+ "detail": str(exc),
+ }
+ )
+ + "\n"
+ )
+ except ValueError as exc:
+ detail = str(exc)
+ status_code = (
+ 415 if detail.lower().startswith("unsupported file type") else 400
+ )
+ yield (
+ json.dumps(
+ {
+ "stage": "error",
+ "status_code": status_code,
+ "detail": detail,
+ }
+ )
+ + "\n"
+ )
+ except Exception:
+ logger.exception("Document extraction failed for %s", filename)
+ yield (
+ json.dumps(
+ {
+ "stage": "error",
+ "status_code": 500,
+ "detail": "Extraction failed",
+ }
+ )
+ + "\n"
+ )
+ finally:
+ cancel_event.set()
+ disconnect_task.cancel()
+ with suppress(asyncio.CancelledError):
+ await disconnect_task
+
+ return StreamingResponse(
+ _ndjson_stream(),
+ media_type = "application/x-ndjson",
+ )
+ finally:
+ # _EXTRACT_SEMAPHORE is owned solely by _run_extract_process_sync; the
+ # worker maps a busy semaphore to DocumentExtractionBusy → an in-stream
+ # error event above.
+ pass
diff --git a/studio/backend/routes/models.py b/studio/backend/routes/models.py
index 9ea113e488..826da462cf 100644
--- a/studio/backend/routes/models.py
+++ b/studio/backend/routes/models.py
@@ -12,7 +12,8 @@
import sys
import uuid
from pathlib import Path
-from fastapi import APIRouter, Body, Depends, HTTPException, Query
+from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request
+from pydantic import BaseModel, Field
from typing import List, Optional
import structlog
from loggers import get_logger
@@ -139,6 +140,24 @@ def _safe_is_dir(path) -> bool:
logger = get_logger(__name__)
+class ModelProbeRequest(BaseModel):
+ model_name: str = Field(..., description = "Model identifier or local path")
+ hf_token: Optional[str] = Field(
+ None, description = "HuggingFace token for gated/private models"
+ )
+ trust_remote_code: bool = Field(
+ False, description = "Allow probes that require custom model code"
+ )
+
+
+def _reject_hf_token_query(request: Request) -> None:
+ if "hf_token" in request.query_params:
+ raise HTTPException(
+ status_code = 400,
+ detail = "HF tokens must be sent with POST JSON probe endpoints, not GET query parameters.",
+ )
+
+
def derive_model_type(
is_vision: bool, audio_type: Optional[str], is_embedding: bool = False
) -> ModelType:
@@ -152,6 +171,40 @@ def derive_model_type(
return "text"
+def _defaults_vision_flags(config_dict: dict) -> tuple[bool, bool]:
+ model_config = config_dict.get("model", {}) if isinstance(config_dict, dict) else {}
+ inference_config = (
+ config_dict.get("inference", {}) if isinstance(config_dict, dict) else {}
+ )
+ yaml_is_vision = bool(model_config.get("is_vision", False))
+ yaml_requires_trust_remote_code = bool(
+ model_config.get("trust_remote_code", False)
+ or inference_config.get("trust_remote_code", False)
+ )
+ return yaml_is_vision, yaml_requires_trust_remote_code
+
+
+def _detect_vision_for_config_endpoint(
+ model_name: str,
+ *,
+ hf_token: Optional[str] = None,
+ trust_remote_code: bool = False,
+ config_dict: Optional[dict] = None,
+) -> bool:
+ defaults = (
+ config_dict if config_dict is not None else load_model_defaults(model_name)
+ )
+ yaml_is_vision, yaml_requires_trust_remote_code = _defaults_vision_flags(defaults)
+ if yaml_is_vision and yaml_requires_trust_remote_code:
+ return True
+ detected = is_vision_model(
+ model_name,
+ hf_token = hf_token,
+ trust_remote_code = trust_remote_code,
+ )
+ return detected
+
+
def _resolve_hf_cache_dir() -> Path:
"""Resolve local HF cache root used by hub downloads."""
try:
@@ -1479,7 +1532,7 @@ async def list_models(
loaded_models.append(model_info)
# Include active GGUF model (loaded via llama-server)
- from routes.inference import get_llama_cpp_backend
+ from core.inference.llama_cpp import get_llama_cpp_backend
llama_backend = get_llama_cpp_backend()
if llama_backend.is_loaded and llama_backend.model_identifier:
@@ -1562,9 +1615,35 @@ def _get_model_size_bytes(
@router.get("/config/{model_name:path}")
async def get_model_config(
+ request: Request,
model_name: str,
- hf_token: Optional[str] = Query(None),
+ trust_remote_code: bool = False,
current_subject: str = Depends(get_current_subject),
+):
+ _reject_hf_token_query(request)
+ return await _build_model_config_response(
+ model_name,
+ hf_token = None,
+ trust_remote_code = trust_remote_code,
+ )
+
+
+@router.post("/config")
+async def post_model_config(
+ request: ModelProbeRequest,
+ current_subject: str = Depends(get_current_subject),
+):
+ return await _build_model_config_response(
+ request.model_name,
+ hf_token = request.hf_token,
+ trust_remote_code = request.trust_remote_code,
+ )
+
+
+async def _build_model_config_response(
+ model_name: str,
+ hf_token: Optional[str] = None,
+ trust_remote_code: bool = False,
):
"""
Get configuration for a specific model.
@@ -1589,7 +1668,12 @@ async def get_model_config(
config_dict = load_model_defaults(model_name)
# Detect model capabilities (pass HF token for gated models)
- is_vision = is_vision_model(model_name, hf_token = hf_token)
+ is_vision = _detect_vision_for_config_endpoint(
+ model_name,
+ hf_token = hf_token,
+ trust_remote_code = trust_remote_code,
+ config_dict = config_dict,
+ )
is_embedding = is_embedding_model(model_name, hf_token = hf_token)
audio_type = detect_audio_type(model_name, hf_token = hf_token)
@@ -1598,7 +1682,11 @@ async def get_model_config(
base_model = None
max_position_embeddings = None
try:
- model_config = ModelConfig.from_identifier(model_name)
+ model_config = ModelConfig.from_identifier(
+ model_name,
+ hf_token = hf_token,
+ trust_remote_code = trust_remote_code,
+ )
is_lora = model_config.is_lora
base_model = model_config.base_model if is_lora else None
max_position_embeddings = _get_max_position_embeddings(model_config)
@@ -2068,8 +2156,35 @@ async def get_lora_base_model(
@router.get("/check-vision/{model_name:path}", response_model = VisionCheckResponse)
async def check_vision_model(
+ request: Request,
model_name: str,
+ trust_remote_code: bool = False,
current_subject: str = Depends(get_current_subject),
+):
+ _reject_hf_token_query(request)
+ return await _check_vision_model_response(
+ model_name,
+ hf_token = None,
+ trust_remote_code = trust_remote_code,
+ )
+
+
+@router.post("/check-vision", response_model = VisionCheckResponse)
+async def post_check_vision_model(
+ request: ModelProbeRequest,
+ current_subject: str = Depends(get_current_subject),
+):
+ return await _check_vision_model_response(
+ request.model_name,
+ hf_token = request.hf_token,
+ trust_remote_code = request.trust_remote_code,
+ )
+
+
+async def _check_vision_model_response(
+ model_name: str,
+ hf_token: Optional[str] = None,
+ trust_remote_code: bool = False,
):
"""
Check if a model is a vision model.
@@ -2078,7 +2193,11 @@ async def check_vision_model(
"""
try:
logger.info(f"Checking if vision model: {model_name}")
- is_vision = is_vision_model(model_name)
+ is_vision = _detect_vision_for_config_endpoint(
+ model_name,
+ hf_token = hf_token,
+ trust_remote_code = trust_remote_code,
+ )
logger.info(f"Vision check result for {model_name}: is_vision={is_vision}")
return VisionCheckResponse(
@@ -2603,7 +2722,7 @@ async def delete_cached_model(
# Check if model is currently loaded
try:
- from routes.inference import get_llama_cpp_backend
+ from core.inference.llama_cpp import get_llama_cpp_backend
llama_backend = get_llama_cpp_backend()
if llama_backend.is_loaded and llama_backend.model_identifier:
diff --git a/studio/backend/run.py b/studio/backend/run.py
index 3bde8abd3c..5e5da55858 100644
--- a/studio/backend/run.py
+++ b/studio/backend/run.py
@@ -494,11 +494,15 @@ def _graceful_shutdown(server = None):
logger.warning("Error shutting down training subprocess: %s", e)
# 5. Kill llama-server subprocess (if loaded)
+ #
+ # Read the module-level singleton directly so we don't instantiate a
+ # fresh backend during shutdown when none was ever loaded.
try:
- from routes.inference import _llama_cpp_backend
+ from core.inference import llama_cpp as _llama_cpp_mod
- if _llama_cpp_backend is not None:
- _llama_cpp_backend._kill_process()
+ backend = getattr(_llama_cpp_mod, "_llama_cpp_backend", None)
+ if backend is not None:
+ backend._kill_process()
except Exception as e:
logger.warning("Error shutting down llama-server: %s", e)
diff --git a/studio/backend/tests/test_anthropic_messages.py b/studio/backend/tests/test_anthropic_messages.py
index 842429d5af..7f0cf5d56a 100644
--- a/studio/backend/tests/test_anthropic_messages.py
+++ b/studio/backend/tests/test_anthropic_messages.py
@@ -34,6 +34,7 @@
AnthropicStreamEmitter,
AnthropicPassthroughEmitter,
)
+import routes.inference as route
from routes.inference import (
_normalize_anthropic_openai_images,
_select_anthropic_server_tools,
@@ -1056,6 +1057,24 @@ def test_bad_base64_raises_400(self):
_normalize_anthropic_openai_images(msgs, is_vision = True)
assert exc.value.status_code == 400
+ def test_image_count_limit_applies(self, monkeypatch):
+ monkeypatch.setattr(route, "_OPENAI_CHAT_MAX_IMAGES", 1)
+ data_url = _jpeg_data_url()
+ msgs = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image_url", "image_url": {"url": data_url}},
+ {"type": "image_url", "image_url": {"url": data_url}},
+ ],
+ }
+ ]
+
+ with pytest.raises(HTTPException) as exc:
+ _normalize_anthropic_openai_images(msgs, is_vision = True)
+
+ assert exc.value.status_code == 413
+
# =====================================================================
# Studio-tool alias detection (/v1/messages tool routing)
diff --git a/studio/backend/tests/test_chat_document_extraction.py b/studio/backend/tests/test_chat_document_extraction.py
new file mode 100644
index 0000000000..3d89883952
--- /dev/null
+++ b/studio/backend/tests/test_chat_document_extraction.py
@@ -0,0 +1,906 @@
+# SPDX-License-Identifier: AGPL-3.0-only
+# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
+
+"""
+Tests for the chat document extractor + VLM capability probe.
+
+Probe tests run regardless of the extraction backend because they only
+shape-check :mod:`core.chat.vlm_capability`. Backend-backed tests skip
+cleanly when the optional deps (pymupdf / pymupdf4llm / mammoth) are
+missing.
+"""
+
+from __future__ import annotations
+
+import importlib.util
+import sys
+from types import ModuleType, SimpleNamespace
+from typing import Any, Dict, Optional
+
+import pytest
+
+from core.chat.vlm_capability import (
+ VlmCapability,
+ detect_loaded_vlm,
+ extract_self_base_url,
+)
+
+
+# ---------------------------------------------------------------------- #
+# VlmCapability dataclass #
+# ---------------------------------------------------------------------- #
+
+
+def test_vlm_capability_none_factory_is_safe_default() -> None:
+ cap = VlmCapability.none()
+ assert cap.is_vlm is False
+ assert cap.endpoint_url is None
+ assert cap.model_name is None
+ assert cap.source == "none"
+ assert cap.reason # non-empty
+
+
+def test_vlm_capability_to_dict_round_trips_fields() -> None:
+ cap = VlmCapability(
+ is_vlm = True,
+ endpoint_url = "http://127.0.0.1:8080",
+ model_name = "qwen2-vl",
+ source = "gguf",
+ reason = None,
+ )
+ assert cap.to_dict() == {
+ "is_vlm": True,
+ "endpoint_url": "http://127.0.0.1:8080",
+ "model_name": "qwen2-vl",
+ "source": "gguf",
+ "reason": None,
+ }
+
+
+# ---------------------------------------------------------------------- #
+# detect_loaded_vlm() across backend shapes #
+# ---------------------------------------------------------------------- #
+
+
+class _FakeLlama:
+ def __init__(
+ self,
+ *,
+ loaded: bool,
+ vision: bool = False,
+ base_url: str = "http://127.0.0.1:8080",
+ model_id: str = "fake-gguf",
+ ) -> None:
+ self.is_loaded = loaded
+ self.is_vision = vision
+ self.base_url = base_url
+ self.model_identifier = model_id
+
+
+class _FakeInferenceBackend:
+ def __init__(
+ self,
+ *,
+ active: Optional[str],
+ info: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ self.active_model_name = active
+ self.models: Dict[str, Dict[str, Any]] = {active: info or {}} if active else {}
+
+
+def _patch_probes(
+ monkeypatch: pytest.MonkeyPatch,
+ *,
+ llama: Optional[_FakeLlama],
+ inference: Optional[_FakeInferenceBackend],
+) -> None:
+ from core.chat import vlm_capability as vc
+
+ if llama is None:
+ monkeypatch.setattr(vc, "_probe_gguf", lambda _llama = None: None)
+ else:
+
+ def probe_gguf(llama_backend = None):
+ backend = llama_backend or llama
+ if not backend.is_loaded:
+ return None
+ is_vision = bool(backend.is_vision)
+ return VlmCapability(
+ is_vlm = is_vision,
+ endpoint_url = backend.base_url,
+ model_name = backend.model_identifier,
+ source = "gguf",
+ reason = None if is_vision else "loaded GGUF is not vision-capable",
+ )
+
+ monkeypatch.setattr(vc, "_probe_gguf", probe_gguf)
+
+ if inference is None:
+ monkeypatch.setattr(vc, "_probe_transformers", lambda _u: None)
+ else:
+
+ def probe_tf(self_base_url):
+ name = inference.active_model_name
+ if not name:
+ return None
+ info = inference.models.get(name) or {}
+ is_vision = bool(info.get("is_vision", False))
+ source = "unsloth" if info.get("is_lora") else "transformers"
+ if not self_base_url:
+ return VlmCapability(
+ is_vlm = False,
+ endpoint_url = None,
+ model_name = name,
+ source = source,
+ reason = "cannot self-loopback: request base URL unavailable",
+ )
+ return VlmCapability(
+ is_vlm = is_vision,
+ endpoint_url = self_base_url.rstrip("/"),
+ model_name = name,
+ source = source,
+ reason = None if is_vision else "loaded model is not vision-capable",
+ )
+
+ monkeypatch.setattr(vc, "_probe_transformers", probe_tf)
+
+
+def test_detect_returns_none_when_no_model_loaded(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ _patch_probes(monkeypatch, llama = None, inference = None)
+ cap = detect_loaded_vlm()
+ assert cap.source == "none"
+ assert cap.is_vlm is False
+
+
+def test_detect_gguf_vision_returns_llama_endpoint(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ llama = _FakeLlama(loaded = True, vision = True, base_url = "http://127.0.0.1:9999")
+ _patch_probes(monkeypatch, llama = llama, inference = None)
+ cap = detect_loaded_vlm("http://studio.local")
+ assert cap.source == "gguf"
+ assert cap.is_vlm is True
+ assert cap.endpoint_url == "http://127.0.0.1:9999" # GGUF ignores self_base_url
+ assert cap.reason is None
+
+
+def test_detect_gguf_vision_accepts_injected_backend(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ from core.chat import vlm_capability as vc
+
+ llama = _FakeLlama(loaded = True, vision = True, base_url = "http://127.0.0.1:9999")
+ monkeypatch.setattr(vc, "_probe_transformers", lambda _u: None)
+
+ cap = detect_loaded_vlm(
+ "http://127.0.0.1:8000",
+ llama_backend = llama,
+ )
+
+ assert cap.source == "gguf"
+ assert cap.is_vlm is True
+ assert cap.endpoint_url == "http://127.0.0.1:9999"
+
+
+def test_detect_gguf_vision_uses_core_llama_accessor(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """The implicit GGUF fallback must use the core-owned singleton path."""
+ from core.chat import vlm_capability as vc
+ from core.inference import llama_cpp
+
+ llama = _FakeLlama(loaded = True, vision = True, base_url = "http://127.0.0.1:9999")
+ assert hasattr(llama_cpp, "get_llama_cpp_backend")
+ monkeypatch.setattr(llama_cpp, "_llama_cpp_backend", llama)
+ monkeypatch.setattr(vc, "_probe_transformers", lambda _u: None)
+
+ cap = detect_loaded_vlm("http://127.0.0.1:8000")
+
+ assert cap.source == "gguf"
+ assert cap.is_vlm is True
+ assert cap.endpoint_url == "http://127.0.0.1:9999"
+
+
+def test_detect_gguf_non_vision_surfaces_reason(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ llama = _FakeLlama(loaded = True, vision = False)
+ _patch_probes(monkeypatch, llama = llama, inference = None)
+ cap = detect_loaded_vlm()
+ assert cap.source == "gguf"
+ assert cap.is_vlm is False
+ assert cap.reason and "vision" in cap.reason.lower()
+
+
+def test_detect_transformers_vision_uses_self_loopback(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ ib = _FakeInferenceBackend(
+ active = "Qwen2-VL-7B",
+ info = {"is_vision": True, "is_lora": False},
+ )
+ _patch_probes(monkeypatch, llama = None, inference = ib)
+ cap = detect_loaded_vlm("http://127.0.0.1:8000/")
+ assert cap.source == "transformers"
+ assert cap.is_vlm is True
+ assert cap.endpoint_url == "http://127.0.0.1:8000"
+ assert cap.model_name == "Qwen2-VL-7B"
+
+
+def test_detect_unsloth_lora_vision_reports_unsloth_source(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ ib = _FakeInferenceBackend(
+ active = "my-qwen-vl-lora",
+ info = {"is_vision": True, "is_lora": True},
+ )
+ _patch_probes(monkeypatch, llama = None, inference = ib)
+ cap = detect_loaded_vlm("http://studio.local:8000")
+ assert cap.source == "unsloth"
+ assert cap.is_vlm is True
+
+
+def test_detect_falls_through_when_gguf_is_loaded_but_endpoint_data_missing(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """A half-initialised llama-server (is_loaded=True but base_url/model
+ missing) must not suppress the transformers fallback path — otherwise
+ a misleading non-vision GGUF result hides an active transformers VLM.
+ """
+ from core.chat import vlm_capability as vc
+
+ fake_llama_cpp = ModuleType("core.inference.llama_cpp")
+ fake_llama_cpp.get_llama_cpp_backend = lambda: _FakeLlama(
+ loaded = True,
+ base_url = "",
+ model_id = "",
+ )
+ fake_inference = ModuleType("core.inference")
+ fake_inference.__path__ = [] # type: ignore[attr-defined]
+ fake_inference.llama_cpp = fake_llama_cpp # type: ignore[attr-defined]
+ monkeypatch.setitem(sys.modules, "core.inference", fake_inference)
+ monkeypatch.setitem(sys.modules, "core.inference.llama_cpp", fake_llama_cpp)
+
+ ib = _FakeInferenceBackend(
+ active = "Qwen2-VL-7B",
+ info = {"is_vision": True, "is_lora": False},
+ )
+ monkeypatch.setattr(
+ vc,
+ "_probe_transformers",
+ lambda self_base_url: VlmCapability(
+ is_vlm = True,
+ endpoint_url = self_base_url.rstrip("/") if self_base_url else None,
+ model_name = ib.active_model_name,
+ source = "transformers",
+ reason = None,
+ ),
+ )
+
+ cap = detect_loaded_vlm("http://127.0.0.1:8000")
+ assert cap.source == "transformers"
+ assert cap.is_vlm is True
+
+
+def test_detect_transformers_without_self_url_reports_missing_loopback(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ ib = _FakeInferenceBackend(
+ active = "Qwen2-VL-7B",
+ info = {"is_vision": True, "is_lora": False},
+ )
+ _patch_probes(monkeypatch, llama = None, inference = ib)
+ cap = detect_loaded_vlm(None)
+ assert cap.is_vlm is False
+ assert cap.reason and "loopback" in cap.reason.lower()
+
+
+# ---------------------------------------------------------------------- #
+# extract_self_base_url — request base-URL extraction #
+# ---------------------------------------------------------------------- #
+
+
+class _FakeState:
+ def __init__(self, server_port: Optional[int] = None) -> None:
+ if server_port is not None:
+ self.server_port = server_port
+
+
+class _FakeApp:
+ def __init__(self, server_port: Optional[int] = None) -> None:
+ self.state = _FakeState(server_port)
+
+
+class _FakeRequest:
+ def __init__(
+ self,
+ base_url: str,
+ *,
+ server_port: Optional[int] = None,
+ scope_server: Optional[tuple[str, int]] = None,
+ ) -> None:
+ self.base_url = base_url
+ self.app = _FakeApp(server_port)
+ self.scope = {"server": scope_server} if scope_server else {}
+
+
+def test_extract_self_base_url_strips_trailing_slash() -> None:
+ assert (
+ extract_self_base_url(_FakeRequest("http://127.0.0.1:8000/"))
+ == "http://127.0.0.1:8000"
+ )
+
+
+def test_extract_self_base_url_prefers_trusted_server_port() -> None:
+ assert (
+ extract_self_base_url(
+ _FakeRequest(
+ "http://attacker.invalid:9999/",
+ server_port = 7777,
+ scope_server = ("127.0.0.1", 6666),
+ )
+ )
+ == "http://127.0.0.1:7777"
+ )
+ assert (
+ extract_self_base_url(
+ _FakeRequest(
+ "http://attacker.invalid:9999/",
+ scope_server = ("127.0.0.1", 6666),
+ )
+ )
+ == "http://127.0.0.1:6666"
+ )
+
+
+def test_extract_self_base_url_ignores_host_header() -> None:
+ assert (
+ extract_self_base_url(_FakeRequest("http://studio.local:8000/"))
+ == "http://127.0.0.1:8000"
+ )
+ assert (
+ extract_self_base_url(_FakeRequest("https://example.com:9443/"))
+ == "http://127.0.0.1:9443"
+ )
+
+
+def test_extract_self_base_url_none_when_empty() -> None:
+ assert extract_self_base_url(_FakeRequest("")) is None
+
+
+def test_extract_self_base_url_none_on_missing_attribute() -> None:
+ assert extract_self_base_url(object()) is None
+
+
+# ---------------------------------------------------------------------- #
+# extract_document orchestration — backend-agnostic (monkey-patched) #
+# ---------------------------------------------------------------------- #
+
+
+@pytest.mark.asyncio
+async def test_max_figures_zero_sets_describe_skipped_reason(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """max_figures=0 must skip description with a specific diagnostic even
+ when a VLM is available."""
+ from core.chat import document_extractor as de
+
+ def fake_extract(_fb, _fn, _opts, _ct = ""):
+ return "# Smoke\n", [], 1, 0, 0
+
+ monkeypatch.setattr(de, "DOCUMENT_EXTRACTION_AVAILABLE", True)
+ monkeypatch.setattr(de, "_run_extract_sync", fake_extract)
+
+ result = await de.extract_document(
+ b"# Smoke\n",
+ "sample.md",
+ describe_images = True,
+ max_figures = 0,
+ capability = VlmCapability(
+ is_vlm = True,
+ endpoint_url = "http://127.0.0.1:8000",
+ model_name = "vlm",
+ source = "transformers",
+ ),
+ )
+
+ assert result.describe_skipped_reason == (
+ "figure description disabled because max_figures is 0"
+ )
+ assert result.markdown == "# Smoke\n"
+ assert result.figures == []
+
+
+@pytest.mark.asyncio
+async def test_run_extract_sync_seam_receives_content_type(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """The test seam path (monkeypatched _run_extract_sync) must be invoked
+ with the content_type so dispatch-by-content-type can be exercised in
+ tests, not only by filename suffix."""
+ from core.chat import document_extractor as de
+
+ received: dict[str, str] = {}
+
+ def fake_extract(_fb, _fn, _opts, ct = ""):
+ received["content_type"] = ct
+ return "ok", [], 0, 0, 0
+
+ monkeypatch.setattr(de, "DOCUMENT_EXTRACTION_AVAILABLE", True)
+ monkeypatch.setattr(de, "_run_extract_sync", fake_extract)
+
+ await de.extract_document(
+ b"hello",
+ "no-suffix-file",
+ content_type = "text/plain",
+ describe_images = False,
+ )
+ assert received["content_type"] == "text/plain"
+
+
+@pytest.mark.asyncio
+async def test_describe_image_via_vlm_sends_auth_header_and_max_tokens(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ from core.chat import document_extractor as de
+
+ captured: dict[str, Any] = {}
+
+ class FakeResponse:
+ status_code = 200
+
+ def json(self):
+ return {"choices": [{"message": {"content": "A chart."}}]}
+
+ class FakeAsyncClient:
+ def __init__(self, *, timeout: float) -> None:
+ captured["timeout"] = timeout
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, *_args):
+ return None
+
+ async def post(self, url, *, headers, json):
+ captured["url"] = url
+ captured["headers"] = headers
+ captured["json"] = json
+ return FakeResponse()
+
+ fake_httpx = ModuleType("httpx")
+ fake_httpx.AsyncClient = FakeAsyncClient
+ monkeypatch.setitem(sys.modules, "httpx", fake_httpx)
+
+ caption, error = await de._describe_image_via_vlm(
+ image_base64 = "abc",
+ image_mime = "image/jpeg",
+ endpoint_url = "http://127.0.0.1:8000",
+ model_name = "vlm",
+ authorization_header = "Bearer token",
+ timeout_seconds = 7,
+ )
+
+ assert caption == "A chart."
+ assert error is None
+ assert captured["url"] == "http://127.0.0.1:8000/v1/chat/completions"
+ assert captured["headers"]["Authorization"] == "Bearer token"
+ assert captured["json"]["max_tokens"] == 512
+ assert "max_completion_tokens" not in captured["json"]
+
+
+# ---------------------------------------------------------------------- #
+# Backend dispatch — real _run_extract_sync (requires pymupdf/mammoth) #
+# ---------------------------------------------------------------------- #
+
+
+_BACKEND_INSTALLED = (
+ importlib.util.find_spec("pymupdf") is not None
+ and importlib.util.find_spec("pymupdf4llm") is not None
+ and importlib.util.find_spec("mammoth") is not None
+)
+
+
+def test_run_extract_sync_rejects_pptx_with_value_error() -> None:
+ """PPTX was dropped in the PyMuPDF4LLM migration. _run_extract_sync
+ must raise ValueError so the route can map it to HTTP 415."""
+ if not _BACKEND_INSTALLED:
+ pytest.skip("extraction backend not installed")
+ from core.chat import document_extractor as de
+
+ with pytest.raises(ValueError):
+ de._run_extract_sync(
+ b"PK\x03\x04",
+ "deck.pptx",
+ {"max_figures": 0, "extract_images": False, "use_vlm_ocr": False},
+ )
+
+
+def test_run_extract_sync_text_path_decodes_utf8() -> None:
+ """TXT / MD paths must not require PDF/DOCX parser dependencies."""
+ from core.chat import document_extractor as de
+
+ md, figs, pages, trunc, seen = de._run_extract_sync(
+ "# Héllo\n".encode("utf-8"),
+ "notes.md",
+ {"max_figures": 0, "extract_images": False, "use_vlm_ocr": False},
+ )
+ assert md == "# Héllo\n"
+ assert figs == []
+ assert pages == 0 and trunc == 0 and seen == 0
+
+
+def test_run_extract_sync_html_converts_to_markdown_without_parser_deps() -> None:
+ """HTML must be cleaned before prompt injection and not depend on PDF/DOCX deps."""
+ from core.chat import document_extractor as de
+
+ md, figs, pages, trunc, seen = de._run_extract_sync(
+ b"]Title
Hello world
",
+ "page.html",
+ {"max_figures": 0, "extract_images": False, "use_vlm_ocr": False},
+ )
+ assert "# Title" in md
+ assert "**world**" in md
+ assert "\"\n"
+ " b\"hello