test(remote): add unit + e2e suite for remote-services subsystem

Covers RemoteClient/RemoteService, REST routes, handlers, SessionStore, InputBridge, DownloadQueueManager, errors, compression, and serve CLI. E2e tier opts in via --live and can auto-spawn its own serve.
This commit is contained in:
imSp4rky
2026-05-21 10:45:25 -06:00
parent 9c905ef7a3
commit 746b573711
29 changed files with 2541 additions and 0 deletions

View File

View File

@@ -0,0 +1,65 @@
"""Unit tests for unshackle.core.api.compression.compression_middleware."""
from __future__ import annotations
import asyncio
import gzip
import json
import pytest
from aiohttp import web
from unshackle.core.api.compression import compression_middleware
pytestmark = pytest.mark.unit
class _FakeReq:
def __init__(self, accept_encoding: str = "gzip") -> None:
self.headers = {"Accept-Encoding": accept_encoding}
def _run(coro):
return asyncio.get_event_loop().run_until_complete(coro) if False else asyncio.run(coro)
def test_skips_when_client_does_not_accept_gzip() -> None:
payload = b"x" * 4096
body_json = json.dumps({"data": "x" * 4096}).encode()
async def handler(req): # noqa: ARG001
return web.json_response({"data": "x" * 4096})
req = _FakeReq(accept_encoding="identity")
resp = _run(compression_middleware(req, handler))
assert resp.headers.get("Content-Encoding") != "gzip"
assert resp.body == body_json or len(resp.body) >= len(body_json) - 8
def test_skips_when_body_below_threshold() -> None:
async def handler(req): # noqa: ARG001
return web.json_response({"hi": "x"})
resp = _run(compression_middleware(_FakeReq(), handler))
assert resp.headers.get("Content-Encoding") != "gzip"
def test_skips_non_json_response() -> None:
async def handler(req): # noqa: ARG001
return web.Response(body=b"x" * 4096, content_type="text/plain")
resp = _run(compression_middleware(_FakeReq(), handler))
assert resp.headers.get("Content-Encoding") != "gzip"
def test_compresses_large_json_when_accepted() -> None:
big = {"data": "x" * 4096}
async def handler(req): # noqa: ARG001
return web.json_response(big)
resp = _run(compression_middleware(_FakeReq(), handler))
assert resp.headers.get("Content-Encoding") == "gzip"
decompressed = gzip.decompress(resp.body)
assert json.loads(decompressed) == big
assert resp.headers["Content-Length"] == str(len(resp.body))

View File

@@ -0,0 +1,120 @@
"""Unit tests for DownloadJob + DownloadQueueManager state machine.
These tests focus on the queue manager's data layer (create/get/list/cancel/
cleanup/serialize) — they do not exercise the actual subprocess download path.
"""
from __future__ import annotations
from datetime import datetime, timedelta
import pytest
from unshackle.core.api.download_manager import DownloadJob, DownloadQueueManager, JobStatus, get_download_manager
pytestmark = pytest.mark.unit
@pytest.fixture
def manager() -> DownloadQueueManager:
"""Fresh manager. We never call start_workers() so no async tasks are created."""
return DownloadQueueManager(max_concurrent_downloads=2, job_retention_hours=24)
def test_create_job_returns_queued_job(manager: DownloadQueueManager) -> None:
job = manager.create_job("ATV", "movie-123", profile="default")
assert isinstance(job, DownloadJob)
assert job.status is JobStatus.QUEUED
assert job.service == "ATV"
assert job.title_id == "movie-123"
assert job.parameters == {"profile": "default"}
def test_get_and_list_jobs(manager: DownloadQueueManager) -> None:
a = manager.create_job("ATV", "a")
b = manager.create_job("NF", "b")
assert manager.get_job(a.job_id) is a
assert manager.get_job("missing") is None
listed = manager.list_jobs()
assert {j.job_id for j in listed} == {a.job_id, b.job_id}
def test_to_dict_short_vs_full(manager: DownloadQueueManager) -> None:
job = manager.create_job("ATV", "t", profile="p")
short = job.to_dict()
assert "parameters" not in short
assert short["status"] == "queued"
assert short["service"] == "ATV"
full = job.to_dict(include_full_details=True)
assert full["parameters"] == {"profile": "p"}
assert "error_message" in full
assert "output_files" in full
def test_cancel_queued_job_sets_cancelled_and_signals_event(manager: DownloadQueueManager) -> None:
job = manager.create_job("ATV", "t")
assert manager.cancel_job(job.job_id) is True
assert job.status is JobStatus.CANCELLED
assert job.cancel_event.is_set()
def test_cancel_unknown_job_returns_false(manager: DownloadQueueManager) -> None:
assert manager.cancel_job("never-existed") is False
def test_cancel_completed_job_returns_false(manager: DownloadQueueManager) -> None:
job = manager.create_job("ATV", "t")
job.status = JobStatus.COMPLETED
assert manager.cancel_job(job.job_id) is False
def test_cancel_downloading_job_signals(manager: DownloadQueueManager) -> None:
job = manager.create_job("ATV", "t")
job.status = JobStatus.DOWNLOADING
assert manager.cancel_job(job.job_id) is True
assert job.status is JobStatus.CANCELLED
assert job.cancel_event.is_set()
def test_cleanup_old_jobs_drops_old_terminal_states(manager: DownloadQueueManager) -> None:
now = datetime.now()
old = now - timedelta(hours=48)
keep_recent = manager.create_job("ATV", "recent")
drop_old_done = manager.create_job("ATV", "old-done")
drop_old_failed = manager.create_job("ATV", "old-failed")
keep_running = manager.create_job("ATV", "running")
keep_recent.status = JobStatus.COMPLETED
keep_recent.completed_time = now
drop_old_done.status = JobStatus.COMPLETED
drop_old_done.completed_time = old
drop_old_failed.status = JobStatus.FAILED
drop_old_failed.created_time = old # never set completed_time
keep_running.status = JobStatus.DOWNLOADING
removed = manager.cleanup_old_jobs()
assert removed == 2
remaining = {j.job_id for j in manager.list_jobs()}
assert keep_recent.job_id in remaining
assert keep_running.job_id in remaining
assert drop_old_done.job_id not in remaining
assert drop_old_failed.job_id not in remaining
def test_get_download_manager_returns_singleton() -> None:
a = get_download_manager()
b = get_download_manager()
assert a is b
def test_job_status_values() -> None:
assert {s.value for s in JobStatus} == {
"queued",
"downloading",
"completed",
"failed",
"cancelled",
}

