perf(aria2c): improve download performance with singleton manager

- Use singleton _Aria2Manager to reuse one aria2c process via RPC
- Add downloads via aria2.addUri instead of stdin input file
- Track per-GID byte-level progress (completedLength/totalLength)
- Add thread-safe operations with threading.Lock
- Enable graceful cancellation by removing individual downloads via RPC
This commit is contained in:
Andy
2026-01-23 17:14:12 -07:00
parent b8e2f3da3f
commit 6b90a19632

View File

@@ -1,6 +1,7 @@
import os import os
import subprocess import subprocess
import textwrap import textwrap
import threading
import time import time
from functools import partial from functools import partial
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
@@ -49,6 +50,138 @@ def rpc(caller: Callable, secret: str, method: str, params: Optional[list[Any]]
return return
class _Aria2Manager:
"""Singleton manager to run one aria2c process and enqueue downloads via RPC."""
def __init__(self) -> None:
self._proc: Optional[subprocess.Popen] = None
self._rpc_port: Optional[int] = None
self._rpc_secret: Optional[str] = None
self._rpc_uri: Optional[str] = None
self._session: Session = Session()
self._max_concurrent_downloads: int = 0
self._max_connection_per_server: int = 1
self._split_default: int = 5
self._file_allocation: str = "prealloc"
self._proxy: Optional[str] = None
self._lock: threading.Lock = threading.Lock()
def _build_args(self) -> list[str]:
args = [
"--continue=true",
f"--max-concurrent-downloads={self._max_concurrent_downloads}",
f"--max-connection-per-server={self._max_connection_per_server}",
f"--split={self._split_default}",
"--max-file-not-found=5",
"--max-tries=5",
"--retry-wait=2",
"--allow-overwrite=true",
"--auto-file-renaming=false",
"--console-log-level=warn",
"--download-result=default",
f"--file-allocation={self._file_allocation}",
"--summary-interval=0",
"--enable-rpc=true",
f"--rpc-listen-port={self._rpc_port}",
f"--rpc-secret={self._rpc_secret}",
]
if self._proxy:
args.extend(["--all-proxy", self._proxy])
return args
def ensure_started(
self,
proxy: Optional[str],
max_workers: Optional[int],
) -> None:
with self._lock:
if self._proc and self._proc.poll() is None:
return
if not binaries.Aria2:
debug_logger = get_debug_logger()
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_binary_missing",
message="Aria2c executable not found in PATH or local binaries directory",
context={"searched_names": ["aria2c", "aria2"]},
)
raise EnvironmentError("Aria2c executable not found...")
if not max_workers:
max_workers = min(32, (os.cpu_count() or 1) + 4)
elif not isinstance(max_workers, int):
raise TypeError(f"Expected max_workers to be {int}, not {type(max_workers)}")
self._rpc_port = get_free_port()
self._rpc_secret = get_random_bytes(16).hex()
self._rpc_uri = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
self._max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers))
self._max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1))
self._split_default = int(config.aria2c.get("split", 5))
self._file_allocation = config.aria2c.get("file_allocation", "prealloc")
self._proxy = proxy or None
args = self._build_args()
self._proc = subprocess.Popen(
[binaries.Aria2, *args], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
# Give aria2c a moment to start up and bind to the RPC port
time.sleep(0.5)
@property
def rpc_uri(self) -> str:
assert self._rpc_uri
return self._rpc_uri
@property
def rpc_secret(self) -> str:
assert self._rpc_secret
return self._rpc_secret
@property
def session(self) -> Session:
return self._session
def add_uris(self, uris: list[str], options: dict[str, Any]) -> str:
"""Add a single download with multiple URIs via RPC."""
gid = rpc(
caller=partial(self._session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.addUri",
params=[uris, options],
)
return gid or ""
def get_global_stat(self) -> dict[str, Any]:
return rpc(
caller=partial(self.session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.getGlobalStat",
) or {}
def tell_status(self, gid: str) -> Optional[dict[str, Any]]:
return rpc(
caller=partial(self.session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.tellStatus",
params=[gid, ["status", "errorCode", "errorMessage", "files", "completedLength", "totalLength"]],
)
def remove(self, gid: str) -> None:
rpc(
caller=partial(self.session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.forceRemove",
params=[gid],
)
_manager = _Aria2Manager()
def download( def download(
urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]], urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]],
output_dir: Path, output_dir: Path,
@@ -58,6 +191,7 @@ def download(
proxy: Optional[str] = None, proxy: Optional[str] = None,
max_workers: Optional[int] = None, max_workers: Optional[int] = None,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
"""Enqueue downloads to the singleton aria2c instance via stdin and track per-call progress via RPC."""
debug_logger = get_debug_logger() debug_logger = get_debug_logger()
if not urls: if not urls:
@@ -92,102 +226,10 @@ def download(
if not isinstance(urls, list): if not isinstance(urls, list):
urls = [urls] urls = [urls]
if not binaries.Aria2:
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_binary_missing",
message="Aria2c executable not found in PATH or local binaries directory",
context={"searched_names": ["aria2c", "aria2"]},
)
raise EnvironmentError("Aria2c executable not found...")
if proxy and not proxy.lower().startswith("http://"):
raise ValueError("Only HTTP proxies are supported by aria2(c)")
if cookies and not isinstance(cookies, CookieJar): if cookies and not isinstance(cookies, CookieJar):
cookies = cookiejar_from_dict(cookies) cookies = cookiejar_from_dict(cookies)
url_files = [] _manager.ensure_started(proxy=proxy, max_workers=max_workers)
for i, url in enumerate(urls):
if isinstance(url, str):
url_data = {"url": url}
else:
url_data: dict[str, Any] = url
url_filename = filename.format(i=i, ext=get_extension(url_data["url"]))
url_text = url_data["url"]
url_text += f"\n\tdir={output_dir}"
url_text += f"\n\tout={url_filename}"
if cookies:
mock_request = requests.Request(url=url_data["url"])
cookie_header = get_cookie_header(cookies, mock_request)
if cookie_header:
url_text += f"\n\theader=Cookie: {cookie_header}"
for key, value in url_data.items():
if key == "url":
continue
if key == "headers":
for header_name, header_value in value.items():
url_text += f"\n\theader={header_name}: {header_value}"
else:
url_text += f"\n\t{key}={value}"
url_files.append(url_text)
url_file = "\n".join(url_files)
rpc_port = get_free_port()
rpc_secret = get_random_bytes(16).hex()
rpc_uri = f"http://127.0.0.1:{rpc_port}/jsonrpc"
rpc_session = Session()
max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers))
max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1))
split = int(config.aria2c.get("split", 5))
file_allocation = config.aria2c.get("file_allocation", "prealloc")
if len(urls) > 1:
split = 1
file_allocation = "none"
arguments = [
# [Basic Options]
"--input-file",
"-",
"--all-proxy",
proxy or "",
"--continue=true",
# [Connection Options]
f"--max-concurrent-downloads={max_concurrent_downloads}",
f"--max-connection-per-server={max_connection_per_server}",
f"--split={split}", # each split uses their own connection
"--max-file-not-found=5", # counted towards --max-tries
"--max-tries=5",
"--retry-wait=2",
# [Advanced Options]
"--allow-overwrite=true",
"--auto-file-renaming=false",
"--console-log-level=warn",
"--download-result=default",
f"--file-allocation={file_allocation}",
"--summary-interval=0",
# [RPC Options]
"--enable-rpc=true",
f"--rpc-listen-port={rpc_port}",
f"--rpc-secret={rpc_secret}",
]
for header, value in (headers or {}).items():
if header.lower() == "cookie":
raise ValueError("You cannot set Cookies as a header manually, please use the `cookies` param.")
if header.lower() == "accept-encoding":
# we cannot set an allowed encoding, or it will return compressed
# and the code is not set up to uncompress the data
continue
if header.lower() == "referer":
arguments.extend(["--referer", value])
continue
if header.lower() == "user-agent":
arguments.extend(["--user-agent", value])
continue
arguments.extend(["--header", f"{header}: {value}"])
if debug_logger: if debug_logger:
first_url = urls[0] if isinstance(urls[0], str) else urls[0].get("url", "") first_url = urls[0] if isinstance(urls[0], str) else urls[0].get("url", "")
@@ -202,128 +244,151 @@ def download(
"first_url": url_display, "first_url": url_display,
"output_dir": str(output_dir), "output_dir": str(output_dir),
"filename": filename, "filename": filename,
"max_concurrent_downloads": max_concurrent_downloads,
"max_connection_per_server": max_connection_per_server,
"split": split,
"file_allocation": file_allocation,
"has_proxy": bool(proxy), "has_proxy": bool(proxy),
"rpc_port": rpc_port,
}, },
) )
yield dict(total=len(urls)) # Build options for each URI and add via RPC
gids: list[str] = []
for i, url in enumerate(urls):
if isinstance(url, str):
url_data = {"url": url}
else:
url_data: dict[str, Any] = url
url_filename = filename.format(i=i, ext=get_extension(url_data["url"]))
opts: dict[str, Any] = {
"dir": str(output_dir),
"out": url_filename,
"split": str(1 if len(urls) > 1 else int(config.aria2c.get("split", 5))),
}
# Cookies as header
if cookies:
mock_request = requests.Request(url=url_data["url"])
cookie_header = get_cookie_header(cookies, mock_request)
if cookie_header:
opts.setdefault("header", []).append(f"Cookie: {cookie_header}")
# Global headers
for header, value in (headers or {}).items():
if header.lower() == "cookie":
raise ValueError("You cannot set Cookies as a header manually, please use the `cookies` param.")
if header.lower() == "accept-encoding":
continue
if header.lower() == "referer":
opts["referer"] = str(value)
continue
if header.lower() == "user-agent":
opts["user-agent"] = str(value)
continue
opts.setdefault("header", []).append(f"{header}: {value}")
# Per-url extra args
for key, value in url_data.items():
if key == "url":
continue
if key == "headers":
for header_name, header_value in value.items():
opts.setdefault("header", []).append(f"{header_name}: {header_value}")
else:
opts[key] = str(value)
# Add via RPC
gid = _manager.add_uris([url_data["url"]], opts)
if gid:
gids.append(gid)
yield dict(total=len(gids))
completed: set[str] = set()
try: try:
p = subprocess.Popen([binaries.Aria2, *arguments], stdin=subprocess.PIPE, stdout=subprocess.DEVNULL) while len(completed) < len(gids):
if DOWNLOAD_CANCELLED.is_set():
# Remove tracked downloads on cancel
for gid in gids:
if gid not in completed:
_manager.remove(gid)
yield dict(downloaded="[yellow]CANCELLED")
raise KeyboardInterrupt()
p.stdin.write(url_file.encode()) stats = _manager.get_global_stat()
p.stdin.close() dl_speed = int(stats.get("downloadSpeed", -1))
while p.poll() is None: # Aggregate progress across all GIDs for this call
global_stats: dict[str, Any] = ( total_completed = 0
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.getGlobalStat") total_size = 0
or {}
)
number_stopped = int(global_stats.get("numStoppedTotal", 0)) # Check each tracked GID
download_speed = int(global_stats.get("downloadSpeed", -1)) for gid in gids:
if gid in completed:
continue
if number_stopped: status = _manager.tell_status(gid)
yield dict(completed=number_stopped) if not status:
if download_speed != -1: continue
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
stopped_downloads: list[dict[str, Any]] = ( completed_length = int(status.get("completedLength", 0))
rpc( total_length = int(status.get("totalLength", 0))
caller=partial(rpc_session.post, url=rpc_uri), total_completed += completed_length
secret=rpc_secret, total_size += total_length
method="aria2.tellStopped",
params=[0, 999999],
)
or []
)
for dl in stopped_downloads: state = status.get("status")
if dl["status"] == "error": if state in ("complete", "error"):
completed.add(gid)
yield dict(completed=len(completed))
if state == "error":
used_uri = None
try:
used_uri = next( used_uri = next(
uri["uri"] uri["uri"]
for file in dl["files"] for file in status.get("files", [])
if file["selected"] == "true" for uri in file.get("uris", [])
for uri in file["uris"] if uri.get("status") == "used"
if uri["status"] == "used"
)
error = f"Download Error (#{dl['gid']}): {dl['errorMessage']} ({dl['errorCode']}), {used_uri}"
error_pretty = "\n ".join(
textwrap.wrap(error, width=console.width - 20, initial_indent="")
) )
except Exception:
used_uri = "unknown"
error = f"Download Error (#{gid}): {status.get('errorMessage')} ({status.get('errorCode')}), {used_uri}"
error_pretty = "\n ".join(textwrap.wrap(error, width=console.width - 20, initial_indent=""))
console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty)) console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty))
if debug_logger: if debug_logger:
debug_logger.log( debug_logger.log(
level="ERROR", level="ERROR",
operation="downloader_aria2c_download_error", operation="downloader_aria2c_download_error",
message=f"Aria2c download failed: {dl['errorMessage']}", message=f"Aria2c download failed: {status.get('errorMessage')}",
context={ context={
"gid": dl["gid"], "gid": gid,
"error_code": dl["errorCode"], "error_code": status.get("errorCode"),
"error_message": dl["errorMessage"], "error_message": status.get("errorMessage"),
"used_uri": used_uri[:200] + "..." if len(used_uri) > 200 else used_uri, "used_uri": used_uri[:200] + "..." if used_uri and len(used_uri) > 200 else used_uri,
"completed_length": dl.get("completedLength"), "completed_length": status.get("completedLength"),
"total_length": dl.get("totalLength"), "total_length": status.get("totalLength"),
}, },
) )
raise ValueError(error) raise ValueError(error)
if number_stopped == len(urls): # Yield aggregate progress for this call's downloads
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.shutdown") if total_size > 0:
break # Yield both advance (bytes downloaded this iteration) and total for rich progress
if dl_speed != -1:
yield dict(downloaded=f"{filesize.decimal(dl_speed)}/s", advance=0, completed=total_completed, total=total_size)
else:
yield dict(advance=0, completed=total_completed, total=total_size)
elif dl_speed != -1:
yield dict(downloaded=f"{filesize.decimal(dl_speed)}/s")
time.sleep(1) time.sleep(1)
p.wait()
if p.returncode != 0:
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_failed",
message=f"Aria2c exited with code {p.returncode}",
context={
"returncode": p.returncode,
"url_count": len(urls),
"output_dir": str(output_dir),
},
)
raise subprocess.CalledProcessError(p.returncode, arguments)
if debug_logger:
debug_logger.log(
level="DEBUG",
operation="downloader_aria2c_complete",
message="Aria2c download completed successfully",
context={
"url_count": len(urls),
"output_dir": str(output_dir),
"filename": filename,
},
)
except ConnectionResetError:
# interrupted while passing URI to download
raise KeyboardInterrupt()
except subprocess.CalledProcessError as e:
if e.returncode in (7, 0xC000013A):
# 7 is when Aria2(c) handled the CTRL+C
# 0xC000013A is when it never got the chance to
raise KeyboardInterrupt()
raise
except KeyboardInterrupt: except KeyboardInterrupt:
DOWNLOAD_CANCELLED.set() # skip pending track downloads DOWNLOAD_CANCELLED.set()
yield dict(downloaded="[yellow]CANCELLED")
raise raise
except Exception as e: except Exception as e:
DOWNLOAD_CANCELLED.set() # skip pending track downloads DOWNLOAD_CANCELLED.set()
yield dict(downloaded="[red]FAILED") yield dict(downloaded="[red]FAILED")
if debug_logger and not isinstance(e, (subprocess.CalledProcessError, ValueError)): if debug_logger and not isinstance(e, ValueError):
debug_logger.log( debug_logger.log(
level="ERROR", level="ERROR",
operation="downloader_aria2c_exception", operation="downloader_aria2c_exception",
@@ -335,8 +400,6 @@ def download(
}, },
) )
raise raise
finally:
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.shutdown")
def aria2c( def aria2c(