fix(api): repair REST API downloads, add /services flags & live progress (#113)

* feat(api): live download phase, granular progress, swallowed-failure detection, per-request CDM

- Tee Track.download progress so the job reports real percentage (not just 5/100%) and a
  human-readable phase ('downloading video 1080p') via the new job.phase field.
- Detect a swallowed download-worker failure (dl.result() prints 'Download Failed' but
  exits 0) and raise, so the job is marked failed instead of completed-with-no-output.
- Per-request CDM override (dl_instance.cdm_override; get_cdm prefers it) so a job can use
  a specific CDM device without mutating shared config.

* feat(api): expose service capability flags + auth methods in /services

needs_auth / has_search / has_drm derived from which Service hooks are overridden, and
auth_methods inferred from what the service's authenticate() body references (cookies /
credentials), so clients can show only the relevant auth options per service.

* feat(api): per-request credential injection for downloads

Accept a 'credential' ('user:pass') job parameter and feed it into the credentials map that
dl.get_credentials() reads, so a client can authenticate a download without persisting
anything to disk. (Kept on the deployment branch; the PR branch uses the client-sent path.)

* feat(api): gate per-request CDM override behind serve.cdm_overrides

A per-request `cdm` selects a server-side device, so honour it only when allow-listed.
`serve.cdm_overrides` opts in: a list permits those device names, or `true` permits any
(single trusted client). Unset/false rejects every override with 403, so an arbitrary device
can't be selected by default.

* fix(api): redact credentials and proxy userinfo in serialized job parameters

Job parameters can carry a raw user:pass credential and a proxy URL with embedded userinfo;
mask them wherever parameters are serialized for an API response so secrets don't leak via
the job-detail endpoint. Also read skipped-subtitle / download-failure state from the dl
instance instead of scraping stdout, and drop the dead n_m3u8dl percent branch.

* feat(api): prefer explicit AUTH_METHODS class var over source inference in /services

Inferring auth methods from authenticate() source mostly returns both options because services
call super().authenticate(cookies, credential). Prefer an explicit AUTH_METHODS class var when a
service declares one, falling back to inference.

* style: use plain hyphens instead of em-dashes in comments

* feat(api): gate per-job credentials, isolate caches, scrub error fields

Address review feedback on #113:
- Gate per-request credential/credentials behind serve.allow_job_credentials (default off,
  403 when not opted in), mirroring the existing serve.cdm_overrides CDM gate.
- Isolate the token cache per credential: when a per-job credential is set, namespace
  config.directories.cache by its hash in the worker, so two clients on the same service
  with different credentials can't share a cached session.
- Scrub the credential, its password half, and proxy userinfo out of error_message,
  error_details, error_traceback and worker_stderr before they leave via the job-detail API.
- Remove the unused _execute_download_sync in-process path (would have leaked one job's
  credential into the shared global config).
- Document serve.cdm_overrides and serve.allow_job_credentials in the example config.
- Add tests for both gates (403 default, allowlist pass) and the parameter/error redaction.

* fix(dl): flag download_failed when result() swallows a worker failure

dl.result() catches a download-worker exception, reports it, and returns
without re-raising so the CLI still exits cleanly. An embedding caller (the
API worker) had no way to tell the title actually failed and would report
the job completed with no output. Expose a download_failed flag, set in the
swallow path, that the worker reads after result() returns.

* feat(api): surface skipped subtitles and pass skip_subtitle_errors

Thread skip_subtitle_errors from the job into dl() so the API can opt into
non-fatal subtitle handling, and report which subtitles were skipped: store
them on the job (dl.SkippedSubtitle dicts) and include them in job details
so a client can tell the user which weren't available.

---------

Co-authored-by: Avi Cohen <avraham.coh770@gmail.com>
This commit is contained in:
AviDev
2026-06-08 20:38:08 +03:00
committed by GitHub
parent 79b884fb6b
commit 1a3cd09fc8
6 changed files with 390 additions and 14 deletions

View File

