Source code for agent_framework.drivers.openai

"""OpenAI Responses API driver implementation."""

from __future__ import annotations

import json
from dataclasses import dataclass, field
from typing import Any, ClassVar

from openai import OpenAI

from agent_framework.model import (
    DriverCapabilities,
    ModelContext,
    ModelDriverBase,
    ModelResponse,
    ProviderRequestTrace,
    ProviderResponseTrace,
    _FallbackMixin,
    normalize_openai_usage,
    openai_responses_text_format_field,
    parse_json_object_model_output,
    resolved_response_format_dict,
)


[docs] @dataclass(slots=True) class OpenAiModelDriver(ModelDriverBase, _FallbackMixin): """OpenAI-backed model driver for the Responses API.""" capabilities: ClassVar[DriverCapabilities] = DriverCapabilities( is_async=False, supports_multimodal=False, supports_response_format=True, supports_tools=False, ) api_key: str on_request_trace: Any | None = None on_response_trace: Any | None = None _fallback_state: dict[tuple[str, ...], int] = field(default_factory=dict, repr=False) _client: Any = field(default=None, repr=False)
[docs] def set_trace_callbacks( self, *, on_request: Any | None = None, on_response: Any | None = None, ) -> None: """Attach optional trace callbacks for exact provider I/O logging.""" self.on_request_trace = on_request self.on_response_trace = on_response
def _get_client(self) -> OpenAI: """Return a cached OpenAI client, constructing it lazily on first use.""" if self._client is None: self._client = OpenAI(api_key=self.api_key) return self._client @staticmethod def _raw_response_text(response: Any) -> str: """Return the raw provider response body when the SDK exposes it.""" if hasattr(response, "model_dump_json"): try: return response.model_dump_json(indent=2) except TypeError: return response.model_dump_json() if hasattr(response, "model_dump"): return json.dumps(response.model_dump(), indent=2, ensure_ascii=False) if hasattr(response, "to_dict"): return json.dumps(response.to_dict(), indent=2, ensure_ascii=False) return getattr(response, "output_text", "") or str(response) @staticmethod def _usage_payload(value: Any) -> dict[str, Any] | None: """Return a plain JSON-serializable usage dict from an SDK usage object.""" if value is None: return None if isinstance(value, dict): return dict(value) if hasattr(value, "model_dump"): dumped = value.model_dump() return dumped if isinstance(dumped, dict) else None if hasattr(value, "to_dict"): dumped = value.to_dict() return dumped if isinstance(dumped, dict) else None payload: dict[str, Any] = {} for name in ( "input_tokens", "output_tokens", "total_tokens", "input_tokens_details", "output_tokens_details", ): field_value = getattr(value, name, None) if field_value is None: continue if isinstance(field_value, dict): payload[name] = dict(field_value) elif hasattr(field_value, "model_dump"): dumped = field_value.model_dump() payload[name] = dumped if isinstance(dumped, dict) else field_value elif hasattr(field_value, "to_dict"): dumped = field_value.to_dict() payload[name] = dumped if isinstance(dumped, dict) else field_value else: payload[name] = field_value return payload or None
[docs] def decide( self, *, agent_id: str | None, provider_name: str, model_names: tuple[str, ...], temperature: float, context: ModelContext, ) -> ModelResponse: """Request a structured decision from the OpenAI Responses API.""" if provider_name != "openai": raise ValueError(f"Unsupported provider: {provider_name}") if not self.api_key: raise ValueError("OPENAI_API_KEY is required for OpenAI-backed agents.") if context.exact_input_payload is not None: model_input = context.exact_input_payload elif context.messages: model_input = list(context.messages) else: model_input = [ {"role": "system", "content": context.system_prompt}, {"role": "user", "content": context.user_prompt}, ] def _try_model(model_name: str) -> ModelResponse: client = self._get_client() fmt_dict = resolved_response_format_dict(context) request_trace_payload: Any = model_input if fmt_dict is not None: request_trace_payload = { "input": model_input, "text": {"format": openai_responses_text_format_field(fmt_dict)}, } if callable(self.on_request_trace): self.on_request_trace( ProviderRequestTrace( agent_id=agent_id, provider_name=provider_name, model_name=model_name, input_payload=request_trace_payload, temperature=temperature, run_id=context.run_id, ) ) create_kwargs: dict[str, Any] = { "model": model_name, "temperature": temperature, "input": model_input, } if fmt_dict is not None: create_kwargs["text"] = { "format": openai_responses_text_format_field(fmt_dict), } response = client.responses.create(**create_kwargs) raw_response_text = self._raw_response_text(response) raw_text = response.output_text.strip() raw_usage = self._usage_payload(getattr(response, "usage", None)) usage = normalize_openai_usage(getattr(response, "usage", None)) parsed_payload: dict[str, object] | None = None normalized_text = raw_text if context.response_mode == "text": parsed_payload = {"kind": "final_message", "message": raw_text} else: try: parsed_payload, normalized_text = parse_json_object_model_output( raw_text, provider_label="OpenAI", ) except Exception: if callable(self.on_response_trace): self.on_response_trace( ProviderResponseTrace( agent_id=agent_id, provider_name=provider_name, model_name=model_name, raw_text=raw_response_text, parsed_payload=None, usage=usage, raw_usage=raw_usage, run_id=context.run_id, ) ) raise if callable(self.on_response_trace): self.on_response_trace( ProviderResponseTrace( agent_id=agent_id, provider_name=provider_name, model_name=model_name, raw_text=raw_response_text, parsed_payload=parsed_payload, usage=usage, raw_usage=raw_usage, run_id=context.run_id, ) ) return ModelResponse( payload=parsed_payload, raw_text=normalized_text, usage=usage, raw_usage=raw_usage, ) return self._fallback_decide(model_names, _try_model)
__all__ = ["OpenAiModelDriver"]