View File

@@ -0,0 +1,132 @@
"""Unit tests for unshackle.core.api.errors."""
from __future__ import annotations
import json
import pytest
from unshackle.core.api.errors import (
APIError,
APIErrorCode,
build_error_response,
categorize_exception,
handle_api_exception,
)
pytestmark = pytest.mark.unit
def _body(resp) -> dict:
return json.loads(resp.body.decode("utf-8"))
def test_api_error_default_http_status_per_code() -> None:
cases = {
APIErrorCode.INVALID_INPUT: 400,
APIErrorCode.INVALID_SERVICE: 400,
APIErrorCode.AUTH_REQUIRED: 401,
APIErrorCode.AUTH_FAILED: 401,
APIErrorCode.FORBIDDEN: 403,
APIErrorCode.GEOFENCE: 403,
APIErrorCode.NOT_FOUND: 404,
APIErrorCode.SESSION_NOT_FOUND: 404,
APIErrorCode.TRACK_NOT_FOUND: 404,
APIErrorCode.RATE_LIMITED: 429,
APIErrorCode.INTERNAL_ERROR: 500,
APIErrorCode.SERVICE_ERROR: 502,
APIErrorCode.DRM_ERROR: 502,
APIErrorCode.NETWORK_ERROR: 503,
APIErrorCode.SERVICE_UNAVAILABLE: 503,
}
for code, expected in cases.items():
assert APIError(code, "x").http_status == expected, code
def test_api_error_explicit_http_status_overrides_default() -> None:
err = APIError(APIErrorCode.INVALID_INPUT, "x", http_status=418)
assert err.http_status == 418
def test_build_error_response_from_api_error() -> None:
err = APIError(
APIErrorCode.SESSION_NOT_FOUND,
"no such session",
details={"session_id": "abc"},
retryable=False,
)
resp = build_error_response(err)
assert resp.status == 404
body = _body(resp)
assert body["status"] == "error"
assert body["error_code"] == "SESSION_NOT_FOUND"
assert body["message"] == "no such session"
assert body["details"] == {"session_id": "abc"}
assert "retryable" not in body
assert "debug_info" not in body
assert "timestamp" in body
def test_build_error_response_retryable_flag() -> None:
err = APIError(APIErrorCode.NETWORK_ERROR, "boom", retryable=True)
body = _body(build_error_response(err))
assert body["retryable"] is True
def test_build_error_response_from_generic_exception() -> None:
resp = build_error_response(RuntimeError("oops"))
assert resp.status == 500
body = _body(resp)
assert body["error_code"] == "INTERNAL_ERROR"
assert body["message"] == "oops"
def test_build_error_response_debug_mode_includes_traceback() -> None:
try:
raise ValueError("kaboom")
except ValueError as e:
resp = build_error_response(e, debug_mode=True, extra_debug_info={"foo": "bar"})
body = _body(resp)
assert body["debug_info"]["exception_type"] == "ValueError"
assert "traceback" in body["debug_info"]
assert body["debug_info"]["foo"] == "bar"
@pytest.mark.parametrize(
"exc, expected_code",
[
(Exception("Invalid credentials provided"), APIErrorCode.AUTH_FAILED),
(Exception("Connection refused"), APIErrorCode.NETWORK_ERROR),
(TimeoutError("read timeout"), APIErrorCode.NETWORK_ERROR),
(Exception("Not available in your region"), APIErrorCode.GEOFENCE),
(Exception("Title not found"), APIErrorCode.NOT_FOUND),
(Exception("HTTP 429 too many requests"), APIErrorCode.RATE_LIMITED),
(Exception("DRM license fetch failed"), APIErrorCode.DRM_ERROR),
(Exception("503 service unavailable"), APIErrorCode.SERVICE_UNAVAILABLE),
(ValueError("malformed body"), APIErrorCode.INVALID_INPUT),
(RuntimeError("totally novel failure xyz"), APIErrorCode.INTERNAL_ERROR),
],
)
def test_categorize_exception(exc: Exception, expected_code: APIErrorCode) -> None:
api_err = categorize_exception(exc, context={"service": "ATV"})
assert api_err.error_code == expected_code
assert api_err.details.get("service") == "ATV"
def test_categorize_preserves_context() -> None:
api_err = categorize_exception(ValueError("bad"), context={"op": "search"})
assert api_err.details["op"] == "search"
def test_handle_api_exception_with_api_error_preserves_code() -> None:
err = APIError(APIErrorCode.TRACK_NOT_FOUND, "no track")
resp = handle_api_exception(err)
body = _body(resp)
assert body["error_code"] == "TRACK_NOT_FOUND"
assert resp.status == 404
def test_handle_api_exception_categorizes_generic() -> None:
resp = handle_api_exception(ConnectionError("oops"))
body = _body(resp)
assert body["error_code"] == "NETWORK_ERROR"

View File

