284 lines
11 KiB
Python
284 lines
11 KiB
Python
import sqlite3
|
|
import tempfile
|
|
import unittest
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
import music_server.services.token_service as token_service_module
|
|
from music_server.services.token_service import TokenService
|
|
|
|
|
|
class _RaceInjectingConnection:
|
|
def __init__(self, conn: sqlite3.Connection, race_state: dict[str, bool]) -> None:
|
|
self._conn = conn
|
|
self._race_state = race_state
|
|
self._injecting = False
|
|
|
|
def execute(self, sql: str, parameters=()):
|
|
normalized_sql = " ".join(sql.lower().split())
|
|
is_bind_update = (
|
|
normalized_sql.startswith("update access_tokens")
|
|
and "set bound_client_id = ?, bound_client_label = ?, bound_at = ?, last_seen_at = ?" in normalized_sql
|
|
and "where token_id = ? and bound_client_id is null" in normalized_sql
|
|
)
|
|
if is_bind_update and not self._injecting and not self._race_state["done"]:
|
|
token_id = parameters[4]
|
|
self._injecting = True
|
|
try:
|
|
self._conn.execute(
|
|
"""
|
|
update access_tokens
|
|
set bound_client_id = ?, bound_client_label = ?, bound_at = ?, last_seen_at = ?
|
|
where token_id = ? and bound_client_id is null
|
|
""",
|
|
(
|
|
"racer-client",
|
|
"Race Winner",
|
|
"2000-01-01T00:00:00+00:00",
|
|
"2000-01-01T00:00:00+00:00",
|
|
token_id,
|
|
),
|
|
)
|
|
self._race_state["done"] = True
|
|
finally:
|
|
self._injecting = False
|
|
|
|
if parameters is None:
|
|
return self._conn.execute(sql)
|
|
return self._conn.execute(sql, parameters)
|
|
|
|
def __getattr__(self, name: str):
|
|
return getattr(self._conn, name)
|
|
|
|
|
|
class _RevokeBeforeBindConnection:
|
|
def __init__(self, conn: sqlite3.Connection, race_state: dict[str, bool]) -> None:
|
|
self._conn = conn
|
|
self._race_state = race_state
|
|
self._injecting = False
|
|
|
|
def execute(self, sql: str, parameters=()):
|
|
normalized_sql = " ".join(sql.lower().split())
|
|
is_bind_update = (
|
|
normalized_sql.startswith("update access_tokens")
|
|
and "set bound_client_id = ?, bound_client_label = ?, bound_at = ?, last_seen_at = ?" in normalized_sql
|
|
and "where token_id = ? and bound_client_id is null" in normalized_sql
|
|
)
|
|
if is_bind_update and not self._injecting and not self._race_state["done"]:
|
|
token_id = parameters[4]
|
|
self._injecting = True
|
|
try:
|
|
self._conn.execute(
|
|
"""
|
|
update access_tokens
|
|
set revoked_at = ?, revoked_reason = ?
|
|
where token_id = ?
|
|
""",
|
|
(
|
|
"2000-01-01T00:00:00+00:00",
|
|
"revoked-during-bind-race",
|
|
token_id,
|
|
),
|
|
)
|
|
self._race_state["done"] = True
|
|
finally:
|
|
self._injecting = False
|
|
|
|
if parameters is None:
|
|
return self._conn.execute(sql)
|
|
return self._conn.execute(sql, parameters)
|
|
|
|
def __getattr__(self, name: str):
|
|
return getattr(self._conn, name)
|
|
|
|
|
|
class TokenServiceTests(unittest.TestCase):
|
|
def _build_racing_connect(self):
|
|
race_state = {"done": False}
|
|
real_connect = token_service_module.connect_sqlite
|
|
|
|
def racing_connect(db_path: str):
|
|
return _RaceInjectingConnection(real_connect(db_path), race_state)
|
|
|
|
return racing_connect
|
|
|
|
def _build_revoke_before_bind_connect(self):
|
|
race_state = {"done": False}
|
|
real_connect = token_service_module.connect_sqlite
|
|
|
|
def racing_connect(db_path: str):
|
|
return _RevokeBeforeBindConnection(real_connect(db_path), race_state)
|
|
|
|
return racing_connect
|
|
|
|
def test_issue_token_persists_hash_and_listable_metadata(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
db_path = Path(tmpdir) / "player.db"
|
|
service = TokenService(str(db_path))
|
|
|
|
issued = service.issue_token(days=90, label="iphone16")
|
|
listed = service.list_tokens(include_revoked=True)
|
|
|
|
self.assertTrue(issued.plaintext_token.startswith("msv1_"))
|
|
self.assertEqual("iphone16", listed[0]["label"])
|
|
self.assertEqual(issued.token_id, listed[0]["token_id"])
|
|
self.assertIsNone(listed[0]["bound_client_id"])
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
row = conn.execute(
|
|
"select token_hash, expires_at from access_tokens where token_id = ?",
|
|
(issued.token_id,),
|
|
).fetchone()
|
|
conn.close()
|
|
|
|
self.assertIsNotNone(row)
|
|
self.assertNotEqual(issued.plaintext_token, row[0])
|
|
|
|
def test_authenticate_binds_first_client_and_reuses_same_client(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
db_path = Path(tmpdir) / "player.db"
|
|
service = TokenService(str(db_path))
|
|
issued = service.issue_token(days=90, label="ipad")
|
|
|
|
first = service.authenticate(
|
|
plaintext_token=issued.plaintext_token,
|
|
client_id="client-a",
|
|
client_label="Alice iPad",
|
|
)
|
|
second = service.authenticate(
|
|
plaintext_token=issued.plaintext_token,
|
|
client_id="client-a",
|
|
client_label="Alice iPad",
|
|
)
|
|
third = service.authenticate(
|
|
plaintext_token=issued.plaintext_token,
|
|
client_id="client-b",
|
|
client_label="Other Device",
|
|
)
|
|
|
|
self.assertTrue(first.valid)
|
|
self.assertTrue(second.valid)
|
|
self.assertEqual("token_bound_to_other_client", third.error_code)
|
|
|
|
def test_unbind_and_revoke_change_future_auth_outcome(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
db_path = Path(tmpdir) / "player.db"
|
|
service = TokenService(str(db_path))
|
|
issued = service.issue_token(days=90, label="android")
|
|
|
|
service.authenticate(
|
|
plaintext_token=issued.plaintext_token,
|
|
client_id="client-a",
|
|
client_label="Pixel",
|
|
)
|
|
service.unbind_token(issued.token_id)
|
|
|
|
rebound = service.authenticate(
|
|
plaintext_token=issued.plaintext_token,
|
|
client_id="client-b",
|
|
client_label="New Pixel",
|
|
)
|
|
service.revoke_token(issued.token_id, reason="replaced")
|
|
revoked = service.authenticate(
|
|
plaintext_token=issued.plaintext_token,
|
|
client_id="client-b",
|
|
client_label="New Pixel",
|
|
)
|
|
|
|
self.assertTrue(rebound.valid)
|
|
self.assertEqual("token_revoked", revoked.error_code)
|
|
|
|
def test_authenticate_returns_bound_other_when_first_bind_loses_race(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
db_path = Path(tmpdir) / "player.db"
|
|
service = TokenService(str(db_path))
|
|
issued = service.issue_token(days=90, label="race")
|
|
|
|
with patch(
|
|
"music_server.services.token_service.connect_sqlite",
|
|
side_effect=self._build_racing_connect(),
|
|
):
|
|
result = service.authenticate(
|
|
plaintext_token=issued.plaintext_token,
|
|
client_id="client-a",
|
|
client_label="Alice Phone",
|
|
)
|
|
|
|
self.assertFalse(result.valid)
|
|
self.assertEqual("token_bound_to_other_client", result.error_code)
|
|
self.assertEqual("racer-client", result.bound_client_id)
|
|
|
|
def test_authenticate_compares_expiration_by_datetime_not_string(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
db_path = Path(tmpdir) / "player.db"
|
|
service = TokenService(str(db_path))
|
|
issued = service.issue_token(days=90, label="offset-expiry")
|
|
|
|
future_utc = datetime.now(timezone.utc) + timedelta(minutes=5)
|
|
future_with_negative_offset = future_utc.astimezone(
|
|
timezone(timedelta(hours=-12))
|
|
).isoformat()
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
conn.execute(
|
|
"update access_tokens set expires_at = ? where token_id = ?",
|
|
(future_with_negative_offset, issued.token_id),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
result = service.authenticate(
|
|
plaintext_token=issued.plaintext_token,
|
|
client_id="client-a",
|
|
client_label="Offset Device",
|
|
)
|
|
|
|
self.assertTrue(result.valid)
|
|
self.assertIsNone(result.error_code)
|
|
|
|
def test_status_uses_final_binding_state_when_first_bind_loses_race(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
db_path = Path(tmpdir) / "player.db"
|
|
service = TokenService(str(db_path))
|
|
issued = service.issue_token(days=90, label="status-race")
|
|
|
|
with patch(
|
|
"music_server.services.token_service.connect_sqlite",
|
|
side_effect=self._build_racing_connect(),
|
|
):
|
|
payload = service.status(
|
|
plaintext_token=issued.plaintext_token,
|
|
client_id="client-a",
|
|
client_label="Alice Phone",
|
|
)
|
|
|
|
self.assertFalse(payload["valid"])
|
|
self.assertEqual("token_bound_to_other_client", payload["status"])
|
|
self.assertTrue(payload["bound"])
|
|
self.assertFalse(payload["isCurrentClientBound"])
|
|
self.assertEqual("Race Winner", payload["boundClientLabel"])
|
|
|
|
def test_authenticate_does_not_pass_when_token_revoked_before_first_bind(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
db_path = Path(tmpdir) / "player.db"
|
|
service = TokenService(str(db_path))
|
|
issued = service.issue_token(days=90, label="revoke-race")
|
|
|
|
with patch(
|
|
"music_server.services.token_service.connect_sqlite",
|
|
side_effect=self._build_revoke_before_bind_connect(),
|
|
):
|
|
result = service.authenticate(
|
|
plaintext_token=issued.plaintext_token,
|
|
client_id="client-a",
|
|
client_label="Alice Phone",
|
|
)
|
|
|
|
self.assertFalse(result.valid)
|
|
self.assertEqual("token_revoked", result.error_code)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|