From 2f35a4d46887defb3800ddfe663cd59a7cba9fa6 Mon Sep 17 00:00:00 2001 From: imSp4rky Date: Mon, 8 Jun 2026 15:37:40 -0600 Subject: [PATCH] feat(api): aggregate REST download progress with weighting, track labels and mux stage Replace the class-level Track.download monkeypatch with a per-job progress sink threaded through dl.result(). The API now reports a single aggregate signal instead of each track's bouncing 0-100%: - bitrate-weighted completion so video/audio dominate subtitles - completed_tracks/total_tracks counts and active_tracks labels (e.g. "video 2160p DV", "audio en-US 5.1") - downloads fill 0-90%; repackaging (when needed) and a "muxing" stage carry it to 100% so post-download work is no longer frozen at 100% - monotonic throughout (handles the download->decrypt callable reuse) Also: - accept "HDR10P" as the canonical API range value ("HDR10+" still works) - declare AUTH_METHODS opt-in on the Service base - raise typed APIError (WORKER_ERROR/DOWNLOAD_ERROR) from the worker path - move the progress helpers to unshackle/core/api/progress.py --- .../unit/test_download_gates_redaction.py | 24 ++- tests/remote/unit/test_progress_sink.py | 155 ++++++++++++++++++ unshackle/commands/dl.py | 12 ++ unshackle/core/api/download_manager.py | 79 +++------ unshackle/core/api/handlers.py | 14 +- unshackle/core/api/progress.py | 118 +++++++++++++ unshackle/core/api/routes.py | 4 +- unshackle/core/service.py | 3 + 8 files changed, 333 insertions(+), 76 deletions(-) create mode 100644 tests/remote/unit/test_progress_sink.py create mode 100644 unshackle/core/api/progress.py diff --git a/tests/remote/unit/test_download_gates_redaction.py b/tests/remote/unit/test_download_gates_redaction.py index 5d45d97..3d36138 100644 --- a/tests/remote/unit/test_download_gates_redaction.py +++ b/tests/remote/unit/test_download_gates_redaction.py @@ -9,13 +9,8 @@ import pytest from aiohttp import web from unshackle.core.api import handlers -from unshackle.core.api.download_manager import ( - DownloadJob, - JobStatus, - _redact_parameters, - _redact_text, - _secret_values, -) +from unshackle.core.api.download_manager import (DownloadJob, JobStatus, _redact_parameters, _redact_text, + _secret_values) from unshackle.core.api.errors import APIError, APIErrorCode pytestmark = pytest.mark.unit @@ -140,3 +135,18 @@ async def test_credential_allowed_when_enabled(stub_handler): stub_handler.setattr(handlers.config, "serve", {"allow_job_credentials": True}) resp = await handlers.download_handler({"service": "ATV", "title_id": "t", "credential": "u:p"}) assert isinstance(resp, web.Response) + + +# ---------- range validation ---------- + + +def test_range_validation_accepts_hdr10p_and_alias(): + # canonical "HDR10P" and back-compat "HDR10+" both pass; mixed casing too + assert handlers.validate_download_parameters({"range": ["HDR10P", "DV", "SDR"]}) is None + assert handlers.validate_download_parameters({"range": ["hdr10+"]}) is None + assert handlers.validate_download_parameters({"range": "HYBRID"}) is None + + +def test_range_validation_rejects_unknown_and_lists_hdr10p(): + err = handlers.validate_download_parameters({"range": ["HDR99"]}) + assert err and "HDR10P" in err and "HDR99" in err diff --git a/tests/remote/unit/test_progress_sink.py b/tests/remote/unit/test_progress_sink.py new file mode 100644 index 0000000..934832e --- /dev/null +++ b/tests/remote/unit/test_progress_sink.py @@ -0,0 +1,155 @@ +"""Unit tests for the aggregate per-job download progress sink. + +``build_job_progress_callables`` wraps the per-track progress callables so the API job sees one +aggregate signal - a bitrate-weighted completion percentage, track counts, and the labels of the +tracks downloading now - instead of each track's own bouncing 0-100%. These tests pin that +contract.""" + +from __future__ import annotations + +import pytest + +from unshackle.core.api.progress import (DOWNLOAD_PROGRESS_CEILING, build_job_progress_callables, + track_progress_label, track_progress_weight) + +pytestmark = pytest.mark.unit + + +# --- lightweight track stand-ins (label/weight key off class name + attributes) --- +class _Range: + def __init__(self, value): + self.value = value + + +class Video: + def __init__(self, height=1080, range_value="SDR", bitrate=4_000_000): + self.height = height + self.range = _Range(range_value) + self.bitrate = bitrate + + +class Audio: + def __init__(self, language="en-US", channels="2.0", bitrate=200_000): + self.language = language + self.channels = channels + self.bitrate = bitrate + + +class Subtitle: + def __init__(self, language="fr"): + self.language = language + self.bitrate = None + + +def _noop(**kwargs): + pass + + +def test_track_progress_label(): + assert track_progress_label(Video(2160, "DV")) == "video 2160p DV" + assert track_progress_label(Video(1080, "HDR10+")) == "video 1080p HDR10+" + assert track_progress_label(Audio("en-US", "5.1")) == "audio en-US 5.1" + assert track_progress_label(Subtitle("ro")) == "subtitle ro" + + +def test_weight_video_over_audio_over_subtitle(): + assert track_progress_weight(Video(bitrate=4_000_000)) == 4_000_000 + assert track_progress_weight(Audio(bitrate=200_000)) == 200_000 + # subtitle has no bitrate -> small fixed weight, far below media + assert track_progress_weight(Subtitle()) < track_progress_weight(Audio(bitrate=200_000)) + + +def test_weighting_makes_video_dominate_progress(): + updates: list[dict] = [] + video, sub = Video(bitrate=4_000_000), Subtitle() + cbs = build_job_progress_callables([video, sub], [_noop, _noop], updates.append) + + # subtitle fully done, video untouched -> progress is tiny (subtitle barely weighted) + cbs[1](downloaded="Downloaded") + assert updates[-1]["completed_tracks"] == 1 + assert updates[-1]["progress"] < 5.0 + + # video half done -> progress is dominated by video (scaled into the 0..ceiling download band) + cbs[0](total=100, completed=50) + assert updates[-1]["progress"] > 40.0 + + +def test_active_tracks_labels_reported_and_cleared_on_done(): + updates: list[dict] = [] + cbs = build_job_progress_callables( + [Video(2160, "DV"), Audio("en-US", "2.0")], [_noop, _noop], updates.append + ) + + cbs[0](total=100, completed=10) # video downloading + assert updates[-1]["active_tracks"] == ["video 2160p DV"] + assert updates[-1]["phase"] == "downloading video 2160p DV" + + cbs[1](total=100, completed=10) # audio also downloading + assert updates[-1]["active_tracks"] == ["video 2160p DV", "audio en-US 2.0"] + + cbs[0](downloaded="Downloaded") # video done -> drops out of active + assert updates[-1]["active_tracks"] == ["audio en-US 2.0"] + + +def test_aggregate_progress_is_monotonic_with_counts(): + updates: list[dict] = [] + inner_calls = [0, 0, 0] + + def make_inner(i): + def inner(**kwargs): + inner_calls[i] += 1 + + return inner + + tracks = [Video(bitrate=1000), Audio(bitrate=1000), Subtitle()] + cbs = build_job_progress_callables(tracks, [make_inner(0), make_inner(1), make_inner(2)], updates.append) + assert len(cbs) == 3 + + cbs[0](total=100, completed=50) + cbs[0](downloaded="Downloaded") + cbs[1](total=100, completed=50) + + progresses = [u["progress"] for u in updates] + assert progresses == sorted(progresses) + assert updates[-1]["completed_tracks"] == 1 + assert updates[-1]["total_tracks"] == 3 + assert inner_calls == [2, 1, 0] + + +def test_all_tracks_done_reaches_download_ceiling(): + # Downloads fill up to the ceiling; dl.result drives muxing the rest of the way to 100. + updates: list[dict] = [] + cbs = build_job_progress_callables([Audio(bitrate=1000), Audio(bitrate=1000)], [_noop, _noop], updates.append) + + cbs[0](total=10, completed=10, downloaded="Downloaded") + assert updates[-1]["progress"] < DOWNLOAD_PROGRESS_CEILING + assert updates[-1]["completed_tracks"] == 1 + + cbs[1](total=10, completed=10, downloaded="Decrypted") + assert updates[-1]["progress"] == pytest.approx(DOWNLOAD_PROGRESS_CEILING) + assert updates[-1]["completed_tracks"] == 2 + + +def test_finished_track_does_not_dip_when_callable_reused_for_decrypt(): + """A track hits 100% (then decrypt reuses the callable with completed=0); the aggregate must + hold, never dip - even before the terminal 'Downloaded'/'Decrypted' string arrives.""" + updates: list[dict] = [] + cbs = build_job_progress_callables([Video(bitrate=1000), Video(bitrate=1000)], [_noop, _noop], updates.append) + + cbs[0](total=100, completed=100) # download hits 100% BEFORE any terminal string -> 50% + cbs[0](total=200, completed=0) # decrypt phase resets counts, still no terminal string + cbs[0](total=200, completed=100) # decrypt mid-way + cbs[0](total=200, completed=200, downloaded="Decrypted") # terminal + + progresses = [u["progress"] for u in updates] + assert progresses == sorted(progresses) # monotonic, no dip + assert updates[-1]["progress"] == pytest.approx(DOWNLOAD_PROGRESS_CEILING / 2) + assert updates[-1]["completed_tracks"] == 1 + + +def test_skipped_subtitle_counts_as_done(): + updates: list[dict] = [] + cbs = build_job_progress_callables([Subtitle()], [_noop], updates.append) + cbs[0](downloaded="[yellow]SKIPPED") + assert updates[-1]["completed_tracks"] == 1 + assert updates[-1]["progress"] == pytest.approx(DOWNLOAD_PROGRESS_CEILING) diff --git a/unshackle/commands/dl.py b/unshackle/commands/dl.py index 0ce5010..cc51f79 100644 --- a/unshackle/commands/dl.py +++ b/unshackle/commands/dl.py @@ -1143,6 +1143,7 @@ class dl: split_audio: Optional[bool] = None, real_video_bitrate: bool = False, real_audio_bitrate: bool = False, + progress_sink: Optional[Callable[[dict[str, Any]], None]] = None, *_: Any, **__: Any, ) -> None: @@ -2214,6 +2215,13 @@ class dl: selected_tracks, tracks_progress_callables = title.tracks.tree(add_progress=True) + if progress_sink is not None: + from unshackle.core.api.progress import build_job_progress_callables + + tracks_progress_callables = build_job_progress_callables( + list(title.tracks), tracks_progress_callables, progress_sink + ) + for track in title.tracks: if hasattr(track, "needs_drm_loading") and track.needs_drm_loading: track.load_drm_if_needed(service) @@ -2470,6 +2478,8 @@ class dl: break # Now repack the decrypted tracks + if progress_sink and any(getattr(t, "needs_repack", False) for t in title.tracks): + progress_sink({"phase": "repackaging", "progress": 92.0, "status": "downloading", "active_tracks": []}) with console.status("Repackaging tracks with FFMPEG..."): has_repacked = False for track in title.tracks: @@ -2636,6 +2646,8 @@ class dl: for video_track in title.tracks.videos or [None]: mux_video_standalone(video_track) + if progress_sink: + progress_sink({"phase": "muxing", "progress": 96.0, "status": "downloading", "active_tracks": []}) try: with Live(Padding(progress, (0, 5, 1, 5)), console=console): mux_index = 0 diff --git a/unshackle/core/api/download_manager.py b/unshackle/core/api/download_manager.py index 50514d1..4b3a40e 100644 --- a/unshackle/core/api/download_manager.py +++ b/unshackle/core/api/download_manager.py @@ -108,8 +108,11 @@ class DownloadJob: error_traceback: Optional[str] = None worker_stderr: Optional[str] = None - # Human-readable current phase (e.g. "downloading video 1080p") + # Current phase, track counts, and labels of the tracks downloading now. phase: Optional[str] = None + completed_tracks: int = 0 + total_tracks: int = 0 + active_tracks: List[str] = field(default_factory=list) # Subtitles skipped under skip_subtitle_errors (non-fatal). Each entry is a dl.SkippedSubtitle # dict (id / language / title) so a client can report which weren't available. @@ -128,6 +131,9 @@ class DownloadJob: "title_id": self.title_id, "progress": self.progress, "phase": self.phase, + "completed_tracks": self.completed_tracks, + "total_tracks": self.total_tracks, + "active_tracks": self.active_tracks, "skipped_subtitles": self.skipped_subtitles, } @@ -175,6 +181,7 @@ def _perform_download( import yaml from unshackle.commands.dl import dl + from unshackle.core.api.errors import APIError, APIErrorCode from unshackle.core.config import config from unshackle.core.services import Services from unshackle.core.tracks import Subtitle, Video @@ -359,69 +366,16 @@ def _perform_download( stdout_capture = StringIO() stderr_capture = StringIO() - # Simple progress tracking if callback provided + # The progress_sink (dl.build_job_progress_callables) owns the percentage; status changes + # are emitted here. if progress_callback: - # Report initial progress progress_callback({"progress": 0.0, "status": "starting"}) - - # Tee each Track.download's progress callable so the downloader's live percentage - # is forwarded to the API job (not just 5%/100%), and expose which track is being - # downloaded now as a human-readable phase. - from unshackle.core.tracks.track import Track as _Track - - if not getattr(_Track, "_api_progress_patched", False): - _orig_track_download = _Track.download - - def _download_with_progress(self, *args, **kwargs): - inner_progress = kwargs.get("progress") - track_type = type(self).__name__ - phase = { - "Video": "downloading video", - "Audio": "downloading audio", - "Subtitle": "downloading subtitle", - }.get(track_type, f"downloading {track_type.lower()}") - height = getattr(self, "height", None) - language = getattr(self, "language", None) - if height: - phase += f" {height}p" - elif track_type in ("Audio", "Subtitle") and language: - phase += f" {language}" - progress_callback({"phase": phase, "status": "downloading"}) - - if callable(inner_progress): - counts = {"completed": 0.0, "total": 0.0} - - def tee(*tee_args, **tee_kwargs): - if tee_kwargs.get("total"): - counts["total"] = tee_kwargs["total"] - if tee_kwargs.get("completed") is not None: - counts["completed"] = tee_kwargs["completed"] - if "advance" in tee_kwargs: - counts["completed"] += tee_kwargs["advance"] - pct = counts["completed"] * 100.0 / counts["total"] if counts["total"] else 0 - if pct: - progress_callback( - {"progress": min(99.0, float(pct)), "phase": phase, "status": "downloading"} - ) - return inner_progress(*tee_args, **tee_kwargs) - - kwargs["progress"] = tee - return _orig_track_download(self, *args, **kwargs) - - _Track.download = _download_with_progress - _Track._api_progress_patched = True - original_result = dl_instance.result def result_with_progress(*args, **kwargs): try: - # Report that download started - progress_callback({"progress": 5.0, "status": "downloading"}) - - # Call original method + progress_callback({"status": "downloading"}) result = original_result(*args, **kwargs) - - # Report completion progress_callback({"progress": 100.0, "status": "completed"}) return result except Exception as e: @@ -481,6 +435,7 @@ def _perform_download( worst=params.get("worst", False), best_available=params.get("best_available", False), split_audio=params.get("split_audio"), + progress_sink=progress_callback, ) except SystemExit as exc: @@ -490,7 +445,7 @@ def _perform_download( log.error(f"Download exited with code {exc.code}") log.error(f"Stdout: {stdout_str}") log.error(f"Stderr: {stderr_str}") - raise Exception(f"Download failed with exit code {exc.code}") + raise APIError(APIErrorCode.DOWNLOAD_ERROR, f"Download failed with exit code {exc.code}") except Exception as exc: # noqa: BLE001 - propagate to caller stdout_str = stdout_capture.getvalue() @@ -504,7 +459,7 @@ def _perform_download( # It sets download_failed in that case, so the job isn't reported as completed with no output. if getattr(dl_instance, "download_failed", False): detail = (stdout_capture.getvalue() + stderr_capture.getvalue())[-200:].strip() - raise Exception("download worker failed: " + (detail or "see logs")) + raise APIError(APIErrorCode.WORKER_ERROR, "download worker failed: " + (detail or "see logs")) # Surface any subtitles that were skipped (non-fatal failures) so the client can report them. if progress_callback: @@ -796,6 +751,12 @@ class DownloadQueueManager: progress_data = json.load(handle) if progress_data.get("phase") and progress_data["phase"] != job.phase: job.phase = progress_data["phase"] + if progress_data.get("total_tracks"): + job.total_tracks = int(progress_data["total_tracks"]) + if progress_data.get("completed_tracks") is not None: + job.completed_tracks = int(progress_data["completed_tracks"]) + if "active_tracks" in progress_data: + job.active_tracks = list(progress_data["active_tracks"]) if progress_data.get("skipped_subtitles"): job.skipped_subtitles = progress_data["skipped_subtitles"] if "progress" in progress_data: diff --git a/unshackle/core/api/handlers.py b/unshackle/core/api/handlers.py index ef509af..8641e4a 100644 --- a/unshackle/core/api/handlers.py +++ b/unshackle/core/api/handlers.py @@ -957,13 +957,13 @@ def validate_download_parameters(data: Dict[str, Any]) -> Optional[str]: return "Cannot use both s_lang and require_subs" if "range" in data and data["range"]: - valid_ranges = ["SDR", "HDR10", "HDR10+", "DV", "HLG", "HYBRID"] - if isinstance(data["range"], list): - for r in data["range"]: - if r.upper() not in valid_ranges: - return f"Invalid range value: {r}. Must be one of: {', '.join(valid_ranges)}" - elif data["range"].upper() not in valid_ranges: - return f"Invalid range value: {data['range']}. Must be one of: {', '.join(valid_ranges)}" + # "HDR10P" is the canonical range value ("+" is awkward in scripts); "HDR10+" stays valid. + valid_ranges = ["SDR", "HDR10", "HDR10P", "DV", "HLG", "HYBRID"] + accepted = {*valid_ranges, "HDR10+"} + values = data["range"] if isinstance(data["range"], list) else [data["range"]] + for r in values: + if r.upper() not in accepted: + return f"Invalid range value: {r}. Must be one of: {', '.join(valid_ranges)}" return None diff --git a/unshackle/core/api/progress.py b/unshackle/core/api/progress.py new file mode 100644 index 0000000..1cfbd45 --- /dev/null +++ b/unshackle/core/api/progress.py @@ -0,0 +1,118 @@ +"""Aggregate job-level download progress for the REST API. + +Turns the per-track progress callables from ``Tracks.tree`` into one signal a job can report: +a bitrate-weighted percentage, track counts, and the labels of the tracks downloading now. +""" + +from __future__ import annotations + +from typing import Any, Callable + +from unshackle.core.constants import AnyTrack + +JOB_PROGRESS_TERMINAL_STATES = {"Downloaded", "Decrypted", "[yellow]SKIPPED"} + +# Weight for a track with no bitrate (subtitles); small vs media bitrates so subs barely move the bar. +SUBTITLE_PROGRESS_WEIGHT = 50_000.0 + +# Downloads fill 0..this; dl.result drives the remainder (repackaging, muxing) up to 100. +DOWNLOAD_PROGRESS_CEILING = 90.0 + + +def track_progress_label(track: AnyTrack) -> str: + """Short label for a track, e.g. "video 2160p DV", "audio en-US 5.1", "subtitle fr".""" + track_type = type(track).__name__ + if track_type == "Video": + parts = ["video"] + height = getattr(track, "height", None) + if height: + parts.append(f"{height}p") + track_range = getattr(track, "range", None) + if track_range is not None: + parts.append(track_range.value) + return " ".join(parts) + if track_type == "Audio": + parts = ["audio"] + language = getattr(track, "language", None) + if language: + parts.append(str(language)) + channels = getattr(track, "channels", None) + if channels: + parts.append(str(channels)) + return " ".join(parts) + if track_type == "Subtitle": + language = getattr(track, "language", None) + return f"subtitle {language}" if language else "subtitle" + return track_type.lower() + + +def track_progress_weight(track: AnyTrack) -> float: + """Track weight in the aggregate (its bitrate in bits/s), so video/audio dominate subtitles.""" + bitrate = getattr(track, "bitrate", None) + return float(bitrate) if bitrate else SUBTITLE_PROGRESS_WEIGHT + + +def build_job_progress_callables( + tracks: list[AnyTrack], + inner_callables: list[Callable[..., None]], + sink: Callable[[dict[str, Any]], None], +) -> list[Callable[..., None]]: + """Wrap each track's progress callable so ``sink`` receives aggregate job progress. + + The sink gets a bitrate-weighted mean completion across all tracks, ``completed_tracks`` / + ``total_tracks`` counts, and ``active_tracks`` labels. Each track's fraction is monotonic, so + the percentage only climbs. The original ``inner`` callable is always invoked. + """ + total = len(inner_callables) + weights = [track_progress_weight(t) for t in tracks] + labels = [track_progress_label(t) for t in tracks] + total_weight = sum(weights) or 1.0 + fractions = [0.0] * total + done = [False] * total + started = [False] * total + + def emit() -> None: + completed = sum(done) + # Downloads fill 0..DOWNLOAD_PROGRESS_CEILING; dl.result drives muxing up to 100. + progress = sum(w * f for w, f in zip(weights, fractions)) * DOWNLOAD_PROGRESS_CEILING / total_weight + active = [labels[i] for i in range(total) if started[i] and not done[i]] + if active: + phase = "downloading " + ", ".join(active[:3]) + if len(active) > 3: + phase += f" (+{len(active) - 3} more)" + else: + phase = f"downloading {completed}/{total} tracks" + sink( + { + "progress": progress, + "phase": phase, + "completed_tracks": completed, + "total_tracks": total, + "active_tracks": active, + "status": "downloading", + } + ) + + def wrap(index: int, inner: Callable[..., None]) -> Callable[..., None]: + counts = {"completed": 0.0, "total": 0.0} + + def tee(*args: Any, **kwargs: Any) -> None: + started[index] = True + if kwargs.get("total"): + counts["total"] = kwargs["total"] + if kwargs.get("completed") is not None: + counts["completed"] = kwargs["completed"] + if "advance" in kwargs: + counts["completed"] += kwargs["advance"] + if kwargs.get("downloaded") in JOB_PROGRESS_TERMINAL_STATES: + done[index] = True + fractions[index] = 1.0 + elif counts["total"]: + # max() keeps the fraction monotonic across the download->decrypt callable reuse. + fractions[index] = max(fractions[index], min(1.0, counts["completed"] / counts["total"])) + emit() + return inner(*args, **kwargs) + + return tee + + return [wrap(i, inner) for i, inner in enumerate(inner_callables)] diff --git a/unshackle/core/api/routes.py b/unshackle/core/api/routes.py index 43daa40..d6d9961 100644 --- a/unshackle/core/api/routes.py +++ b/unshackle/core/api/routes.py @@ -233,9 +233,7 @@ async def services(request: web.Request) -> web.Response: or getattr(service_module, "get_playready_license", None) is not _BaseService.get_playready_license ) - # Auth methods the service accepts. Prefer an explicit `AUTH_METHODS` class var - # (reliable); otherwise fall back to inferring from what authenticate() references - # - that mostly returns both because services call super().authenticate(...). + # Prefer the service's explicit AUTH_METHODS; otherwise infer from authenticate(). methods = [] if service_data["needs_auth"]: declared = getattr(service_module, "AUTH_METHODS", None) diff --git a/unshackle/core/service.py b/unshackle/core/service.py index 5c252da..417984b 100644 --- a/unshackle/core/service.py +++ b/unshackle/core/service.py @@ -95,6 +95,9 @@ class Service(metaclass=ABCMeta): GEOFENCE: tuple[str, ...] = () # list of ip regions required to use the service. empty list == no specific region. # vault namespace override; when set, key vault read/write uses this tag instead of the service's own. VAULT_TAG: Optional[str] = None + # Auth methods the service accepts ("cookies"/"credentials"); when None the REST /services + # endpoint infers them from authenticate(). + AUTH_METHODS: Optional[tuple[str, ...]] = None def __init__(self, ctx: click.Context): console.print(Padding(Rule(f"[rule.text]Service: {self.__class__.__name__}"), (1, 2)))