Files

150 lines
6.2 KiB
Python

from concurrent.futures import ThreadPoolExecutor
import tempfile
import unittest
from pathlib import Path
def _fetch_stats_row(db_path: Path, origin_source: str, candidate_source: str):
from musicdl.catalogsync.resolver_stats import connect_resolver_stats_database
conn = connect_resolver_stats_database(db_path)
try:
return conn.execute(
"""
SELECT attempt_count, resolve_success_count,
created_at, updated_at, last_attempt_at, last_success_at
FROM resolver_source_stats
WHERE origin_source = ? AND candidate_source = ?
""",
(origin_source, candidate_source),
).fetchone()
finally:
conn.close()
class ResolverStatsRepositoryTests(unittest.TestCase):
def test_initialize_resolver_stats_database_creates_stats_table_with_timestamps(self):
from musicdl.catalogsync.resolver_stats import initialize_resolver_stats_database
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
db_path = Path(tmpdir) / "resolver_stats.db"
conn = initialize_resolver_stats_database(db_path)
try:
table_names = {
row["name"]
for row in conn.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
).fetchall()
}
column_names = {
row["name"]
for row in conn.execute(
"PRAGMA table_info(resolver_source_stats)"
).fetchall()
}
finally:
conn.close()
self.assertIn("resolver_source_stats", table_names)
self.assertTrue(
{"created_at", "updated_at", "last_attempt_at", "last_success_at"}
<= column_names
)
def test_record_fallback_result_tracks_attempt_and_success_timestamps(self):
from musicdl.catalogsync.resolver_stats import ResolverStatsRepository
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
db_path = Path(tmpdir) / "resolver_stats.db"
repo = ResolverStatsRepository(db_path)
repo.record_fallback_result("qq", "kuwo", succeeded=False)
first_row = _fetch_stats_row(db_path, "qq", "kuwo")
repo.record_fallback_result("qq", "kuwo", succeeded=True)
second_row = _fetch_stats_row(db_path, "qq", "kuwo")
self.assertEqual(1, int(first_row["attempt_count"]))
self.assertEqual(0, int(first_row["resolve_success_count"]))
self.assertIsNotNone(first_row["created_at"])
self.assertIsNotNone(first_row["updated_at"])
self.assertIsNotNone(first_row["last_attempt_at"])
self.assertIsNone(first_row["last_success_at"])
self.assertEqual(2, int(second_row["attempt_count"]))
self.assertEqual(1, int(second_row["resolve_success_count"]))
self.assertIsNotNone(second_row["created_at"])
self.assertIsNotNone(second_row["updated_at"])
self.assertIsNotNone(second_row["last_attempt_at"])
self.assertIsNotNone(second_row["last_success_at"])
def test_repository_operations_can_run_from_non_creator_thread(self):
from musicdl.catalogsync.resolver_stats import ResolverStatsRepository
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
repo = ResolverStatsRepository(Path(tmpdir) / "resolver_stats.db")
with ThreadPoolExecutor(max_workers=1) as executor:
record_future = executor.submit(
repo.record_fallback_result,
"qq",
"kuwo",
succeeded=True,
)
rank_future = executor.submit(
lambda: repo.rank_fallback_sources(
"qq",
["kuwo", "migu"],
warmup_attempts=0,
)
)
self.assertIsNone(record_future.exception())
self.assertEqual(["kuwo", "migu"], rank_future.result())
def test_rank_fallback_sources_keeps_config_order_before_warmup(self):
from musicdl.catalogsync.resolver_stats import ResolverStatsRepository
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
repo = ResolverStatsRepository(Path(tmpdir) / "resolver_stats.db")
repo.record_fallback_result("qq", "kuwo", succeeded=True)
ranked = repo.rank_fallback_sources(
"qq",
["kuwo", "migu", "qianqian"],
warmup_attempts=1000,
)
self.assertEqual(["kuwo", "migu", "qianqian"], ranked)
def test_rank_fallback_sources_reorders_after_warmup_per_origin_source(self):
from musicdl.catalogsync.resolver_stats import ResolverStatsRepository
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
repo = ResolverStatsRepository(Path(tmpdir) / "resolver_stats.db")
for _ in range(800):
repo.record_fallback_result("qq", "migu", succeeded=True)
for _ in range(200):
repo.record_fallback_result("qq", "kuwo", succeeded=False)
ranked = repo.rank_fallback_sources(
"qq",
["kuwo", "migu", "qianqian"],
warmup_attempts=1000,
)
self.assertEqual(["migu", "qianqian", "kuwo"], ranked)
def test_rank_fallback_sources_uses_config_order_as_tie_breaker(self):
from musicdl.catalogsync.resolver_stats import ResolverStatsRepository
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
repo = ResolverStatsRepository(Path(tmpdir) / "resolver_stats.db")
for _ in range(5):
repo.record_fallback_result("qq", "kuwo", succeeded=True)
repo.record_fallback_result("qq", "migu", succeeded=True)
ranked = repo.rank_fallback_sources(
"qq",
["kuwo", "migu", "qianqian"],
warmup_attempts=10,
)
self.assertEqual(["kuwo", "migu", "qianqian"], ranked)