from __future__ import annotations import copy import re from typing import Any, Callable from .models import normalize_source_name, parse_size_to_bytes SOURCE_CLIENT_NAMES = { "netease": "NeteaseMusicClient", "qq": "QQMusicClient", "kuwo": "KuwoMusicClient", "migu": "MiguMusicClient", "qianqian": "QianqianMusicClient", "kugou": "KugouMusicClient", } DEFAULT_DOWNLOAD_SOURCES = ["qq", "kuwo", "migu", "qianqian", "kugou", "netease"] DEFAULT_FALLBACK_RANK_WARMUP_ATTEMPTS = 1000 LOSSLESS_EXTENSIONS = {"flac", "wav", "alac", "ape", "wv", "tta", "dsf", "dff"} ARTIST_SEPARATOR_RE = re.compile(r"\s*(?:/|,|&|\|)\s*") def normalize_audio_ext(value: str | None) -> str: return str(value or "").strip().lower().lstrip(".") def normalize_keyword(value: str | None) -> str: return " ".join(str(value or "").strip().lower().split()) def normalize_artist_keyword(value: str | None) -> str: normalized = normalize_keyword(value) for token in ("&", "/", "\\", ",", "|", ";"): normalized = normalized.replace(token, " ") return " ".join(normalized.split()) def dedupe_preserve_order(values: list[str]) -> list[str]: seen: set[str] = set() result: list[str] = [] for value in values: normalized = normalize_source_name(value) if normalized in seen: continue seen.add(normalized) result.append(normalized) return result def candidate_file_size_bytes(song_info: Any) -> int: size_bytes = getattr(song_info, "file_size_bytes", None) if isinstance(size_bytes, (int, float)) and size_bytes > 0: return int(size_bytes) return int(parse_size_to_bytes(getattr(song_info, "file_size", None)) or 0) def search_result_quality_group(song_info: Any) -> int: ext_candidates = [ normalize_audio_ext(getattr(song_info, "ext", None)), normalize_audio_ext(getattr(song_info, "codec", None)), ] download_url_status = getattr(song_info, "download_url_status", None) if isinstance(download_url_status, dict): probe_status = download_url_status.get("probe_status") or {} ext_candidates.append(normalize_audio_ext(probe_status.get("ext"))) for ext in ext_candidates: if not ext: continue if ext in LOSSLESS_EXTENSIONS: return 0 if ext == "mp3": return 1 return 2 def song_info_match_priority(candidate_song_info: Any, target_song_info: Any) -> int: candidate_source = normalize_source_name(getattr(candidate_song_info, "source", None)) target_source = normalize_source_name(getattr(target_song_info, "source", None)) candidate_identifier = str(getattr(candidate_song_info, "identifier", "") or "").strip() target_identifier = str(getattr(target_song_info, "identifier", "") or "").strip() candidate_song_name = normalize_keyword(getattr(candidate_song_info, "song_name", None)) target_song_name = normalize_keyword(getattr(target_song_info, "song_name", None)) candidate_singers = normalize_artist_keyword(getattr(candidate_song_info, "singers", None)) target_singers = normalize_artist_keyword(getattr(target_song_info, "singers", None)) if candidate_source == target_source and candidate_identifier and target_identifier and candidate_identifier == target_identifier: return 0 if candidate_song_name and target_song_name and candidate_song_name == target_song_name and candidate_singers and target_singers and candidate_singers == target_singers: return 1 if candidate_song_name and target_song_name and candidate_song_name == target_song_name: return 2 return 99 def match_priority_group(match_priority: int) -> int: if match_priority >= 99: return 99 if match_priority <= 1: return 0 return 1 def is_high_confidence_match(match_priority: int) -> bool: return match_priority_group(match_priority) == 0 def build_resolve_keyword(song_info: Any, row: dict[str, Any]) -> str: keyword_parts: list[str] = [] for value in ( getattr(song_info, "song_name", None), row.get("name"), getattr(song_info, "singers", None), row.get("singers"), ): text = str(value or "").strip() if text and text.upper() != "NULL" and text not in keyword_parts: keyword_parts.append(text) if keyword_parts: return " ".join(keyword_parts) return str(getattr(song_info, "identifier", None) or row.get("remote_song_id") or "").strip() def merge_resolved_song_info(base_song_info: Any, resolved_song_info: Any) -> Any: if not resolved_song_info or not getattr(resolved_song_info, "with_valid_download_url", False): return copy.deepcopy(base_song_info) merged_song_info = copy.deepcopy(resolved_song_info) merged_song_info.work_dir = getattr(base_song_info, "work_dir", getattr(merged_song_info, "work_dir", None)) if not isinstance(getattr(merged_song_info, "raw_data", None), dict): merged_song_info.raw_data = {} base_raw_data = getattr(base_song_info, "raw_data", None) if isinstance(base_raw_data, dict) and "search" in base_raw_data and "search" not in merged_song_info.raw_data: merged_song_info.raw_data["search"] = copy.deepcopy(base_raw_data["search"]) merged_song_info.raw_data["deferred_search"] = False if not getattr(merged_song_info, "source", None): merged_song_info.source = getattr(base_song_info, "source", None) if not getattr(merged_song_info, "root_source", None): merged_song_info.root_source = getattr(base_song_info, "root_source", None) for attr in ("song_name", "singers", "album", "duration_s", "duration", "cover_url"): current_value = getattr(merged_song_info, attr, None) fallback_value = getattr(base_song_info, attr, None) if current_value in {None, "", "NULL", "-:-:-"} and fallback_value not in {None, "", "NULL"}: setattr(merged_song_info, attr, fallback_value) if not getattr(merged_song_info, "ext", None): merged_song_info.ext = getattr(base_song_info, "ext", None) if not getattr(merged_song_info, "file_size_bytes", None): merged_song_info.file_size_bytes = getattr(base_song_info, "file_size_bytes", None) if not getattr(merged_song_info, "file_size", None): merged_song_info.file_size = getattr(base_song_info, "file_size", None) return merged_song_info class MultiSourceSongResolver: def __init__( self, client_factory: Callable[[str], object], request_overrides_factory: Callable[[tuple[int, int]], dict] | None = None, resolver_stats_repo: Any | None = None, warmup_attempts: int = DEFAULT_FALLBACK_RANK_WARMUP_ATTEMPTS, ): self.client_factory = client_factory self.request_overrides_factory = request_overrides_factory or (lambda timeout: {"timeout": timeout}) self.resolver_stats_repo = resolver_stats_repo self.warmup_attempts = max(0, int(warmup_attempts)) @staticmethod def _has_valid_download_url(song_info: Any) -> bool: return bool(getattr(song_info, "with_valid_download_url", False)) def _request_overrides(self, timeout: tuple[int, int]) -> dict: return dict(self.request_overrides_factory(timeout)) @staticmethod def _emit_progress(progress_callback: Callable[[str], None] | None, message: str) -> None: if progress_callback is None: return progress_callback(str(message)) def _refresh_song_info(self, client: object, song_info: Any) -> Any: if self._has_valid_download_url(song_info): return copy.deepcopy(song_info) raw_data = getattr(song_info, "raw_data", None) search_result = raw_data.get("search") if isinstance(raw_data, dict) else None if not isinstance(search_result, dict): return copy.deepcopy(song_info) request_overrides = self._request_overrides((10, 30)) third_party_song = None if hasattr(client, "_parsewiththirdpartapis"): try: third_party_song = client._parsewiththirdpartapis( search_result=search_result, request_overrides=request_overrides, ) except Exception: third_party_song = None refreshed_song = None if hasattr(client, "_parsewithofficialapiv1"): try: kwargs = { "search_result": search_result, "request_overrides": request_overrides, } if third_party_song is not None: kwargs["song_info_flac"] = third_party_song refreshed_song = client._parsewithofficialapiv1(**kwargs) except TypeError: try: refreshed_song = client._parsewithofficialapiv1( search_result=search_result, request_overrides=request_overrides, ) except Exception: refreshed_song = None except Exception: refreshed_song = None for candidate in (refreshed_song, third_party_song): if not self._has_valid_download_url(candidate): continue return merge_resolved_song_info(song_info, candidate) return copy.deepcopy(song_info) def _search_source_candidates(self, source: str, keyword: str) -> list[Any]: if not keyword: return [] try: client = self.client_factory(source) results = client.search( keyword=keyword, num_threadings=1, request_overrides=self._request_overrides((10, 30)), rule={}, ) except Exception: return [] return list(results or []) def _pick_best_candidate(self, candidates: list[Any], target_song_info: Any, source_rank: int) -> Any: matched_candidates: list[tuple[Any, int, int]] = [] for candidate in candidates: if not self._has_valid_download_url(candidate): continue match_priority = song_info_match_priority(candidate, target_song_info) if match_priority >= 99: continue matched_candidates.append((candidate, match_priority, source_rank)) if not matched_candidates: return None matched_candidates.sort( key=lambda item: ( match_priority_group(item[1]), search_result_quality_group(item[0]), -candidate_file_size_bytes(item[0]), item[2], item[1], ) ) return matched_candidates[0][0] def _build_target_song_info(self, row: dict[str, Any], snapshot_song_info: Any): if snapshot_song_info is not None: return copy.deepcopy(snapshot_song_info) from musicdl.modules.utils.data import SongInfo return SongInfo( source=SOURCE_CLIENT_NAMES.get(normalize_source_name(row.get("platform"))), identifier=str(row.get("remote_song_id") or row.get("id") or ""), song_name=row.get("name"), singers=row.get("singers"), album=row.get("album"), ext=row.get("ext"), file_size_bytes=row.get("file_size_bytes"), raw_data={}, ) def _rank_fallback_sources(self, origin_source: str, fallback_sources: list[str]) -> list[str]: ordered_sources = dedupe_preserve_order(list(fallback_sources)) if len(ordered_sources) <= 1 or self.resolver_stats_repo is None: return ordered_sources try: ranked_sources = self.resolver_stats_repo.rank_fallback_sources( origin_source, ordered_sources, warmup_attempts=self.warmup_attempts, ) except Exception: return ordered_sources ranked_ordered_sources = dedupe_preserve_order(list(ranked_sources or [])) filtered_ranked_sources = [source for source in ranked_ordered_sources if source in ordered_sources] for source in ordered_sources: if source not in filtered_ranked_sources: filtered_ranked_sources.append(source) return filtered_ranked_sources def _record_fallback_result(self, origin_source: str, candidate_source: str, *, succeeded: bool) -> None: if self.resolver_stats_repo is None: return try: self.resolver_stats_repo.record_fallback_result( origin_source, candidate_source, succeeded=succeeded, ) except Exception: return def resolve_song_info( self, row: dict[str, Any], snapshot_song_info: Any, download_sources: list[str] | None = None, progress_callback: Callable[[str], None] | None = None, ) -> Any: target_song_info = self._build_target_song_info(row=row, snapshot_song_info=snapshot_song_info) preferred_source = normalize_source_name(getattr(target_song_info, "source", None) or row.get("platform")) ordered_sources = dedupe_preserve_order(list(download_sources or DEFAULT_DOWNLOAD_SOURCES)) keyword = build_resolve_keyword(target_song_info, row) candidate_rows: list[tuple[Any, int, int]] = [] fallback_sources = [source for source in ordered_sources if source != preferred_source] ranked_fallback_sources = self._rank_fallback_sources(preferred_source, fallback_sources) should_attempt_preferred = preferred_source not in {"", "unknown", None} total_attempts = len(ranked_fallback_sources) + (1 if should_attempt_preferred else 0) if should_attempt_preferred: source_rank = 0 self._emit_progress( progress_callback, f"resolving source {preferred_source} ({source_rank + 1}/{total_attempts})", ) try: client = self.client_factory(preferred_source) refreshed_song = self._refresh_song_info(client, target_song_info) if self._has_valid_download_url(refreshed_song): merged_refreshed = merge_resolved_song_info(target_song_info, refreshed_song) refreshed_match_priority = song_info_match_priority(merged_refreshed, target_song_info) candidate_rows.append((merged_refreshed, refreshed_match_priority, source_rank)) if is_high_confidence_match(refreshed_match_priority): return merged_refreshed search_candidates = self._search_source_candidates(preferred_source, keyword) best_candidate = self._pick_best_candidate(search_candidates, target_song_info, source_rank) if best_candidate is not None: merged_candidate = merge_resolved_song_info(target_song_info, best_candidate) match_priority = song_info_match_priority(merged_candidate, target_song_info) candidate_rows.append((merged_candidate, match_priority, source_rank)) if is_high_confidence_match(match_priority): return merged_candidate except Exception: pass fallback_start_rank = 2 if should_attempt_preferred else 1 for source_rank, source in enumerate(ranked_fallback_sources, start=fallback_start_rank): self._emit_progress( progress_callback, f"resolving source {source} ({source_rank}/{total_attempts})", ) search_candidates = self._search_source_candidates(source, keyword) best_candidate = self._pick_best_candidate(search_candidates, target_song_info, source_rank - 1) if best_candidate is None: self._record_fallback_result(preferred_source, source, succeeded=False) continue self._record_fallback_result(preferred_source, source, succeeded=True) return merge_resolved_song_info(target_song_info, best_candidate) if not candidate_rows: return target_song_info candidate_rows.sort( key=lambda item: ( match_priority_group(item[1]), search_result_quality_group(item[0]), -candidate_file_size_bytes(item[0]), item[2], item[1], ) ) return candidate_rows[0][0]