167 lines
5.5 KiB
Python
167 lines
5.5 KiB
Python
from __future__ import annotations
|
|
|
|
import sqlite3
|
|
from contextlib import suppress
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
SQLITE_BUSY_TIMEOUT_MS = 30000
|
|
RESOLVER_STATS_DB_FILENAME = "resolver_stats.db"
|
|
|
|
|
|
SCHEMA_STATEMENTS = [
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS resolver_source_stats (
|
|
origin_source TEXT NOT NULL,
|
|
candidate_source TEXT NOT NULL,
|
|
attempt_count INTEGER NOT NULL DEFAULT 0,
|
|
resolve_success_count INTEGER NOT NULL DEFAULT 0,
|
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
|
last_attempt_at TEXT,
|
|
last_success_at TEXT,
|
|
PRIMARY KEY(origin_source, candidate_source)
|
|
)
|
|
""",
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_resolver_source_stats_origin_source
|
|
ON resolver_source_stats (origin_source)
|
|
""",
|
|
]
|
|
|
|
|
|
def default_resolver_stats_db_path(db_path: str | Path) -> Path:
|
|
return Path(db_path).parent / RESOLVER_STATS_DB_FILENAME
|
|
|
|
|
|
def connect_resolver_stats_database(db_path: str | Path) -> sqlite3.Connection:
|
|
path = Path(db_path)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
conn = sqlite3.connect(path, timeout=SQLITE_BUSY_TIMEOUT_MS / 1000)
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute(f"PRAGMA busy_timeout = {SQLITE_BUSY_TIMEOUT_MS}")
|
|
with suppress(sqlite3.OperationalError):
|
|
conn.execute("PRAGMA journal_mode = WAL")
|
|
with suppress(sqlite3.OperationalError):
|
|
conn.execute("PRAGMA synchronous = NORMAL")
|
|
return conn
|
|
|
|
|
|
def initialize_resolver_stats_database(db_path: str | Path) -> sqlite3.Connection:
|
|
conn = connect_resolver_stats_database(db_path)
|
|
for statement in SCHEMA_STATEMENTS:
|
|
conn.execute(statement)
|
|
conn.commit()
|
|
return conn
|
|
|
|
|
|
class ResolverStatsRepository:
|
|
def __init__(self, db_path: str | Path):
|
|
self.db_path = Path(db_path)
|
|
conn = initialize_resolver_stats_database(self.db_path)
|
|
conn.close()
|
|
|
|
def record_fallback_result(
|
|
self,
|
|
origin_source: str,
|
|
candidate_source: str,
|
|
*,
|
|
succeeded: bool,
|
|
) -> None:
|
|
conn = connect_resolver_stats_database(self.db_path)
|
|
try:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO resolver_source_stats (
|
|
origin_source,
|
|
candidate_source,
|
|
attempt_count,
|
|
resolve_success_count,
|
|
created_at,
|
|
updated_at,
|
|
last_attempt_at,
|
|
last_success_at
|
|
)
|
|
VALUES (
|
|
?, ?, 1, ?,
|
|
CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP,
|
|
CASE WHEN ? THEN CURRENT_TIMESTAMP ELSE NULL END
|
|
)
|
|
ON CONFLICT(origin_source, candidate_source) DO UPDATE SET
|
|
attempt_count = attempt_count + 1,
|
|
resolve_success_count = (
|
|
resolve_success_count + excluded.resolve_success_count
|
|
),
|
|
updated_at = CURRENT_TIMESTAMP,
|
|
last_attempt_at = CURRENT_TIMESTAMP,
|
|
last_success_at = CASE
|
|
WHEN excluded.resolve_success_count > 0
|
|
THEN CURRENT_TIMESTAMP
|
|
ELSE last_success_at
|
|
END
|
|
""",
|
|
(
|
|
origin_source,
|
|
candidate_source,
|
|
int(succeeded),
|
|
int(succeeded),
|
|
),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
def rank_fallback_sources(
|
|
self,
|
|
origin_source: str,
|
|
fallback_sources: Iterable[str],
|
|
*,
|
|
warmup_attempts: int = 1000,
|
|
) -> list[str]:
|
|
sources = list(fallback_sources)
|
|
if len(sources) <= 1:
|
|
return sources
|
|
|
|
conn = connect_resolver_stats_database(self.db_path)
|
|
try:
|
|
if warmup_attempts > 0:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT COALESCE(SUM(attempt_count), 0) AS total_attempt_count
|
|
FROM resolver_source_stats
|
|
WHERE origin_source = ?
|
|
""",
|
|
(origin_source,),
|
|
).fetchone()
|
|
total_attempt_count = int(row["total_attempt_count"] if row else 0)
|
|
if total_attempt_count < warmup_attempts:
|
|
return sources
|
|
|
|
placeholders = ", ".join("?" for _ in sources)
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT candidate_source, attempt_count, resolve_success_count
|
|
FROM resolver_source_stats
|
|
WHERE origin_source = ? AND candidate_source IN ({placeholders})
|
|
""",
|
|
(origin_source, *sources),
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
|
|
stats_by_source = {
|
|
str(row["candidate_source"]): (
|
|
int(row["attempt_count"]),
|
|
int(row["resolve_success_count"]),
|
|
)
|
|
for row in rows
|
|
}
|
|
order_index = {source: idx for idx, source in enumerate(sources)}
|
|
|
|
def _sort_key(source: str) -> tuple[float, int]:
|
|
attempts, successes = stats_by_source.get(source, (0, 0))
|
|
smoothed_success_rate = (successes + 1) / (attempts + 2)
|
|
return (-smoothed_success_rate, order_index[source])
|
|
|
|
return sorted(sources, key=_sort_key)
|