@@ -0,0 +1,142 @@
"""Unit tests for the /api/download security gates (per-request CDM + credential overrides)
and the secret redaction applied to job parameters and error/stderr fields."""
from __future__ import annotations
from datetime import datetime
import pytest
from aiohttp import web
from unshackle.core.api import handlers
from unshackle.core.api.download_manager import (
DownloadJob,
JobStatus,
_redact_parameters,
_redact_text,
_secret_values,
)
from unshackle.core.api.errors import APIError, APIErrorCode
pytestmark = pytest.mark.unit
# ---------- redaction ----------
def test_redact_parameters_masks_secrets_and_proxy_userinfo():
params = {
"service": "ATV",
"credential": "user:hunter2",
"password": "pw",
"token": "tok",
"api_key": "ak",
"proxy": "http://bob:secret@proxy.example:8080",
"quality": "1080p",
}
red = _redact_parameters(params)
assert red["credential"] == "***"
assert red["password"] == "***"
assert red["token"] == "***"
assert red["api_key"] == "***"
assert red["proxy"] == "http://***@proxy.example:8080"
assert red["quality"] == "1080p" # non-secret left intact
assert params["credential"] == "user:hunter2" # original dict not mutated
def test_redact_parameters_masks_credentials_dict():
assert _redact_parameters({"credentials": {"default": "u:p"}})["credentials"] == "***"
def test_secret_values_includes_password_half_and_dict_values():
secrets = _secret_values({"credential": "user:hunter2", "credentials": {"d": "alice:wonder"}})
assert "user:hunter2" in secrets # full credential
assert "hunter2" in secrets # password half of user:pass
assert "alice:wonder" in secrets # value from the credentials map
def test_redact_text_scrubs_credential_and_proxy_from_free_text():
params = {"credential": "user:hunter2", "proxy": "http://bob:secret@p:1"}
out = _redact_text("auth failed for user:hunter2 via http://bob:secret@p:1", params)
assert "hunter2" not in out
assert "bob:secret@" not in out
assert "***" in out
def test_redact_text_passthrough_without_secrets():
assert _redact_text("plain error", {}) == "plain error"
assert _redact_text(None, {}) is None
def test_to_dict_full_details_redacts_error_fields_and_parameters():
job = DownloadJob(
job_id="j1",
status=JobStatus.FAILED,
created_time=datetime(2026, 1, 1),
service="ATV",
title_id="t",
parameters={"credential": "user:hunter2"},
)
job.error_message = "login failed for user:hunter2"
job.worker_stderr = "Traceback ... user:hunter2 ..."
d = job.to_dict(include_full_details=True)
assert "hunter2" not in d["error_message"]
assert "hunter2" not in d["worker_stderr"]
assert d["parameters"]["credential"] == "***"
# ---------- gates ----------
class _PastGate(Exception):
"""Raised by the stubbed Services.load to prove a request got past the gate into the try block."""
@pytest.fixture
def stub_handler(monkeypatch):
"""Make the service valid and make the first call after the gate (Services.load) explode, so a
forbidden request raises APIError *before* the try block and an allowed one is caught inside it."""
monkeypatch.setattr(handlers, "validate_service", lambda tag, request=None: tag)
def _boom(*_args, **_kwargs):
raise _PastGate()
monkeypatch.setattr(handlers.Services, "load", _boom)
return monkeypatch
async def test_cdm_override_forbidden_by_default(stub_handler):
stub_handler.setattr(handlers.config, "serve", {})
with pytest.raises(APIError) as ei:
await handlers.download_handler({"service": "ATV", "title_id": "t", "cdm": "dev"})
assert ei.value.error_code == APIErrorCode.FORBIDDEN
async def test_cdm_override_allowed_when_enabled(stub_handler):
stub_handler.setattr(handlers.config, "serve", {"cdm_overrides": True})
# passing the gate reaches the stubbed Services.load, whose error is caught and returned as a response
resp = await handlers.download_handler({"service": "ATV", "title_id": "t", "cdm": "dev"})
assert isinstance(resp, web.Response)
async def test_cdm_override_allowlist_permits_only_named_device(stub_handler):
stub_handler.setattr(handlers.config, "serve", {"cdm_overrides": ["good"]})
assert isinstance(
await handlers.download_handler({"service": "ATV", "title_id": "t", "cdm": "good"}), web.Response
)
with pytest.raises(APIError) as ei:
await handlers.download_handler({"service": "ATV", "title_id": "t", "cdm": "other"})
assert ei.value.error_code == APIErrorCode.FORBIDDEN
async def test_credential_forbidden_by_default(stub_handler):
stub_handler.setattr(handlers.config, "serve", {})
with pytest.raises(APIError) as ei:
await handlers.download_handler({"service": "ATV", "title_id": "t", "credential": "u:p"})
assert ei.value.error_code == APIErrorCode.FORBIDDEN
async def test_credential_allowed_when_enabled(stub_handler):
stub_handler.setattr(handlers.config, "serve", {"allow_job_credentials": True})
resp = await handlers.download_handler({"service": "ATV", "title_id": "t", "credential": "u:p"})
assert isinstance(resp, web.Response)

