diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 2499e6cc094401..1ff0b59514f682 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from typing import TypedDict -from sqlalchemy import Engine +from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker import contexts @@ -12,6 +12,7 @@ Account, App, EndUser, + Message, WorkflowNodeExecutionModel, WorkflowRun, WorkflowRunTriggeredFrom, @@ -72,9 +73,29 @@ def __getattr__(self, item): pagination = self.get_paginate_workflow_runs(app_model, args, triggered_from) + # Batch-load the associated Message for every run in a single query to avoid + # an N+1 pattern: the deprecated WorkflowRun.message property issues one query + # per run. The filter matches that property exactly (app_id + workflow_run_id). + workflow_runs = pagination.data + run_ids = [workflow_run.id for workflow_run in workflow_runs] + messages_by_run_id: dict[str, Message] = {} + if run_ids: + messages = db.session.scalars( + select(Message).where( + Message.app_id == app_model.id, + Message.workflow_run_id.in_(run_ids), + ) + ).all() + for loaded_message in messages: + run_id = loaded_message.workflow_run_id + if run_id is None: + continue + # setdefault mirrors scalar()'s single-row-per-run semantics. + messages_by_run_id.setdefault(run_id, loaded_message) + with_message_workflow_runs = [] - for workflow_run in pagination.data: - message = workflow_run.message + for workflow_run in workflow_runs: + message = messages_by_run_id.get(workflow_run.id) with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run) if message: with_message_workflow_run.message_id = message.id diff --git a/api/tests/unit_tests/services/test_workflow_run_service.py b/api/tests/unit_tests/services/test_workflow_run_service.py index 03471389a6597d..2c69a742f15abe 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service.py +++ b/api/tests/unit_tests/services/test_workflow_run_service.py @@ -34,6 +34,13 @@ def _end_user(**kwargs: Any) -> EndUser: return cast(EndUser, SimpleNamespace(**kwargs)) +def _fake_session_returning_messages(messages: list[Any]) -> SimpleNamespace: + """A stand-in db session whose scalars(...).all() returns the given messages.""" + scalars_result = MagicMock() + scalars_result.all.return_value = messages + return SimpleNamespace(scalars=MagicMock(return_value=scalars_result)) + + class TestWorkflowRunServiceInitialization: def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing( self, @@ -120,15 +127,15 @@ def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_w ) -> None: service = WorkflowRunService(session_factory=MagicMock(name="session_factory")) app_model = _app_model(tenant_id="tenant-1", id="app-1") - run_with_message = SimpleNamespace( - id="run-1", - status="running", - message=SimpleNamespace(id="msg-1", conversation_id="conv-1"), - ) - run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None) + run_with_message = SimpleNamespace(id="run-1", status="running") + run_without_message = SimpleNamespace(id="run-2", status="succeeded") pagination = SimpleNamespace(data=[run_with_message, run_without_message]) monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination)) + message = SimpleNamespace(id="msg-1", conversation_id="conv-1", workflow_run_id="run-1") + fake_session = _fake_session_returning_messages([message]) + monkeypatch.setattr(service_module, "db", SimpleNamespace(session=fake_session)) + result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"}) assert result is pagination @@ -138,6 +145,32 @@ def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_w assert result.data[0].status == "running" assert not hasattr(result.data[1], "message_id") assert result.data[1].id == "run-2" + # Messages are batch-loaded in a single query, not one per run. + fake_session.scalars.assert_called_once() + + def test_get_paginate_advanced_chat_workflow_runs_batch_loads_messages_without_n_plus_one( + self, + repository_factory_mocks: tuple[MagicMock, MagicMock, Any], + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Messages must load with a constant query count regardless of run count. + + Previously the deprecated WorkflowRun.message property issued one query per + run (N+1); they are now batch-loaded in a single query. + """ + service = WorkflowRunService(session_factory=MagicMock(name="session_factory")) + app_model = _app_model(tenant_id="tenant-1", id="app-1") + runs = [SimpleNamespace(id=f"run-{i}", status="succeeded") for i in range(5)] + pagination = SimpleNamespace(data=runs) + monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination)) + + fake_session = _fake_session_returning_messages([]) + monkeypatch.setattr(service_module, "db", SimpleNamespace(session=fake_session)) + + service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={}) + + # Exactly one message query for the whole page, independent of run count. + assert fake_session.scalars.call_count == 1 def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app( self,