feat(api): aggregate REST download progress with weighting, track labels and mux stage

Replace the class-level Track.download monkeypatch with a per-job progress sink threaded through dl.result(). The API now reports a single aggregate signal instead of each track's bouncing 0-100%:

- bitrate-weighted completion so video/audio dominate subtitles
- completed_tracks/total_tracks counts and active_tracks labels (e.g. "video 2160p DV", "audio en-US 5.1")
- downloads fill 0-90%; repackaging (when needed) and a "muxing" stage carry it to 100% so post-download work is no longer frozen at 100%
- monotonic throughout (handles the download->decrypt callable reuse)

Also:
- accept "HDR10P" as the canonical API range value ("HDR10+" still works)
- declare AUTH_METHODS opt-in on the Service base
- raise typed APIError (WORKER_ERROR/DOWNLOAD_ERROR) from the worker path
- move the progress helpers to unshackle/core/api/progress.py
This commit is contained in:
imSp4rky
2026-06-08 15:37:40 -06:00
parent 1a3cd09fc8
commit 2f35a4d468
8 changed files with 333 additions and 76 deletions

View File

@@ -9,13 +9,8 @@ import pytest
from aiohttp import web
from unshackle.core.api import handlers
from unshackle.core.api.download_manager import (
DownloadJob,
JobStatus,
_redact_parameters,
_redact_text,
_secret_values,
)
from unshackle.core.api.download_manager import (DownloadJob, JobStatus, _redact_parameters, _redact_text,
_secret_values)
from unshackle.core.api.errors import APIError, APIErrorCode
pytestmark = pytest.mark.unit
@@ -140,3 +135,18 @@ async def test_credential_allowed_when_enabled(stub_handler):
stub_handler.setattr(handlers.config, "serve", {"allow_job_credentials": True})
resp = await handlers.download_handler({"service": "ATV", "title_id": "t", "credential": "u:p"})
assert isinstance(resp, web.Response)
# ---------- range validation ----------
def test_range_validation_accepts_hdr10p_and_alias():
# canonical "HDR10P" and back-compat "HDR10+" both pass; mixed casing too
assert handlers.validate_download_parameters({"range": ["HDR10P", "DV", "SDR"]}) is None
assert handlers.validate_download_parameters({"range": ["hdr10+"]}) is None
assert handlers.validate_download_parameters({"range": "HYBRID"}) is None
def test_range_validation_rejects_unknown_and_lists_hdr10p():
err = handlers.validate_download_parameters({"range": ["HDR99"]})
assert err and "HDR10P" in err and "HDR99" in err

View File