@@ -0,0 +1,273 @@
"""Unit tests for unshackle.core.api.handlers serializers + validators."""
from __future__ import annotations
import pytest
from langcodes import Language
from unshackle.core.api.handlers import (
sanitize_log,
serialize_audio_track,
serialize_drm,
serialize_subtitle_track,
serialize_title,
serialize_video_track,
validate_download_parameters,
validate_service,
)
from unshackle.core.titles.episode import Episode
from unshackle.core.titles.movie import Movie
from unshackle.core.tracks import Audio, Subtitle, Video
from unshackle.core.tracks.track import Track
pytestmark = pytest.mark.unit
class _FakeSvc:
pass
def _video(**overrides) -> Video:
base = dict(
url="https://example.com/v.mpd",
language=Language.get("en"),
descriptor=Track.Descriptor.URL,
codec=Video.Codec.AVC,
range_=Video.Range.SDR,
bitrate=5_000_000,
width=1920,
height=1080,
fps=24,
id_="video-001",
)
base.update(overrides)
return Video(**base)
def _audio(**overrides) -> Audio:
base = dict(
url="https://example.com/a.mpd",
language=Language.get("en"),
descriptor=Track.Descriptor.URL,
codec=Audio.Codec.AAC,
bitrate=128_000,
channels=2,
joc=0,
descriptive=False,
id_="audio-001",
)
base.update(overrides)
return Audio(**base)
def _subtitle(**overrides) -> Subtitle:
base = dict(
url="https://example.com/s.vtt",
language=Language.get("en"),
descriptor=Track.Descriptor.URL,
codec=Subtitle.Codec.WebVTT,
cc=False,
sdh=False,
forced=False,
id_="sub-001",
)
base.update(overrides)
return Subtitle(**base)
# ---------- sanitize_log ----------
@pytest.mark.parametrize(
"raw, expected",
[
("hello\nworld", "helloworld"),
("a\r\nb\x00c", "abc"),
("clean", "clean"),
(12345, "12345"),
],
)
def test_sanitize_log(raw, expected: str) -> None:
assert sanitize_log(raw) == expected
# ---------- serialize_title ----------
def test_serialize_title_movie() -> None:
movie = Movie(id_="movie-0001", service=_FakeSvc, name="Title X", year=2024, language=Language.get("en"))
d = serialize_title(movie)
assert d["type"] == "movie"
assert d["name"] == "Title X"
assert d["year"] == 2024
assert d["id"] == "movie-0001"
assert d["language"] == "en"
def test_serialize_title_episode_named() -> None:
ep = Episode(
id_="ep-00001",
service=_FakeSvc,
title="My Show",
season=2,
number=3,
name="Pilot",
year=2024,
language=Language.get("en"),
)
d = serialize_title(ep)
assert d["type"] == "episode"
assert d["series_title"] == "My Show"
assert d["season"] == 2
assert d["number"] == 3
assert d["name"] == "Pilot"
def test_serialize_title_episode_unnamed_falls_back_to_number() -> None:
ep = Episode(
id_="ep-00002",
service=_FakeSvc,
title="Show",
season=1,
number=5,
name=None,
year=None,
language=Language.get("en"),
)
d = serialize_title(ep)
assert d["name"] == "Episode 05"
# ---------- serialize_video/audio/subtitle ----------
def test_serialize_video_track_basic() -> None:
d = serialize_video_track(_video())
assert d["id"] == "video-001"
assert d["codec"] == "AVC"
assert d["bitrate"] == 5000 # kbps
assert d["resolution"] == "1920x1080"
assert d["fps"] == 24
assert d["range"] == "SDR"
assert d["language"] == "en"
assert d["drm"] is None
assert "url" not in d
def test_serialize_video_track_include_url() -> None:
d = serialize_video_track(_video(), include_url=True)
assert d["url"] == "https://example.com/v.mpd"
def test_serialize_audio_track_basic() -> None:
d = serialize_audio_track(_audio())
assert d["id"] == "audio-001"
assert d["codec"] == "AAC"
assert d["bitrate"] == 128
assert d["channels"] == 2
assert d["descriptive"] is False
def test_serialize_subtitle_track_basic() -> None:
d = serialize_subtitle_track(_subtitle(forced=True))
assert d["id"] == "sub-001"
assert d["codec"] == "WebVTT"
assert d["forced"] is True
assert d["sdh"] is False
assert d["cc"] is False
# ---------- serialize_drm ----------
def test_serialize_drm_none_returns_none() -> None:
assert serialize_drm(None) is None
assert serialize_drm([]) is None
def test_serialize_drm_widevine_minimal() -> None:
class _PSSH:
def dumps(self) -> str:
return "BASE64PSSH=="
class _Widevine:
def __init__(self) -> None:
self._pssh = _PSSH()
self.kids = ["00112233445566778899aabbccddeeff"]
self.license_url = "https://lic.example.com/wv"
out = serialize_drm(_Widevine())
assert isinstance(out, list)
assert len(out) == 1
info = out[0]
assert info["type"] == "_widevine" # class name lowercased
assert info["pssh"] == "BASE64PSSH=="
assert info["kids"] == ["00112233445566778899aabbccddeeff"]
assert info["license_url"] == "https://lic.example.com/wv"
# ---------- validate_service ----------
def test_validate_service_unknown_returns_none() -> None:
assert validate_service("NOPE_THIS_IS_NOT_REAL_") is None
# ---------- validate_download_parameters ----------
def test_validate_download_params_accepts_defaults() -> None:
assert validate_download_parameters({}) is None
@pytest.mark.parametrize(
"data, fragment",
[
({"vcodec": "WUT"}, "Invalid vcodec"),
({"vcodec": 123}, "vcodec must be a string or list"),
({"acodec": "MP9"}, "Invalid acodec"),
({"sub_format": "doc"}, "Invalid sub_format"),
({"vbitrate": -1}, "vbitrate"),
({"abitrate": "no"}, "abitrate"),
({"vbitrate_range": "no-dash-but-letters"}, None),
({"vbitrate_range": "nope"}, "MIN-MAX"),
({"channels": -3}, "channels"),
({"workers": 0}, "workers"),
({"downloads": 0}, "downloads"),
({"video_only": True, "audio_only": True}, "exclusive"),
({"no_subs": True, "subs_only": True}, "no_subs and subs_only"),
({"no_audio": True, "audio_only": True}, "no_audio and audio_only"),
({"s_lang": ["en"], "require_subs": ["en"]}, "s_lang and require_subs"),
({"range": "UHD"}, "Invalid range"),
({"range": ["SDR", "UHD"]}, "Invalid range value"),
],
)
def test_validate_download_params_errors(data: dict, fragment) -> None:
result = validate_download_parameters(data)
if fragment is None:
# A dash-containing string is valid syntactically per current rule
assert result is None
else:
assert result is not None
assert fragment in result
def test_validate_download_params_accepts_valid_values() -> None:
assert (
validate_download_parameters(
{
"vcodec": "H264,H265",
"acodec": ["AAC", "EAC3"],
"sub_format": "VTT",
"vbitrate": 6000,
"abitrate": 128,
"vbitrate_range": "6000-7000",
"abitrate_range": "96-192",
"channels": 5.1,
"workers": 8,
"downloads": 2,
"range": ["SDR", "HDR10"],
}
)
is None
)

