diff --git a/Makefile b/Makefile index 801c59cc281a37..6921ccaaf025ae 100644 --- a/Makefile +++ b/Makefile @@ -80,6 +80,7 @@ lint: @uv run --project api --dev ruff check --fix ./api @$(MAKE) api-contract-lint @uv run --directory api --dev lint-imports + @$(MAKE) api-import-baseline-lint @uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example @echo "✅ Linting complete" @@ -88,6 +89,11 @@ api-contract-lint: @uv run --project api --dev python api/dev/lint_response_contracts.py @echo "✅ Response contract lint complete" +api-import-baseline-lint: + @echo "🏗️ Checking import-linter baseline..." + @uv run --project api --dev python scripts/lint_imports_baseline.py --baseline import_linter_baseline.json + @echo "✅ Import baseline lint complete" + type-check: @echo "📝 Running type checks (pyrefly + mypy)..." @./dev/pyrefly-check-local $(PATH_TO_CHECK) diff --git a/api/tests/unit_tests/commands/test_lint_imports_baseline.py b/api/tests/unit_tests/commands/test_lint_imports_baseline.py new file mode 100644 index 00000000000000..a961a085db1643 --- /dev/null +++ b/api/tests/unit_tests/commands/test_lint_imports_baseline.py @@ -0,0 +1,291 @@ +import importlib.util +import json +import sys +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace + +import pytest +from pydantic import ValidationError + + +def _load_lint_imports_baseline_module(): + repo_root = Path(__file__).parents[4] + script_path = repo_root / "scripts" / "lint_imports_baseline.py" + spec = importlib.util.spec_from_file_location("lint_imports_baseline", script_path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +@dataclass(frozen=True) +class _FakeContract: + name: str + + +@dataclass(frozen=True) +class _ProtectedImportGroup: + top_level_module: str + illegal_links: list[dict[str, object]] + original_expression: str | None = None + + +class _FakeReport: + def __init__(self, entries: list[tuple[_FakeContract, SimpleNamespace]]) -> None: + self._entries = entries + + def get_contracts_and_checks(self): + return iter(self._entries) + + +def test_snapshot_from_report_collects_unique_direct_imports_from_nested_metadata(): + module = _load_lint_imports_baseline_module() + report = _FakeReport( + [ + ( + _FakeContract("layers"), + SimpleNamespace( + kept=False, + metadata={ + "invalid_dependencies": [ + { + "routes": [ + { + "chain": [ + { + "importer": "controllers.apps.list", + "imported": "services.apps", + "line_numbers": (3,), + }, + { + "importer": "services.apps", + "imported": "core.apps", + "line_numbers": (4,), + }, + ], + "extra_firsts": [], + "extra_lasts": [], + }, + { + "chain": [ + { + "importer": "controllers.apps.list", + "imported": "services.apps", + "line_numbers": (9,), + } + ], + "extra_firsts": [], + "extra_lasts": [], + }, + ] + } + ] + }, + ), + ), + ( + _FakeContract("protected"), + SimpleNamespace( + kept=False, + metadata={ + "illegal_imports": [ + _ProtectedImportGroup( + top_level_module="extensions.secret", + illegal_links=[ + { + "importer": "controllers.apps.list", + "imported": "extensions.secret", + "line_numbers": (12,), + } + ], + ) + ] + }, + ), + ), + ] + ) + + assert module.snapshot_from_report(report) == { + "layers": { + "controllers.apps.list": ["services.apps"], + "services.apps": ["core.apps"], + }, + "protected": { + "controllers.apps.list": ["extensions.secret"], + }, + } + + +def test_compare_snapshots_reports_new_direct_imports_in_subset_mode(): + module = _load_lint_imports_baseline_module() + + failures = module.compare_snapshots( + current_snapshot={ + "layers": { + "controllers.apps.list": ["services.apps", "services.billing"], + } + }, + baseline_snapshot={ + "layers": { + "controllers.apps.list": ["services.apps"], + } + }, + comparison="subset", + ) + + assert len(failures) == 1 + assert failures[0].contract_name == "layers" + assert failures[0].importer == "controllers.apps.list" + assert failures[0].extra_imports == ("services.billing",) + assert failures[0].baseline_count == 1 + assert failures[0].current_count == 2 + + +def test_compare_snapshots_count_mode_rejects_only_growth(): + module = _load_lint_imports_baseline_module() + + failures = module.compare_snapshots( + current_snapshot={ + "layers": { + "controllers.apps.list": ["services.apps", "services.billing", "services.audit"], + } + }, + baseline_snapshot={ + "layers": { + "controllers.apps.list": ["services.apps", "services.billing"], + } + }, + comparison="count", + ) + + assert len(failures) == 1 + assert failures[0].current_count == 3 + assert failures[0].baseline_count == 2 + + +def test_main_writes_baseline_snapshot(tmp_path: Path, monkeypatch): + module = _load_lint_imports_baseline_module() + baseline_path = tmp_path / "import-baseline.json" + + monkeypatch.setattr( + module, + "load_report", + lambda **_: _FakeReport( + [ + ( + _FakeContract("layers"), + SimpleNamespace( + kept=False, + metadata={ + "invalid_dependencies": [ + { + "routes": [ + { + "chain": [ + { + "importer": "controllers.apps.list", + "imported": "services.apps", + "line_numbers": (3,), + } + ], + "extra_firsts": [], + "extra_lasts": [], + } + ] + } + ] + }, + ), + ) + ] + ), + ) + + assert module.main(["--baseline", str(baseline_path), "--write-baseline"]) == 0 + assert json.loads(baseline_path.read_text(encoding="utf-8")) == { + "version": 1, + "contracts": { + "layers": { + "controllers.apps.list": ["services.apps"], + } + }, + } + + +def test_main_fails_on_replacement_violation_in_default_subset_mode(tmp_path: Path, monkeypatch, capsys): + module = _load_lint_imports_baseline_module() + baseline_path = tmp_path / "import-baseline.json" + baseline_path.write_text( + json.dumps( + { + "version": 1, + "contracts": { + "layers": { + "controllers.apps.list": ["services.apps"], + } + }, + } + ), + encoding="utf-8", + ) + + monkeypatch.setattr( + module, + "load_report", + lambda **_: _FakeReport( + [ + ( + _FakeContract("layers"), + SimpleNamespace( + kept=False, + metadata={ + "invalid_dependencies": [ + { + "routes": [ + { + "chain": [ + { + "importer": "controllers.apps.list", + "imported": "services.billing", + "line_numbers": (8,), + } + ], + "extra_firsts": [], + "extra_lasts": [], + } + ] + } + ] + }, + ), + ) + ] + ), + ) + + assert module.main(["--baseline", str(baseline_path)]) == 1 + output = capsys.readouterr().out + assert "controllers.apps.list" in output + assert "services.billing" in output + + +def test_load_baseline_rejects_unexpected_top_level_fields(tmp_path: Path): + module = _load_lint_imports_baseline_module() + baseline_path = tmp_path / "import-baseline.json" + baseline_path.write_text( + json.dumps( + { + "version": 1, + "contracts": {}, + "unexpected": True, + } + ), + encoding="utf-8", + ) + + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + module.load_baseline(baseline_path) diff --git a/import_linter_baseline.json b/import_linter_baseline.json new file mode 100644 index 00000000000000..9f9dfddefae718 --- /dev/null +++ b/import_linter_baseline.json @@ -0,0 +1,4 @@ +{ + "contracts": {}, + "version": 1 +} diff --git a/scripts/lint_imports_baseline.py b/scripts/lint_imports_baseline.py new file mode 100644 index 00000000000000..c882a566997c13 --- /dev/null +++ b/scripts/lint_imports_baseline.py @@ -0,0 +1,284 @@ +"""Gate import-linter violations against a committed baseline snapshot. + +This wrapper keeps import-linter as the source of truth for architectural +contracts, then snapshots the broken direct-import edges per contract and +importer module. The default comparison mode is ``subset`` because it prevents +same-count replacements from silently regressing the architecture. A weaker +``count`` mode is also available when a team explicitly wants count-only gating. +""" + +from __future__ import annotations + +import argparse +import dataclasses +from collections.abc import Iterator +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal, NewType +import sys + +from importlinter import configuration +from importlinter.application import use_cases +from pydantic import BaseModel, ConfigDict + +type BaselineVersion = Literal[1] +type ComparisonMode = Literal["subset", "count"] +ContractName = NewType("ContractName", str) +ModuleName = NewType("ModuleName", str) +type ModulesByImporter = dict[ModuleName, list[ModuleName]] +type MutableModulesByImporter = dict[ModuleName, set[ModuleName]] +type BaselineSnapshot = dict[ContractName, ModulesByImporter] +type ImportEdge = tuple[ModuleName, ModuleName] + + +class BaselinePayload(BaseModel): + """Serialized baseline file payload.""" + + version: BaselineVersion = 1 + contracts: BaselineSnapshot + model_config = ConfigDict(extra="forbid") + + +REPO_ROOT = Path(__file__).resolve().parents[1] +API_DIR = REPO_ROOT / "api" +DEFAULT_CONFIG_PATH = API_DIR / ".importlinter" + + +@dataclass(frozen=True) +class SnapshotFailure: + contract_name: ContractName + importer: ModuleName + baseline_count: int + current_count: int + extra_imports: tuple[ModuleName, ...] + + +def load_report(config_path: str | None = None, contract_ids: tuple[str, ...] = ()) -> Any: + """Build and return an import-linter report using the same path setup as the CLI.""" + + configuration.configure() + api_dir = str(API_DIR) + if api_dir not in sys.path: + sys.path.insert(0, api_dir) + + resolved_config_path = config_path or str(DEFAULT_CONFIG_PATH) + user_options = use_cases.read_user_options(config_filename=resolved_config_path) + return use_cases.create_report( + user_options=user_options, + limit_to_contracts=contract_ids, + ) + + +def snapshot_from_report(report: Any) -> BaselineSnapshot: + """Return broken direct-import edges grouped by contract and importer module.""" + + snapshot: BaselineSnapshot = {} + for contract, check in report.get_contracts_and_checks(): + if check.kept: + continue + + imports_by_importer: MutableModulesByImporter = {} + for importer, imported in _iter_direct_imports(check.metadata): + imports_by_importer.setdefault(importer, set()).add(imported) + + if check.metadata and not imports_by_importer: + raise ValueError(f"Broken contract '{contract.name}' does not expose direct import edges in metadata.") + + if imports_by_importer: + snapshot[ContractName(contract.name)] = { + importer: sorted(imported_modules) for importer, imported_modules in sorted(imports_by_importer.items()) + } + + return {contract_name: snapshot[contract_name] for contract_name in sorted(snapshot)} + + +def normalize_snapshot(snapshot: BaselineSnapshot) -> BaselineSnapshot: + """Return a stable snapshot with sorted keys and deduplicated imported modules.""" + + normalized_snapshot: BaselineSnapshot = {} + for contract_name, importers in snapshot.items(): + normalized_importers: ModulesByImporter = {} + for importer, imported_modules in importers.items(): + normalized_importers[importer] = sorted(set(imported_modules)) + normalized_snapshot[contract_name] = normalized_importers + + return {contract_name: normalized_snapshot[contract_name] for contract_name in sorted(normalized_snapshot)} + + +def compare_snapshots( + current_snapshot: BaselineSnapshot, + baseline_snapshot: BaselineSnapshot, + comparison: ComparisonMode = "subset", +) -> list[SnapshotFailure]: + """Compare the current and baseline snapshots and return any regressions.""" + + failures: list[SnapshotFailure] = [] + + contract_names = sorted(set(current_snapshot) | set(baseline_snapshot)) + for contract_name in contract_names: + current_by_importer = current_snapshot.get(contract_name, {}) + baseline_by_importer = baseline_snapshot.get(contract_name, {}) + + for importer in sorted(set(current_by_importer) | set(baseline_by_importer)): + current_imports = set(current_by_importer.get(importer, [])) + baseline_imports = set(baseline_by_importer.get(importer, [])) + extra_imports = tuple(sorted(current_imports - baseline_imports)) + + if comparison == "subset": + is_failure = bool(extra_imports) + else: + is_failure = len(current_imports) > len(baseline_imports) + + if is_failure: + failures.append( + SnapshotFailure( + contract_name=contract_name, + importer=importer, + baseline_count=len(baseline_imports), + current_count=len(current_imports), + extra_imports=extra_imports, + ) + ) + + return failures + + +def load_baseline(path: Path) -> BaselineSnapshot: + """Load and validate a baseline file.""" + + payload = BaselinePayload.model_validate_json(path.read_text(encoding="utf-8")) + return normalize_snapshot(payload.contracts) + + +def write_baseline(path: Path, snapshot: BaselineSnapshot) -> None: + """Persist the supplied snapshot as a JSON baseline file.""" + + payload = BaselinePayload(contracts=normalize_snapshot(snapshot)) + path.write_text(payload.model_dump_json(indent=2) + "\n", encoding="utf-8") + + +def main(argv: list[str] | None = None) -> int: + parser = build_argument_parser() + args = parser.parse_args(argv) + + baseline_path = args.baseline + current_snapshot = snapshot_from_report(load_report(config_path=args.config, contract_ids=tuple(args.contract))) + + if args.write_baseline: + write_baseline(baseline_path, current_snapshot) + _write_line(f"Wrote import baseline to {baseline_path}.") + return 0 + + baseline_snapshot = load_baseline(baseline_path) + failures = compare_snapshots( + current_snapshot=current_snapshot, + baseline_snapshot=baseline_snapshot, + comparison=args.comparison, + ) + if failures: + _print_failures(failures, comparison=args.comparison) + return 1 + + _write_line( + "Import baseline OK. " + f"Checked {sum(len(importers) for importers in current_snapshot.values())} importer entries." + ) + return 0 + + +def build_argument_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Compare current import-linter violations against a committed baseline." + ) + parser.add_argument( + "--baseline", + type=Path, + required=True, + help="Path to the committed baseline JSON file.", + ) + parser.add_argument( + "--write-baseline", + action="store_true", + help="Write the current violation snapshot to the baseline file and exit.", + ) + parser.add_argument( + "--comparison", + choices=("subset", "count"), + default="subset", + help="Comparison strategy. 'subset' is stricter and is the default.", + ) + parser.add_argument( + "--config", + help="Optional import-linter config file path.", + ) + parser.add_argument( + "--contract", + action="append", + default=[], + help="Optional contract id filter. May be passed multiple times.", + ) + return parser + + +def _iter_direct_imports(node: object, seen: set[int] | None = None) -> Iterator[ImportEdge]: + if seen is None: + seen = set() + + if node is None or isinstance(node, (str, int, float, bool)): + return + + if not isinstance(node, type) and dataclasses.is_dataclass(node): + for field in dataclasses.fields(node): + yield from _iter_direct_imports(getattr(node, field.name), seen) + return + + if isinstance(node, dict): + marker = id(node) + if marker in seen: + return + seen.add(marker) + + importer = node.get("importer") + imported = node.get("imported") + if isinstance(importer, str) and isinstance(imported, str): + yield ModuleName(importer), ModuleName(imported) + + for value in node.values(): + yield from _iter_direct_imports(value, seen) + return + + if isinstance(node, (list, tuple, set, frozenset)): + marker = id(node) + if marker in seen: + return + seen.add(marker) + + for item in node: + yield from _iter_direct_imports(item, seen) + return + + if hasattr(node, "__dict__"): + marker = id(node) + if marker in seen: + return + seen.add(marker) + yield from _iter_direct_imports(vars(node), seen) + + +def _print_failures(failures: list[SnapshotFailure], comparison: ComparisonMode) -> None: + _write_line(f"Import baseline regression detected ({comparison} mode):") + for failure in failures: + _write_line( + f"- [{failure.contract_name}] {failure.importer}: " + f"baseline={failure.baseline_count}, current={failure.current_count}" + ) + if failure.extra_imports: + _write_line(f" new imports: {', '.join(failure.extra_imports)}") + + +def _write_line(message: str) -> None: + sys.stdout.write(f"{message}\n") + + +if __name__ == "__main__": + raise SystemExit(main())