diff --git a/.github/workflows/pr5351-cpu-inference-macos.yml b/.github/workflows/pr5351-cpu-inference-macos.yml new file mode 100644 index 0000000000..df154f7354 --- /dev/null +++ b/.github/workflows/pr5351-cpu-inference-macos.yml @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 CPU-inference cross-OS lane: macOS (Apple Silicon). +# Same as the Ubuntu lane but on macos-14. llama-cpp-python builds +# with Metal autodetect disabled to stay on the CPU code path so the +# result mirrors a non-GPU Mac. + +name: PR-5351 CPU inference macOS + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/core/chat/**' + - 'tests/studio/test_cpu_inference_on_extracted_document.py' + - '.github/workflows/pr5351-cpu-inference-macos.yml' + workflow_dispatch: + +concurrency: + group: pr5351-cpu-inference-macos-${{ github.ref }} + cancel-in-progress: true + +jobs: + cpu-inference: + runs-on: macos-14 + timeout-minutes: 40 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend + llama-cpp-python (CPU build) + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install \ + python-multipart aiofiles sqlalchemy cryptography \ + pyyaml jinja2 mammoth pymupdf pymupdf4llm pytest pytest-asyncio \ + pytest-timeout huggingface_hub requests numpy + # Disable Metal so the lane stays CPU-only; mirrors a no-GPU Mac. + CMAKE_ARGS="-DGGML_METAL=OFF -DGGML_ACCELERATE=OFF -DGGML_NATIVE=OFF" \ + pip install --upgrade --quiet llama-cpp-python + + - name: CPU inference on extracted document + env: + PR5351_LLAMA_THREADS: '3' + run: | + python -m pytest -q tests/studio/test_cpu_inference_on_extracted_document.py -s --tb=short diff --git a/.github/workflows/pr5351-cpu-inference-ubuntu.yml b/.github/workflows/pr5351-cpu-inference-ubuntu.yml new file mode 100644 index 0000000000..4b0a441a12 --- /dev/null +++ b/.github/workflows/pr5351-cpu-inference-ubuntu.yml @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 CPU-inference cross-OS lane: Ubuntu. +# Builds llama-cpp-python from source for CPU, downloads a 0.5B GGUF +# from HF, extracts a synthetic PDF via the PR's document extractor, +# and asserts the model answers a ground-truth question. Proves +# end-to-end document-attach -> extract -> inference works on a CPU +# runner with no GPU. + +name: PR-5351 CPU inference Ubuntu + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/core/chat/**' + - 'tests/studio/test_cpu_inference_on_extracted_document.py' + - '.github/workflows/pr5351-cpu-inference-ubuntu.yml' + workflow_dispatch: + +concurrency: + group: pr5351-cpu-inference-ubuntu-${{ github.ref }} + cancel-in-progress: true + +jobs: + cpu-inference: + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend + llama-cpp-python (CPU build) + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install \ + python-multipart aiofiles sqlalchemy cryptography \ + pyyaml jinja2 mammoth pymupdf pymupdf4llm pytest pytest-asyncio \ + pytest-timeout huggingface_hub requests numpy + # CPU wheel ships pre-built on Linux; falls back to source if needed. + CMAKE_ARGS="-DGGML_NATIVE=OFF" pip install --upgrade --quiet llama-cpp-python + + - name: CPU inference on extracted document + env: + PR5351_LLAMA_THREADS: '4' + run: | + python -m pytest -q tests/studio/test_cpu_inference_on_extracted_document.py -s --tb=short diff --git a/.github/workflows/pr5351-cpu-inference-windows.yml b/.github/workflows/pr5351-cpu-inference-windows.yml new file mode 100644 index 0000000000..50972f17e7 --- /dev/null +++ b/.github/workflows/pr5351-cpu-inference-windows.yml @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 CPU-inference cross-OS lane: Windows. +# llama-cpp-python wheels exist for Windows; if pip falls back to +# source, MSVC is preinstalled on windows-latest. CPU-only. + +name: PR-5351 CPU inference Windows + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/core/chat/**' + - 'tests/studio/test_cpu_inference_on_extracted_document.py' + - '.github/workflows/pr5351-cpu-inference-windows.yml' + workflow_dispatch: + +concurrency: + group: pr5351-cpu-inference-windows-${{ github.ref }} + cancel-in-progress: true + +jobs: + cpu-inference: + runs-on: windows-latest + timeout-minutes: 40 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend + llama-cpp-python (CPU build) + shell: pwsh + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install python-multipart aiofiles sqlalchemy cryptography pyyaml jinja2 mammoth pymupdf pymupdf4llm pytest pytest-asyncio pytest-timeout huggingface_hub requests numpy + $env:CMAKE_ARGS = "-DGGML_NATIVE=OFF" + pip install --upgrade --quiet llama-cpp-python + + - name: CPU inference on extracted document + shell: pwsh + env: + PR5351_LLAMA_THREADS: '4' + run: | + python -m pytest -q tests/studio/test_cpu_inference_on_extracted_document.py -s --tb=short diff --git a/.github/workflows/pr5351-macos.yml b/.github/workflows/pr5351-macos.yml new file mode 100644 index 0000000000..6bb149659b --- /dev/null +++ b/.github/workflows/pr5351-macos.yml @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 cross-OS validation: macOS lane. +# macos-14 (arm64). Validates the multiprocessing `spawn` path that +# differs from Linux's default `fork`, the MLX detection branch in +# core/chat/vlm_capability.py, and Safari/WebKit-relevant filesystem +# behaviour. CPU-only; CUDA spoof auto-engages via tests/conftest.py. + +name: PR-5351 macOS + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/**' + - 'tests/studio/**' + - 'tests/conftest.py' + - '.github/workflows/pr5351-macos.yml' + workflow_dispatch: + +concurrency: + group: pr5351-macos-${{ github.ref }} + cancel-in-progress: true + +jobs: + pytest: + runs-on: macos-14 + timeout-minutes: 25 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend test dependencies (CPU only) + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install \ + python-multipart aiofiles sqlalchemy cryptography \ + pyyaml jinja2 mammoth unpdf requests \ + 'numpy<3' pytest pytest-asyncio httpx + pip install --index-url https://download.pytorch.org/whl/cpu 'torch>=2.4,<2.11' + pip install 'transformers>=4.51,<5.5' + + - name: PR-5351 document tests (macOS spawn semantics) + working-directory: studio/backend + env: + # macOS's default start method is spawn; exercise the same + # config users see in production. + UNSLOTH_STUDIO_EXTRACT_CONCURRENCY: '2' + run: | + python -m pytest -q tests/test_chat_document_extraction.py tests/test_chat_document_routes.py tests/test_inference_worker.py tests/test_vision_cache.py tests/test_anthropic_messages.py tests/test_openai_tool_passthrough.py tests/test_models_get_model_config_case_resolution.py --tb=short + + - name: PR-5351 regression tests + cancel timing + run: | + python -m pytest -q tests/studio/test_extractor_semaphore_leak.py tests/studio/test_html_independent_of_inference.py tests/studio/test_gguf_singleton_shared.py tests/studio/test_stream_cancel_registration_timing.py --tb=short diff --git a/.github/workflows/pr5351-ubuntu.yml b/.github/workflows/pr5351-ubuntu.yml new file mode 100644 index 0000000000..d1dd6d8712 --- /dev/null +++ b/.github/workflows/pr5351-ubuntu.yml @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 cross-OS validation: Ubuntu lane. +# Runs the document-extraction tests, the cancellation-timing structural +# test, and the three regression tests added in the fix commit against +# Python 3.11 on ubuntu-latest. CPU-only; the existing tests/conftest.py +# auto-installs the CUDA spoof so unsloth/unsloth_zoo device probes +# return "cuda". + +name: PR-5351 Ubuntu + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/**' + - 'tests/studio/**' + - 'tests/conftest.py' + - '.github/workflows/pr5351-ubuntu.yml' + workflow_dispatch: + +concurrency: + group: pr5351-ubuntu-${{ github.ref }} + cancel-in-progress: true + +jobs: + pytest: + runs-on: ubuntu-latest + timeout-minutes: 20 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend test dependencies (CPU only) + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install \ + python-multipart aiofiles sqlalchemy cryptography \ + pyyaml jinja2 mammoth unpdf requests \ + 'numpy<3' pytest pytest-asyncio httpx + pip install --index-url https://download.pytorch.org/whl/cpu 'torch>=2.4,<2.11' + pip install 'transformers>=4.51,<5.5' + + - name: PR-5351 document tests + working-directory: studio/backend + run: | + python -m pytest -q tests/test_chat_document_extraction.py tests/test_chat_document_routes.py tests/test_inference_worker.py tests/test_vision_cache.py tests/test_anthropic_messages.py tests/test_openai_tool_passthrough.py tests/test_models_get_model_config_case_resolution.py --tb=short + + - name: PR-5351 regression tests + cancel timing + run: | + python -m pytest -q tests/studio/test_extractor_semaphore_leak.py tests/studio/test_html_independent_of_inference.py tests/studio/test_gguf_singleton_shared.py tests/studio/test_stream_cancel_registration_timing.py --tb=short diff --git a/.github/workflows/pr5351-windows.yml b/.github/workflows/pr5351-windows.yml new file mode 100644 index 0000000000..777e1c38ec --- /dev/null +++ b/.github/workflows/pr5351-windows.yml @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. +# +# PR-5351 cross-OS validation: Windows lane. +# windows-latest. Validates the multiprocessing `spawn` path +# (mandatory on Windows), path normalisation, and EAGAIN-style +# Process construction failures under load (the exact bug class the +# semaphore-leak fix protects against). CPU-only; CUDA spoof +# auto-engages via tests/conftest.py. + +name: PR-5351 Windows + +on: + push: + branches: [pr-5351-cross-os-validation] + paths: + - 'studio/backend/**' + - 'tests/studio/**' + - 'tests/conftest.py' + - '.github/workflows/pr5351-windows.yml' + workflow_dispatch: + +concurrency: + group: pr5351-windows-${{ github.ref }} + cancel-in-progress: true + +jobs: + pytest: + runs-on: windows-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install backend test dependencies (CPU only) + shell: pwsh + run: | + python -m pip install --upgrade pip + pip install -r studio/backend/requirements/studio.txt + pip install python-multipart aiofiles sqlalchemy cryptography pyyaml jinja2 mammoth unpdf requests "numpy<3" pytest pytest-asyncio httpx + pip install --index-url https://download.pytorch.org/whl/cpu "torch>=2.4,<2.11" + pip install "transformers>=4.51,<5.5" + + - name: PR-5351 document tests (Windows spawn semantics) + working-directory: studio/backend + shell: pwsh + env: + UNSLOTH_STUDIO_EXTRACT_CONCURRENCY: '2' + run: | + python -m pytest -q tests/test_chat_document_extraction.py tests/test_chat_document_routes.py tests/test_inference_worker.py tests/test_vision_cache.py tests/test_anthropic_messages.py tests/test_openai_tool_passthrough.py tests/test_models_get_model_config_case_resolution.py --tb=short + + - name: PR-5351 regression tests + cancel timing + shell: pwsh + run: | + python -m pytest -q tests/studio/test_extractor_semaphore_leak.py tests/studio/test_html_independent_of_inference.py tests/studio/test_gguf_singleton_shared.py tests/studio/test_stream_cancel_registration_timing.py --tb=short diff --git a/.github/workflows/release-desktop.yml b/.github/workflows/release-desktop.yml deleted file mode 100644 index e747605322..0000000000 --- a/.github/workflows/release-desktop.yml +++ /dev/null @@ -1,902 +0,0 @@ -name: Release Desktop App - -on: - workflow_dispatch: - inputs: - studio_version: - description: 'Studio version tag to release (for example, v0.1.39-beta)' - type: string - required: true - pypi_version: - description: 'Exact PyPI unsloth version just published/stamped (for example, 2026.5.3); leave blank to use MIN_DESKTOP_BACKEND_VERSION' - type: string - required: false - draft: - description: 'Create as draft release; draft runs do not advance desktop-latest updater channel' - type: boolean - default: true - -permissions: - contents: read - -concurrency: - group: release-desktop-${{ github.repository }} - cancel-in-progress: false - -jobs: - prepare-version: - name: Prepare release versions - runs-on: ubuntu-latest - outputs: - studio_version: ${{ steps.prepare.outputs.studio_version }} - app_version: ${{ steps.prepare.outputs.app_version }} - desktop_release_tag: ${{ steps.prepare.outputs.desktop_release_tag }} - prerelease: ${{ steps.prepare.outputs.prerelease }} - pypi_version: ${{ steps.prepare.outputs.pypi_version }} - - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd - with: - persist-credentials: false - - - name: Validate release versions - id: prepare - shell: bash - env: - INPUT_STUDIO_VERSION: ${{ inputs.studio_version }} - INPUT_PYPI_VERSION: ${{ inputs.pypi_version }} - run: | - python3 <<'PY' - import os - import pathlib - import re - import sys - - studio_version = os.environ['INPUT_STUDIO_VERSION'].strip() - if not studio_version: - sys.exit('studio_version is required, for example v0.1.39-beta') - if re.fullmatch(r'v?20\d{2}\.\d+\.\d+(?:[-+][0-9A-Za-z.-]+)?', studio_version): - sys.exit(f'studio_version must be a Studio SemVer tag, not a date-style backend version: {studio_version}') - - semver_tag = re.compile( - r'^v(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)' - r'(?:-[0-9A-Za-z.][0-9A-Za-z.-]*)?$' - ) - if not semver_tag.fullmatch(studio_version): - sys.exit(f'studio_version must be a SemVer tag with leading v, for example v0.1.39-beta: {studio_version}') - - app_version = studio_version.removeprefix('v') - desktop_release_tag = f'desktop-v{app_version}' - prerelease = 'true' if '-' in app_version.split('+', 1)[0] else 'false' - - def parse_backend_version(version): - match = re.fullmatch( - r'(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)' - r'(?:([a-zA-Z]|\.dev|dev|\.rc|rc|\.post|post)(\d*))?' - r'(?:[-+]([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?', - version, - ) - if not match: - return None - major, minor, patch, suffix_name, suffix_number, suffix_text = match.groups() - if suffix_name: - normalized = suffix_name.lower().lstrip('.') - order = {'dev': 0, 'a': 1, 'b': 2, 'rc': 3, 'post': 5}.get(normalized) - if order is None: - return None - number = int(suffix_number or '0') - elif suffix_text: - order = 3 if version[version.find(suffix_text) - 1] == '-' else 4 - number = 0 - else: - order = 4 - number = 0 - return (int(major), int(minor), int(patch), order, number) - - preflight = pathlib.Path('studio/src-tauri/src/preflight/version.rs').read_text() - match = re.search(r'MIN_DESKTOP_BACKEND_VERSION:\s*&str\s*=\s*"([^"]+)"', preflight) - if not match: - sys.exit('Could not read MIN_DESKTOP_BACKEND_VERSION') - min_backend_version = match.group(1) - - input_pypi_version = os.environ.get('INPUT_PYPI_VERSION', '').strip() - parsed_min_backend = parse_backend_version(min_backend_version) - if parsed_min_backend is None: - sys.exit(f'MIN_DESKTOP_BACKEND_VERSION is not a supported backend package version: {min_backend_version}') - - pypi_version = input_pypi_version or min_backend_version - parsed_pypi = parse_backend_version(pypi_version) - if parsed_pypi is None: - sys.exit(f'pypi_version is not a supported backend package version: {pypi_version}') - if parsed_pypi < parsed_min_backend: - sys.exit( - f'pypi_version {pypi_version} is lower than desktop minimum ' - f'MIN_DESKTOP_BACKEND_VERSION {min_backend_version}' - ) - - if input_pypi_version: - print( - 'Using exact PyPI unsloth version from pypi_version input: ' - f'{pypi_version} (desktop minimum: {min_backend_version})' - ) - else: - print( - 'Using exact PyPI unsloth version from MIN_DESKTOP_BACKEND_VERSION: ' - f'{pypi_version}' - ) - - with open(os.environ['GITHUB_OUTPUT'], 'a', encoding='utf-8') as output: - print(f'studio_version={studio_version}', file=output) - print(f'app_version={app_version}', file=output) - print(f'desktop_release_tag={desktop_release_tag}', file=output) - print(f'prerelease={prerelease}', file=output) - print(f'pypi_version={pypi_version}', file=output) - PY - - - name: Verify PyPI package and Studio stamp - shell: bash - env: - STUDIO_VERSION: ${{ steps.prepare.outputs.studio_version }} - PYPI_VERSION: ${{ steps.prepare.outputs.pypi_version }} - run: | - set -euo pipefail - python3 <<'PY' - import json - import os - import pathlib - import sys - import time - import urllib.error - import urllib.request - - pypi_version = os.environ['PYPI_VERSION'] - dist_dir = pathlib.Path(os.environ['RUNNER_TEMP'], 'pypi-unsloth-dist') - dist_dir.mkdir(parents=True, exist_ok=True) - metadata_url = f'https://pypi.org/pypi/unsloth/{pypi_version}/json' - - last_error = None - for attempt in range(1, 6): - try: - with urllib.request.urlopen(metadata_url, timeout=30) as response: - metadata = json.load(response) - break - except Exception as exc: - last_error = exc - if attempt < 5: - time.sleep(10 * attempt) - else: - sys.exit(f'Publish unsloth=={pypi_version} to PyPI before the desktop release ({last_error})') - - files = metadata.get('urls') or [] - if not files: - sys.exit(f'PyPI returned no distribution files for unsloth=={pypi_version}') - - for file_info in files: - filename = file_info.get('filename') - url = file_info.get('url') - if not filename or '/' in filename or not url: - sys.exit(f'Unexpected PyPI file entry for unsloth=={pypi_version}: {file_info!r}') - target = dist_dir / filename - for attempt in range(1, 4): - try: - with urllib.request.urlopen(url, timeout=60) as response: - target.write_bytes(response.read()) - break - except Exception as exc: - last_error = exc - if attempt < 3: - time.sleep(5 * attempt) - else: - sys.exit(f'Could not download {filename} from PyPI ({last_error})') - PY - - if [ -f scripts/stamp_studio_release.py ]; then - mapfile -t dists < <(find "$RUNNER_TEMP/pypi-unsloth-dist" -type f \( -name '*.whl' -o -name '*.tar.gz' \) | sort) - if [ "${#dists[@]}" -eq 0 ]; then - echo "No PyPI wheel/sdist artifacts downloaded for unsloth==$PYPI_VERSION" >&2 - exit 1 - fi - python3 scripts/stamp_studio_release.py --verify-dist "$RUNNER_TEMP/pypi-unsloth-dist" --expected "$STUDIO_VERSION" - else - echo "scripts/stamp_studio_release.py not found; release-desktop requires #5308 to verify the PyPI Studio stamp." >&2 - exit 1 - fi - - - name: Guard public updater channel version - if: ${{ !inputs.draft }} - shell: bash - env: - GH_REPO: ${{ github.repository }} - GH_TOKEN: ${{ github.token }} - APP_VERSION: ${{ steps.prepare.outputs.app_version }} - run: | - set -euo pipefail - mkdir -p "$RUNNER_TEMP/desktop-current" - if ! gh release download desktop-latest --pattern latest.json --dir "$RUNNER_TEMP/desktop-current" --clobber 2>/dev/null; then - echo "No existing desktop-latest latest.json found; allowing first channel publish." - exit 0 - fi - python3 <<'PY' - import json - import os - import pathlib - import re - import sys - - def parse(value: str): - value = value.removeprefix('v') - match = re.fullmatch( - r'(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)' - r'(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?' - r'(?:\+[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?', - value, - ) - if not match: - sys.exit(f'desktop-latest latest.json has invalid version: {value}') - major, minor, patch, prerelease = match.groups() - return (int(major), int(minor), int(patch), prerelease) - - def numeric_tail(identifier: str) -> tuple[str, int] | None: - match = re.fullmatch(r'([A-Za-z-]+)(\d+)', identifier) - if not match: - return None - return (match.group(1).lower(), int(match.group(2))) - - def compare_identifier(left: str, right: str) -> int: - left_num = left.isdigit() - right_num = right.isdigit() - if left_num and right_num: - return (int(left) > int(right)) - (int(left) < int(right)) - if left_num: - return -1 - if right_num: - return 1 - - left_tail = numeric_tail(left) - right_tail = numeric_tail(right) - if left_tail and right_tail and left_tail[0] == right_tail[0]: - return (left_tail[1] > right_tail[1]) - (left_tail[1] < right_tail[1]) - - return (left > right) - (left < right) - - def compare_prerelease(left: str | None, right: str | None) -> int: - if left == right: - return 0 - if left is None: - return 1 - if right is None: - return -1 - left_parts = left.split('.') - right_parts = right.split('.') - for left_part, right_part in zip(left_parts, right_parts): - order = compare_identifier(left_part, right_part) - if order: - return order - return (len(left_parts) > len(right_parts)) - (len(left_parts) < len(right_parts)) - - def compare(left: str, right: str) -> int: - left_major, left_minor, left_patch, left_pre = parse(left) - right_major, right_minor, right_patch, right_pre = parse(right) - left_core = (left_major, left_minor, left_patch) - right_core = (right_major, right_minor, right_patch) - if left_core != right_core: - return (left_core > right_core) - (left_core < right_core) - return compare_prerelease(left_pre, right_pre) - - current_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-current', 'latest.json') - current = json.loads(current_path.read_text()).get('version') - next_version = os.environ['APP_VERSION'] - if not isinstance(current, str): - sys.exit('desktop-latest latest.json has missing version') - if compare(next_version, current) < 0: - sys.exit( - f'Refusing to publish {next_version}; desktop-latest currently points at newer version {current}.' - ) - PY - - build: - # TODO: split into a "build (no secrets)" + "publish (secrets)" job pair - # with actions/upload-artifact handoff so the matrix build cannot - # publish a Release on its own. The current matrix runs across - # Linux/macOS/Windows in a single job, so the split needs artefact - # collection across the OS matrix and is out of scope for this - # hardening pass. - permissions: - contents: write # tauri-apps/tauri-action creates / uploads a GitHub Release - strategy: - fail-fast: false - max-parallel: 1 - matrix: - include: - - platform: macos-latest - args: '--target aarch64-apple-darwin' - label: macOS (Apple Silicon) - # - platform: macos-latest - # args: '--target x86_64-apple-darwin' - # label: macOS (Intel) - - platform: ubuntu-22.04 - args: '' - label: Linux (x64) - - platform: windows-latest - args: '' - label: Windows (x64) - - name: Build ${{ matrix.label }} - needs: prepare-version - runs-on: ${{ matrix.platform }} - - env: - FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true - APP_VERSION: ${{ needs.prepare-version.outputs.app_version }} - STUDIO_VERSION: ${{ needs.prepare-version.outputs.studio_version }} - DESKTOP_RELEASE_TAG: ${{ needs.prepare-version.outputs.desktop_release_tag }} - DESKTOP_PRERELEASE: ${{ needs.prepare-version.outputs.prerelease }} - - steps: - # harden-runner in audit mode: surfaces every egress destination in - # the runner log so the allowlist for a future `egress-policy: block` - # promotion can be derived from observed traffic. Audit mode is - # cross-platform (Linux / macOS / Windows runners); blocking mode is - # currently Linux-only, so we deliberately stay in audit until the - # macOS + Windows codesign paths have been observed. - - name: Harden runner (audit) - uses: step-security/harden-runner@a5ad31d6a139d249332a2605b85202e8c0b78450 # v2.19.1 - with: - egress-policy: audit - - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd - with: - persist-credentials: false - - # ── Linux dependencies ── - - name: Install Linux dependencies - if: matrix.platform == 'ubuntu-22.04' - run: | - sudo apt-get update - sudo apt-get install -y libwebkit2gtk-4.1-dev libayatana-appindicator3-dev librsvg2-dev libxdo-dev libssl-dev patchelf - - # ── Node.js ── - - name: Setup Node.js - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e - with: - node-version: 24 - - - name: Install pinned Tauri CLI - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: npm install --save-dev --prefix studio @tauri-apps/cli@2.10.1 --no-fund --no-audit - - - name: Verify pinned Tauri CLI - shell: bash - run: | - out="$(npx --prefix studio tauri --version)" - echo "$out" - if [ "$out" != "tauri-cli 2.10.1" ]; then - echo "Expected tauri-cli 2.10.1, got $out" >&2 - exit 1 - fi - - - name: Verify desktop updater and Linux package config - shell: bash - run: | - node <<'JS' - const { readFileSync } = require('node:fs'); - - const expected = 'https://github.com/unslothai/unsloth/releases/download/desktop-latest/latest.json'; - const config = JSON.parse(readFileSync('studio/src-tauri/tauri.conf.json', 'utf8')); - const endpoints = config.plugins?.updater?.endpoints; - if (!Array.isArray(endpoints) || endpoints.length !== 1) { - throw new Error('Expected exactly one desktop updater endpoint'); - } - if (endpoints[0] !== expected) { - throw new Error('Desktop updater endpoint must be ' + expected + ', got ' + endpoints[0]); - } - if (endpoints.some((endpoint) => endpoint.includes('/releases/latest/'))) { - throw new Error('Desktop updater endpoint must not use repo-wide /releases/latest/'); - } - - const targets = config.bundle?.targets; - if (Array.isArray(targets) && targets.some((target) => String(target).toLowerCase() === 'rpm')) { - throw new Error('Desktop release must not target RPM packages'); - } - if (config.bundle?.linux?.rpm) { - throw new Error('bundle.linux.rpm must not be configured'); - } - - const workflow = readFileSync('.github/workflows/release-desktop.yml', 'utf8'); - const lines = workflow.split(/\r?\n/); - const releaseBodies = []; - for (let i = 0; i < lines.length; i += 1) { - const match = lines[i].match(/^(\s*)releaseBody:\s*\|\s*$/); - if (!match) continue; - const baseIndent = match[1].length; - const bodyLines = []; - i += 1; - for (; i < lines.length; i += 1) { - const line = lines[i]; - if (line.trim() === '') { - bodyLines.push(''); - continue; - } - const indent = line.match(/^\s*/)[0].length; - if (indent <= baseIndent) { - i -= 1; - break; - } - bodyLines.push(line.slice(baseIndent + 2)); - } - releaseBodies.push(bodyLines.join('\n')); - } - if (releaseBodies.length === 0) { - throw new Error('Expected at least one desktop release body'); - } - for (const body of releaseBodies) { - if (/\brpm\b|\.rpm/i.test(body)) { - throw new Error('Desktop release body must not advertise RPM packages'); - } - } - JS - - - name: Install frontend dependencies - working-directory: studio/frontend - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: npm install --no-fund --no-audit - - # ── Rust ── - - name: Install Rust stable - uses: dtolnay/rust-toolchain@29eef336d9b2848a0b548edc03f92a220660cdb8 # stable @ 2026-03-27 - with: - targets: ${{ matrix.platform == 'macos-latest' && 'aarch64-apple-darwin,x86_64-apple-darwin' || '' }} - - - name: Patch desktop app version - shell: bash - working-directory: studio/src-tauri - run: | - set -euo pipefail - if command -v python3 >/dev/null 2>&1; then - PYTHON=python3 - else - PYTHON=python - fi - "$PYTHON" <<'PY' - import os - import pathlib - import re - import sys - - app_version = os.environ['APP_VERSION'] - if not app_version: - sys.exit('APP_VERSION is required') - - cargo_toml = pathlib.Path('Cargo.toml') - lines = cargo_toml.read_text().splitlines(keepends=True) - in_package = False - patched = False - for index, line in enumerate(lines): - stripped = line.strip() - if stripped == '[package]': - in_package = True - continue - if stripped.startswith('[') and stripped.endswith(']'): - in_package = False - if in_package and re.fullmatch(r'version\s*=\s*"[^"]+"\s*', stripped): - lines[index] = f'version = "{app_version}"\n' - patched = True - break - if not patched: - sys.exit('Could not patch [package] version in Cargo.toml') - cargo_toml.write_text(''.join(lines)) - - cargo_lock = pathlib.Path('Cargo.lock') - lock_text = cargo_lock.read_text() - lock_text, count = re.subn( - r'(?m)(^\[\[package\]\]\nname = "unsloth-studio"\nversion = ")[^"]+(")', - lambda match: f'{match.group(1)}{app_version}{match.group(2)}', - lock_text, - ) - if count != 1: - sys.exit(f'Could not patch unsloth-studio version in Cargo.lock (matches={count})') - cargo_lock.write_text(lock_text) - PY - - cargo metadata --locked --no-deps --format-version 1 > "$RUNNER_TEMP/cargo-metadata.json" - "$PYTHON" <<'PY' - import json - import os - import pathlib - import sys - - app_version = os.environ['APP_VERSION'] - metadata = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'cargo-metadata.json').read_text()) - versions = [package['version'] for package in metadata.get('packages', []) if package.get('name') == 'unsloth-studio'] - if versions != [app_version]: - sys.exit(f'cargo metadata unsloth-studio version mismatch: expected {app_version}, got {versions}') - PY - - git diff -- Cargo.toml Cargo.lock - - - name: Rust cache - uses: swatinem/rust-cache@e18b497796c12c097a38f9edb9d0641fb99eee32 - with: - workspaces: 'studio/src-tauri -> target' - - # ── macOS: import signing certificate ── - - name: Import Apple certificate - if: matrix.platform == 'macos-latest' - env: - APPLE_CERTIFICATE: ${{ secrets.APPLE_CERTIFICATE }} - APPLE_CERTIFICATE_PASSWORD: ${{ secrets.APPLE_CERTIFICATE_PASSWORD }} - KEYCHAIN_PASSWORD: ${{ secrets.KEYCHAIN_PASSWORD }} - run: | - echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12 - security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain - security default-keychain -s build.keychain - security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain - security set-keychain-settings -t 3600 -u build.keychain - security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign - security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain - security find-identity -v -p codesigning build.keychain - rm -f certificate.p12 - - # ── Windows: install Azure Trusted Signing CLI ── - - name: Install trusted-signing-cli - if: matrix.platform == 'windows-latest' - run: | - cargo install trusted-signing-cli --version 0.10.0 --locked - echo "$env:USERPROFILE\.cargo\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - - # ── Windows: verify signing CLI is accessible ── - - name: Verify trusted-signing-cli - if: matrix.platform == 'windows-latest' - run: | - Write-Output "PATH: $env:PATH" - Get-Command trusted-signing-cli -ErrorAction SilentlyContinue || Write-Output "trusted-signing-cli NOT in PATH" - trusted-signing-cli --version || Write-Output "trusted-signing-cli failed to run" - - # ── Linux: build + sign + upload ── - - name: Build Linux app - if: matrix.platform == 'ubuntu-22.04' - uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }} - TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }} - with: - projectPath: studio - tauriScript: npx --prefix . tauri - tagName: ${{ needs.prepare-version.outputs.desktop_release_tag }} - releaseName: 'Unsloth Studio (Desktop) ${{ needs.prepare-version.outputs.studio_version }}' - releaseBody: | - Desktop app for Unsloth Studio. - - **macOS**: Download the Apple Silicon `.dmg`. - **Windows**: Download the `-setup.exe` installer. - **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal). - - > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package. - > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64` - > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually. - releaseDraft: ${{ inputs.draft }} - prerelease: ${{ needs.prepare-version.outputs.prerelease }} - args: -v ${{ matrix.args }} - - # ── macOS: build + sign + notarize + upload ── - - name: Build macOS app - if: matrix.platform == 'macos-latest' - uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }} - TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }} - APPLE_SIGNING_IDENTITY: ${{ secrets.APPLE_SIGNING_IDENTITY }} - APPLE_ID: ${{ secrets.APPLE_ID }} - APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }} - APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }} - with: - projectPath: studio - tauriScript: npx --prefix . tauri - tagName: ${{ needs.prepare-version.outputs.desktop_release_tag }} - releaseName: 'Unsloth Studio (Desktop) ${{ needs.prepare-version.outputs.studio_version }}' - releaseBody: | - Desktop app for Unsloth Studio. - - **macOS**: Download the Apple Silicon `.dmg`. - **Windows**: Download the `-setup.exe` installer. - **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal). - - > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package. - > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64` - > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually. - releaseDraft: ${{ inputs.draft }} - prerelease: ${{ needs.prepare-version.outputs.prerelease }} - args: -v ${{ matrix.args }} - - # ── Windows: build + sign + upload ── - - name: Build Windows app - if: matrix.platform == 'windows-latest' - uses: tauri-apps/tauri-action@84b9d35b5fc46c1e45415bdb6144030364f7ebc5 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }} - TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }} - AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} - AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }} - AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} - AZURE_TRUSTED_SIGNING_ACCOUNT_NAME: ${{ secrets.AZURE_TRUSTED_SIGNING_ACCOUNT_NAME }} - AZURE_CERTIFICATE_PROFILE_NAME: ${{ secrets.AZURE_CERTIFICATE_PROFILE_NAME }} - with: - projectPath: studio - tauriScript: npx --prefix . tauri - tagName: ${{ needs.prepare-version.outputs.desktop_release_tag }} - releaseName: 'Unsloth Studio (Desktop) ${{ needs.prepare-version.outputs.studio_version }}' - releaseBody: | - Desktop app for Unsloth Studio. - - **macOS**: Download the Apple Silicon `.dmg`. - **Windows**: Download the `-setup.exe` installer. - **Linux**: Download `.deb` (Ubuntu/Debian) or `.AppImage` (universal). - - > Linux in-app updates are AppImage-oriented. Package installs should update by downloading a new package. - > Linux AppImage on Ubuntu 24.04+ may require: `sudo apt install libfuse2t64` - > First-run system dependency elevation is supported on Ubuntu/Debian. Other Linux distributions should install system packages manually. - releaseDraft: ${{ inputs.draft }} - prerelease: ${{ needs.prepare-version.outputs.prerelease }} - args: -v ${{ matrix.args }} - - # Release process note: only non-draft workflow runs advance the public - # desktop-latest updater channel. Draft builds are for private review; if a - # draft is manually published later, this channel intentionally remains - # unchanged until a narrow manual channel-publish flow is added or a public - # desktop release is created by running this workflow with draft=false. - publish-updater-channel: - name: Publish desktop updater channel - needs: [prepare-version, build] - if: ${{ !inputs.draft }} - runs-on: ubuntu-latest - permissions: - contents: write - env: - GH_REPO: ${{ github.repository }} - APP_VERSION: ${{ needs.prepare-version.outputs.app_version }} - STUDIO_VERSION: ${{ needs.prepare-version.outputs.studio_version }} - DESKTOP_RELEASE_TAG: ${{ needs.prepare-version.outputs.desktop_release_tag }} - DESKTOP_PRERELEASE: ${{ needs.prepare-version.outputs.prerelease }} - - steps: - - name: Download versioned updater metadata - shell: bash - env: - GH_TOKEN: ${{ github.token }} - run: | - set -euo pipefail - mkdir -p "$RUNNER_TEMP/desktop-updater" - gh api "repos/${GITHUB_REPOSITORY}/releases/tags/${DESKTOP_RELEASE_TAG}" > "$RUNNER_TEMP/source-release.json" - python3 <<'PY' - import json - import os - import pathlib - import sys - - source = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'source-release.json').read_text()) - expected_tag = os.environ['DESKTOP_RELEASE_TAG'] - if source.get('tag_name') != expected_tag: - sys.exit(f'Expected source release {expected_tag}, got {source.get("tag_name")}') - if source.get('draft'): - sys.exit(f'Source desktop release {expected_tag} is draft; refusing to publish public updater channel') - PY - gh release download "$DESKTOP_RELEASE_TAG" --pattern latest.json --dir "$RUNNER_TEMP/desktop-updater" --clobber - test -s "$RUNNER_TEMP/desktop-updater/latest.json" - - - name: Validate versioned updater metadata - shell: bash - run: | - python3 <<'PY' - import json - import os - import pathlib - import re - import sys - - app_version = os.environ['APP_VERSION'] - release_tag = os.environ['DESKTOP_RELEASE_TAG'] - latest_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-updater', 'latest.json') - data = json.loads(latest_path.read_text()) - if not isinstance(data, dict): - sys.exit('latest.json must be a JSON object') - - version = data.get('version') - if not isinstance(version, str) or not version: - sys.exit('latest.json missing version') - if not re.fullmatch(r'v?\d+\.\d+\.\d+(?:[-+][0-9A-Za-z.-]+)?', version): - sys.exit(f'latest.json version is not SemVer-like: {version}') - if version.removeprefix('v') != app_version: - sys.exit(f'latest.json version {version} does not match desktop app version {app_version}') - - platforms = data.get('platforms') - if not isinstance(platforms, dict) or not platforms: - sys.exit('latest.json missing platforms') - - required_families = { - 'darwin-aarch64': False, - 'linux-x86_64': False, - 'windows-x86_64': False, - } - expected_prefix = f'https://github.com/unslothai/unsloth/releases/download/{release_tag}/' - forbidden_fragments = ('/releases/latest/', '/releases/download/desktop-latest/') - - for platform, entry in platforms.items(): - if not isinstance(entry, dict): - sys.exit(f'Platform {platform} must be an object') - url = entry.get('url') - signature = entry.get('signature') - if not isinstance(url, str) or not url.strip(): - sys.exit(f'Platform {platform} missing url') - if not isinstance(signature, str) or not signature.strip(): - sys.exit(f'Platform {platform} missing signature') - if any(fragment in url for fragment in forbidden_fragments): - sys.exit(f'Platform {platform} points at a moving updater channel: {url}') - if not url.startswith(expected_prefix): - sys.exit(f'Platform {platform} URL must point at {release_tag}: {url}') - for family in required_families: - if platform == family or platform.startswith(family + '-'): - required_families[family] = True - - missing = [family for family, found in required_families.items() if not found] - if missing: - sys.exit('latest.json missing required platform families: ' + ', '.join(missing)) - PY - - - name: Ensure desktop updater channel release - shell: bash - env: - GH_TOKEN: ${{ github.token }} - run: | - set -euo pipefail - channel_json="$RUNNER_TEMP/desktop-latest-release.json" - if ! gh api "repos/${GITHUB_REPOSITORY}/releases/tags/desktop-latest" > "$channel_json" 2>/dev/null; then - gh release create desktop-latest \ - --title "Unsloth Studio Desktop updater channel" \ - --notes "Machine-managed desktop updater channel; latest.json is replaced by release-desktop.yml." \ - --prerelease \ - --latest=false \ - --target "$GITHUB_SHA" - gh api "repos/${GITHUB_REPOSITORY}/releases/tags/desktop-latest" > "$channel_json" - fi - - python3 <<'PY' - import json - import os - import pathlib - import sys - - channel = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-latest-release.json').read_text()) - if channel.get('draft'): - sys.exit('desktop-latest release is draft; refusing to publish updater channel') - if channel.get('immutable'): - sys.exit('desktop-latest release is immutable; cannot replace latest.json') - if not channel.get('prerelease'): - sys.exit('desktop-latest release must be a prerelease so it cannot compete with repo-wide latest') - PY - - - name: Prevent updater channel downgrade - shell: bash - env: - GH_TOKEN: ${{ github.token }} - run: | - set -euo pipefail - mkdir -p "$RUNNER_TEMP/desktop-current" - if ! gh release download desktop-latest --pattern latest.json --dir "$RUNNER_TEMP/desktop-current" --clobber 2>/dev/null; then - echo "No existing desktop-latest latest.json found; allowing first channel publish." - exit 0 - fi - python3 <<'PY' - import json - import os - import pathlib - import re - import sys - - def parse(value: str): - value = value.removeprefix('v') - match = re.fullmatch( - r'(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)' - r'(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?' - r'(?:\+[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*)?', - value, - ) - if not match: - sys.exit(f'desktop-latest latest.json has invalid version: {value}') - major, minor, patch, prerelease = match.groups() - return (int(major), int(minor), int(patch), prerelease) - - def numeric_tail(identifier: str) -> tuple[str, int] | None: - match = re.fullmatch(r'([A-Za-z-]+)(\d+)', identifier) - if not match: - return None - return (match.group(1).lower(), int(match.group(2))) - - def compare_identifier(left: str, right: str) -> int: - left_num = left.isdigit() - right_num = right.isdigit() - if left_num and right_num: - return (int(left) > int(right)) - (int(left) < int(right)) - if left_num: - return -1 - if right_num: - return 1 - - left_tail = numeric_tail(left) - right_tail = numeric_tail(right) - if left_tail and right_tail and left_tail[0] == right_tail[0]: - return (left_tail[1] > right_tail[1]) - (left_tail[1] < right_tail[1]) - - return (left > right) - (left < right) - - def compare_prerelease(left: str | None, right: str | None) -> int: - if left == right: - return 0 - if left is None: - return 1 - if right is None: - return -1 - left_parts = left.split('.') - right_parts = right.split('.') - for left_part, right_part in zip(left_parts, right_parts): - order = compare_identifier(left_part, right_part) - if order: - return order - return (len(left_parts) > len(right_parts)) - (len(left_parts) < len(right_parts)) - - def compare(left: str, right: str) -> int: - left_major, left_minor, left_patch, left_pre = parse(left) - right_major, right_minor, right_patch, right_pre = parse(right) - left_core = (left_major, left_minor, left_patch) - right_core = (right_major, right_minor, right_patch) - if left_core != right_core: - return (left_core > right_core) - (left_core < right_core) - return compare_prerelease(left_pre, right_pre) - - current_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-current', 'latest.json') - next_path = pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-updater', 'latest.json') - current = json.loads(current_path.read_text()).get('version') - next_version = json.loads(next_path.read_text()).get('version') - if not isinstance(current, str) or not isinstance(next_version, str): - sys.exit('Could not compare desktop-latest channel versions') - if compare(next_version, current) < 0: - sys.exit( - f'Refusing to move desktop-latest from {current} to older version {next_version}.' - ) - PY - - - name: Publish desktop updater channel metadata - shell: bash - env: - GH_TOKEN: ${{ github.token }} - run: | - set -euo pipefail - gh release upload desktop-latest "$RUNNER_TEMP/desktop-updater/latest.json" --clobber - gh api "repos/${GITHUB_REPOSITORY}/releases/tags/desktop-latest" > "$RUNNER_TEMP/desktop-latest-release.json" - python3 <<'PY' - import json - import os - import pathlib - import sys - - channel = json.loads(pathlib.Path(os.environ['RUNNER_TEMP'], 'desktop-latest-release.json').read_text()) - assets = [asset for asset in channel.get('assets', []) if asset.get('name') == 'latest.json'] - if len(assets) != 1: - sys.exit(f'Expected exactly one desktop-latest latest.json asset, found {len(assets)}') - expected_url = f'https://github.com/{os.environ["GITHUB_REPOSITORY"]}/releases/download/desktop-latest/latest.json' - actual_url = assets[0].get('browser_download_url') - if actual_url != expected_url: - sys.exit(f'desktop-latest latest.json URL mismatch: expected {expected_url}, got {actual_url}') - PY diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml deleted file mode 100644 index 1a4cf841d0..0000000000 --- a/.github/workflows/stale.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: 'Inactive Issue Pinger' - -on: - schedule: - - cron: '30 5 * * *' # Runs at 5:30 UTC every day - -jobs: - stale: - runs-on: ubuntu-latest - permissions: - issues: write - - steps: - - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0 - with: - # The message to post on stale issues. - # This message will ping the issue author. - # Note: The stale bot action does not currently support a direct placeholder for the last commenter. - # As a workaround, this message encourages any participant to reply. - stale-issue-message: > - Is this issue still important to you? - Apologies in advance we might have missed this issue as well. - For faster response times, please post on our Reddit server - https://www.reddit.com/r/unsloth or our Discord - https://discord.com/invite/unsloth - - # The number of days of inactivity before an issue is considered stale. - days-before-issue-stale: 9999 - - # Set to -1 to never close stale issues. - days-before-issue-close: -1 - - # A label to apply to stale issues. - stale-issue-label: 'inactive' - - # The number of operations to perform per run to avoid rate limiting. - operations-per-run: 500 - - enable-statistics: false diff --git a/.github/workflows/studio-frontend-ci.yml b/.github/workflows/studio-frontend-ci.yml deleted file mode 100644 index 1270a57ef6..0000000000 --- a/.github/workflows/studio-frontend-ci.yml +++ /dev/null @@ -1,151 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -# Frontend PR gate: lockfile freshness, typecheck, build, and a bundle grep -# that catches the 2026.5.1 chat-history regression at the JS level. -# -# biome runs as non-blocking for now: the codebase currently has accumulated -# ~470 errors and ~1650 warnings against the existing biome config. Surfacing -# the count in CI lets us drive it down without forcing a fleet-wide cleanup -# in the same PR. Drop `continue-on-error` once that number is zero. - -name: Frontend CI - -on: - pull_request: - paths: - - 'studio/frontend/**' - - 'scripts/check_frontend_dep_removal.py' - - 'tests/studio/test_frontend_dep_removal.py' - - '.github/workflows/studio-frontend-ci.yml' - push: - branches: [main, pip] - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - build: - name: Frontend build + bundle sanity - runs-on: ubuntu-latest - timeout-minutes: 10 - defaults: - run: - working-directory: studio/frontend - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - # FIXME: drop this step once @assistant-ui/* and assistant-stream - # leave 0.x -- on 1.x, caret ranges are conventional. Until then, - # every 0.minor on this surface is a SemVer-major (this is exactly - # how 2026.5.1 shipped a broken chat runtime: ^0.12.19 quietly - # resolved to 0.12.28). - - name: '@assistant-ui must be pinned exactly (no caret/tilde)' - working-directory: ${{ github.workspace }} - run: | - set -e - if grep -nE '"(@assistant-ui/[a-z-]+|assistant-stream)":[[:space:]]*"[\^~]' studio/frontend/package.json; then - echo "::error file=studio/frontend/package.json::These packages must be pinned to exact versions until they leave 0.x. Drop the leading ^ or ~." - exit 1 - fi - echo "All assistant-ui packages are pinned exactly." - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '22' - - # Run the structural lockfile scan BEFORE npm ci. A compromised - # tarball runs its `prepare` / `postinstall` during `npm ci`, - # so any catch has to fire upstream of that. The scanner is - # pure-Python read-only; safe to call ahead of every install. - - name: Lockfile supply-chain audit (pre-install scan) - working-directory: ${{ github.workspace }} - run: python3 scripts/lockfile_supply_chain_audit.py - - - name: Lockfile must agree with package.json (npm ci is strict) - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: npm ci --no-fund --no-audit - - - name: npm ci must not have modified the working tree - working-directory: ${{ github.workspace }} - run: | - if ! git diff --quiet -- studio/frontend; then - echo "::error::npm ci modified files; commit the updated lockfile" - git status -- studio/frontend - exit 1 - fi - - # Catch the common foot-gun: a dep dropped from package.json that is - # still imported somewhere. The script walks the lockfile dep graph - # from the new top-level deps and only counts top-level node_modules - # paths as valid resolution targets for bare src/ imports. - # - # actions/checkout uses fetch-depth: 1 by default, so the base branch - # is not available locally. Fetch the single base commit with an - # explicit refspec so origin/ is reliably created (a bare - # `git fetch origin ` only updates FETCH_HEAD in some configs). - - name: Dependency removal safety check - if: github.event_name == 'pull_request' - working-directory: ${{ github.workspace }} - run: | - git fetch --no-tags --depth=1 origin \ - "${{ github.base_ref }}:refs/remotes/origin/${{ github.base_ref }}" - python3 scripts/check_frontend_dep_removal.py \ - --base "origin/${{ github.base_ref }}" \ - --enumerate-dead - python3 tests/studio/test_frontend_dep_removal.py - - - name: Typecheck - run: npm run typecheck - - - name: Build - run: npm run build - - - name: Built bundle must not contain Studio's unstable_Provider call site - run: | - set -e - JS=$(ls dist/assets/index-*.js | head -1) - HITS=$(grep -c 'unstable_Provider:' "$JS" || echo 0) - echo "main bundle: $JS" - echo "unstable_Provider: hits=$HITS (assistant-ui internals contribute up to 3)" - if [ "$HITS" -gt 3 ]; then - echo "::error file=studio/frontend/src/features/chat/runtime-provider.tsx::Studio bundle still passes unstable_Provider through useRemoteThreadListRuntime; this is the 2026.5.1 chat-history regression. Pass adapters directly into useLocalRuntime instead." - exit 1 - fi - - - name: Bundle size budget (75 MB) - run: | - SIZE=$(du -sb dist | cut -f1) - BUDGET=$((75 * 1024 * 1024)) - echo "dist size: $SIZE bytes ($((SIZE/1024/1024)) MB), budget: $BUDGET bytes (75 MB)" - if [ "$SIZE" -gt "$BUDGET" ]; then - echo "::error::studio/frontend/dist/ exceeded the 75 MB budget. Drop dead deps (e.g. the unused next dep) or split chunks." - exit 1 - fi - - - name: Biome (non-blocking until accumulated drift is cleared) - continue-on-error: true - run: npm run biome:check - - - name: Upload built dist - # Always upload so a green run is reviewable too -- the dist - # output catches "tests passed but bundle changed unexpectedly" - # regressions that would be invisible if we only kept artifacts - # on failure. - if: always() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: studio-frontend-dist - path: studio/frontend/dist - retention-days: 3 diff --git a/.github/workflows/studio-inference-smoke.yml b/.github/workflows/studio-inference-smoke.yml deleted file mode 100644 index 6def56f769..0000000000 --- a/.github/workflows/studio-inference-smoke.yml +++ /dev/null @@ -1,1052 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -# Three end-to-end smoke jobs that boot a freshly-installed Studio and -# exercise the surfaces real users hit through the OpenAI / Anthropic -# SDKs and curl. Each job picks the smallest model that exercises the -# behaviour under test, primes HF_HOME via actions/cache, and shares -# the install.sh --local --no-torch bootstrap. -# -# 1. OpenAI, Anthropic API tests -# gemma-3-270m-it UD-Q4_K_XL (~254 MiB). -# Password rotation via /api/auth/change-password (old fails, -# new works), then OpenAI + Anthropic Python SDKs against /v1/* -# with temperature=0 and a fixed seed. Asserts the four-turn -# conversation is deterministic across two runs. -# -# 2. Tool calling Tests -# Qwen3.5-2B UD-IQ3_XXS (~890 MiB). OpenAI function calling, -# server-side tools (python, terminal, web_search) via -# enable_tools / enabled_tools, and enable_thinking on/off. -# -# 3. JSON, images -# gemma-4-E2B-it UD-IQ3_XXS (~2.4 GiB) + mmproj-F16 (~986 MiB). -# response_format JSON-schema decoding and OpenAI image_url -# (data URI) plus Anthropic source/base64 image inputs. -# -# All three jobs run in parallel. Total wall time is dominated by job 3 -# on a cold cache; warm cache cuts that to ~3 min. - -name: Studio GGUF CI - -on: - pull_request: - paths: - - 'studio/**' - - 'unsloth/**' - - 'unsloth_cli/**' - - 'install.sh' - - 'pyproject.toml' - - '.github/workflows/studio-inference-smoke.yml' - push: - branches: [main, pip] - # Manual trigger for pre-warming HF_HOME caches on main, or re-running - # against an arbitrary branch without pushing a no-op commit. - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - # ───────────────────────────────────────────────────────────────────── - # Job 1: OpenAI, Anthropic API tests - # ───────────────────────────────────────────────────────────────────── - openai-anthropic: - name: OpenAI, Anthropic API tests - runs-on: ubuntu-latest - timeout-minutes: 25 - env: - GGUF_REPO: unsloth/gemma-3-270m-it-GGUF - GGUF_VARIANT: UD-Q4_K_XL - GGUF_FILE: gemma-3-270m-it-UD-Q4_K_XL.gguf - STUDIO_PORT: '18888' - HF_HOME: ${{ github.workspace }}/hf-cache - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - - name: Linux deps for llama.cpp prebuilt - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends \ - libcurl4-openssl-dev libssl-dev jq - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '22' - - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: '3.12' - cache: 'pip' - - - name: Restore HF_HOME for ${{ env.GGUF_REPO }} - id: cache-hf - uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - continue-on-error: true - with: - path: hf-cache - key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-v1 - - - name: Prime HF_HOME with the GGUF - id: prime-hf - if: steps.cache-hf.outputs.cache-hit != 'true' || steps.cache-hf.outcome != 'success' - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - python -m pip install --upgrade huggingface_hub - mkdir -p hf-cache - bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$GGUF_FILE" - - - name: Save HF_HOME for ${{ env.GGUF_REPO }} - if: always() && steps.prime-hf.outcome == 'success' - uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - with: - path: hf-cache - key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-v1 - - - name: Install Studio (--local, --no-torch) - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - mkdir -p logs - set -o pipefail - bash install.sh --local --no-torch 2>&1 | tee logs/install.log - - - name: Install OpenAI + Anthropic Python SDKs - run: pip install 'openai>=1.50' 'anthropic>=0.40' - - - name: Reset auth + boot Studio (API-only) - run: | - unsloth studio reset-password - mkdir -p logs - UNSLOTH_API_ONLY=1 unsloth studio -H 127.0.0.1 -p "$STUDIO_PORT" \ - > logs/studio.log 2>&1 & - echo "STUDIO_PID=$!" >> "$GITHUB_ENV" - - - name: Wait for /api/health - run: | - for i in $(seq 1 180); do - if curl -fs "http://127.0.0.1:${STUDIO_PORT}/api/health" > /tmp/health.json; then - jq -e '.status == "healthy"' /tmp/health.json - exit 0 - fi - sleep 1 - done - echo "Studio did not become healthy in 180s" - tail -200 logs/studio.log - exit 1 - - - name: Password rotation (old must fail, new must work) - run: | - OLD=$(cat ~/.unsloth/studio/auth/.bootstrap_password) - NEW="CIRotated-$(python -c 'import secrets; print(secrets.token_urlsafe(12))')" - echo "::add-mask::$OLD" - echo "::add-mask::$NEW" - # 1. Login with the bootstrap password. - OLD_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}" | jq -r .access_token) - [ -n "$OLD_TOKEN" ] && [ "$OLD_TOKEN" != "null" ] || { echo "bootstrap login failed"; exit 1; } - # 2. Rotate to a fresh random password. - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/change-password" \ - -H "Authorization: Bearer $OLD_TOKEN" -H 'content-type: application/json' \ - -d "{\"current_password\":\"$OLD\",\"new_password\":\"$NEW\"}" > /dev/null - # 3. Old password must now be rejected (HTTP 401). - OLD_STATUS=$(curl -s -o /dev/null -w '%{http_code}' \ - -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}") - if [ "$OLD_STATUS" != "401" ]; then - echo "::error::Login with old password returned $OLD_STATUS, expected 401" - exit 1 - fi - # 4. New password must succeed; capture the JWT for downstream steps. - NEW_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$NEW\"}" | jq -r .access_token) - [ -n "$NEW_TOKEN" ] && [ "$NEW_TOKEN" != "null" ] || { echo "new login failed"; exit 1; } - echo "TOKEN=$NEW_TOKEN" >> "$GITHUB_ENV" - echo "password rotation OK (old=401, new=200)" - - - name: Load the GGUF (HF repo + variant, served from HF_HOME cache) - run: | - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/inference/load" \ - -H "Authorization: Bearer $TOKEN" -H 'content-type: application/json' \ - --max-time 600 \ - -d "{\"model_path\":\"$GGUF_REPO\",\"gguf_variant\":\"$GGUF_VARIANT\",\"is_lora\":false,\"max_seq_length\":2048}" \ - | jq '{status, display_name, is_gguf, context_length}' - - - name: Multi-turn determinism via OpenAI + Anthropic SDKs - env: - BASE_URL: http://127.0.0.1:18888 - run: | - python - <<'PY' - import json - import os - from openai import OpenAI - from anthropic import Anthropic - - BASE = os.environ["BASE_URL"] - KEY = os.environ["TOKEN"] # JWT also accepted as Bearer on /v1/* - SEED = 3407 - - # Four-turn conversation: the second and fourth turns can only be - # answered correctly if the model sees the prior turns, so this - # also exercises the conversation-history wiring. - PROMPTS = [ - "What is 1+1?", - "What did I ask before?", - "What is the capital of France?", - "Repeat the city name", - ] - - def run_openai(): - client = OpenAI(base_url = f"{BASE}/v1", api_key = KEY) - history, replies = [], [] - for prompt in PROMPTS: - history.append({"role": "user", "content": prompt}) - resp = client.chat.completions.create( - model = "default", - messages = history, - temperature = 0.0, - max_tokens = 80, - seed = SEED, - extra_body = {"enable_thinking": False}, - ) - text = resp.choices[0].message.content or "" - replies.append(text) - history.append({"role": "assistant", "content": text}) - return replies - - def run_anthropic(): - # Two SDK quirks vs. Studio: - # 1. base_url must NOT include /v1 -- the SDK appends - # /v1/messages itself; otherwise the request hits - # /v1/v1/messages and 405s. - # 2. The SDK sends `x-api-key` by default, but Studio's - # auth layer is HTTPBearer-only. Override via - # default_headers so Authorization: Bearer ... is - # sent instead. - client = Anthropic( - base_url = BASE, - api_key = "unused", - default_headers = {"Authorization": f"Bearer {KEY}"}, - ) - history, replies = [], [] - for prompt in PROMPTS: - history.append({"role": "user", "content": prompt}) - msg = client.messages.create( - model = "default", - max_tokens = 80, - messages = history, - temperature = 0.0, - extra_body = {"seed": SEED, "enable_thinking": False}, - ) - text = "".join(b.text for b in msg.content if getattr(b, "type", None) == "text") - replies.append(text) - history.append({"role": "assistant", "content": text}) - return replies - - for label, runner in (("openai", run_openai), ("anthropic", run_anthropic)): - first = runner() - second = runner() - determinism_failures = [] - for i, (a, b) in enumerate(zip(first, second), start = 1): - print(f"[{label} turn {i}] {a!r}") - # Both runs must be non-empty; small-quant drift - # across runs is WARN-only (grounding asserts below - # are the stronger signal). - assert a, f"{label}: empty turn {i} response in first run" - assert b, f"{label}: empty turn {i} response in second run" - if a.strip() != b.strip(): - determinism_failures.append( - f"turn {i}: run1={a!r} run2={b!r}" - ) - if determinism_failures: - print( - f"[{label}] WARN non-determinism at temperature=0.0 across " - f"{len(determinism_failures)} of {len(first)} turn(s); " - f"small-quant model drift, not a Studio regression. " - f"Details: " + " | ".join(determinism_failures) - ) - # Sanity: turn-2 reply should mention the earlier question, and - # turn-4 reply should mention Paris (model echoes the city it - # produced for turn 3). Lower-cased substring checks keep the - # assertion robust to formatting jitter. - joined = " ".join(first).lower() - assert "1" in first[0], f"{label}: turn-1 answer should contain '1', got {first[0]!r}" - assert "paris" in joined, f"{label}: expected 'paris' somewhere in the four-turn transcript: {first}" - status_word = "PASS" if not determinism_failures else "PASS (with drift)" - print(f"[{label}] {status_word} -- 4 turns, history grounded ('paris' present)") - PY - - - name: Stop Studio - if: always() - run: | - kill "${STUDIO_PID}" 2>/dev/null || true - sleep 2 - ss -tln | grep ":${STUDIO_PORT}" || true - - - name: Upload logs - # Always upload so green runs are still reviewable. - if: always() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: openai-anthropic-log - path: | - logs/studio.log - logs/install.log - retention-days: 7 - - # ───────────────────────────────────────────────────────────────────── - # Job 2: Tool calling Tests - # ───────────────────────────────────────────────────────────────────── - tool-calling: - name: Tool calling Tests - runs-on: ubuntu-latest - timeout-minutes: 25 - env: - # Tool calling is the highest-volume GGUF in this workflow - # (Qwen3.5-2B at IQ3_XXS = ~890 MiB). Caching HF_HOME would - # store xet chunks + blobs + snapshots = ~4 GiB compressed -- - # 4-5x file-size inflation, dominated by xet chunks. Use main's - # `--local-dir gguf-cache` pattern to cache the flat .gguf only. - # Studio's /api/inference/load accepts either a HF repo (which - # uses HF_HOME) or an absolute file path; passing the absolute - # path keeps the test off HF_HOME entirely so the cache size - # tracks the GGUF file 1:1. The OpenAI/Anth and JSON+images - # jobs still cover the gguf_variant resolution path. - GGUF_REPO: unsloth/Qwen3.5-2B-GGUF - GGUF_FILE: Qwen3.5-2B-UD-IQ3_XXS.gguf - STUDIO_PORT: '18889' - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - - name: Linux deps for llama.cpp prebuilt - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends \ - libcurl4-openssl-dev libssl-dev jq - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '22' - - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: '3.12' - cache: 'pip' - - - name: Restore GGUF model file - id: cache-gguf - uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - continue-on-error: true - with: - path: gguf-cache - key: ${{ runner.os }}-gguf-${{ env.GGUF_REPO }}-${{ env.GGUF_FILE }}-v1 - - - name: Download GGUF if cache miss - id: download-gguf - if: steps.cache-gguf.outputs.cache-hit != 'true' || steps.cache-gguf.outcome != 'success' - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - python -m pip install --upgrade huggingface_hub - mkdir -p gguf-cache - bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$GGUF_FILE" gguf-cache - - - name: Save GGUF model file - if: always() && steps.download-gguf.outcome == 'success' - uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - with: - path: gguf-cache - key: ${{ runner.os }}-gguf-${{ env.GGUF_REPO }}-${{ env.GGUF_FILE }}-v1 - - - name: Install Studio (--local, --no-torch) - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - mkdir -p logs - set -o pipefail - bash install.sh --local --no-torch 2>&1 | tee logs/install.log - - - name: Reset auth + boot Studio (API-only, default tool policy) - # We deliberately use the API-only mode rather than - # `unsloth studio run` because the latter calls - # `set_tool_policy(...)` with a resolved bool: on loopback the - # default resolves to True, which forces every request through - # the server-side agentic loop and breaks the standard - # function-calling test below. API-only mode leaves - # tool_policy=None so each request's `enable_tools` field is - # honoured. - run: | - unsloth studio reset-password - mkdir -p logs - UNSLOTH_API_ONLY=1 unsloth studio -H 127.0.0.1 -p "$STUDIO_PORT" \ - > logs/studio.log 2>&1 & - echo "STUDIO_PID=$!" >> "$GITHUB_ENV" - - - name: Wait for /api/health, log in, change password, load model - run: | - for i in $(seq 1 180); do - if curl -fs "http://127.0.0.1:${STUDIO_PORT}/api/health" > /tmp/health.json; then - jq -e '.status == "healthy"' /tmp/health.json && break - fi - sleep 1 - done - jq -e '.status == "healthy"' /tmp/health.json - OLD=$(cat ~/.unsloth/studio/auth/.bootstrap_password) - NEW="CITool-$(python -c 'import secrets; print(secrets.token_urlsafe(12))')" - echo "::add-mask::$OLD" - echo "::add-mask::$NEW" - OLD_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}" | jq -r .access_token) - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/change-password" \ - -H "Authorization: Bearer $OLD_TOKEN" -H 'content-type: application/json' \ - -d "{\"current_password\":\"$OLD\",\"new_password\":\"$NEW\"}" > /dev/null - TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$NEW\"}" | jq -r .access_token) - echo "API_KEY=$TOKEN" >> "$GITHUB_ENV" - GGUF_PATH="$GITHUB_WORKSPACE/gguf-cache/${GGUF_FILE}" - ls -lh "$GGUF_PATH" - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/inference/load" \ - -H "Authorization: Bearer $TOKEN" -H 'content-type: application/json' \ - --max-time 600 \ - -d "{\"model_path\":\"$GGUF_PATH\",\"is_lora\":false,\"max_seq_length\":2048}" \ - | jq '{status, display_name}' - - - name: Tool calling, server-side tools, thinking on/off - env: - BASE_URL: http://127.0.0.1:18889 - run: | - python - <<'PY' - import json - import os - import urllib.request - - BASE = os.environ["BASE_URL"] - KEY = os.environ["API_KEY"] - SEED = 3407 - - def post(path, body, *, timeout = 240): - """Plain JSON POST. For requests that don't go through - the server-side agentic loop, the response is one JSON - object.""" - data = json.dumps(body).encode() - req = urllib.request.Request( - f"{BASE}{path}", - data = data, - method = "POST", - headers = { - "Authorization": f"Bearer {KEY}", - "Content-Type": "application/json", - }, - ) - with urllib.request.urlopen(req, timeout = timeout) as resp: - return resp.status, json.loads(resp.read().decode()) - - def post_sse(path, body, *, timeout = 600): - """POST a streaming request and accumulate the assistant - text deltas. The server-side agentic loop ALWAYS returns - SSE regardless of the request's `stream` field, so any - call with enable_tools=true must use this helper. - - Returns (content, raw_payloads): - content -- concatenated assistant delta.content - raw_payloads -- list of every raw "data: ..." event - payload (JSON strings). Callers asserting - that a server-side tool actually ran (and - not just that the model emitted some - text) should grep raw_payloads for tool - invocation markers / tool output, since - `delta.content` alone is not evidence - that the tool path executed. - """ - body = {**body, "stream": True} - data = json.dumps(body).encode() - req = urllib.request.Request( - f"{BASE}{path}", - data = data, - method = "POST", - headers = { - "Authorization": f"Bearer {KEY}", - "Content-Type": "application/json", - }, - ) - parts = [] - events = [] - with urllib.request.urlopen(req, timeout = timeout) as resp: - for raw in resp: - line = raw.decode().strip() - if not line.startswith("data: "): - continue - payload = line[6:] - if payload == "[DONE]": - break - events.append(payload) - try: - chunk = json.loads(payload) - except json.JSONDecodeError: - continue - for choice in chunk.get("choices", []): - delta = choice.get("delta", {}) or {} - if delta.get("content"): - parts.append(delta["content"]) - return "".join(parts), events - - _STUDIO_TOOL_TYPES = { - "tool_start", "tool_end", "tool_use", "tool_result", - } - - def _tool_invoked(events): - """Structural check: True iff some SSE payload is a real - tool envelope (Studio tool_start/tool_end, Anthropic - tool_use/tool_result, OpenAI non-empty delta.tool_calls / - message.tool_calls / finish_reason='tool_calls' / - role:'tool' / function_call). tool_status is NOT - evidence: Studio emits empty tool_status events on - iteration boundaries even when no tool ran. - """ - for raw in events: - try: - ev = json.loads(raw) - except (json.JSONDecodeError, TypeError): - continue - if not isinstance(ev, dict): - continue - if ev.get("type") in _STUDIO_TOOL_TYPES: - return True - for choice in ev.get("choices", []) or []: - if not isinstance(choice, dict): - continue - if choice.get("finish_reason") == "tool_calls": - return True - for src_key in ("delta", "message"): - src = choice.get(src_key) or {} - if not isinstance(src, dict): - continue - tc = src.get("tool_calls") - if isinstance(tc, list) and tc: - return True - if src.get("function_call"): - return True - if src.get("role") == "tool": - return True - for item in ev.get("output", []) or []: - if isinstance(item, dict) and item.get("type") in { - "tool_call", "function_call", "tool_use", - }: - return True - content = ev.get("content") - if isinstance(content, list): - for blk in content: - if isinstance(blk, dict) and blk.get("type") in { - "tool_use", "tool_result", - }: - return True - return False - - def _tool_output_contains(events, *needles): - """True iff any tool_end.result / tool_result.content / - tool-role message content contains a needle. Inspects - the tool's own output, not the model's narration.""" - for raw in events: - try: - ev = json.loads(raw) - except (json.JSONDecodeError, TypeError): - continue - if not isinstance(ev, dict): - continue - if ev.get("type") == "tool_end": - result = ev.get("result") - if isinstance(result, str) and any(n in result for n in needles if n): - return True - if ev.get("type") == "tool_result": - content = ev.get("content") - if isinstance(content, str) and any(n in content for n in needles if n): - return True - if isinstance(content, list): - for blk in content: - if isinstance(blk, dict): - text = blk.get("text") or blk.get("content") - if isinstance(text, str) and any(n in text for n in needles if n): - return True - for choice in ev.get("choices", []) or []: - delta = (choice or {}).get("delta") or {} - msg = (choice or {}).get("message") or {} - for src in (delta, msg): - if src.get("role") == "tool": - content = src.get("content") or "" - if isinstance(content, str) and any(n in content for n in needles if n): - return True - return False - - # ── 1. Standard OpenAI function calling ────────────────────── - weather_tool = { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get current weather for a city.", - "parameters": { - "type": "object", - "properties": {"city": {"type": "string"}}, - "required": ["city"], - }, - }, - } - - status, data = post("/v1/chat/completions", { - "messages": [{"role": "user", "content": "What is the weather in Paris?"}], - "tools": [weather_tool], - "tool_choice": "required", - "stream": False, - "temperature": 0.0, - "seed": SEED, - "max_tokens": 120, - }) - assert status == 200, f"tool call status {status}: {data}" - choice = data["choices"][0] - assert choice["finish_reason"] == "tool_calls", f"finish_reason={choice['finish_reason']!r}" - tc = choice["message"]["tool_calls"][0] - assert tc["function"]["name"] == "get_weather" - args = json.loads(tc["function"]["arguments"]) - assert args.get("city"), f"missing city arg: {args}" - print(f"[tools] PASS function calling -> {tc['function']['name']}({args})") - - # T=0 = deterministic argmax in llama.cpp; T>0 lets seed - # rotation explore distinct trajectories on retry. - TOOL_PROBE_TEMP = 0.4 - - def _run_tool_probe(*, label, prompt, enabled, session, needles, - max_attempts = 4): - """Drive a server-side tool with retries. Hard FAIL if no - attempt has structural invocation evidence. WARN (not - FAIL) if invoked but no attempt produces the expected - literal in tool_end.result -- small-quant Qwen3.5-2B can - emit OpenAI tool_calls deltas without Studio's GGUF - agentic loop intercepting them, and that GGUF-vs-OpenAI - format mismatch is out of scope for #5642. - """ - attempts_log = [] - best = None - for attempt_i in range(max_attempts): - attempt_seed = SEED + attempt_i - content, events = post_sse("/v1/chat/completions", { - "messages": [{"role": "user", "content": prompt}], - "enable_tools": True, - "enabled_tools": enabled, - "session_id": f"{session}-att{attempt_i}", - "temperature": TOOL_PROBE_TEMP, - "seed": attempt_seed, - "max_tokens": 600, - }) - invoked = _tool_invoked(events) - produced = _tool_output_contains(events, *needles) - attempts_log.append({ - "attempt": attempt_i, "seed": attempt_seed, - "n_events": len(events), - "tool_invoked": invoked, "tool_output_contains": produced, - "content_len": len(content), - }) - if invoked and produced: - print(f"[tools] PASS {label} attempt {attempt_i}") - return content, events, attempts_log - if invoked and best is None: - best = (content, events) - print(f"[tools] retry {label} attempt {attempt_i}: invoked={invoked} output_ok={produced} events={len(events)}") - if best is not None: - print(f"[tools] WARN {label}: invoked but no tool_end.result match (small-quant flake). Attempts: {attempts_log}") - content, events = best - return content, events, attempts_log - raise AssertionError( - f"{label}: no structural tool-invocation evidence across " - f"{max_attempts} attempts. enable_tools may be silently " - f"ignored. Attempts: {attempts_log}" - ) - - # ── 2. Server-side python tool ─────────────────────────────── - content, events, _attempts = _run_tool_probe( - label = "python tool", - prompt = "What is 123 * 456? Use the python tool to compute it and tell me the number.", - enabled = ["python"], - session = "ci-tool-calling-py", - needles = ("56088", "56,088"), - ) - if "56088" in content or "56,088" in content: - print(f"[tools] python tool narration OK") - else: - print(f"[tools] python tool narration drifted -- content={content!r}") - - # ── 3. Server-side bash (terminal) tool ────────────────────── - content, events, _attempts = _run_tool_probe( - label = "bash/terminal tool", - prompt = "Use the terminal tool to run `echo hello-bash-tool` and tell me the exact output.", - enabled = ["terminal"], - session = "ci-tool-calling-bash", - needles = ("hello-bash-tool",), - ) - if "hello-bash-tool" in content: - print(f"[tools] bash/terminal narration OK") - else: - print(f"[tools] bash/terminal narration dropped literal -- content={content!r}") - - # ── 4. Server-side web_search tool ─────────────────────────── - # DuckDuckGo is flaky from CI runners and small Qwen3.5-2B - # may not actually search. Only assert that the SSE stream - # opens and yields any data; HTTP / parser failures already - # raise above. Tool-invocation strictness is relaxed here - # because (a) the search may legitimately return no results, - # and (b) DuckDuckGo upstream blocks GHA IP ranges often - # enough that requiring a tool_call marker would create - # red-herring failures from infra rather than from Studio. - try: - content, events = post_sse("/v1/chat/completions", { - "messages": [{"role": "user", "content": "Search the web for 'unsloth ai github' and summarise."}], - "enable_tools": True, - "enabled_tools": ["web_search"], - "session_id": "ci-tool-calling-web", - "temperature": 0.0, - "seed": SEED, - "max_tokens": 400, - }) - print( - f"[tools] PASS web_search stream ({len(content)} chars in content, " - f"{len(events)} raw events)" - ) - except Exception as exc: - print(f"[tools] WARN web_search probe failed (non-blocking): {exc}") - - # ── 5. Thinking on / off ───────────────────────────────────── - # Studio strips think blocks from message.content for tools-mode - # responses, so we toggle plain chat (no enable_tools) and look - # at the surfaced reasoning_content / message.thinking field. - def thinking_call(enable): - status, data = post("/v1/chat/completions", { - "messages": [{"role": "user", "content": "Briefly: is 17 prime?"}], - "stream": False, - "enable_thinking": enable, - "temperature": 0.0, - "seed": SEED, - "max_tokens": 300, - }) - assert status == 200 - msg = data["choices"][0]["message"] - # Studio surfaces thinking via reasoning_content (OpenAI - # extension). Fall back to inline markers for - # robustness across template versions. - raw = (msg.get("content") or "") + (msg.get("reasoning_content") or "") - return raw - - on_text = thinking_call(True) - off_text = thinking_call(False) - had_think_on = ("" in on_text) or len(on_text) > 80 - had_think_off = ("" in off_text) and len(off_text) > 0 - assert had_think_on, ( - f"enable_thinking=True produced no thinking signal: {on_text!r}" - ) - # Off-mode should not contain the literal marker. - assert "" not in off_text, ( - f"enable_thinking=False but still present: {off_text!r}" - ) - print(f"[tools] PASS thinking on/off (on={len(on_text)} chars, off={len(off_text)} chars)") - PY - - - name: Stop Studio - if: always() - run: | - kill "${STUDIO_PID}" 2>/dev/null || true - sleep 2 - ss -tln | grep ":${STUDIO_PORT}" || true - - - name: Upload logs - # Always upload so green runs are still reviewable. - if: always() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: tool-calling-log - path: | - logs/studio.log - logs/install.log - retention-days: 7 - - # ───────────────────────────────────────────────────────────────────── - # Job 3: JSON, images - # ───────────────────────────────────────────────────────────────────── - json-images: - name: JSON, images - runs-on: ubuntu-latest - timeout-minutes: 30 - env: - GGUF_REPO: unsloth/gemma-4-E2B-it-GGUF - GGUF_VARIANT: UD-IQ3_XXS - GGUF_FILE: gemma-4-E2B-it-UD-IQ3_XXS.gguf - MMPROJ_FILE: mmproj-F16.gguf - STUDIO_PORT: '18890' - HF_HOME: ${{ github.workspace }}/hf-cache - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - - name: Linux deps for llama.cpp prebuilt - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends \ - libcurl4-openssl-dev libssl-dev jq - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '22' - - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: '3.12' - cache: 'pip' - - - name: Restore HF_HOME for ${{ env.GGUF_REPO }} (model + mmproj) - id: cache-hf - uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - continue-on-error: true - with: - path: hf-cache - key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-${{ env.MMPROJ_FILE }}-v1 - - - name: Prime HF_HOME with the GGUF + mmproj - id: prime-hf - if: steps.cache-hf.outputs.cache-hit != 'true' || steps.cache-hf.outcome != 'success' - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - python -m pip install --upgrade huggingface_hub - mkdir -p hf-cache - bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$GGUF_FILE" - bash .github/scripts/hf-download-with-retry.sh "$GGUF_REPO" "$MMPROJ_FILE" - - - name: Save HF_HOME for ${{ env.GGUF_REPO }} (model + mmproj) - if: always() && steps.prime-hf.outcome == 'success' - uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 - with: - path: hf-cache - key: ${{ runner.os }}-hf-${{ env.GGUF_REPO }}-${{ env.GGUF_VARIANT }}-${{ env.MMPROJ_FILE }}-v1 - - - name: Install Studio (--local, --no-torch) - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: | - mkdir -p logs - set -o pipefail - bash install.sh --local --no-torch 2>&1 | tee logs/install.log - - - name: Install OpenAI + Anthropic Python SDKs - run: pip install 'openai>=1.50' 'anthropic>=0.40' - - - name: Reset auth + boot Studio (API-only) - # See Job 2's comment: API-only mode keeps tool_policy=None so - # response_format requests aren't routed through the agentic - # tool loop. - run: | - unsloth studio reset-password - mkdir -p logs - UNSLOTH_API_ONLY=1 unsloth studio -H 127.0.0.1 -p "$STUDIO_PORT" \ - > logs/studio.log 2>&1 & - echo "STUDIO_PID=$!" >> "$GITHUB_ENV" - - - name: Wait for /api/health, log in, change password, load model - run: | - for i in $(seq 1 180); do - if curl -fs "http://127.0.0.1:${STUDIO_PORT}/api/health" > /tmp/health.json; then - jq -e '.status == "healthy"' /tmp/health.json && break - fi - sleep 1 - done - jq -e '.status == "healthy"' /tmp/health.json - OLD=$(cat ~/.unsloth/studio/auth/.bootstrap_password) - NEW="CIJson-$(python -c 'import secrets; print(secrets.token_urlsafe(12))')" - echo "::add-mask::$OLD" - echo "::add-mask::$NEW" - OLD_TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$OLD\"}" | jq -r .access_token) - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/change-password" \ - -H "Authorization: Bearer $OLD_TOKEN" -H 'content-type: application/json' \ - -d "{\"current_password\":\"$OLD\",\"new_password\":\"$NEW\"}" > /dev/null - TOKEN=$(curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/auth/login" \ - -H 'content-type: application/json' \ - -d "{\"username\":\"unsloth\",\"password\":\"$NEW\"}" | jq -r .access_token) - echo "API_KEY=$TOKEN" >> "$GITHUB_ENV" - # Load the GGUF (mmproj is auto-detected via the HF repo - # lookup, the cached file is pulled out of HF_HOME). - curl -fs -X POST "http://127.0.0.1:${STUDIO_PORT}/api/inference/load" \ - -H "Authorization: Bearer $TOKEN" -H 'content-type: application/json' \ - --max-time 900 \ - -d "{\"model_path\":\"$GGUF_REPO\",\"gguf_variant\":\"$GGUF_VARIANT\",\"is_lora\":false,\"max_seq_length\":2048}" \ - | jq '{status, display_name, is_vision}' - - - name: JSON schema decoding + image input - env: - BASE_URL: http://127.0.0.1:18890 - run: | - python - <<'PY' - import base64 - import json - import os - import urllib.request - from openai import OpenAI - from anthropic import Anthropic - - BASE = os.environ["BASE_URL"] - KEY = os.environ["API_KEY"] - SEED = 3407 - - def post(path, body, *, timeout = 240): - req = urllib.request.Request( - f"{BASE}{path}", - data = json.dumps(body).encode(), - method = "POST", - headers = { - "Authorization": f"Bearer {KEY}", - "Content-Type": "application/json", - }, - ) - with urllib.request.urlopen(req, timeout = timeout) as resp: - return resp.status, json.loads(resp.read().decode()) - - # ── 1. response_format = json_object (JSON mode) ───────────── - # llama.cpp's HTTP server supports OpenAI-compatible JSON - # mode: `response_format: {"type": "json_object"}` constrains - # the model to emit syntactically-valid JSON. We use raw HTTP - # rather than the OpenAI SDK so that the field shape Studio - # forwards to llama-server is unambiguous (the SDK rewrites - # response_format depending on which variant it recognises). - # We deliberately do NOT pass a strict JSON schema -- on - # small Gemma-4 quants the GBNF-from-schema path occasionally - # produces empty output, and JSON mode is the surface we care - # about exposing through Studio. - status, data = post("/v1/chat/completions", { - "model": "default", - "messages": [ - {"role": "system", "content": 'Reply with a single JSON object of the form {"city": "...", "country": "..."}. Output ONLY the JSON, nothing else.'}, - {"role": "user", "content": "What is the capital of France?"}, - ], - "temperature": 0.0, - "max_tokens": 200, - "seed": SEED, - "stream": False, - "enable_thinking": False, - "response_format": {"type": "json_object"}, - }, timeout = 600) - assert status == 200, f"json status {status}: {data}" - content = (data["choices"][0]["message"].get("content") or "").strip() - # Some chat templates wrap JSON in ```json fences even in JSON - # mode -- strip those before parsing. - if content.startswith("```"): - content = content.split("```", 2)[1] - if content.startswith("json"): - content = content[4:] - content = content.strip("`\n ") - parsed = json.loads(content) - assert "paris" in str(parsed.get("city", "")).lower(), ( - f"city != Paris: {parsed}" - ) - print(f"[json] PASS json_object -> {parsed}") - - # ── 2. OpenAI image_url (data URI base64) ─────────────────── - # 64x64 solid-red PNG. stb_image (used by Studio's image - # normaliser at routes/inference.py:3410) rejects 4x4 or - # smaller PNGs as truncated, so we go up to 64x64 -- still - # tiny in token cost. The assertion is loose: any non-empty - # response from the vision path proves multimodal end-to-end - # wiring; small VL quants are weak at colour identification. - PNG_64X64_RED_B64 = ( - "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAIAAAAlC+aJAAAAYklEQVR4nO3PMQ0AIADAMEAI/k" - "UhBhEcDcmqYJtn7/GzpQNeNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA" - "1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaA1oDWgNaBdCJ0BmMJ25zMAAAAASUVORK5CYII=" - ) - data_uri = f"data:image/png;base64,{PNG_64X64_RED_B64}" - - client = OpenAI(base_url = f"{BASE}/v1", api_key = KEY) - openai_resp = client.chat.completions.create( - model = "default", - temperature = 0.0, - max_tokens = 80, - seed = SEED, - messages = [{ - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": data_uri}}, - {"type": "text", "text": "What colour dominates this image? Reply in one word."}, - ], - }], - ) - openai_text = (openai_resp.choices[0].message.content or "").lower() - print(f"[image/openai] reply: {openai_text!r}") - assert openai_text, "OpenAI image_url returned empty content" - # We do not strictly require 'red' -- some quants of small VL - # models are weak at colour names. Just require a non-empty - # answer; the vision path is the part under test. - print("[image/openai] PASS image_url accepted, non-empty response") - - # ── 3. Anthropic source/base64 image ──────────────────────── - # Two SDK quirks vs. Studio: base_url must NOT include /v1 - # (the SDK appends it itself; otherwise /v1/v1/messages -> 405), - # and Studio's auth is HTTPBearer-only so the SDK's default - # x-api-key header is ignored -- send Authorization: Bearer - # via default_headers. - anthropic = Anthropic( - base_url = BASE, - api_key = "unused", - default_headers = {"Authorization": f"Bearer {KEY}"}, - ) - a_msg = anthropic.messages.create( - model = "default", - max_tokens = 80, - temperature = 0.0, - extra_body = {"seed": SEED}, - messages = [{ - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": PNG_64X64_RED_B64, - }, - }, - {"type": "text", "text": "Describe this image briefly."}, - ], - }], - ) - a_text = "".join(b.text for b in a_msg.content if getattr(b, "type", None) == "text") - print(f"[image/anthropic] reply: {a_text!r}") - assert a_text, "Anthropic source/base64 returned empty content" - print("[image/anthropic] PASS source/base64 accepted, non-empty response") - PY - - - name: Stop Studio - if: always() - run: | - kill "${STUDIO_PID}" 2>/dev/null || true - sleep 2 - ss -tln | grep ":${STUDIO_PORT}" || true - - - name: Upload logs - # Always upload so green runs are still reviewable. - if: always() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: json-images-log - path: | - logs/studio.log - logs/install.log - retention-days: 7 diff --git a/.github/workflows/studio-tauri-smoke.yml b/.github/workflows/studio-tauri-smoke.yml deleted file mode 100644 index 1156c264ae..0000000000 --- a/.github/workflows/studio-tauri-smoke.yml +++ /dev/null @@ -1,128 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -# PR-time smoke for the Tauri desktop wrapper. Builds the frontend and the -# Tauri Linux debug binary, with no codesigning. Catches: -# - tauri.conf.json drift -# - src-tauri Cargo.toml or rust source breakage -# - Tauri CLI version drift (we pin 2.10.1, matching release-desktop.yml) -# - frontend output not picked up by Tauri's distDir -# -# Linux-only on a free `ubuntu-latest` runner. Mac and Windows desktop builds -# stay in release-desktop.yml (manual `workflow_dispatch`) because they need -# code-signing secrets and ~30 min of runner time each. - -name: Studio Tauri CI - -on: - pull_request: - paths: - - 'studio/frontend/**' - - 'studio/src-tauri/**' - # CLI rename / signature change can break Tauri's spawned - # `unsloth studio` -- include unsloth_cli in the trigger set. - - 'unsloth_cli/**' - - '.github/workflows/studio-tauri-smoke.yml' - push: - branches: [main, pip] - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - linux-debug-build: - name: Tauri Linux debug build (no codesign) - runs-on: ubuntu-22.04 - timeout-minutes: 25 - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - - name: Linux native deps for Tauri / WebKit2GTK - run: | - sudo apt-get update - sudo apt-get install -y \ - libwebkit2gtk-4.1-dev libayatana-appindicator3-dev \ - librsvg2-dev libxdo-dev libssl-dev patchelf - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '24' - - - uses: dtolnay/rust-toolchain@29eef336d9b2848a0b548edc03f92a220660cdb8 # stable @ 2026-03-27 - - - uses: swatinem/rust-cache@e18b497796c12c097a38f9edb9d0641fb99eee32 # v2.9.1 - with: - workspaces: studio/src-tauri -> target - - - name: Install pinned Tauri CLI (matches release-desktop.yml) - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: npm install --save-dev --prefix studio @tauri-apps/cli@2.10.1 --no-fund --no-audit - - - name: Verify pinned Tauri CLI version - run: | - out="$(npx --prefix studio tauri --version)" - echo "$out" - [ "$out" = "tauri-cli 2.10.1" ] || { echo "::error::expected tauri-cli 2.10.1, got $out"; exit 1; } - - - name: Lockfile supply-chain audit (pre-install scan) - run: python3 scripts/lockfile_supply_chain_audit.py - - - name: Frontend build (npm ci, vite) - working-directory: studio/frontend - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: | - npm ci --no-fund --no-audit - npm run build - test -f dist/index.html - - - name: Tauri debug build (Linux, no bundle, no codesign) - # `--debug` + `--no-bundle` keeps this lean: compiles the Rust crate, - # confirms the frontend dist is wired into Tauri, but skips the AppImage - # / .deb production. Code signing is irrelevant because we never produce - # a distributable artifact. - env: - TAURI_SIGNING_PRIVATE_KEY: '' - TAURI_SIGNING_PRIVATE_KEY_PASSWORD: '' - run: npx --prefix studio tauri build --debug --no-bundle - - - name: Inspect produced binary - run: | - BIN=$(find studio/src-tauri/target/debug -maxdepth 1 -type f -executable 2>/dev/null \ - | grep -Ev '\.(d|so|dylib|dll)$' \ - | grep -Ev '/(deps|build|examples)$' \ - | head -1) - echo "binary: $BIN" - if [ -z "$BIN" ]; then - echo "::error::Tauri debug binary not produced" - ls -la studio/src-tauri/target/debug/ || true - exit 1 - fi - file "$BIN" - du -h "$BIN" - - - name: Upload Tauri debug build - # Always upload so a green run leaves the binary inspectable too. - if: always() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: tauri-debug-build - path: | - studio/src-tauri/target/debug - studio/frontend/dist - retention-days: 3 diff --git a/.github/workflows/wheel-smoke.yml b/.github/workflows/wheel-smoke.yml deleted file mode 100644 index 3de3c33ca2..0000000000 --- a/.github/workflows/wheel-smoke.yml +++ /dev/null @@ -1,136 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. - -# Builds the PyPI wheel from the PR branch, then verifies the built wheel -# actually contains what we expect to ship and does NOT contain the broken -# Studio bundle that 2026.5.1 published. This is the single workflow that -# would have blocked the 2026.5.1 release before twine upload. -# -# Verified locally end-to-end against this branch: -# - python -m build produces unsloth--py3-none-any.whl in 13s -# - wheel content sanity passes: -# lockfile shipped, frontend dist shipped, -# no node_modules in wheel, no bun.lock in wheel, -# main bundle has unstable_Provider hits=1 (assistant-ui internals only). -# - Studio backend imports cleanly from the installed wheel with the -# lightweight dep set below. - -name: Wheel CI - -on: - pull_request: - paths: - - 'pyproject.toml' - - 'studio/**' - - 'unsloth/**' - - 'unsloth_cli/**' - - '.github/workflows/wheel-smoke.yml' - push: - branches: [main, pip] - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -permissions: - contents: read - -jobs: - wheel: - name: Wheel build + content sanity + import smoke - runs-on: ubuntu-latest - timeout-minutes: 15 - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 - with: - node-version: '22' - - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: '3.12' - - - name: Lockfile supply-chain audit (pre-install scan) - run: python3 scripts/lockfile_supply_chain_audit.py - - - name: Build frontend - # Lifecycle scripts (esbuild native-binary postinstall, etc.) are - # required for `vite build`. The pre-install lockfile structural - # audit (lockfile_supply_chain_audit.py) is the practical defence - # against the npm postinstall-dropper class -- it fires BEFORE any - # tarball runs, on the injection pattern itself rather than an - # advisory-DB lookup. - run: | - cd studio/frontend - npm ci --no-fund --no-audit - npm run build - - - name: Build wheel + sdist - run: | - python -m pip install --upgrade pip build - rm -rf dist build ./*.egg-info - python -m build - - - name: Wheel content sanity - run: | - python - <<'PY' - import zipfile, glob, sys - w = glob.glob("dist/unsloth-*.whl") - if not w: - print("FAIL: no wheel produced"); sys.exit(2) - w = w[0] - print(f"wheel: {w}") - with zipfile.ZipFile(w) as z: - n = z.namelist() - checks = { - "lockfile shipped": any(s.endswith("studio/frontend/package-lock.json") for s in n), - "frontend dist shipped": any(s.endswith("studio/frontend/dist/index.html") for s in n), - "no node_modules": not any("studio/frontend/node_modules/" in s for s in n), - "no bun.lock": not any(s.endswith("studio/frontend/bun.lock") for s in n), - } - js = [s for s in n - if "studio/frontend/dist/assets/" in s - and s.endswith(".js") - and "/index-" in s] - if not js: - print("FAIL: no main bundle index-*.js in wheel"); sys.exit(2) - data = z.read(js[0]).decode("utf-8", "replace") - hits = data.count("unstable_Provider:") - print(f"main bundle: {js[0]}") - print(f"unstable_Provider hits: {hits} (>=4 indicates 2026.5.1 regression)") - checks["bundle has no Studio unstable_Provider call site"] = (hits < 4) - - print() - for k, v in checks.items(): - print(f" [{'PASS' if v else 'FAIL'}] {k}") - sys.exit(0 if all(checks.values()) else 1) - PY - - - name: Studio backend import smoke - # Imports `studio.backend.main:app` from the freshly-installed wheel in - # a clean venv. This catches the class of bug that 2026.5.1 shipped with: - # frontend dist missing, package-lock.json missing, or the wheel's Python - # source tree broken in a way that surfaces only at app construction time. - run: | - python -m venv /tmp/v - /tmp/v/bin/pip install --upgrade pip - /tmp/v/bin/pip install -r studio/backend/requirements/studio.txt - /tmp/v/bin/pip install \ - python-multipart aiofiles sqlalchemy cryptography \ - pyyaml jinja2 mammoth unpdf requests \ - 'numpy<3' - /tmp/v/bin/pip install --no-deps dist/unsloth-*.whl - # Run from /tmp so Python imports the installed package, not the source tree. - cd /tmp - /tmp/v/bin/python -c "from studio.backend.main import app; print('Studio backend OK:', app.title)" - - - name: Upload wheel on failure - if: failure() - uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 - with: - name: unsloth-wheel - path: dist/ - retention-days: 7 diff --git a/.gitignore b/.gitignore index a839633790..da33583e29 100644 --- a/.gitignore +++ b/.gitignore @@ -235,3 +235,9 @@ package-lock.json !studio/backend/core/data_recipe/oxc-validator/package-lock.json !studio/package-lock.json llama.cpp/ +/.omc +/studio/frontend/.omc +/.codex +/studio/.omc +/studio/backend/.omc +*.patch diff --git a/studio/backend/assets/configs/model_defaults/other/deepseek-ai_DeepSeek-OCR.yaml b/studio/backend/assets/configs/model_defaults/other/deepseek-ai_DeepSeek-OCR.yaml new file mode 100644 index 0000000000..b827a1f910 --- /dev/null +++ b/studio/backend/assets/configs/model_defaults/other/deepseek-ai_DeepSeek-OCR.yaml @@ -0,0 +1,22 @@ +# Model defaults for deepseek-ai/DeepSeek-OCR +# Custom-code OCR vision model. Used by Studio chat as a temporary OCR +# model swap during scanned-PDF extraction; never used for training. + +model: + identifier: deepseek-ai/DeepSeek-OCR + display_name: DeepSeek-OCR + is_vision: true + is_ocr: true + +training: + trust_remote_code: true + max_seq_length: 8192 + packing: false + +inference: + trust_remote_code: true + temperature: 0.0 + top_p: 1.0 + top_k: -1 + min_p: 0.0 + default_max_seq_length: 8192 diff --git a/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml b/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml index b7587bbd91..bffb79902c 100644 --- a/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml +++ b/studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml @@ -3,6 +3,12 @@ # Also applies to: unsloth/PaddleOCR-VL # added inference parameters from unsloth notebook +model: + identifier: unsloth/PaddleOCR-VL + display_name: PaddleOCR-VL + is_vision: true + is_ocr: true + training: trust_remote_code: true max_seq_length: 2048 @@ -50,6 +56,11 @@ logging: inference: trust_remote_code: true - temperature: 1.5 - min_p: 0.1 + # OCR is a closed-form transcription task; sibling OCR presets + # (DeepSeek-OCR, GLM-OCR) use deterministic decoding so the + # transcription is reproducible. Match that convention here. + temperature: 0.0 + min_p: 0.0 + top_p: 1.0 + top_k: -1 diff --git a/studio/backend/assets/configs/model_defaults/other/zai-org_GLM-OCR.yaml b/studio/backend/assets/configs/model_defaults/other/zai-org_GLM-OCR.yaml new file mode 100644 index 0000000000..2249aa4487 --- /dev/null +++ b/studio/backend/assets/configs/model_defaults/other/zai-org_GLM-OCR.yaml @@ -0,0 +1,22 @@ +# Model defaults for zai-org/GLM-OCR +# GLM family OCR vision model with model_type "glm_ocr". Used by Studio chat +# as a temporary OCR model swap during scanned-PDF extraction. + +model: + identifier: zai-org/GLM-OCR + display_name: GLM-OCR + is_vision: true + is_ocr: true + +training: + trust_remote_code: true + max_seq_length: 8192 + packing: false + +inference: + trust_remote_code: true + temperature: 0.0 + top_p: 1.0 + top_k: -1 + min_p: 0.0 + default_max_seq_length: 8192 diff --git a/studio/backend/core/chat/__init__.py b/studio/backend/core/chat/__init__.py new file mode 100644 index 0000000000..8ce71de2e8 --- /dev/null +++ b/studio/backend/core/chat/__init__.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +""" +Chat-surface helpers that do not belong in ``core/inference`` (tightly +coupled to model backends) and explicitly not in ``core/data_recipe`` +(owns dataset pipelines). + +Exposes the document-extraction pipeline used when a user drops a +PDF / DOCX / HTML / MD / TXT file into the chat composer. PDF parsing +uses PyMuPDF4LLM, DOCX uses mammoth. PPTX is not supported here — +convert to PDF first. +""" + +from __future__ import annotations + +from .document_extractor import ( + DOCUMENT_EXTRACTION_AVAILABLE, + DEFAULT_DOCUMENT_VISUAL_PAYLOADS, + DocumentExtractionBusy, + DocumentExtractionCancelled, + DocumentExtractionEncrypted, + DocumentExtractionTimeout, + DocumentExtractionUnavailable, + ExtractedFigure, + ExtractResult, + _EXTRACT_CONCURRENCY, + MAX_DOCUMENT_VISUAL_PAYLOADS, + SUPPORTED_MIME_TYPES, + SUPPORTED_SUFFIXES, + _EXTRACT_SEMAPHORE, + _drain_future_exception, + document_parser_support, + document_parser_unavailable_reasons, + extract_document, +) +from .vlm_capability import ( + VlmCapability, + detect_loaded_vlm, + extract_self_base_url, +) + +__all__ = [ + "DOCUMENT_EXTRACTION_AVAILABLE", + "DEFAULT_DOCUMENT_VISUAL_PAYLOADS", + "DocumentExtractionBusy", + "DocumentExtractionCancelled", + "DocumentExtractionEncrypted", + "DocumentExtractionTimeout", + "DocumentExtractionUnavailable", + "ExtractedFigure", + "ExtractResult", + "_EXTRACT_CONCURRENCY", + "MAX_DOCUMENT_VISUAL_PAYLOADS", + "SUPPORTED_MIME_TYPES", + "SUPPORTED_SUFFIXES", + "VlmCapability", + "_EXTRACT_SEMAPHORE", + "_drain_future_exception", + "detect_loaded_vlm", + "document_parser_support", + "document_parser_unavailable_reasons", + "extract_document", + "extract_self_base_url", +] diff --git a/studio/backend/core/chat/document_extractor.py b/studio/backend/core/chat/document_extractor.py new file mode 100644 index 0000000000..915fc596c2 --- /dev/null +++ b/studio/backend/core/chat/document_extractor.py @@ -0,0 +1,1243 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +""" +Document extractor for the Chat composer. + +Given raw file bytes (PDF / DOCX / HTML / MD / TXT), produce Markdown +suitable to splice into an outgoing chat message. When a vision-capable +model is loaded, selected figures are captioned through our OpenAI-compatible +``/v1/chat/completions`` surface after conversion. + +This build uses **PyMuPDF4LLM** (via ``pymupdf4llm`` / ``pymupdf``) for PDF +parsing and **mammoth** for DOCX conversion. Plain-text and Markdown inputs +are decoded as UTF-8 with replacement; HTML inputs are converted to Markdown. + +Notes and limitations: + +* **OCR is disabled.** There is no local OCR pass in this build, so scanned + PDFs without a text layer will yield empty or near-empty Markdown. The + ``use_vlm_ocr`` flag is still accepted for API compatibility; when set it + renders bounded page images so a loaded vision model can describe them. +* **PPTX is not supported** in this build. ``SUPPORTED_SUFFIXES`` and + ``SUPPORTED_MIME_TYPES`` no longer advertise the PowerPoint types. +* Parser dependencies are checked per format so plain-text, Markdown, and HTML + still work when optional PDF or DOCX libraries are missing. +* If the loaded model is not vision-capable, image description is silently + skipped and ``figures`` comes back with captions set to ``None``; + ``describe_skipped_reason`` carries the diagnostic text. +""" + +from __future__ import annotations + +import asyncio +import base64 +import inspect +import io +import logging +import math +import multiprocessing +import os +import queue +import threading +import time +from dataclasses import dataclass, field, replace +from typing import Any, Awaitable, Callable, Literal, List, Optional + +from .vlm_capability import VlmCapability, detect_loaded_vlm + + +logger = logging.getLogger(__name__) + + +SUPPORTED_MIME_TYPES = frozenset( + { + "application/pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/json", + "application/x-ndjson", + "application/xml", + "application/yaml", + "application/javascript", + "text/html", + "text/markdown", + "text/plain", + "text/csv", + "text/css", + "text/javascript", + "text/xml", + "text/yaml", + } +) + +SUPPORTED_SUFFIXES = frozenset( + { + ".pdf", + ".docx", + ".html", + ".htm", + ".md", + ".txt", + ".csv", + ".json", + ".jsonl", + ".yaml", + ".yml", + ".py", + ".js", + ".jsx", + ".ts", + ".tsx", + ".go", + ".rs", + ".java", + ".c", + ".cpp", + ".h", + ".hpp", + ".cs", + ".php", + ".rb", + ".swift", + ".kt", + ".kts", + ".scala", + ".sh", + ".bash", + ".zsh", + ".ps1", + ".sql", + ".toml", + ".ini", + ".cfg", + ".log", + ".xml", + ".css", + ".scss", + } +) + + +_DESCRIBE_PROMPT = ( + "Describe this figure in <=60 words. Focus on factual content " + "(axes, labels, captions, visible text, main objects). Do not " + "speculate beyond what is visible." +) + + +DEFAULT_DOCUMENT_VISUAL_PAYLOADS = 3 +MAX_DOCUMENT_VISUAL_PAYLOADS = 10 +_MAX_ENCODED_VISUALS = DEFAULT_DOCUMENT_VISUAL_PAYLOADS +_EXTRACT_TIMEOUT_SECONDS = 120 +_VLM_CAPTION_TOTAL_TIMEOUT_SECONDS = 180 +_LOCAL_VLM_CAPTION_CONCURRENCY = 1 +_DEFAULT_VLM_CAPTION_CONCURRENCY = 3 +_EXTRACT_CONCURRENCY = max( + 1, int(os.environ.get("UNSLOTH_STUDIO_EXTRACT_CONCURRENCY", "2")) +) +_EXTRACT_SEMAPHORE = threading.BoundedSemaphore(_EXTRACT_CONCURRENCY) +# Bounded queue wait: callers park here for a slot instead of failing fast +# with 503 when the worker pool is saturated. Tuned so a fast burst (e.g. +# multi-select 4 PDFs) drains naturally without surfacing busy errors, +# while truly stuck workers still time out via _EXTRACT_TIMEOUT_SECONDS. +_EXTRACT_QUEUE_WAIT_SECONDS = max( + 0.0, + float(os.environ.get("UNSLOTH_STUDIO_EXTRACT_QUEUE_WAIT", "60")), +) +_PAGE_RENDER_DPI = 150 +_MAX_PAGE_RENDER_PIXELS = 4_000_000 +_MIME_TO_SUFFIX = { + "application/pdf": ".pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/json": ".json", + "application/x-ndjson": ".jsonl", + "application/xml": ".xml", + "application/yaml": ".yaml", + "application/javascript": ".js", + "text/html": ".html", + "text/markdown": ".md", + "text/plain": ".txt", + "text/csv": ".csv", + "text/css": ".css", + "text/javascript": ".js", + "text/xml": ".xml", + "text/yaml": ".yaml", +} + +_PLAIN_TEXT_SUFFIXES = SUPPORTED_SUFFIXES - {".pdf", ".docx", ".html", ".htm"} + + +def _normalized_suffix(filename: str, content_type: str = "") -> str: + suffix = os.path.splitext(filename)[1].lower() + if suffix in SUPPORTED_SUFFIXES: + return suffix + mime = (content_type or "").split(";", 1)[0].strip().lower() + return _MIME_TO_SUFFIX.get(mime, suffix) + + +class DocumentExtractionUnavailable(RuntimeError): + """Document extraction backend is not installed or failed to import. + + The backend is PyMuPDF4LLM + mammoth for parsed document formats. + """ + + +class DocumentExtractionTimeout(RuntimeError): + """Raised when document parsing exceeds the 120-second worker limit.""" + + +class DocumentExtractionBusy(RuntimeError): + """Raised when the bounded document extraction worker pool is saturated.""" + + +class DocumentExtractionCancelled(RuntimeError): + """Raised when the caller cancels an in-flight extraction.""" + + +class DocumentExtractionEncrypted(RuntimeError): + """Raised when a PDF is encrypted and cannot be parsed without a password.""" + + +try: # pragma: no cover - presence depends on optional install + import pymupdf # type: ignore + import pymupdf4llm # type: ignore +except Exception as _pdf_extract_exc: # pragma: no cover + pymupdf = None # type: ignore[assignment] + pymupdf4llm = None # type: ignore[assignment] + _PDF_EXTRACTION_IMPORT_ERROR: Optional[BaseException] = _pdf_extract_exc +else: + _PDF_EXTRACTION_IMPORT_ERROR = None + +try: # pragma: no cover - presence depends on optional install + import mammoth # type: ignore +except Exception as _docx_extract_exc: # pragma: no cover + mammoth = None # type: ignore[assignment] + _DOCX_EXTRACTION_IMPORT_ERROR: Optional[BaseException] = _docx_extract_exc +else: + _DOCX_EXTRACTION_IMPORT_ERROR = None + +# The dispatcher can still extract plain text / code / data files when PDF or +# DOCX optional parsers are missing. Format-specific helpers raise +# DocumentExtractionUnavailable only when that format is actually requested. +DOCUMENT_EXTRACTION_AVAILABLE = True +_DOCUMENT_EXTRACTION_IMPORT_ERROR: Optional[BaseException] = ( + _PDF_EXTRACTION_IMPORT_ERROR or _DOCX_EXTRACTION_IMPORT_ERROR +) + + +def document_parser_support() -> dict[str, bool]: + return { + "pdf": _PDF_EXTRACTION_IMPORT_ERROR is None, + "docx": _DOCX_EXTRACTION_IMPORT_ERROR is None, + "html": True, + "text": True, + "data": True, + "code": True, + } + + +def document_parser_unavailable_reasons() -> dict[str, str]: + reasons: dict[str, str] = {} + if _PDF_EXTRACTION_IMPORT_ERROR is not None: + reasons["pdf"] = "PDF extraction requires pymupdf and pymupdf4llm." + if _DOCX_EXTRACTION_IMPORT_ERROR is not None: + reasons["docx"] = "DOCX extraction requires mammoth." + return reasons + + +@dataclass +class ExtractedFigure: + id: str + page: Optional[int] + caption: Optional[str] + error: Optional[str] = None + kind: Literal["figure", "page"] = "figure" + image_mime: Optional[str] = None + image_base64: Optional[str] = None + image_width: Optional[int] = None + image_height: Optional[int] = None + + +@dataclass +class ExtractResult: + markdown: str + figures: List[ExtractedFigure] = field(default_factory = list) + page_count: int = 0 + tokens_est: int = 0 + describe_skipped_reason: Optional[str] = None + vlm_source: Optional[str] = None + vlm_model: Optional[str] = None + image_input_available: bool = False + warnings: List[str] = field(default_factory = list) + + +ProgressCb = Callable[[dict], Awaitable[None]] + + +def _ensure_pdf_backend() -> None: + if pymupdf is None or pymupdf4llm is None: + if _PDF_EXTRACTION_IMPORT_ERROR is not None: + logger.debug( + "PDF extraction parser import failed: %s", + _PDF_EXTRACTION_IMPORT_ERROR, + ) + raise DocumentExtractionUnavailable( + "PDF extraction requires pymupdf and pymupdf4llm. Re-run Studio " + "setup to install the parser dependencies from " + "studio/backend/requirements/single-env/data-designer-deps.txt" + ) + + +def _ensure_docx_backend() -> None: + if mammoth is None: + if _DOCX_EXTRACTION_IMPORT_ERROR is not None: + logger.debug( + "DOCX extraction parser import failed: %s", + _DOCX_EXTRACTION_IMPORT_ERROR, + ) + raise DocumentExtractionUnavailable( + "DOCX extraction requires mammoth. Re-run Studio setup to install " + "the parser dependencies from " + "studio/backend/requirements/single-env/data-designer-deps.txt" + ) + + +def _estimate_tokens(text: str) -> int: + return max(0, len(text) // 4) + + +def _encode_pil_image_for_chat( + image: Any, +) -> tuple[Optional[str], Optional[int], Optional[int], Optional[str]]: + if image is None: + return None, None, None, None + try: + from PIL import Image as PILImage + + img = image.copy() + img.thumbnail((1600, 1600)) + if img.mode in ("RGBA", "LA"): + background = PILImage.new("RGB", img.size, (255, 255, 255)) + alpha = img.getchannel("A") + background.paste(img.convert("RGB"), mask = alpha) + img = background + elif img.mode != "RGB": + img = img.convert("RGB") + + out = io.BytesIO() + img.save(out, format = "JPEG", quality = 88, optimize = True) + encoded = base64.b64encode(out.getvalue()).decode("ascii") + return encoded, img.width, img.height, "image/jpeg" + except (ImportError, AttributeError, ValueError, OSError) as exc: + logger.warning("Failed to encode extracted document image", exc_info = exc) + return None, None, None, None + + +async def _describe_image_via_vlm( + *, + image_base64: str, + image_mime: str, + endpoint_url: str, + model_name: str, + authorization_header: Optional[str], + timeout_seconds: float, +) -> tuple[Optional[str], Optional[str]]: + try: + import httpx + except Exception as exc: + return None, f"httpx unavailable: {exc}" + + headers = {"Content-Type": "application/json"} + if authorization_header: + headers["Authorization"] = authorization_header + + data_url = f"data:{image_mime};base64,{image_base64}" + payload = { + "model": model_name, + "stream": False, + "max_tokens": 512, + "temperature": 0.2, + "top_p": 0.9, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": _DESCRIBE_PROMPT}, + {"type": "image_url", "image_url": {"url": data_url}}, + ], + } + ], + } + try: + async with httpx.AsyncClient(timeout = timeout_seconds) as client: + response = await client.post( + endpoint_url.rstrip("/") + "/v1/chat/completions", + headers = headers, + json = payload, + ) + if response.status_code >= 400: + return None, ( + f"VLM caption request failed with HTTP " f"{response.status_code}" + ) + body = response.json() + choice = (body.get("choices") or [{}])[0] + message = choice.get("message") or {} + finish_reason = choice.get("finish_reason") + + # Some chat templates (Gemma 3/3n via llama-server, Qwen3 always-think) + # route the entire visible reply into ``reasoning_content`` and leave + # ``content`` empty. The chat UI handles this in its streaming + # consumer (see ``llama_cpp._chat_completion``); mirror that fallback + # here so non-streaming callers see the same answer. + candidates: list[Any] = [ + message.get("content"), + message.get("reasoning_content"), + message.get("text"), + ] + # Some servers return content as a list of parts (OpenAI multimodal); + # join any text parts into one string before checking emptiness. + normalized: list[str] = [] + for raw in candidates: + if isinstance(raw, str): + if raw.strip(): + normalized.append(raw.strip()) + elif isinstance(raw, list): + parts = [ + part.get("text", "") + for part in raw + if isinstance(part, dict) and isinstance(part.get("text"), str) + ] + joined = "".join(parts).strip() + if joined: + normalized.append(joined) + + if not normalized: + logger.warning( + "VLM caption empty: finish_reason=%r message_keys=%s", + finish_reason, + list(message.keys()), + ) + return None, (f"VLM caption empty (finish_reason={finish_reason!r})") + # Prefer the first non-empty candidate + # (content > reasoning_content > text). + return normalized[0], None + except Exception as exc: + logger.debug("VLM caption request failed", exc_info = True) + return None, f"VLM caption request failed: {type(exc).__name__}" + + +def _build_extract_options( + *, + extract_images: bool, + use_vlm_ocr: bool, + max_visual_payloads: int, +) -> tuple[dict, list[str]]: + """Return ``(options, build_warnings)``. + + The options dict is a simple bag of flags consumed by the synchronous + extract dispatcher. There is no local OCR pass available in this build; + ``use_vlm_ocr=True`` is implemented as a bounded full-page visual + extraction fallback for VLM captioning. + """ + build_warnings: list[str] = [] + if use_vlm_ocr: + build_warnings.append( + "Full-page OCR was requested, but this build has no local OCR " + "engine; rendered page images will be sent to the loaded vision " + "model when image description is enabled." + ) + options = { + "extract_images": bool(extract_images), + "use_vlm_ocr": bool(use_vlm_ocr), + "max_visual_payloads": max(0, max_visual_payloads), + } + return options, build_warnings + + +def _pymupdf4llm_markdown_kwargs() -> dict[str, Any]: + """Return kwargs supported by the installed pymupdf4llm.to_markdown().""" + preferred = { + "write_images": False, + "show_progress": False, + "ignore_images": True, + "table_strategy": "lines_strict", + "use_ocr": False, + "force_ocr": False, + } + try: + signature = inspect.signature(pymupdf4llm.to_markdown) + except (TypeError, ValueError): + return { + key: value + for key, value in preferred.items() + if key not in {"use_ocr", "force_ocr"} + } + params = signature.parameters + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()): + return preferred + return {key: value for key, value in preferred.items() if key in params} + + +def _safe_page_pixmap(page: Any) -> Any: + rect = getattr(page, "rect", None) + width_pt = max(float(getattr(rect, "width", 0) or 0), 1.0) + height_pt = max(float(getattr(rect, "height", 0) or 0), 1.0) + scale = _PAGE_RENDER_DPI / 72.0 + projected_pixels = width_pt * scale * height_pt * scale + if projected_pixels > _MAX_PAGE_RENDER_PIXELS: + scale *= math.sqrt(_MAX_PAGE_RENDER_PIXELS / projected_pixels) + scale = max(scale, 0.05) + matrix = pymupdf.Matrix(scale, scale) # type: ignore[union-attr] + return page.get_pixmap(matrix = matrix, alpha = False) + + +def _append_page_image_figure( + doc: Any, + figures_out: list[ExtractedFigure], + *, + page_index: int, + max_figures: int, + encode_image: bool = True, +) -> bool: + if len(figures_out) >= max_figures: + return False + if not encode_image: + figures_out.append( + ExtractedFigure( + id = f"page-{page_index + 1}", + page = page_index + 1, + caption = None, + error = None, + kind = "page", + ) + ) + return True + try: + from PIL import Image as PILImage + + pix = _safe_page_pixmap(doc[page_index]) + png_bytes = pix.tobytes("png") + page_image = PILImage.open(io.BytesIO(png_bytes)) + image_base64, image_width, image_height, image_mime = ( + _encode_pil_image_for_chat(page_image) + ) + if not image_base64: + return False + figures_out.append( + ExtractedFigure( + id = f"page-{page_index + 1}", + page = page_index + 1, + caption = None, + error = None, + kind = "page", + image_mime = image_mime, + image_base64 = image_base64, + image_width = image_width, + image_height = image_height, + ) + ) + return True + except ( + ImportError, + MemoryError, + OverflowError, + ValueError, + OSError, + RuntimeError, + ) as exc: + logger.warning( + "Failed to render page %d preview for PDF", + page_index + 1, + exc_info = exc, + ) + return False + + +def _extract_pdf( + file_bytes: bytes, + max_figures: int, + use_vlm_ocr: bool, + max_visual_payloads: int, +) -> tuple[str, list[ExtractedFigure], int, int, int]: + """Extract Markdown + figures from a PDF via PyMuPDF4LLM. + + Returns ``(markdown, figures, page_count, truncated_count, seen)``. + """ + _ensure_pdf_backend() + assert pymupdf is not None and pymupdf4llm is not None # for type-checkers + + doc = pymupdf.open(stream = file_bytes, filetype = "pdf") + try: + # ``is_encrypted`` is True for any file with an /Encrypt dict + # (very common for Acrobat-distilled PDFs, scanner output, the + # classic Orimi test file). ``needs_pass`` is the real "user + # password required" signal. Refuse extraction only when an + # actual password is missing. + if getattr(doc, "needs_pass", False): + raise DocumentExtractionEncrypted( + "Encrypted PDF; provide a password before extracting it." + ) + markdown = pymupdf4llm.to_markdown(doc, **_pymupdf4llm_markdown_kwargs()) + + figures_out: list[ExtractedFigure] = [] + encoded_visuals = 0 + seen = 0 + truncated_count = 0 + page_count = len(doc) + + if max_figures > 0 and page_count > 0: + if use_vlm_ocr: + for page_index in range(page_count): + if len(figures_out) >= max_figures: + truncated_count += page_count - page_index + break + if _append_page_image_figure( + doc, + figures_out, + page_index = page_index, + max_figures = max_figures, + encode_image = encoded_visuals < max_visual_payloads, + ): + if figures_out[-1].image_base64: + encoded_visuals += 1 + seen += 1 + elif _append_page_image_figure( + doc, + figures_out, + page_index = 0, + max_figures = max_figures, + encode_image = encoded_visuals < max_visual_payloads, + ): + if figures_out[-1].image_base64: + encoded_visuals += 1 + + if not use_vlm_ocr: + try: + from PIL import Image as PILImage + + for page_index in range(page_count): + page = doc[page_index] + try: + images = page.get_images(full = True) + except (ValueError, RuntimeError) as exc: + logger.debug( + "page.get_images failed on page %d", + page_index + 1, + exc_info = exc, + ) + continue + for img_info in images: + xref = img_info[0] if img_info else 0 + if not xref: + continue + try: + extracted = doc.extract_image(xref) + except (ValueError, RuntimeError) as exc: + logger.debug( + "doc.extract_image failed for xref %s", + xref, + exc_info = exc, + ) + continue + if not extracted: + continue + raw_bytes = extracted.get("image") + if not raw_bytes: + continue + try: + pil_img = PILImage.open(io.BytesIO(raw_bytes)) + pil_img.load() + except (OSError, ValueError) as exc: + logger.debug( + "PIL failed to decode extracted image xref %s", + xref, + exc_info = exc, + ) + continue + if pil_img.width < 50 or pil_img.height < 50: + continue + seen += 1 + if len(figures_out) >= max_figures: + truncated_count += 1 + continue + image_base64 = None + image_width = None + image_height = None + image_mime = None + if encoded_visuals < max_visual_payloads: + ( + image_base64, + image_width, + image_height, + image_mime, + ) = _encode_pil_image_for_chat(pil_img) + if image_base64: + encoded_visuals += 1 + figures_out.append( + ExtractedFigure( + id = f"fig-{len(figures_out)}", + page = page_index + 1, + caption = None, + error = None, + kind = "figure", + image_mime = image_mime, + image_base64 = image_base64, + image_width = image_width, + image_height = image_height, + ) + ) + except ImportError as exc: + logger.warning( + "Pillow is unavailable; skipping embedded-image extraction", + exc_info = exc, + ) + + return markdown, figures_out, page_count, truncated_count, seen + finally: + try: + doc.close() + except Exception: # pragma: no cover - defensive + logger.debug("pymupdf doc.close() raised", exc_info = True) + + +def _extract_docx( + file_bytes: bytes, +) -> tuple[str, list[ExtractedFigure], int, int, int]: + _ensure_docx_backend() + assert mammoth is not None # for type-checkers + stream = io.BytesIO(file_bytes) + result = mammoth.convert_to_markdown(stream) + markdown = result.value or "" + return markdown, [], 0, 0, 0 + + +def _extract_plaintext( + file_bytes: bytes, +) -> tuple[str, list[ExtractedFigure], int, int, int]: + text = file_bytes.decode("utf-8", errors = "replace") + return text, [], 0, 0, 0 + + +def _extract_html( + file_bytes: bytes, +) -> tuple[str, list[ExtractedFigure], int, int, int]: + html = file_bytes.decode("utf-8", errors = "replace") + try: + from core.inference._html_to_md import html_to_markdown + except Exception as exc: + logger.warning( + "HTML-to-Markdown converter unavailable; using raw HTML", + exc_info = exc, + ) + return html, [], 0, 0, 0 + return html_to_markdown(html), [], 0, 0, 0 + + +def _run_extract_sync( + file_bytes: bytes, + filename: str, + options: dict, + content_type: str = "", +) -> tuple[str, list[ExtractedFigure], int, int, int]: + """Synchronous dispatch by file suffix. + + Returns ``(markdown, figures, page_count, truncated_count, seen)``. + """ + suffix = _normalized_suffix(filename, content_type) + extract_images = bool(options.get("extract_images")) + use_vlm_ocr = bool(options.get("use_vlm_ocr")) + max_figures = int(options.get("max_figures", 0)) if extract_images else 0 + max_visual_payloads = int( + options.get("max_visual_payloads", DEFAULT_DOCUMENT_VISUAL_PAYLOADS) + ) + + if suffix == ".pdf": + return _extract_pdf(file_bytes, max_figures, use_vlm_ocr, max_visual_payloads) + if suffix == ".docx": + return _extract_docx(file_bytes) + if suffix in {".html", ".htm"}: + return _extract_html(file_bytes) + if suffix in _PLAIN_TEXT_SUFFIXES: + return _extract_plaintext(file_bytes) + raise ValueError(f"Unsupported file type: {filename}") + + +_RUN_EXTRACT_SYNC_ORIGINAL = _run_extract_sync + + +def _run_extract_worker( + result_queue: Any, + file_bytes: bytes, + filename: str, + options: dict, + content_type: str, +) -> None: + try: + result_queue.put( + ("ok", _run_extract_sync(file_bytes, filename, options, content_type)) + ) + except DocumentExtractionUnavailable as exc: + result_queue.put(("extraction_unavailable", str(exc))) + except DocumentExtractionEncrypted as exc: + result_queue.put(("encrypted", str(exc))) + except ValueError as exc: + result_queue.put(("value_error", str(exc))) + except BaseException as exc: + result_queue.put(("error", type(exc).__name__, str(exc))) + + +def _drain_future_exception(fut: Any) -> None: + """Retrieve a future's exception (if any) so asyncio's gc-time + "Future exception was never retrieved" warning stays quiet when the + awaiting task is cancelled mid-flight (e.g. client disconnect or + AbortController abort).""" + try: + if fut.cancelled(): + return + fut.exception() + except BaseException: + # Never let a drain hook itself raise — best effort only. + pass + + +def _terminate_extract_process(proc: multiprocessing.Process) -> None: + if not proc.is_alive(): + return + proc.terminate() + proc.join(5) + if proc.is_alive() and hasattr(proc, "kill"): + proc.kill() + proc.join(2) + + +def _run_extract_process_sync( + file_bytes: bytes, + filename: str, + options: dict, + content_type: str, + timeout_seconds: int, + cancel_event: Optional[threading.Event] = None, +) -> tuple[str, list[ExtractedFigure], int, int, int]: + if cancel_event is not None and cancel_event.is_set(): + raise DocumentExtractionCancelled("document extraction was cancelled") + # Park up to _EXTRACT_QUEUE_WAIT_SECONDS waiting for a slot, polling + # cancel_event so a client disconnect during the wait short-circuits + # cleanly instead of holding the request open. + deadline = time.monotonic() + _EXTRACT_QUEUE_WAIT_SECONDS + acquired = _EXTRACT_SEMAPHORE.acquire(blocking = False) + while True: + if acquired: + break + if cancel_event is not None and cancel_event.is_set(): + raise DocumentExtractionCancelled("document extraction was cancelled") + remaining = deadline - time.monotonic() + if remaining <= 0: + break + wait = min(remaining, 0.5) + if _EXTRACT_SEMAPHORE.acquire(timeout = wait): + acquired = True + break + if not acquired: + raise DocumentExtractionBusy("document extraction is busy") + + # Everything past the semaphore acquisition must live inside the + # try/finally so the slot is released even if multiprocessing + # context creation / Queue allocation / Process construction + # itself raises (e.g. OSError on fork-resource exhaustion, EAGAIN + # on Windows under load). + result_queue = None + proc = None + try: + # Prefer "fork" only on Linux. macOS defaults to "spawn" in + # modern Python because Objective-C runtimes (loaded by + # PyMuPDF/CoreFoundation/Quartz) crash under fork. Windows has + # never supported fork. + import sys as _sys + if os.name == "nt" or _sys.platform == "darwin": + mp_method = "spawn" + else: + mp_method = "fork" + ctx = multiprocessing.get_context(mp_method) + result_queue = ctx.Queue(maxsize = 1) + proc = ctx.Process( + target = _run_extract_worker, + args = ( + result_queue, + file_bytes, + filename, + options, + content_type, + ), + daemon = True, + ) + if cancel_event is not None and cancel_event.is_set(): + raise DocumentExtractionCancelled("document extraction was cancelled") + proc.start() + deadline = time.monotonic() + timeout_seconds + message = None + while message is None: + try: + message = result_queue.get(timeout = 0.1) + break + except queue.Empty: + if cancel_event is not None and cancel_event.is_set(): + _terminate_extract_process(proc) + raise DocumentExtractionCancelled( + "document extraction was cancelled" + ) + if not proc.is_alive(): + # The worker may have put its result and exited + # between the queue.get timeout and this is_alive + # check. Drain the queue once more before declaring + # failure so a successful extraction is not lost. + try: + message = result_queue.get_nowait() + except queue.Empty: + pass + break + if time.monotonic() >= deadline: + _terminate_extract_process(proc) + raise DocumentExtractionTimeout( + "document parsing exceeded the 120-second worker limit" + ) + + proc.join(2) + if proc.is_alive(): + proc.terminate() + proc.join(2) + if message is None: + # One more attempt after the join completes; covers the + # case where the worker exited cleanly with a result still + # queued. + try: + message = result_queue.get_nowait() + except queue.Empty: + pass + if message is None: + raise RuntimeError( + f"document extraction worker exited without a result " + f"(exitcode={proc.exitcode})" + ) + + kind = message[0] + if kind == "ok": + return message[1] + if kind == "extraction_unavailable": + raise DocumentExtractionUnavailable(message[1]) + if kind == "encrypted": + raise DocumentExtractionEncrypted(message[1]) + if kind == "value_error": + raise ValueError(message[1]) + if kind == "error": + raise RuntimeError(f"{message[1]}: {message[2]}") + raise RuntimeError(f"unexpected document worker result: {kind!r}") + finally: + if proc is not None: + try: + _terminate_extract_process(proc) + except Exception: + pass + if result_queue is not None: + try: + result_queue.close() + result_queue.join_thread() + except Exception: + pass + _EXTRACT_SEMAPHORE.release() + + +async def extract_document( + file_bytes: bytes, + filename: str, + *, + content_type: str = "", + describe_images: bool = True, + use_vlm_ocr: bool = False, + max_figures: int = 40, + max_visual_payloads: int = DEFAULT_DOCUMENT_VISUAL_PAYLOADS, + vlm_timeout_seconds: float = 60.0, + capability: Optional[VlmCapability] = None, + self_base_url: Optional[str] = None, + authorization_header: Optional[str] = None, + progress_cb: Optional[ProgressCb] = None, + cancel_event: Optional[threading.Event] = None, +) -> ExtractResult: + """Extract layout-aware Markdown plus figure metadata. + + When ``describe_images`` is True and the active model is + vision-capable, the selected visual references are captioned via the + OpenAI-compat ``/v1/chat/completions`` surface after extraction. + Otherwise figures come back with ``caption=None`` and + ``describe_skipped_reason`` carries the human-readable reason. + """ + + async def _emit(**event: Any) -> None: + if cancel_event is not None and cancel_event.is_set(): + raise DocumentExtractionCancelled("document extraction was cancelled") + if progress_cb is not None: + try: + await progress_cb(event) + except Exception: + logger.debug("progress_cb raised; continuing", exc_info = True) + + max_figures = max(0, max_figures) + max_visual_payloads = max(0, min(max_visual_payloads, max_figures)) + cap = capability if capability is not None else detect_loaded_vlm(self_base_url) + image_input_available = bool(cap.is_vlm and cap.endpoint_url and cap.model_name) + describe_available = bool( + describe_images and cap.is_vlm and cap.endpoint_url and cap.model_name + ) + effective_describe = ( + describe_available and max_figures > 0 and max_visual_payloads > 0 + ) + extract_images = max_figures > 0 + + skipped_reason: Optional[str] = None + if describe_images and not effective_describe: + if describe_available and max_figures <= 0: + skipped_reason = "figure description disabled because max_figures is 0" + elif describe_available and max_visual_payloads <= 0: + skipped_reason = ( + "figure description disabled because max_visual_payloads is 0" + ) + else: + skipped_reason = cap.reason or "no_vlm" + + await _emit(stage = "parsing") + + options, build_warnings = _build_extract_options( + extract_images = extract_images, + use_vlm_ocr = use_vlm_ocr, + max_visual_payloads = max_visual_payloads, + ) + options["max_figures"] = max_figures + + try: + if _run_extract_sync is _RUN_EXTRACT_SYNC_ORIGINAL: + # Drive run_in_executor directly (rather than asyncio.to_thread) + # so we can attach a done-callback that retrieves the future's + # exception even when the awaiting task is cancelled — silences + # "Future exception was never retrieved" noise on busy/cancel. + loop = asyncio.get_running_loop() + extract_future = loop.run_in_executor( + None, + _run_extract_process_sync, + file_bytes, + filename, + options, + content_type, + _EXTRACT_TIMEOUT_SECONDS, + cancel_event, + ) + extract_future.add_done_callback(_drain_future_exception) + ( + markdown, + figures_out, + page_count, + truncated_count, + seen, + ) = await extract_future + else: + # Tests monkeypatch _run_extract_sync directly; preserve that seam + # without forcing patched callables through multiprocessing spawn. + loop = asyncio.get_running_loop() + ( + markdown, + figures_out, + page_count, + truncated_count, + seen, + ) = await asyncio.wait_for( + loop.run_in_executor( + None, + _run_extract_sync, + file_bytes, + filename, + options, + content_type, + ), + timeout = _EXTRACT_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + raise DocumentExtractionTimeout( + "document parsing exceeded the 120-second worker limit" + ) + except DocumentExtractionTimeout: + raise + except DocumentExtractionBusy: + raise + except DocumentExtractionCancelled: + raise + except DocumentExtractionEncrypted: + raise + except DocumentExtractionUnavailable: + raise + except ValueError: + # Unsupported file type — surface unchanged so the route can map to 415. + raise + except Exception as exc: + logger.exception("document extraction failed for %s", filename) + raise RuntimeError("document extraction failed") from exc + + caption_deadline_hit = False + if effective_describe: + caption_concurrency = ( + _LOCAL_VLM_CAPTION_CONCURRENCY + if cap.source in {"transformers", "unsloth"} + else _DEFAULT_VLM_CAPTION_CONCURRENCY + ) + sem = asyncio.Semaphore(caption_concurrency) + + captionable_total = sum( + 1 + for fig in figures_out[:max_figures] + if fig.image_base64 and fig.image_mime + ) + captioned_completed = 0 + await _emit( + stage = "captioning", + current = 0, + total = captionable_total, + page = None, + total_pages = page_count, + ) + + async def _describe_one(index: int, figure: ExtractedFigure) -> None: + nonlocal captioned_completed + if figure.caption or not figure.image_base64 or not figure.image_mime: + return + if cancel_event is not None and cancel_event.is_set(): + raise DocumentExtractionCancelled("document extraction was cancelled") + async with sem: + if cancel_event is not None and cancel_event.is_set(): + raise DocumentExtractionCancelled( + "document extraction was cancelled" + ) + try: + caption, error = await _describe_image_via_vlm( + image_base64 = figure.image_base64, + image_mime = figure.image_mime, + endpoint_url = cap.endpoint_url or "", + model_name = cap.model_name or "", + authorization_header = authorization_header, + timeout_seconds = vlm_timeout_seconds, + ) + figures_out[index] = replace( + figure, + caption = caption, + error = error, + ) + except asyncio.TimeoutError as exc: + logger.warning( + "VLM describe timed out for figure %s", figure.id, exc_info = exc + ) + figures_out[index] = replace( + figure, + error = f"VLM describe timed out: {type(exc).__name__}", + ) + except Exception as exc: + logger.warning( + "VLM describe failed for figure %s", figure.id, exc_info = exc + ) + figures_out[index] = replace( + figure, + error = f"VLM describe failed: {type(exc).__name__}", + ) + finally: + captioned_completed += 1 + await _emit( + stage = "captioning", + current = captioned_completed, + total = captionable_total, + page = figure.page, + total_pages = page_count, + ) + + tasks = [ + _describe_one(index, fig) + for index, fig in enumerate(figures_out[:max_figures]) + if fig.image_base64 and fig.image_mime + ] + if tasks: + try: + caption_timeout_seconds = _VLM_CAPTION_TOTAL_TIMEOUT_SECONDS + if cap.source in {"transformers", "unsloth"}: + caption_timeout_seconds = max( + caption_timeout_seconds, + len(tasks) * vlm_timeout_seconds + 15, + ) + results = await asyncio.wait_for( + asyncio.gather(*tasks, return_exceptions = True), + timeout = caption_timeout_seconds, + ) + for result in results: + if isinstance( + result, + (DocumentExtractionCancelled, asyncio.CancelledError), + ): + raise result + except asyncio.TimeoutError: + caption_deadline_hit = True + for index, figure in enumerate(figures_out): + if figure.image_base64 and not figure.caption and not figure.error: + figures_out[index] = replace( + figure, + error = "VLM caption deadline exceeded", + ) + + warnings: List[str] = list(build_warnings) + if truncated_count > 0: + warnings.append( + f"Document has {seen} figures; showing the first {max_figures} " + f"({truncated_count} truncated)." + ) + visual_payload_count = sum(1 for figure in figures_out if figure.image_base64) + if ( + visual_payload_count >= max_visual_payloads + and len(figures_out) > visual_payload_count + ): + warnings.append( + f"Only the first {max_visual_payloads} visual payloads " + "were attached; remaining figure references are text-only." + ) + if ( + effective_describe + and figures_out + and all(f.caption is None for f in figures_out) + ): + error_samples: list[str] = [] + seen_errors: set[str] = set() + for figure in figures_out: + if not figure.error or figure.error in seen_errors: + continue + seen_errors.add(figure.error) + error_samples.append(f"{figure.id}: {figure.error}") + if len(error_samples) >= 3: + break + sample_suffix = ( + " Examples: " + "; ".join(error_samples) + "." if error_samples else "" + ) + warnings.append( + "Figure descriptions were requested but none were produced — " + "check that the loaded model accepts image inputs via /v1." + f"{sample_suffix}" + ) + if caption_deadline_hit: + warnings.append( + "Figure captioning reached the inline timeout; some image " + "descriptions were skipped." + ) + + await _emit(stage = "done") + + return ExtractResult( + markdown = markdown, + figures = figures_out, + page_count = page_count, + tokens_est = _estimate_tokens(markdown), + describe_skipped_reason = skipped_reason, + vlm_source = cap.source, + vlm_model = cap.model_name, + image_input_available = image_input_available, + warnings = warnings, + ) diff --git a/studio/backend/core/chat/vlm_capability.py b/studio/backend/core/chat/vlm_capability.py new file mode 100644 index 0000000000..f8992c6455 --- /dev/null +++ b/studio/backend/core/chat/vlm_capability.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +""" +Runtime probe: is the currently loaded model vision-capable, and where +is its OpenAI-compatible endpoint? + +Unifies the three Studio inference backends (embedded llama-server for +GGUF, transformers, Unsloth/LoRA) behind a single ``VlmCapability`` +dataclass. Read-only — never loads or modifies models. + +Why this replaces the old ``VISION_ARCHITECTURES`` allow-list: +- Allow-lists silently exclude legitimately new vision architectures. +- Runtime probing matches the user's actual loaded model. +- The document extractor can caption selected visual references through + any loaded backend exposing ``/v1/chat/completions`` without + hard-coding architecture names. +""" + +from __future__ import annotations + +import logging +from dataclasses import asdict, dataclass +from typing import Any, Literal, Optional +from urllib.parse import urlparse + + +logger = logging.getLogger(__name__) + + +VlmSource = Literal["gguf", "transformers", "unsloth", "none"] + + +@dataclass(frozen = True) +class VlmCapability: + """Immutable snapshot of the loaded model's image-input capability.""" + + is_vlm: bool + endpoint_url: Optional[str] + model_name: Optional[str] + source: VlmSource + reason: Optional[str] = None + + @classmethod + def none(cls, reason: str = "no model loaded") -> "VlmCapability": + return cls( + is_vlm = False, + endpoint_url = None, + model_name = None, + source = "none", + reason = reason, + ) + + def to_dict(self) -> dict: + return asdict(self) + + +def _probe_gguf(llama: Any = None) -> Optional[VlmCapability]: + if llama is None: + try: + from core.inference.llama_cpp import get_llama_cpp_backend + except Exception: # pragma: no cover - older embedding paths + return None + + try: + llama = get_llama_cpp_backend() + except Exception: + return None + + if not getattr(llama, "is_loaded", False): + return None + + base_url = getattr(llama, "base_url", None) + model_id = getattr(llama, "model_identifier", None) + is_vision = bool(getattr(llama, "is_vision", False)) + + if not base_url or not model_id: + # Half-initialised llama-server state — fall through to the + # transformers probe instead of returning a misleading + # non-vision GGUF result that suppresses the fallback chain. + logger.debug( + "llama-server reports is_loaded=True but base_url / model id missing" + ) + return None + + return VlmCapability( + is_vlm = is_vision, + endpoint_url = base_url, + model_name = model_id, + source = "gguf", + reason = None + if is_vision + else "gguf: model loaded, is_vision=False (no mmproj clip)", + ) + + +def _probe_transformers(self_base_url: Optional[str]) -> Optional[VlmCapability]: + try: + from core.inference import get_inference_backend + except ModuleNotFoundError as exc: + if exc.name == "core.inference" or ( + exc.name and exc.name.startswith("core.inference.") + ): + return None + logger.exception("Failed to import transformers inference backend") + return None + except ImportError: + # A different ImportError variant (e.g. circular import). Treat as + # backend-unavailable. Anything else (NameError/AttributeError raised + # by core.inference.__init__) propagates so real bugs aren't masked + # as "no VLM loaded". + logger.exception("Failed to import transformers inference backend") + return None + + try: + ib = get_inference_backend() + except Exception: + return None + + name: Optional[str] = getattr(ib, "active_model_name", None) + if not name: + return None + + models: dict = getattr(ib, "models", {}) or {} + info: dict = models.get(name) or {} + is_vision = bool(info.get("is_vision", False)) + is_lora = bool(info.get("is_lora", False)) + source: VlmSource = "unsloth" if is_lora else "transformers" + + if not self_base_url: + return VlmCapability( + is_vlm = False, + endpoint_url = None, + model_name = name, + source = source, + reason = f"{source}: self_base_url=None (cannot self-loopback to /v1/chat/completions)", + ) + + return VlmCapability( + is_vlm = is_vision, + endpoint_url = self_base_url.rstrip("/"), + model_name = name, + source = source, + reason = None if is_vision else f"{source}: active model not marked is_vision", + ) + + +def detect_loaded_vlm( + self_base_url: Optional[str] = None, + *, + llama_backend: Any = None, +) -> VlmCapability: + """Identify the active model and whether it can describe images. + + ``self_base_url`` is only consulted when the active model is served + by the transformers / Unsloth backend; document image captioning must + loop back through our own ``/v1/chat/completions``. GGUF models return + llama-server's own URL and ignore this argument. + """ + gguf = _probe_gguf(llama_backend) + if gguf is not None: + return gguf + + tf = _probe_transformers(self_base_url) + if tf is not None: + return tf + + return VlmCapability.none() + + +def extract_self_base_url(request: Any) -> Optional[str]: + """Derive a trusted local base URL for the active Studio server. + + The request Host header is attacker-controlled in many deployments, + so the returned origin always uses ``127.0.0.1``. Only the server + port is discovered, preferring the port published by ``run.py`` and + then uvicorn's ASGI scope. ``request.base_url`` is a last-resort + fallback for tests and non-uvicorn embedding. + """ + port: Optional[int] = None + + try: + candidate = getattr(getattr(request, "app", None), "state", None) + candidate = getattr(candidate, "server_port", None) + if isinstance(candidate, int) and candidate > 0: + port = candidate + except Exception: + port = None + + if port is None: + try: + server = getattr(request, "scope", {}).get("server") + if ( + isinstance(server, tuple) + and len(server) >= 2 + and isinstance(server[1], int) + and server[1] > 0 + ): + port = server[1] + except Exception: + port = None + + if port is None: + try: + base = str(getattr(request, "base_url", "") or "") + if not base: + return None + parsed = urlparse(base) + port = parsed.port if parsed.port is not None else 8888 + except Exception: + return None + + return f"http://127.0.0.1:{int(port)}" diff --git a/studio/backend/core/export/export.py b/studio/backend/core/export/export.py index 7cabd382eb..1ad3a3607b 100644 --- a/studio/backend/core/export/export.py +++ b/studio/backend/core/export/export.py @@ -182,7 +182,10 @@ def load_checkpoint( # Detect audio type and vision self._audio_type = detect_audio_type(model_id) - self.is_vision = not self._audio_type and is_vision_model(model_id) + self.is_vision = not self._audio_type and is_vision_model( + model_id, + trust_remote_code = trust_remote_code, + ) # Load model based on type if self._audio_type == "csm": diff --git a/studio/backend/core/inference/__init__.py b/studio/backend/core/inference/__init__.py index 35318f6357..8c56a56564 100644 --- a/studio/backend/core/inference/__init__.py +++ b/studio/backend/core/inference/__init__.py @@ -7,17 +7,43 @@ The default get_inference_backend() returns an InferenceOrchestrator that delegates to a subprocess. The original InferenceBackend runs inside the subprocess and can be imported directly from .inference when needed. -""" -from .orchestrator import InferenceOrchestrator, get_inference_backend -from .llama_cpp import LlamaCppBackend +Symbols are exposed lazily through ``__getattr__`` (PEP 562) so that +importing a stdlib-only helper from this package (e.g. +``from core.inference._html_to_md import html_to_markdown``) does not +eagerly pull in the orchestrator or the GGUF/llama-server backend. +That matters for the document-extractor HTML path which must keep +working in environments where the inference extras are unavailable or +broken. +""" -# Expose InferenceOrchestrator as InferenceBackend for backward compat -InferenceBackend = InferenceOrchestrator +from typing import Any __all__ = [ "InferenceBackend", "InferenceOrchestrator", "get_inference_backend", + "get_llama_cpp_backend", "LlamaCppBackend", ] + + +def __getattr__(name: str) -> Any: + if name in ("InferenceOrchestrator", "get_inference_backend", "InferenceBackend"): + from .orchestrator import InferenceOrchestrator, get_inference_backend + + globals()["InferenceOrchestrator"] = InferenceOrchestrator + globals()["get_inference_backend"] = get_inference_backend + globals()["InferenceBackend"] = InferenceOrchestrator + return globals()[name] + if name in ("LlamaCppBackend", "get_llama_cpp_backend"): + from .llama_cpp import LlamaCppBackend, get_llama_cpp_backend + + globals()["LlamaCppBackend"] = LlamaCppBackend + globals()["get_llama_cpp_backend"] = get_llama_cpp_backend + return globals()[name] + raise AttributeError(name) + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index 76234386aa..b41ee999b8 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -711,6 +711,10 @@ def is_active(self) -> bool: def base_url(self) -> str: return f"http://127.0.0.1:{self._port}" + @property + def api_key(self) -> Optional[str]: + return self._api_key + @property def model_identifier(self) -> Optional[str]: return self._model_identifier @@ -4077,6 +4081,9 @@ def _parse_tool_calls_from_text(content: str) -> list[dict]: def _build_openai_messages( messages: list[dict], image_b64: Optional[str] = None, + image_b64s: Optional[list[str]] = None, + image_mime: Optional[str] = None, + image_mimes: Optional[list[str]] = None, ) -> list[dict]: """ Build OpenAI-format messages, optionally injecting an image_url @@ -4084,8 +4091,18 @@ def _build_openai_messages( If no image is provided, returns messages as-is. """ - if not image_b64: + images = ( + image_b64s if image_b64s is not None else ([image_b64] if image_b64 else []) + ) + images = [image for image in images if image] + if not images: return messages + if image_b64s is not None: + mimes = image_mimes or ["image/png"] * len(images) + else: + mimes = [image_mime or "image/png"] + if len(mimes) < len(images): + mimes = [*mimes, *(["image/png"] * (len(images) - len(mimes)))] # Find the last user message and convert to multimodal content parts result = [msg.copy() for msg in messages] @@ -4096,14 +4113,18 @@ def _build_openai_messages( if last_user_idx is not None: text_content = result[last_user_idx].get("content", "") - result[last_user_idx]["content"] = [ - {"type": "text", "text": text_content}, + image_parts = [ { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{image_b64}", + "url": f"data:{mime if mime and '/' in mime else 'image/png'};base64,{image}", }, - }, + } + for image, mime in zip(images, mimes) + ] + result[last_user_idx]["content"] = [ + {"type": "text", "text": text_content}, + *image_parts, ] return result @@ -4235,6 +4256,9 @@ def generate_chat_completion( self, messages: list[dict], image_b64: Optional[str] = None, + image_b64s: Optional[list[str]] = None, + image_mime: Optional[str] = None, + image_mimes: Optional[list[str]] = None, temperature: float = 0.6, top_p: float = 0.95, top_k: int = 20, @@ -4259,7 +4283,13 @@ def generate_chat_completion( if not self.is_loaded: raise RuntimeError("llama-server is not loaded") - openai_messages = self._build_openai_messages(messages, image_b64) + openai_messages = self._build_openai_messages( + messages, + image_b64 = image_b64, + image_b64s = image_b64s, + image_mime = image_mime, + image_mimes = image_mimes, + ) payload = { "messages": openai_messages, @@ -5490,3 +5520,20 @@ def generate_audio_response( return LlamaCppBackend._codec_mgr.decode( audio_type, device, token_ids = token_ids, text = data.get("content", "") ) + + +_llama_cpp_backend: Optional[LlamaCppBackend] = None + + +def get_llama_cpp_backend() -> LlamaCppBackend: + """Return the process-wide GGUF llama-server backend. + + Keep the singleton in ``core.inference`` so core helpers such as + ``core.chat.detect_loaded_vlm`` do not need to import route modules. + The instance is lazy to avoid subprocess cleanup side effects for + callers that only import model helpers. + """ + global _llama_cpp_backend + if _llama_cpp_backend is None: + _llama_cpp_backend = LlamaCppBackend() + return _llama_cpp_backend diff --git a/studio/backend/core/inference/worker.py b/studio/backend/core/inference/worker.py index 20a7d2d16c..ba12157780 100644 --- a/studio/backend/core/inference/worker.py +++ b/studio/backend/core/inference/worker.py @@ -74,7 +74,28 @@ def _send_response(resp_queue: Any, response: dict) -> None: logger.error("Failed to send response: %s", exc) -def _build_model_config(config: dict): +def _resolve_trust_remote_code(config: dict) -> bool: + # Auto-enable trust_remote_code for NemotronH/Nano models only. + # NemotronH has config parsing bugs requiring trust_remote_code=True. + # Other transformers 5.x models are native and do NOT need it. + # NOTE: Must NOT match Llama-Nemotron (standard Llama architecture). + trust_remote_code = config.get("trust_remote_code", False) + if not trust_remote_code: + model_name = config["model_name"] + _mn_lower = model_name.lower() + _NEMOTRON_TRUST_SUBSTRINGS = ("nemotron_h", "nemotron-h", "nemotron-3-nano") + if any(sub in _mn_lower for sub in _NEMOTRON_TRUST_SUBSTRINGS) and ( + _mn_lower.startswith("unsloth/") or _mn_lower.startswith("nvidia/") + ): + trust_remote_code = True + logger.info( + "Auto-enabled trust_remote_code for Nemotron model: %s", + model_name, + ) + return bool(trust_remote_code) + + +def _build_model_config(config: dict, *, trust_remote_code: bool | None = None): """Build a ModelConfig from the config dict.""" from utils.models import ModelConfig @@ -82,11 +103,14 @@ def _build_model_config(config: dict): hf_token = config.get("hf_token") hf_token = hf_token if hf_token and hf_token.strip() else None gguf_variant = config.get("gguf_variant") + if trust_remote_code is None: + trust_remote_code = _resolve_trust_remote_code(config) mc = ModelConfig.from_identifier( model_id = model_name, hf_token = hf_token, gguf_variant = gguf_variant, + trust_remote_code = trust_remote_code, ) if not mc: raise ValueError(f"Invalid model identifier: {model_name}") @@ -247,7 +271,8 @@ def _beat(): def _handle_load(backend, config: dict, resp_queue: Any) -> None: """Handle a load command: load a model into the backend.""" try: - mc = _build_model_config(config) + trust_remote_code = _resolve_trust_remote_code(config) + mc = _build_model_config(config, trust_remote_code = trust_remote_code) hf_token = config.get("hf_token") hf_token = hf_token if hf_token and hf_token.strip() else None @@ -287,24 +312,6 @@ def _handle_load(backend, config: dict, resp_queue: Any) -> None: except Exception as e: logger.warning("Could not read adapter_config.json: %s", e) - # Auto-enable trust_remote_code for NemotronH/Nano models only. - # NemotronH has config parsing bugs requiring trust_remote_code=True. - # Other transformers 5.x models are native and do NOT need it. - # NOTE: Must NOT match Llama-Nemotron (standard Llama architecture). - _NEMOTRON_TRUST_SUBSTRINGS = ("nemotron_h", "nemotron-h", "nemotron-3-nano") - trust_remote_code = config.get("trust_remote_code", False) - if not trust_remote_code: - model_name = config["model_name"] - _mn_lower = model_name.lower() - if any(sub in _mn_lower for sub in _NEMOTRON_TRUST_SUBSTRINGS) and ( - _mn_lower.startswith("unsloth/") or _mn_lower.startswith("nvidia/") - ): - trust_remote_code = True - logger.info( - "Auto-enabled trust_remote_code for Nemotron model: %s", - model_name, - ) - # Send heartbeats every 30s so the orchestrator knows we're still alive # (download / weight loading can take a long time on slow connections) xet_disabled = os.environ.get("HF_HUB_DISABLE_XET") == "1" diff --git a/studio/backend/core/training/trainer.py b/studio/backend/core/training/trainer.py index b128fb5338..39b491b488 100644 --- a/studio/backend/core/training/trainer.py +++ b/studio/backend/core/training/trainer.py @@ -201,7 +201,11 @@ def pre_detect_and_load_tokenizer( # --- Detect VLM --- vision = ( - is_vision_model(model_name, hf_token = hf_token) + is_vision_model( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) if not self.is_audio else False ) @@ -574,7 +578,11 @@ def load_model( # VLM: vision model with image dataset (mutually exclusive with audio paths) vision = ( - is_vision_model(model_name, hf_token = hf_token) + is_vision_model( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) if not self.is_audio else False ) diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index 0af9425fdc..21949baf5b 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -129,6 +129,10 @@ class ValidateModelRequest(BaseModel): gguf_variant: Optional[str] = Field( None, description = "GGUF quantization variant (e.g. 'Q4_K_M')" ) + trust_remote_code: bool = Field( + False, + description = "Allow validation probes that require custom model code.", + ) class ValidateModelResponse(BaseModel): @@ -172,6 +176,14 @@ class GenerateRequest(BaseModel): image_base64: Optional[str] = Field( None, description = "Base64 encoded image for vision models" ) + session_id: Optional[str] = Field( + None, + description = "[x-unsloth] Session/thread ID for cancellation scoping.", + ) + cancel_id: Optional[str] = Field( + None, + description = "[x-unsloth] Per-request cancellation token matched by /inference/cancel.", + ) class LoadResponse(BaseModel): @@ -353,6 +365,10 @@ class InferenceStatusResponse(BaseModel): supports_tools: bool = Field( False, description = "Whether the active model supports tool calling" ) + cache_type_kv: Optional[str] = Field( + None, + description = "KV cache data type for K and V (e.g. 'f16', 'bf16', 'q8_0')", + ) context_length: Optional[int] = Field( None, description = "Context length of the active model" ) @@ -1471,3 +1487,159 @@ class AnthropicMessagesResponse(BaseModel): stop_reason: Optional[str] = None stop_sequence: Optional[str] = None usage: AnthropicUsage = Field(default_factory = AnthropicUsage) + + +# ---------------------------------------------------------------------- # +# Chat document extraction (parsed documents + optional VLM captions) # +# ---------------------------------------------------------------------- # + + +class ExtractedFigureModel(BaseModel): + """A single extracted visual reference, optionally described by a + locally-loaded vision model.""" + + id: str = Field(..., description = "Stable id (e.g. 'fig-0')") + page: Optional[int] = Field(None, description = "1-based page number, if known") + caption: Optional[str] = Field( + None, description = "Short VLM-generated caption, or null if skipped/failed" + ) + error: Optional[str] = Field( + None, description = "Reason the describe call failed, if any" + ) + kind: Literal["figure", "page"] = Field( + "figure", + description = "Whether this reference is a detected figure or page image", + ) + image_mime: Optional[str] = Field( + None, description = "MIME type for image_base64 when a visual payload is present" + ) + image_base64: Optional[str] = Field( + None, + description = ( + "Base64-encoded visual payload for this reference. The first visual " + "reference is sent to vision-capable chat models as [Image #1]." + ), + ) + image_width: Optional[int] = Field( + None, ge = 1, description = "Width of image_base64 after resize" + ) + image_height: Optional[int] = Field( + None, ge = 1, description = "Height of image_base64 after resize" + ) + + +class ExtractDocumentResponse(BaseModel): + """ + Returned synchronously from ``POST /chat/extract-document`` for + small docs, or as the final SSE event for larger ones. + """ + + schema_version: int = Field( + 1, description = "Document extraction payload schema version" + ) + filename: str = Field(..., description = "Original filename uploaded") + markdown: str = Field( + ..., description = "Layout-aware Markdown extracted from the document" + ) + page_count: int = Field(0, ge = 0, description = "Number of pages in the source") + tokens_est: int = Field( + 0, ge = 0, description = "Rough char/4 token estimate for the markdown" + ) + truncated: bool = Field( + False, + description = "Whether markdown was clipped to the requested token budget", + ) + figures: List[ExtractedFigureModel] = Field( + default_factory = list, + description = "Figures discovered in the document (captions optional)", + ) + describe_skipped_reason: Optional[str] = Field( + None, + description = ( + "If image description was requested but skipped, the reason " + "(e.g. 'loaded GGUF is not vision-capable'). Mirrors the " + "``reason`` surfaced by /chat/document-support." + ), + ) + vlm_source: Optional[str] = Field( + None, + description = ( + "Which inference backend served the describe calls: 'gguf', " + "'transformers', 'unsloth', or 'none' when no VLM was used." + ), + ) + vlm_model: Optional[str] = Field( + None, + description = "Identifier of the VLM whose captions appear in this document", + ) + image_input_available: bool = Field( + False, + description = ( + "Whether the active model can receive an extracted visual payload " + "alongside the markdown." + ), + ) + warnings: List[str] = Field( + default_factory = list, + description = "Non-fatal warnings surfaced to the UI", + ) + + +class VlmCapabilityModel(BaseModel): + """Runtime probe result for the currently-loaded model.""" + + is_vlm: bool = Field( + ..., description = "Whether the active model accepts image inputs" + ) + endpoint_url: Optional[str] = Field( + None, + description = "Root URL serving /v1/chat/completions for the active model", + ) + model_name: Optional[str] = Field( + None, description = "Identifier of the active model, if any is loaded" + ) + source: Literal["gguf", "transformers", "unsloth", "none"] = Field( + ..., description = "Which backend currently owns the active model" + ) + reason: Optional[str] = Field( + None, + description = "Populated when is_vlm is false; explains why the UI toggle is disabled", + ) + + +class DocumentSupportResponse(BaseModel): + """Returned by GET /chat/document-support. + + Drives the Chat settings-card toggles. ``max_visual_payloads`` is kept + for older clients as an informational hint, not a hard request cap. + """ + + schema_version: int = Field( + 1, description = "Document support payload schema version" + ) + extraction_available: bool = Field( + ..., + description = ( + "Whether the document extraction backend successfully imported " + "on the server" + ), + ) + max_visual_payloads: int = Field( + ..., + ge = 0, + description = "Legacy visual-payload hint; not a hard request cap", + ) + max_extract_concurrency: int = Field( + 1, + ge = 1, + description = "Maximum server-side document extraction workers", + ) + format_support: Dict[str, bool] = Field( + default_factory = dict, + description = "Per-format parser availability for document extraction", + ) + unavailable_formats: Dict[str, str] = Field( + default_factory = dict, + description = "Per-format parser unavailability reasons", + ) + vlm: VlmCapabilityModel diff --git a/studio/backend/requirements/studio.txt b/studio/backend/requirements/studio.txt index 96f8816b57..13c556b878 100644 --- a/studio/backend/requirements/studio.txt +++ b/studio/backend/requirements/studio.txt @@ -16,5 +16,15 @@ huggingface-hub==0.36.2 structlog>=24.1.0 diceware ddgs +pypdf>=6.0.0,<7 +python-multipart>=0.0.26 +# Document extraction relies on pymupdf4llm 1.27+ (installed via +# data-designer-deps.txt), which pulls pymupdf-layout. The bundled ONNX +# models work fine on modern onnxruntime; we require >=1.19 because +# earlier wheels (e.g. 1.17.x) were built against NumPy 1.x and crash +# on import in venvs that have NumPy 2.x installed (pymupdf.layout -> +# onnxruntime -> numpy._multiarray_umath ABI mismatch). Verified +# end-to-end with onnxruntime 1.25.0 + numpy 2.4.x. +onnxruntime>=1.19 cryptography>=42.0.0 httpx>=0.27.0 diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index a156f2397c..6b7fec5bdb 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -9,9 +9,11 @@ import sys import time import uuid +from contextlib import suppress from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import StreamingResponse, JSONResponse, Response +from pydantic import ValidationError from typing import Any, Optional, Union import json import httpx @@ -124,6 +126,7 @@ def _friendly_error(exc: Exception) -> str: _canonicalize_spec_mode, _hf_offline_if_dns_dead, detect_reasoning_flags, + get_llama_cpp_backend, ) from core.inference.llama_server_args import ( strip_shadowing_flags, @@ -151,6 +154,7 @@ def _friendly_error(exc: Exception) -> str: _canonicalize_spec_mode, _hf_offline_if_dns_dead, detect_reasoning_flags, + get_llama_cpp_backend, ) from core.inference.llama_server_args import ( strip_shadowing_flags, @@ -210,10 +214,14 @@ def _friendly_error(exc: Exception) -> str: AnthropicUsage, CreateOpenAIContainerBody, DeleteOpenAIContainerBody, + DocumentSupportResponse, + ExtractDocumentResponse, + ExtractedFigureModel, ListOpenAIContainersResponse, OpenAIContainerRequest, OpenAIContainerSummary, ) +from dataclasses import asdict as _asdict from core.inference.anthropic_compat import ( anthropic_messages_to_openai, anthropic_tools_to_openai, @@ -558,12 +566,12 @@ def _resolve_model_identifier_for_request( return str(grant.canonical_path), display_label, True -# GGUF inference backend (llama-server) -_llama_cpp_backend = LlamaCppBackend() - - -def get_llama_cpp_backend() -> LlamaCppBackend: - return _llama_cpp_backend +# GGUF inference backend (llama-server) singleton lives in +# ``core.inference.llama_cpp``. ``get_llama_cpp_backend`` is already +# imported above and re-exported from this module so external callers +# that do ``from routes.inference import get_llama_cpp_backend`` keep +# resolving to the same process-wide instance that load/list/delete/ +# shutdown all consult. @router.post("/load", response_model = LoadResponse) @@ -661,6 +669,7 @@ async def load_model( reasoning_always_on = llama_backend.reasoning_always_on, supports_preserve_thinking = llama_backend.supports_preserve_thinking, supports_tools = llama_backend.supports_tools, + cache_type_kv = llama_backend.cache_type_kv, chat_template = llama_backend.chat_template, speculative_type = llama_backend.requested_spec_mode, spec_draft_n_max = llama_backend.spec_draft_n_max, @@ -713,6 +722,26 @@ async def load_model( chat_template = _chat_template, ) + model_defaults = load_model_defaults(request.model_path) + defaults_require_trust_remote_code = bool( + model_defaults.get("model", {}).get("trust_remote_code", False) + or model_defaults.get("inference", {}).get("trust_remote_code", False) + ) + if defaults_require_trust_remote_code and not request.trust_remote_code: + display_name = ( + model_defaults.get("model", {}).get("display_name") + or request.model_path.split("/")[-1] + or request.model_path + ) + raise HTTPException( + status_code = 400, + detail = ( + f"Model '{display_name}' requires trust_remote_code to be enabled. " + "Please enable 'Trust remote code' in Chat Settings and try again." + ), + ) + + # Create config using clean factory method. # is_lora auto-detected from adapter_config.json on disk/HF. # DNS-probe wrap so offline loads skip 30-60s of soft-failed # network checks before the worker starts. @@ -721,6 +750,7 @@ async def load_model( model_id = model_identifier, hf_token = request.hf_token, gguf_variant = request.gguf_variant, + trust_remote_code = request.trust_remote_code, ) if not config: @@ -1122,10 +1152,39 @@ async def validate_model( model_identifier, model_log_label, native_grant_backed = ( _resolve_model_identifier_for_request(request, operation = "validate-model") ) + if not native_grant_backed: + model_defaults = load_model_defaults(request.model_path) + default_model_config = model_defaults.get("model", {}) + default_inference_config = model_defaults.get("inference", {}) + defaults_require_trust_remote_code = bool( + default_model_config.get("trust_remote_code", False) + or default_inference_config.get("trust_remote_code", False) + ) + if defaults_require_trust_remote_code and not request.trust_remote_code: + display_name = ( + default_model_config.get("display_name") + or request.model_path.split("/")[-1] + or request.model_path + ) + return ValidateModelResponse( + valid = True, + message = ( + "Model identifier is valid, but this model requires " + "trust_remote_code before probing or loading." + ), + identifier = request.model_path, + display_name = display_name, + is_gguf = False, + is_lora = False, + is_vision = bool(default_model_config.get("is_vision", False)), + requires_trust_remote_code = True, + ) + config = ModelConfig.from_identifier( model_id = model_identifier, hf_token = request.hf_token, gguf_variant = request.gguf_variant, + trust_remote_code = request.trust_remote_code, ) if not config: @@ -1231,10 +1290,15 @@ async def cancel_inference( A cancel_id arriving before its stream registers is stashed briefly and replayed on registration. Returns {"cancelled": N}. """ + # The cancel body is a tiny dict of identifiers; cap the read so an + # authenticated client cannot make this endpoint buffer megabytes + # the way the sibling JSON inference endpoints already prevent. try: - body = await request.json() + body = await _read_json_body_limited(request, max_bytes = 64 * 1024) if not isinstance(body, dict): body = {} + except HTTPException: + raise except Exception as e: logger.debug("Failed to parse cancel request body: %s", e) body = {} @@ -1260,6 +1324,7 @@ async def cancel_inference( @router.post("/generate/stream") async def generate_stream( + fastapi_request: Request, request: GenerateRequest, current_subject: str = Depends(get_current_subject), ): @@ -1302,9 +1367,21 @@ async def generate_stream( status_code = 400, detail = f"Failed to decode image: {str(e)}" ) + cancel_event = threading.Event() + completion_id = f"legacy-{uuid.uuid4().hex[:12]}" + _tracker = _TrackedCancel( + cancel_event, + request.cancel_id, + request.session_id, + completion_id, + ) + _tracker.__enter__() + async def stream(): + _DONE = object() try: - for chunk in backend.generate_chat_response( + yield f"data: {json.dumps({'completion_id': completion_id})}\n\n" + gen = backend.generate_chat_response( messages = request.messages, system_prompt = request.system_prompt, image = image, @@ -1313,7 +1390,19 @@ async def stream(): top_k = request.top_k, max_new_tokens = request.max_new_tokens, repetition_penalty = request.repetition_penalty, - ): + cancel_event = cancel_event, + ) + while True: + if cancel_event.is_set(): + backend.reset_generation_state() + break + if await fastapi_request.is_disconnected(): + cancel_event.set() + backend.reset_generation_state() + return + chunk = await asyncio.to_thread(next, gen, _DONE) + if chunk is _DONE: + break yield f"data: {json.dumps({'content': chunk})}\n\n" yield "data: [DONE]\n\n" @@ -1321,6 +1410,9 @@ async def stream(): backend.reset_generation_state() logger.error(f"Error during generation: {e}", exc_info = True) yield f"data: {json.dumps({'error': _friendly_error(e)})}\n\n" + finally: + cancel_event.set() + _tracker.__exit__(None, None, None) return StreamingResponse( stream(), @@ -1632,9 +1724,123 @@ def _decode_audio_base64(b64: str) -> np.ndarray: return waveform.squeeze(0).numpy() +_OPENAI_CHAT_MAX_IMAGES = 256 +_OPENAI_CHAT_MAX_IMAGE_BYTES = 20 * 1024 * 1024 +_OPENAI_CHAT_MAX_IMAGE_PIXELS = 40_000_000 +_OPENAI_CHAT_MAX_IMAGE_BASE64_CHARS = ( + (_OPENAI_CHAT_MAX_IMAGE_BYTES + 2) // 3 +) * 4 + 1024 + + +def _convert_openai_image_b64_to_png_b64(image_b64: str) -> str: + if len(image_b64) > _OPENAI_CHAT_MAX_IMAGE_BASE64_CHARS: + raise HTTPException( + status_code = 413, + detail = "Image payload exceeds the 20 MB decoded-image limit.", + ) + + try: + import base64 as _b64 + from io import BytesIO as _BytesIO + from PIL import Image as _Image + + raw = _b64.b64decode(image_b64, validate = True) + if len(raw) > _OPENAI_CHAT_MAX_IMAGE_BYTES: + raise HTTPException( + status_code = 413, + detail = "Image payload exceeds the 20 MB decoded-image limit.", + ) + with _Image.open(_BytesIO(raw)) as img: + width, height = img.size + if width * height > _OPENAI_CHAT_MAX_IMAGE_PIXELS: + raise HTTPException( + status_code = 413, + detail = "Image dimensions exceed the 40 MP limit.", + ) + converted = img.convert("RGB") + buf = _BytesIO() + converted.save(buf, format = "PNG") + png = buf.getvalue() + if len(png) > _OPENAI_CHAT_MAX_IMAGE_BYTES: + raise HTTPException( + status_code = 413, + detail = "Converted image payload exceeds the 20 MB limit.", + ) + return _b64.b64encode(png).decode("ascii") + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code = 400, detail = f"Failed to process image: {e}" + ) from e + + +def _data_url_base64_payload(url: str) -> str: + try: + header, b64data = url.split(",", 1) + except ValueError as exc: + raise HTTPException( + status_code = 400, detail = "Image data URL is missing base64 payload." + ) from exc + if ";base64" not in header.lower(): + raise HTTPException( + status_code = 400, detail = "Image data URL must be base64 encoded." + ) + return b64data + + +def _normalize_openai_message_images( + openai_messages: list[dict], + *, + is_vision: bool, + not_vision_detail: str, +) -> bool: + """Apply image count/size/pixel guards and normalize data URLs to PNG.""" + has_image = False + image_count = 0 + + for msg in openai_messages: + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if not isinstance(part, dict) or part.get("type") != "image_url": + continue + + has_image = True + image_count += 1 + if image_count > _OPENAI_CHAT_MAX_IMAGES: + raise HTTPException( + status_code = 413, + detail = f"Too many images provided; maximum is {_OPENAI_CHAT_MAX_IMAGES}.", + ) + if not is_vision: + raise HTTPException(status_code = 400, detail = not_vision_detail) + + image_url = part.get("image_url") or {} + if not isinstance(image_url, dict): + raise HTTPException( + status_code = 400, detail = "Invalid image_url content part." + ) + url = image_url.get("url", "") + if not isinstance(url, str): + raise HTTPException(status_code = 400, detail = "Invalid image_url URL.") + if not url.startswith("data:"): + # Remote URLs are counted but cannot be byte/pixel checked here. + continue + + b64data = _data_url_base64_payload(url) + png_b64 = _convert_openai_image_b64_to_png_b64(b64data) + normalized = dict(image_url) + normalized["url"] = f"data:image/png;base64,{png_b64}" + part["image_url"] = normalized + + return has_image + + def _extract_content_parts( messages: list, -) -> tuple[str, list[dict], "Optional[str]"]: +) -> tuple[str, list[dict], list[str]]: """ Parse OpenAI-format messages into components the inference backend expects. @@ -1644,11 +1850,11 @@ def _extract_content_parts( Returns: system_prompt: The system message text (empty string if none provided). chat_messages: Non-system messages with content flattened to strings. - image_base64: Base64 data of the *first* image found, or ``None``. + image_base64s: Base64 data for image parts, in request order. """ system_prompt = "" chat_messages: list[dict] = [] - first_image_b64: Optional[str] = None + image_b64s: list[str] = [] for msg in messages: # ── System messages → extract as system_prompt ──────── @@ -1672,11 +1878,12 @@ def _extract_content_parts( for part in msg.content: if part.type == "text": text_parts.append(part.text) - elif part.type == "image_url" and first_image_b64 is None: + elif part.type == "image_url": url = part.image_url.url if url.startswith("data:"): # data:image/png;base64, → extract - first_image_b64 = url.split(",", 1)[1] if "," in url else None + if "," in url: + image_b64s.append(url.split(",", 1)[1]) else: logger.warning( f"Remote image URLs not yet supported: {url[:80]}..." @@ -1684,7 +1891,7 @@ def _extract_content_parts( combined_text = "\n".join(text_parts) if text_parts else "" chat_messages.append({"role": msg.role, "content": combined_text}) - return system_prompt, chat_messages, first_image_b64 + return system_prompt, chat_messages, image_b64s # ── External provider proxy ────────────────────────────────────── @@ -2149,9 +2356,23 @@ async def delete_openai_container( @router.post("/chat/completions") async def openai_chat_completions( - payload: ChatCompletionRequest, request: Request, current_subject: str = Depends(get_current_subject), +): + body = await _read_json_body_limited( + request, + max_bytes = _OPENAI_CHAT_BODY_MAX_BYTES, + ) + try: + payload = ChatCompletionRequest.model_validate(body) + except ValidationError as exc: + raise HTTPException(status_code = 422, detail = exc.errors()) from exc + return await _openai_chat_completions_impl(payload, request) + + +async def _openai_chat_completions_impl( + payload: ChatCompletionRequest, + request: Request, ): """ OpenAI-compatible chat completions endpoint. @@ -2406,7 +2627,7 @@ async def audio_input_stream(): ) # ── Parse messages (handles multimodal content parts) ───── - system_prompt, chat_messages, extracted_image_b64 = _extract_content_parts( + system_prompt, chat_messages, extracted_image_b64s = _extract_content_parts( payload.messages ) @@ -2710,7 +2931,7 @@ async def gguf_tool_stream(): def gguf_generate(): return llama_backend.generate_chat_completion( messages = gguf_messages, - image_b64 = image_b64, + image_b64s = image_b64s, temperature = payload.temperature, top_p = payload.top_p, top_k = payload.top_k, @@ -2879,7 +3100,9 @@ async def gguf_stream_chunks(): # ── Standard Unsloth path ───────────────────────────────── # Decode image (from content parts OR legacy field) - image_b64 = extracted_image_b64 or payload.image_base64 + image_b64 = ( + extracted_image_b64s[0] if extracted_image_b64s else payload.image_base64 + ) image = None if image_b64: @@ -3425,9 +3648,9 @@ async def serve_sandbox_file( # ── Path containment check ────────────────────────────────── home = os.path.expanduser("~") sandbox_root = os.path.realpath(os.path.join(home, "studio_sandbox")) - safe_session = os.path.basename(session_id.replace("..", "")) - if not safe_session: + if not _re.fullmatch(r"[A-Za-z0-9_-]+", session_id or ""): raise HTTPException(status_code = 404, detail = "Not found") + safe_session = session_id file_path = os.path.realpath( os.path.join(sandbox_root, safe_session, safe_filename) @@ -3516,7 +3739,9 @@ async def openai_completions( detail = "No GGUF model loaded. Load a GGUF model first.", ) - body = await request.json() + body = await _read_json_body_limited( + request, max_bytes = _OPENAI_PROXY_BODY_MAX_BYTES + ) target_url = f"{llama_backend.base_url}/v1/completions" is_stream = body.get("stream", False) @@ -3595,7 +3820,9 @@ async def openai_embeddings( detail = "No GGUF model loaded. Load a GGUF model first.", ) - body = await request.json() + body = await _read_json_body_limited( + request, max_bytes = _OPENAI_PROXY_BODY_MAX_BYTES + ) target_url = f"{llama_backend.base_url}/v1/embeddings" async with httpx.AsyncClient() as client: @@ -3894,7 +4121,7 @@ async def _responses_non_streaming( ) -> JSONResponse: """Handle a non-streaming Responses API call.""" chat_req = _build_chat_request(payload, messages, stream = False) - result = await openai_chat_completions(chat_req, request) + result = await _openai_chat_completions_impl(chat_req, request) # openai_chat_completions returns a JSONResponse for non-streaming if isinstance(result, JSONResponse): @@ -4410,45 +4637,11 @@ def _normalize_anthropic_openai_images( HTTPException(400) when images are present but the active model is not a vision model, or when an image cannot be decoded. """ - from PIL import Image - - has_image = False - for msg in openai_messages: - content = msg.get("content") - if not isinstance(content, list): - continue - for part in content: - if part.get("type") != "image_url": - continue - - has_image = True - if not is_vision: - raise HTTPException( - status_code = 400, - detail = "Image provided but current GGUF model does not support vision.", - ) - - url = (part.get("image_url") or {}).get("url", "") - if not url.startswith("data:"): - # Remote URLs are forwarded as-is; llama-server will - # fetch (or fail) per its own support matrix. - continue - - try: - _, b64data = url.split(",", 1) - raw = base64.b64decode(b64data) - img = Image.open(io.BytesIO(raw)).convert("RGB") - buf = io.BytesIO() - img.save(buf, format = "PNG") - png_b64 = base64.b64encode(buf.getvalue()).decode("ascii") - except Exception: - raise HTTPException( - status_code = 400, - detail = "Failed to process image.", - ) - part["image_url"] = {"url": f"data:image/png;base64,{png_b64}"} - - return has_image + return _normalize_openai_message_images( + openai_messages, + is_vision = is_vision, + not_vision_detail = "Image provided but current GGUF model does not support vision.", + ) @router.post("/messages") @@ -5271,7 +5464,7 @@ def _drop_empty_assistant_sentinels(messages: list[dict]) -> list[dict]: return out -def _openai_messages_for_passthrough(payload) -> list[dict]: +def _openai_messages_for_passthrough(payload, *, is_vision: bool = True) -> list[dict]: """Build OpenAI-format message dicts for the /v1/chat/completions passthrough path. @@ -5279,7 +5472,7 @@ def _openai_messages_for_passthrough(payload) -> list[dict]: unset optional fields) so they are already in standard OpenAI format — including ``role="tool"`` tool-result messages and assistant messages that carry structured ``tool_calls``. Content-parts images already in - the message list are left untouched. + the message list are counted, bounded, and data URLs are normalized to PNG. When a client uses Studio's legacy ``image_base64`` top-level field, the image is re-encoded to PNG (llama-server's stb_image has limited format @@ -5291,41 +5484,29 @@ def _openai_messages_for_passthrough(payload) -> list[dict]: [m.model_dump(exclude_none = True) for m in payload.messages] ) - if not payload.image_base64: - return messages - - try: - import base64 as _b64 - from io import BytesIO as _BytesIO - from PIL import Image as _Image - - raw = _b64.b64decode(payload.image_base64) - img = _Image.open(_BytesIO(raw)).convert("RGB") - buf = _BytesIO() - img.save(buf, format = "PNG") - png_b64 = _b64.b64encode(buf.getvalue()).decode("ascii") - except Exception: - raise HTTPException( - status_code = 400, - detail = "Failed to process image.", - ) + if payload.image_base64: + data_url = f"data:image/unknown;base64,{payload.image_base64}" + image_part = {"type": "image_url", "image_url": {"url": data_url}} - data_url = f"data:image/png;base64,{png_b64}" - image_part = {"type": "image_url", "image_url": {"url": data_url}} - - for msg in reversed(messages): - if msg.get("role") != "user": - continue - existing = msg.get("content") - if isinstance(existing, str): - msg["content"] = [{"type": "text", "text": existing}, image_part] - elif isinstance(existing, list): - existing.append(image_part) + for msg in reversed(messages): + if msg.get("role") != "user": + continue + existing = msg.get("content") + if isinstance(existing, str): + msg["content"] = [{"type": "text", "text": existing}, image_part] + elif isinstance(existing, list): + existing.append(image_part) + else: + msg["content"] = [image_part] + break else: - msg["content"] = [image_part] - break - else: - messages.append({"role": "user", "content": [image_part]}) + messages.append({"role": "user", "content": [image_part]}) + + _normalize_openai_message_images( + messages, + is_vision = is_vision, + not_vision_detail = "Image provided but current GGUF model does not support vision.", + ) return messages @@ -5385,14 +5566,16 @@ def _extract_response_format(payload): return rf if isinstance(rf, dict) else None -def _build_openai_passthrough_body(payload, backend_ctx = None) -> dict: +def _build_openai_passthrough_body( + payload, backend_ctx = None, *, is_vision: bool = True +) -> dict: """Assemble the llama-server request body from a ChatCompletionRequest. Only explicitly-known OpenAI / llama-server fields are forwarded so that Studio-specific extensions (``enable_tools``, ``enabled_tools``, ``session_id``, ...) never leak to the backend. """ - messages = _openai_messages_for_passthrough(payload) + messages = _openai_messages_for_passthrough(payload, is_vision = is_vision) tool_choice = payload.tool_choice if payload.tool_choice is not None else "auto" # When the caller asked for a specific reasoning mode, forward it to # llama-server via chat_template_kwargs so the Jinja template renders @@ -5437,7 +5620,9 @@ async def _openai_passthrough_stream( """ target_url = f"{llama_backend.base_url}/v1/chat/completions" body = _build_openai_passthrough_body( - payload, backend_ctx = llama_backend.context_length + payload, + backend_ctx = llama_backend.context_length, + is_vision = llama_backend.is_vision, ) _cancel_keys = (payload.cancel_id, payload.session_id, completion_id) @@ -5595,7 +5780,9 @@ async def _openai_passthrough_non_streaming( """ target_url = f"{llama_backend.base_url}/v1/chat/completions" body = _build_openai_passthrough_body( - payload, backend_ctx = llama_backend.context_length + payload, + backend_ctx = llama_backend.context_length, + is_vision = llama_backend.is_vision, ) try: @@ -5657,3 +5844,952 @@ async def _openai_passthrough_non_streaming( # verbatim (matches the docstring). Status is guaranteed 200 by # the check above. return Response(content = resp.content, media_type = "application/json") + + +# ---------------------------------------------------------------------- # +# Chat document extraction (PyMuPDF4LLM + optional VLM image description)# +# ---------------------------------------------------------------------- # + +try: + from core.chat import ( + DOCUMENT_EXTRACTION_AVAILABLE as _DOCUMENT_EXTRACTION_AVAILABLE, + DEFAULT_DOCUMENT_VISUAL_PAYLOADS as _DEFAULT_DOCUMENT_VISUAL_PAYLOADS, + DocumentExtractionBusy as _DocumentExtractionBusy, + DocumentExtractionCancelled as _DocumentExtractionCancelled, + DocumentExtractionEncrypted as _DocumentExtractionEncrypted, + DocumentExtractionTimeout as _DocumentExtractionTimeout, + DocumentExtractionUnavailable as _DocumentExtractionUnavailable, + _EXTRACT_CONCURRENCY as _DOCUMENT_EXTRACT_CONCURRENCY, + MAX_DOCUMENT_VISUAL_PAYLOADS as _MAX_DOCUMENT_VISUAL_PAYLOADS, + SUPPORTED_MIME_TYPES as _DOC_MIME_OK, + SUPPORTED_SUFFIXES as _DOC_SUFFIX_OK, + VlmCapability as _VlmCapability, + _EXTRACT_SEMAPHORE, + _drain_future_exception as _drain_doc_future_exception, + detect_loaded_vlm as _detect_loaded_vlm, + document_parser_support as _document_parser_support, + document_parser_unavailable_reasons as _document_parser_unavailable_reasons, + extract_document as _extract_document, + extract_self_base_url as _extract_self_base_url, + ) +except ImportError: # pragma: no cover - package always installed alongside + _DOCUMENT_EXTRACTION_AVAILABLE = False + _DEFAULT_DOCUMENT_VISUAL_PAYLOADS = 0 + _DOCUMENT_EXTRACT_CONCURRENCY = 1 + _MAX_DOCUMENT_VISUAL_PAYLOADS = 0 + _DOC_MIME_OK = frozenset() + _DOC_SUFFIX_OK = frozenset() + _detect_loaded_vlm = None # type: ignore[assignment] + _extract_document = None # type: ignore[assignment] + _extract_self_base_url = None # type: ignore[assignment] + _document_parser_support = lambda: {} # type: ignore[assignment] + _document_parser_unavailable_reasons = lambda: {} # type: ignore[assignment] + _VlmCapability = None # type: ignore[assignment] + _drain_doc_future_exception = lambda _f: None # type: ignore[assignment] + + class _DocumentExtractionUnavailable(RuntimeError): # type: ignore[no-redef] + pass + + class _DocumentExtractionTimeout(RuntimeError): # type: ignore[no-redef] + pass + + class _DocumentExtractionBusy(RuntimeError): # type: ignore[no-redef] + pass + + class _DocumentExtractionCancelled(RuntimeError): # type: ignore[no-redef] + pass + + class _DocumentExtractionEncrypted(RuntimeError): # type: ignore[no-redef] + pass + + _EXTRACT_SEMAPHORE = threading.BoundedSemaphore(1) + + +_EXTRACT_MAX_BYTES = 100 * 1024 * 1024 +_EXTRACT_MULTIPART_OVERHEAD_BYTES = 1024 * 1024 +_EXTRACT_READ_CHUNK_BYTES = 64 * 1024 +_EXTRACT_MAX_PAGES_INLINE = 200 +_EXTRACT_TOKEN_BUDGET_DEFAULT = 8000 +_EXTRACT_TOKEN_BUDGET_MIN = 0 + +_DOCX_MIME = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" +_HTML_MIME_TYPES = {"text/html"} +_DATA_MIME_TYPES = { + "application/json", + "application/x-ndjson", + "application/xml", + "application/yaml", + "text/csv", + "text/xml", + "text/yaml", +} +_CODE_MIME_TYPES = { + "application/javascript", + "text/css", + "text/javascript", +} +_DATA_SUFFIXES = {".csv", ".json", ".jsonl", ".yaml", ".yml", ".xml"} +_CODE_SUFFIXES = { + ".py", + ".js", + ".jsx", + ".ts", + ".tsx", + ".go", + ".rs", + ".java", + ".c", + ".cpp", + ".h", + ".hpp", + ".cs", + ".php", + ".rb", + ".swift", + ".kt", + ".kts", + ".scala", + ".sh", + ".bash", + ".zsh", + ".ps1", + ".sql", + ".toml", + ".ini", + ".cfg", + ".css", + ".scss", +} + + +async def _wait_for_document_request_disconnect( + fastapi_request: Request, + cancel_event: threading.Event, +) -> bool: + while not cancel_event.is_set(): + if await fastapi_request.is_disconnected(): + cancel_event.set() + return True + await asyncio.sleep(0.2) + return False + + +def _extract_ext(filename: str) -> str: + return os.path.splitext(filename or "")[1].lower() + + +def _is_supported_upload(filename: str, content_type: str) -> bool: + if (content_type or "").split(";")[0].strip().lower() in _DOC_MIME_OK: + return True + return _extract_ext(filename) in _DOC_SUFFIX_OK + + +def _document_upload_format(filename: str, content_type: str) -> Optional[str]: + mime = (content_type or "").split(";")[0].strip().lower() + ext = _extract_ext(filename) + if mime == "application/pdf" or ext == ".pdf": + return "pdf" + if mime == _DOCX_MIME or ext == ".docx": + return "docx" + if mime in _HTML_MIME_TYPES or ext in {".html", ".htm"}: + return "html" + if mime in _DATA_MIME_TYPES or ext in _DATA_SUFFIXES: + return "data" + if mime in _CODE_MIME_TYPES or ext in _CODE_SUFFIXES: + return "code" + if mime.startswith("text/") or ext in {".md", ".txt", ".log"}: + return "text" + return None + + +def _raise_if_document_parser_unavailable( + filename: str, + content_type: str, +) -> None: + format_key = _document_upload_format(filename, content_type) + if format_key is None: + return + support = _document_parser_support() + if support.get(format_key, True): + return + reason = _document_parser_unavailable_reasons().get( + format_key, + f"{format_key.upper()} extraction is not available on this server.", + ) + raise HTTPException(status_code = 501, detail = reason) + + +def _document_caption_authorization_header( + capability: Any, + llama_backend: Any, + studio_authorization_header: Optional[str], +) -> Optional[str]: + if getattr(capability, "source", None) != "gguf": + return studio_authorization_header + api_key = getattr(llama_backend, "api_key", None) or getattr( + llama_backend, "_api_key", None + ) + return f"Bearer {api_key}" if api_key else None + + +_FORM_TRUE = {"1", "true", "yes", "on"} +_FORM_FALSE = {"0", "false", "no", "off"} + + +def _parse_bool_form(value: Any, *, default: bool, field: str = "value") -> bool: + if value is None: + return default + norm = str(value).strip().lower() + if not norm: + return default + if norm in _FORM_TRUE: + return True + if norm in _FORM_FALSE: + return False + raise HTTPException( + status_code = 400, + detail = f"Invalid boolean value for {field}: {value!r}", + ) + + +def _parse_int_form( + value: Any, + *, + default: int, + lo: int, + hi: Optional[int] = None, +) -> int: + try: + parsed = int(value) if value is not None else default + except (TypeError, ValueError): + parsed = default + parsed = max(lo, parsed) + return min(parsed, hi) if hi is not None else parsed + + +def _reject_oversized_content_length(request: Request) -> None: + raw = request.headers.get("content-length") + if raw is None: + return + try: + total = int(raw) + except ValueError: + raise HTTPException( + status_code = 400, + detail = "Invalid Content-Length header", + ) + max_request_bytes = _EXTRACT_MAX_BYTES + _EXTRACT_MULTIPART_OVERHEAD_BYTES + if total > max_request_bytes: + raise HTTPException( + status_code = 413, + detail = ( + f"Request exceeds the {_EXTRACT_MAX_BYTES // (1024*1024)} MB " + "file limit" + ), + ) + + +async def _iter_request_body_limited(request: Request, *, max_bytes: int): + total = 0 + async for chunk in request.stream(): + if not chunk: + continue + total += len(chunk) + if total > max_bytes: + raise HTTPException( + status_code = 413, + detail = ( + f"Request exceeds the {_EXTRACT_MAX_BYTES // (1024*1024)} MB " + "file limit" + ), + ) + yield chunk + + +async def _read_multipart_form_limited(request: Request, *, max_bytes: int): + from starlette.formparsers import MultiPartException, MultiPartParser + + try: + parser = MultiPartParser( + request.headers, + _iter_request_body_limited(request, max_bytes = max_bytes), + ) + return await parser.parse() + except HTTPException: + raise + except MultiPartException as exc: + raise HTTPException(status_code = 400, detail = exc.message) from exc + + +# Cap on /completions and /embeddings JSON bodies. Those proxy payloads should +# be small (a few prompts + sampling params); 10 MB is generous headroom while +# still protecting against unbounded buffering when a client sends a falsified +# Content-Length and streams a much larger body. +_OPENAI_PROXY_BODY_MAX_BYTES = 10 * 1024 * 1024 +# Chat-completions also carries multimodal data URLs. Keep it bounded, but +# large enough that document extraction's visual-payload budget reaches the +# existing per-image guards instead of being rejected by the JSON body reader +# first. +_OPENAI_CHAT_BODY_IMAGE_SLOTS = max( + 1, + min( + _OPENAI_CHAT_MAX_IMAGES, + _MAX_DOCUMENT_VISUAL_PAYLOADS or _DEFAULT_DOCUMENT_VISUAL_PAYLOADS or 1, + ), +) +_OPENAI_CHAT_BODY_MAX_BYTES = max( + 32 * 1024 * 1024, + (_OPENAI_CHAT_MAX_IMAGE_BASE64_CHARS * _OPENAI_CHAT_BODY_IMAGE_SLOTS) + + (2 * 1024 * 1024), +) + + +async def _read_json_body_limited(request: Request, *, max_bytes: int) -> Any: + """Stream the request body, enforce a hard byte cap, then parse as JSON. + + Unlike trusting Content-Length, this aborts mid-stream once the cap is + exceeded so a spoofed header cannot force the server to buffer arbitrary + payloads before parsing. + """ + total = 0 + chunks: list[bytes] = [] + async for chunk in request.stream(): + if not chunk: + continue + total += len(chunk) + if total > max_bytes: + raise HTTPException( + status_code = 413, + detail = f"Request body exceeds the {max_bytes // (1024 * 1024)} MB limit", + ) + chunks.append(chunk) + raw = b"".join(chunks) + try: + return json.loads(raw) if raw else {} + except json.JSONDecodeError as exc: + raise HTTPException(status_code = 400, detail = f"Invalid JSON body: {exc.msg}") + + +async def _read_upload_limited(upload: Any, *, max_bytes: int) -> bytes: + buf = bytearray() + while True: + chunk = await upload.read(_EXTRACT_READ_CHUNK_BYTES) + if not chunk: + break + buf.extend(chunk) + if len(buf) > max_bytes: + raise HTTPException( + status_code = 413, + detail = f"File exceeds the {max_bytes // (1024*1024)} MB limit", + ) + return bytes(buf) + + +def _is_pdf_upload(filename: str, content_type: str) -> bool: + mime = (content_type or "").split(";")[0].strip().lower() + return mime == "application/pdf" or _extract_ext(filename) == ".pdf" + + +def _preflight_pdf_page_count( + file_bytes: bytes, + filename: str, + content_type: str, +) -> Optional[int]: + if not _is_pdf_upload(filename, content_type): + return None + + pypdf_error: Optional[BaseException] = None + try: + from pypdf import PdfReader + + reader = PdfReader(io.BytesIO(file_bytes), strict = False) + # Many PDFs report ``is_encrypted=True`` even though they only use a + # null/empty user password and open fine (Acrobat-distilled docs, + # the classic Orimi test PDF, scanner output). Try the empty + # password before refusing; PyMuPDF's ``needs_pass`` is the real + # signal in the fallback branch below. + if getattr(reader, "is_encrypted", False): + try: + if reader.decrypt("") == 0: + raise HTTPException( + status_code = 422, + detail = "Encrypted PDFs are not supported for inline extraction", + ) + except HTTPException: + raise + except Exception: + # ``decrypt`` itself failed (corrupt /Encrypt dict, unknown + # algorithm). Fall through to the PyMuPDF fallback rather + # than declaring the file encrypted. + raise RuntimeError("pypdf decrypt probe failed") + return len(reader.pages) + except HTTPException: + raise + except Exception as exc: + pypdf_error = exc + logger.warning( + "pypdf page-count preflight failed for %s; trying PyMuPDF fallback", + filename, + ) + + try: + import pymupdf as _pymupdf # type: ignore + + doc = _pymupdf.open(stream = file_bytes, filetype = "pdf") + try: + # PyMuPDF's ``needs_pass`` is True only when an actual password + # is required. ``is_encrypted`` is True for any file with an + # /Encrypt dict, which includes the common null-password case + # that opens fine. Refuse only when a password is actually + # needed. + if getattr(doc, "needs_pass", False): + raise HTTPException( + status_code = 422, + detail = "Encrypted PDFs are not supported for inline extraction", + ) + return len(doc) + finally: + doc.close() + except HTTPException: + raise + except Exception as exc: + if pypdf_error is not None: + logger.warning( + "PyMuPDF page-count fallback also failed for %s: %s", + filename, + exc, + ) + else: + logger.exception("PDF page-count preflight failed for %s", filename) + raise HTTPException( + status_code = 400, + detail = "Unable to read PDF page count before extraction", + ) from exc + + +def _truncate_markdown_to_token_budget( + markdown: str, + *, + token_budget: int, + original_tokens_est: int, +) -> tuple[str, int, Optional[str]]: + char_budget = max(_EXTRACT_TOKEN_BUDGET_MIN, token_budget) * 4 + if len(markdown) <= char_budget: + return markdown, original_tokens_est, None + + clipped = markdown[:char_budget] + clipped = ( + _re.sub(r"\s+\S*$", "", clipped).rstrip() or markdown[:char_budget].rstrip() + ) + clipped += f"\n\n[... truncated; original was ~{original_tokens_est} tokens ...]" + warning = ( + f"Extracted markdown was truncated to {token_budget} tokens " + f"(original was ~{original_tokens_est} tokens)." + ) + return clipped, max(0, len(clipped) // 4), warning + + +@studio_router.get("/chat/document-support", response_model = DocumentSupportResponse) +async def document_support_endpoint( + fastapi_request: Request, + current_subject: str = Depends(get_current_subject), +): + """Whether document extraction + per-figure captions are available. + + Polled by the frontend when the settings panel mounts and when the + loaded model changes. The response drives the "describe figures" + toggle: when ``vlm.is_vlm`` is false the UI disables the toggle and + surfaces ``vlm.reason`` as tooltip text. + """ + if _extract_document is None or _detect_loaded_vlm is None: + return DocumentSupportResponse( + extraction_available = False, + max_visual_payloads = 0, + max_extract_concurrency = 1, + format_support = {}, + unavailable_formats = {}, + vlm = { + "is_vlm": False, + "endpoint_url": None, + "model_name": None, + "source": "none", + "reason": "document extraction backend is not installed", + }, + ) + + self_base_url = ( + _extract_self_base_url(fastapi_request) if _extract_self_base_url else None + ) + try: + cap = _detect_loaded_vlm( + self_base_url, + llama_backend = get_llama_cpp_backend(), + ) + except Exception as exc: + logger.exception("Document support VLM probe failed") + if _VlmCapability is not None: + cap = _VlmCapability.none( + f"document support probe failed: {type(exc).__name__}" + ) + else: # pragma: no cover - only when core.chat import fallback is active + cap = None + return DocumentSupportResponse( + extraction_available = True, + max_visual_payloads = _MAX_DOCUMENT_VISUAL_PAYLOADS, + max_extract_concurrency = _DOCUMENT_EXTRACT_CONCURRENCY, + format_support = _document_parser_support(), + unavailable_formats = _document_parser_unavailable_reasons(), + vlm = cap.to_dict() + if cap is not None + else { + "is_vlm": False, + "endpoint_url": None, + "model_name": None, + "source": "none", + "reason": "document support probe failed", + }, + ) + + +@studio_router.post("/chat/extract-document") +async def extract_document_endpoint( + fastapi_request: Request, + current_subject: str = Depends(get_current_subject), +): + """Upload a PDF / DOCX / HTML / MD / text file and stream + progress events plus a final layout-aware Markdown payload. + + Response is NDJSON (one JSON object per line). Validation errors + raised before streaming begins return as standard HTTP 4xx/5xx. + Once the stream starts, the final line is `{"stage":"result", ...}` + or `{"stage":"error", ...}`. Large documents (>200 pages) are + rejected with 413 until the background-job path lands. + """ + if _extract_document is None: + raise HTTPException( + status_code = 501, + detail = ( + "document extraction backend is not installed. Re-run Studio " + "setup to install the parser dependencies." + ), + ) + + _reject_oversized_content_length(fastapi_request) + + try: + try: + form = await _read_multipart_form_limited( + fastapi_request, + max_bytes = _EXTRACT_MAX_BYTES + _EXTRACT_MULTIPART_OVERHEAD_BYTES, + ) + except HTTPException: + raise + except Exception as exc: + logger.exception("Invalid multipart document extraction payload") + raise HTTPException(status_code = 400, detail = "Invalid multipart payload") + + upload = form.get("file") + if upload is None or not hasattr(upload, "read"): + raise HTTPException(status_code = 400, detail = "Missing 'file' field") + + filename = getattr(upload, "filename", None) or "upload" + content_type = getattr(upload, "content_type", "") or "" + if not _is_supported_upload(filename, content_type): + raise HTTPException( + status_code = 415, + detail = f"Unsupported file type: {filename} ({content_type})", + ) + _raise_if_document_parser_unavailable(filename, content_type) + + file_bytes = await _read_upload_limited(upload, max_bytes = _EXTRACT_MAX_BYTES) + if not file_bytes: + raise HTTPException(status_code = 400, detail = "Uploaded file is empty") + + preflight_page_count = _preflight_pdf_page_count( + file_bytes, filename, content_type + ) + if ( + preflight_page_count is not None + and preflight_page_count > _EXTRACT_MAX_PAGES_INLINE + ): + raise HTTPException( + status_code = 413, + detail = ( + f"Document has {preflight_page_count} pages; inline extraction " + f"is capped at {_EXTRACT_MAX_PAGES_INLINE}. Split into smaller " + f"documents or reduce the page range." + ), + ) + + describe_images = _parse_bool_form( + form.get("describe_images"), default = False, field = "describe_images" + ) + use_vlm_ocr = _parse_bool_form( + form.get("use_vlm_ocr"), default = False, field = "use_vlm_ocr" + ) + max_figures = _parse_int_form( + form.get("max_figures"), + default = 40, + lo = 0, + ) + max_visual_payloads = _parse_int_form( + form.get("max_visual_payloads"), + default = _DEFAULT_DOCUMENT_VISUAL_PAYLOADS, + lo = 0, + ) + token_budget = _parse_int_form( + form.get("token_budget"), + default = _EXTRACT_TOKEN_BUDGET_DEFAULT, + lo = 0, + ) + + self_base_url = ( + _extract_self_base_url(fastapi_request) if _extract_self_base_url else None + ) + llama_backend = get_llama_cpp_backend() + capability = ( + _detect_loaded_vlm( + self_base_url, + llama_backend = llama_backend, + ) + if _detect_loaded_vlm + else None + ) + caption_authorization_header = _document_caption_authorization_header( + capability, + llama_backend, + fastapi_request.headers.get("authorization"), + ) + + if await fastapi_request.is_disconnected(): + raise HTTPException(status_code = 499, detail = "Client closed request") + + accept_header = (fastapi_request.headers.get("accept", "") or "").lower() + wants_stream = "application/x-ndjson" in accept_header + + def _build_response_payload(result: Any) -> ExtractDocumentResponse: + markdown_, tokens_est_, truncate_warning_ = ( + _truncate_markdown_to_token_budget( + result.markdown, + token_budget = token_budget, + original_tokens_est = result.tokens_est, + ) + ) + warnings_ = list(result.warnings) + if truncate_warning_: + warnings_.append(truncate_warning_) + return ExtractDocumentResponse( + filename = filename, + markdown = markdown_, + page_count = result.page_count, + tokens_est = tokens_est_, + truncated = truncate_warning_ is not None, + figures = [ExtractedFigureModel(**_asdict(f)) for f in result.figures], + describe_skipped_reason = result.describe_skipped_reason, + vlm_source = result.vlm_source, + vlm_model = result.vlm_model, + image_input_available = getattr(result, "image_input_available", False), + warnings = warnings_, + ) + + if not wants_stream: + # ---- Legacy JSON path (no progress events) ----------------- + cancel_event = threading.Event() + extraction_task = asyncio.create_task( + _extract_document( + file_bytes, + filename, + content_type = content_type, + describe_images = describe_images, + use_vlm_ocr = use_vlm_ocr, + max_figures = max_figures, + max_visual_payloads = max_visual_payloads, + capability = capability, + self_base_url = self_base_url, + authorization_header = caption_authorization_header, + cancel_event = cancel_event, + ) + ) + disconnect_task = asyncio.create_task( + _wait_for_document_request_disconnect(fastapi_request, cancel_event) + ) + try: + done, _pending = await asyncio.wait( + {extraction_task, disconnect_task}, + return_when = asyncio.FIRST_COMPLETED, + ) + if extraction_task in done: + result = await extraction_task + elif disconnect_task in done and disconnect_task.result(): + cancel_event.set() + with suppress( + _DocumentExtractionCancelled, + asyncio.CancelledError, + asyncio.TimeoutError, + ): + await asyncio.wait_for( + asyncio.shield(extraction_task), timeout = 10 + ) + if not extraction_task.done(): + extraction_task.cancel() + raise _DocumentExtractionCancelled( + "document extraction was cancelled" + ) + else: + result = await extraction_task + except _DocumentExtractionUnavailable as exc: + raise HTTPException(status_code = 501, detail = str(exc)) + except _DocumentExtractionTimeout: + raise HTTPException( + status_code = 504, + detail = "Document parsing timed out after 120s before image captioning", + ) + except _DocumentExtractionBusy: + raise HTTPException( + status_code = 503, detail = "Document extraction is busy" + ) + except _DocumentExtractionCancelled: + raise HTTPException(status_code = 499, detail = "Client closed request") + except _DocumentExtractionEncrypted as exc: + raise HTTPException(status_code = 422, detail = str(exc)) + except ValueError as exc: + detail = str(exc) + status_code = ( + 415 if detail.lower().startswith("unsupported file type") else 400 + ) + raise HTTPException(status_code = status_code, detail = detail) + except Exception: + logger.exception("Document extraction failed for %s", filename) + raise HTTPException(status_code = 500, detail = "Extraction failed") + finally: + cancel_event.set() + disconnect_task.cancel() + with suppress(asyncio.CancelledError): + await disconnect_task + + if result.page_count > _EXTRACT_MAX_PAGES_INLINE: + raise HTTPException( + status_code = 413, + detail = ( + f"Document has {result.page_count} pages; inline extraction " + f"is capped at {_EXTRACT_MAX_PAGES_INLINE}. Split into smaller " + f"documents or reduce the page range." + ), + ) + return _build_response_payload(result) + + # ---- Streaming NDJSON path (Accept: application/x-ndjson) ------ + progress_queue: asyncio.Queue = asyncio.Queue() + + async def _progress_cb(event: dict) -> None: + await progress_queue.put(dict(event)) + + async def _ndjson_stream(): + cancel_event = threading.Event() + extraction_task = asyncio.create_task( + _extract_document( + file_bytes, + filename, + content_type = content_type, + describe_images = describe_images, + use_vlm_ocr = use_vlm_ocr, + max_figures = max_figures, + max_visual_payloads = max_visual_payloads, + capability = capability, + self_base_url = self_base_url, + authorization_header = caption_authorization_header, + cancel_event = cancel_event, + progress_cb = _progress_cb, + ) + ) + # Always drain the task's exception so a busy/cancel race + # doesn't leave an orphan "Future exception was never retrieved" + # in the logs when the body iterator exits early. + extraction_task.add_done_callback(_drain_doc_future_exception) + disconnect_task = asyncio.create_task( + _wait_for_document_request_disconnect(fastapi_request, cancel_event) + ) + try: + extract_wait = asyncio.ensure_future(asyncio.shield(extraction_task)) + extract_wait.add_done_callback(_drain_doc_future_exception) + while True: + queue_get = asyncio.ensure_future(progress_queue.get()) + queue_get.add_done_callback(_drain_doc_future_exception) + done, _pending = await asyncio.wait( + {queue_get, extract_wait, disconnect_task}, + return_when = asyncio.FIRST_COMPLETED, + ) + if queue_get in done: + event = queue_get.result() + yield json.dumps(event) + "\n" + else: + queue_get.cancel() + with suppress(asyncio.CancelledError): + await queue_get + + if disconnect_task in done and disconnect_task.result(): + cancel_event.set() + with suppress( + _DocumentExtractionCancelled, + asyncio.CancelledError, + asyncio.TimeoutError, + ): + await asyncio.wait_for( + asyncio.shield(extraction_task), timeout = 10 + ) + if not extraction_task.done(): + extraction_task.cancel() + raise _DocumentExtractionCancelled( + "document extraction was cancelled" + ) + + # The shield-wrapper may complete (cancelled) before + # the underlying extraction_task is done; calling + # ``.result()`` in that window raises + # InvalidStateError. Wait for the real task before + # consuming its result. + if extraction_task.done(): + # Drain any remaining progress events before result. + while not progress_queue.empty(): + try: + event = progress_queue.get_nowait() + except asyncio.QueueEmpty: + break + yield json.dumps(event) + "\n" + result = extraction_task.result() + break + if extract_wait in done: + # Shield-wrapper finished but the real task is + # still running. Re-arm the wait on a fresh + # shielded future and loop. + extract_wait = asyncio.ensure_future( + asyncio.shield(extraction_task) + ) + extract_wait.add_done_callback( + _drain_doc_future_exception + ) + + if result.page_count > _EXTRACT_MAX_PAGES_INLINE: + yield ( + json.dumps( + { + "stage": "error", + "status_code": 413, + "detail": ( + f"Document has {result.page_count} pages; inline extraction " + f"is capped at {_EXTRACT_MAX_PAGES_INLINE}. Split into smaller " + f"documents or reduce the page range." + ), + } + ) + + "\n" + ) + return + + response = _build_response_payload(result) + yield ( + json.dumps( + { + "stage": "result", + "data": response.model_dump(mode = "json"), + } + ) + + "\n" + ) + except _DocumentExtractionUnavailable as exc: + yield ( + json.dumps( + { + "stage": "error", + "status_code": 501, + "detail": str(exc), + } + ) + + "\n" + ) + except _DocumentExtractionTimeout: + yield ( + json.dumps( + { + "stage": "error", + "status_code": 504, + "detail": "Document parsing timed out after 120s before image captioning", + } + ) + + "\n" + ) + except _DocumentExtractionBusy: + yield ( + json.dumps( + { + "stage": "error", + "status_code": 503, + "detail": "Document extraction is busy", + } + ) + + "\n" + ) + except _DocumentExtractionCancelled: + yield ( + json.dumps( + { + "stage": "error", + "status_code": 499, + "detail": "Client closed request", + } + ) + + "\n" + ) + except _DocumentExtractionEncrypted as exc: + yield ( + json.dumps( + { + "stage": "error", + "status_code": 422, + "detail": str(exc), + } + ) + + "\n" + ) + except ValueError as exc: + detail = str(exc) + status_code = ( + 415 if detail.lower().startswith("unsupported file type") else 400 + ) + yield ( + json.dumps( + { + "stage": "error", + "status_code": status_code, + "detail": detail, + } + ) + + "\n" + ) + except Exception: + logger.exception("Document extraction failed for %s", filename) + yield ( + json.dumps( + { + "stage": "error", + "status_code": 500, + "detail": "Extraction failed", + } + ) + + "\n" + ) + finally: + cancel_event.set() + disconnect_task.cancel() + with suppress(asyncio.CancelledError): + await disconnect_task + + return StreamingResponse( + _ndjson_stream(), + media_type = "application/x-ndjson", + ) + finally: + # _EXTRACT_SEMAPHORE is owned solely by _run_extract_process_sync; the + # worker maps a busy semaphore to DocumentExtractionBusy → an in-stream + # error event above. + pass diff --git a/studio/backend/routes/models.py b/studio/backend/routes/models.py index 9ea113e488..826da462cf 100644 --- a/studio/backend/routes/models.py +++ b/studio/backend/routes/models.py @@ -12,7 +12,8 @@ import sys import uuid from pathlib import Path -from fastapi import APIRouter, Body, Depends, HTTPException, Query +from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request +from pydantic import BaseModel, Field from typing import List, Optional import structlog from loggers import get_logger @@ -139,6 +140,24 @@ def _safe_is_dir(path) -> bool: logger = get_logger(__name__) +class ModelProbeRequest(BaseModel): + model_name: str = Field(..., description = "Model identifier or local path") + hf_token: Optional[str] = Field( + None, description = "HuggingFace token for gated/private models" + ) + trust_remote_code: bool = Field( + False, description = "Allow probes that require custom model code" + ) + + +def _reject_hf_token_query(request: Request) -> None: + if "hf_token" in request.query_params: + raise HTTPException( + status_code = 400, + detail = "HF tokens must be sent with POST JSON probe endpoints, not GET query parameters.", + ) + + def derive_model_type( is_vision: bool, audio_type: Optional[str], is_embedding: bool = False ) -> ModelType: @@ -152,6 +171,40 @@ def derive_model_type( return "text" +def _defaults_vision_flags(config_dict: dict) -> tuple[bool, bool]: + model_config = config_dict.get("model", {}) if isinstance(config_dict, dict) else {} + inference_config = ( + config_dict.get("inference", {}) if isinstance(config_dict, dict) else {} + ) + yaml_is_vision = bool(model_config.get("is_vision", False)) + yaml_requires_trust_remote_code = bool( + model_config.get("trust_remote_code", False) + or inference_config.get("trust_remote_code", False) + ) + return yaml_is_vision, yaml_requires_trust_remote_code + + +def _detect_vision_for_config_endpoint( + model_name: str, + *, + hf_token: Optional[str] = None, + trust_remote_code: bool = False, + config_dict: Optional[dict] = None, +) -> bool: + defaults = ( + config_dict if config_dict is not None else load_model_defaults(model_name) + ) + yaml_is_vision, yaml_requires_trust_remote_code = _defaults_vision_flags(defaults) + if yaml_is_vision and yaml_requires_trust_remote_code: + return True + detected = is_vision_model( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) + return detected + + def _resolve_hf_cache_dir() -> Path: """Resolve local HF cache root used by hub downloads.""" try: @@ -1479,7 +1532,7 @@ async def list_models( loaded_models.append(model_info) # Include active GGUF model (loaded via llama-server) - from routes.inference import get_llama_cpp_backend + from core.inference.llama_cpp import get_llama_cpp_backend llama_backend = get_llama_cpp_backend() if llama_backend.is_loaded and llama_backend.model_identifier: @@ -1562,9 +1615,35 @@ def _get_model_size_bytes( @router.get("/config/{model_name:path}") async def get_model_config( + request: Request, model_name: str, - hf_token: Optional[str] = Query(None), + trust_remote_code: bool = False, current_subject: str = Depends(get_current_subject), +): + _reject_hf_token_query(request) + return await _build_model_config_response( + model_name, + hf_token = None, + trust_remote_code = trust_remote_code, + ) + + +@router.post("/config") +async def post_model_config( + request: ModelProbeRequest, + current_subject: str = Depends(get_current_subject), +): + return await _build_model_config_response( + request.model_name, + hf_token = request.hf_token, + trust_remote_code = request.trust_remote_code, + ) + + +async def _build_model_config_response( + model_name: str, + hf_token: Optional[str] = None, + trust_remote_code: bool = False, ): """ Get configuration for a specific model. @@ -1589,7 +1668,12 @@ async def get_model_config( config_dict = load_model_defaults(model_name) # Detect model capabilities (pass HF token for gated models) - is_vision = is_vision_model(model_name, hf_token = hf_token) + is_vision = _detect_vision_for_config_endpoint( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + config_dict = config_dict, + ) is_embedding = is_embedding_model(model_name, hf_token = hf_token) audio_type = detect_audio_type(model_name, hf_token = hf_token) @@ -1598,7 +1682,11 @@ async def get_model_config( base_model = None max_position_embeddings = None try: - model_config = ModelConfig.from_identifier(model_name) + model_config = ModelConfig.from_identifier( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) is_lora = model_config.is_lora base_model = model_config.base_model if is_lora else None max_position_embeddings = _get_max_position_embeddings(model_config) @@ -2068,8 +2156,35 @@ async def get_lora_base_model( @router.get("/check-vision/{model_name:path}", response_model = VisionCheckResponse) async def check_vision_model( + request: Request, model_name: str, + trust_remote_code: bool = False, current_subject: str = Depends(get_current_subject), +): + _reject_hf_token_query(request) + return await _check_vision_model_response( + model_name, + hf_token = None, + trust_remote_code = trust_remote_code, + ) + + +@router.post("/check-vision", response_model = VisionCheckResponse) +async def post_check_vision_model( + request: ModelProbeRequest, + current_subject: str = Depends(get_current_subject), +): + return await _check_vision_model_response( + request.model_name, + hf_token = request.hf_token, + trust_remote_code = request.trust_remote_code, + ) + + +async def _check_vision_model_response( + model_name: str, + hf_token: Optional[str] = None, + trust_remote_code: bool = False, ): """ Check if a model is a vision model. @@ -2078,7 +2193,11 @@ async def check_vision_model( """ try: logger.info(f"Checking if vision model: {model_name}") - is_vision = is_vision_model(model_name) + is_vision = _detect_vision_for_config_endpoint( + model_name, + hf_token = hf_token, + trust_remote_code = trust_remote_code, + ) logger.info(f"Vision check result for {model_name}: is_vision={is_vision}") return VisionCheckResponse( @@ -2603,7 +2722,7 @@ async def delete_cached_model( # Check if model is currently loaded try: - from routes.inference import get_llama_cpp_backend + from core.inference.llama_cpp import get_llama_cpp_backend llama_backend = get_llama_cpp_backend() if llama_backend.is_loaded and llama_backend.model_identifier: diff --git a/studio/backend/run.py b/studio/backend/run.py index 3bde8abd3c..5e5da55858 100644 --- a/studio/backend/run.py +++ b/studio/backend/run.py @@ -494,11 +494,15 @@ def _graceful_shutdown(server = None): logger.warning("Error shutting down training subprocess: %s", e) # 5. Kill llama-server subprocess (if loaded) + # + # Read the module-level singleton directly so we don't instantiate a + # fresh backend during shutdown when none was ever loaded. try: - from routes.inference import _llama_cpp_backend + from core.inference import llama_cpp as _llama_cpp_mod - if _llama_cpp_backend is not None: - _llama_cpp_backend._kill_process() + backend = getattr(_llama_cpp_mod, "_llama_cpp_backend", None) + if backend is not None: + backend._kill_process() except Exception as e: logger.warning("Error shutting down llama-server: %s", e) diff --git a/studio/backend/tests/test_anthropic_messages.py b/studio/backend/tests/test_anthropic_messages.py index 842429d5af..7f0cf5d56a 100644 --- a/studio/backend/tests/test_anthropic_messages.py +++ b/studio/backend/tests/test_anthropic_messages.py @@ -34,6 +34,7 @@ AnthropicStreamEmitter, AnthropicPassthroughEmitter, ) +import routes.inference as route from routes.inference import ( _normalize_anthropic_openai_images, _select_anthropic_server_tools, @@ -1056,6 +1057,24 @@ def test_bad_base64_raises_400(self): _normalize_anthropic_openai_images(msgs, is_vision = True) assert exc.value.status_code == 400 + def test_image_count_limit_applies(self, monkeypatch): + monkeypatch.setattr(route, "_OPENAI_CHAT_MAX_IMAGES", 1) + data_url = _jpeg_data_url() + msgs = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": data_url}}, + {"type": "image_url", "image_url": {"url": data_url}}, + ], + } + ] + + with pytest.raises(HTTPException) as exc: + _normalize_anthropic_openai_images(msgs, is_vision = True) + + assert exc.value.status_code == 413 + # ===================================================================== # Studio-tool alias detection (/v1/messages tool routing) diff --git a/studio/backend/tests/test_chat_document_extraction.py b/studio/backend/tests/test_chat_document_extraction.py new file mode 100644 index 0000000000..3d89883952 --- /dev/null +++ b/studio/backend/tests/test_chat_document_extraction.py @@ -0,0 +1,906 @@ +# SPDX-License-Identifier: AGPL-3.0-only +# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 + +""" +Tests for the chat document extractor + VLM capability probe. + +Probe tests run regardless of the extraction backend because they only +shape-check :mod:`core.chat.vlm_capability`. Backend-backed tests skip +cleanly when the optional deps (pymupdf / pymupdf4llm / mammoth) are +missing. +""" + +from __future__ import annotations + +import importlib.util +import sys +from types import ModuleType, SimpleNamespace +from typing import Any, Dict, Optional + +import pytest + +from core.chat.vlm_capability import ( + VlmCapability, + detect_loaded_vlm, + extract_self_base_url, +) + + +# ---------------------------------------------------------------------- # +# VlmCapability dataclass # +# ---------------------------------------------------------------------- # + + +def test_vlm_capability_none_factory_is_safe_default() -> None: + cap = VlmCapability.none() + assert cap.is_vlm is False + assert cap.endpoint_url is None + assert cap.model_name is None + assert cap.source == "none" + assert cap.reason # non-empty + + +def test_vlm_capability_to_dict_round_trips_fields() -> None: + cap = VlmCapability( + is_vlm = True, + endpoint_url = "http://127.0.0.1:8080", + model_name = "qwen2-vl", + source = "gguf", + reason = None, + ) + assert cap.to_dict() == { + "is_vlm": True, + "endpoint_url": "http://127.0.0.1:8080", + "model_name": "qwen2-vl", + "source": "gguf", + "reason": None, + } + + +# ---------------------------------------------------------------------- # +# detect_loaded_vlm() across backend shapes # +# ---------------------------------------------------------------------- # + + +class _FakeLlama: + def __init__( + self, + *, + loaded: bool, + vision: bool = False, + base_url: str = "http://127.0.0.1:8080", + model_id: str = "fake-gguf", + ) -> None: + self.is_loaded = loaded + self.is_vision = vision + self.base_url = base_url + self.model_identifier = model_id + + +class _FakeInferenceBackend: + def __init__( + self, + *, + active: Optional[str], + info: Optional[Dict[str, Any]] = None, + ) -> None: + self.active_model_name = active + self.models: Dict[str, Dict[str, Any]] = {active: info or {}} if active else {} + + +def _patch_probes( + monkeypatch: pytest.MonkeyPatch, + *, + llama: Optional[_FakeLlama], + inference: Optional[_FakeInferenceBackend], +) -> None: + from core.chat import vlm_capability as vc + + if llama is None: + monkeypatch.setattr(vc, "_probe_gguf", lambda _llama = None: None) + else: + + def probe_gguf(llama_backend = None): + backend = llama_backend or llama + if not backend.is_loaded: + return None + is_vision = bool(backend.is_vision) + return VlmCapability( + is_vlm = is_vision, + endpoint_url = backend.base_url, + model_name = backend.model_identifier, + source = "gguf", + reason = None if is_vision else "loaded GGUF is not vision-capable", + ) + + monkeypatch.setattr(vc, "_probe_gguf", probe_gguf) + + if inference is None: + monkeypatch.setattr(vc, "_probe_transformers", lambda _u: None) + else: + + def probe_tf(self_base_url): + name = inference.active_model_name + if not name: + return None + info = inference.models.get(name) or {} + is_vision = bool(info.get("is_vision", False)) + source = "unsloth" if info.get("is_lora") else "transformers" + if not self_base_url: + return VlmCapability( + is_vlm = False, + endpoint_url = None, + model_name = name, + source = source, + reason = "cannot self-loopback: request base URL unavailable", + ) + return VlmCapability( + is_vlm = is_vision, + endpoint_url = self_base_url.rstrip("/"), + model_name = name, + source = source, + reason = None if is_vision else "loaded model is not vision-capable", + ) + + monkeypatch.setattr(vc, "_probe_transformers", probe_tf) + + +def test_detect_returns_none_when_no_model_loaded( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_probes(monkeypatch, llama = None, inference = None) + cap = detect_loaded_vlm() + assert cap.source == "none" + assert cap.is_vlm is False + + +def test_detect_gguf_vision_returns_llama_endpoint( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llama = _FakeLlama(loaded = True, vision = True, base_url = "http://127.0.0.1:9999") + _patch_probes(monkeypatch, llama = llama, inference = None) + cap = detect_loaded_vlm("http://studio.local") + assert cap.source == "gguf" + assert cap.is_vlm is True + assert cap.endpoint_url == "http://127.0.0.1:9999" # GGUF ignores self_base_url + assert cap.reason is None + + +def test_detect_gguf_vision_accepts_injected_backend( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from core.chat import vlm_capability as vc + + llama = _FakeLlama(loaded = True, vision = True, base_url = "http://127.0.0.1:9999") + monkeypatch.setattr(vc, "_probe_transformers", lambda _u: None) + + cap = detect_loaded_vlm( + "http://127.0.0.1:8000", + llama_backend = llama, + ) + + assert cap.source == "gguf" + assert cap.is_vlm is True + assert cap.endpoint_url == "http://127.0.0.1:9999" + + +def test_detect_gguf_vision_uses_core_llama_accessor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The implicit GGUF fallback must use the core-owned singleton path.""" + from core.chat import vlm_capability as vc + from core.inference import llama_cpp + + llama = _FakeLlama(loaded = True, vision = True, base_url = "http://127.0.0.1:9999") + assert hasattr(llama_cpp, "get_llama_cpp_backend") + monkeypatch.setattr(llama_cpp, "_llama_cpp_backend", llama) + monkeypatch.setattr(vc, "_probe_transformers", lambda _u: None) + + cap = detect_loaded_vlm("http://127.0.0.1:8000") + + assert cap.source == "gguf" + assert cap.is_vlm is True + assert cap.endpoint_url == "http://127.0.0.1:9999" + + +def test_detect_gguf_non_vision_surfaces_reason( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llama = _FakeLlama(loaded = True, vision = False) + _patch_probes(monkeypatch, llama = llama, inference = None) + cap = detect_loaded_vlm() + assert cap.source == "gguf" + assert cap.is_vlm is False + assert cap.reason and "vision" in cap.reason.lower() + + +def test_detect_transformers_vision_uses_self_loopback( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ib = _FakeInferenceBackend( + active = "Qwen2-VL-7B", + info = {"is_vision": True, "is_lora": False}, + ) + _patch_probes(monkeypatch, llama = None, inference = ib) + cap = detect_loaded_vlm("http://127.0.0.1:8000/") + assert cap.source == "transformers" + assert cap.is_vlm is True + assert cap.endpoint_url == "http://127.0.0.1:8000" + assert cap.model_name == "Qwen2-VL-7B" + + +def test_detect_unsloth_lora_vision_reports_unsloth_source( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ib = _FakeInferenceBackend( + active = "my-qwen-vl-lora", + info = {"is_vision": True, "is_lora": True}, + ) + _patch_probes(monkeypatch, llama = None, inference = ib) + cap = detect_loaded_vlm("http://studio.local:8000") + assert cap.source == "unsloth" + assert cap.is_vlm is True + + +def test_detect_falls_through_when_gguf_is_loaded_but_endpoint_data_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A half-initialised llama-server (is_loaded=True but base_url/model + missing) must not suppress the transformers fallback path — otherwise + a misleading non-vision GGUF result hides an active transformers VLM. + """ + from core.chat import vlm_capability as vc + + fake_llama_cpp = ModuleType("core.inference.llama_cpp") + fake_llama_cpp.get_llama_cpp_backend = lambda: _FakeLlama( + loaded = True, + base_url = "", + model_id = "", + ) + fake_inference = ModuleType("core.inference") + fake_inference.__path__ = [] # type: ignore[attr-defined] + fake_inference.llama_cpp = fake_llama_cpp # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "core.inference", fake_inference) + monkeypatch.setitem(sys.modules, "core.inference.llama_cpp", fake_llama_cpp) + + ib = _FakeInferenceBackend( + active = "Qwen2-VL-7B", + info = {"is_vision": True, "is_lora": False}, + ) + monkeypatch.setattr( + vc, + "_probe_transformers", + lambda self_base_url: VlmCapability( + is_vlm = True, + endpoint_url = self_base_url.rstrip("/") if self_base_url else None, + model_name = ib.active_model_name, + source = "transformers", + reason = None, + ), + ) + + cap = detect_loaded_vlm("http://127.0.0.1:8000") + assert cap.source == "transformers" + assert cap.is_vlm is True + + +def test_detect_transformers_without_self_url_reports_missing_loopback( + monkeypatch: pytest.MonkeyPatch, +) -> None: + ib = _FakeInferenceBackend( + active = "Qwen2-VL-7B", + info = {"is_vision": True, "is_lora": False}, + ) + _patch_probes(monkeypatch, llama = None, inference = ib) + cap = detect_loaded_vlm(None) + assert cap.is_vlm is False + assert cap.reason and "loopback" in cap.reason.lower() + + +# ---------------------------------------------------------------------- # +# extract_self_base_url — request base-URL extraction # +# ---------------------------------------------------------------------- # + + +class _FakeState: + def __init__(self, server_port: Optional[int] = None) -> None: + if server_port is not None: + self.server_port = server_port + + +class _FakeApp: + def __init__(self, server_port: Optional[int] = None) -> None: + self.state = _FakeState(server_port) + + +class _FakeRequest: + def __init__( + self, + base_url: str, + *, + server_port: Optional[int] = None, + scope_server: Optional[tuple[str, int]] = None, + ) -> None: + self.base_url = base_url + self.app = _FakeApp(server_port) + self.scope = {"server": scope_server} if scope_server else {} + + +def test_extract_self_base_url_strips_trailing_slash() -> None: + assert ( + extract_self_base_url(_FakeRequest("http://127.0.0.1:8000/")) + == "http://127.0.0.1:8000" + ) + + +def test_extract_self_base_url_prefers_trusted_server_port() -> None: + assert ( + extract_self_base_url( + _FakeRequest( + "http://attacker.invalid:9999/", + server_port = 7777, + scope_server = ("127.0.0.1", 6666), + ) + ) + == "http://127.0.0.1:7777" + ) + assert ( + extract_self_base_url( + _FakeRequest( + "http://attacker.invalid:9999/", + scope_server = ("127.0.0.1", 6666), + ) + ) + == "http://127.0.0.1:6666" + ) + + +def test_extract_self_base_url_ignores_host_header() -> None: + assert ( + extract_self_base_url(_FakeRequest("http://studio.local:8000/")) + == "http://127.0.0.1:8000" + ) + assert ( + extract_self_base_url(_FakeRequest("https://example.com:9443/")) + == "http://127.0.0.1:9443" + ) + + +def test_extract_self_base_url_none_when_empty() -> None: + assert extract_self_base_url(_FakeRequest("")) is None + + +def test_extract_self_base_url_none_on_missing_attribute() -> None: + assert extract_self_base_url(object()) is None + + +# ---------------------------------------------------------------------- # +# extract_document orchestration — backend-agnostic (monkey-patched) # +# ---------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_max_figures_zero_sets_describe_skipped_reason( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """max_figures=0 must skip description with a specific diagnostic even + when a VLM is available.""" + from core.chat import document_extractor as de + + def fake_extract(_fb, _fn, _opts, _ct = ""): + return "# Smoke\n", [], 1, 0, 0 + + monkeypatch.setattr(de, "DOCUMENT_EXTRACTION_AVAILABLE", True) + monkeypatch.setattr(de, "_run_extract_sync", fake_extract) + + result = await de.extract_document( + b"# Smoke\n", + "sample.md", + describe_images = True, + max_figures = 0, + capability = VlmCapability( + is_vlm = True, + endpoint_url = "http://127.0.0.1:8000", + model_name = "vlm", + source = "transformers", + ), + ) + + assert result.describe_skipped_reason == ( + "figure description disabled because max_figures is 0" + ) + assert result.markdown == "# Smoke\n" + assert result.figures == [] + + +@pytest.mark.asyncio +async def test_run_extract_sync_seam_receives_content_type( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The test seam path (monkeypatched _run_extract_sync) must be invoked + with the content_type so dispatch-by-content-type can be exercised in + tests, not only by filename suffix.""" + from core.chat import document_extractor as de + + received: dict[str, str] = {} + + def fake_extract(_fb, _fn, _opts, ct = ""): + received["content_type"] = ct + return "ok", [], 0, 0, 0 + + monkeypatch.setattr(de, "DOCUMENT_EXTRACTION_AVAILABLE", True) + monkeypatch.setattr(de, "_run_extract_sync", fake_extract) + + await de.extract_document( + b"hello", + "no-suffix-file", + content_type = "text/plain", + describe_images = False, + ) + assert received["content_type"] == "text/plain" + + +@pytest.mark.asyncio +async def test_describe_image_via_vlm_sends_auth_header_and_max_tokens( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from core.chat import document_extractor as de + + captured: dict[str, Any] = {} + + class FakeResponse: + status_code = 200 + + def json(self): + return {"choices": [{"message": {"content": "A chart."}}]} + + class FakeAsyncClient: + def __init__(self, *, timeout: float) -> None: + captured["timeout"] = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, *_args): + return None + + async def post(self, url, *, headers, json): + captured["url"] = url + captured["headers"] = headers + captured["json"] = json + return FakeResponse() + + fake_httpx = ModuleType("httpx") + fake_httpx.AsyncClient = FakeAsyncClient + monkeypatch.setitem(sys.modules, "httpx", fake_httpx) + + caption, error = await de._describe_image_via_vlm( + image_base64 = "abc", + image_mime = "image/jpeg", + endpoint_url = "http://127.0.0.1:8000", + model_name = "vlm", + authorization_header = "Bearer token", + timeout_seconds = 7, + ) + + assert caption == "A chart." + assert error is None + assert captured["url"] == "http://127.0.0.1:8000/v1/chat/completions" + assert captured["headers"]["Authorization"] == "Bearer token" + assert captured["json"]["max_tokens"] == 512 + assert "max_completion_tokens" not in captured["json"] + + +# ---------------------------------------------------------------------- # +# Backend dispatch — real _run_extract_sync (requires pymupdf/mammoth) # +# ---------------------------------------------------------------------- # + + +_BACKEND_INSTALLED = ( + importlib.util.find_spec("pymupdf") is not None + and importlib.util.find_spec("pymupdf4llm") is not None + and importlib.util.find_spec("mammoth") is not None +) + + +def test_run_extract_sync_rejects_pptx_with_value_error() -> None: + """PPTX was dropped in the PyMuPDF4LLM migration. _run_extract_sync + must raise ValueError so the route can map it to HTTP 415.""" + if not _BACKEND_INSTALLED: + pytest.skip("extraction backend not installed") + from core.chat import document_extractor as de + + with pytest.raises(ValueError): + de._run_extract_sync( + b"PK\x03\x04", + "deck.pptx", + {"max_figures": 0, "extract_images": False, "use_vlm_ocr": False}, + ) + + +def test_run_extract_sync_text_path_decodes_utf8() -> None: + """TXT / MD paths must not require PDF/DOCX parser dependencies.""" + from core.chat import document_extractor as de + + md, figs, pages, trunc, seen = de._run_extract_sync( + "# Héllo\n".encode("utf-8"), + "notes.md", + {"max_figures": 0, "extract_images": False, "use_vlm_ocr": False}, + ) + assert md == "# Héllo\n" + assert figs == [] + assert pages == 0 and trunc == 0 and seen == 0 + + +def test_run_extract_sync_html_converts_to_markdown_without_parser_deps() -> None: + """HTML must be cleaned before prompt injection and not depend on PDF/DOCX deps.""" + from core.chat import document_extractor as de + + md, figs, pages, trunc, seen = de._run_extract_sync( + b"

