-
Notifications
You must be signed in to change notification settings - Fork 4.3k
feat(realtime): add input guardrails for RealtimeAgent and RealtimeRunConfig #3721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
99d9bac
a006387
bcf6d51
058565e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| ) | ||
| from ..agent import Agent | ||
| from ..exceptions import ToolInputGuardrailTripwireTriggered, UserError | ||
| from ..guardrail import InputGuardrail, InputGuardrailResult | ||
| from ..handoffs import Handoff | ||
| from ..items import ToolApprovalItem | ||
| from ..logger import logger | ||
|
|
@@ -43,6 +44,7 @@ | |
| RealtimeHistoryAdded, | ||
| RealtimeHistoryUpdated, | ||
| RealtimeInputAudioTimeoutTriggered, | ||
| RealtimeInputGuardrailTripped, | ||
| RealtimeRawModelEvent, | ||
| RealtimeSessionEvent, | ||
| RealtimeToolApprovalRequired, | ||
|
|
@@ -202,6 +204,8 @@ def __init__( | |
|
|
||
| # Guardrails state tracking | ||
| self._interrupted_response_ids: set[str] = set() | ||
| # User item_ids for which an input guardrail has already interrupted the response. | ||
| self._interrupted_input_item_ids: set[str] = set() | ||
| self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript | ||
| self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count | ||
| self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get( | ||
|
|
@@ -365,6 +369,10 @@ async def on_event(self, event: RealtimeModelEvent) -> None: | |
| await self._put_event( | ||
| RealtimeHistoryUpdated(info=self._event_info, history=self._history) | ||
| ) | ||
| # Run input guardrails on the finalized user transcript. The transcription completes | ||
| # around the time the server begins generating a response, so we mirror the | ||
| # output-guardrail trip behavior and force a response cancel when a guardrail trips. | ||
| self._enqueue_input_guardrail_task(event.transcript, event.item_id) | ||
| elif event.type == "input_audio_timeout_triggered": | ||
| await self._put_event( | ||
| RealtimeInputAudioTimeoutTriggered( | ||
|
|
@@ -1263,6 +1271,94 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool: | |
|
|
||
| return False | ||
|
|
||
| async def _run_input_guardrails( | ||
| self, | ||
| text: str, | ||
| item_id: str, | ||
| agent: RealtimeAgent, | ||
| input_guardrails: list[InputGuardrail[Any]], | ||
| ) -> bool: | ||
| """Run input guardrails on the user's transcribed input. Returns True if any guardrail was | ||
| triggered. | ||
|
|
||
| ``agent`` and ``input_guardrails`` are snapshotted when the transcription event is handled | ||
| so that a concurrent ``update_agent()`` or handoff cannot swap in a different agent's | ||
| guardrails before this background task runs. | ||
| """ | ||
| # If we've already interrupted the response for this user item, skip. | ||
| if not input_guardrails or item_id in self._interrupted_input_item_ids: | ||
| return False | ||
|
|
||
| async def _run_one(guardrail: InputGuardrail[Any]) -> InputGuardrailResult | None: | ||
| try: | ||
| return await guardrail.run( | ||
| # TODO (rm) Remove this cast, it's wrong | ||
| cast(Agent[Any], agent), | ||
| text, | ||
| self._context_wrapper, | ||
| ) | ||
| except Exception as exc: | ||
| logger.warning( | ||
| "Input guardrail %r raised %s: %s; skipping it.", | ||
| guardrail.get_name(), | ||
| type(exc).__name__, | ||
| exc, | ||
| ) | ||
| logger.debug("Input guardrail failure details.", exc_info=True) | ||
| return None | ||
|
|
||
| # Run the guardrails concurrently and act on the first tripwire as soon as it is available, | ||
| # cancelling the rest. This mirrors the streamed input-guardrail path: a slow guardrail | ||
| # cannot delay the forced cancel behind unrelated guardrails, so the unsafe turn is | ||
| # interrupted as early as possible instead of waiting for every guardrail to finish. | ||
| guardrail_tasks = [ | ||
| asyncio.create_task(_run_one(guardrail)) for guardrail in input_guardrails | ||
| ] | ||
| triggered_results: list[InputGuardrailResult] = [] | ||
| try: | ||
| for completed in asyncio.as_completed(guardrail_tasks): | ||
| result = await completed | ||
| if result is not None and result.output.tripwire_triggered: | ||
| triggered_results.append(result) | ||
| break | ||
| finally: | ||
| for task in guardrail_tasks: | ||
| if not task.done(): | ||
| task.cancel() | ||
| await asyncio.gather(*guardrail_tasks, return_exceptions=True) | ||
|
|
||
| if triggered_results: | ||
| # Double-check: bail if already interrupted for this user item. | ||
| if item_id in self._interrupted_input_item_ids: | ||
| return False | ||
|
|
||
| # Mark as interrupted immediately (before any awaits) to minimize the race window. | ||
| self._interrupted_input_item_ids.add(item_id) | ||
|
|
||
| # Emit input guardrail tripped event. | ||
| await self._put_event( | ||
| RealtimeInputGuardrailTripped( | ||
| guardrail_results=triggered_results, | ||
| message=text, | ||
| info=self._event_info, | ||
| ) | ||
| ) | ||
|
|
||
| # Interrupt the model, forcing a cancel of any in-progress response. | ||
| await self._model.send_event(RealtimeModelSendInterrupt(force_response_cancel=True)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When a model-backed input guardrail finishes after the guarded audio turn has already ended and a later user turn is active, this sends an untargeted forced Useful? React with 👍 / 👎. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When an input guardrail is slower than the response it is checking (for example, the guarded turn reaches Useful? React with 👍 / 👎.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair point but this is how output guardrail happens too. For the sake of parity, I will acknowledge this as a drawback but not change this, if this is desired, it should be raised in a separate PR |
||
|
|
||
| # Send guardrail triggered message. | ||
| guardrail_names = [result.guardrail.get_name() for result in triggered_results] | ||
| await self._model.send_event( | ||
| RealtimeModelSendUserInput( | ||
| user_input=f"input guardrail triggered: {', '.join(guardrail_names)}" | ||
| ) | ||
| ) | ||
|
|
||
| return True | ||
|
|
||
| return False | ||
|
|
||
| def _enqueue_guardrail_task(self, text: str, response_id: str) -> None: | ||
| # Runs the guardrails in a separate task to avoid blocking the main loop | ||
|
|
||
|
|
@@ -1272,6 +1368,33 @@ def _enqueue_guardrail_task(self, text: str, response_id: str) -> None: | |
| # Add callback to remove completed tasks and handle exceptions | ||
| task.add_done_callback(self._on_guardrail_task_done) | ||
|
|
||
| def _enqueue_input_guardrail_task(self, text: str, item_id: str) -> None: | ||
| # Snapshot the active agent and its guardrails now; a later update_agent()/handoff must not | ||
| # change which guardrails run against this transcript. | ||
| agent = self._current_agent | ||
| combined_guardrails = agent.input_guardrails + self._run_config.get("input_guardrails", []) | ||
|
|
||
| seen_ids: set[int] = set() | ||
| input_guardrails: list[InputGuardrail[Any]] = [] | ||
| for guardrail in combined_guardrails: | ||
| guardrail_id = id(guardrail) | ||
| if guardrail_id not in seen_ids: | ||
| input_guardrails.append(guardrail) | ||
| seen_ids.add(guardrail_id) | ||
|
|
||
| # Skip creating a no-op task when no input guardrails are configured. | ||
| if not input_guardrails: | ||
| return | ||
|
|
||
| # Runs the input guardrails in a separate task to avoid blocking the main loop. | ||
| task = asyncio.create_task( | ||
| self._run_input_guardrails(text, item_id, agent, input_guardrails) | ||
| ) | ||
| # Reuse the shared guardrail task set + done callback so completed tasks are removed, | ||
| # exceptions surface as events, and close() cancels any still-running task. | ||
| self._guardrail_tasks.add(task) | ||
| task.add_done_callback(self._on_guardrail_task_done) | ||
|
|
||
| def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: | ||
| """Handle completion of a guardrail task.""" | ||
| # Remove from tracking set | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fresh evidence: this version now uses
as_completed, but thisawait asyncio.gather(...)still runs before the forced cancel at line 1348. When one input guardrail trips quickly and another model-backed guardrail is slow to acknowledge cancellation or does cleanup, the session waits here before sendingresponse.cancel, so the unsafe realtime response can continue generating for that latency; request the interrupt and mark the item interrupted before awaiting sibling task cleanup.Useful? React with 👍 / 👎.