from __future__ import annotations import sqlite3 from contextlib import suppress from pathlib import Path from typing import Iterable SQLITE_BUSY_TIMEOUT_MS = 30000 RESOLVER_STATS_DB_FILENAME = "resolver_stats.db" SCHEMA_STATEMENTS = [ """ CREATE TABLE IF NOT EXISTS resolver_source_stats ( origin_source TEXT NOT NULL, candidate_source TEXT NOT NULL, attempt_count INTEGER NOT NULL DEFAULT 0, resolve_success_count INTEGER NOT NULL DEFAULT 0, created_at TEXT DEFAULT CURRENT_TIMESTAMP, updated_at TEXT DEFAULT CURRENT_TIMESTAMP, last_attempt_at TEXT, last_success_at TEXT, PRIMARY KEY(origin_source, candidate_source) ) """, """ CREATE INDEX IF NOT EXISTS idx_resolver_source_stats_origin_source ON resolver_source_stats (origin_source) """, ] def default_resolver_stats_db_path(db_path: str | Path) -> Path: return Path(db_path).parent / RESOLVER_STATS_DB_FILENAME def connect_resolver_stats_database(db_path: str | Path) -> sqlite3.Connection: path = Path(db_path) path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(path, timeout=SQLITE_BUSY_TIMEOUT_MS / 1000) conn.row_factory = sqlite3.Row conn.execute(f"PRAGMA busy_timeout = {SQLITE_BUSY_TIMEOUT_MS}") with suppress(sqlite3.OperationalError): conn.execute("PRAGMA journal_mode = WAL") with suppress(sqlite3.OperationalError): conn.execute("PRAGMA synchronous = NORMAL") return conn def initialize_resolver_stats_database(db_path: str | Path) -> sqlite3.Connection: conn = connect_resolver_stats_database(db_path) for statement in SCHEMA_STATEMENTS: conn.execute(statement) conn.commit() return conn class ResolverStatsRepository: def __init__(self, db_path: str | Path): self.db_path = Path(db_path) conn = initialize_resolver_stats_database(self.db_path) conn.close() def record_fallback_result( self, origin_source: str, candidate_source: str, *, succeeded: bool, ) -> None: conn = connect_resolver_stats_database(self.db_path) try: conn.execute( """ INSERT INTO resolver_source_stats ( origin_source, candidate_source, attempt_count, resolve_success_count, created_at, updated_at, last_attempt_at, last_success_at ) VALUES ( ?, ?, 1, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, CASE WHEN ? THEN CURRENT_TIMESTAMP ELSE NULL END ) ON CONFLICT(origin_source, candidate_source) DO UPDATE SET attempt_count = attempt_count + 1, resolve_success_count = ( resolve_success_count + excluded.resolve_success_count ), updated_at = CURRENT_TIMESTAMP, last_attempt_at = CURRENT_TIMESTAMP, last_success_at = CASE WHEN excluded.resolve_success_count > 0 THEN CURRENT_TIMESTAMP ELSE last_success_at END """, ( origin_source, candidate_source, int(succeeded), int(succeeded), ), ) conn.commit() finally: conn.close() def rank_fallback_sources( self, origin_source: str, fallback_sources: Iterable[str], *, warmup_attempts: int = 1000, ) -> list[str]: sources = list(fallback_sources) if len(sources) <= 1: return sources conn = connect_resolver_stats_database(self.db_path) try: if warmup_attempts > 0: row = conn.execute( """ SELECT COALESCE(SUM(attempt_count), 0) AS total_attempt_count FROM resolver_source_stats WHERE origin_source = ? """, (origin_source,), ).fetchone() total_attempt_count = int(row["total_attempt_count"] if row else 0) if total_attempt_count < warmup_attempts: return sources placeholders = ", ".join("?" for _ in sources) rows = conn.execute( f""" SELECT candidate_source, attempt_count, resolve_success_count FROM resolver_source_stats WHERE origin_source = ? AND candidate_source IN ({placeholders}) """, (origin_source, *sources), ).fetchall() finally: conn.close() stats_by_source = { str(row["candidate_source"]): ( int(row["attempt_count"]), int(row["resolve_success_count"]), ) for row in rows } order_index = {source: idx for idx, source in enumerate(sources)} def _sort_key(source: str) -> tuple[float, int]: attempts, successes = stats_by_source.get(source, (0, 0)) smoothed_success_rate = (successes + 1) / (attempts + 2) return (-smoothed_success_rate, order_index[source]) return sorted(sources, key=_sort_key)