@@ -0,0 +1,155 @@
"""Unit tests for the aggregate per-job download progress sink.
``build_job_progress_callables`` wraps the per-track progress callables so the API job sees one
aggregate signal - a bitrate-weighted completion percentage, track counts, and the labels of the
tracks downloading now - instead of each track's own bouncing 0-100%. These tests pin that
contract."""
from __future__ import annotations
import pytest
from unshackle.core.api.progress import (DOWNLOAD_PROGRESS_CEILING, build_job_progress_callables,
track_progress_label, track_progress_weight)
pytestmark = pytest.mark.unit
# --- lightweight track stand-ins (label/weight key off class name + attributes) ---
class _Range:
def __init__(self, value):
self.value = value
class Video:
def __init__(self, height=1080, range_value="SDR", bitrate=4_000_000):
self.height = height
self.range = _Range(range_value)
self.bitrate = bitrate
class Audio:
def __init__(self, language="en-US", channels="2.0", bitrate=200_000):
self.language = language
self.channels = channels
self.bitrate = bitrate
class Subtitle:
def __init__(self, language="fr"):
self.language = language
self.bitrate = None
def _noop(**kwargs):
pass
def test_track_progress_label():
assert track_progress_label(Video(2160, "DV")) == "video 2160p DV"
assert track_progress_label(Video(1080, "HDR10+")) == "video 1080p HDR10+"
assert track_progress_label(Audio("en-US", "5.1")) == "audio en-US 5.1"
assert track_progress_label(Subtitle("ro")) == "subtitle ro"
def test_weight_video_over_audio_over_subtitle():
assert track_progress_weight(Video(bitrate=4_000_000)) == 4_000_000
assert track_progress_weight(Audio(bitrate=200_000)) == 200_000
# subtitle has no bitrate -> small fixed weight, far below media
assert track_progress_weight(Subtitle()) < track_progress_weight(Audio(bitrate=200_000))
def test_weighting_makes_video_dominate_progress():
updates: list[dict] = []
video, sub = Video(bitrate=4_000_000), Subtitle()
cbs = build_job_progress_callables([video, sub], [_noop, _noop], updates.append)
# subtitle fully done, video untouched -> progress is tiny (subtitle barely weighted)
cbs[1](downloaded="Downloaded")
assert updates[-1]["completed_tracks"] == 1
assert updates[-1]["progress"] < 5.0
# video half done -> progress is dominated by video (scaled into the 0..ceiling download band)
cbs[0](total=100, completed=50)
assert updates[-1]["progress"] > 40.0
def test_active_tracks_labels_reported_and_cleared_on_done():
updates: list[dict] = []
cbs = build_job_progress_callables(
[Video(2160, "DV"), Audio("en-US", "2.0")], [_noop, _noop], updates.append
)
cbs[0](total=100, completed=10) # video downloading
assert updates[-1]["active_tracks"] == ["video 2160p DV"]
assert updates[-1]["phase"] == "downloading video 2160p DV"
cbs[1](total=100, completed=10) # audio also downloading
assert updates[-1]["active_tracks"] == ["video 2160p DV", "audio en-US 2.0"]
cbs[0](downloaded="Downloaded") # video done -> drops out of active
assert updates[-1]["active_tracks"] == ["audio en-US 2.0"]
def test_aggregate_progress_is_monotonic_with_counts():
updates: list[dict] = []
inner_calls = [0, 0, 0]
def make_inner(i):
def inner(**kwargs):
inner_calls[i] += 1
return inner
tracks = [Video(bitrate=1000), Audio(bitrate=1000), Subtitle()]
cbs = build_job_progress_callables(tracks, [make_inner(0), make_inner(1), make_inner(2)], updates.append)
assert len(cbs) == 3
cbs[0](total=100, completed=50)
cbs[0](downloaded="Downloaded")
cbs[1](total=100, completed=50)
progresses = [u["progress"] for u in updates]
assert progresses == sorted(progresses)
assert updates[-1]["completed_tracks"] == 1
assert updates[-1]["total_tracks"] == 3
assert inner_calls == [2, 1, 0]
def test_all_tracks_done_reaches_download_ceiling():
# Downloads fill up to the ceiling; dl.result drives muxing the rest of the way to 100.
updates: list[dict] = []
cbs = build_job_progress_callables([Audio(bitrate=1000), Audio(bitrate=1000)], [_noop, _noop], updates.append)
cbs[0](total=10, completed=10, downloaded="Downloaded")
assert updates[-1]["progress"] < DOWNLOAD_PROGRESS_CEILING
assert updates[-1]["completed_tracks"] == 1
cbs[1](total=10, completed=10, downloaded="Decrypted")
assert updates[-1]["progress"] == pytest.approx(DOWNLOAD_PROGRESS_CEILING)
assert updates[-1]["completed_tracks"] == 2
def test_finished_track_does_not_dip_when_callable_reused_for_decrypt():
"""A track hits 100% (then decrypt reuses the callable with completed=0); the aggregate must
hold, never dip - even before the terminal 'Downloaded'/'Decrypted' string arrives."""
updates: list[dict] = []
cbs = build_job_progress_callables([Video(bitrate=1000), Video(bitrate=1000)], [_noop, _noop], updates.append)
cbs[0](total=100, completed=100) # download hits 100% BEFORE any terminal string -> 50%
cbs[0](total=200, completed=0) # decrypt phase resets counts, still no terminal string
cbs[0](total=200, completed=100) # decrypt mid-way
cbs[0](total=200, completed=200, downloaded="Decrypted") # terminal
progresses = [u["progress"] for u in updates]
assert progresses == sorted(progresses) # monotonic, no dip
assert updates[-1]["progress"] == pytest.approx(DOWNLOAD_PROGRESS_CEILING / 2)
assert updates[-1]["completed_tracks"] == 1
def test_skipped_subtitle_counts_as_done():
updates: list[dict] = []
cbs = build_job_progress_callables([Subtitle()], [_noop], updates.append)
cbs[0](downloaded="[yellow]SKIPPED")
assert updates[-1]["completed_tracks"] == 1
assert updates[-1]["progress"] == pytest.approx(DOWNLOAD_PROGRESS_CEILING)

View File