View File

@@ -0,0 +1,100 @@
"""Unit tests for unshackle.core.api.input_bridge.InputBridge."""
from __future__ import annotations
import threading
import time
import pytest
from unshackle.core.api.input_bridge import AuthStatus, BridgeCancelledError, InputBridge
pytestmark = pytest.mark.unit
def test_initial_status_is_authenticating() -> None:
bridge = InputBridge()
assert bridge.status is AuthStatus.AUTHENTICATING
assert bridge.get_pending_prompt() is None
assert bridge.error is None
def test_submit_response_returns_false_when_no_prompt_pending() -> None:
bridge = InputBridge()
assert bridge.submit_response("foo") is False
def test_request_input_blocks_until_submit() -> None:
bridge = InputBridge()
result: list[str] = []
def worker() -> None:
result.append(bridge.request_input("OTP?", timeout=5))
t = threading.Thread(target=worker)
t.start()
for _ in range(50):
if bridge.get_pending_prompt() == "OTP?":
break
time.sleep(0.02)
assert bridge.status is AuthStatus.PENDING_INPUT
assert bridge.get_pending_prompt() == "OTP?"
assert bridge.submit_response("123456") is True
t.join(timeout=2)
assert result == ["123456"]
assert bridge.status is AuthStatus.AUTHENTICATING
assert bridge.get_pending_prompt() is None
def test_request_input_times_out() -> None:
bridge = InputBridge()
with pytest.raises(TimeoutError):
bridge.request_input("hello", timeout=0.6)
assert bridge.status is AuthStatus.FAILED
assert "timed out" in (bridge.error or "")
def test_cancel_before_request_raises_immediately() -> None:
bridge = InputBridge()
bridge.cancel()
with pytest.raises(BridgeCancelledError):
bridge.request_input("hello", timeout=5)
assert bridge.status is AuthStatus.FAILED
def test_cancel_unblocks_pending_request() -> None:
bridge = InputBridge()
exc: list[Exception] = []
def worker() -> None:
try:
bridge.request_input("OTP?", timeout=5)
except BridgeCancelledError as e:
exc.append(e)
t = threading.Thread(target=worker)
t.start()
for _ in range(50):
if bridge.status is AuthStatus.PENDING_INPUT:
break
time.sleep(0.02)
bridge.cancel()
t.join(timeout=2)
assert exc and isinstance(exc[0], BridgeCancelledError)
assert bridge.status is AuthStatus.FAILED
def test_get_pending_prompt_returns_none_outside_pending_state() -> None:
bridge = InputBridge()
bridge.status = AuthStatus.AUTHENTICATED
assert bridge.get_pending_prompt() is None
def test_status_and_error_setters() -> None:
bridge = InputBridge()
bridge.status = AuthStatus.AUTHENTICATED
bridge.error = "boom"
assert bridge.status is AuthStatus.AUTHENTICATED
assert bridge.error == "boom"

View File

