Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(
asyncio.Queue()
)
self._event_iterator_waiters = 0
self._closing = False
self._closed = False
self._stored_exception: BaseException | None = None
self._pending_tool_calls: dict[
Expand Down Expand Up @@ -291,8 +292,14 @@ async def update_agent(self, agent: RealtimeAgent) -> None:
)

async def on_event(self, event: RealtimeModelEvent) -> None:
if self._closing or self._closed:
return

await self._put_event(RealtimeRawModelEvent(data=event, info=self._event_info))

if self._closing or self._closed:
return

if event.type == "error":
await self._put_event(RealtimeError(info=self._event_info, error=event.error))
elif event.type == "function_call":
Expand Down Expand Up @@ -466,6 +473,8 @@ async def on_event(self, event: RealtimeModelEvent) -> None:

async def _put_event(self, event: RealtimeSessionEvent) -> None:
"""Put an event into the queue."""
if self._closed:
return
await self._event_queue.put(event)

async def _function_needs_approval(
Expand Down Expand Up @@ -1220,6 +1229,8 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool:

def _enqueue_guardrail_task(self, text: str, response_id: str) -> None:
# Runs the guardrails in a separate task to avoid blocking the main loop
if self._closing or self._closed:
return

task = asyncio.create_task(self._run_output_guardrails(text, response_id))
self._guardrail_tasks.add(task)
Expand All @@ -1246,11 +1257,28 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
)
)

def _cleanup_guardrail_tasks(self) -> None:
for task in self._guardrail_tasks:
if not task.done():
task.cancel()
self._guardrail_tasks.clear()
async def _cancel_and_await_tasks(self, tasks: set[asyncio.Task[Any]]) -> None:
current_task = asyncio.current_task()

while tasks:
tasks_to_await: list[asyncio.Task[Any]] = []
for task in list(tasks):
if task is current_task:
tasks.discard(task)
continue
if not task.done():
task.cancel()
tasks_to_await.append(task)

if not tasks_to_await:
return

await asyncio.gather(*tasks_to_await, return_exceptions=True)
for task in tasks_to_await:
tasks.discard(task)

async def _cleanup_guardrail_tasks(self) -> None:
await self._cancel_and_await_tasks(self._guardrail_tasks)

def _enqueue_tool_call_task(
self,
Expand All @@ -1261,6 +1289,11 @@ def _enqueue_tool_call_task(
call_id_reserved: bool = False,
) -> None:
"""Run tool calls in the background to avoid blocking realtime transport."""
if self._closing or self._closed:
if call_id_reserved:
self._finish_tool_call(event.call_id, mark_completed=False)
return

handle_kwargs: dict[str, Any] = {"agent_snapshot": agent_snapshot}
if from_pending_approval:
handle_kwargs["from_pending_approval"] = True
Expand Down Expand Up @@ -1316,11 +1349,8 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
)
)

def _cleanup_tool_call_tasks(self) -> None:
for task in self._tool_call_tasks:
if not task.done():
task.cancel()
self._tool_call_tasks.clear()
async def _cleanup_tool_call_tasks(self) -> None:
await self._cancel_and_await_tasks(self._tool_call_tasks)

def _wake_event_iterators(self) -> None:
for _ in range(self._event_iterator_waiters):
Expand All @@ -1332,13 +1362,15 @@ async def _cleanup(self) -> None:
self._wake_event_iterators()
return

# Cancel and cleanup guardrail tasks
self._cleanup_guardrail_tasks()
self._cleanup_tool_call_tasks()
self._closing = True

# Remove ourselves as a listener
self._model.remove_listener(self)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Prevent in-flight events from enqueuing after cleanup

Removing the listener here does not stop model events that already copied this listener before close() began; fresh evidence is OpenAIRealtimeModel._emit_event in src/agents/realtime/openai_realtime.py lines 624-626, which iterates over list(self._listeners) and then awaits listener.on_event(event). When close() races with an already-dispatched function_call and _tool_call_tasks is still empty, _cleanup_tool_call_tasks() returns before that in-flight on_event reaches _enqueue_tool_call_task, so the newly added tool task is never cancelled or awaited and can continue after the session is closed.