@@ -1143,6 +1143,7 @@ class dl:
split_audio: Optional[bool] = None,
real_video_bitrate: bool = False,
real_audio_bitrate: bool = False,
progress_sink: Optional[Callable[[dict[str, Any]], None]] = None,
*_: Any,
**__: Any,
) -> None:
@@ -2214,6 +2215,13 @@ class dl:
selected_tracks, tracks_progress_callables = title.tracks.tree(add_progress=True)
if progress_sink is not None:
from unshackle.core.api.progress import build_job_progress_callables
tracks_progress_callables = build_job_progress_callables(
list(title.tracks), tracks_progress_callables, progress_sink
)
for track in title.tracks:
if hasattr(track, "needs_drm_loading") and track.needs_drm_loading:
track.load_drm_if_needed(service)
@@ -2470,6 +2478,8 @@ class dl:
break
# Now repack the decrypted tracks
if progress_sink and any(getattr(t, "needs_repack", False) for t in title.tracks):
progress_sink({"phase": "repackaging", "progress": 92.0, "status": "downloading", "active_tracks": []})
with console.status("Repackaging tracks with FFMPEG..."):
has_repacked = False
for track in title.tracks:
@@ -2636,6 +2646,8 @@ class dl:
for video_track in title.tracks.videos or [None]:
mux_video_standalone(video_track)
if progress_sink:
progress_sink({"phase": "muxing", "progress": 96.0, "status": "downloading", "active_tracks": []})
try:
with Live(Padding(progress, (0, 5, 1, 5)), console=console):
mux_index = 0

View File

