mirror of
https://github.com/unshackle-dl/unshackle.git
synced 2026-06-10 11:12:13 +00:00
test(remote): add unit + e2e suite for remote-services subsystem
Covers RemoteClient/RemoteService, REST routes, handlers, SessionStore, InputBridge, DownloadQueueManager, errors, compression, and serve CLI. E2e tier opts in via --live and can auto-spawn its own serve.
This commit is contained in:
0
tests/remote/unit/__init__.py
Normal file
0
tests/remote/unit/__init__.py
Normal file
65
tests/remote/unit/test_compression.py
Normal file
65
tests/remote/unit/test_compression.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Unit tests for unshackle.core.api.compression.compression_middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from unshackle.core.api.compression import compression_middleware
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeReq:
|
||||
def __init__(self, accept_encoding: str = "gzip") -> None:
|
||||
self.headers = {"Accept-Encoding": accept_encoding}
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro) if False else asyncio.run(coro)
|
||||
|
||||
|
||||
def test_skips_when_client_does_not_accept_gzip() -> None:
|
||||
payload = b"x" * 4096
|
||||
body_json = json.dumps({"data": "x" * 4096}).encode()
|
||||
|
||||
async def handler(req): # noqa: ARG001
|
||||
return web.json_response({"data": "x" * 4096})
|
||||
|
||||
req = _FakeReq(accept_encoding="identity")
|
||||
resp = _run(compression_middleware(req, handler))
|
||||
assert resp.headers.get("Content-Encoding") != "gzip"
|
||||
assert resp.body == body_json or len(resp.body) >= len(body_json) - 8
|
||||
|
||||
|
||||
def test_skips_when_body_below_threshold() -> None:
|
||||
async def handler(req): # noqa: ARG001
|
||||
return web.json_response({"hi": "x"})
|
||||
|
||||
resp = _run(compression_middleware(_FakeReq(), handler))
|
||||
assert resp.headers.get("Content-Encoding") != "gzip"
|
||||
|
||||
|
||||
def test_skips_non_json_response() -> None:
|
||||
async def handler(req): # noqa: ARG001
|
||||
return web.Response(body=b"x" * 4096, content_type="text/plain")
|
||||
|
||||
resp = _run(compression_middleware(_FakeReq(), handler))
|
||||
assert resp.headers.get("Content-Encoding") != "gzip"
|
||||
|
||||
|
||||
def test_compresses_large_json_when_accepted() -> None:
|
||||
big = {"data": "x" * 4096}
|
||||
|
||||
async def handler(req): # noqa: ARG001
|
||||
return web.json_response(big)
|
||||
|
||||
resp = _run(compression_middleware(_FakeReq(), handler))
|
||||
assert resp.headers.get("Content-Encoding") == "gzip"
|
||||
decompressed = gzip.decompress(resp.body)
|
||||
assert json.loads(decompressed) == big
|
||||
assert resp.headers["Content-Length"] == str(len(resp.body))
|
||||
120
tests/remote/unit/test_download_manager.py
Normal file
120
tests/remote/unit/test_download_manager.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Unit tests for DownloadJob + DownloadQueueManager state machine.
|
||||
|
||||
These tests focus on the queue manager's data layer (create/get/list/cancel/
|
||||
cleanup/serialize) — they do not exercise the actual subprocess download path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from unshackle.core.api.download_manager import DownloadJob, DownloadQueueManager, JobStatus, get_download_manager
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager() -> DownloadQueueManager:
|
||||
"""Fresh manager. We never call start_workers() so no async tasks are created."""
|
||||
return DownloadQueueManager(max_concurrent_downloads=2, job_retention_hours=24)
|
||||
|
||||
|
||||
def test_create_job_returns_queued_job(manager: DownloadQueueManager) -> None:
|
||||
job = manager.create_job("ATV", "movie-123", profile="default")
|
||||
assert isinstance(job, DownloadJob)
|
||||
assert job.status is JobStatus.QUEUED
|
||||
assert job.service == "ATV"
|
||||
assert job.title_id == "movie-123"
|
||||
assert job.parameters == {"profile": "default"}
|
||||
|
||||
|
||||
def test_get_and_list_jobs(manager: DownloadQueueManager) -> None:
|
||||
a = manager.create_job("ATV", "a")
|
||||
b = manager.create_job("NF", "b")
|
||||
assert manager.get_job(a.job_id) is a
|
||||
assert manager.get_job("missing") is None
|
||||
listed = manager.list_jobs()
|
||||
assert {j.job_id for j in listed} == {a.job_id, b.job_id}
|
||||
|
||||
|
||||
def test_to_dict_short_vs_full(manager: DownloadQueueManager) -> None:
|
||||
job = manager.create_job("ATV", "t", profile="p")
|
||||
short = job.to_dict()
|
||||
assert "parameters" not in short
|
||||
assert short["status"] == "queued"
|
||||
assert short["service"] == "ATV"
|
||||
full = job.to_dict(include_full_details=True)
|
||||
assert full["parameters"] == {"profile": "p"}
|
||||
assert "error_message" in full
|
||||
assert "output_files" in full
|
||||
|
||||
|
||||
def test_cancel_queued_job_sets_cancelled_and_signals_event(manager: DownloadQueueManager) -> None:
|
||||
job = manager.create_job("ATV", "t")
|
||||
assert manager.cancel_job(job.job_id) is True
|
||||
assert job.status is JobStatus.CANCELLED
|
||||
assert job.cancel_event.is_set()
|
||||
|
||||
|
||||
def test_cancel_unknown_job_returns_false(manager: DownloadQueueManager) -> None:
|
||||
assert manager.cancel_job("never-existed") is False
|
||||
|
||||
|
||||
def test_cancel_completed_job_returns_false(manager: DownloadQueueManager) -> None:
|
||||
job = manager.create_job("ATV", "t")
|
||||
job.status = JobStatus.COMPLETED
|
||||
assert manager.cancel_job(job.job_id) is False
|
||||
|
||||
|
||||
def test_cancel_downloading_job_signals(manager: DownloadQueueManager) -> None:
|
||||
job = manager.create_job("ATV", "t")
|
||||
job.status = JobStatus.DOWNLOADING
|
||||
assert manager.cancel_job(job.job_id) is True
|
||||
assert job.status is JobStatus.CANCELLED
|
||||
assert job.cancel_event.is_set()
|
||||
|
||||
|
||||
def test_cleanup_old_jobs_drops_old_terminal_states(manager: DownloadQueueManager) -> None:
|
||||
now = datetime.now()
|
||||
old = now - timedelta(hours=48)
|
||||
keep_recent = manager.create_job("ATV", "recent")
|
||||
drop_old_done = manager.create_job("ATV", "old-done")
|
||||
drop_old_failed = manager.create_job("ATV", "old-failed")
|
||||
keep_running = manager.create_job("ATV", "running")
|
||||
|
||||
keep_recent.status = JobStatus.COMPLETED
|
||||
keep_recent.completed_time = now
|
||||
|
||||
drop_old_done.status = JobStatus.COMPLETED
|
||||
drop_old_done.completed_time = old
|
||||
|
||||
drop_old_failed.status = JobStatus.FAILED
|
||||
drop_old_failed.created_time = old # never set completed_time
|
||||
|
||||
keep_running.status = JobStatus.DOWNLOADING
|
||||
|
||||
removed = manager.cleanup_old_jobs()
|
||||
assert removed == 2
|
||||
remaining = {j.job_id for j in manager.list_jobs()}
|
||||
assert keep_recent.job_id in remaining
|
||||
assert keep_running.job_id in remaining
|
||||
assert drop_old_done.job_id not in remaining
|
||||
assert drop_old_failed.job_id not in remaining
|
||||
|
||||
|
||||
def test_get_download_manager_returns_singleton() -> None:
|
||||
a = get_download_manager()
|
||||
b = get_download_manager()
|
||||
assert a is b
|
||||
|
||||
|
||||
def test_job_status_values() -> None:
|
||||
assert {s.value for s in JobStatus} == {
|
||||
"queued",
|
||||
"downloading",
|
||||
"completed",
|
||||
"failed",
|
||||
"cancelled",
|
||||
}
|
||||
132
tests/remote/unit/test_errors.py
Normal file
132
tests/remote/unit/test_errors.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Unit tests for unshackle.core.api.errors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from unshackle.core.api.errors import (
|
||||
APIError,
|
||||
APIErrorCode,
|
||||
build_error_response,
|
||||
categorize_exception,
|
||||
handle_api_exception,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _body(resp) -> dict:
|
||||
return json.loads(resp.body.decode("utf-8"))
|
||||
|
||||
|
||||
def test_api_error_default_http_status_per_code() -> None:
|
||||
cases = {
|
||||
APIErrorCode.INVALID_INPUT: 400,
|
||||
APIErrorCode.INVALID_SERVICE: 400,
|
||||
APIErrorCode.AUTH_REQUIRED: 401,
|
||||
APIErrorCode.AUTH_FAILED: 401,
|
||||
APIErrorCode.FORBIDDEN: 403,
|
||||
APIErrorCode.GEOFENCE: 403,
|
||||
APIErrorCode.NOT_FOUND: 404,
|
||||
APIErrorCode.SESSION_NOT_FOUND: 404,
|
||||
APIErrorCode.TRACK_NOT_FOUND: 404,
|
||||
APIErrorCode.RATE_LIMITED: 429,
|
||||
APIErrorCode.INTERNAL_ERROR: 500,
|
||||
APIErrorCode.SERVICE_ERROR: 502,
|
||||
APIErrorCode.DRM_ERROR: 502,
|
||||
APIErrorCode.NETWORK_ERROR: 503,
|
||||
APIErrorCode.SERVICE_UNAVAILABLE: 503,
|
||||
}
|
||||
for code, expected in cases.items():
|
||||
assert APIError(code, "x").http_status == expected, code
|
||||
|
||||
|
||||
def test_api_error_explicit_http_status_overrides_default() -> None:
|
||||
err = APIError(APIErrorCode.INVALID_INPUT, "x", http_status=418)
|
||||
assert err.http_status == 418
|
||||
|
||||
|
||||
def test_build_error_response_from_api_error() -> None:
|
||||
err = APIError(
|
||||
APIErrorCode.SESSION_NOT_FOUND,
|
||||
"no such session",
|
||||
details={"session_id": "abc"},
|
||||
retryable=False,
|
||||
)
|
||||
resp = build_error_response(err)
|
||||
assert resp.status == 404
|
||||
body = _body(resp)
|
||||
assert body["status"] == "error"
|
||||
assert body["error_code"] == "SESSION_NOT_FOUND"
|
||||
assert body["message"] == "no such session"
|
||||
assert body["details"] == {"session_id": "abc"}
|
||||
assert "retryable" not in body
|
||||
assert "debug_info" not in body
|
||||
assert "timestamp" in body
|
||||
|
||||
|
||||
def test_build_error_response_retryable_flag() -> None:
|
||||
err = APIError(APIErrorCode.NETWORK_ERROR, "boom", retryable=True)
|
||||
body = _body(build_error_response(err))
|
||||
assert body["retryable"] is True
|
||||
|
||||
|
||||
def test_build_error_response_from_generic_exception() -> None:
|
||||
resp = build_error_response(RuntimeError("oops"))
|
||||
assert resp.status == 500
|
||||
body = _body(resp)
|
||||
assert body["error_code"] == "INTERNAL_ERROR"
|
||||
assert body["message"] == "oops"
|
||||
|
||||
|
||||
def test_build_error_response_debug_mode_includes_traceback() -> None:
|
||||
try:
|
||||
raise ValueError("kaboom")
|
||||
except ValueError as e:
|
||||
resp = build_error_response(e, debug_mode=True, extra_debug_info={"foo": "bar"})
|
||||
body = _body(resp)
|
||||
assert body["debug_info"]["exception_type"] == "ValueError"
|
||||
assert "traceback" in body["debug_info"]
|
||||
assert body["debug_info"]["foo"] == "bar"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exc, expected_code",
|
||||
[
|
||||
(Exception("Invalid credentials provided"), APIErrorCode.AUTH_FAILED),
|
||||
(Exception("Connection refused"), APIErrorCode.NETWORK_ERROR),
|
||||
(TimeoutError("read timeout"), APIErrorCode.NETWORK_ERROR),
|
||||
(Exception("Not available in your region"), APIErrorCode.GEOFENCE),
|
||||
(Exception("Title not found"), APIErrorCode.NOT_FOUND),
|
||||
(Exception("HTTP 429 too many requests"), APIErrorCode.RATE_LIMITED),
|
||||
(Exception("DRM license fetch failed"), APIErrorCode.DRM_ERROR),
|
||||
(Exception("503 service unavailable"), APIErrorCode.SERVICE_UNAVAILABLE),
|
||||
(ValueError("malformed body"), APIErrorCode.INVALID_INPUT),
|
||||
(RuntimeError("totally novel failure xyz"), APIErrorCode.INTERNAL_ERROR),
|
||||
],
|
||||
)
|
||||
def test_categorize_exception(exc: Exception, expected_code: APIErrorCode) -> None:
|
||||
api_err = categorize_exception(exc, context={"service": "ATV"})
|
||||
assert api_err.error_code == expected_code
|
||||
assert api_err.details.get("service") == "ATV"
|
||||
|
||||
|
||||
def test_categorize_preserves_context() -> None:
|
||||
api_err = categorize_exception(ValueError("bad"), context={"op": "search"})
|
||||
assert api_err.details["op"] == "search"
|
||||
|
||||
|
||||
def test_handle_api_exception_with_api_error_preserves_code() -> None:
|
||||
err = APIError(APIErrorCode.TRACK_NOT_FOUND, "no track")
|
||||
resp = handle_api_exception(err)
|
||||
body = _body(resp)
|
||||
assert body["error_code"] == "TRACK_NOT_FOUND"
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
def test_handle_api_exception_categorizes_generic() -> None:
|
||||
resp = handle_api_exception(ConnectionError("oops"))
|
||||
body = _body(resp)
|
||||
assert body["error_code"] == "NETWORK_ERROR"
|
||||
273
tests/remote/unit/test_handlers_serialize.py
Normal file
273
tests/remote/unit/test_handlers_serialize.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Unit tests for unshackle.core.api.handlers serializers + validators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langcodes import Language
|
||||
|
||||
from unshackle.core.api.handlers import (
|
||||
sanitize_log,
|
||||
serialize_audio_track,
|
||||
serialize_drm,
|
||||
serialize_subtitle_track,
|
||||
serialize_title,
|
||||
serialize_video_track,
|
||||
validate_download_parameters,
|
||||
validate_service,
|
||||
)
|
||||
from unshackle.core.titles.episode import Episode
|
||||
from unshackle.core.titles.movie import Movie
|
||||
from unshackle.core.tracks import Audio, Subtitle, Video
|
||||
from unshackle.core.tracks.track import Track
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeSvc:
|
||||
pass
|
||||
|
||||
|
||||
def _video(**overrides) -> Video:
|
||||
base = dict(
|
||||
url="https://example.com/v.mpd",
|
||||
language=Language.get("en"),
|
||||
descriptor=Track.Descriptor.URL,
|
||||
codec=Video.Codec.AVC,
|
||||
range_=Video.Range.SDR,
|
||||
bitrate=5_000_000,
|
||||
width=1920,
|
||||
height=1080,
|
||||
fps=24,
|
||||
id_="video-001",
|
||||
)
|
||||
base.update(overrides)
|
||||
return Video(**base)
|
||||
|
||||
|
||||
def _audio(**overrides) -> Audio:
|
||||
base = dict(
|
||||
url="https://example.com/a.mpd",
|
||||
language=Language.get("en"),
|
||||
descriptor=Track.Descriptor.URL,
|
||||
codec=Audio.Codec.AAC,
|
||||
bitrate=128_000,
|
||||
channels=2,
|
||||
joc=0,
|
||||
descriptive=False,
|
||||
id_="audio-001",
|
||||
)
|
||||
base.update(overrides)
|
||||
return Audio(**base)
|
||||
|
||||
|
||||
def _subtitle(**overrides) -> Subtitle:
|
||||
base = dict(
|
||||
url="https://example.com/s.vtt",
|
||||
language=Language.get("en"),
|
||||
descriptor=Track.Descriptor.URL,
|
||||
codec=Subtitle.Codec.WebVTT,
|
||||
cc=False,
|
||||
sdh=False,
|
||||
forced=False,
|
||||
id_="sub-001",
|
||||
)
|
||||
base.update(overrides)
|
||||
return Subtitle(**base)
|
||||
|
||||
|
||||
# ---------- sanitize_log ----------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw, expected",
|
||||
[
|
||||
("hello\nworld", "helloworld"),
|
||||
("a\r\nb\x00c", "abc"),
|
||||
("clean", "clean"),
|
||||
(12345, "12345"),
|
||||
],
|
||||
)
|
||||
def test_sanitize_log(raw, expected: str) -> None:
|
||||
assert sanitize_log(raw) == expected
|
||||
|
||||
|
||||
# ---------- serialize_title ----------
|
||||
|
||||
|
||||
def test_serialize_title_movie() -> None:
|
||||
movie = Movie(id_="movie-0001", service=_FakeSvc, name="Title X", year=2024, language=Language.get("en"))
|
||||
d = serialize_title(movie)
|
||||
assert d["type"] == "movie"
|
||||
assert d["name"] == "Title X"
|
||||
assert d["year"] == 2024
|
||||
assert d["id"] == "movie-0001"
|
||||
assert d["language"] == "en"
|
||||
|
||||
|
||||
def test_serialize_title_episode_named() -> None:
|
||||
ep = Episode(
|
||||
id_="ep-00001",
|
||||
service=_FakeSvc,
|
||||
title="My Show",
|
||||
season=2,
|
||||
number=3,
|
||||
name="Pilot",
|
||||
year=2024,
|
||||
language=Language.get("en"),
|
||||
)
|
||||
d = serialize_title(ep)
|
||||
assert d["type"] == "episode"
|
||||
assert d["series_title"] == "My Show"
|
||||
assert d["season"] == 2
|
||||
assert d["number"] == 3
|
||||
assert d["name"] == "Pilot"
|
||||
|
||||
|
||||
def test_serialize_title_episode_unnamed_falls_back_to_number() -> None:
|
||||
ep = Episode(
|
||||
id_="ep-00002",
|
||||
service=_FakeSvc,
|
||||
title="Show",
|
||||
season=1,
|
||||
number=5,
|
||||
name=None,
|
||||
year=None,
|
||||
language=Language.get("en"),
|
||||
)
|
||||
d = serialize_title(ep)
|
||||
assert d["name"] == "Episode 05"
|
||||
|
||||
|
||||
# ---------- serialize_video/audio/subtitle ----------
|
||||
|
||||
|
||||
def test_serialize_video_track_basic() -> None:
|
||||
d = serialize_video_track(_video())
|
||||
assert d["id"] == "video-001"
|
||||
assert d["codec"] == "AVC"
|
||||
assert d["bitrate"] == 5000 # kbps
|
||||
assert d["resolution"] == "1920x1080"
|
||||
assert d["fps"] == 24
|
||||
assert d["range"] == "SDR"
|
||||
assert d["language"] == "en"
|
||||
assert d["drm"] is None
|
||||
assert "url" not in d
|
||||
|
||||
|
||||
def test_serialize_video_track_include_url() -> None:
|
||||
d = serialize_video_track(_video(), include_url=True)
|
||||
assert d["url"] == "https://example.com/v.mpd"
|
||||
|
||||
|
||||
def test_serialize_audio_track_basic() -> None:
|
||||
d = serialize_audio_track(_audio())
|
||||
assert d["id"] == "audio-001"
|
||||
assert d["codec"] == "AAC"
|
||||
assert d["bitrate"] == 128
|
||||
assert d["channels"] == 2
|
||||
assert d["descriptive"] is False
|
||||
|
||||
|
||||
def test_serialize_subtitle_track_basic() -> None:
|
||||
d = serialize_subtitle_track(_subtitle(forced=True))
|
||||
assert d["id"] == "sub-001"
|
||||
assert d["codec"] == "WebVTT"
|
||||
assert d["forced"] is True
|
||||
assert d["sdh"] is False
|
||||
assert d["cc"] is False
|
||||
|
||||
|
||||
# ---------- serialize_drm ----------
|
||||
|
||||
|
||||
def test_serialize_drm_none_returns_none() -> None:
|
||||
assert serialize_drm(None) is None
|
||||
assert serialize_drm([]) is None
|
||||
|
||||
|
||||
def test_serialize_drm_widevine_minimal() -> None:
|
||||
class _PSSH:
|
||||
def dumps(self) -> str:
|
||||
return "BASE64PSSH=="
|
||||
|
||||
class _Widevine:
|
||||
def __init__(self) -> None:
|
||||
self._pssh = _PSSH()
|
||||
self.kids = ["00112233445566778899aabbccddeeff"]
|
||||
self.license_url = "https://lic.example.com/wv"
|
||||
|
||||
out = serialize_drm(_Widevine())
|
||||
assert isinstance(out, list)
|
||||
assert len(out) == 1
|
||||
info = out[0]
|
||||
assert info["type"] == "_widevine" # class name lowercased
|
||||
assert info["pssh"] == "BASE64PSSH=="
|
||||
assert info["kids"] == ["00112233445566778899aabbccddeeff"]
|
||||
assert info["license_url"] == "https://lic.example.com/wv"
|
||||
|
||||
|
||||
# ---------- validate_service ----------
|
||||
|
||||
|
||||
def test_validate_service_unknown_returns_none() -> None:
|
||||
assert validate_service("NOPE_THIS_IS_NOT_REAL_") is None
|
||||
|
||||
|
||||
# ---------- validate_download_parameters ----------
|
||||
|
||||
|
||||
def test_validate_download_params_accepts_defaults() -> None:
|
||||
assert validate_download_parameters({}) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data, fragment",
|
||||
[
|
||||
({"vcodec": "WUT"}, "Invalid vcodec"),
|
||||
({"vcodec": 123}, "vcodec must be a string or list"),
|
||||
({"acodec": "MP9"}, "Invalid acodec"),
|
||||
({"sub_format": "doc"}, "Invalid sub_format"),
|
||||
({"vbitrate": -1}, "vbitrate"),
|
||||
({"abitrate": "no"}, "abitrate"),
|
||||
({"vbitrate_range": "no-dash-but-letters"}, None),
|
||||
({"vbitrate_range": "nope"}, "MIN-MAX"),
|
||||
({"channels": -3}, "channels"),
|
||||
({"workers": 0}, "workers"),
|
||||
({"downloads": 0}, "downloads"),
|
||||
({"video_only": True, "audio_only": True}, "exclusive"),
|
||||
({"no_subs": True, "subs_only": True}, "no_subs and subs_only"),
|
||||
({"no_audio": True, "audio_only": True}, "no_audio and audio_only"),
|
||||
({"s_lang": ["en"], "require_subs": ["en"]}, "s_lang and require_subs"),
|
||||
({"range": "UHD"}, "Invalid range"),
|
||||
({"range": ["SDR", "UHD"]}, "Invalid range value"),
|
||||
],
|
||||
)
|
||||
def test_validate_download_params_errors(data: dict, fragment) -> None:
|
||||
result = validate_download_parameters(data)
|
||||
if fragment is None:
|
||||
# A dash-containing string is valid syntactically per current rule
|
||||
assert result is None
|
||||
else:
|
||||
assert result is not None
|
||||
assert fragment in result
|
||||
|
||||
|
||||
def test_validate_download_params_accepts_valid_values() -> None:
|
||||
assert (
|
||||
validate_download_parameters(
|
||||
{
|
||||
"vcodec": "H264,H265",
|
||||
"acodec": ["AAC", "EAC3"],
|
||||
"sub_format": "VTT",
|
||||
"vbitrate": 6000,
|
||||
"abitrate": 128,
|
||||
"vbitrate_range": "6000-7000",
|
||||
"abitrate_range": "96-192",
|
||||
"channels": 5.1,
|
||||
"workers": 8,
|
||||
"downloads": 2,
|
||||
"range": ["SDR", "HDR10"],
|
||||
}
|
||||
)
|
||||
is None
|
||||
)
|
||||
100
tests/remote/unit/test_input_bridge.py
Normal file
100
tests/remote/unit/test_input_bridge.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Unit tests for unshackle.core.api.input_bridge.InputBridge."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from unshackle.core.api.input_bridge import AuthStatus, BridgeCancelledError, InputBridge
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_initial_status_is_authenticating() -> None:
|
||||
bridge = InputBridge()
|
||||
assert bridge.status is AuthStatus.AUTHENTICATING
|
||||
assert bridge.get_pending_prompt() is None
|
||||
assert bridge.error is None
|
||||
|
||||
|
||||
def test_submit_response_returns_false_when_no_prompt_pending() -> None:
|
||||
bridge = InputBridge()
|
||||
assert bridge.submit_response("foo") is False
|
||||
|
||||
|
||||
def test_request_input_blocks_until_submit() -> None:
|
||||
bridge = InputBridge()
|
||||
result: list[str] = []
|
||||
|
||||
def worker() -> None:
|
||||
result.append(bridge.request_input("OTP?", timeout=5))
|
||||
|
||||
t = threading.Thread(target=worker)
|
||||
t.start()
|
||||
for _ in range(50):
|
||||
if bridge.get_pending_prompt() == "OTP?":
|
||||
break
|
||||
time.sleep(0.02)
|
||||
assert bridge.status is AuthStatus.PENDING_INPUT
|
||||
assert bridge.get_pending_prompt() == "OTP?"
|
||||
|
||||
assert bridge.submit_response("123456") is True
|
||||
t.join(timeout=2)
|
||||
assert result == ["123456"]
|
||||
assert bridge.status is AuthStatus.AUTHENTICATING
|
||||
assert bridge.get_pending_prompt() is None
|
||||
|
||||
|
||||
def test_request_input_times_out() -> None:
|
||||
bridge = InputBridge()
|
||||
with pytest.raises(TimeoutError):
|
||||
bridge.request_input("hello", timeout=0.6)
|
||||
assert bridge.status is AuthStatus.FAILED
|
||||
assert "timed out" in (bridge.error or "")
|
||||
|
||||
|
||||
def test_cancel_before_request_raises_immediately() -> None:
|
||||
bridge = InputBridge()
|
||||
bridge.cancel()
|
||||
with pytest.raises(BridgeCancelledError):
|
||||
bridge.request_input("hello", timeout=5)
|
||||
assert bridge.status is AuthStatus.FAILED
|
||||
|
||||
|
||||
def test_cancel_unblocks_pending_request() -> None:
|
||||
bridge = InputBridge()
|
||||
exc: list[Exception] = []
|
||||
|
||||
def worker() -> None:
|
||||
try:
|
||||
bridge.request_input("OTP?", timeout=5)
|
||||
except BridgeCancelledError as e:
|
||||
exc.append(e)
|
||||
|
||||
t = threading.Thread(target=worker)
|
||||
t.start()
|
||||
for _ in range(50):
|
||||
if bridge.status is AuthStatus.PENDING_INPUT:
|
||||
break
|
||||
time.sleep(0.02)
|
||||
|
||||
bridge.cancel()
|
||||
t.join(timeout=2)
|
||||
assert exc and isinstance(exc[0], BridgeCancelledError)
|
||||
assert bridge.status is AuthStatus.FAILED
|
||||
|
||||
|
||||
def test_get_pending_prompt_returns_none_outside_pending_state() -> None:
|
||||
bridge = InputBridge()
|
||||
bridge.status = AuthStatus.AUTHENTICATED
|
||||
assert bridge.get_pending_prompt() is None
|
||||
|
||||
|
||||
def test_status_and_error_setters() -> None:
|
||||
bridge = InputBridge()
|
||||
bridge.status = AuthStatus.AUTHENTICATED
|
||||
bridge.error = "boom"
|
||||
assert bridge.status is AuthStatus.AUTHENTICATED
|
||||
assert bridge.error == "boom"
|
||||
111
tests/remote/unit/test_remote_client.py
Normal file
111
tests/remote/unit/test_remote_client.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Unit tests for unshackle.core.remote_service.RemoteClient."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import responses
|
||||
|
||||
from unshackle.core.remote_service import RemoteClient
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> RemoteClient:
|
||||
return RemoteClient(server_url="http://srv:8786", api_key="secret-xyz")
|
||||
|
||||
|
||||
def test_session_sets_secret_key_header(client: RemoteClient) -> None:
|
||||
s = client.session
|
||||
assert s.headers.get("X-Secret-Key") == "secret-xyz"
|
||||
assert s.headers["User-Agent"].startswith("unshackle/")
|
||||
|
||||
|
||||
def test_session_omits_secret_key_when_empty() -> None:
|
||||
c = RemoteClient(server_url="http://srv:8786", api_key="")
|
||||
assert "X-Secret-Key" not in c.session.headers
|
||||
|
||||
|
||||
def test_server_url_trailing_slash_stripped() -> None:
|
||||
c = RemoteClient(server_url="http://srv:8786/", api_key="")
|
||||
assert c.server_url == "http://srv:8786"
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_get_returns_json(client: RemoteClient) -> None:
|
||||
responses.add(
|
||||
responses.GET,
|
||||
"http://srv:8786/api/health",
|
||||
json={"status": "ok"},
|
||||
status=200,
|
||||
)
|
||||
assert client.get("/api/health") == {"status": "ok"}
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_post_sends_json_body(client: RemoteClient) -> None:
|
||||
captured = {}
|
||||
|
||||
def cb(request):
|
||||
captured["body"] = json.loads(request.body)
|
||||
return (200, {}, json.dumps({"session_id": "abc"}))
|
||||
|
||||
responses.add_callback(
|
||||
responses.POST, "http://srv:8786/api/session/create", callback=cb, content_type="application/json"
|
||||
)
|
||||
result = client.post("/api/session/create", {"service": "ATV"})
|
||||
assert result == {"session_id": "abc"}
|
||||
assert captured["body"] == {"service": "ATV"}
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_delete_returns_json(client: RemoteClient) -> None:
|
||||
responses.add(
|
||||
responses.DELETE,
|
||||
"http://srv:8786/api/session/abc",
|
||||
json={"status": "deleted"},
|
||||
status=200,
|
||||
)
|
||||
assert client.delete("/api/session/abc") == {"status": "deleted"}
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_4xx_raises_systemexit_with_logged_error(client: RemoteClient, caplog: pytest.LogCaptureFixture) -> None:
|
||||
responses.add(
|
||||
responses.GET,
|
||||
"http://srv:8786/api/session/none",
|
||||
json={"error_code": "SESSION_NOT_FOUND", "message": "no such session"},
|
||||
status=404,
|
||||
)
|
||||
with caplog.at_level("ERROR"), pytest.raises(SystemExit):
|
||||
client.get("/api/session/none")
|
||||
assert "SESSION_NOT_FOUND" in caplog.text
|
||||
assert "no such session" in caplog.text
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_connection_error_raises_systemexit(client: RemoteClient) -> None:
|
||||
import requests
|
||||
|
||||
responses.add(
|
||||
responses.GET,
|
||||
"http://srv:8786/api/health",
|
||||
body=requests.ConnectionError("boom"),
|
||||
)
|
||||
with pytest.raises(SystemExit):
|
||||
client.get("/api/health")
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_timeout_raises_systemexit(client: RemoteClient) -> None:
|
||||
import requests
|
||||
|
||||
responses.add(
|
||||
responses.GET,
|
||||
"http://srv:8786/api/health",
|
||||
body=requests.Timeout("slow"),
|
||||
)
|
||||
with pytest.raises(SystemExit):
|
||||
client.get("/api/health")
|
||||
117
tests/remote/unit/test_remote_resolve.py
Normal file
117
tests/remote/unit/test_remote_resolve.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Unit tests for resolve_server / _resolve_proxy in unshackle.core.remote_service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import click
|
||||
import pytest
|
||||
|
||||
from unshackle.core.remote_service import _resolve_proxy, resolve_server
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_remote_services(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from unshackle.core import remote_service as rs
|
||||
|
||||
monkeypatch.setattr(rs.config, "remote_services", {})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_remote_service(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from unshackle.core import remote_service as rs
|
||||
|
||||
monkeypatch.setattr(
|
||||
rs.config,
|
||||
"remote_services",
|
||||
{
|
||||
"primary": {
|
||||
"url": "https://primary:8080",
|
||||
"api_key": "key-abc",
|
||||
"services": {"ATV": True, "NF": True},
|
||||
"server_cdm": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multi_remote_services(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from unshackle.core import remote_service as rs
|
||||
|
||||
monkeypatch.setattr(
|
||||
rs.config,
|
||||
"remote_services",
|
||||
{
|
||||
"a": {"url": "https://a:8080", "api_key": "ka"},
|
||||
"b": {"url": "https://b:8080", "api_key": "kb"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_server_no_config_raises_click(empty_remote_services) -> None:
|
||||
with pytest.raises(click.ClickException) as exc:
|
||||
resolve_server(None)
|
||||
assert "remote_services" in str(exc.value.message)
|
||||
|
||||
|
||||
def test_resolve_server_single_picks_only_entry(single_remote_service) -> None:
|
||||
url, key, services = resolve_server(None)
|
||||
assert url == "https://primary:8080"
|
||||
assert key == "key-abc"
|
||||
assert services["_server_cdm"] is True
|
||||
assert services.get("ATV") is True
|
||||
|
||||
|
||||
def test_resolve_server_explicit_name(single_remote_service) -> None:
|
||||
url, key, services = resolve_server("primary")
|
||||
assert url == "https://primary:8080"
|
||||
assert services["_server_cdm"] is True
|
||||
|
||||
|
||||
def test_resolve_server_unknown_name_raises(single_remote_service) -> None:
|
||||
with pytest.raises(click.ClickException) as exc:
|
||||
resolve_server("bogus")
|
||||
assert "bogus" in str(exc.value.message)
|
||||
|
||||
|
||||
def test_resolve_server_multi_requires_explicit(multi_remote_services) -> None:
|
||||
with pytest.raises(click.ClickException) as exc:
|
||||
resolve_server(None)
|
||||
assert "--server" in str(exc.value.message)
|
||||
|
||||
|
||||
def test_resolve_server_multi_with_name(multi_remote_services) -> None:
|
||||
url, key, services = resolve_server("b")
|
||||
assert url == "https://b:8080"
|
||||
assert key == "kb"
|
||||
assert services["_server_cdm"] is False
|
||||
|
||||
|
||||
def test_resolve_proxy_none_returns_none() -> None:
|
||||
assert _resolve_proxy(None) is None
|
||||
assert _resolve_proxy("") is None
|
||||
|
||||
|
||||
def test_resolve_proxy_passes_through_value(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
import unshackle.core.proxies.resolve as resolve_mod
|
||||
|
||||
monkeypatch.setattr(resolve_mod, "initialize_proxy_providers", lambda: [])
|
||||
monkeypatch.setattr(resolve_mod, "resolve_proxy", lambda arg, providers: f"http://proxy/{arg}")
|
||||
|
||||
assert _resolve_proxy("us") == "http://proxy/us"
|
||||
|
||||
|
||||
def test_resolve_proxy_value_error_becomes_click(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
import unshackle.core.proxies.resolve as resolve_mod
|
||||
|
||||
monkeypatch.setattr(resolve_mod, "initialize_proxy_providers", lambda: [])
|
||||
|
||||
def boom(*_):
|
||||
raise ValueError("no such country")
|
||||
|
||||
monkeypatch.setattr(resolve_mod, "resolve_proxy", boom)
|
||||
|
||||
with pytest.raises(click.ClickException) as exc:
|
||||
_resolve_proxy("xx")
|
||||
assert "no such country" in str(exc.value.message)
|
||||
157
tests/remote/unit/test_remote_service_helpers.py
Normal file
157
tests/remote/unit/test_remote_service_helpers.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Unit tests for module-level helpers in unshackle.core.remote_service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import pytest
|
||||
|
||||
from unshackle.core.remote_service import (
|
||||
_build_title,
|
||||
_build_tracks,
|
||||
_deserialize_audio,
|
||||
_deserialize_subtitle,
|
||||
_deserialize_video,
|
||||
_enum_get,
|
||||
_match_track,
|
||||
_reconstruct_drm,
|
||||
)
|
||||
from unshackle.core.titles.episode import Episode
|
||||
from unshackle.core.titles.movie import Movie
|
||||
from unshackle.core.tracks import Audio, Subtitle, Video
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _Color(Enum):
|
||||
RED = 1
|
||||
BLUE = 2
|
||||
|
||||
|
||||
def test_enum_get_known() -> None:
|
||||
assert _enum_get(_Color, "RED") is _Color.RED
|
||||
|
||||
|
||||
def test_enum_get_unknown_returns_default() -> None:
|
||||
assert _enum_get(_Color, "PURPLE", default=_Color.BLUE) is _Color.BLUE
|
||||
|
||||
|
||||
def test_enum_get_none_returns_default() -> None:
|
||||
assert _enum_get(_Color, None, default=_Color.RED) is _Color.RED
|
||||
|
||||
|
||||
def test_deserialize_video_minimal() -> None:
|
||||
v = _deserialize_video({"id": "video-1", "codec": "AVC", "width": 1920, "height": 1080, "bitrate": 5000})
|
||||
assert isinstance(v, Video)
|
||||
assert v.id == "video-1"
|
||||
assert v.codec is Video.Codec.AVC
|
||||
assert v.bitrate == 5_000_000 # kbps -> bps
|
||||
assert v.width == 1920
|
||||
assert v.height == 1080
|
||||
assert v.range is Video.Range.SDR
|
||||
|
||||
|
||||
def test_deserialize_video_unknown_codec_falls_back_to_none() -> None:
|
||||
v = _deserialize_video({"id": "v2", "codec": "MADE_UP", "width": 0, "height": 0})
|
||||
assert v.codec is None
|
||||
|
||||
|
||||
def test_deserialize_audio_atmos_flag_sets_joc() -> None:
|
||||
a = _deserialize_audio({"id": "a1", "codec": "AAC", "atmos": True, "channels": 6, "bitrate": 256})
|
||||
assert isinstance(a, Audio)
|
||||
assert a.joc == 1
|
||||
assert a.channels == 6
|
||||
assert a.bitrate == 256_000
|
||||
|
||||
|
||||
def test_deserialize_audio_no_atmos() -> None:
|
||||
a = _deserialize_audio({"id": "a2", "codec": "AAC", "channels": 2})
|
||||
assert a.joc == 0
|
||||
|
||||
|
||||
def test_deserialize_subtitle_forced_flag() -> None:
|
||||
s = _deserialize_subtitle({"id": "s1", "codec": "WebVTT", "language": "en", "forced": True})
|
||||
assert isinstance(s, Subtitle)
|
||||
assert s.forced is True
|
||||
assert s.sdh is False
|
||||
|
||||
|
||||
def test_deserialize_subtitle_sdh_flag() -> None:
|
||||
s = _deserialize_subtitle({"id": "s2", "codec": "WebVTT", "language": "en", "sdh": True})
|
||||
assert s.sdh is True
|
||||
assert s.forced is False
|
||||
|
||||
|
||||
def test_reconstruct_drm_empty() -> None:
|
||||
assert _reconstruct_drm(None) == []
|
||||
assert _reconstruct_drm([]) == []
|
||||
|
||||
|
||||
def test_reconstruct_drm_skips_entries_without_pssh() -> None:
|
||||
assert _reconstruct_drm([{"type": "widevine"}]) == []
|
||||
|
||||
|
||||
def test_reconstruct_drm_invalid_pssh_silently_dropped() -> None:
|
||||
assert _reconstruct_drm([{"type": "widevine", "pssh": "not-real-pssh"}]) == []
|
||||
|
||||
|
||||
def test_build_tracks_aggregates() -> None:
|
||||
data = {
|
||||
"video": [{"id": "v", "codec": "AVC", "width": 1280, "height": 720, "bitrate": 2500}],
|
||||
"audio": [{"id": "a", "codec": "AAC", "channels": 2, "bitrate": 128}],
|
||||
"subtitles": [{"id": "s", "codec": "WebVTT", "language": "en"}],
|
||||
"attachments": [],
|
||||
}
|
||||
t = _build_tracks(data)
|
||||
assert len(t.videos) == 1
|
||||
assert len(t.audio) == 1
|
||||
assert len(t.subtitles) == 1
|
||||
|
||||
|
||||
def test_match_track_by_id() -> None:
|
||||
a = _deserialize_video({"id": "v1", "codec": "AVC", "width": 1920, "height": 1080})
|
||||
b = _deserialize_video({"id": "v2", "codec": "AVC", "width": 1280, "height": 720})
|
||||
remote = _deserialize_video({"id": "v2", "codec": "AVC", "width": 1280, "height": 720})
|
||||
assert _match_track(remote, [a, b]) is b
|
||||
|
||||
|
||||
def test_match_track_by_attributes_when_id_missing() -> None:
|
||||
local = _deserialize_video({"id": "X", "codec": "AVC", "width": 1920, "height": 1080, "language": "en"})
|
||||
remote = _deserialize_video({"id": "Y", "codec": "AVC", "width": 1920, "height": 1080, "language": "en"})
|
||||
assert _match_track(remote, [local]) is local
|
||||
|
||||
|
||||
def test_match_track_no_candidates_returns_none() -> None:
|
||||
remote = _deserialize_video({"id": "X", "codec": "AVC", "width": 1, "height": 1})
|
||||
assert _match_track(remote, []) is None
|
||||
|
||||
|
||||
def test_build_title_movie() -> None:
|
||||
info = {"type": "movie", "id": "movie-0001", "name": "Foo", "year": 2024, "language": "en"}
|
||||
title = _build_title(info, "ATV", "fallback")
|
||||
assert isinstance(title, Movie)
|
||||
assert title.id == "movie-0001"
|
||||
assert title.name == "Foo"
|
||||
|
||||
|
||||
def test_build_title_episode() -> None:
|
||||
info = {
|
||||
"type": "episode",
|
||||
"id": "ep-00001",
|
||||
"series_title": "Show",
|
||||
"season": 1,
|
||||
"number": 2,
|
||||
"name": "Pilot",
|
||||
"year": 2024,
|
||||
"language": "en",
|
||||
}
|
||||
title = _build_title(info, "ATV", "fallback")
|
||||
assert isinstance(title, Episode)
|
||||
assert title.season == 1
|
||||
assert title.number == 2
|
||||
assert title.name == "Pilot"
|
||||
|
||||
|
||||
def test_build_title_falls_back_to_id_when_missing() -> None:
|
||||
title = _build_title({"type": "movie", "name": "x"}, "ATV", "fallback-id")
|
||||
assert title.id == "fallback-id"
|
||||
159
tests/remote/unit/test_routes.py
Normal file
159
tests/remote/unit/test_routes.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Unit tests for unshackle.core.api.routes.setup_routes wiring + CORS + auth gating.
|
||||
|
||||
We build small aiohttp apps in-test with setup_routes(), mirroring what
|
||||
unshackle/commands/serve.py does. We avoid hitting the real handlers by
|
||||
stubbing the route table for selected paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from unshackle.core.api.compression import compression_middleware
|
||||
from unshackle.core.api.routes import cors_middleware, setup_routes
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_app():
|
||||
"""Factory that builds an aiohttp app for tests."""
|
||||
|
||||
def _factory(remote_only: bool = False, with_auth_middleware: bool = False):
|
||||
middlewares = [cors_middleware, compression_middleware]
|
||||
if with_auth_middleware:
|
||||
middlewares.insert(1, _no_key_required_auth())
|
||||
app = web.Application(middlewares=middlewares)
|
||||
app["config"] = {"users": {}}
|
||||
app["debug_api"] = False
|
||||
setup_routes(app, remote_only=remote_only)
|
||||
return app
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
def _no_key_required_auth():
|
||||
"""Mirror serve.py's api_key_authentication middleware: required X-Secret-Key
|
||||
on every endpoint except /api/health."""
|
||||
|
||||
@web.middleware
|
||||
async def mw(request, handler):
|
||||
if request.path == "/api/health":
|
||||
return await handler(request)
|
||||
secret = request.headers.get("X-Secret-Key")
|
||||
if not secret:
|
||||
return web.json_response({"status": 401, "message": "Secret Key is Empty."}, status=401)
|
||||
if secret not in request.app["config"]["users"]:
|
||||
return web.json_response({"status": 401, "message": "Secret Key is Invalid."}, status=401)
|
||||
return await handler(request)
|
||||
|
||||
return mw
|
||||
|
||||
|
||||
def _collect_paths(app: web.Application) -> list[tuple[str, str]]:
|
||||
return sorted({(r.method, r.resource.canonical) for r in app.router.routes()})
|
||||
|
||||
|
||||
def test_setup_routes_full_mode_wires_all_endpoints(make_app) -> None:
|
||||
app = make_app(remote_only=False)
|
||||
paths = _collect_paths(app)
|
||||
expected = {
|
||||
("GET", "/api/health"),
|
||||
("GET", "/api/services"),
|
||||
("POST", "/api/search"),
|
||||
("POST", "/api/list-titles"),
|
||||
("POST", "/api/list-tracks"),
|
||||
("POST", "/api/download"),
|
||||
("GET", "/api/download/jobs"),
|
||||
("GET", "/api/download/jobs/{job_id}"),
|
||||
("DELETE", "/api/download/jobs/{job_id}"),
|
||||
("POST", "/api/session/create"),
|
||||
("GET", "/api/session/{session_id}/titles"),
|
||||
("POST", "/api/session/{session_id}/tracks"),
|
||||
("POST", "/api/session/{session_id}/segments"),
|
||||
("POST", "/api/session/{session_id}/license"),
|
||||
("GET", "/api/session/{session_id}/prompt"),
|
||||
("POST", "/api/session/{session_id}/prompt"),
|
||||
("GET", "/api/session/{session_id}"),
|
||||
("DELETE", "/api/session/{session_id}"),
|
||||
}
|
||||
assert expected.issubset(set(paths))
|
||||
|
||||
|
||||
def test_setup_routes_remote_only_excludes_list_and_download(make_app) -> None:
|
||||
app = make_app(remote_only=True)
|
||||
paths = set(_collect_paths(app))
|
||||
assert ("POST", "/api/list-titles") not in paths
|
||||
assert ("POST", "/api/list-tracks") not in paths
|
||||
assert ("POST", "/api/download") not in paths
|
||||
assert ("GET", "/api/download/jobs") not in paths
|
||||
# session endpoints still present
|
||||
assert ("POST", "/api/session/create") in paths
|
||||
assert ("GET", "/api/session/{session_id}/titles") in paths
|
||||
assert ("POST", "/api/session/{session_id}/license") in paths
|
||||
|
||||
|
||||
async def test_cors_preflight_returns_headers(make_app, aiohttp_client) -> None:
|
||||
app = make_app(remote_only=True)
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.options("/api/health")
|
||||
assert resp.status == 200
|
||||
assert resp.headers["Access-Control-Allow-Origin"] == "*"
|
||||
assert "GET" in resp.headers["Access-Control-Allow-Methods"]
|
||||
assert "X-Secret-Key" in resp.headers["Access-Control-Allow-Headers"]
|
||||
|
||||
|
||||
async def test_health_endpoint_responds_ok(make_app, aiohttp_client, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from unshackle.core.api import routes as routes_mod
|
||||
|
||||
async def _no_update(_):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(routes_mod.UpdateChecker, "check_for_updates", _no_update)
|
||||
app = make_app(remote_only=True)
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/api/health")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["status"] == "ok"
|
||||
assert "version" in body
|
||||
|
||||
|
||||
async def test_health_bypasses_api_key_auth_middleware(make_app, aiohttp_client, monkeypatch) -> None:
|
||||
from unshackle.core.api import routes as routes_mod
|
||||
|
||||
async def _no_update(_):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(routes_mod.UpdateChecker, "check_for_updates", _no_update)
|
||||
|
||||
app = make_app(remote_only=True, with_auth_middleware=True)
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/api/health")
|
||||
assert resp.status == 200 # health bypasses auth
|
||||
|
||||
|
||||
async def test_auth_middleware_rejects_missing_key(make_app, aiohttp_client) -> None:
|
||||
app = make_app(remote_only=True, with_auth_middleware=True)
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/api/session/abc")
|
||||
assert resp.status == 401
|
||||
body = await resp.json()
|
||||
assert "Secret Key" in body["message"]
|
||||
|
||||
|
||||
async def test_auth_middleware_rejects_invalid_key(make_app, aiohttp_client) -> None:
|
||||
app = make_app(remote_only=True, with_auth_middleware=True)
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/api/session/abc", headers={"X-Secret-Key": "wrong"})
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
async def test_auth_middleware_accepts_known_key(make_app, aiohttp_client) -> None:
|
||||
app = make_app(remote_only=True, with_auth_middleware=True)
|
||||
app["config"]["users"]["good-key"] = {"devices": []}
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/api/session/nonexistent", headers={"X-Secret-Key": "good-key"})
|
||||
# Auth passed; handler then 404s the session — anything other than 401 is fine here.
|
||||
assert resp.status != 401
|
||||
79
tests/remote/unit/test_serve_cli.py
Normal file
79
tests/remote/unit/test_serve_cli.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Unit tests for the `unshackle serve` Click command flag surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from unshackle.commands.serve import serve
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner() -> CliRunner:
|
||||
return CliRunner()
|
||||
|
||||
|
||||
def test_serve_help_lists_documented_flags(runner: CliRunner) -> None:
|
||||
result = runner.invoke(serve, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
out = result.output
|
||||
for flag in (
|
||||
"--host",
|
||||
"--port",
|
||||
"--caddy",
|
||||
"--api-only",
|
||||
"--no-widevine",
|
||||
"--no-playready",
|
||||
"--no-key",
|
||||
"--debug-api",
|
||||
"--debug",
|
||||
"--remote-only",
|
||||
):
|
||||
assert flag in out, f"missing flag in --help: {flag}"
|
||||
|
||||
|
||||
def test_serve_api_only_with_no_widevine_rejected(runner: CliRunner, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""`--api-only` is mutually exclusive with `--no-widevine`/`--no-playready`."""
|
||||
monkeypatch.setenv("UNSHACKLE_NO_RUN", "1") # belt-and-braces; not currently checked
|
||||
|
||||
# Stub web.run_app to avoid actually starting the server if validation passes.
|
||||
from aiohttp import web
|
||||
|
||||
monkeypatch.setattr(web, "run_app", lambda *a, **kw: None)
|
||||
|
||||
# Force a clean config.serve so no_key path doesn't blow up loading wvds.
|
||||
from unshackle.core.config import config as cfg
|
||||
|
||||
monkeypatch.setattr(cfg, "serve", {"api_secret": "x"})
|
||||
|
||||
result = runner.invoke(serve, ["--api-only", "--no-widevine", "--no-key"])
|
||||
assert result.exit_code != 0
|
||||
assert "Cannot use --api-only" in (result.output or str(result.exception))
|
||||
|
||||
|
||||
def test_serve_no_key_without_api_secret_does_not_require_secret(
|
||||
runner: CliRunner, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""With --no-key, the missing api_secret check is bypassed."""
|
||||
from aiohttp import web
|
||||
|
||||
monkeypatch.setattr(web, "run_app", lambda *a, **kw: None)
|
||||
from unshackle.core.config import config as cfg
|
||||
|
||||
monkeypatch.setattr(cfg, "serve", {})
|
||||
|
||||
result = runner.invoke(serve, ["--api-only", "--no-key", "--remote-only"])
|
||||
# No exception should escape, exit code 0 means startup proceeded then run_app stub returned.
|
||||
assert result.exit_code == 0, result.output
|
||||
|
||||
|
||||
def test_serve_without_no_key_requires_api_secret(runner: CliRunner, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from unshackle.core.config import config as cfg
|
||||
|
||||
monkeypatch.setattr(cfg, "serve", {}) # no api_secret configured
|
||||
|
||||
result = runner.invoke(serve, ["--api-only"])
|
||||
assert result.exit_code != 0
|
||||
assert "api_secret" in (result.output or "").lower() or "api_secret" in str(result.exception).lower()
|
||||
117
tests/remote/unit/test_session_store.py
Normal file
117
tests/remote/unit/test_session_store.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Unit tests for unshackle.core.api.session_store.SessionStore."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from unshackle.core.api.input_bridge import AuthStatus, InputBridge
|
||||
from unshackle.core.api.session_store import SessionEntry, SessionStore, get_session_store
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class _FakeService:
|
||||
"""Minimal stub Service used to fill SessionEntry.service_instance."""
|
||||
|
||||
def __init__(self, tag: str = "TEST") -> None:
|
||||
self.tag = tag
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store() -> SessionStore:
|
||||
return SessionStore()
|
||||
|
||||
|
||||
async def test_create_returns_entry_with_uuid(store: SessionStore) -> None:
|
||||
entry = await store.create("ATV", _FakeService())
|
||||
assert isinstance(entry, SessionEntry)
|
||||
assert entry.service_tag == "ATV"
|
||||
assert entry.session_id and len(entry.session_id) >= 32
|
||||
assert store.session_count == 1
|
||||
|
||||
|
||||
async def test_create_with_explicit_session_id(store: SessionStore) -> None:
|
||||
entry = await store.create("NF", _FakeService(), session_id="fixed-id")
|
||||
assert entry.session_id == "fixed-id"
|
||||
|
||||
|
||||
async def test_get_returns_none_for_missing(store: SessionStore) -> None:
|
||||
assert await store.get("nope") is None
|
||||
|
||||
|
||||
async def test_get_touches_last_accessed(store: SessionStore) -> None:
|
||||
entry = await store.create("DSNP", _FakeService())
|
||||
before = entry.last_accessed
|
||||
await asyncio.sleep(0.01)
|
||||
fetched = await store.get(entry.session_id)
|
||||
assert fetched is entry
|
||||
assert fetched.last_accessed > before
|
||||
|
||||
|
||||
async def test_delete_removes_and_cancels_bridge(store: SessionStore) -> None:
|
||||
entry = await store.create("CRAV", _FakeService())
|
||||
entry.input_bridge = InputBridge()
|
||||
assert entry.input_bridge.status is AuthStatus.AUTHENTICATING
|
||||
|
||||
deleted = await store.delete(entry.session_id)
|
||||
assert deleted is True
|
||||
assert entry.input_bridge.status is AuthStatus.FAILED # cancelled
|
||||
assert store.session_count == 0
|
||||
|
||||
|
||||
async def test_delete_returns_false_when_missing(store: SessionStore) -> None:
|
||||
assert await store.delete("missing") is False
|
||||
|
||||
|
||||
async def test_cleanup_expired_drops_old_authenticated(store: SessionStore, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
entry = await store.create("ATV", _FakeService())
|
||||
entry.last_accessed = datetime.now(timezone.utc) - timedelta(seconds=store._ttl + 100)
|
||||
removed = await store.cleanup_expired()
|
||||
assert removed == 1
|
||||
assert store.session_count == 0
|
||||
|
||||
|
||||
async def test_cleanup_expired_keeps_pending_input_under_grace(store: SessionStore) -> None:
|
||||
"""Sessions awaiting user input get a longer grace period (10 min) than authenticated TTL."""
|
||||
entry = await store.create("ATV", _FakeService())
|
||||
entry.input_bridge = InputBridge()
|
||||
entry.auth_status = AuthStatus.PENDING_INPUT
|
||||
removed = await store.cleanup_expired()
|
||||
assert removed == 0
|
||||
assert store.session_count == 1
|
||||
|
||||
|
||||
async def test_cancel_all_bridges(store: SessionStore) -> None:
|
||||
a = await store.create("ATV", _FakeService())
|
||||
b = await store.create("NF", _FakeService())
|
||||
a.input_bridge = InputBridge()
|
||||
b.input_bridge = InputBridge()
|
||||
|
||||
await store.cancel_all_bridges()
|
||||
assert a.input_bridge.status is AuthStatus.FAILED
|
||||
assert b.input_bridge.status is AuthStatus.FAILED
|
||||
|
||||
|
||||
async def test_get_session_store_returns_singleton() -> None:
|
||||
a = get_session_store()
|
||||
b = get_session_store()
|
||||
assert a is b
|
||||
|
||||
|
||||
async def test_max_sessions_evicts_oldest(store: SessionStore, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(type(store), "_max_sessions", property(lambda _: 2))
|
||||
|
||||
a = await store.create("A", _FakeService(), session_id="a")
|
||||
await asyncio.sleep(0.01)
|
||||
b = await store.create("B", _FakeService(), session_id="b")
|
||||
await asyncio.sleep(0.01)
|
||||
c = await store.create("C", _FakeService(), session_id="c")
|
||||
|
||||
assert store.session_count == 2
|
||||
assert await store.get("a") is None # evicted
|
||||
assert (await store.get("b")) is b
|
||||
assert (await store.get("c")) is c
|
||||
Reference in New Issue
Block a user