Title

Hello world

", + "page.html", + {"max_figures": 0, "extract_images": False, "use_vlm_ocr": False}, + ) + assert "# Title" in md + assert "**world**" in md + assert "\"\n" + " b\"

hello

\")\n" + "out, *_rest = mod._extract_html(dirty)\n" + "import json\n" + "print(json.dumps({'out': out}))\n" + ) + proc = _run_subprocess(body) + assert proc.returncode == 0, proc.stderr + + import json + + parsed = json.loads(proc.stdout.strip().splitlines()[-1]) + out = parsed["out"] + # Pre-fix this returns the raw HTML because the fallback branch + # in _extract_html swallows the ImportError. + assert "alert" not in out, ( + f" survived into the prompt; raw output:\n{out}" + ) + assert " bytes: + """Mint a tiny PDF with an empty user password (mirrors what + Orimi's test file and many distiller pipelines produce).""" + pymupdf = pytest.importorskip("pymupdf") + doc = pymupdf.open() + page = doc.new_page() + page.insert_text( + (72, 100), + "pseudo-encrypted PDF: null user password, opens without prompt", + fontsize=12, + ) + out = doc.tobytes( + encryption=pymupdf.PDF_ENCRYPT_AES_256, + owner_pw="owner-pw", + user_pw="", + ) + doc.close() + return out + + +def test_extract_pdf_accepts_null_password(monkeypatch): + """The extractor must not raise DocumentExtractionEncrypted for a + PDF whose user password is the empty string. PyMuPDF's + ``needs_pass`` is the canonical signal; ``is_encrypted`` is too + aggressive.""" + from core.chat import document_extractor as mod + + file_bytes = _make_pseudo_encrypted_pdf() + + md, figures, page_count, truncated, seen = mod._extract_pdf( + file_bytes, + max_figures=0, + use_vlm_ocr=False, + max_visual_payloads=0, + ) + + assert page_count == 1 + assert "pseudo-encrypted PDF" in md + assert figures == [] + + +def test_preflight_pdf_page_count_accepts_null_password(): + """The pre-extraction preflight at + ``routes.inference._preflight_pdf_page_count`` must accept + null-password PDFs.""" + from routes.inference import _preflight_pdf_page_count + + file_bytes = _make_pseudo_encrypted_pdf() + n = _preflight_pdf_page_count( + file_bytes, + filename="pseudo_encrypted.pdf", + content_type="application/pdf", + ) + assert n == 1 + + +def test_extract_pdf_still_rejects_password_required(monkeypatch): + """Sanity-check the other direction: a PDF that actually requires + a non-empty user password must still raise + DocumentExtractionEncrypted.""" + pymupdf = pytest.importorskip("pymupdf") + doc = pymupdf.open() + page = doc.new_page() + page.insert_text((72, 100), "this one needs a password", fontsize=12) + encrypted = doc.tobytes( + encryption=pymupdf.PDF_ENCRYPT_AES_256, + owner_pw="owner", + user_pw="real-password", + ) + doc.close() + + from core.chat import document_extractor as mod + + with pytest.raises(mod.DocumentExtractionEncrypted): + mod._extract_pdf( + encrypted, + max_figures=0, + use_vlm_ocr=False, + max_visual_payloads=0, + ) diff --git a/tests/studio/test_stream_cancel_registration_timing.py b/tests/studio/test_stream_cancel_registration_timing.py index 40ec3d6e1f..d70ab8ee42 100644 --- a/tests/studio/test_stream_cancel_registration_timing.py +++ b/tests/studio/test_stream_cancel_registration_timing.py @@ -121,11 +121,19 @@ def test_no_tracker_enter_inside_async_generators(): def test_tracker_enter_exists_in_sync_body_of_chat_completions(): + # The handler `openai_chat_completions` is a thin wrapper around + # `_openai_chat_completions_impl`, where the streaming bodies (and + # therefore the tracker registration) live after the document- + # extractor refactor. Accept tracker-__enter__ calls that appear in + # either function so the structural guarantee survives the wrapper. top = None for n in ast.walk(_TREE): - if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions": - top = n - break + if isinstance(n, ast.AsyncFunctionDef) and n.name in { + "openai_chat_completions", + "_openai_chat_completions_impl", + }: + if top is None or n.name == "_openai_chat_completions_impl": + top = n assert top is not None, "openai_chat_completions handler missing" count = 0 for sub in ast.walk(top): @@ -171,11 +179,17 @@ def test_async_generators_cleanup_tracker_in_finally(): def test_streaming_responses_have_no_background_task(): + # The streaming bodies live in `_openai_chat_completions_impl` after + # the document-extractor refactor; the public handler is a thin + # wrapper. Walk the impl so this guard does not vacuously pass. top = None for n in ast.walk(_TREE): - if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions": - top = n - break + if isinstance(n, ast.AsyncFunctionDef) and n.name in { + "openai_chat_completions", + "_openai_chat_completions_impl", + }: + if top is None or n.name == "_openai_chat_completions_impl": + top = n assert top is not None for sub in ast.walk(top): if not (isinstance(sub, ast.Call) and isinstance(sub.func, ast.Name)): @@ -482,12 +496,19 @@ def test_stream_chunks_cancel_branch_resets_backend_state(): # internal cancel path does not do this, so a cancel-via-POST that # only broke the loop would leave the subprocess in a dirty state # for the next request. + # `stream_chunks` is now nested inside `_openai_chat_completions_impl` + # (the implementation function the thin `openai_chat_completions` + # wrapper delegates to). Search either function so the test survives + # the document-extractor refactor. fn = None top = None for n in ast.walk(_TREE): - if isinstance(n, ast.AsyncFunctionDef) and n.name == "openai_chat_completions": - top = n - break + if isinstance(n, ast.AsyncFunctionDef) and n.name in { + "openai_chat_completions", + "_openai_chat_completions_impl", + }: + if top is None or n.name == "_openai_chat_completions_impl": + top = n assert top is not None for n in ast.walk(top): if isinstance(n, ast.AsyncFunctionDef) and n.name == "stream_chunks":