Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions api/services/workflow_run_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Sequence
from typing import TypedDict

from sqlalchemy import Engine
from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker

import contexts
Expand All @@ -12,6 +12,7 @@
Account,
App,
EndUser,
Message,
WorkflowNodeExecutionModel,
WorkflowRun,
WorkflowRunTriggeredFrom,
Expand Down Expand Up @@ -72,9 +73,29 @@ def __getattr__(self, item):

pagination = self.get_paginate_workflow_runs(app_model, args, triggered_from)

# Batch-load the associated Message for every run in a single query to avoid
# an N+1 pattern: the deprecated WorkflowRun.message property issues one query
# per run. The filter matches that property exactly (app_id + workflow_run_id).
workflow_runs = pagination.data
run_ids = [workflow_run.id for workflow_run in workflow_runs]
messages_by_run_id: dict[str, Message] = {}
if run_ids:
messages = db.session.scalars(
select(Message).where(
Message.app_id == app_model.id,
Message.workflow_run_id.in_(run_ids),
)
).all()
for loaded_message in messages:
run_id = loaded_message.workflow_run_id
if run_id is None:
continue
# setdefault mirrors scalar()'s single-row-per-run semantics.
messages_by_run_id.setdefault(run_id, loaded_message)

with_message_workflow_runs = []
for workflow_run in pagination.data:
message = workflow_run.message
for workflow_run in workflow_runs:
message = messages_by_run_id.get(workflow_run.id)
with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run)
if message:
with_message_workflow_run.message_id = message.id
Expand Down
45 changes: 39 additions & 6 deletions api/tests/unit_tests/services/test_workflow_run_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def _end_user(**kwargs: Any) -> EndUser:
return cast(EndUser, SimpleNamespace(**kwargs))


def _fake_session_returning_messages(messages: list[Any]) -> SimpleNamespace:
"""A stand-in db session whose scalars(...).all() returns the given messages."""
scalars_result = MagicMock()
scalars_result.all.return_value = messages
return SimpleNamespace(scalars=MagicMock(return_value=scalars_result))


class TestWorkflowRunServiceInitialization:
def test___init___should_create_sessionmaker_from_db_engine_when_session_factory_missing(
self,
Expand Down Expand Up @@ -120,15 +127,15 @@ def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_w
) -> None:
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
run_with_message = SimpleNamespace(
id="run-1",
status="running",
message=SimpleNamespace(id="msg-1", conversation_id="conv-1"),
)
run_without_message = SimpleNamespace(id="run-2", status="succeeded", message=None)
run_with_message = SimpleNamespace(id="run-1", status="running")
run_without_message = SimpleNamespace(id="run-2", status="succeeded")
pagination = SimpleNamespace(data=[run_with_message, run_without_message])
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))

message = SimpleNamespace(id="msg-1", conversation_id="conv-1", workflow_run_id="run-1")
fake_session = _fake_session_returning_messages([message])
monkeypatch.setattr(service_module, "db", SimpleNamespace(session=fake_session))

result = service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={"limit": "2"})

assert result is pagination
Expand All @@ -138,6 +145,32 @@ def test_get_paginate_advanced_chat_workflow_runs_should_attach_message_fields_w
assert result.data[0].status == "running"
assert not hasattr(result.data[1], "message_id")
assert result.data[1].id == "run-2"
# Messages are batch-loaded in a single query, not one per run.
fake_session.scalars.assert_called_once()

def test_get_paginate_advanced_chat_workflow_runs_batch_loads_messages_without_n_plus_one(
self,
repository_factory_mocks: tuple[MagicMock, MagicMock, Any],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Messages must load with a constant query count regardless of run count.

Previously the deprecated WorkflowRun.message property issued one query per
run (N+1); they are now batch-loaded in a single query.
"""
service = WorkflowRunService(session_factory=MagicMock(name="session_factory"))
app_model = _app_model(tenant_id="tenant-1", id="app-1")
runs = [SimpleNamespace(id=f"run-{i}", status="succeeded") for i in range(5)]
pagination = SimpleNamespace(data=runs)
monkeypatch.setattr(service, "get_paginate_workflow_runs", MagicMock(return_value=pagination))

fake_session = _fake_session_returning_messages([])
monkeypatch.setattr(service_module, "db", SimpleNamespace(session=fake_session))

service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args={})

# Exactly one message query for the whole page, independent of run count.
assert fake_session.scalars.call_count == 1

def test_get_workflow_run_should_delegate_to_repository_by_tenant_and_app(
self,
Expand Down
Loading