diff --git a/unshackle/core/manifests/dash.py b/unshackle/core/manifests/dash.py index fe338c0..3ee61fd 100644 --- a/unshackle/core/manifests/dash.py +++ b/unshackle/core/manifests/dash.py @@ -107,13 +107,7 @@ class DASH: if period_id := period.get("id"): filtered_period_ids.append(period_id) continue - if next(iter(period.xpath("SegmentType/@value")), "content") != "content": - if period_id := period.get("id"): - filtered_period_ids.append(period_id) - continue - if "urn:amazon:primevideo:cachingBreadth" in [ - x.get("schemeIdUri") for x in period.findall("SupplementalProperty") - ]: + if not DASH._is_content_period(period, []): if period_id := period.get("id"): filtered_period_ids.append(period_id) continue @@ -242,6 +236,7 @@ class DASH: "period": period, "adaptation_set": adaptation_set, "representation": rep, + "representation_id": rep.get("id"), "filtered_period_ids": filtered_period_ids, } }, @@ -278,9 +273,10 @@ class DASH: log = logging.getLogger("DASH") manifest: ElementTree = track.data["dash"]["manifest"] - period: Element = track.data["dash"]["period"] adaptation_set: Element = track.data["dash"]["adaptation_set"] representation: Element = track.data["dash"]["representation"] + rep_id: Optional[str] = track.data["dash"].get("representation_id") or representation.get("id") + filtered_period_ids: list[str] = track.data["dash"].get("filtered_period_ids", []) # Preserve existing DRM if it was set by the service, especially when service set Widevine # but manifest only contains PlayReady protection (common scenario for some services) @@ -321,174 +317,63 @@ class DASH: if kid not in drm_obj.content_keys: drm_obj.content_keys[kid] = key - manifest_base_url = manifest.findtext("BaseURL") - if not manifest_base_url: - manifest_base_url = track.url - elif not re.match("^https?://", manifest_base_url, re.IGNORECASE): - manifest_base_url = urljoin(track.url, f"./{manifest_base_url}") - period_base_url = urljoin(manifest_base_url, period.findtext("BaseURL") or "") - adaptation_set_base_url = urljoin(period_base_url, adaptation_set.findtext("BaseURL") or "") - rep_base_url = urljoin(adaptation_set_base_url, representation.findtext("BaseURL") or "") - - period_duration = period.get("duration") or manifest.get("mediaPresentationDuration") - init_data: Optional[bytes] = None - - segment_template = representation.find("SegmentTemplate") - if segment_template is None: - segment_template = adaptation_set.find("SegmentTemplate") - - segment_list = representation.find("SegmentList") - if segment_list is None: - segment_list = adaptation_set.find("SegmentList") - - segment_base = representation.find("SegmentBase") - if segment_base is None: - segment_base = adaptation_set.find("SegmentBase") - + # Collect segments from all content periods in the manifest + all_periods = manifest.findall("Period") segments: list[tuple[str, Optional[str]]] = [] - segment_timescale: float = 0 segment_durations: list[int] = [] + segment_timescale: float = 0 + init_data: Optional[bytes] = None track_kid: Optional[UUID] = None - if segment_template is not None: - segment_template = copy(segment_template) - start_number = int(segment_template.get("startNumber") or 1) - end_number = int(segment_template.get("endNumber") or 0) or None - segment_timeline = segment_template.find("SegmentTimeline") - segment_timescale = float(segment_template.get("timescale") or 1) + content_periods = [p for p in all_periods if DASH._is_content_period(p, filtered_period_ids)] + period_count = len(content_periods) - for item in ("initialization", "media"): - value = segment_template.get(item) - if not value: + if period_count > 1: + log.info(f"Multi-period manifest detected with {period_count} content periods") + + for period_idx, content_period in enumerate(content_periods): + # Find the matching representation in this period + matched_rep = None + matched_as = None + for as_ in content_period.findall("AdaptationSet"): + if DASH.is_trick_mode(as_): continue - if not re.match("^https?://", value, re.IGNORECASE): - if not rep_base_url: - raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.") - value = urljoin(rep_base_url, value) - if not urlparse(value).query: - manifest_url_query = urlparse(track.url).query - if manifest_url_query: - value += f"?{manifest_url_query}" - segment_template.set(item, value) + for rep in as_.findall("Representation"): + if rep.get("id") == rep_id: + matched_rep = rep + matched_as = as_ + break + if matched_rep is not None: + break - init_url = segment_template.get("initialization") - if init_url: - res = session.get( - DASH.replace_fields( - init_url, Bandwidth=representation.get("bandwidth"), RepresentationID=representation.get("id") - ) - ) - res.raise_for_status() - init_data = res.content - track_kid = track.get_key_id(init_data) + if matched_rep is None or matched_as is None: + period_id = content_period.get("id", period_idx) + log.warning(f"Representation '{rep_id}' not found in period '{period_id}', skipping") + continue - if segment_timeline is not None: - current_time = 0 - for s in segment_timeline.findall("S"): - if s.get("t"): - current_time = int(s.get("t")) - for _ in range(1 + (int(s.get("r") or 0))): - segment_durations.append(current_time) - current_time += int(s.get("d")) + p_init, p_segments, p_timescale, p_durations, p_kid = DASH._get_period_segments( + period=content_period, + adaptation_set=matched_as, + representation=matched_rep, + manifest=manifest, + track=track, + track_url=track.url, + session=session, + ) - if not end_number: - end_number = len(segment_durations) - # Handle high startNumber in DVR/catch-up manifests where startNumber > segment count - if start_number > end_number: - end_number = start_number + len(segment_durations) - 1 - - for t, n in zip(segment_durations, range(start_number, end_number + 1)): - segments.append( - ( - DASH.replace_fields( - segment_template.get("media"), - Bandwidth=representation.get("bandwidth"), - Number=n, - RepresentationID=representation.get("id"), - Time=t, - ), - None, - ) - ) + if period_idx == 0: + # First period: use its init data and KID for DRM licensing + init_data = p_init + track_kid = p_kid + segment_timescale = p_timescale else: - if not period_duration: - raise ValueError("Duration of the Period was unable to be determined.") - period_duration = DASH.pt_to_sec(period_duration) - segment_duration = float(segment_template.get("duration")) or 1 + if p_kid and track_kid and p_kid != track_kid: + log.debug(f"Period {content_period.get('id', period_idx)} has different KID: {p_kid}") - if not end_number: - segment_count = math.ceil(period_duration / (segment_duration / segment_timescale)) - end_number = start_number + segment_count - 1 + segments.extend(p_segments) + segment_durations.extend(p_durations) - for s in range(start_number, end_number + 1): - segments.append( - ( - DASH.replace_fields( - segment_template.get("media"), - Bandwidth=representation.get("bandwidth"), - Number=s, - RepresentationID=representation.get("id"), - Time=s, - ), - None, - ) - ) - # TODO: Should we floor/ceil/round, or is int() ok? - segment_durations.append(int(segment_duration)) - elif segment_list is not None: - segment_timescale = float(segment_list.get("timescale") or 1) - - init_data = None - initialization = segment_list.find("Initialization") - if initialization is not None: - source_url = initialization.get("sourceURL") - if not source_url: - source_url = rep_base_url - elif not re.match("^https?://", source_url, re.IGNORECASE): - source_url = urljoin(rep_base_url, f"./{source_url}") - - if initialization.get("range"): - init_range_header = {"Range": f"bytes={initialization.get('range')}"} - else: - init_range_header = None - - res = session.get(url=source_url, headers=init_range_header) - res.raise_for_status() - init_data = res.content - track_kid = track.get_key_id(init_data) - - segment_urls = segment_list.findall("SegmentURL") - for segment_url in segment_urls: - media_url = segment_url.get("media") - if not media_url: - media_url = rep_base_url - elif not re.match("^https?://", media_url, re.IGNORECASE): - media_url = urljoin(rep_base_url, f"./{media_url}") - - segments.append((media_url, segment_url.get("mediaRange"))) - segment_durations.append(int(segment_url.get("duration") or 1)) - elif segment_base is not None: - media_range = None - init_data = None - initialization = segment_base.find("Initialization") - if initialization is not None: - if initialization.get("range"): - init_range_header = {"Range": f"bytes={initialization.get('range')}"} - else: - init_range_header = None - - res = session.get(url=rep_base_url, headers=init_range_header) - res.raise_for_status() - init_data = res.content - track_kid = track.get_key_id(init_data) - total_size = res.headers.get("Content-Range", "").split("/")[-1] - if total_size: - media_range = f"{len(init_data)}-{total_size}" - - segments.append((rep_base_url, media_range)) - elif rep_base_url: - segments.append((rep_base_url, None)) - else: + if not segments: log.error("Could not find a way to get segments from this MPD manifest.") log.debug(track.url) sys.exit(1) @@ -682,6 +567,212 @@ class DASH: progress(downloaded="Downloaded") + @staticmethod + def _is_content_period(period: Element, filtered_period_ids: list[str]) -> bool: + """Check if a period is a valid content period (not an ad, not filtered, not trick mode).""" + period_id = period.get("id") + if period_id and period_id in filtered_period_ids: + return False + if next(iter(period.xpath("SegmentType/@value")), "content") != "content": + return False + if "urn:amazon:primevideo:cachingBreadth" in [ + x.get("schemeIdUri") for x in period.findall("SupplementalProperty") + ]: + return False + return True + + @staticmethod + def _get_period_segments( + period: Element, + adaptation_set: Element, + representation: Element, + manifest: ElementTree, + track: AnyTrack, + track_url: str, + session: Union[Session, CurlSession], + ) -> tuple[ + Optional[bytes], + list[tuple[str, Optional[str]]], + float, + list[int], + Optional[UUID], + ]: + """ + Extract segments from a single period's representation. + + Returns: + A tuple of (init_data, segments, segment_timescale, segment_durations, track_kid). + """ + manifest_base_url = manifest.findtext("BaseURL") + if not manifest_base_url: + manifest_base_url = track_url + elif not re.match("^https?://", manifest_base_url, re.IGNORECASE): + manifest_base_url = urljoin(track_url, f"./{manifest_base_url}") + period_base_url = urljoin(manifest_base_url, period.findtext("BaseURL") or "") + adaptation_set_base_url = urljoin(period_base_url, adaptation_set.findtext("BaseURL") or "") + rep_base_url = urljoin(adaptation_set_base_url, representation.findtext("BaseURL") or "") + + period_duration = period.get("duration") or manifest.get("mediaPresentationDuration") + init_data: Optional[bytes] = None + + segment_template = representation.find("SegmentTemplate") + if segment_template is None: + segment_template = adaptation_set.find("SegmentTemplate") + + segment_list = representation.find("SegmentList") + if segment_list is None: + segment_list = adaptation_set.find("SegmentList") + + segment_base = representation.find("SegmentBase") + if segment_base is None: + segment_base = adaptation_set.find("SegmentBase") + + segments: list[tuple[str, Optional[str]]] = [] + segment_timescale: float = 0 + segment_durations: list[int] = [] + track_kid: Optional[UUID] = None + + if segment_template is not None: + segment_template = copy(segment_template) + start_number = int(segment_template.get("startNumber") or 1) + end_number = int(segment_template.get("endNumber") or 0) or None + segment_timeline = segment_template.find("SegmentTimeline") + segment_timescale = float(segment_template.get("timescale") or 1) + + for item in ("initialization", "media"): + value = segment_template.get(item) + if not value: + continue + if not re.match("^https?://", value, re.IGNORECASE): + if not rep_base_url: + raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.") + value = urljoin(rep_base_url, value) + if not urlparse(value).query: + manifest_url_query = urlparse(track_url).query + if manifest_url_query: + value += f"?{manifest_url_query}" + segment_template.set(item, value) + + init_url = segment_template.get("initialization") + if init_url: + res = session.get( + DASH.replace_fields( + init_url, Bandwidth=representation.get("bandwidth"), RepresentationID=representation.get("id") + ) + ) + res.raise_for_status() + init_data = res.content + track_kid = track.get_key_id(init_data) + + if segment_timeline is not None: + current_time = 0 + for s in segment_timeline.findall("S"): + if s.get("t"): + current_time = int(s.get("t")) + for _ in range(1 + (int(s.get("r") or 0))): + segment_durations.append(current_time) + current_time += int(s.get("d")) + + if not end_number: + end_number = len(segment_durations) + # Handle high startNumber in DVR/catch-up manifests where startNumber > segment count + if start_number > end_number: + end_number = start_number + len(segment_durations) - 1 + + for t, n in zip(segment_durations, range(start_number, end_number + 1)): + segments.append( + ( + DASH.replace_fields( + segment_template.get("media"), + Bandwidth=representation.get("bandwidth"), + Number=n, + RepresentationID=representation.get("id"), + Time=t, + ), + None, + ) + ) + else: + if not period_duration: + raise ValueError("Duration of the Period was unable to be determined.") + period_duration = DASH.pt_to_sec(period_duration) + segment_duration = float(segment_template.get("duration")) or 1 + + if not end_number: + segment_count = math.ceil(period_duration / (segment_duration / segment_timescale)) + end_number = start_number + segment_count - 1 + + for s in range(start_number, end_number + 1): + segments.append( + ( + DASH.replace_fields( + segment_template.get("media"), + Bandwidth=representation.get("bandwidth"), + Number=s, + RepresentationID=representation.get("id"), + Time=s, + ), + None, + ) + ) + # TODO: Should we floor/ceil/round, or is int() ok? + segment_durations.append(int(segment_duration)) + elif segment_list is not None: + segment_timescale = float(segment_list.get("timescale") or 1) + + init_data = None + initialization = segment_list.find("Initialization") + if initialization is not None: + source_url = initialization.get("sourceURL") + if not source_url: + source_url = rep_base_url + elif not re.match("^https?://", source_url, re.IGNORECASE): + source_url = urljoin(rep_base_url, f"./{source_url}") + + if initialization.get("range"): + init_range_header = {"Range": f"bytes={initialization.get('range')}"} + else: + init_range_header = None + + res = session.get(url=source_url, headers=init_range_header) + res.raise_for_status() + init_data = res.content + track_kid = track.get_key_id(init_data) + + segment_urls = segment_list.findall("SegmentURL") + for segment_url in segment_urls: + media_url = segment_url.get("media") + if not media_url: + media_url = rep_base_url + elif not re.match("^https?://", media_url, re.IGNORECASE): + media_url = urljoin(rep_base_url, f"./{media_url}") + + segments.append((media_url, segment_url.get("mediaRange"))) + segment_durations.append(int(segment_url.get("duration") or 1)) + elif segment_base is not None: + media_range = None + init_data = None + initialization = segment_base.find("Initialization") + if initialization is not None: + if initialization.get("range"): + init_range_header = {"Range": f"bytes={initialization.get('range')}"} + else: + init_range_header = None + + res = session.get(url=rep_base_url, headers=init_range_header) + res.raise_for_status() + init_data = res.content + track_kid = track.get_key_id(init_data) + total_size = res.headers.get("Content-Range", "").split("/")[-1] + if total_size: + media_range = f"{len(init_data)}-{total_size}" + + segments.append((rep_base_url, media_range)) + elif rep_base_url: + segments.append((rep_base_url, None)) + + return init_data, segments, segment_timescale, segment_durations, track_kid + @staticmethod def _get(item: str, adaptation_set: Element, representation: Optional[Element] = None) -> Optional[Any]: """Helper to get a requested item from the Representation, otherwise from the AdaptationSet."""