Initial import: Music_Server, MusicFree, catalog-sync
This commit is contained in:
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
SQLITE_BUSY_TIMEOUT_MS = 30000
|
||||
RESOLVER_STATS_DB_FILENAME = "resolver_stats.db"
|
||||
|
||||
|
||||
SCHEMA_STATEMENTS = [
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS resolver_source_stats (
|
||||
origin_source TEXT NOT NULL,
|
||||
candidate_source TEXT NOT NULL,
|
||||
attempt_count INTEGER NOT NULL DEFAULT 0,
|
||||
resolve_success_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
last_attempt_at TEXT,
|
||||
last_success_at TEXT,
|
||||
PRIMARY KEY(origin_source, candidate_source)
|
||||
)
|
||||
""",
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_resolver_source_stats_origin_source
|
||||
ON resolver_source_stats (origin_source)
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
def default_resolver_stats_db_path(db_path: str | Path) -> Path:
|
||||
return Path(db_path).parent / RESOLVER_STATS_DB_FILENAME
|
||||
|
||||
|
||||
def connect_resolver_stats_database(db_path: str | Path) -> sqlite3.Connection:
|
||||
path = Path(db_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(path, timeout=SQLITE_BUSY_TIMEOUT_MS / 1000)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(f"PRAGMA busy_timeout = {SQLITE_BUSY_TIMEOUT_MS}")
|
||||
with suppress(sqlite3.OperationalError):
|
||||
conn.execute("PRAGMA journal_mode = WAL")
|
||||
with suppress(sqlite3.OperationalError):
|
||||
conn.execute("PRAGMA synchronous = NORMAL")
|
||||
return conn
|
||||
|
||||
|
||||
def initialize_resolver_stats_database(db_path: str | Path) -> sqlite3.Connection:
|
||||
conn = connect_resolver_stats_database(db_path)
|
||||
for statement in SCHEMA_STATEMENTS:
|
||||
conn.execute(statement)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
class ResolverStatsRepository:
|
||||
def __init__(self, db_path: str | Path):
|
||||
self.db_path = Path(db_path)
|
||||
conn = initialize_resolver_stats_database(self.db_path)
|
||||
conn.close()
|
||||
|
||||
def record_fallback_result(
|
||||
self,
|
||||
origin_source: str,
|
||||
candidate_source: str,
|
||||
*,
|
||||
succeeded: bool,
|
||||
) -> None:
|
||||
conn = connect_resolver_stats_database(self.db_path)
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO resolver_source_stats (
|
||||
origin_source,
|
||||
candidate_source,
|
||||
attempt_count,
|
||||
resolve_success_count,
|
||||
created_at,
|
||||
updated_at,
|
||||
last_attempt_at,
|
||||
last_success_at
|
||||
)
|
||||
VALUES (
|
||||
?, ?, 1, ?,
|
||||
CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP,
|
||||
CASE WHEN ? THEN CURRENT_TIMESTAMP ELSE NULL END
|
||||
)
|
||||
ON CONFLICT(origin_source, candidate_source) DO UPDATE SET
|
||||
attempt_count = attempt_count + 1,
|
||||
resolve_success_count = (
|
||||
resolve_success_count + excluded.resolve_success_count
|
||||
),
|
||||
updated_at = CURRENT_TIMESTAMP,
|
||||
last_attempt_at = CURRENT_TIMESTAMP,
|
||||
last_success_at = CASE
|
||||
WHEN excluded.resolve_success_count > 0
|
||||
THEN CURRENT_TIMESTAMP
|
||||
ELSE last_success_at
|
||||
END
|
||||
""",
|
||||
(
|
||||
origin_source,
|
||||
candidate_source,
|
||||
int(succeeded),
|
||||
int(succeeded),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def rank_fallback_sources(
|
||||
self,
|
||||
origin_source: str,
|
||||
fallback_sources: Iterable[str],
|
||||
*,
|
||||
warmup_attempts: int = 1000,
|
||||
) -> list[str]:
|
||||
sources = list(fallback_sources)
|
||||
if len(sources) <= 1:
|
||||
return sources
|
||||
|
||||
conn = connect_resolver_stats_database(self.db_path)
|
||||
try:
|
||||
if warmup_attempts > 0:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT COALESCE(SUM(attempt_count), 0) AS total_attempt_count
|
||||
FROM resolver_source_stats
|
||||
WHERE origin_source = ?
|
||||
""",
|
||||
(origin_source,),
|
||||
).fetchone()
|
||||
total_attempt_count = int(row["total_attempt_count"] if row else 0)
|
||||
if total_attempt_count < warmup_attempts:
|
||||
return sources
|
||||
|
||||
placeholders = ", ".join("?" for _ in sources)
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT candidate_source, attempt_count, resolve_success_count
|
||||
FROM resolver_source_stats
|
||||
WHERE origin_source = ? AND candidate_source IN ({placeholders})
|
||||
""",
|
||||
(origin_source, *sources),
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
stats_by_source = {
|
||||
str(row["candidate_source"]): (
|
||||
int(row["attempt_count"]),
|
||||
int(row["resolve_success_count"]),
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
order_index = {source: idx for idx, source in enumerate(sources)}
|
||||
|
||||
def _sort_key(source: str) -> tuple[float, int]:
|
||||
attempts, successes = stats_by_source.get(source, (0, 0))
|
||||
smoothed_success_rate = (successes + 1) / (attempts + 2)
|
||||
return (-smoothed_success_rate, order_index[source])
|
||||
|
||||
return sorted(sources, key=_sort_key)
|
||||
Reference in New Issue
Block a user