Source code for agent_framework_evaluator.usage

"""Evaluator-side usage aggregation from trace events."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

from agent_framework.model import LlmUsage
from agent_framework.tracing import TraceEvent


[docs] @dataclass(slots=True) class UsageTotals: input_tokens: int = 0 input_cached_tokens: int = 0 output_tokens: int = 0 output_cached_tokens: int = 0 total_tokens: int = 0
[docs] def add(self, usage: LlmUsage | dict[str, int] | None) -> None: if usage is None: return payload = usage.to_dict() if isinstance(usage, LlmUsage) else dict(usage) self.input_tokens += int(payload.get("input_tokens") or 0) self.input_cached_tokens += int(payload.get("input_cached_tokens") or 0) self.output_tokens += int(payload.get("output_tokens") or 0) self.output_cached_tokens += int(payload.get("output_cached_tokens") or 0) self.total_tokens += int(payload.get("total_tokens") or 0)
[docs] def copy(self) -> "UsageTotals": return UsageTotals( input_tokens=self.input_tokens, input_cached_tokens=self.input_cached_tokens, output_tokens=self.output_tokens, output_cached_tokens=self.output_cached_tokens, total_tokens=self.total_tokens, )
[docs] def replace(self, usage: LlmUsage | dict[str, int] | None) -> None: self.input_tokens = 0 self.input_cached_tokens = 0 self.output_tokens = 0 self.output_cached_tokens = 0 self.total_tokens = 0 self.add(usage)
[docs] def to_dict(self) -> dict[str, int]: return { "input_tokens": self.input_tokens, "input_cached_tokens": self.input_cached_tokens, "output_tokens": self.output_tokens, "output_cached_tokens": self.output_cached_tokens, "total_tokens": self.total_tokens, }
[docs] @dataclass(slots=True) class AgentUsageSummary: agent_id: str run_id: str parent_run_id: str | None self_totals: UsageTotals = field(default_factory=UsageTotals) inclusive_totals: UsageTotals = field(default_factory=UsageTotals) llm_calls: list[dict[str, Any]] = field(default_factory=list)
[docs] def to_dict(self) -> dict[str, Any]: return { "agent_id": self.agent_id, "run_id": self.run_id, "parent_run_id": self.parent_run_id, "self_totals": self.self_totals.to_dict(), "inclusive_totals": self.inclusive_totals.to_dict(), "llm_calls": list(self.llm_calls), }
[docs] @dataclass(slots=True) class SessionUsageSummary: session_totals: UsageTotals = field(default_factory=UsageTotals) agents: dict[str, Any] = field(default_factory=dict) runs: dict[str, AgentUsageSummary] = field(default_factory=dict)
[docs] def to_dict(self) -> dict[str, Any]: return { "session_totals": self.session_totals.to_dict(), "agents": dict(self.agents), "runs": {run_id: summary.to_dict() for run_id, summary in self.runs.items()}, }
[docs] class EvaluatorUsageTracker: """Consume trace events and expose a stable usage summary snapshot.""" def __init__(self) -> None: self._runs: dict[str, AgentUsageSummary] = {} self._session_totals = UsageTotals() self._session_totals_from_runtime = False
[docs] def reset(self) -> None: self._runs.clear() self._session_totals = UsageTotals() self._session_totals_from_runtime = False
[docs] def consume_trace_event(self, event: dict[str, Any] | TraceEvent) -> None: raw = event if isinstance(event, dict) else { "kind": event.kind, "context": { "run_id": getattr(event.context, "run_id", None), "agent_id": getattr(event.context, "agent_id", None), }, "payload": event.payload or {}, } kind = str(raw.get("kind") or "") payload = raw.get("payload") payload = payload if isinstance(payload, dict) else {} context = raw.get("context") context = context if isinstance(context, dict) else {} if kind == "runtime.audit.agent_call_started": run_id = str(payload.get("run_id") or context.get("run_id") or "") if not run_id: return agent_id = str(payload.get("agent_name") or context.get("agent_id") or "") parent_run_id = payload.get("parent_run_id") self._runs.setdefault( run_id, AgentUsageSummary( agent_id=agent_id, run_id=run_id, parent_run_id=str(parent_run_id) if parent_run_id else None, ), ) return if kind == "runtime.audit.agent_call_finished": run_id = str(payload.get("run_id") or context.get("run_id") or "") if not run_id: return summary = self._ensure_run( run_id=run_id, agent_id=str(payload.get("agent_id") or context.get("agent_id") or ""), parent_run_id=None, ) if isinstance(payload.get("usage_self"), dict): summary.self_totals.replace(payload["usage_self"]) if isinstance(payload.get("usage_inclusive"), dict): summary.inclusive_totals.replace(payload["usage_inclusive"]) return if kind == "llm.response": run_id = str(payload.get("run_id") or context.get("run_id") or "") if not run_id: return summary = self._ensure_run( run_id=run_id, agent_id=str(payload.get("agent_id") or context.get("agent_id") or ""), parent_run_id=None, ) usage = payload.get("usage") usage_dict = dict(usage) if isinstance(usage, dict) else None summary.llm_calls.append( { "provider_name": payload.get("provider_name"), "model_name": payload.get("model_name"), "usage": usage_dict, "raw_usage": payload.get("raw_usage"), } ) summary.self_totals.add(usage_dict) if not self._session_totals_from_runtime: self._session_totals.add(usage_dict) return if kind == "runtime.agent_finished": run_id = str(context.get("run_id") or payload.get("run_id") or "") if not run_id: return summary = self._ensure_run( run_id=run_id, agent_id=str(context.get("agent_id") or payload.get("agent_id") or ""), parent_run_id=None, ) if isinstance(payload.get("usage_self"), dict): summary.self_totals.replace(payload["usage_self"]) if isinstance(payload.get("usage_inclusive"), dict): summary.inclusive_totals.replace(payload["usage_inclusive"]) return if kind == "runtime.session_finished": usage_session_totals = payload.get("usage_session_totals") if isinstance(usage_session_totals, dict): self._session_totals.replace(usage_session_totals) self._session_totals_from_runtime = True
[docs] def snapshot(self) -> dict[str, Any]: runs: dict[str, AgentUsageSummary] = {} for run_id, summary in self._runs.items(): if summary.inclusive_totals.total_tokens == 0: summary.inclusive_totals.replace(self._recompute_inclusive(run_id).to_dict()) runs[run_id] = summary agents: dict[str, Any] = {} for summary in runs.values(): agent_id = summary.agent_id or "(unknown)" entry = agents.setdefault( agent_id, { "agent_id": agent_id, "run_ids": [], "self_totals": UsageTotals(), "inclusive_totals": UsageTotals(), }, ) entry["run_ids"].append(summary.run_id) entry["self_totals"].add(summary.self_totals.to_dict()) entry["inclusive_totals"].add(summary.inclusive_totals.to_dict()) agent_payloads = { agent_id: { "agent_id": data["agent_id"], "run_ids": data["run_ids"], "self_totals": data["self_totals"].to_dict(), "inclusive_totals": data["inclusive_totals"].to_dict(), } for agent_id, data in agents.items() } return SessionUsageSummary( session_totals=self._session_totals.copy(), agents=agent_payloads, runs=runs, ).to_dict()
def _recompute_inclusive(self, run_id: str) -> UsageTotals: summary = self._runs.get(run_id) if summary is None: return UsageTotals() totals = summary.self_totals.copy() child_ids = [ child_run_id for child_run_id, child in self._runs.items() if child.parent_run_id == run_id ] for child_run_id in child_ids: totals.add(self._recompute_inclusive(child_run_id).to_dict()) return totals def _ensure_run(self, *, run_id: str, agent_id: str, parent_run_id: str | None) -> AgentUsageSummary: summary = self._runs.get(run_id) if summary is None: summary = AgentUsageSummary(agent_id=agent_id, run_id=run_id, parent_run_id=parent_run_id) self._runs[run_id] = summary elif agent_id and not summary.agent_id: summary.agent_id = agent_id if parent_run_id and not summary.parent_run_id: summary.parent_run_id = parent_run_id return summary
__all__ = [ "AgentUsageSummary", "EvaluatorUsageTracker", "SessionUsageSummary", "UsageTotals", ]