import sqlite3 import tempfile import unittest from datetime import datetime, timedelta, timezone from pathlib import Path from unittest.mock import patch import music_server.services.token_service as token_service_module from music_server.services.token_service import TokenService class _RaceInjectingConnection: def __init__(self, conn: sqlite3.Connection, race_state: dict[str, bool]) -> None: self._conn = conn self._race_state = race_state self._injecting = False def execute(self, sql: str, parameters=()): normalized_sql = " ".join(sql.lower().split()) is_bind_update = ( normalized_sql.startswith("update access_tokens") and "set bound_client_id = ?, bound_client_label = ?, bound_at = ?, last_seen_at = ?" in normalized_sql and "where token_id = ? and bound_client_id is null" in normalized_sql ) if is_bind_update and not self._injecting and not self._race_state["done"]: token_id = parameters[4] self._injecting = True try: self._conn.execute( """ update access_tokens set bound_client_id = ?, bound_client_label = ?, bound_at = ?, last_seen_at = ? where token_id = ? and bound_client_id is null """, ( "racer-client", "Race Winner", "2000-01-01T00:00:00+00:00", "2000-01-01T00:00:00+00:00", token_id, ), ) self._race_state["done"] = True finally: self._injecting = False if parameters is None: return self._conn.execute(sql) return self._conn.execute(sql, parameters) def __getattr__(self, name: str): return getattr(self._conn, name) class _RevokeBeforeBindConnection: def __init__(self, conn: sqlite3.Connection, race_state: dict[str, bool]) -> None: self._conn = conn self._race_state = race_state self._injecting = False def execute(self, sql: str, parameters=()): normalized_sql = " ".join(sql.lower().split()) is_bind_update = ( normalized_sql.startswith("update access_tokens") and "set bound_client_id = ?, bound_client_label = ?, bound_at = ?, last_seen_at = ?" in normalized_sql and "where token_id = ? and bound_client_id is null" in normalized_sql ) if is_bind_update and not self._injecting and not self._race_state["done"]: token_id = parameters[4] self._injecting = True try: self._conn.execute( """ update access_tokens set revoked_at = ?, revoked_reason = ? where token_id = ? """, ( "2000-01-01T00:00:00+00:00", "revoked-during-bind-race", token_id, ), ) self._race_state["done"] = True finally: self._injecting = False if parameters is None: return self._conn.execute(sql) return self._conn.execute(sql, parameters) def __getattr__(self, name: str): return getattr(self._conn, name) class TokenServiceTests(unittest.TestCase): def _build_racing_connect(self): race_state = {"done": False} real_connect = token_service_module.connect_sqlite def racing_connect(db_path: str): return _RaceInjectingConnection(real_connect(db_path), race_state) return racing_connect def _build_revoke_before_bind_connect(self): race_state = {"done": False} real_connect = token_service_module.connect_sqlite def racing_connect(db_path: str): return _RevokeBeforeBindConnection(real_connect(db_path), race_state) return racing_connect def test_issue_token_persists_hash_and_listable_metadata(self): with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "player.db" service = TokenService(str(db_path)) issued = service.issue_token(days=90, label="iphone16") listed = service.list_tokens(include_revoked=True) self.assertTrue(issued.plaintext_token.startswith("msv1_")) self.assertEqual("iphone16", listed[0]["label"]) self.assertEqual(issued.token_id, listed[0]["token_id"]) self.assertIsNone(listed[0]["bound_client_id"]) conn = sqlite3.connect(db_path) row = conn.execute( "select token_hash, expires_at from access_tokens where token_id = ?", (issued.token_id,), ).fetchone() conn.close() self.assertIsNotNone(row) self.assertNotEqual(issued.plaintext_token, row[0]) def test_authenticate_binds_first_client_and_reuses_same_client(self): with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "player.db" service = TokenService(str(db_path)) issued = service.issue_token(days=90, label="ipad") first = service.authenticate( plaintext_token=issued.plaintext_token, client_id="client-a", client_label="Alice iPad", ) second = service.authenticate( plaintext_token=issued.plaintext_token, client_id="client-a", client_label="Alice iPad", ) third = service.authenticate( plaintext_token=issued.plaintext_token, client_id="client-b", client_label="Other Device", ) self.assertTrue(first.valid) self.assertTrue(second.valid) self.assertEqual("token_bound_to_other_client", third.error_code) def test_unbind_and_revoke_change_future_auth_outcome(self): with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "player.db" service = TokenService(str(db_path)) issued = service.issue_token(days=90, label="android") service.authenticate( plaintext_token=issued.plaintext_token, client_id="client-a", client_label="Pixel", ) service.unbind_token(issued.token_id) rebound = service.authenticate( plaintext_token=issued.plaintext_token, client_id="client-b", client_label="New Pixel", ) service.revoke_token(issued.token_id, reason="replaced") revoked = service.authenticate( plaintext_token=issued.plaintext_token, client_id="client-b", client_label="New Pixel", ) self.assertTrue(rebound.valid) self.assertEqual("token_revoked", revoked.error_code) def test_authenticate_returns_bound_other_when_first_bind_loses_race(self): with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "player.db" service = TokenService(str(db_path)) issued = service.issue_token(days=90, label="race") with patch( "music_server.services.token_service.connect_sqlite", side_effect=self._build_racing_connect(), ): result = service.authenticate( plaintext_token=issued.plaintext_token, client_id="client-a", client_label="Alice Phone", ) self.assertFalse(result.valid) self.assertEqual("token_bound_to_other_client", result.error_code) self.assertEqual("racer-client", result.bound_client_id) def test_authenticate_compares_expiration_by_datetime_not_string(self): with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "player.db" service = TokenService(str(db_path)) issued = service.issue_token(days=90, label="offset-expiry") future_utc = datetime.now(timezone.utc) + timedelta(minutes=5) future_with_negative_offset = future_utc.astimezone( timezone(timedelta(hours=-12)) ).isoformat() conn = sqlite3.connect(db_path) conn.execute( "update access_tokens set expires_at = ? where token_id = ?", (future_with_negative_offset, issued.token_id), ) conn.commit() conn.close() result = service.authenticate( plaintext_token=issued.plaintext_token, client_id="client-a", client_label="Offset Device", ) self.assertTrue(result.valid) self.assertIsNone(result.error_code) def test_status_uses_final_binding_state_when_first_bind_loses_race(self): with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "player.db" service = TokenService(str(db_path)) issued = service.issue_token(days=90, label="status-race") with patch( "music_server.services.token_service.connect_sqlite", side_effect=self._build_racing_connect(), ): payload = service.status( plaintext_token=issued.plaintext_token, client_id="client-a", client_label="Alice Phone", ) self.assertFalse(payload["valid"]) self.assertEqual("token_bound_to_other_client", payload["status"]) self.assertTrue(payload["bound"]) self.assertFalse(payload["isCurrentClientBound"]) self.assertEqual("Race Winner", payload["boundClientLabel"]) def test_authenticate_does_not_pass_when_token_revoked_before_first_bind(self): with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "player.db" service = TokenService(str(db_path)) issued = service.issue_token(days=90, label="revoke-race") with patch( "music_server.services.token_service.connect_sqlite", side_effect=self._build_revoke_before_bind_connect(), ): result = service.authenticate( plaintext_token=issued.plaintext_token, client_id="client-a", client_label="Alice Phone", ) self.assertFalse(result.valid) self.assertEqual("token_revoked", result.error_code) if __name__ == "__main__": unittest.main()