feat(dash): refactor segment extraction and add content period validation

This commit is contained in:
Andy
2026-03-20 12:47:49 -06:00
parent c323db9481
commit dc197af29e

View File

@@ -107,13 +107,7 @@ class DASH:
if period_id := period.get("id"): if period_id := period.get("id"):
filtered_period_ids.append(period_id) filtered_period_ids.append(period_id)
continue continue
if next(iter(period.xpath("SegmentType/@value")), "content") != "content": if not DASH._is_content_period(period, []):
if period_id := period.get("id"):
filtered_period_ids.append(period_id)
continue
if "urn:amazon:primevideo:cachingBreadth" in [
x.get("schemeIdUri") for x in period.findall("SupplementalProperty")
]:
if period_id := period.get("id"): if period_id := period.get("id"):
filtered_period_ids.append(period_id) filtered_period_ids.append(period_id)
continue continue
@@ -242,6 +236,7 @@ class DASH:
"period": period, "period": period,
"adaptation_set": adaptation_set, "adaptation_set": adaptation_set,
"representation": rep, "representation": rep,
"representation_id": rep.get("id"),
"filtered_period_ids": filtered_period_ids, "filtered_period_ids": filtered_period_ids,
} }
}, },
@@ -278,9 +273,10 @@ class DASH:
log = logging.getLogger("DASH") log = logging.getLogger("DASH")
manifest: ElementTree = track.data["dash"]["manifest"] manifest: ElementTree = track.data["dash"]["manifest"]
period: Element = track.data["dash"]["period"]
adaptation_set: Element = track.data["dash"]["adaptation_set"] adaptation_set: Element = track.data["dash"]["adaptation_set"]
representation: Element = track.data["dash"]["representation"] representation: Element = track.data["dash"]["representation"]
rep_id: Optional[str] = track.data["dash"].get("representation_id") or representation.get("id")
filtered_period_ids: list[str] = track.data["dash"].get("filtered_period_ids", [])
# Preserve existing DRM if it was set by the service, especially when service set Widevine # Preserve existing DRM if it was set by the service, especially when service set Widevine
# but manifest only contains PlayReady protection (common scenario for some services) # but manifest only contains PlayReady protection (common scenario for some services)
@@ -321,174 +317,63 @@ class DASH:
if kid not in drm_obj.content_keys: if kid not in drm_obj.content_keys:
drm_obj.content_keys[kid] = key drm_obj.content_keys[kid] = key
manifest_base_url = manifest.findtext("BaseURL") # Collect segments from all content periods in the manifest
if not manifest_base_url: all_periods = manifest.findall("Period")
manifest_base_url = track.url
elif not re.match("^https?://", manifest_base_url, re.IGNORECASE):
manifest_base_url = urljoin(track.url, f"./{manifest_base_url}")
period_base_url = urljoin(manifest_base_url, period.findtext("BaseURL") or "")
adaptation_set_base_url = urljoin(period_base_url, adaptation_set.findtext("BaseURL") or "")
rep_base_url = urljoin(adaptation_set_base_url, representation.findtext("BaseURL") or "")
period_duration = period.get("duration") or manifest.get("mediaPresentationDuration")
init_data: Optional[bytes] = None
segment_template = representation.find("SegmentTemplate")
if segment_template is None:
segment_template = adaptation_set.find("SegmentTemplate")
segment_list = representation.find("SegmentList")
if segment_list is None:
segment_list = adaptation_set.find("SegmentList")
segment_base = representation.find("SegmentBase")
if segment_base is None:
segment_base = adaptation_set.find("SegmentBase")
segments: list[tuple[str, Optional[str]]] = [] segments: list[tuple[str, Optional[str]]] = []
segment_timescale: float = 0
segment_durations: list[int] = [] segment_durations: list[int] = []
segment_timescale: float = 0
init_data: Optional[bytes] = None
track_kid: Optional[UUID] = None track_kid: Optional[UUID] = None
if segment_template is not None: content_periods = [p for p in all_periods if DASH._is_content_period(p, filtered_period_ids)]
segment_template = copy(segment_template) period_count = len(content_periods)
start_number = int(segment_template.get("startNumber") or 1)
end_number = int(segment_template.get("endNumber") or 0) or None
segment_timeline = segment_template.find("SegmentTimeline")
segment_timescale = float(segment_template.get("timescale") or 1)
for item in ("initialization", "media"): if period_count > 1:
value = segment_template.get(item) log.info(f"Multi-period manifest detected with {period_count} content periods")
if not value:
for period_idx, content_period in enumerate(content_periods):
# Find the matching representation in this period
matched_rep = None
matched_as = None
for as_ in content_period.findall("AdaptationSet"):
if DASH.is_trick_mode(as_):
continue continue
if not re.match("^https?://", value, re.IGNORECASE): for rep in as_.findall("Representation"):
if not rep_base_url: if rep.get("id") == rep_id:
raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.") matched_rep = rep
value = urljoin(rep_base_url, value) matched_as = as_
if not urlparse(value).query: break
manifest_url_query = urlparse(track.url).query if matched_rep is not None:
if manifest_url_query: break
value += f"?{manifest_url_query}"
segment_template.set(item, value)
init_url = segment_template.get("initialization") if matched_rep is None or matched_as is None:
if init_url: period_id = content_period.get("id", period_idx)
res = session.get( log.warning(f"Representation '{rep_id}' not found in period '{period_id}', skipping")
DASH.replace_fields( continue
init_url, Bandwidth=representation.get("bandwidth"), RepresentationID=representation.get("id")
)
)
res.raise_for_status()
init_data = res.content
track_kid = track.get_key_id(init_data)
if segment_timeline is not None: p_init, p_segments, p_timescale, p_durations, p_kid = DASH._get_period_segments(
current_time = 0 period=content_period,
for s in segment_timeline.findall("S"): adaptation_set=matched_as,
if s.get("t"): representation=matched_rep,
current_time = int(s.get("t")) manifest=manifest,
for _ in range(1 + (int(s.get("r") or 0))): track=track,
segment_durations.append(current_time) track_url=track.url,
current_time += int(s.get("d")) session=session,
if not end_number:
end_number = len(segment_durations)
# Handle high startNumber in DVR/catch-up manifests where startNumber > segment count
if start_number > end_number:
end_number = start_number + len(segment_durations) - 1
for t, n in zip(segment_durations, range(start_number, end_number + 1)):
segments.append(
(
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=n,
RepresentationID=representation.get("id"),
Time=t,
),
None,
)
) )
if period_idx == 0:
# First period: use its init data and KID for DRM licensing
init_data = p_init
track_kid = p_kid
segment_timescale = p_timescale
else: else:
if not period_duration: if p_kid and track_kid and p_kid != track_kid:
raise ValueError("Duration of the Period was unable to be determined.") log.debug(f"Period {content_period.get('id', period_idx)} has different KID: {p_kid}")
period_duration = DASH.pt_to_sec(period_duration)
segment_duration = float(segment_template.get("duration")) or 1
if not end_number: segments.extend(p_segments)
segment_count = math.ceil(period_duration / (segment_duration / segment_timescale)) segment_durations.extend(p_durations)
end_number = start_number + segment_count - 1
for s in range(start_number, end_number + 1): if not segments:
segments.append(
(
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=s,
RepresentationID=representation.get("id"),
Time=s,
),
None,
)
)
# TODO: Should we floor/ceil/round, or is int() ok?
segment_durations.append(int(segment_duration))
elif segment_list is not None:
segment_timescale = float(segment_list.get("timescale") or 1)
init_data = None
initialization = segment_list.find("Initialization")
if initialization is not None:
source_url = initialization.get("sourceURL")
if not source_url:
source_url = rep_base_url
elif not re.match("^https?://", source_url, re.IGNORECASE):
source_url = urljoin(rep_base_url, f"./{source_url}")
if initialization.get("range"):
init_range_header = {"Range": f"bytes={initialization.get('range')}"}
else:
init_range_header = None
res = session.get(url=source_url, headers=init_range_header)
res.raise_for_status()
init_data = res.content
track_kid = track.get_key_id(init_data)
segment_urls = segment_list.findall("SegmentURL")
for segment_url in segment_urls:
media_url = segment_url.get("media")
if not media_url:
media_url = rep_base_url
elif not re.match("^https?://", media_url, re.IGNORECASE):
media_url = urljoin(rep_base_url, f"./{media_url}")
segments.append((media_url, segment_url.get("mediaRange")))
segment_durations.append(int(segment_url.get("duration") or 1))
elif segment_base is not None:
media_range = None
init_data = None
initialization = segment_base.find("Initialization")
if initialization is not None:
if initialization.get("range"):
init_range_header = {"Range": f"bytes={initialization.get('range')}"}
else:
init_range_header = None
res = session.get(url=rep_base_url, headers=init_range_header)
res.raise_for_status()
init_data = res.content
track_kid = track.get_key_id(init_data)
total_size = res.headers.get("Content-Range", "").split("/")[-1]
if total_size:
media_range = f"{len(init_data)}-{total_size}"
segments.append((rep_base_url, media_range))
elif rep_base_url:
segments.append((rep_base_url, None))
else:
log.error("Could not find a way to get segments from this MPD manifest.") log.error("Could not find a way to get segments from this MPD manifest.")
log.debug(track.url) log.debug(track.url)
sys.exit(1) sys.exit(1)
@@ -682,6 +567,212 @@ class DASH:
progress(downloaded="Downloaded") progress(downloaded="Downloaded")
@staticmethod
def _is_content_period(period: Element, filtered_period_ids: list[str]) -> bool:
"""Check if a period is a valid content period (not an ad, not filtered, not trick mode)."""
period_id = period.get("id")
if period_id and period_id in filtered_period_ids:
return False
if next(iter(period.xpath("SegmentType/@value")), "content") != "content":
return False
if "urn:amazon:primevideo:cachingBreadth" in [
x.get("schemeIdUri") for x in period.findall("SupplementalProperty")
]:
return False
return True
@staticmethod
def _get_period_segments(
period: Element,
adaptation_set: Element,
representation: Element,
manifest: ElementTree,
track: AnyTrack,
track_url: str,
session: Union[Session, CurlSession],
) -> tuple[
Optional[bytes],
list[tuple[str, Optional[str]]],
float,
list[int],
Optional[UUID],
]:
"""
Extract segments from a single period's representation.
Returns:
A tuple of (init_data, segments, segment_timescale, segment_durations, track_kid).
"""
manifest_base_url = manifest.findtext("BaseURL")
if not manifest_base_url:
manifest_base_url = track_url
elif not re.match("^https?://", manifest_base_url, re.IGNORECASE):
manifest_base_url = urljoin(track_url, f"./{manifest_base_url}")
period_base_url = urljoin(manifest_base_url, period.findtext("BaseURL") or "")
adaptation_set_base_url = urljoin(period_base_url, adaptation_set.findtext("BaseURL") or "")
rep_base_url = urljoin(adaptation_set_base_url, representation.findtext("BaseURL") or "")
period_duration = period.get("duration") or manifest.get("mediaPresentationDuration")
init_data: Optional[bytes] = None
segment_template = representation.find("SegmentTemplate")
if segment_template is None:
segment_template = adaptation_set.find("SegmentTemplate")
segment_list = representation.find("SegmentList")
if segment_list is None:
segment_list = adaptation_set.find("SegmentList")
segment_base = representation.find("SegmentBase")
if segment_base is None:
segment_base = adaptation_set.find("SegmentBase")
segments: list[tuple[str, Optional[str]]] = []
segment_timescale: float = 0
segment_durations: list[int] = []
track_kid: Optional[UUID] = None
if segment_template is not None:
segment_template = copy(segment_template)
start_number = int(segment_template.get("startNumber") or 1)
end_number = int(segment_template.get("endNumber") or 0) or None
segment_timeline = segment_template.find("SegmentTimeline")
segment_timescale = float(segment_template.get("timescale") or 1)
for item in ("initialization", "media"):
value = segment_template.get(item)
if not value:
continue
if not re.match("^https?://", value, re.IGNORECASE):
if not rep_base_url:
raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.")
value = urljoin(rep_base_url, value)
if not urlparse(value).query:
manifest_url_query = urlparse(track_url).query
if manifest_url_query:
value += f"?{manifest_url_query}"
segment_template.set(item, value)
init_url = segment_template.get("initialization")
if init_url:
res = session.get(
DASH.replace_fields(
init_url, Bandwidth=representation.get("bandwidth"), RepresentationID=representation.get("id")
)
)
res.raise_for_status()
init_data = res.content
track_kid = track.get_key_id(init_data)
if segment_timeline is not None:
current_time = 0
for s in segment_timeline.findall("S"):
if s.get("t"):
current_time = int(s.get("t"))
for _ in range(1 + (int(s.get("r") or 0))):
segment_durations.append(current_time)
current_time += int(s.get("d"))
if not end_number:
end_number = len(segment_durations)
# Handle high startNumber in DVR/catch-up manifests where startNumber > segment count
if start_number > end_number:
end_number = start_number + len(segment_durations) - 1
for t, n in zip(segment_durations, range(start_number, end_number + 1)):
segments.append(
(
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=n,
RepresentationID=representation.get("id"),
Time=t,
),
None,
)
)
else:
if not period_duration:
raise ValueError("Duration of the Period was unable to be determined.")
period_duration = DASH.pt_to_sec(period_duration)
segment_duration = float(segment_template.get("duration")) or 1
if not end_number:
segment_count = math.ceil(period_duration / (segment_duration / segment_timescale))
end_number = start_number + segment_count - 1
for s in range(start_number, end_number + 1):
segments.append(
(
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=s,
RepresentationID=representation.get("id"),
Time=s,
),
None,
)
)
# TODO: Should we floor/ceil/round, or is int() ok?
segment_durations.append(int(segment_duration))
elif segment_list is not None:
segment_timescale = float(segment_list.get("timescale") or 1)
init_data = None
initialization = segment_list.find("Initialization")
if initialization is not None:
source_url = initialization.get("sourceURL")
if not source_url:
source_url = rep_base_url
elif not re.match("^https?://", source_url, re.IGNORECASE):
source_url = urljoin(rep_base_url, f"./{source_url}")
if initialization.get("range"):
init_range_header = {"Range": f"bytes={initialization.get('range')}"}
else:
init_range_header = None
res = session.get(url=source_url, headers=init_range_header)
res.raise_for_status()
init_data = res.content
track_kid = track.get_key_id(init_data)
segment_urls = segment_list.findall("SegmentURL")
for segment_url in segment_urls:
media_url = segment_url.get("media")
if not media_url:
media_url = rep_base_url
elif not re.match("^https?://", media_url, re.IGNORECASE):
media_url = urljoin(rep_base_url, f"./{media_url}")
segments.append((media_url, segment_url.get("mediaRange")))
segment_durations.append(int(segment_url.get("duration") or 1))
elif segment_base is not None:
media_range = None
init_data = None
initialization = segment_base.find("Initialization")
if initialization is not None:
if initialization.get("range"):
init_range_header = {"Range": f"bytes={initialization.get('range')}"}
else:
init_range_header = None
res = session.get(url=rep_base_url, headers=init_range_header)
res.raise_for_status()
init_data = res.content
track_kid = track.get_key_id(init_data)
total_size = res.headers.get("Content-Range", "").split("/")[-1]
if total_size:
media_range = f"{len(init_data)}-{total_size}"
segments.append((rep_base_url, media_range))
elif rep_base_url:
segments.append((rep_base_url, None))
return init_data, segments, segment_timescale, segment_durations, track_kid
@staticmethod @staticmethod
def _get(item: str, adaptation_set: Element, representation: Optional[Element] = None) -> Optional[Any]: def _get(item: str, adaptation_set: Element, representation: Optional[Element] = None) -> Optional[Any]:
"""Helper to get a requested item from the Representation, otherwise from the AdaptationSet.""" """Helper to get a requested item from the Representation, otherwise from the AdaptationSet."""