@@ -0,0 +1,111 @@
"""Unit tests for unshackle.core.remote_service.RemoteClient."""
from __future__ import annotations
import json
import pytest
import responses
from unshackle.core.remote_service import RemoteClient
pytestmark = pytest.mark.unit
@pytest.fixture
def client() -> RemoteClient:
return RemoteClient(server_url="http://srv:8786", api_key="secret-xyz")
def test_session_sets_secret_key_header(client: RemoteClient) -> None:
s = client.session
assert s.headers.get("X-Secret-Key") == "secret-xyz"
assert s.headers["User-Agent"].startswith("unshackle/")
def test_session_omits_secret_key_when_empty() -> None:
c = RemoteClient(server_url="http://srv:8786", api_key="")
assert "X-Secret-Key" not in c.session.headers
def test_server_url_trailing_slash_stripped() -> None:
c = RemoteClient(server_url="http://srv:8786/", api_key="")
assert c.server_url == "http://srv:8786"
@responses.activate
def test_get_returns_json(client: RemoteClient) -> None:
responses.add(
responses.GET,
"http://srv:8786/api/health",
json={"status": "ok"},
status=200,
)
assert client.get("/api/health") == {"status": "ok"}
@responses.activate
def test_post_sends_json_body(client: RemoteClient) -> None:
captured = {}
def cb(request):
captured["body"] = json.loads(request.body)
return (200, {}, json.dumps({"session_id": "abc"}))
responses.add_callback(
responses.POST, "http://srv:8786/api/session/create", callback=cb, content_type="application/json"
)
result = client.post("/api/session/create", {"service": "ATV"})
assert result == {"session_id": "abc"}
assert captured["body"] == {"service": "ATV"}
@responses.activate
def test_delete_returns_json(client: RemoteClient) -> None:
responses.add(
responses.DELETE,
"http://srv:8786/api/session/abc",
json={"status": "deleted"},
status=200,
)
assert client.delete("/api/session/abc") == {"status": "deleted"}
@responses.activate
def test_4xx_raises_systemexit_with_logged_error(client: RemoteClient, caplog: pytest.LogCaptureFixture) -> None:
responses.add(
responses.GET,
"http://srv:8786/api/session/none",
json={"error_code": "SESSION_NOT_FOUND", "message": "no such session"},
status=404,
)
with caplog.at_level("ERROR"), pytest.raises(SystemExit):
client.get("/api/session/none")
assert "SESSION_NOT_FOUND" in caplog.text
assert "no such session" in caplog.text
@responses.activate
def test_connection_error_raises_systemexit(client: RemoteClient) -> None:
import requests
responses.add(
responses.GET,
"http://srv:8786/api/health",
body=requests.ConnectionError("boom"),
)
with pytest.raises(SystemExit):
client.get("/api/health")
@responses.activate
def test_timeout_raises_systemexit(client: RemoteClient) -> None:
import requests
responses.add(
responses.GET,
"http://srv:8786/api/health",
body=requests.Timeout("slow"),
)
with pytest.raises(SystemExit):
client.get("/api/health")

View File

@@ -0,0 +1,117 @@
"""Unit tests for resolve_server / _resolve_proxy in unshackle.core.remote_service."""
from __future__ import annotations
import click
import pytest
from unshackle.core.remote_service import _resolve_proxy, resolve_server
pytestmark = pytest.mark.unit
@pytest.fixture
def empty_remote_services(monkeypatch: pytest.MonkeyPatch) -> None:
from unshackle.core import remote_service as rs
monkeypatch.setattr(rs.config, "remote_services", {})
@pytest.fixture
def single_remote_service(monkeypatch: pytest.MonkeyPatch) -> None:
from unshackle.core import remote_service as rs
monkeypatch.setattr(
rs.config,
"remote_services",
{
"primary": {
"url": "https://primary:8080",
"api_key": "key-abc",
"services": {"ATV": True, "NF": True},
"server_cdm": True,
}
},
)
@pytest.fixture
def multi_remote_services(monkeypatch: pytest.MonkeyPatch) -> None:
from unshackle.core import remote_service as rs
monkeypatch.setattr(
rs.config,
"remote_services",
{
"a": {"url": "https://a:8080", "api_key": "ka"},
"b": {"url": "https://b:8080", "api_key": "kb"},
},
)
def test_resolve_server_no_config_raises_click(empty_remote_services) -> None:
with pytest.raises(click.ClickException) as exc:
resolve_server(None)
assert "remote_services" in str(exc.value.message)
def test_resolve_server_single_picks_only_entry(single_remote_service) -> None:
url, key, services = resolve_server(None)
assert url == "https://primary:8080"
assert key == "key-abc"
assert services["_server_cdm"] is True
assert services.get("ATV") is True
def test_resolve_server_explicit_name(single_remote_service) -> None:
url, key, services = resolve_server("primary")
assert url == "https://primary:8080"
assert services["_server_cdm"] is True
def test_resolve_server_unknown_name_raises(single_remote_service) -> None:
with pytest.raises(click.ClickException) as exc:
resolve_server("bogus")
assert "bogus" in str(exc.value.message)
def test_resolve_server_multi_requires_explicit(multi_remote_services) -> None:
with pytest.raises(click.ClickException) as exc:
resolve_server(None)
assert "--server" in str(exc.value.message)
def test_resolve_server_multi_with_name(multi_remote_services) -> None:
url, key, services = resolve_server("b")
assert url == "https://b:8080"
assert key == "kb"
assert services["_server_cdm"] is False
def test_resolve_proxy_none_returns_none() -> None:
assert _resolve_proxy(None) is None
assert _resolve_proxy("") is None
def test_resolve_proxy_passes_through_value(monkeypatch: pytest.MonkeyPatch) -> None:
import unshackle.core.proxies.resolve as resolve_mod
monkeypatch.setattr(resolve_mod, "initialize_proxy_providers", lambda: [])
monkeypatch.setattr(resolve_mod, "resolve_proxy", lambda arg, providers: f"http://proxy/{arg}")
assert _resolve_proxy("us") == "http://proxy/us"
def test_resolve_proxy_value_error_becomes_click(monkeypatch: pytest.MonkeyPatch) -> None:
import unshackle.core.proxies.resolve as resolve_mod
monkeypatch.setattr(resolve_mod, "initialize_proxy_providers", lambda: [])
def boom(*_):
raise ValueError("no such country")
monkeypatch.setattr(resolve_mod, "resolve_proxy", boom)
with pytest.raises(click.ClickException) as exc:
_resolve_proxy("xx")
assert "no such country" in str(exc.value.message)

View File

