From 6b90a19632e54abd137269b66b0dcdbe59539be6 Mon Sep 17 00:00:00 2001 From: Andy Date: Fri, 23 Jan 2026 17:14:12 -0700 Subject: [PATCH] perf(aria2c): improve download performance with singleton manager - Use singleton _Aria2Manager to reuse one aria2c process via RPC - Add downloads via aria2.addUri instead of stdin input file - Track per-GID byte-level progress (completedLength/totalLength) - Add thread-safe operations with threading.Lock - Enable graceful cancellation by removing individual downloads via RPC --- unshackle/core/downloaders/aria2c.py | 459 +++++++++++++++------------ 1 file changed, 261 insertions(+), 198 deletions(-) diff --git a/unshackle/core/downloaders/aria2c.py b/unshackle/core/downloaders/aria2c.py index 6f5b5d0..af7d34f 100644 --- a/unshackle/core/downloaders/aria2c.py +++ b/unshackle/core/downloaders/aria2c.py @@ -1,6 +1,7 @@ import os import subprocess import textwrap +import threading import time from functools import partial from http.cookiejar import CookieJar @@ -49,6 +50,138 @@ def rpc(caller: Callable, secret: str, method: str, params: Optional[list[Any]] return +class _Aria2Manager: + """Singleton manager to run one aria2c process and enqueue downloads via RPC.""" + + def __init__(self) -> None: + self._proc: Optional[subprocess.Popen] = None + self._rpc_port: Optional[int] = None + self._rpc_secret: Optional[str] = None + self._rpc_uri: Optional[str] = None + self._session: Session = Session() + self._max_concurrent_downloads: int = 0 + self._max_connection_per_server: int = 1 + self._split_default: int = 5 + self._file_allocation: str = "prealloc" + self._proxy: Optional[str] = None + self._lock: threading.Lock = threading.Lock() + + def _build_args(self) -> list[str]: + args = [ + "--continue=true", + f"--max-concurrent-downloads={self._max_concurrent_downloads}", + f"--max-connection-per-server={self._max_connection_per_server}", + f"--split={self._split_default}", + "--max-file-not-found=5", + "--max-tries=5", + "--retry-wait=2", + "--allow-overwrite=true", + "--auto-file-renaming=false", + "--console-log-level=warn", + "--download-result=default", + f"--file-allocation={self._file_allocation}", + "--summary-interval=0", + "--enable-rpc=true", + f"--rpc-listen-port={self._rpc_port}", + f"--rpc-secret={self._rpc_secret}", + ] + if self._proxy: + args.extend(["--all-proxy", self._proxy]) + return args + + def ensure_started( + self, + proxy: Optional[str], + max_workers: Optional[int], + ) -> None: + with self._lock: + if self._proc and self._proc.poll() is None: + return + + if not binaries.Aria2: + debug_logger = get_debug_logger() + if debug_logger: + debug_logger.log( + level="ERROR", + operation="downloader_aria2c_binary_missing", + message="Aria2c executable not found in PATH or local binaries directory", + context={"searched_names": ["aria2c", "aria2"]}, + ) + raise EnvironmentError("Aria2c executable not found...") + + if not max_workers: + max_workers = min(32, (os.cpu_count() or 1) + 4) + elif not isinstance(max_workers, int): + raise TypeError(f"Expected max_workers to be {int}, not {type(max_workers)}") + + self._rpc_port = get_free_port() + self._rpc_secret = get_random_bytes(16).hex() + self._rpc_uri = f"http://127.0.0.1:{self._rpc_port}/jsonrpc" + + self._max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers)) + self._max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1)) + self._split_default = int(config.aria2c.get("split", 5)) + self._file_allocation = config.aria2c.get("file_allocation", "prealloc") + self._proxy = proxy or None + + args = self._build_args() + self._proc = subprocess.Popen( + [binaries.Aria2, *args], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + # Give aria2c a moment to start up and bind to the RPC port + time.sleep(0.5) + + @property + def rpc_uri(self) -> str: + assert self._rpc_uri + return self._rpc_uri + + @property + def rpc_secret(self) -> str: + assert self._rpc_secret + return self._rpc_secret + + @property + def session(self) -> Session: + return self._session + + def add_uris(self, uris: list[str], options: dict[str, Any]) -> str: + """Add a single download with multiple URIs via RPC.""" + gid = rpc( + caller=partial(self._session.post, url=self.rpc_uri), + secret=self.rpc_secret, + method="aria2.addUri", + params=[uris, options], + ) + return gid or "" + + def get_global_stat(self) -> dict[str, Any]: + return rpc( + caller=partial(self.session.post, url=self.rpc_uri), + secret=self.rpc_secret, + method="aria2.getGlobalStat", + ) or {} + + def tell_status(self, gid: str) -> Optional[dict[str, Any]]: + return rpc( + caller=partial(self.session.post, url=self.rpc_uri), + secret=self.rpc_secret, + method="aria2.tellStatus", + params=[gid, ["status", "errorCode", "errorMessage", "files", "completedLength", "totalLength"]], + ) + + def remove(self, gid: str) -> None: + rpc( + caller=partial(self.session.post, url=self.rpc_uri), + secret=self.rpc_secret, + method="aria2.forceRemove", + params=[gid], + ) + + +_manager = _Aria2Manager() + + def download( urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]], output_dir: Path, @@ -58,6 +191,7 @@ def download( proxy: Optional[str] = None, max_workers: Optional[int] = None, ) -> Generator[dict[str, Any], None, None]: + """Enqueue downloads to the singleton aria2c instance via stdin and track per-call progress via RPC.""" debug_logger = get_debug_logger() if not urls: @@ -92,102 +226,10 @@ def download( if not isinstance(urls, list): urls = [urls] - if not binaries.Aria2: - if debug_logger: - debug_logger.log( - level="ERROR", - operation="downloader_aria2c_binary_missing", - message="Aria2c executable not found in PATH or local binaries directory", - context={"searched_names": ["aria2c", "aria2"]}, - ) - raise EnvironmentError("Aria2c executable not found...") - - if proxy and not proxy.lower().startswith("http://"): - raise ValueError("Only HTTP proxies are supported by aria2(c)") - if cookies and not isinstance(cookies, CookieJar): cookies = cookiejar_from_dict(cookies) - url_files = [] - for i, url in enumerate(urls): - if isinstance(url, str): - url_data = {"url": url} - else: - url_data: dict[str, Any] = url - url_filename = filename.format(i=i, ext=get_extension(url_data["url"])) - url_text = url_data["url"] - url_text += f"\n\tdir={output_dir}" - url_text += f"\n\tout={url_filename}" - if cookies: - mock_request = requests.Request(url=url_data["url"]) - cookie_header = get_cookie_header(cookies, mock_request) - if cookie_header: - url_text += f"\n\theader=Cookie: {cookie_header}" - for key, value in url_data.items(): - if key == "url": - continue - if key == "headers": - for header_name, header_value in value.items(): - url_text += f"\n\theader={header_name}: {header_value}" - else: - url_text += f"\n\t{key}={value}" - url_files.append(url_text) - url_file = "\n".join(url_files) - - rpc_port = get_free_port() - rpc_secret = get_random_bytes(16).hex() - rpc_uri = f"http://127.0.0.1:{rpc_port}/jsonrpc" - rpc_session = Session() - - max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers)) - max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1)) - split = int(config.aria2c.get("split", 5)) - file_allocation = config.aria2c.get("file_allocation", "prealloc") - if len(urls) > 1: - split = 1 - file_allocation = "none" - - arguments = [ - # [Basic Options] - "--input-file", - "-", - "--all-proxy", - proxy or "", - "--continue=true", - # [Connection Options] - f"--max-concurrent-downloads={max_concurrent_downloads}", - f"--max-connection-per-server={max_connection_per_server}", - f"--split={split}", # each split uses their own connection - "--max-file-not-found=5", # counted towards --max-tries - "--max-tries=5", - "--retry-wait=2", - # [Advanced Options] - "--allow-overwrite=true", - "--auto-file-renaming=false", - "--console-log-level=warn", - "--download-result=default", - f"--file-allocation={file_allocation}", - "--summary-interval=0", - # [RPC Options] - "--enable-rpc=true", - f"--rpc-listen-port={rpc_port}", - f"--rpc-secret={rpc_secret}", - ] - - for header, value in (headers or {}).items(): - if header.lower() == "cookie": - raise ValueError("You cannot set Cookies as a header manually, please use the `cookies` param.") - if header.lower() == "accept-encoding": - # we cannot set an allowed encoding, or it will return compressed - # and the code is not set up to uncompress the data - continue - if header.lower() == "referer": - arguments.extend(["--referer", value]) - continue - if header.lower() == "user-agent": - arguments.extend(["--user-agent", value]) - continue - arguments.extend(["--header", f"{header}: {value}"]) + _manager.ensure_started(proxy=proxy, max_workers=max_workers) if debug_logger: first_url = urls[0] if isinstance(urls[0], str) else urls[0].get("url", "") @@ -202,128 +244,151 @@ def download( "first_url": url_display, "output_dir": str(output_dir), "filename": filename, - "max_concurrent_downloads": max_concurrent_downloads, - "max_connection_per_server": max_connection_per_server, - "split": split, - "file_allocation": file_allocation, "has_proxy": bool(proxy), - "rpc_port": rpc_port, }, ) - yield dict(total=len(urls)) + # Build options for each URI and add via RPC + gids: list[str] = [] + + for i, url in enumerate(urls): + if isinstance(url, str): + url_data = {"url": url} + else: + url_data: dict[str, Any] = url + + url_filename = filename.format(i=i, ext=get_extension(url_data["url"])) + + opts: dict[str, Any] = { + "dir": str(output_dir), + "out": url_filename, + "split": str(1 if len(urls) > 1 else int(config.aria2c.get("split", 5))), + } + + # Cookies as header + if cookies: + mock_request = requests.Request(url=url_data["url"]) + cookie_header = get_cookie_header(cookies, mock_request) + if cookie_header: + opts.setdefault("header", []).append(f"Cookie: {cookie_header}") + + # Global headers + for header, value in (headers or {}).items(): + if header.lower() == "cookie": + raise ValueError("You cannot set Cookies as a header manually, please use the `cookies` param.") + if header.lower() == "accept-encoding": + continue + if header.lower() == "referer": + opts["referer"] = str(value) + continue + if header.lower() == "user-agent": + opts["user-agent"] = str(value) + continue + opts.setdefault("header", []).append(f"{header}: {value}") + + # Per-url extra args + for key, value in url_data.items(): + if key == "url": + continue + if key == "headers": + for header_name, header_value in value.items(): + opts.setdefault("header", []).append(f"{header_name}: {header_value}") + else: + opts[key] = str(value) + + # Add via RPC + gid = _manager.add_uris([url_data["url"]], opts) + if gid: + gids.append(gid) + + yield dict(total=len(gids)) + + completed: set[str] = set() try: - p = subprocess.Popen([binaries.Aria2, *arguments], stdin=subprocess.PIPE, stdout=subprocess.DEVNULL) + while len(completed) < len(gids): + if DOWNLOAD_CANCELLED.is_set(): + # Remove tracked downloads on cancel + for gid in gids: + if gid not in completed: + _manager.remove(gid) + yield dict(downloaded="[yellow]CANCELLED") + raise KeyboardInterrupt() - p.stdin.write(url_file.encode()) - p.stdin.close() + stats = _manager.get_global_stat() + dl_speed = int(stats.get("downloadSpeed", -1)) - while p.poll() is None: - global_stats: dict[str, Any] = ( - rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.getGlobalStat") - or {} - ) + # Aggregate progress across all GIDs for this call + total_completed = 0 + total_size = 0 - number_stopped = int(global_stats.get("numStoppedTotal", 0)) - download_speed = int(global_stats.get("downloadSpeed", -1)) + # Check each tracked GID + for gid in gids: + if gid in completed: + continue - if number_stopped: - yield dict(completed=number_stopped) - if download_speed != -1: - yield dict(downloaded=f"{filesize.decimal(download_speed)}/s") + status = _manager.tell_status(gid) + if not status: + continue - stopped_downloads: list[dict[str, Any]] = ( - rpc( - caller=partial(rpc_session.post, url=rpc_uri), - secret=rpc_secret, - method="aria2.tellStopped", - params=[0, 999999], - ) - or [] - ) + completed_length = int(status.get("completedLength", 0)) + total_length = int(status.get("totalLength", 0)) + total_completed += completed_length + total_size += total_length - for dl in stopped_downloads: - if dl["status"] == "error": - used_uri = next( - uri["uri"] - for file in dl["files"] - if file["selected"] == "true" - for uri in file["uris"] - if uri["status"] == "used" - ) - error = f"Download Error (#{dl['gid']}): {dl['errorMessage']} ({dl['errorCode']}), {used_uri}" - error_pretty = "\n ".join( - textwrap.wrap(error, width=console.width - 20, initial_indent="") - ) - console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty)) - if debug_logger: - debug_logger.log( - level="ERROR", - operation="downloader_aria2c_download_error", - message=f"Aria2c download failed: {dl['errorMessage']}", - context={ - "gid": dl["gid"], - "error_code": dl["errorCode"], - "error_message": dl["errorMessage"], - "used_uri": used_uri[:200] + "..." if len(used_uri) > 200 else used_uri, - "completed_length": dl.get("completedLength"), - "total_length": dl.get("totalLength"), - }, - ) - raise ValueError(error) + state = status.get("status") + if state in ("complete", "error"): + completed.add(gid) + yield dict(completed=len(completed)) - if number_stopped == len(urls): - rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.shutdown") - break + if state == "error": + used_uri = None + try: + used_uri = next( + uri["uri"] + for file in status.get("files", []) + for uri in file.get("uris", []) + if uri.get("status") == "used" + ) + except Exception: + used_uri = "unknown" + error = f"Download Error (#{gid}): {status.get('errorMessage')} ({status.get('errorCode')}), {used_uri}" + error_pretty = "\n ".join(textwrap.wrap(error, width=console.width - 20, initial_indent="")) + console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty)) + if debug_logger: + debug_logger.log( + level="ERROR", + operation="downloader_aria2c_download_error", + message=f"Aria2c download failed: {status.get('errorMessage')}", + context={ + "gid": gid, + "error_code": status.get("errorCode"), + "error_message": status.get("errorMessage"), + "used_uri": used_uri[:200] + "..." if used_uri and len(used_uri) > 200 else used_uri, + "completed_length": status.get("completedLength"), + "total_length": status.get("totalLength"), + }, + ) + raise ValueError(error) + + # Yield aggregate progress for this call's downloads + if total_size > 0: + # Yield both advance (bytes downloaded this iteration) and total for rich progress + if dl_speed != -1: + yield dict(downloaded=f"{filesize.decimal(dl_speed)}/s", advance=0, completed=total_completed, total=total_size) + else: + yield dict(advance=0, completed=total_completed, total=total_size) + elif dl_speed != -1: + yield dict(downloaded=f"{filesize.decimal(dl_speed)}/s") time.sleep(1) - - p.wait() - - if p.returncode != 0: - if debug_logger: - debug_logger.log( - level="ERROR", - operation="downloader_aria2c_failed", - message=f"Aria2c exited with code {p.returncode}", - context={ - "returncode": p.returncode, - "url_count": len(urls), - "output_dir": str(output_dir), - }, - ) - raise subprocess.CalledProcessError(p.returncode, arguments) - - if debug_logger: - debug_logger.log( - level="DEBUG", - operation="downloader_aria2c_complete", - message="Aria2c download completed successfully", - context={ - "url_count": len(urls), - "output_dir": str(output_dir), - "filename": filename, - }, - ) - - except ConnectionResetError: - # interrupted while passing URI to download - raise KeyboardInterrupt() - except subprocess.CalledProcessError as e: - if e.returncode in (7, 0xC000013A): - # 7 is when Aria2(c) handled the CTRL+C - # 0xC000013A is when it never got the chance to - raise KeyboardInterrupt() - raise except KeyboardInterrupt: - DOWNLOAD_CANCELLED.set() # skip pending track downloads - yield dict(downloaded="[yellow]CANCELLED") + DOWNLOAD_CANCELLED.set() raise except Exception as e: - DOWNLOAD_CANCELLED.set() # skip pending track downloads + DOWNLOAD_CANCELLED.set() yield dict(downloaded="[red]FAILED") - if debug_logger and not isinstance(e, (subprocess.CalledProcessError, ValueError)): + if debug_logger and not isinstance(e, ValueError): debug_logger.log( level="ERROR", operation="downloader_aria2c_exception", @@ -335,8 +400,6 @@ def download( }, ) raise - finally: - rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.shutdown") def aria2c(