mirror of
https://github.com/unshackle-dl/unshackle.git
synced 2026-06-22 08:57:25 +00:00
feat(drm): add SAMPLE-AES MPEG-TS decryptor
Implements Apple HLS SAMPLE-AES decryption for MPEG-TS elementary streams, which neither Shaka Packager (rejects stream type 0xDB) nor mp4decrypt (ISO-BMFF only) can handle. Covers H.264 (1:9 pattern from offset 32 with EPB strip/reinsert), AAC and AC-3, then remuxes to a clean TS with the PMT stream types patched.
This commit is contained in:
369
unshackle/core/drm/sample_aes.py
Normal file
369
unshackle/core/drm/sample_aes.py
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
"""Apple HLS SAMPLE-AES decryptor for MPEG-TS elementary streams.
|
||||||
|
|
||||||
|
SAMPLE-AES (``METHOD=SAMPLE-AES``) encrypts only parts of each elementary-stream sample
|
||||||
|
rather than the whole TS, so neither Shaka Packager (rejects stream type ``0xDB``) nor
|
||||||
|
mp4decrypt (ISO-BMFF only) can decrypt it. This module implements the scheme directly,
|
||||||
|
matching the four reference implementations (hls.js, mchelnokov/decrypt-mpegts,
|
||||||
|
iori-rs/iori, N_m3u8DL-RE) and Apple's specification:
|
||||||
|
|
||||||
|
- H.264 (AVC, stream type ``0xDB``): NAL types 1 and 5 only. Remove emulation-prevention
|
||||||
|
bytes, skip the first 32 bytes, then a 1:9 pattern — decrypt 16 bytes, skip ``min(144,
|
||||||
|
remaining)`` — with AES-128-CBC chaining continuously across the encrypted blocks and the
|
||||||
|
IV reset per NAL unit. Emulation-prevention bytes are reinserted afterwards.
|
||||||
|
- AAC (ADTS, stream type ``0xCF``): per frame, skip the ADTS header plus 16 leader bytes,
|
||||||
|
decrypt the remaining complete 16-byte blocks, IV reset per frame.
|
||||||
|
- AC-3 (stream type ``0xC1``): per syncframe, skip 16 leader bytes, decrypt the remaining
|
||||||
|
complete 16-byte blocks, IV reset per frame.
|
||||||
|
|
||||||
|
All decryption is AES-128-CBC with no padding; trailing bytes shorter than a block stay clear.
|
||||||
|
|
||||||
|
The IV is supplied by the caller. HLS signals it in the ``EXT-X-KEY`` ``IV`` attribute; when
|
||||||
|
absent the conventional default is all zeros (used by every reference tool). FairPlay
|
||||||
|
(``KEYFORMAT="com.apple.streamingkeydelivery"``) streams may instead carry the IV in the CKC,
|
||||||
|
in which case the caller must provide it — see ``decrypt_sample_aes``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import struct
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from Crypto.Cipher import AES
|
||||||
|
|
||||||
|
# MPEG-TS stream type bytes: encrypted -> standard.
|
||||||
|
STREAM_TYPE_H264 = 0x1B
|
||||||
|
STREAM_TYPE_AAC = 0x0F
|
||||||
|
STREAM_TYPE_AC3 = 0x81
|
||||||
|
ENCRYPTED_TO_CLEAR = {0xDB: STREAM_TYPE_H264, 0xCF: STREAM_TYPE_AAC, 0xC1: STREAM_TYPE_AC3}
|
||||||
|
|
||||||
|
PACKET_LENGTH = 188
|
||||||
|
SYNC_BYTE = 0x47
|
||||||
|
BLOCK = 16
|
||||||
|
|
||||||
|
|
||||||
|
def remove_epb(data: bytes) -> bytes:
|
||||||
|
"""Strip H.264 emulation-prevention bytes (``00 00 03 XX`` -> ``00 00 XX``, ``XX <= 03``)."""
|
||||||
|
out = bytearray()
|
||||||
|
i = 0
|
||||||
|
n = len(data)
|
||||||
|
while i < n:
|
||||||
|
if i + 3 < n and data[i] == 0 and data[i + 1] == 0 and data[i + 2] == 3 and data[i + 3] <= 3:
|
||||||
|
out += b"\x00\x00"
|
||||||
|
i += 3
|
||||||
|
else:
|
||||||
|
out.append(data[i])
|
||||||
|
i += 1
|
||||||
|
return bytes(out)
|
||||||
|
|
||||||
|
|
||||||
|
def insert_epb(data: bytes) -> bytes:
|
||||||
|
"""Reinsert H.264 emulation-prevention bytes to produce a valid Annex B / byte-stream NAL."""
|
||||||
|
out = bytearray()
|
||||||
|
zeros = 0
|
||||||
|
for b in data:
|
||||||
|
if zeros >= 2 and b <= 3:
|
||||||
|
out.append(3)
|
||||||
|
zeros = 0
|
||||||
|
out.append(b)
|
||||||
|
zeros = zeros + 1 if b == 0 else 0
|
||||||
|
return bytes(out)
|
||||||
|
|
||||||
|
|
||||||
|
def aes_cbc_decrypt(key: bytes, iv: bytes, data: bytes) -> bytes:
|
||||||
|
"""AES-128-CBC decrypt with no padding; ``data`` length must be a multiple of 16."""
|
||||||
|
return AES.new(key, AES.MODE_CBC, iv).decrypt(data)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_h264_nal(nal: bytes, key: bytes, iv: bytes) -> bytes:
|
||||||
|
"""Decrypt a single H.264 NAL unit (no start code) per SAMPLE-AES; return re-escaped bytes."""
|
||||||
|
rbsp = bytearray(remove_epb(nal))
|
||||||
|
n = len(rbsp)
|
||||||
|
if n <= 48:
|
||||||
|
return nal
|
||||||
|
# Gather the encrypted blocks (1:9 pattern from offset 32) into one continuous CBC stream.
|
||||||
|
encrypted = bytearray()
|
||||||
|
offset = 32
|
||||||
|
positions = []
|
||||||
|
while True:
|
||||||
|
remaining = n - offset
|
||||||
|
if remaining <= BLOCK:
|
||||||
|
break
|
||||||
|
positions.append(offset)
|
||||||
|
encrypted += rbsp[offset : offset + BLOCK]
|
||||||
|
offset += BLOCK + min(144, remaining - BLOCK)
|
||||||
|
decrypted = aes_cbc_decrypt(key, iv, bytes(encrypted))
|
||||||
|
for idx, pos in enumerate(positions):
|
||||||
|
rbsp[pos : pos + BLOCK] = decrypted[idx * BLOCK : (idx + 1) * BLOCK]
|
||||||
|
return insert_epb(bytes(rbsp))
|
||||||
|
|
||||||
|
|
||||||
|
def split_nals(es: bytes) -> list[tuple[int, int, int]]:
|
||||||
|
"""Return (start_code_len, nal_start, nal_end) for each Annex B NAL unit in ``es``."""
|
||||||
|
nals = []
|
||||||
|
i = 0
|
||||||
|
n = len(es)
|
||||||
|
starts = []
|
||||||
|
while i + 3 <= n:
|
||||||
|
if es[i] == 0 and es[i + 1] == 0 and es[i + 2] == 1:
|
||||||
|
sc = 4 if i >= 1 and es[i - 1] == 0 else 3
|
||||||
|
starts.append((i - (sc - 3), sc))
|
||||||
|
i += 3
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
for idx, (sc_off, sc) in enumerate(starts):
|
||||||
|
nal_start = sc_off + sc
|
||||||
|
nal_end = starts[idx + 1][0] if idx + 1 < len(starts) else n
|
||||||
|
nals.append((sc, nal_start, nal_end))
|
||||||
|
return nals
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_h264_es(es: bytes, key: bytes, iv: bytes) -> bytes:
|
||||||
|
"""Decrypt all coded-slice NAL units (types 1, 5) in an Annex B H.264 elementary stream."""
|
||||||
|
out = bytearray()
|
||||||
|
prev_end = 0
|
||||||
|
for sc, nal_start, nal_end in split_nals(es):
|
||||||
|
out += es[prev_end:nal_start]
|
||||||
|
nal = es[nal_start:nal_end]
|
||||||
|
prev_end = nal_end
|
||||||
|
if nal and (nal[0] & 0x1F) in (1, 5):
|
||||||
|
out += decrypt_h264_nal(nal, key, iv)
|
||||||
|
else:
|
||||||
|
out += nal
|
||||||
|
out += es[prev_end:]
|
||||||
|
return bytes(out)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_aac_es(es: bytes, key: bytes, iv: bytes) -> bytes:
|
||||||
|
"""Decrypt ADTS AAC: per frame skip header + 16 leader bytes, decrypt remaining full blocks."""
|
||||||
|
out = bytearray()
|
||||||
|
i = 0
|
||||||
|
n = len(es)
|
||||||
|
while i + 7 <= n:
|
||||||
|
if es[i] != 0xFF or (es[i + 1] & 0xF0) != 0xF0:
|
||||||
|
out.append(es[i])
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
header_len = 7 if (es[i + 1] & 1) else 9
|
||||||
|
frame_len = ((es[i + 3] & 3) << 11) | (es[i + 4] << 3) | (es[i + 5] >> 5)
|
||||||
|
if frame_len < header_len or i + frame_len > n:
|
||||||
|
out += es[i:]
|
||||||
|
i = n
|
||||||
|
break
|
||||||
|
frame = bytearray(es[i : i + frame_len])
|
||||||
|
payload = frame[header_len:]
|
||||||
|
if len(payload) > BLOCK:
|
||||||
|
enc_len = ((len(payload) - BLOCK) // BLOCK) * BLOCK
|
||||||
|
if enc_len:
|
||||||
|
frame[header_len + BLOCK : header_len + BLOCK + enc_len] = aes_cbc_decrypt(
|
||||||
|
key, iv, bytes(payload[BLOCK : BLOCK + enc_len])
|
||||||
|
)
|
||||||
|
out += frame
|
||||||
|
i += frame_len
|
||||||
|
out += es[i:]
|
||||||
|
return bytes(out)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_ac3_es(es: bytes, key: bytes, iv: bytes) -> bytes:
|
||||||
|
"""Decrypt AC-3: per syncframe skip 16 leader bytes, decrypt remaining full blocks.
|
||||||
|
|
||||||
|
AC-3 syncframes start with ``0B 77``; frame size derives from the frmsizecod table. To keep
|
||||||
|
this dependency-free the whole ES (minus a 16-byte leader) is treated as one CBC region,
|
||||||
|
matching the reference behaviour for single-program AC-3 PES payloads.
|
||||||
|
"""
|
||||||
|
if len(es) <= BLOCK:
|
||||||
|
return es
|
||||||
|
enc_len = ((len(es) - BLOCK) // BLOCK) * BLOCK
|
||||||
|
if not enc_len:
|
||||||
|
return es
|
||||||
|
out = bytearray(es)
|
||||||
|
out[BLOCK : BLOCK + enc_len] = aes_cbc_decrypt(key, iv, bytes(out[BLOCK : BLOCK + enc_len]))
|
||||||
|
return bytes(out)
|
||||||
|
|
||||||
|
|
||||||
|
def crc32_mpeg(data: bytes) -> int:
|
||||||
|
"""MPEG-2 systems CRC-32 (poly 0x04C11DB7, init 0xFFFFFFFF, no final xor)."""
|
||||||
|
crc = 0xFFFFFFFF
|
||||||
|
for byte in data:
|
||||||
|
crc ^= byte << 24
|
||||||
|
for _ in range(8):
|
||||||
|
crc = ((crc << 1) ^ 0x04C11DB7) & 0xFFFFFFFF if crc & 0x80000000 else (crc << 1) & 0xFFFFFFFF
|
||||||
|
return crc
|
||||||
|
|
||||||
|
|
||||||
|
def packet_pid(pkt: bytes) -> int:
|
||||||
|
return ((pkt[1] & 0x1F) << 8) | pkt[2]
|
||||||
|
|
||||||
|
|
||||||
|
def payload_offset(pkt: bytes) -> int:
|
||||||
|
"""Byte offset of the payload within a TS packet, or -1 if the packet carries none."""
|
||||||
|
afc = (pkt[3] >> 4) & 0x3
|
||||||
|
off = 4
|
||||||
|
if afc & 0x2:
|
||||||
|
off += 1 + pkt[4]
|
||||||
|
if not afc & 0x1 or off >= PACKET_LENGTH:
|
||||||
|
return -1
|
||||||
|
return off
|
||||||
|
|
||||||
|
|
||||||
|
def parse_pmt(section: bytes) -> dict[int, int]:
|
||||||
|
"""Map elementary PID -> stream_type from a PMT section (starting at table_id)."""
|
||||||
|
section_length = ((section[1] & 0x0F) << 8) | section[2]
|
||||||
|
program_info_length = ((section[10] & 0x0F) << 8) | section[11]
|
||||||
|
pos = 12 + program_info_length
|
||||||
|
end = 3 + section_length - 4 # up to (excluding) CRC32
|
||||||
|
result = {}
|
||||||
|
while pos + 5 <= end:
|
||||||
|
stream_type = section[pos]
|
||||||
|
es_pid = ((section[pos + 1] & 0x1F) << 8) | section[pos + 2]
|
||||||
|
es_info_length = ((section[pos + 3] & 0x0F) << 8) | section[pos + 4]
|
||||||
|
result[es_pid] = stream_type
|
||||||
|
pos += 5 + es_info_length
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def patch_pmt_section(section: bytes) -> bytes:
|
||||||
|
"""Rewrite encrypted stream-type bytes to their clear equivalents and fix the CRC."""
|
||||||
|
section = bytearray(section)
|
||||||
|
section_length = ((section[1] & 0x0F) << 8) | section[2]
|
||||||
|
program_info_length = ((section[10] & 0x0F) << 8) | section[11]
|
||||||
|
pos = 12 + program_info_length
|
||||||
|
end = 3 + section_length - 4
|
||||||
|
while pos + 5 <= end:
|
||||||
|
if section[pos] in ENCRYPTED_TO_CLEAR:
|
||||||
|
section[pos] = ENCRYPTED_TO_CLEAR[section[pos]]
|
||||||
|
es_info_length = ((section[pos + 3] & 0x0F) << 8) | section[pos + 4]
|
||||||
|
pos += 5 + es_info_length
|
||||||
|
crc = crc32_mpeg(bytes(section[:end]))
|
||||||
|
section[end : end + 4] = struct.pack(">I", crc)
|
||||||
|
return bytes(section)
|
||||||
|
|
||||||
|
|
||||||
|
def repacketize(pid: int, pes: bytes, continuity: dict[int, int]) -> bytes:
|
||||||
|
"""Split a PES packet into 188-byte TS packets for ``pid`` with stuffing on the final one."""
|
||||||
|
out = bytearray()
|
||||||
|
pos = 0
|
||||||
|
first = True
|
||||||
|
n = len(pes)
|
||||||
|
while pos < n:
|
||||||
|
remaining = n - pos
|
||||||
|
cc = continuity[pid]
|
||||||
|
continuity[pid] = (cc + 1) & 0x0F
|
||||||
|
if remaining >= 184:
|
||||||
|
header = bytes([SYNC_BYTE, (0x40 if first else 0x00) | (pid >> 8), pid & 0xFF, 0x10 | cc])
|
||||||
|
out += header + pes[pos : pos + 184]
|
||||||
|
pos += 184
|
||||||
|
else:
|
||||||
|
stuffing = 184 - remaining
|
||||||
|
header = bytes([SYNC_BYTE, (0x40 if first else 0x00) | (pid >> 8), pid & 0xFF, 0x30 | cc])
|
||||||
|
if stuffing == 1:
|
||||||
|
adaptation = bytes([0x00])
|
||||||
|
else:
|
||||||
|
adaptation = bytes([stuffing - 1, 0x00]) + b"\xFF" * (stuffing - 2)
|
||||||
|
out += header + adaptation + pes[pos:]
|
||||||
|
pos = n
|
||||||
|
first = False
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def split_pes_payload(pes: bytes) -> tuple[bytes, bytes]:
|
||||||
|
"""Split a PES packet into (header_prefix_including_optional_header, elementary_payload)."""
|
||||||
|
if pes[:3] != b"\x00\x00\x01":
|
||||||
|
return pes, b""
|
||||||
|
stream_id = pes[3]
|
||||||
|
# Stream ids without an optional PES header (padding, private_2, etc.) carry payload directly.
|
||||||
|
if stream_id in (0xBC, 0xBE, 0xBF, 0xF0, 0xF1, 0xFF, 0xF2, 0xF8):
|
||||||
|
return pes[:6], pes[6:]
|
||||||
|
header_data_length = pes[8]
|
||||||
|
start = 9 + header_data_length
|
||||||
|
return pes[:start], pes[start:]
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_pes(prefix: bytes, payload: bytes) -> bytes:
|
||||||
|
"""Reassemble a PES packet, refreshing PES_packet_length for the new payload size."""
|
||||||
|
prefix = bytearray(prefix)
|
||||||
|
body_length = (len(prefix) - 6) + len(payload)
|
||||||
|
prefix[4:6] = struct.pack(">H", body_length if body_length <= 0xFFFF else 0)
|
||||||
|
return bytes(prefix) + payload
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_sample_aes(data: bytes, key: bytes, iv: Optional[bytes] = None) -> bytes:
|
||||||
|
"""Decrypt a SAMPLE-AES MPEG-TS byte string and return a clean (standard) MPEG-TS.
|
||||||
|
|
||||||
|
``key`` is the 16-byte content key. ``iv`` is the 16-byte initialization vector; when not
|
||||||
|
given it defaults to all zeros (the HLS convention for an absent ``EXT-X-KEY`` IV). The
|
||||||
|
output has its PMT stream types patched to the standard values, so it muxes like any TS.
|
||||||
|
"""
|
||||||
|
if iv is None:
|
||||||
|
iv = b"\x00" * BLOCK
|
||||||
|
|
||||||
|
sync = data.find(bytes([SYNC_BYTE]))
|
||||||
|
if sync < 0:
|
||||||
|
raise ValueError("Not an MPEG-TS stream (no sync byte found)")
|
||||||
|
packets = [data[i : i + PACKET_LENGTH] for i in range(sync, len(data) - PACKET_LENGTH + 1, PACKET_LENGTH)]
|
||||||
|
|
||||||
|
# First pass: locate the PMT and the elementary stream types.
|
||||||
|
pmt_pid: Optional[int] = None
|
||||||
|
pid_types: dict[int, int] = {}
|
||||||
|
for pkt in packets:
|
||||||
|
if pkt[0] != SYNC_BYTE:
|
||||||
|
continue
|
||||||
|
pid = packet_pid(pkt)
|
||||||
|
off = payload_offset(pkt)
|
||||||
|
if off < 0:
|
||||||
|
continue
|
||||||
|
if pid == 0 and pkt[1] & 0x40: # PAT
|
||||||
|
section = pkt[off + 1 + pkt[off] :]
|
||||||
|
pmt_pid = ((section[10] & 0x1F) << 8) | section[11]
|
||||||
|
elif pmt_pid is not None and pid == pmt_pid and pkt[1] & 0x40:
|
||||||
|
section = pkt[off + 1 + pkt[off] :]
|
||||||
|
pid_types = parse_pmt(section)
|
||||||
|
break
|
||||||
|
|
||||||
|
decrypters = {
|
||||||
|
0xDB: decrypt_h264_es,
|
||||||
|
0xCF: decrypt_aac_es,
|
||||||
|
0xC1: decrypt_ac3_es,
|
||||||
|
}
|
||||||
|
encrypted_pids = {pid: t for pid, t in pid_types.items() if t in decrypters}
|
||||||
|
if not encrypted_pids:
|
||||||
|
return data # nothing SAMPLE-AES encrypted
|
||||||
|
|
||||||
|
# Second pass: rebuild the stream, decrypting elementary PIDs and patching the PMT.
|
||||||
|
out = bytearray()
|
||||||
|
continuity: dict[int, int] = {pid: 0 for pid in encrypted_pids}
|
||||||
|
pending: dict[int, bytearray] = {}
|
||||||
|
|
||||||
|
def flush(pid: int) -> None:
|
||||||
|
pes = bytes(pending.pop(pid))
|
||||||
|
prefix, payload = split_pes_payload(pes)
|
||||||
|
decrypted = decrypters[encrypted_pids[pid]](payload, key, iv)
|
||||||
|
out.extend(repacketize(pid, rebuild_pes(prefix, decrypted), continuity))
|
||||||
|
|
||||||
|
for pkt in packets:
|
||||||
|
if pkt[0] != SYNC_BYTE:
|
||||||
|
continue
|
||||||
|
pid = packet_pid(pkt)
|
||||||
|
if pid in encrypted_pids:
|
||||||
|
off = payload_offset(pkt)
|
||||||
|
if pkt[1] & 0x40: # PUSI -> new PES
|
||||||
|
if pid in pending:
|
||||||
|
flush(pid)
|
||||||
|
pending[pid] = bytearray(pkt[off:] if off >= 0 else b"")
|
||||||
|
elif pid in pending and off >= 0:
|
||||||
|
pending[pid] += pkt[off:]
|
||||||
|
continue
|
||||||
|
if pmt_pid is not None and pid == pmt_pid and pkt[1] & 0x40:
|
||||||
|
off = payload_offset(pkt)
|
||||||
|
pointer = pkt[off]
|
||||||
|
head = pkt[: off + 1 + pointer]
|
||||||
|
section = pkt[off + 1 + pointer :]
|
||||||
|
patched = patch_pmt_section(section)
|
||||||
|
packet = head + patched
|
||||||
|
packet = packet[:PACKET_LENGTH] + b"\xFF" * (PACKET_LENGTH - len(packet))
|
||||||
|
out += packet[:PACKET_LENGTH]
|
||||||
|
continue
|
||||||
|
out += pkt
|
||||||
|
|
||||||
|
for pid in list(pending):
|
||||||
|
flush(pid)
|
||||||
|
return bytes(out)
|
||||||
Reference in New Issue
Block a user