@@ -0,0 +1,157 @@
"""Unit tests for module-level helpers in unshackle.core.remote_service."""
from __future__ import annotations
from enum import Enum
import pytest
from unshackle.core.remote_service import (
_build_title,
_build_tracks,
_deserialize_audio,
_deserialize_subtitle,
_deserialize_video,
_enum_get,
_match_track,
_reconstruct_drm,
)
from unshackle.core.titles.episode import Episode
from unshackle.core.titles.movie import Movie
from unshackle.core.tracks import Audio, Subtitle, Video
pytestmark = pytest.mark.unit
class _Color(Enum):
RED = 1
BLUE = 2
def test_enum_get_known() -> None:
assert _enum_get(_Color, "RED") is _Color.RED
def test_enum_get_unknown_returns_default() -> None:
assert _enum_get(_Color, "PURPLE", default=_Color.BLUE) is _Color.BLUE
def test_enum_get_none_returns_default() -> None:
assert _enum_get(_Color, None, default=_Color.RED) is _Color.RED
def test_deserialize_video_minimal() -> None:
v = _deserialize_video({"id": "video-1", "codec": "AVC", "width": 1920, "height": 1080, "bitrate": 5000})
assert isinstance(v, Video)
assert v.id == "video-1"
assert v.codec is Video.Codec.AVC
assert v.bitrate == 5_000_000 # kbps -> bps
assert v.width == 1920
assert v.height == 1080
assert v.range is Video.Range.SDR
def test_deserialize_video_unknown_codec_falls_back_to_none() -> None:
v = _deserialize_video({"id": "v2", "codec": "MADE_UP", "width": 0, "height": 0})
assert v.codec is None
def test_deserialize_audio_atmos_flag_sets_joc() -> None:
a = _deserialize_audio({"id": "a1", "codec": "AAC", "atmos": True, "channels": 6, "bitrate": 256})
assert isinstance(a, Audio)
assert a.joc == 1
assert a.channels == 6
assert a.bitrate == 256_000
def test_deserialize_audio_no_atmos() -> None:
a = _deserialize_audio({"id": "a2", "codec": "AAC", "channels": 2})
assert a.joc == 0
def test_deserialize_subtitle_forced_flag() -> None:
s = _deserialize_subtitle({"id": "s1", "codec": "WebVTT", "language": "en", "forced": True})
assert isinstance(s, Subtitle)
assert s.forced is True
assert s.sdh is False
def test_deserialize_subtitle_sdh_flag() -> None:
s = _deserialize_subtitle({"id": "s2", "codec": "WebVTT", "language": "en", "sdh": True})
assert s.sdh is True
assert s.forced is False
def test_reconstruct_drm_empty() -> None:
assert _reconstruct_drm(None) == []
assert _reconstruct_drm([]) == []
def test_reconstruct_drm_skips_entries_without_pssh() -> None:
assert _reconstruct_drm([{"type": "widevine"}]) == []
def test_reconstruct_drm_invalid_pssh_silently_dropped() -> None:
assert _reconstruct_drm([{"type": "widevine", "pssh": "not-real-pssh"}]) == []
def test_build_tracks_aggregates() -> None:
data = {
"video": [{"id": "v", "codec": "AVC", "width": 1280, "height": 720, "bitrate": 2500}],
"audio": [{"id": "a", "codec": "AAC", "channels": 2, "bitrate": 128}],
"subtitles": [{"id": "s", "codec": "WebVTT", "language": "en"}],
"attachments": [],
}
t = _build_tracks(data)
assert len(t.videos) == 1
assert len(t.audio) == 1
assert len(t.subtitles) == 1
def test_match_track_by_id() -> None:
a = _deserialize_video({"id": "v1", "codec": "AVC", "width": 1920, "height": 1080})
b = _deserialize_video({"id": "v2", "codec": "AVC", "width": 1280, "height": 720})
remote = _deserialize_video({"id": "v2", "codec": "AVC", "width": 1280, "height": 720})
assert _match_track(remote, [a, b]) is b
def test_match_track_by_attributes_when_id_missing() -> None:
local = _deserialize_video({"id": "X", "codec": "AVC", "width": 1920, "height": 1080, "language": "en"})
remote = _deserialize_video({"id": "Y", "codec": "AVC", "width": 1920, "height": 1080, "language": "en"})
assert _match_track(remote, [local]) is local
def test_match_track_no_candidates_returns_none() -> None:
remote = _deserialize_video({"id": "X", "codec": "AVC", "width": 1, "height": 1})
assert _match_track(remote, []) is None
def test_build_title_movie() -> None:
info = {"type": "movie", "id": "movie-0001", "name": "Foo", "year": 2024, "language": "en"}
title = _build_title(info, "ATV", "fallback")
assert isinstance(title, Movie)
assert title.id == "movie-0001"
assert title.name == "Foo"
def test_build_title_episode() -> None:
info = {
"type": "episode",
"id": "ep-00001",
"series_title": "Show",
"season": 1,
"number": 2,
"name": "Pilot",
"year": 2024,
"language": "en",
}
title = _build_title(info, "ATV", "fallback")
assert isinstance(title, Episode)
assert title.season == 1
assert title.number == 2
assert title.name == "Pilot"
def test_build_title_falls_back_to_id_when_missing() -> None:
title = _build_title({"type": "movie", "name": "x"}, "ATV", "fallback-id")
assert title.id == "fallback-id"

View File

