150 lines
6.2 KiB
Python
150 lines
6.2 KiB
Python
from concurrent.futures import ThreadPoolExecutor
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
|
|
def _fetch_stats_row(db_path: Path, origin_source: str, candidate_source: str):
|
|
from musicdl.catalogsync.resolver_stats import connect_resolver_stats_database
|
|
|
|
conn = connect_resolver_stats_database(db_path)
|
|
try:
|
|
return conn.execute(
|
|
"""
|
|
SELECT attempt_count, resolve_success_count,
|
|
created_at, updated_at, last_attempt_at, last_success_at
|
|
FROM resolver_source_stats
|
|
WHERE origin_source = ? AND candidate_source = ?
|
|
""",
|
|
(origin_source, candidate_source),
|
|
).fetchone()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
class ResolverStatsRepositoryTests(unittest.TestCase):
|
|
def test_initialize_resolver_stats_database_creates_stats_table_with_timestamps(self):
|
|
from musicdl.catalogsync.resolver_stats import initialize_resolver_stats_database
|
|
|
|
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
|
|
db_path = Path(tmpdir) / "resolver_stats.db"
|
|
conn = initialize_resolver_stats_database(db_path)
|
|
try:
|
|
table_names = {
|
|
row["name"]
|
|
for row in conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type = 'table'"
|
|
).fetchall()
|
|
}
|
|
column_names = {
|
|
row["name"]
|
|
for row in conn.execute(
|
|
"PRAGMA table_info(resolver_source_stats)"
|
|
).fetchall()
|
|
}
|
|
finally:
|
|
conn.close()
|
|
|
|
self.assertIn("resolver_source_stats", table_names)
|
|
self.assertTrue(
|
|
{"created_at", "updated_at", "last_attempt_at", "last_success_at"}
|
|
<= column_names
|
|
)
|
|
|
|
def test_record_fallback_result_tracks_attempt_and_success_timestamps(self):
|
|
from musicdl.catalogsync.resolver_stats import ResolverStatsRepository
|
|
|
|
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
|
|
db_path = Path(tmpdir) / "resolver_stats.db"
|
|
repo = ResolverStatsRepository(db_path)
|
|
repo.record_fallback_result("qq", "kuwo", succeeded=False)
|
|
first_row = _fetch_stats_row(db_path, "qq", "kuwo")
|
|
|
|
repo.record_fallback_result("qq", "kuwo", succeeded=True)
|
|
second_row = _fetch_stats_row(db_path, "qq", "kuwo")
|
|
|
|
self.assertEqual(1, int(first_row["attempt_count"]))
|
|
self.assertEqual(0, int(first_row["resolve_success_count"]))
|
|
self.assertIsNotNone(first_row["created_at"])
|
|
self.assertIsNotNone(first_row["updated_at"])
|
|
self.assertIsNotNone(first_row["last_attempt_at"])
|
|
self.assertIsNone(first_row["last_success_at"])
|
|
|
|
self.assertEqual(2, int(second_row["attempt_count"]))
|
|
self.assertEqual(1, int(second_row["resolve_success_count"]))
|
|
self.assertIsNotNone(second_row["created_at"])
|
|
self.assertIsNotNone(second_row["updated_at"])
|
|
self.assertIsNotNone(second_row["last_attempt_at"])
|
|
self.assertIsNotNone(second_row["last_success_at"])
|
|
|
|
def test_repository_operations_can_run_from_non_creator_thread(self):
|
|
from musicdl.catalogsync.resolver_stats import ResolverStatsRepository
|
|
|
|
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
|
|
repo = ResolverStatsRepository(Path(tmpdir) / "resolver_stats.db")
|
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
record_future = executor.submit(
|
|
repo.record_fallback_result,
|
|
"qq",
|
|
"kuwo",
|
|
succeeded=True,
|
|
)
|
|
rank_future = executor.submit(
|
|
lambda: repo.rank_fallback_sources(
|
|
"qq",
|
|
["kuwo", "migu"],
|
|
warmup_attempts=0,
|
|
)
|
|
)
|
|
|
|
self.assertIsNone(record_future.exception())
|
|
self.assertEqual(["kuwo", "migu"], rank_future.result())
|
|
|
|
def test_rank_fallback_sources_keeps_config_order_before_warmup(self):
|
|
from musicdl.catalogsync.resolver_stats import ResolverStatsRepository
|
|
|
|
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
|
|
repo = ResolverStatsRepository(Path(tmpdir) / "resolver_stats.db")
|
|
repo.record_fallback_result("qq", "kuwo", succeeded=True)
|
|
ranked = repo.rank_fallback_sources(
|
|
"qq",
|
|
["kuwo", "migu", "qianqian"],
|
|
warmup_attempts=1000,
|
|
)
|
|
|
|
self.assertEqual(["kuwo", "migu", "qianqian"], ranked)
|
|
|
|
def test_rank_fallback_sources_reorders_after_warmup_per_origin_source(self):
|
|
from musicdl.catalogsync.resolver_stats import ResolverStatsRepository
|
|
|
|
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
|
|
repo = ResolverStatsRepository(Path(tmpdir) / "resolver_stats.db")
|
|
for _ in range(800):
|
|
repo.record_fallback_result("qq", "migu", succeeded=True)
|
|
for _ in range(200):
|
|
repo.record_fallback_result("qq", "kuwo", succeeded=False)
|
|
ranked = repo.rank_fallback_sources(
|
|
"qq",
|
|
["kuwo", "migu", "qianqian"],
|
|
warmup_attempts=1000,
|
|
)
|
|
|
|
self.assertEqual(["migu", "qianqian", "kuwo"], ranked)
|
|
|
|
def test_rank_fallback_sources_uses_config_order_as_tie_breaker(self):
|
|
from musicdl.catalogsync.resolver_stats import ResolverStatsRepository
|
|
|
|
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
|
|
repo = ResolverStatsRepository(Path(tmpdir) / "resolver_stats.db")
|
|
for _ in range(5):
|
|
repo.record_fallback_result("qq", "kuwo", succeeded=True)
|
|
repo.record_fallback_result("qq", "migu", succeeded=True)
|
|
|
|
ranked = repo.rank_fallback_sources(
|
|
"qq",
|
|
["kuwo", "migu", "qianqian"],
|
|
warmup_attempts=10,
|
|
)
|
|
|
|
self.assertEqual(["kuwo", "migu", "qianqian"], ranked)
|