View File

@@ -673,6 +673,10 @@ class dl:
# Subtitles skipped under --skip-subtitle-errors, recorded so an embedding caller can
# report which weren't available without parsing the console output. See SkippedSubtitle.
self.skipped_subtitles: list[SkippedSubtitle] = []
# result() catches a download-worker failure, reports it, and returns rather than
# re-raising (so the CLI still exits cleanly). Flag it so an embedding caller (the API
# worker) can tell the title did not complete instead of reading it as a success.
self.download_failed: bool = False
if not config.output_template:
raise click.ClickException(
@@ -2318,6 +2322,9 @@ class dl:
)
return
except Exception as e: # noqa
# Reported and swallowed (no re-raise) so the CLI exits cleanly; flag it so the
# API worker sees the title failed rather than completing with no output.
self.download_failed = True
error_messages = [
":x: Download Failed...",
f" {type(e).__name__}: {e}",
@@ -3455,7 +3462,9 @@ class dl:
Now supports quality-based selection when quality is provided.
Raises a ValueError if there's a problem getting a CDM.
"""
cdm_name = config.cdm.get(service) or config.cdm.get("default")
# A per-request override (set by the REST API per job) takes precedence over the
# global config, so a job can select a specific CDM device without mutating shared state.
cdm_name = getattr(self, "cdm_override", None) or config.cdm.get(service) or config.cdm.get("default")
if not cdm_name:
return None

View File

@@ -22,6 +22,60 @@ def _sanitize_log(value: object) -> str:
return str(value).replace("\n", "").replace("\r", "").replace("\x00", "")
# Job parameters may carry secrets (a raw "user:pass" credential, a proxy URL with embedded
# userinfo). These must never leave the process via the API or logs, so they are masked
# wherever parameters are serialized for a response.
_REDACTED = "***"
_SENSITIVE_PARAM_KEYS = ("credential", "credentials", "password", "token", "api_key")
_PROXY_USERINFO_RE = re.compile(r"(?<=://)[^/@]+@")
def _redact_parameters(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""Return a copy of job parameters with secrets masked, safe to serialize."""
if not isinstance(parameters, dict):
return parameters
redacted = dict(parameters)
for key in _SENSITIVE_PARAM_KEYS:
if redacted.get(key):
redacted[key] = _REDACTED
proxy = redacted.get("proxy")
if isinstance(proxy, str) and "@" in proxy:
redacted["proxy"] = _PROXY_USERINFO_RE.sub(f"{_REDACTED}@", proxy)
return redacted
def _secret_values(parameters: Dict[str, Any]) -> List[str]:
"""Raw secret strings carried in job parameters, longest first, for scrubbing free text."""
if not isinstance(parameters, dict):
return []
secrets: List[str] = []
for key in ("credential", "password", "token", "api_key"):
value = parameters.get(key)
if isinstance(value, str) and value:
secrets.append(value)
if key == "credential" and ":" in value:
password = value.split(":", 1)[1]
if len(password) >= 4: # short passwords would blanket-replace and garble the text
secrets.append(password)
creds = parameters.get("credentials")
if isinstance(creds, dict):
secrets.extend(v for v in creds.values() if isinstance(v, str) and v)
elif isinstance(creds, str) and creds:
secrets.append(creds)
return sorted(set(secrets), key=len, reverse=True) # longest first so substrings don't survive
def _redact_text(text: Optional[str], parameters: Dict[str, Any]) -> Optional[str]:
"""Mask proxy userinfo and any known parameter secrets that leaked into a free-text field
(error message / details / traceback / worker stderr) before it is returned via the API."""
if not isinstance(text, str) or not text:
return text
text = _PROXY_USERINFO_RE.sub(f"{_REDACTED}@", text)
for secret in _secret_values(parameters):
text = text.replace(secret, _REDACTED)
return text
class JobStatus(Enum):
QUEUED = "queued"
DOWNLOADING = "downloading"
@@ -54,6 +108,13 @@ class DownloadJob:
error_traceback: Optional[str] = None
worker_stderr: Optional[str] = None
# Human-readable current phase (e.g. "downloading video 1080p")
phase: Optional[str] = None
# Subtitles skipped under skip_subtitle_errors (non-fatal). Each entry is a dl.SkippedSubtitle
# dict (id / language / title) so a client can report which weren't available.
skipped_subtitles: List[Dict[str, Any]] = field(default_factory=list)
# Cancellation support
cancel_event: threading.Event = field(default_factory=threading.Event)
@@ -66,20 +127,24 @@ class DownloadJob:
"service": self.service,
"title_id": self.title_id,
"progress": self.progress,
"phase": self.phase,
"skipped_subtitles": self.skipped_subtitles,
}
if include_full_details:
# Error/stderr/traceback are free text a service may have echoed a credential or proxy
# URL into, so scrub them with the same secrets that _redact_parameters masks.
result.update(
{
"parameters": self.parameters,
"parameters": _redact_parameters(self.parameters),
"started_time": self.started_time.isoformat() if self.started_time else None,
"completed_time": self.completed_time.isoformat() if self.completed_time else None,
"output_files": self.output_files,
"error_message": self.error_message,
"error_details": self.error_details,
"error_message": _redact_text(self.error_message, self.parameters),
"error_details": _redact_text(self.error_details, self.parameters),
"error_code": self.error_code,
"error_traceback": self.error_traceback,
"worker_stderr": self.worker_stderr,
"error_traceback": _redact_text(self.error_traceback, self.parameters),
"worker_stderr": _redact_text(self.worker_stderr, self.parameters),
}
)
@@ -118,6 +183,16 @@ def _perform_download(
log.info(f"Starting sync download for job {job_id}")
# A service caches tokens under cache/<Service>/, keyed by service name only, so two jobs on
# one service with different credentials would share a cache. When a per-job credential is set,
# namespace the cache dir by a hash of it so the sessions can't cross.
job_credential = params.get("credential")
if job_credential:
import hashlib
cred_hash = hashlib.sha256(job_credential.encode("utf-8")).hexdigest()[:12]
config.directories.cache = config.directories.cache / "_jobs" / cred_hash
# Convert string parameters to enums (API receives strings, dl.result() expects enums)
vcodec_raw = params.get("vcodec")
if vcodec_raw:
@@ -224,6 +299,19 @@ def _perform_download(
enrich=params.get("enrich", False),
output_dir=Path(params["output_dir"]) if params.get("output_dir") else None,
)
# Per-request CDM override (a device name in the WVDs dir); get_cdm() takes it first.
if params.get("cdm"):
dl_instance.cdm_override = params["cdm"]
# Per-request credential ("user:pass"); feed it into the map get_credentials() reads so a
# client can authenticate without anything being persisted to disk. Without a profile,
# get_credentials() falls back to "default", so store it there too rather than dropping it
# (which would silently authenticate as the server's own default account).
if params.get("credential"):
svc_creds = config.credentials.get(service)
if not isinstance(svc_creds, dict):
config.credentials[service] = svc_creds = {}
svc_creds[params.get("profile") or "default"] = params["credential"]
service_module = Services.load(service)
@@ -276,7 +364,53 @@ def _perform_download(
# Report initial progress
progress_callback({"progress": 0.0, "status": "starting"})
# Simple approach: report progress at key points
# Tee each Track.download's progress callable so the downloader's live percentage
# is forwarded to the API job (not just 5%/100%), and expose which track is being
# downloaded now as a human-readable phase.
from unshackle.core.tracks.track import Track as _Track
if not getattr(_Track, "_api_progress_patched", False):
_orig_track_download = _Track.download
def _download_with_progress(self, *args, **kwargs):
inner_progress = kwargs.get("progress")
track_type = type(self).__name__
phase = {
"Video": "downloading video",
"Audio": "downloading audio",
"Subtitle": "downloading subtitle",
}.get(track_type, f"downloading {track_type.lower()}")
height = getattr(self, "height", None)
language = getattr(self, "language", None)
if height:
phase += f" {height}p"
elif track_type in ("Audio", "Subtitle") and language:
phase += f" {language}"
progress_callback({"phase": phase, "status": "downloading"})
if callable(inner_progress):
counts = {"completed": 0.0, "total": 0.0}
def tee(*tee_args, **tee_kwargs):
if tee_kwargs.get("total"):
counts["total"] = tee_kwargs["total"]
if tee_kwargs.get("completed") is not None:
counts["completed"] = tee_kwargs["completed"]
if "advance" in tee_kwargs:
counts["completed"] += tee_kwargs["advance"]
pct = counts["completed"] * 100.0 / counts["total"] if counts["total"] else 0
if pct:
progress_callback(
{"progress": min(99.0, float(pct)), "phase": phase, "status": "downloading"}
)
return inner_progress(*tee_args, **tee_kwargs)
kwargs["progress"] = tee
return _orig_track_download(self, *args, **kwargs)
_Track.download = _download_with_progress
_Track._api_progress_patched = True
original_result = dl_instance.result
def result_with_progress(*args, **kwargs):
@@ -326,6 +460,7 @@ def _perform_download(
subs_only=params.get("subs_only", False),
chapters_only=params.get("chapters_only", False),
no_subs=params.get("no_subs", False),
skip_subtitle_errors=params.get("skip_subtitle_errors", False),
no_audio=params.get("no_audio", False),
no_chapters=params.get("no_chapters", False),
no_video=params.get("no_video", False),
@@ -365,6 +500,18 @@ def _perform_download(
log.error(f"Stderr: {stderr_str}")
raise
# dl.result() catches a download-worker exception, reports it, but returns normally (exit 0).
# It sets download_failed in that case, so the job isn't reported as completed with no output.
if getattr(dl_instance, "download_failed", False):
detail = (stdout_capture.getvalue() + stderr_capture.getvalue())[-200:].strip()
raise Exception("download worker failed: " + (detail or "see logs"))
# Surface any subtitles that were skipped (non-fatal failures) so the client can report them.
if progress_callback:
skipped_subs = getattr(dl_instance, "skipped_subtitles", None)
if skipped_subs:
progress_callback({"skipped_subtitles": list(skipped_subs)})
output_files = [str(p) for p in dl_instance.completed_files]
log.info(f"Download completed for job {job_id}, {len(output_files)} file(s) in {original_download_dir}")
@@ -647,6 +794,10 @@ class DownloadQueueManager:
if os.path.exists(progress_path):
with open(progress_path, "r", encoding="utf-8") as handle:
progress_data = json.load(handle)
if progress_data.get("phase") and progress_data["phase"] != job.phase:
job.phase = progress_data["phase"]
if progress_data.get("skipped_subtitles"):
job.skipped_subtitles = progress_data["skipped_subtitles"]
if "progress" in progress_data:
new_progress = float(progress_data["progress"])
if new_progress != job.progress:
@@ -670,14 +821,18 @@ class DownloadQueueManager:
stdout = stdout_bytes.decode("utf-8", errors="ignore")
stderr = stderr_bytes.decode("utf-8", errors="ignore")
# A service can echo a credential or a proxy URL into its output, so scrub it before
# it reaches the log as well, not only the API response.
safe_stdout = _redact_text(stdout.strip(), job.parameters)
safe_stderr = _redact_text(stderr.strip(), job.parameters)
if stdout.strip():
log.debug(f"Worker stdout for job {job.job_id}: {stdout.strip()}")
log.debug(f"Worker stdout for job {job.job_id}: {safe_stdout}")
if stderr.strip():
job.worker_stderr = stderr.strip()
if returncode != 0:
log.warning(f"Worker stderr for job {job.job_id}: {stderr.strip()}")
log.warning(f"Worker stderr for job {job.job_id}: {safe_stderr}")
else:
log.debug(f"Worker stderr for job {job.job_id}: {stderr.strip()}")
log.debug(f"Worker stderr for job {job.job_id}: {safe_stderr}")
result_data: Optional[Dict[str, Any]] = None
try:
@@ -719,10 +874,6 @@ class DownloadQueueManager:
except OSError:
pass
def _execute_download_sync(self, job: DownloadJob) -> List[str]:
"""Execute download synchronously using existing dl.py logic."""
return _perform_download(job.job_id, job.service, job.title_id, job.parameters.copy(), job.cancel_event)
async def _cleanup_worker(self):
"""Worker that periodically cleans up old jobs."""
while not self._shutdown_event.is_set():

View File

@@ -1005,6 +1005,32 @@ async def download_handler(data: Dict[str, Any], request: Optional[web.Request]
details={"service": normalized_service, "title_id": title_id},
)
# A per-request `cdm` selects a server-side device, so it is gated here rather than honoured
# blindly. `serve.cdm_overrides` opts in: a list permits only those device names, or `true`
# permits any (for a single trusted client). Unset/false rejects every override.
requested_cdm = data.get("cdm")
if requested_cdm:
allowed = (config.serve or {}).get("cdm_overrides")
permitted = allowed is True or (
isinstance(allowed, (list, tuple, set)) and requested_cdm in allowed
)
if not permitted:
raise APIError(
APIErrorCode.FORBIDDEN,
"The requested CDM is not permitted for API downloads.",
details={"cdm": requested_cdm},
)
# A per-request `credential` (or `credentials` map) authenticates the job with client-supplied
# secrets instead of the server-side credentials. Gate it behind `serve.allow_job_credentials`
# (default off) so a default deployment stays locked to its own credentials; mirrors the CDM gate.
if data.get("credential") or data.get("credentials"):
if not (config.serve or {}).get("allow_job_credentials"):
raise APIError(
APIErrorCode.FORBIDDEN,
"Per-request credentials are not permitted for API downloads.",
)
try:
# Load service module to extract service-specific parameter defaults
service_module = Services.load(normalized_service)

View File

@@ -221,6 +221,43 @@ async def services(request: web.Request) -> web.Response:
if service_module.__doc__:
service_data["help"] = service_module.__doc__.strip()
# Capability flags, derived from which Service hooks the service overrides.
from unshackle.core.service import Service as _BaseService
service_data["needs_auth"] = (
getattr(service_module, "authenticate", None) is not _BaseService.authenticate
)
service_data["has_search"] = getattr(service_module, "search", None) is not _BaseService.search
service_data["has_drm"] = (
getattr(service_module, "get_widevine_license", None) is not _BaseService.get_widevine_license
or getattr(service_module, "get_playready_license", None) is not _BaseService.get_playready_license
)
# Auth methods the service accepts. Prefer an explicit `AUTH_METHODS` class var
# (reliable); otherwise fall back to inferring from what authenticate() references
# - that mostly returns both because services call super().authenticate(...).
methods = []
if service_data["needs_auth"]:
declared = getattr(service_module, "AUTH_METHODS", None)
if declared:
methods = list(declared)
else:
try:
import inspect as _inspect
src_lines = _inspect.getsource(service_module.authenticate).splitlines()
start = next((i + 1 for i, ln in enumerate(src_lines) if ln.rstrip().endswith(":")), 1)
body = "\n".join(src_lines[start:])
if "cookies" in body:
methods.append("cookies")
if "credential" in body:
methods.append("credentials")
except (OSError, TypeError):
pass
if not methods:
methods = ["cookies"]
service_data["auth_methods"] = methods
except Exception as e:
log.warning(f"Could not load details for service {tag}: {e}")

View File

@@ -525,6 +525,17 @@ serve:
# playready_devices: # PlayReady device paths (auto-populated from directories.prds)
# - '/path/to/device.prd'
# Per-request override gates (both default OFF = locked to server-side config).
# A /api/download job may ask to use a specific server-side CDM device or to
# authenticate with client-supplied credentials; both are rejected with 403 unless
# explicitly opted in here, so a default deployment can't be driven to use arbitrary
# devices or inject credentials.
# cdm_overrides: true # true = allow any device; or a list of permitted device names:
# cdm_overrides:
# - generic_nexus_4464_l3
# allow_job_credentials: true # allow a per-request `credential` / `credentials` in a job
# # (each distinct credential gets an isolated token cache)
# Optional: any /api/download flag can be set here as a server-side default.
# Per-request body values still win. Useful for raising concurrency without
# changing every client call. Full list of accepted keys: see docs/API.md.