@@ -0,0 +1,159 @@
"""Unit tests for unshackle.core.api.routes.setup_routes wiring + CORS + auth gating.
We build small aiohttp apps in-test with setup_routes(), mirroring what
unshackle/commands/serve.py does. We avoid hitting the real handlers by
stubbing the route table for selected paths.
"""
from __future__ import annotations
import pytest
from aiohttp import web
from unshackle.core.api.compression import compression_middleware
from unshackle.core.api.routes import cors_middleware, setup_routes
pytestmark = pytest.mark.unit
@pytest.fixture
def make_app():
"""Factory that builds an aiohttp app for tests."""
def _factory(remote_only: bool = False, with_auth_middleware: bool = False):
middlewares = [cors_middleware, compression_middleware]
if with_auth_middleware:
middlewares.insert(1, _no_key_required_auth())
app = web.Application(middlewares=middlewares)
app["config"] = {"users": {}}
app["debug_api"] = False
setup_routes(app, remote_only=remote_only)
return app
return _factory
def _no_key_required_auth():
"""Mirror serve.py's api_key_authentication middleware: required X-Secret-Key
on every endpoint except /api/health."""
@web.middleware
async def mw(request, handler):
if request.path == "/api/health":
return await handler(request)
secret = request.headers.get("X-Secret-Key")
if not secret:
return web.json_response({"status": 401, "message": "Secret Key is Empty."}, status=401)
if secret not in request.app["config"]["users"]:
return web.json_response({"status": 401, "message": "Secret Key is Invalid."}, status=401)
return await handler(request)
return mw
def _collect_paths(app: web.Application) -> list[tuple[str, str]]:
return sorted({(r.method, r.resource.canonical) for r in app.router.routes()})
def test_setup_routes_full_mode_wires_all_endpoints(make_app) -> None:
app = make_app(remote_only=False)
paths = _collect_paths(app)
expected = {
("GET", "/api/health"),
("GET", "/api/services"),
("POST", "/api/search"),
("POST", "/api/list-titles"),
("POST", "/api/list-tracks"),
("POST", "/api/download"),
("GET", "/api/download/jobs"),
("GET", "/api/download/jobs/{job_id}"),
("DELETE", "/api/download/jobs/{job_id}"),
("POST", "/api/session/create"),
("GET", "/api/session/{session_id}/titles"),
("POST", "/api/session/{session_id}/tracks"),
("POST", "/api/session/{session_id}/segments"),
("POST", "/api/session/{session_id}/license"),
("GET", "/api/session/{session_id}/prompt"),
("POST", "/api/session/{session_id}/prompt"),
("GET", "/api/session/{session_id}"),
("DELETE", "/api/session/{session_id}"),
}
assert expected.issubset(set(paths))
def test_setup_routes_remote_only_excludes_list_and_download(make_app) -> None:
app = make_app(remote_only=True)
paths = set(_collect_paths(app))
assert ("POST", "/api/list-titles") not in paths
assert ("POST", "/api/list-tracks") not in paths
assert ("POST", "/api/download") not in paths
assert ("GET", "/api/download/jobs") not in paths
# session endpoints still present
assert ("POST", "/api/session/create") in paths
assert ("GET", "/api/session/{session_id}/titles") in paths
assert ("POST", "/api/session/{session_id}/license") in paths
async def test_cors_preflight_returns_headers(make_app, aiohttp_client) -> None:
app = make_app(remote_only=True)
client = await aiohttp_client(app)
resp = await client.options("/api/health")
assert resp.status == 200
assert resp.headers["Access-Control-Allow-Origin"] == "*"
assert "GET" in resp.headers["Access-Control-Allow-Methods"]
assert "X-Secret-Key" in resp.headers["Access-Control-Allow-Headers"]
async def test_health_endpoint_responds_ok(make_app, aiohttp_client, monkeypatch: pytest.MonkeyPatch) -> None:
from unshackle.core.api import routes as routes_mod
async def _no_update(_):
return None
monkeypatch.setattr(routes_mod.UpdateChecker, "check_for_updates", _no_update)
app = make_app(remote_only=True)
client = await aiohttp_client(app)
resp = await client.get("/api/health")
assert resp.status == 200
body = await resp.json()
assert body["status"] == "ok"
assert "version" in body
async def test_health_bypasses_api_key_auth_middleware(make_app, aiohttp_client, monkeypatch) -> None:
from unshackle.core.api import routes as routes_mod
async def _no_update(_):
return None
monkeypatch.setattr(routes_mod.UpdateChecker, "check_for_updates", _no_update)
app = make_app(remote_only=True, with_auth_middleware=True)
client = await aiohttp_client(app)
resp = await client.get("/api/health")
assert resp.status == 200 # health bypasses auth
async def test_auth_middleware_rejects_missing_key(make_app, aiohttp_client) -> None:
app = make_app(remote_only=True, with_auth_middleware=True)
client = await aiohttp_client(app)
resp = await client.get("/api/session/abc")
assert resp.status == 401
body = await resp.json()
assert "Secret Key" in body["message"]
async def test_auth_middleware_rejects_invalid_key(make_app, aiohttp_client) -> None:
app = make_app(remote_only=True, with_auth_middleware=True)
client = await aiohttp_client(app)
resp = await client.get("/api/session/abc", headers={"X-Secret-Key": "wrong"})
assert resp.status == 401
async def test_auth_middleware_accepts_known_key(make_app, aiohttp_client) -> None:
app = make_app(remote_only=True, with_auth_middleware=True)
app["config"]["users"]["good-key"] = {"devices": []}
client = await aiohttp_client(app)
resp = await client.get("/api/session/nonexistent", headers={"X-Secret-Key": "good-key"})
# Auth passed; handler then 404s the session — anything other than 401 is fine here.
assert resp.status != 401

View File

