perf(aria2c): improve download performance with singleton manager

- Use singleton _Aria2Manager to reuse one aria2c process via RPC
- Add downloads via aria2.addUri instead of stdin input file
- Track per-GID byte-level progress (completedLength/totalLength)
- Add thread-safe operations with threading.Lock
- Enable graceful cancellation by removing individual downloads via RPC
This commit is contained in:
Andy
2026-01-23 17:14:12 -07:00
parent b8e2f3da3f
commit 6b90a19632

View File

@@ -1,6 +1,7 @@
import os
import subprocess
import textwrap
import threading
import time
from functools import partial
from http.cookiejar import CookieJar
@@ -49,6 +50,138 @@ def rpc(caller: Callable, secret: str, method: str, params: Optional[list[Any]]
return
class _Aria2Manager:
"""Singleton manager to run one aria2c process and enqueue downloads via RPC."""
def __init__(self) -> None:
self._proc: Optional[subprocess.Popen] = None
self._rpc_port: Optional[int] = None
self._rpc_secret: Optional[str] = None
self._rpc_uri: Optional[str] = None
self._session: Session = Session()
self._max_concurrent_downloads: int = 0
self._max_connection_per_server: int = 1
self._split_default: int = 5
self._file_allocation: str = "prealloc"
self._proxy: Optional[str] = None
self._lock: threading.Lock = threading.Lock()
def _build_args(self) -> list[str]:
args = [
"--continue=true",
f"--max-concurrent-downloads={self._max_concurrent_downloads}",
f"--max-connection-per-server={self._max_connection_per_server}",
f"--split={self._split_default}",
"--max-file-not-found=5",
"--max-tries=5",
"--retry-wait=2",
"--allow-overwrite=true",
"--auto-file-renaming=false",
"--console-log-level=warn",
"--download-result=default",
f"--file-allocation={self._file_allocation}",
"--summary-interval=0",
"--enable-rpc=true",
f"--rpc-listen-port={self._rpc_port}",
f"--rpc-secret={self._rpc_secret}",
]
if self._proxy:
args.extend(["--all-proxy", self._proxy])
return args
def ensure_started(
self,
proxy: Optional[str],
max_workers: Optional[int],
) -> None:
with self._lock:
if self._proc and self._proc.poll() is None:
return
if not binaries.Aria2:
debug_logger = get_debug_logger()
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_binary_missing",
message="Aria2c executable not found in PATH or local binaries directory",
context={"searched_names": ["aria2c", "aria2"]},
)
raise EnvironmentError("Aria2c executable not found...")
if not max_workers:
max_workers = min(32, (os.cpu_count() or 1) + 4)
elif not isinstance(max_workers, int):
raise TypeError(f"Expected max_workers to be {int}, not {type(max_workers)}")
self._rpc_port = get_free_port()
self._rpc_secret = get_random_bytes(16).hex()
self._rpc_uri = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
self._max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers))
self._max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1))
self._split_default = int(config.aria2c.get("split", 5))
self._file_allocation = config.aria2c.get("file_allocation", "prealloc")
self._proxy = proxy or None
args = self._build_args()
self._proc = subprocess.Popen(
[binaries.Aria2, *args], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
# Give aria2c a moment to start up and bind to the RPC port
time.sleep(0.5)
@property
def rpc_uri(self) -> str:
assert self._rpc_uri
return self._rpc_uri
@property
def rpc_secret(self) -> str:
assert self._rpc_secret
return self._rpc_secret
@property
def session(self) -> Session:
return self._session
def add_uris(self, uris: list[str], options: dict[str, Any]) -> str:
"""Add a single download with multiple URIs via RPC."""
gid = rpc(
caller=partial(self._session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.addUri",
params=[uris, options],
)
return gid or ""
def get_global_stat(self) -> dict[str, Any]:
return rpc(
caller=partial(self.session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.getGlobalStat",
) or {}
def tell_status(self, gid: str) -> Optional[dict[str, Any]]:
return rpc(
caller=partial(self.session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.tellStatus",
params=[gid, ["status", "errorCode", "errorMessage", "files", "completedLength", "totalLength"]],
)
def remove(self, gid: str) -> None:
rpc(
caller=partial(self.session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.forceRemove",
params=[gid],
)
_manager = _Aria2Manager()
def download(
urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]],
output_dir: Path,
@@ -58,6 +191,7 @@ def download(
proxy: Optional[str] = None,
max_workers: Optional[int] = None,
) -> Generator[dict[str, Any], None, None]:
"""Enqueue downloads to the singleton aria2c instance via stdin and track per-call progress via RPC."""
debug_logger = get_debug_logger()
if not urls:
@@ -92,102 +226,10 @@ def download(
if not isinstance(urls, list):
urls = [urls]
if not binaries.Aria2:
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_binary_missing",
message="Aria2c executable not found in PATH or local binaries directory",
context={"searched_names": ["aria2c", "aria2"]},
)
raise EnvironmentError("Aria2c executable not found...")
if proxy and not proxy.lower().startswith("http://"):
raise ValueError("Only HTTP proxies are supported by aria2(c)")
if cookies and not isinstance(cookies, CookieJar):
cookies = cookiejar_from_dict(cookies)
url_files = []
for i, url in enumerate(urls):
if isinstance(url, str):
url_data = {"url": url}
else:
url_data: dict[str, Any] = url
url_filename = filename.format(i=i, ext=get_extension(url_data["url"]))
url_text = url_data["url"]
url_text += f"\n\tdir={output_dir}"
url_text += f"\n\tout={url_filename}"
if cookies:
mock_request = requests.Request(url=url_data["url"])
cookie_header = get_cookie_header(cookies, mock_request)
if cookie_header:
url_text += f"\n\theader=Cookie: {cookie_header}"
for key, value in url_data.items():
if key == "url":
continue
if key == "headers":
for header_name, header_value in value.items():
url_text += f"\n\theader={header_name}: {header_value}"
else:
url_text += f"\n\t{key}={value}"
url_files.append(url_text)
url_file = "\n".join(url_files)
rpc_port = get_free_port()
rpc_secret = get_random_bytes(16).hex()
rpc_uri = f"http://127.0.0.1:{rpc_port}/jsonrpc"
rpc_session = Session()
max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers))
max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1))
split = int(config.aria2c.get("split", 5))
file_allocation = config.aria2c.get("file_allocation", "prealloc")
if len(urls) > 1:
split = 1
file_allocation = "none"
arguments = [
# [Basic Options]
"--input-file",
"-",
"--all-proxy",
proxy or "",
"--continue=true",
# [Connection Options]
f"--max-concurrent-downloads={max_concurrent_downloads}",
f"--max-connection-per-server={max_connection_per_server}",
f"--split={split}", # each split uses their own connection
"--max-file-not-found=5", # counted towards --max-tries
"--max-tries=5",
"--retry-wait=2",
# [Advanced Options]
"--allow-overwrite=true",
"--auto-file-renaming=false",
"--console-log-level=warn",
"--download-result=default",
f"--file-allocation={file_allocation}",
"--summary-interval=0",
# [RPC Options]
"--enable-rpc=true",
f"--rpc-listen-port={rpc_port}",
f"--rpc-secret={rpc_secret}",
]
for header, value in (headers or {}).items():
if header.lower() == "cookie":
raise ValueError("You cannot set Cookies as a header manually, please use the `cookies` param.")
if header.lower() == "accept-encoding":
# we cannot set an allowed encoding, or it will return compressed
# and the code is not set up to uncompress the data
continue
if header.lower() == "referer":
arguments.extend(["--referer", value])
continue
if header.lower() == "user-agent":
arguments.extend(["--user-agent", value])
continue
arguments.extend(["--header", f"{header}: {value}"])
_manager.ensure_started(proxy=proxy, max_workers=max_workers)
if debug_logger:
first_url = urls[0] if isinstance(urls[0], str) else urls[0].get("url", "")
@@ -202,128 +244,151 @@ def download(
"first_url": url_display,
"output_dir": str(output_dir),
"filename": filename,
"max_concurrent_downloads": max_concurrent_downloads,
"max_connection_per_server": max_connection_per_server,
"split": split,
"file_allocation": file_allocation,
"has_proxy": bool(proxy),
"rpc_port": rpc_port,
},
)
yield dict(total=len(urls))
# Build options for each URI and add via RPC
gids: list[str] = []
for i, url in enumerate(urls):
if isinstance(url, str):
url_data = {"url": url}
else:
url_data: dict[str, Any] = url
url_filename = filename.format(i=i, ext=get_extension(url_data["url"]))
opts: dict[str, Any] = {
"dir": str(output_dir),
"out": url_filename,
"split": str(1 if len(urls) > 1 else int(config.aria2c.get("split", 5))),
}
# Cookies as header
if cookies:
mock_request = requests.Request(url=url_data["url"])
cookie_header = get_cookie_header(cookies, mock_request)
if cookie_header:
opts.setdefault("header", []).append(f"Cookie: {cookie_header}")
# Global headers
for header, value in (headers or {}).items():
if header.lower() == "cookie":
raise ValueError("You cannot set Cookies as a header manually, please use the `cookies` param.")
if header.lower() == "accept-encoding":
continue
if header.lower() == "referer":
opts["referer"] = str(value)
continue
if header.lower() == "user-agent":
opts["user-agent"] = str(value)
continue
opts.setdefault("header", []).append(f"{header}: {value}")
# Per-url extra args
for key, value in url_data.items():
if key == "url":
continue
if key == "headers":
for header_name, header_value in value.items():
opts.setdefault("header", []).append(f"{header_name}: {header_value}")
else:
opts[key] = str(value)
# Add via RPC
gid = _manager.add_uris([url_data["url"]], opts)
if gid:
gids.append(gid)
yield dict(total=len(gids))
completed: set[str] = set()
try:
p = subprocess.Popen([binaries.Aria2, *arguments], stdin=subprocess.PIPE, stdout=subprocess.DEVNULL)
while len(completed) < len(gids):
if DOWNLOAD_CANCELLED.is_set():
# Remove tracked downloads on cancel
for gid in gids:
if gid not in completed:
_manager.remove(gid)
yield dict(downloaded="[yellow]CANCELLED")
raise KeyboardInterrupt()
p.stdin.write(url_file.encode())
p.stdin.close()
stats = _manager.get_global_stat()
dl_speed = int(stats.get("downloadSpeed", -1))
while p.poll() is None:
global_stats: dict[str, Any] = (
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.getGlobalStat")
or {}
)
# Aggregate progress across all GIDs for this call
total_completed = 0
total_size = 0
number_stopped = int(global_stats.get("numStoppedTotal", 0))
download_speed = int(global_stats.get("downloadSpeed", -1))
# Check each tracked GID
for gid in gids:
if gid in completed:
continue
if number_stopped:
yield dict(completed=number_stopped)
if download_speed != -1:
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
status = _manager.tell_status(gid)
if not status:
continue
stopped_downloads: list[dict[str, Any]] = (
rpc(
caller=partial(rpc_session.post, url=rpc_uri),
secret=rpc_secret,
method="aria2.tellStopped",
params=[0, 999999],
)
or []
)
completed_length = int(status.get("completedLength", 0))
total_length = int(status.get("totalLength", 0))
total_completed += completed_length
total_size += total_length
for dl in stopped_downloads:
if dl["status"] == "error":
state = status.get("status")
if state in ("complete", "error"):
completed.add(gid)
yield dict(completed=len(completed))
if state == "error":
used_uri = None
try:
used_uri = next(
uri["uri"]
for file in dl["files"]
if file["selected"] == "true"
for uri in file["uris"]
if uri["status"] == "used"
)
error = f"Download Error (#{dl['gid']}): {dl['errorMessage']} ({dl['errorCode']}), {used_uri}"
error_pretty = "\n ".join(
textwrap.wrap(error, width=console.width - 20, initial_indent="")
for file in status.get("files", [])
for uri in file.get("uris", [])
if uri.get("status") == "used"
)
except Exception:
used_uri = "unknown"
error = f"Download Error (#{gid}): {status.get('errorMessage')} ({status.get('errorCode')}), {used_uri}"
error_pretty = "\n ".join(textwrap.wrap(error, width=console.width - 20, initial_indent=""))
console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty))
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_download_error",
message=f"Aria2c download failed: {dl['errorMessage']}",
message=f"Aria2c download failed: {status.get('errorMessage')}",
context={
"gid": dl["gid"],
"error_code": dl["errorCode"],
"error_message": dl["errorMessage"],
"used_uri": used_uri[:200] + "..." if len(used_uri) > 200 else used_uri,
"completed_length": dl.get("completedLength"),
"total_length": dl.get("totalLength"),
"gid": gid,
"error_code": status.get("errorCode"),
"error_message": status.get("errorMessage"),
"used_uri": used_uri[:200] + "..." if used_uri and len(used_uri) > 200 else used_uri,
"completed_length": status.get("completedLength"),
"total_length": status.get("totalLength"),
},
)
raise ValueError(error)
if number_stopped == len(urls):
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.shutdown")
break
# Yield aggregate progress for this call's downloads
if total_size > 0:
# Yield both advance (bytes downloaded this iteration) and total for rich progress
if dl_speed != -1:
yield dict(downloaded=f"{filesize.decimal(dl_speed)}/s", advance=0, completed=total_completed, total=total_size)
else:
yield dict(advance=0, completed=total_completed, total=total_size)
elif dl_speed != -1:
yield dict(downloaded=f"{filesize.decimal(dl_speed)}/s")
time.sleep(1)
p.wait()
if p.returncode != 0:
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_failed",
message=f"Aria2c exited with code {p.returncode}",
context={
"returncode": p.returncode,
"url_count": len(urls),
"output_dir": str(output_dir),
},
)
raise subprocess.CalledProcessError(p.returncode, arguments)
if debug_logger:
debug_logger.log(
level="DEBUG",
operation="downloader_aria2c_complete",
message="Aria2c download completed successfully",
context={
"url_count": len(urls),
"output_dir": str(output_dir),
"filename": filename,
},
)
except ConnectionResetError:
# interrupted while passing URI to download
raise KeyboardInterrupt()
except subprocess.CalledProcessError as e:
if e.returncode in (7, 0xC000013A):
# 7 is when Aria2(c) handled the CTRL+C
# 0xC000013A is when it never got the chance to
raise KeyboardInterrupt()
raise
except KeyboardInterrupt:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
yield dict(downloaded="[yellow]CANCELLED")
DOWNLOAD_CANCELLED.set()
raise
except Exception as e:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
DOWNLOAD_CANCELLED.set()
yield dict(downloaded="[red]FAILED")
if debug_logger and not isinstance(e, (subprocess.CalledProcessError, ValueError)):
if debug_logger and not isinstance(e, ValueError):
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_exception",
@@ -335,8 +400,6 @@ def download(
},
)
raise
finally:
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.shutdown")
def aria2c(