@@ -108,8 +108,11 @@ class DownloadJob:
error_traceback: Optional[str] = None
worker_stderr: Optional[str] = None
# Human-readable current phase (e.g. "downloading video 1080p")
# Current phase, track counts, and labels of the tracks downloading now.
phase: Optional[str] = None
completed_tracks: int = 0
total_tracks: int = 0
active_tracks: List[str] = field(default_factory=list)
# Subtitles skipped under skip_subtitle_errors (non-fatal). Each entry is a dl.SkippedSubtitle
# dict (id / language / title) so a client can report which weren't available.
@@ -128,6 +131,9 @@ class DownloadJob:
"title_id": self.title_id,
"progress": self.progress,
"phase": self.phase,
"completed_tracks": self.completed_tracks,
"total_tracks": self.total_tracks,
"active_tracks": self.active_tracks,
"skipped_subtitles": self.skipped_subtitles,
}
@@ -175,6 +181,7 @@ def _perform_download(
import yaml
from unshackle.commands.dl import dl
from unshackle.core.api.errors import APIError, APIErrorCode
from unshackle.core.config import config
from unshackle.core.services import Services
from unshackle.core.tracks import Subtitle, Video
@@ -359,69 +366,16 @@ def _perform_download(
stdout_capture = StringIO()
stderr_capture = StringIO()
# Simple progress tracking if callback provided
# The progress_sink (dl.build_job_progress_callables) owns the percentage; status changes
# are emitted here.
if progress_callback:
# Report initial progress
progress_callback({"progress": 0.0, "status": "starting"})
# Tee each Track.download's progress callable so the downloader's live percentage
# is forwarded to the API job (not just 5%/100%), and expose which track is being
# downloaded now as a human-readable phase.
from unshackle.core.tracks.track import Track as _Track
if not getattr(_Track, "_api_progress_patched", False):
_orig_track_download = _Track.download
def _download_with_progress(self, *args, **kwargs):
inner_progress = kwargs.get("progress")
track_type = type(self).__name__
phase = {
"Video": "downloading video",
"Audio": "downloading audio",
"Subtitle": "downloading subtitle",
}.get(track_type, f"downloading {track_type.lower()}")
height = getattr(self, "height", None)
language = getattr(self, "language", None)
if height:
phase += f" {height}p"
elif track_type in ("Audio", "Subtitle") and language:
phase += f" {language}"
progress_callback({"phase": phase, "status": "downloading"})
if callable(inner_progress):
counts = {"completed": 0.0, "total": 0.0}
def tee(*tee_args, **tee_kwargs):
if tee_kwargs.get("total"):
counts["total"] = tee_kwargs["total"]
if tee_kwargs.get("completed") is not None:
counts["completed"] = tee_kwargs["completed"]
if "advance" in tee_kwargs:
counts["completed"] += tee_kwargs["advance"]
pct = counts["completed"] * 100.0 / counts["total"] if counts["total"] else 0
if pct:
progress_callback(
{"progress": min(99.0, float(pct)), "phase": phase, "status": "downloading"}
)
return inner_progress(*tee_args, **tee_kwargs)
kwargs["progress"] = tee
return _orig_track_download(self, *args, **kwargs)
_Track.download = _download_with_progress
_Track._api_progress_patched = True
original_result = dl_instance.result
def result_with_progress(*args, **kwargs):
try:
# Report that download started
progress_callback({"progress": 5.0, "status": "downloading"})
# Call original method
progress_callback({"status": "downloading"})
result = original_result(*args, **kwargs)
# Report completion
progress_callback({"progress": 100.0, "status": "completed"})
return result
except Exception as e:
@@ -481,6 +435,7 @@ def _perform_download(
worst=params.get("worst", False),
best_available=params.get("best_available", False),
split_audio=params.get("split_audio"),
progress_sink=progress_callback,
)
except SystemExit as exc:
@@ -490,7 +445,7 @@ def _perform_download(
log.error(f"Download exited with code {exc.code}")
log.error(f"Stdout: {stdout_str}")
log.error(f"Stderr: {stderr_str}")
raise Exception(f"Download failed with exit code {exc.code}")
raise APIError(APIErrorCode.DOWNLOAD_ERROR, f"Download failed with exit code {exc.code}")
except Exception as exc: # noqa: BLE001 - propagate to caller
stdout_str = stdout_capture.getvalue()
@@ -504,7 +459,7 @@ def _perform_download(
# It sets download_failed in that case, so the job isn't reported as completed with no output.
if getattr(dl_instance, "download_failed", False):
detail = (stdout_capture.getvalue() + stderr_capture.getvalue())[-200:].strip()
raise Exception("download worker failed: " + (detail or "see logs"))
raise APIError(APIErrorCode.WORKER_ERROR, "download worker failed: " + (detail or "see logs"))
# Surface any subtitles that were skipped (non-fatal failures) so the client can report them.
if progress_callback:
@@ -796,6 +751,12 @@ class DownloadQueueManager:
progress_data = json.load(handle)
if progress_data.get("phase") and progress_data["phase"] != job.phase:
job.phase = progress_data["phase"]
if progress_data.get("total_tracks"):
job.total_tracks = int(progress_data["total_tracks"])
if progress_data.get("completed_tracks") is not None:
job.completed_tracks = int(progress_data["completed_tracks"])
if "active_tracks" in progress_data:
job.active_tracks = list(progress_data["active_tracks"])
if progress_data.get("skipped_subtitles"):
job.skipped_subtitles = progress_data["skipped_subtitles"]
if "progress" in progress_data:

View File

@@ -957,13 +957,13 @@ def validate_download_parameters(data: Dict[str, Any]) -> Optional[str]:
return "Cannot use both s_lang and require_subs"
if "range" in data and data["range"]:
valid_ranges = ["SDR", "HDR10", "HDR10+", "DV", "HLG", "HYBRID"]
if isinstance(data["range"], list):
for r in data["range"]:
if r.upper() not in valid_ranges:
return f"Invalid range value: {r}. Must be one of: {', '.join(valid_ranges)}"
elif data["range"].upper() not in valid_ranges:
return f"Invalid range value: {data['range']}. Must be one of: {', '.join(valid_ranges)}"
# "HDR10P" is the canonical range value ("+" is awkward in scripts); "HDR10+" stays valid.
valid_ranges = ["SDR", "HDR10", "HDR10P", "DV", "HLG", "HYBRID"]
accepted = {*valid_ranges, "HDR10+"}
values = data["range"] if isinstance(data["range"], list) else [data["range"]]
for r in values:
if r.upper() not in accepted:
return f"Invalid range value: {r}. Must be one of: {', '.join(valid_ranges)}"
return None

View File

@@ -0,0 +1,118 @@
"""Aggregate job-level download progress for the REST API.
Turns the per-track progress callables from ``Tracks.tree`` into one signal a job can report:
a bitrate-weighted percentage, track counts, and the labels of the tracks downloading now.
"""
from __future__ import annotations
from typing import Any, Callable
from unshackle.core.constants import AnyTrack
JOB_PROGRESS_TERMINAL_STATES = {"Downloaded", "Decrypted", "[yellow]SKIPPED"}
# Weight for a track with no bitrate (subtitles); small vs media bitrates so subs barely move the bar.
SUBTITLE_PROGRESS_WEIGHT = 50_000.0
# Downloads fill 0..this; dl.result drives the remainder (repackaging, muxing) up to 100.
DOWNLOAD_PROGRESS_CEILING = 90.0
def track_progress_label(track: AnyTrack) -> str:
"""Short label for a track, e.g. "video 2160p DV", "audio en-US 5.1", "subtitle fr"."""
track_type = type(track).__name__
if track_type == "Video":
parts = ["video"]
height = getattr(track, "height", None)
if height:
parts.append(f"{height}p")
track_range = getattr(track, "range", None)
if track_range is not None:
parts.append(track_range.value)
return " ".join(parts)
if track_type == "Audio":
parts = ["audio"]
language = getattr(track, "language", None)
if language:
parts.append(str(language))
channels = getattr(track, "channels", None)
if channels:
parts.append(str(channels))
return " ".join(parts)
if track_type == "Subtitle":
language = getattr(track, "language", None)
return f"subtitle {language}" if language else "subtitle"
return track_type.lower()
def track_progress_weight(track: AnyTrack) -> float:
"""Track weight in the aggregate (its bitrate in bits/s), so video/audio dominate subtitles."""
bitrate = getattr(track, "bitrate", None)
return float(bitrate) if bitrate else SUBTITLE_PROGRESS_WEIGHT
def build_job_progress_callables(
tracks: list[AnyTrack],
inner_callables: list[Callable[..., None]],
sink: Callable[[dict[str, Any]], None],
) -> list[Callable[..., None]]:
"""Wrap each track's progress callable so ``sink`` receives aggregate job progress.
The sink gets a bitrate-weighted mean completion across all tracks, ``completed_tracks`` /
``total_tracks`` counts, and ``active_tracks`` labels. Each track's fraction is monotonic, so
the percentage only climbs. The original ``inner`` callable is always invoked.
"""
total = len(inner_callables)
weights = [track_progress_weight(t) for t in tracks]
labels = [track_progress_label(t) for t in tracks]
total_weight = sum(weights) or 1.0
fractions = [0.0] * total
done = [False] * total
started = [False] * total
def emit() -> None:
completed = sum(done)
# Downloads fill 0..DOWNLOAD_PROGRESS_CEILING; dl.result drives muxing up to 100.
progress = sum(w * f for w, f in zip(weights, fractions)) * DOWNLOAD_PROGRESS_CEILING / total_weight
active = [labels[i] for i in range(total) if started[i] and not done[i]]
if active:
phase = "downloading " + ", ".join(active[:3])
if len(active) > 3:
phase += f" (+{len(active) - 3} more)"
else:
phase = f"downloading {completed}/{total} tracks"
sink(
{
"progress": progress,
"phase": phase,
"completed_tracks": completed,
"total_tracks": total,
"active_tracks": active,
"status": "downloading",
}
)
def wrap(index: int, inner: Callable[..., None]) -> Callable[..., None]:
counts = {"completed": 0.0, "total": 0.0}
def tee(*args: Any, **kwargs: Any) -> None:
started[index] = True
if kwargs.get("total"):
counts["total"] = kwargs["total"]
if kwargs.get("completed") is not None:
counts["completed"] = kwargs["completed"]
if "advance" in kwargs:
counts["completed"] += kwargs["advance"]
if kwargs.get("downloaded") in JOB_PROGRESS_TERMINAL_STATES:
done[index] = True
fractions[index] = 1.0
elif counts["total"]:
# max() keeps the fraction monotonic across the download->decrypt callable reuse.
fractions[index] = max(fractions[index], min(1.0, counts["completed"] / counts["total"]))
emit()
return inner(*args, **kwargs)
return tee
return [wrap(i, inner) for i, inner in enumerate(inner_callables)]

View File

@@ -233,9 +233,7 @@ async def services(request: web.Request) -> web.Response:
or getattr(service_module, "get_playready_license", None) is not _BaseService.get_playready_license
)
# Auth methods the service accepts. Prefer an explicit `AUTH_METHODS` class var
# (reliable); otherwise fall back to inferring from what authenticate() references
# - that mostly returns both because services call super().authenticate(...).
# Prefer the service's explicit AUTH_METHODS; otherwise infer from authenticate().
methods = []
if service_data["needs_auth"]:
declared = getattr(service_module, "AUTH_METHODS", None)

View File

@@ -95,6 +95,9 @@ class Service(metaclass=ABCMeta):
GEOFENCE: tuple[str, ...] = () # list of ip regions required to use the service. empty list == no specific region.
# vault namespace override; when set, key vault read/write uses this tag instead of the service's own.
VAULT_TAG: Optional[str] = None
# Auth methods the service accepts ("cookies"/"credentials"); when None the REST /services
# endpoint infers them from authenticate().
AUTH_METHODS: Optional[tuple[str, ...]] = None
def __init__(self, ctx: click.Context):
console.print(Padding(Rule(f"[rule.text]Service: {self.__class__.__name__}"), (1, 2)))