forked from kenzuya/unshackle
perf(aria2c): improve download performance with singleton manager
- Use singleton _Aria2Manager to reuse one aria2c process via RPC - Add downloads via aria2.addUri instead of stdin input file - Track per-GID byte-level progress (completedLength/totalLength) - Add thread-safe operations with threading.Lock - Enable graceful cancellation by removing individual downloads via RPC
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
import threading
|
||||
import time
|
||||
from functools import partial
|
||||
from http.cookiejar import CookieJar
|
||||
@@ -49,6 +50,138 @@ def rpc(caller: Callable, secret: str, method: str, params: Optional[list[Any]]
|
||||
return
|
||||
|
||||
|
||||
class _Aria2Manager:
|
||||
"""Singleton manager to run one aria2c process and enqueue downloads via RPC."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._proc: Optional[subprocess.Popen] = None
|
||||
self._rpc_port: Optional[int] = None
|
||||
self._rpc_secret: Optional[str] = None
|
||||
self._rpc_uri: Optional[str] = None
|
||||
self._session: Session = Session()
|
||||
self._max_concurrent_downloads: int = 0
|
||||
self._max_connection_per_server: int = 1
|
||||
self._split_default: int = 5
|
||||
self._file_allocation: str = "prealloc"
|
||||
self._proxy: Optional[str] = None
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
|
||||
def _build_args(self) -> list[str]:
|
||||
args = [
|
||||
"--continue=true",
|
||||
f"--max-concurrent-downloads={self._max_concurrent_downloads}",
|
||||
f"--max-connection-per-server={self._max_connection_per_server}",
|
||||
f"--split={self._split_default}",
|
||||
"--max-file-not-found=5",
|
||||
"--max-tries=5",
|
||||
"--retry-wait=2",
|
||||
"--allow-overwrite=true",
|
||||
"--auto-file-renaming=false",
|
||||
"--console-log-level=warn",
|
||||
"--download-result=default",
|
||||
f"--file-allocation={self._file_allocation}",
|
||||
"--summary-interval=0",
|
||||
"--enable-rpc=true",
|
||||
f"--rpc-listen-port={self._rpc_port}",
|
||||
f"--rpc-secret={self._rpc_secret}",
|
||||
]
|
||||
if self._proxy:
|
||||
args.extend(["--all-proxy", self._proxy])
|
||||
return args
|
||||
|
||||
def ensure_started(
|
||||
self,
|
||||
proxy: Optional[str],
|
||||
max_workers: Optional[int],
|
||||
) -> None:
|
||||
with self._lock:
|
||||
if self._proc and self._proc.poll() is None:
|
||||
return
|
||||
|
||||
if not binaries.Aria2:
|
||||
debug_logger = get_debug_logger()
|
||||
if debug_logger:
|
||||
debug_logger.log(
|
||||
level="ERROR",
|
||||
operation="downloader_aria2c_binary_missing",
|
||||
message="Aria2c executable not found in PATH or local binaries directory",
|
||||
context={"searched_names": ["aria2c", "aria2"]},
|
||||
)
|
||||
raise EnvironmentError("Aria2c executable not found...")
|
||||
|
||||
if not max_workers:
|
||||
max_workers = min(32, (os.cpu_count() or 1) + 4)
|
||||
elif not isinstance(max_workers, int):
|
||||
raise TypeError(f"Expected max_workers to be {int}, not {type(max_workers)}")
|
||||
|
||||
self._rpc_port = get_free_port()
|
||||
self._rpc_secret = get_random_bytes(16).hex()
|
||||
self._rpc_uri = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
|
||||
|
||||
self._max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers))
|
||||
self._max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1))
|
||||
self._split_default = int(config.aria2c.get("split", 5))
|
||||
self._file_allocation = config.aria2c.get("file_allocation", "prealloc")
|
||||
self._proxy = proxy or None
|
||||
|
||||
args = self._build_args()
|
||||
self._proc = subprocess.Popen(
|
||||
[binaries.Aria2, *args], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
||||
)
|
||||
# Give aria2c a moment to start up and bind to the RPC port
|
||||
time.sleep(0.5)
|
||||
|
||||
@property
|
||||
def rpc_uri(self) -> str:
|
||||
assert self._rpc_uri
|
||||
return self._rpc_uri
|
||||
|
||||
@property
|
||||
def rpc_secret(self) -> str:
|
||||
assert self._rpc_secret
|
||||
return self._rpc_secret
|
||||
|
||||
@property
|
||||
def session(self) -> Session:
|
||||
return self._session
|
||||
|
||||
def add_uris(self, uris: list[str], options: dict[str, Any]) -> str:
|
||||
"""Add a single download with multiple URIs via RPC."""
|
||||
gid = rpc(
|
||||
caller=partial(self._session.post, url=self.rpc_uri),
|
||||
secret=self.rpc_secret,
|
||||
method="aria2.addUri",
|
||||
params=[uris, options],
|
||||
)
|
||||
return gid or ""
|
||||
|
||||
def get_global_stat(self) -> dict[str, Any]:
|
||||
return rpc(
|
||||
caller=partial(self.session.post, url=self.rpc_uri),
|
||||
secret=self.rpc_secret,
|
||||
method="aria2.getGlobalStat",
|
||||
) or {}
|
||||
|
||||
def tell_status(self, gid: str) -> Optional[dict[str, Any]]:
|
||||
return rpc(
|
||||
caller=partial(self.session.post, url=self.rpc_uri),
|
||||
secret=self.rpc_secret,
|
||||
method="aria2.tellStatus",
|
||||
params=[gid, ["status", "errorCode", "errorMessage", "files", "completedLength", "totalLength"]],
|
||||
)
|
||||
|
||||
def remove(self, gid: str) -> None:
|
||||
rpc(
|
||||
caller=partial(self.session.post, url=self.rpc_uri),
|
||||
secret=self.rpc_secret,
|
||||
method="aria2.forceRemove",
|
||||
params=[gid],
|
||||
)
|
||||
|
||||
|
||||
_manager = _Aria2Manager()
|
||||
|
||||
|
||||
def download(
|
||||
urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]],
|
||||
output_dir: Path,
|
||||
@@ -58,6 +191,7 @@ def download(
|
||||
proxy: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""Enqueue downloads to the singleton aria2c instance via stdin and track per-call progress via RPC."""
|
||||
debug_logger = get_debug_logger()
|
||||
|
||||
if not urls:
|
||||
@@ -92,102 +226,10 @@ def download(
|
||||
if not isinstance(urls, list):
|
||||
urls = [urls]
|
||||
|
||||
if not binaries.Aria2:
|
||||
if debug_logger:
|
||||
debug_logger.log(
|
||||
level="ERROR",
|
||||
operation="downloader_aria2c_binary_missing",
|
||||
message="Aria2c executable not found in PATH or local binaries directory",
|
||||
context={"searched_names": ["aria2c", "aria2"]},
|
||||
)
|
||||
raise EnvironmentError("Aria2c executable not found...")
|
||||
|
||||
if proxy and not proxy.lower().startswith("http://"):
|
||||
raise ValueError("Only HTTP proxies are supported by aria2(c)")
|
||||
|
||||
if cookies and not isinstance(cookies, CookieJar):
|
||||
cookies = cookiejar_from_dict(cookies)
|
||||
|
||||
url_files = []
|
||||
for i, url in enumerate(urls):
|
||||
if isinstance(url, str):
|
||||
url_data = {"url": url}
|
||||
else:
|
||||
url_data: dict[str, Any] = url
|
||||
url_filename = filename.format(i=i, ext=get_extension(url_data["url"]))
|
||||
url_text = url_data["url"]
|
||||
url_text += f"\n\tdir={output_dir}"
|
||||
url_text += f"\n\tout={url_filename}"
|
||||
if cookies:
|
||||
mock_request = requests.Request(url=url_data["url"])
|
||||
cookie_header = get_cookie_header(cookies, mock_request)
|
||||
if cookie_header:
|
||||
url_text += f"\n\theader=Cookie: {cookie_header}"
|
||||
for key, value in url_data.items():
|
||||
if key == "url":
|
||||
continue
|
||||
if key == "headers":
|
||||
for header_name, header_value in value.items():
|
||||
url_text += f"\n\theader={header_name}: {header_value}"
|
||||
else:
|
||||
url_text += f"\n\t{key}={value}"
|
||||
url_files.append(url_text)
|
||||
url_file = "\n".join(url_files)
|
||||
|
||||
rpc_port = get_free_port()
|
||||
rpc_secret = get_random_bytes(16).hex()
|
||||
rpc_uri = f"http://127.0.0.1:{rpc_port}/jsonrpc"
|
||||
rpc_session = Session()
|
||||
|
||||
max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers))
|
||||
max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1))
|
||||
split = int(config.aria2c.get("split", 5))
|
||||
file_allocation = config.aria2c.get("file_allocation", "prealloc")
|
||||
if len(urls) > 1:
|
||||
split = 1
|
||||
file_allocation = "none"
|
||||
|
||||
arguments = [
|
||||
# [Basic Options]
|
||||
"--input-file",
|
||||
"-",
|
||||
"--all-proxy",
|
||||
proxy or "",
|
||||
"--continue=true",
|
||||
# [Connection Options]
|
||||
f"--max-concurrent-downloads={max_concurrent_downloads}",
|
||||
f"--max-connection-per-server={max_connection_per_server}",
|
||||
f"--split={split}", # each split uses their own connection
|
||||
"--max-file-not-found=5", # counted towards --max-tries
|
||||
"--max-tries=5",
|
||||
"--retry-wait=2",
|
||||
# [Advanced Options]
|
||||
"--allow-overwrite=true",
|
||||
"--auto-file-renaming=false",
|
||||
"--console-log-level=warn",
|
||||
"--download-result=default",
|
||||
f"--file-allocation={file_allocation}",
|
||||
"--summary-interval=0",
|
||||
# [RPC Options]
|
||||
"--enable-rpc=true",
|
||||
f"--rpc-listen-port={rpc_port}",
|
||||
f"--rpc-secret={rpc_secret}",
|
||||
]
|
||||
|
||||
for header, value in (headers or {}).items():
|
||||
if header.lower() == "cookie":
|
||||
raise ValueError("You cannot set Cookies as a header manually, please use the `cookies` param.")
|
||||
if header.lower() == "accept-encoding":
|
||||
# we cannot set an allowed encoding, or it will return compressed
|
||||
# and the code is not set up to uncompress the data
|
||||
continue
|
||||
if header.lower() == "referer":
|
||||
arguments.extend(["--referer", value])
|
||||
continue
|
||||
if header.lower() == "user-agent":
|
||||
arguments.extend(["--user-agent", value])
|
||||
continue
|
||||
arguments.extend(["--header", f"{header}: {value}"])
|
||||
_manager.ensure_started(proxy=proxy, max_workers=max_workers)
|
||||
|
||||
if debug_logger:
|
||||
first_url = urls[0] if isinstance(urls[0], str) else urls[0].get("url", "")
|
||||
@@ -202,128 +244,151 @@ def download(
|
||||
"first_url": url_display,
|
||||
"output_dir": str(output_dir),
|
||||
"filename": filename,
|
||||
"max_concurrent_downloads": max_concurrent_downloads,
|
||||
"max_connection_per_server": max_connection_per_server,
|
||||
"split": split,
|
||||
"file_allocation": file_allocation,
|
||||
"has_proxy": bool(proxy),
|
||||
"rpc_port": rpc_port,
|
||||
},
|
||||
)
|
||||
|
||||
yield dict(total=len(urls))
|
||||
# Build options for each URI and add via RPC
|
||||
gids: list[str] = []
|
||||
|
||||
for i, url in enumerate(urls):
|
||||
if isinstance(url, str):
|
||||
url_data = {"url": url}
|
||||
else:
|
||||
url_data: dict[str, Any] = url
|
||||
|
||||
url_filename = filename.format(i=i, ext=get_extension(url_data["url"]))
|
||||
|
||||
opts: dict[str, Any] = {
|
||||
"dir": str(output_dir),
|
||||
"out": url_filename,
|
||||
"split": str(1 if len(urls) > 1 else int(config.aria2c.get("split", 5))),
|
||||
}
|
||||
|
||||
# Cookies as header
|
||||
if cookies:
|
||||
mock_request = requests.Request(url=url_data["url"])
|
||||
cookie_header = get_cookie_header(cookies, mock_request)
|
||||
if cookie_header:
|
||||
opts.setdefault("header", []).append(f"Cookie: {cookie_header}")
|
||||
|
||||
# Global headers
|
||||
for header, value in (headers or {}).items():
|
||||
if header.lower() == "cookie":
|
||||
raise ValueError("You cannot set Cookies as a header manually, please use the `cookies` param.")
|
||||
if header.lower() == "accept-encoding":
|
||||
continue
|
||||
if header.lower() == "referer":
|
||||
opts["referer"] = str(value)
|
||||
continue
|
||||
if header.lower() == "user-agent":
|
||||
opts["user-agent"] = str(value)
|
||||
continue
|
||||
opts.setdefault("header", []).append(f"{header}: {value}")
|
||||
|
||||
# Per-url extra args
|
||||
for key, value in url_data.items():
|
||||
if key == "url":
|
||||
continue
|
||||
if key == "headers":
|
||||
for header_name, header_value in value.items():
|
||||
opts.setdefault("header", []).append(f"{header_name}: {header_value}")
|
||||
else:
|
||||
opts[key] = str(value)
|
||||
|
||||
# Add via RPC
|
||||
gid = _manager.add_uris([url_data["url"]], opts)
|
||||
if gid:
|
||||
gids.append(gid)
|
||||
|
||||
yield dict(total=len(gids))
|
||||
|
||||
completed: set[str] = set()
|
||||
|
||||
try:
|
||||
p = subprocess.Popen([binaries.Aria2, *arguments], stdin=subprocess.PIPE, stdout=subprocess.DEVNULL)
|
||||
while len(completed) < len(gids):
|
||||
if DOWNLOAD_CANCELLED.is_set():
|
||||
# Remove tracked downloads on cancel
|
||||
for gid in gids:
|
||||
if gid not in completed:
|
||||
_manager.remove(gid)
|
||||
yield dict(downloaded="[yellow]CANCELLED")
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
p.stdin.write(url_file.encode())
|
||||
p.stdin.close()
|
||||
stats = _manager.get_global_stat()
|
||||
dl_speed = int(stats.get("downloadSpeed", -1))
|
||||
|
||||
while p.poll() is None:
|
||||
global_stats: dict[str, Any] = (
|
||||
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.getGlobalStat")
|
||||
or {}
|
||||
)
|
||||
# Aggregate progress across all GIDs for this call
|
||||
total_completed = 0
|
||||
total_size = 0
|
||||
|
||||
number_stopped = int(global_stats.get("numStoppedTotal", 0))
|
||||
download_speed = int(global_stats.get("downloadSpeed", -1))
|
||||
# Check each tracked GID
|
||||
for gid in gids:
|
||||
if gid in completed:
|
||||
continue
|
||||
|
||||
if number_stopped:
|
||||
yield dict(completed=number_stopped)
|
||||
if download_speed != -1:
|
||||
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
|
||||
status = _manager.tell_status(gid)
|
||||
if not status:
|
||||
continue
|
||||
|
||||
stopped_downloads: list[dict[str, Any]] = (
|
||||
rpc(
|
||||
caller=partial(rpc_session.post, url=rpc_uri),
|
||||
secret=rpc_secret,
|
||||
method="aria2.tellStopped",
|
||||
params=[0, 999999],
|
||||
)
|
||||
or []
|
||||
)
|
||||
completed_length = int(status.get("completedLength", 0))
|
||||
total_length = int(status.get("totalLength", 0))
|
||||
total_completed += completed_length
|
||||
total_size += total_length
|
||||
|
||||
for dl in stopped_downloads:
|
||||
if dl["status"] == "error":
|
||||
state = status.get("status")
|
||||
if state in ("complete", "error"):
|
||||
completed.add(gid)
|
||||
yield dict(completed=len(completed))
|
||||
|
||||
if state == "error":
|
||||
used_uri = None
|
||||
try:
|
||||
used_uri = next(
|
||||
uri["uri"]
|
||||
for file in dl["files"]
|
||||
if file["selected"] == "true"
|
||||
for uri in file["uris"]
|
||||
if uri["status"] == "used"
|
||||
)
|
||||
error = f"Download Error (#{dl['gid']}): {dl['errorMessage']} ({dl['errorCode']}), {used_uri}"
|
||||
error_pretty = "\n ".join(
|
||||
textwrap.wrap(error, width=console.width - 20, initial_indent="")
|
||||
for file in status.get("files", [])
|
||||
for uri in file.get("uris", [])
|
||||
if uri.get("status") == "used"
|
||||
)
|
||||
except Exception:
|
||||
used_uri = "unknown"
|
||||
error = f"Download Error (#{gid}): {status.get('errorMessage')} ({status.get('errorCode')}), {used_uri}"
|
||||
error_pretty = "\n ".join(textwrap.wrap(error, width=console.width - 20, initial_indent=""))
|
||||
console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty))
|
||||
if debug_logger:
|
||||
debug_logger.log(
|
||||
level="ERROR",
|
||||
operation="downloader_aria2c_download_error",
|
||||
message=f"Aria2c download failed: {dl['errorMessage']}",
|
||||
message=f"Aria2c download failed: {status.get('errorMessage')}",
|
||||
context={
|
||||
"gid": dl["gid"],
|
||||
"error_code": dl["errorCode"],
|
||||
"error_message": dl["errorMessage"],
|
||||
"used_uri": used_uri[:200] + "..." if len(used_uri) > 200 else used_uri,
|
||||
"completed_length": dl.get("completedLength"),
|
||||
"total_length": dl.get("totalLength"),
|
||||
"gid": gid,
|
||||
"error_code": status.get("errorCode"),
|
||||
"error_message": status.get("errorMessage"),
|
||||
"used_uri": used_uri[:200] + "..." if used_uri and len(used_uri) > 200 else used_uri,
|
||||
"completed_length": status.get("completedLength"),
|
||||
"total_length": status.get("totalLength"),
|
||||
},
|
||||
)
|
||||
raise ValueError(error)
|
||||
|
||||
if number_stopped == len(urls):
|
||||
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.shutdown")
|
||||
break
|
||||
# Yield aggregate progress for this call's downloads
|
||||
if total_size > 0:
|
||||
# Yield both advance (bytes downloaded this iteration) and total for rich progress
|
||||
if dl_speed != -1:
|
||||
yield dict(downloaded=f"{filesize.decimal(dl_speed)}/s", advance=0, completed=total_completed, total=total_size)
|
||||
else:
|
||||
yield dict(advance=0, completed=total_completed, total=total_size)
|
||||
elif dl_speed != -1:
|
||||
yield dict(downloaded=f"{filesize.decimal(dl_speed)}/s")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
p.wait()
|
||||
|
||||
if p.returncode != 0:
|
||||
if debug_logger:
|
||||
debug_logger.log(
|
||||
level="ERROR",
|
||||
operation="downloader_aria2c_failed",
|
||||
message=f"Aria2c exited with code {p.returncode}",
|
||||
context={
|
||||
"returncode": p.returncode,
|
||||
"url_count": len(urls),
|
||||
"output_dir": str(output_dir),
|
||||
},
|
||||
)
|
||||
raise subprocess.CalledProcessError(p.returncode, arguments)
|
||||
|
||||
if debug_logger:
|
||||
debug_logger.log(
|
||||
level="DEBUG",
|
||||
operation="downloader_aria2c_complete",
|
||||
message="Aria2c download completed successfully",
|
||||
context={
|
||||
"url_count": len(urls),
|
||||
"output_dir": str(output_dir),
|
||||
"filename": filename,
|
||||
},
|
||||
)
|
||||
|
||||
except ConnectionResetError:
|
||||
# interrupted while passing URI to download
|
||||
raise KeyboardInterrupt()
|
||||
except subprocess.CalledProcessError as e:
|
||||
if e.returncode in (7, 0xC000013A):
|
||||
# 7 is when Aria2(c) handled the CTRL+C
|
||||
# 0xC000013A is when it never got the chance to
|
||||
raise KeyboardInterrupt()
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
DOWNLOAD_CANCELLED.set() # skip pending track downloads
|
||||
yield dict(downloaded="[yellow]CANCELLED")
|
||||
DOWNLOAD_CANCELLED.set()
|
||||
raise
|
||||
except Exception as e:
|
||||
DOWNLOAD_CANCELLED.set() # skip pending track downloads
|
||||
DOWNLOAD_CANCELLED.set()
|
||||
yield dict(downloaded="[red]FAILED")
|
||||
if debug_logger and not isinstance(e, (subprocess.CalledProcessError, ValueError)):
|
||||
if debug_logger and not isinstance(e, ValueError):
|
||||
debug_logger.log(
|
||||
level="ERROR",
|
||||
operation="downloader_aria2c_exception",
|
||||
@@ -335,8 +400,6 @@ def download(
|
||||
},
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.shutdown")
|
||||
|
||||
|
||||
def aria2c(
|
||||
|
||||
Reference in New Issue
Block a user