Useful? React with 👍 / 👎.


# Cancel and cleanup guardrail tasks
await self._cleanup_guardrail_tasks()
await self._cleanup_tool_call_tasks()

# Close the model connection
await self._model.close()

Expand Down
126 changes: 126 additions & 0 deletions tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,90 @@ async def test_aiter_exits_waiting_iterators_when_session_closes():
task.result()


@pytest.mark.asyncio
async def test_cleanup_awaits_cancelled_background_tasks():
model = _DummyModel()
agent = RealtimeAgent(name="agent")
session = RealtimeSession(model, agent, None)

guardrail_started = asyncio.Event()
guardrail_finished = asyncio.Event()
tool_started = asyncio.Event()
tool_finished = asyncio.Event()

async def guardrail_task():
guardrail_started.set()
try:
await asyncio.Event().wait()
finally:
await asyncio.sleep(0)
guardrail_finished.set()

async def tool_call_task():
tool_started.set()
try:
await asyncio.Event().wait()
finally:
await asyncio.sleep(0)
tool_finished.set()

guardrail = asyncio.create_task(guardrail_task())
tool_call = asyncio.create_task(tool_call_task())
session._guardrail_tasks.add(guardrail)
session._tool_call_tasks.add(tool_call)

await guardrail_started.wait()
await tool_started.wait()

await session._cleanup()

assert guardrail.done()
assert tool_call.done()
assert guardrail_finished.is_set()
assert tool_finished.is_set()
assert session._guardrail_tasks == set()
assert session._tool_call_tasks == set()


@pytest.mark.asyncio
async def test_cleanup_awaits_background_tasks_added_during_cancellation():
model = _DummyModel()
agent = RealtimeAgent(name="agent")
session = RealtimeSession(model, agent, None)

first_started = asyncio.Event()
second_started = asyncio.Event()
second_finished = asyncio.Event()

async def second_task():
second_started.set()
try:
await asyncio.Event().wait()
finally:
await asyncio.sleep(0)
second_finished.set()

async def first_task():
first_started.set()
try:
await asyncio.Event().wait()
finally:
task = asyncio.create_task(second_task())
session._guardrail_tasks.add(task)
await second_started.wait()

first = asyncio.create_task(first_task())
session._guardrail_tasks.add(first)

await first_started.wait()

await session._cleanup()

assert first.done()
assert second_finished.is_set()
assert session._guardrail_tasks == set()


@pytest.mark.asyncio
async def test_transcription_completed_adds_new_user_item():
model = _DummyModel()
Expand Down Expand Up @@ -685,6 +769,48 @@ async def test_function_call_event_runs_async_by_default(self, mock_model, mock_
assert isinstance(raw_event, RealtimeRawModelEvent)
assert raw_event.data == function_call_event

@pytest.mark.asyncio
async def test_cleanup_prevents_in_flight_function_call_from_enqueuing_task(
self, mock_model, mock_agent
):
session = RealtimeSession(mock_model, mock_agent, None)
function_call_event = RealtimeModelToolCallEvent(
name="test_function",
call_id="call_cleanup_race",
arguments="{}",
)

first_put_started = asyncio.Event()
release_first_put = asyncio.Event()
tool_task_started = asyncio.Event()
original_put_event = session._put_event

async def blocked_put_event(event):
first_put_started.set()
await release_first_put.wait()
await original_put_event(event)

async def blocked_handle_tool_call(*_args, **_kwargs):
tool_task_started.set()
await asyncio.Event().wait()

with pytest.MonkeyPatch().context() as m:
handle_tool_call_mock = AsyncMock(side_effect=blocked_handle_tool_call)
m.setattr(session, "_put_event", blocked_put_event)
m.setattr(session, "_handle_tool_call", handle_tool_call_mock)

on_event_task = asyncio.create_task(session.on_event(function_call_event))
await first_put_started.wait()

await session._cleanup()
release_first_put.set()
await on_event_task
await asyncio.sleep(0)

assert session._tool_call_tasks == set()
assert not tool_task_started.is_set()
handle_tool_call_mock.assert_not_awaited()


class TestHistoryManagement:
"""Test suite for history management and audio transcription in
Expand Down
Loading