Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,11 +1246,26 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
)
)

def _cleanup_guardrail_tasks(self) -> None:
for task in self._guardrail_tasks:
async def _cancel_and_await_tasks(self, tasks: set[asyncio.Task[Any]]) -> None:
if not tasks:
return

current_task = asyncio.current_task()
tasks_to_await: list[asyncio.Task[Any]] = []
for task in list(tasks):
if task is current_task:
continue
if not task.done():
task.cancel()
self._guardrail_tasks.clear()
tasks_to_await.append(task)

if tasks_to_await:
await asyncio.gather(*tasks_to_await, return_exceptions=True)

tasks.clear()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Don't drop tasks added during cleanup

When cleanup is awaiting cancellation finalizers, the session is still accepting model events because the listener is removed only after these awaits. If a transcript/function_call event or approval path enqueues another task after list(tasks) is taken, tasks.clear() removes that new live task from tracking without cancelling or awaiting it, so it can continue after close() and send events to a model that is being closed. Consider marking the session closing/removing the listener before awaiting, or only discarding the snapshot and looping until no tracked tasks remain.

Useful? React with 👍 / 👎.


async def _cleanup_guardrail_tasks(self) -> None:
await self._cancel_and_await_tasks(self._guardrail_tasks)

def _enqueue_tool_call_task(
self,
Expand Down Expand Up @@ -1316,11 +1331,8 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
)
)

def _cleanup_tool_call_tasks(self) -> None:
for task in self._tool_call_tasks:
if not task.done():
task.cancel()
self._tool_call_tasks.clear()
async def _cleanup_tool_call_tasks(self) -> None:
await self._cancel_and_await_tasks(self._tool_call_tasks)

def _wake_event_iterators(self) -> None:
for _ in range(self._event_iterator_waiters):
Expand All @@ -1333,8 +1345,8 @@ async def _cleanup(self) -> None:
return

# Cancel and cleanup guardrail tasks
self._cleanup_guardrail_tasks()
self._cleanup_tool_call_tasks()
await self._cleanup_guardrail_tasks()
await self._cleanup_tool_call_tasks()

# Remove ourselves as a listener
self._model.remove_listener(self)
Expand Down
45 changes: 45 additions & 0 deletions tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,51 @@ async def test_aiter_exits_waiting_iterators_when_session_closes():
task.result()


@pytest.mark.asyncio
async def test_cleanup_awaits_cancelled_background_tasks():
model = _DummyModel()
agent = RealtimeAgent(name="agent")
session = RealtimeSession(model, agent, None)

guardrail_started = asyncio.Event()
guardrail_finished = asyncio.Event()
tool_started = asyncio.Event()
tool_finished = asyncio.Event()

async def guardrail_task():
guardrail_started.set()
try:
await asyncio.Event().wait()
finally:
await asyncio.sleep(0)
guardrail_finished.set()

async def tool_call_task():
tool_started.set()
try:
await asyncio.Event().wait()
finally:
await asyncio.sleep(0)
tool_finished.set()

guardrail = asyncio.create_task(guardrail_task())
tool_call = asyncio.create_task(tool_call_task())
session._guardrail_tasks.add(guardrail)
session._tool_call_tasks.add(tool_call)

await guardrail_started.wait()
await tool_started.wait()

await session._cleanup()

assert guardrail.done()
assert tool_call.done()
assert guardrail_finished.is_set()
assert tool_finished.is_set()
assert session._guardrail_tasks == set()
assert session._tool_call_tasks == set()


@pytest.mark.asyncio
async def test_transcription_completed_adds_new_user_item():
model = _DummyModel()
Expand Down