Files

167 lines
5.5 KiB
Python

from __future__ import annotations
import sqlite3
from contextlib import suppress
from pathlib import Path
from typing import Iterable
SQLITE_BUSY_TIMEOUT_MS = 30000
RESOLVER_STATS_DB_FILENAME = "resolver_stats.db"
SCHEMA_STATEMENTS = [
"""
CREATE TABLE IF NOT EXISTS resolver_source_stats (
origin_source TEXT NOT NULL,
candidate_source TEXT NOT NULL,
attempt_count INTEGER NOT NULL DEFAULT 0,
resolve_success_count INTEGER NOT NULL DEFAULT 0,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP,
last_attempt_at TEXT,
last_success_at TEXT,
PRIMARY KEY(origin_source, candidate_source)
)
""",
"""
CREATE INDEX IF NOT EXISTS idx_resolver_source_stats_origin_source
ON resolver_source_stats (origin_source)
""",
]
def default_resolver_stats_db_path(db_path: str | Path) -> Path:
return Path(db_path).parent / RESOLVER_STATS_DB_FILENAME
def connect_resolver_stats_database(db_path: str | Path) -> sqlite3.Connection:
path = Path(db_path)
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(path, timeout=SQLITE_BUSY_TIMEOUT_MS / 1000)
conn.row_factory = sqlite3.Row
conn.execute(f"PRAGMA busy_timeout = {SQLITE_BUSY_TIMEOUT_MS}")
with suppress(sqlite3.OperationalError):
conn.execute("PRAGMA journal_mode = WAL")
with suppress(sqlite3.OperationalError):
conn.execute("PRAGMA synchronous = NORMAL")
return conn
def initialize_resolver_stats_database(db_path: str | Path) -> sqlite3.Connection:
conn = connect_resolver_stats_database(db_path)
for statement in SCHEMA_STATEMENTS:
conn.execute(statement)
conn.commit()
return conn
class ResolverStatsRepository:
def __init__(self, db_path: str | Path):
self.db_path = Path(db_path)
conn = initialize_resolver_stats_database(self.db_path)
conn.close()
def record_fallback_result(
self,
origin_source: str,
candidate_source: str,
*,
succeeded: bool,
) -> None:
conn = connect_resolver_stats_database(self.db_path)
try:
conn.execute(
"""
INSERT INTO resolver_source_stats (
origin_source,
candidate_source,
attempt_count,
resolve_success_count,
created_at,
updated_at,
last_attempt_at,
last_success_at
)
VALUES (
?, ?, 1, ?,
CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP,
CASE WHEN ? THEN CURRENT_TIMESTAMP ELSE NULL END
)
ON CONFLICT(origin_source, candidate_source) DO UPDATE SET
attempt_count = attempt_count + 1,
resolve_success_count = (
resolve_success_count + excluded.resolve_success_count
),
updated_at = CURRENT_TIMESTAMP,
last_attempt_at = CURRENT_TIMESTAMP,
last_success_at = CASE
WHEN excluded.resolve_success_count > 0
THEN CURRENT_TIMESTAMP
ELSE last_success_at
END
""",
(
origin_source,
candidate_source,
int(succeeded),
int(succeeded),
),
)
conn.commit()
finally:
conn.close()
def rank_fallback_sources(
self,
origin_source: str,
fallback_sources: Iterable[str],
*,
warmup_attempts: int = 1000,
) -> list[str]:
sources = list(fallback_sources)
if len(sources) <= 1:
return sources
conn = connect_resolver_stats_database(self.db_path)
try:
if warmup_attempts > 0:
row = conn.execute(
"""
SELECT COALESCE(SUM(attempt_count), 0) AS total_attempt_count
FROM resolver_source_stats
WHERE origin_source = ?
""",
(origin_source,),
).fetchone()
total_attempt_count = int(row["total_attempt_count"] if row else 0)
if total_attempt_count < warmup_attempts:
return sources
placeholders = ", ".join("?" for _ in sources)
rows = conn.execute(
f"""
SELECT candidate_source, attempt_count, resolve_success_count
FROM resolver_source_stats
WHERE origin_source = ? AND candidate_source IN ({placeholders})
""",
(origin_source, *sources),
).fetchall()
finally:
conn.close()
stats_by_source = {
str(row["candidate_source"]): (
int(row["attempt_count"]),
int(row["resolve_success_count"]),
)
for row in rows
}
order_index = {source: idx for idx, source in enumerate(sources)}
def _sort_key(source: str) -> tuple[float, int]:
attempts, successes = stats_by_source.get(source, (0, 0))
smoothed_success_rate = (successes + 1) / (attempts + 2)
return (-smoothed_success_rate, order_index[source])
return sorted(sources, key=_sort_key)