@@ -0,0 +1,79 @@
"""Unit tests for the `unshackle serve` Click command flag surface."""
from __future__ import annotations
import pytest
from click.testing import CliRunner
from unshackle.commands.serve import serve
pytestmark = pytest.mark.unit
@pytest.fixture
def runner() -> CliRunner:
return CliRunner()
def test_serve_help_lists_documented_flags(runner: CliRunner) -> None:
result = runner.invoke(serve, ["--help"])
assert result.exit_code == 0
out = result.output
for flag in (
"--host",
"--port",
"--caddy",
"--api-only",
"--no-widevine",
"--no-playready",
"--no-key",
"--debug-api",
"--debug",
"--remote-only",
):
assert flag in out, f"missing flag in --help: {flag}"
def test_serve_api_only_with_no_widevine_rejected(runner: CliRunner, monkeypatch: pytest.MonkeyPatch) -> None:
"""`--api-only` is mutually exclusive with `--no-widevine`/`--no-playready`."""
monkeypatch.setenv("UNSHACKLE_NO_RUN", "1") # belt-and-braces; not currently checked
# Stub web.run_app to avoid actually starting the server if validation passes.
from aiohttp import web
monkeypatch.setattr(web, "run_app", lambda *a, **kw: None)
# Force a clean config.serve so no_key path doesn't blow up loading wvds.
from unshackle.core.config import config as cfg
monkeypatch.setattr(cfg, "serve", {"api_secret": "x"})
result = runner.invoke(serve, ["--api-only", "--no-widevine", "--no-key"])
assert result.exit_code != 0
assert "Cannot use --api-only" in (result.output or str(result.exception))
def test_serve_no_key_without_api_secret_does_not_require_secret(
runner: CliRunner, monkeypatch: pytest.MonkeyPatch
) -> None:
"""With --no-key, the missing api_secret check is bypassed."""
from aiohttp import web
monkeypatch.setattr(web, "run_app", lambda *a, **kw: None)
from unshackle.core.config import config as cfg
monkeypatch.setattr(cfg, "serve", {})
result = runner.invoke(serve, ["--api-only", "--no-key", "--remote-only"])
# No exception should escape, exit code 0 means startup proceeded then run_app stub returned.
assert result.exit_code == 0, result.output
def test_serve_without_no_key_requires_api_secret(runner: CliRunner, monkeypatch: pytest.MonkeyPatch) -> None:
from unshackle.core.config import config as cfg
monkeypatch.setattr(cfg, "serve", {}) # no api_secret configured
result = runner.invoke(serve, ["--api-only"])
assert result.exit_code != 0
assert "api_secret" in (result.output or "").lower() or "api_secret" in str(result.exception).lower()

View File

@@ -0,0 +1,117 @@
"""Unit tests for unshackle.core.api.session_store.SessionStore."""
from __future__ import annotations
import asyncio
import pytest
from unshackle.core.api.input_bridge import AuthStatus, InputBridge
from unshackle.core.api.session_store import SessionEntry, SessionStore, get_session_store
pytestmark = pytest.mark.unit
class _FakeService:
"""Minimal stub Service used to fill SessionEntry.service_instance."""
def __init__(self, tag: str = "TEST") -> None:
self.tag = tag
@pytest.fixture
def store() -> SessionStore:
return SessionStore()
async def test_create_returns_entry_with_uuid(store: SessionStore) -> None:
entry = await store.create("ATV", _FakeService())
assert isinstance(entry, SessionEntry)
assert entry.service_tag == "ATV"
assert entry.session_id and len(entry.session_id) >= 32
assert store.session_count == 1
async def test_create_with_explicit_session_id(store: SessionStore) -> None:
entry = await store.create("NF", _FakeService(), session_id="fixed-id")
assert entry.session_id == "fixed-id"
async def test_get_returns_none_for_missing(store: SessionStore) -> None:
assert await store.get("nope") is None
async def test_get_touches_last_accessed(store: SessionStore) -> None:
entry = await store.create("DSNP", _FakeService())
before = entry.last_accessed
await asyncio.sleep(0.01)
fetched = await store.get(entry.session_id)
assert fetched is entry
assert fetched.last_accessed > before
async def test_delete_removes_and_cancels_bridge(store: SessionStore) -> None:
entry = await store.create("CRAV", _FakeService())
entry.input_bridge = InputBridge()
assert entry.input_bridge.status is AuthStatus.AUTHENTICATING
deleted = await store.delete(entry.session_id)
assert deleted is True
assert entry.input_bridge.status is AuthStatus.FAILED # cancelled
assert store.session_count == 0
async def test_delete_returns_false_when_missing(store: SessionStore) -> None:
assert await store.delete("missing") is False
async def test_cleanup_expired_drops_old_authenticated(store: SessionStore, monkeypatch: pytest.MonkeyPatch) -> None:
from datetime import datetime, timedelta, timezone
entry = await store.create("ATV", _FakeService())
entry.last_accessed = datetime.now(timezone.utc) - timedelta(seconds=store._ttl + 100)
removed = await store.cleanup_expired()
assert removed == 1
assert store.session_count == 0
async def test_cleanup_expired_keeps_pending_input_under_grace(store: SessionStore) -> None:
"""Sessions awaiting user input get a longer grace period (10 min) than authenticated TTL."""
entry = await store.create("ATV", _FakeService())
entry.input_bridge = InputBridge()
entry.auth_status = AuthStatus.PENDING_INPUT
removed = await store.cleanup_expired()
assert removed == 0
assert store.session_count == 1
async def test_cancel_all_bridges(store: SessionStore) -> None:
a = await store.create("ATV", _FakeService())
b = await store.create("NF", _FakeService())
a.input_bridge = InputBridge()
b.input_bridge = InputBridge()
await store.cancel_all_bridges()
assert a.input_bridge.status is AuthStatus.FAILED
assert b.input_bridge.status is AuthStatus.FAILED
async def test_get_session_store_returns_singleton() -> None:
a = get_session_store()
b = get_session_store()
assert a is b
async def test_max_sessions_evicts_oldest(store: SessionStore, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(type(store), "_max_sessions", property(lambda _: 2))
a = await store.create("A", _FakeService(), session_id="a")
await asyncio.sleep(0.01)
b = await store.create("B", _FakeService(), session_id="b")
await asyncio.sleep(0.01)
c = await store.create("C", _FakeService(), session_id="c")
assert store.session_count == 2
assert await store.get("a") is None # evicted
assert (await store.get("b")) is b
assert (await store.get("c")) is c