6 Commits

10 changed files with 137 additions and 51 deletions

View File

@@ -2,6 +2,10 @@
<img width="16" height="16" alt="no_encryption" src="https://github.com/user-attachments/assets/6ff88473-0dd2-4bbc-b1ea-c683d5d7a134" /> unshackle
<br/>
<sup><em>Movie, TV, and Music Archival Software</em></sup>
<br/>
<a href="https://discord.gg/mHYyPaCbFK">
<img src="https://img.shields.io/discord/1395571732001325127?label=&logo=discord&logoColor=ffffff&color=7289DA&labelColor=7289DA" alt="Discord">
</a>
</p>
## What is unshackle?

View File

@@ -153,6 +153,13 @@ class dl:
default=[],
help="Language wanted for Video, you would use this if the video language doesn't match the audio.",
)
@click.option(
"-al",
"--a-lang",
type=LANGUAGE_RANGE,
default=[],
help="Language wanted for Audio, overrides -l/--lang for audio tracks.",
)
@click.option("-sl", "--s-lang", type=LANGUAGE_RANGE, default=["all"], help="Language wanted for Subtitles.")
@click.option("-fs", "--forced-subs", is_flag=True, default=False, help="Include forced subtitle tracks.")
@click.option(
@@ -413,6 +420,7 @@ class dl:
wanted: list[str],
lang: list[str],
v_lang: list[str],
a_lang: list[str],
s_lang: list[str],
forced_subs: bool,
sub_format: Optional[Subtitle.Codec],
@@ -588,8 +596,9 @@ class dl:
if language not in processed_video_sort_lang:
processed_video_sort_lang.append(language)
audio_sort_lang = a_lang or lang
processed_audio_sort_lang = []
for language in lang:
for language in audio_sort_lang:
if language == "orig":
if title.language:
orig_lang = str(title.language) if hasattr(title.language, "__str__") else title.language
@@ -753,9 +762,10 @@ class dl:
if not title.tracks.audio:
self.log.error(f"There's no {abitrate}kbps Audio Track...")
sys.exit(1)
if lang:
audio_languages = a_lang or lang
if audio_languages:
processed_lang = []
for language in lang:
for language in audio_languages:
if language == "orig":
if title.language:
orig_lang = (
@@ -782,7 +792,7 @@ class dl:
selected_audio.append(highest_quality)
title.tracks.audio = selected_audio
elif "all" not in processed_lang:
per_language = 0 if len(processed_lang) > 1 else 1
per_language = 1
title.tracks.audio = title.tracks.by_language(
title.tracks.audio, processed_lang, per_language=per_language
)

View File

@@ -4,6 +4,7 @@ import base64
import html
import json
import logging
import os
import shutil
import subprocess
import sys
@@ -584,11 +585,24 @@ class HLS:
if DOWNLOAD_LICENCE_ONLY.is_set():
return
if segment_save_dir.exists():
segment_save_dir.rmdir()
def find_segments_recursively(directory: Path) -> list[Path]:
"""Find all segment files recursively in any directory structure created by downloaders."""
segments = []
# First check direct files in the directory
if directory.exists():
segments.extend([x for x in directory.iterdir() if x.is_file()])
# If no direct files, recursively search subdirectories
if not segments:
for subdir in directory.iterdir():
if subdir.is_dir():
segments.extend(find_segments_recursively(subdir))
return sorted(segments)
# finally merge all the discontinuity save files together to the final path
segments_to_merge = [x for x in sorted(save_dir.iterdir()) if x.is_file()]
segments_to_merge = find_segments_recursively(save_dir)
if len(segments_to_merge) == 1:
shutil.move(segments_to_merge[0], save_path)
else:
@@ -601,9 +615,16 @@ class HLS:
discontinuity_data = discontinuity_file.read_bytes()
f.write(discontinuity_data)
f.flush()
os.fsync(f.fileno())
discontinuity_file.unlink()
save_dir.rmdir()
# Clean up empty segment directory
if save_dir.exists() and save_dir.name.endswith("_segments"):
try:
save_dir.rmdir()
except OSError:
# Directory might not be empty, try removing recursively
shutil.rmtree(save_dir, ignore_errors=True)
progress(downloaded="Downloaded")
@@ -613,40 +634,75 @@ class HLS:
@staticmethod
def merge_segments(segments: list[Path], save_path: Path) -> int:
"""
Concatenate Segments by first demuxing with FFmpeg.
Concatenate Segments using FFmpeg concat with binary fallback.
Returns the file size of the merged file.
"""
if not binaries.FFMPEG:
raise EnvironmentError("FFmpeg executable was not found but is required to merge HLS segments.")
demuxer_file = segments[0].parent / "ffmpeg_concat_demuxer.txt"
demuxer_file.write_text("\n".join([f"file '{segment}'" for segment in segments]))
subprocess.check_call(
[
binaries.FFMPEG,
"-hide_banner",
"-loglevel",
"panic",
"-f",
"concat",
"-safe",
"0",
"-i",
demuxer_file,
"-map",
"0",
"-c",
"copy",
save_path,
]
)
demuxer_file.unlink()
# Track segment directories for cleanup
segment_dirs = set()
for segment in segments:
segment.unlink()
# Track all parent directories that contain segments
current_dir = segment.parent
while current_dir.name and "_segments" in str(current_dir):
segment_dirs.add(current_dir)
current_dir = current_dir.parent
def cleanup_segments_and_dirs():
"""Clean up segments and directories after successful merge."""
for segment in segments:
segment.unlink(missing_ok=True)
for segment_dir in segment_dirs:
if segment_dir.exists():
try:
shutil.rmtree(segment_dir)
except OSError:
pass # Directory cleanup failed, but merge succeeded
# Try FFmpeg concat first (preferred method)
if binaries.FFMPEG:
try:
demuxer_file = save_path.parent / f"ffmpeg_concat_demuxer_{save_path.stem}.txt"
demuxer_file.write_text("\n".join([f"file '{segment.absolute()}'" for segment in segments]))
subprocess.check_call(
[
binaries.FFMPEG,
"-hide_banner",
"-loglevel",
"error",
"-f",
"concat",
"-safe",
"0",
"-i",
demuxer_file,
"-map",
"0",
"-c",
"copy",
save_path,
],
timeout=300, # 5 minute timeout
)
demuxer_file.unlink(missing_ok=True)
cleanup_segments_and_dirs()
return save_path.stat().st_size
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, OSError) as e:
# FFmpeg failed, clean up demuxer file and fall back to binary concat
logging.getLogger("HLS").debug(f"FFmpeg concat failed ({e}), falling back to binary concatenation")
demuxer_file.unlink(missing_ok=True)
# Remove partial output file if it exists
save_path.unlink(missing_ok=True)
# Fallback: Binary concatenation
logging.getLogger("HLS").debug(f"Using binary concatenation for {len(segments)} segments")
with open(save_path, "wb") as output_file:
for segment in segments:
with open(segment, "rb") as segment_file:
output_file.write(segment_file.read())
cleanup_segments_and_dirs()
return save_path.stat().st_size
@staticmethod

View File

@@ -4,8 +4,9 @@ from uuid import UUID
class Vault(metaclass=ABCMeta):
def __init__(self, name: str):
def __init__(self, name: str, no_push: bool = False):
self.name = name
self.no_push = no_push
def __str__(self) -> str:
return f"{self.name} {type(self).__name__}"

View File

@@ -57,7 +57,7 @@ class Vaults:
"""Add a KID:KEY to all Vaults, optionally with an exclusion."""
success = 0
for vault in self.vaults:
if vault != excluding:
if vault != excluding and not vault.no_push:
try:
success += vault.add_key(self.service, kid, key)
except (PermissionError, NotImplementedError):
@@ -68,13 +68,15 @@ class Vaults:
"""
Add multiple KID:KEYs to all Vaults. Duplicate Content Keys are skipped.
PermissionErrors when the user cannot create Tables are absorbed and ignored.
Vaults with no_push=True are skipped.
"""
success = 0
for vault in self.vaults:
try:
success += bool(vault.add_keys(self.service, kid_keys))
except (PermissionError, NotImplementedError):
pass
if not vault.no_push:
try:
success += bool(vault.add_keys(self.service, kid_keys))
except (PermissionError, NotImplementedError):
pass
return success

View File

@@ -101,6 +101,8 @@ remote_cdm:
secret: secret_key
# Key Vaults store your obtained Content Encryption Keys (CEKs)
# Use 'no_push: true' to prevent a vault from receiving pushed keys
# while still allowing it to provide keys when requested
key_vaults:
- type: SQLite
name: Local
@@ -110,6 +112,7 @@ key_vaults:
# name: "Remote Vault"
# uri: "https://key-vault.example.com"
# token: "secret_token"
# no_push: true # This vault will only provide keys, not receive them
# - type: MySQL
# name: "MySQL Vault"
# host: "127.0.0.1"
@@ -117,6 +120,7 @@ key_vaults:
# database: vault
# username: user
# password: pass
# no_push: false # Default behavior - vault both provides and receives keys
# Choose what software to use to download data
downloader: aria2c

View File

@@ -10,8 +10,8 @@ from unshackle.core.vault import Vault
class API(Vault):
"""Key Vault using a simple RESTful HTTP API call."""
def __init__(self, name: str, uri: str, token: str):
super().__init__(name)
def __init__(self, name: str, uri: str, token: str, no_push: bool = False):
super().__init__(name, no_push)
self.uri = uri.rstrip("/")
self.session = Session()
self.session.headers.update({"User-Agent": f"unshackle v{__version__}"})

View File

@@ -18,7 +18,15 @@ class InsertResult(Enum):
class HTTP(Vault):
"""Key Vault using HTTP API with support for both query parameters and JSON payloads."""
def __init__(self, name: str, host: str, password: str, username: Optional[str] = None, api_mode: str = "query"):
def __init__(
self,
name: str,
host: str,
password: str,
username: Optional[str] = None,
api_mode: str = "query",
no_push: bool = False,
):
"""
Initialize HTTP Vault.
@@ -28,8 +36,9 @@ class HTTP(Vault):
password: Password for query mode or API token for json mode
username: Username (required for query mode, ignored for json mode)
api_mode: "query" for query parameters or "json" for JSON API
no_push: If True, this vault will not receive pushed keys
"""
super().__init__(name)
super().__init__(name, no_push)
self.url = host
self.password = password
self.username = username

View File

@@ -12,12 +12,12 @@ from unshackle.core.vault import Vault
class MySQL(Vault):
"""Key Vault using a remotely-accessed mysql database connection."""
def __init__(self, name: str, host: str, database: str, username: str, **kwargs):
def __init__(self, name: str, host: str, database: str, username: str, no_push: bool = False, **kwargs):
"""
All extra arguments provided via **kwargs will be sent to pymysql.connect.
This can be used to provide more specific connection information.
"""
super().__init__(name)
super().__init__(name, no_push)
self.slug = f"{host}:{database}:{username}"
self.conn_factory = ConnectionFactory(
dict(host=host, db=database, user=username, cursorclass=DictCursor, **kwargs)

View File

@@ -12,8 +12,8 @@ from unshackle.core.vault import Vault
class SQLite(Vault):
"""Key Vault using a locally-accessed sqlite DB file."""
def __init__(self, name: str, path: Union[str, Path]):
super().__init__(name)
def __init__(self, name: str, path: Union[str, Path], no_push: bool = False):
super().__init__(name, no_push)
self.path = Path(path).expanduser()
# TODO: Use a DictCursor or such to get fetches as dict?
self.conn_factory = ConnectionFactory(self.path)