Files
unshackle/unshackle/core/downloaders/requests.py
imSp4rky 8bdb942234 feat(dl): add download resume support via HTTP Range headers
Partial downloads are now preserved across interruptions and retries. When a control file and partial data exist, the downloader sends a Range header to resume from the last byte. Falls back to full re-download if the server doesn't support Range requests (no 206).
2026-04-12 11:40:15 -06:00

488 lines
19 KiB
Python

import math
import os
import time
from concurrent.futures import FIRST_COMPLETED, wait
from concurrent.futures.thread import ThreadPoolExecutor
from http.cookiejar import CookieJar
from pathlib import Path
from queue import Empty, Queue
from typing import Any, Generator, MutableMapping, Optional, Union
from requests import Session
from requests.adapters import HTTPAdapter
from rich import filesize
from unshackle.core.constants import DOWNLOAD_CANCELLED
from unshackle.core.utilities import get_debug_logger, get_extension
MAX_ATTEMPTS = 5
RETRY_WAIT = 2
PROGRESS_WINDOW = 2
# Adaptive chunk sizing — benchmarked optimal range
MIN_CHUNK = 524_288 # 512KB
MAX_CHUNK = 4_194_304 # 4MB
DEFAULT_CHUNK = 524_288 # 512KB
SPEED_ROLLING_WINDOW = 10 # seconds of history to keep for speed calculation
def _adaptive_chunk_size(content_length: int) -> int:
"""Pick chunk size based on content length. Benchmarked sweet spot: 512KB-4MB."""
if content_length <= 0:
return DEFAULT_CHUNK
return min(MAX_CHUNK, max(MIN_CHUNK, content_length // 4))
def _is_requests_session(session: Any) -> bool:
"""Check if the session is a standard requests.Session (supports resp.raw)."""
return isinstance(session, Session)
def _is_rnet_session(session: Any) -> bool:
"""Check if the session is an RnetSession (uses resp.stream())."""
from unshackle.core.session import RnetSession
return isinstance(session, RnetSession)
def download(
url: str,
save_path: Path,
session: Optional[Any] = None,
segmented: bool = False,
**kwargs: Any,
) -> Generator[dict[str, Any], None, None]:
"""
Download a file with optimized I/O.
Supports both requests.Session and RnetSession for TLS fingerprinting.
Uses raw socket reads for requests.Session and native rnet streaming for RnetSession.
Yields the following download status updates while chunks are downloading:
- {total: 123} (there are 123 chunks to download)
- {total: None} (there are an unknown number of chunks to download)
- {advance: 1} (one chunk was downloaded)
- {downloaded: "10.1 MB/s"} (currently downloading at a rate of 10.1 MB/s)
- {file_downloaded: Path(...), written: 1024} (download finished, has the save path and size)
Parameters:
url: Web URL of a file to download.
save_path: The path to save the file to. If the save path's directory does not
exist then it will be made automatically.
session: A requests.Session or RnetSession to make HTTP requests with.
RnetSession preserves TLS fingerprinting for services that need it.
segmented: If downloads are segments or parts of one bigger file.
kwargs: Any extra keyword arguments to pass to the session.get() call. Use this
for one-time request changes like a header, cookie, or proxy. For example,
to request Byte-ranges use e.g., `headers={"Range": "bytes=0-128"}`.
"""
session = session or Session()
save_dir = save_path.parent
control_file = save_path.with_name(f"{save_path.name}.!dev")
save_dir.mkdir(parents=True, exist_ok=True)
resume_offset = 0
if control_file.exists() and save_path.exists():
resume_offset = save_path.stat().st_size
elif control_file.exists():
control_file.unlink()
elif save_path.exists():
yield dict(file_downloaded=save_path, written=save_path.stat().st_size)
return
control_file.write_bytes(b"")
_time = time.time
use_raw = _is_requests_session(session)
attempts = 1
completed = False
try:
while True:
written = 0
last_speed_refresh = _time()
try:
use_rnet = _is_rnet_session(session)
request_kwargs = dict(kwargs)
if resume_offset > 0:
req_headers = dict(request_kwargs.get("headers", {}) or {})
req_headers["Range"] = f"bytes={resume_offset}-"
request_kwargs["headers"] = req_headers
stream = session.get(url, stream=True, **request_kwargs)
stream.raise_for_status()
resumed = resume_offset > 0 and stream.status_code == 206
if resume_offset > 0 and not resumed:
resume_offset = 0
if use_rnet:
content_length = stream.content_length or 0
else:
try:
content_length = int(stream.headers.get("Content-Length", "0"))
if stream.headers.get("Content-Encoding", "").lower() in ["gzip", "deflate", "br"]:
content_length = 0
except ValueError:
content_length = 0
chunk_size = _adaptive_chunk_size(content_length)
total_size = (resume_offset + content_length) if resumed and content_length > 0 else content_length
if not segmented:
if total_size > 0:
yield dict(total=total_size)
else:
yield dict(total=None)
if resumed and resume_offset > 0:
yield dict(advance=resume_offset)
file_mode = "ab" if resumed else "wb"
with open(save_path, file_mode, buffering=1_048_576) as f:
if not resumed and content_length > 0:
f.truncate(content_length)
f.seek(0)
_write = f.write
if use_rnet:
chunks = stream.stream()
elif use_raw:
_read = stream.raw.read
def _chunks() -> Generator[bytes, None, None]:
while True:
chunk = _read(chunk_size)
if not chunk:
break
yield chunk
stream.close()
chunks = _chunks()
else:
def _chunks_iter() -> Generator[bytes, None, None]:
yield from stream.iter_content(chunk_size=chunk_size)
stream.close()
chunks = _chunks_iter()
_data_accumulated = 0
_bytes_since_yield = 0
for chunk in chunks:
if DOWNLOAD_CANCELLED.is_set():
break
_write(chunk)
download_size = len(chunk)
written += download_size
if not segmented:
_bytes_since_yield += download_size
_data_accumulated += download_size
now = _time()
time_since = now - last_speed_refresh
if time_since > PROGRESS_WINDOW:
yield dict(advance=_bytes_since_yield)
_bytes_since_yield = 0
download_speed = math.ceil(_data_accumulated / (time_since or 1))
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
last_speed_refresh = now
_data_accumulated = 0
if not segmented and _bytes_since_yield > 0:
yield dict(advance=_bytes_since_yield)
if not resumed and content_length > 0 and written != content_length:
f.truncate(written)
if not segmented and content_length and written < content_length:
raise IOError(f"Failed to read {content_length} bytes from the track URI.")
yield dict(file_downloaded=save_path, written=resume_offset + written)
if segmented:
yield dict(advance=1)
completed = True
break
except Exception:
if DOWNLOAD_CANCELLED.is_set() or attempts == MAX_ATTEMPTS:
return
if save_path.exists():
resume_offset = save_path.stat().st_size
time.sleep(RETRY_WAIT)
attempts += 1
finally:
if completed:
control_file.unlink(missing_ok=True)
def requests(
urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]],
output_dir: Path,
filename: str,
headers: Optional[MutableMapping[str, Union[str, bytes]]] = None,
cookies: Optional[Union[MutableMapping[str, str], CookieJar]] = None,
proxy: Optional[str] = None,
max_workers: Optional[int] = None,
session: Optional[Any] = None,
) -> Generator[dict[str, Any], None, None]:
"""
Download files with optimized I/O and adaptive chunk sizing.
Supports both requests.Session and RnetSession. When a RnetSession is
provided (e.g. from a service's get_session()), TLS fingerprinting is preserved
on all segment downloads.
Yields the following download status updates while chunks are downloading:
- {total: 123} (there are 123 chunks to download)
- {total: None} (there are an unknown number of chunks to download)
- {advance: 1} (one chunk was downloaded)
- {downloaded: "10.1 MB/s"} (currently downloading at a rate of 10.1 MB/s)
- {file_downloaded: Path(...), written: 1024} (download finished, has the save path and size)
The data is in the same format accepted by rich's progress.update() function.
However, The `downloaded`, `file_downloaded` and `written` keys are custom and not
natively accepted by rich progress bars.
Parameters:
urls: Web URL(s) to file(s) to download. You can use a dictionary with the key
"url" for the URI, and other keys for extra arguments to use per-URL.
output_dir: The folder to save the file into. If the save path's directory does
not exist then it will be made automatically.
filename: The filename or filename template to use for each file. The variables
you can use are `i` for the URL index and `ext` for the URL extension.
headers: A mapping of HTTP Header Key/Values to use for all downloads.
cookies: A mapping of Cookie Key/Values or a Cookie Jar to use for all downloads.
proxy: An optional proxy URI to route connections through for all downloads.
max_workers: The maximum amount of threads to use for downloads. Defaults to
min(12,(cpu_count+4)).
session: An optional requests.Session or RnetSession to use. If provided,
it will be used directly (preserving TLS fingerprinting). If None, a new
requests.Session with HTTPAdapter connection pooling will be created.
"""
if not urls:
raise ValueError("urls must be provided and not empty")
elif not isinstance(urls, (str, dict, list)):
raise TypeError(f"Expected urls to be {str} or {dict} or a list of one of them, not {type(urls)}")
if not output_dir:
raise ValueError("output_dir must be provided")
elif not isinstance(output_dir, Path):
raise TypeError(f"Expected output_dir to be {Path}, not {type(output_dir)}")
if not filename:
raise ValueError("filename must be provided")
elif not isinstance(filename, str):
raise TypeError(f"Expected filename to be {str}, not {type(filename)}")
if not isinstance(headers, (MutableMapping, type(None))):
raise TypeError(f"Expected headers to be {MutableMapping}, not {type(headers)}")
if not isinstance(cookies, (MutableMapping, CookieJar, type(None))):
raise TypeError(f"Expected cookies to be {MutableMapping} or {CookieJar}, not {type(cookies)}")
if not isinstance(proxy, (str, type(None))):
raise TypeError(f"Expected proxy to be {str}, not {type(proxy)}")
if not isinstance(max_workers, (int, type(None))):
raise TypeError(f"Expected max_workers to be {int}, not {type(max_workers)}")
debug_logger = get_debug_logger()
if not isinstance(urls, list):
urls = [urls]
if not max_workers:
max_workers = min(16, (os.cpu_count() or 1) + 4)
urls = [
dict(save_path=save_path, **url) if isinstance(url, dict) else dict(url=url, save_path=save_path)
for i, url in enumerate(urls)
for save_path in [
output_dir / filename.format(i=i, ext=get_extension(url["url"] if isinstance(url, dict) else url))
]
]
# Use provided session or create a new optimized requests.Session
# When a session is provided (e.g., service's RnetSession), don't mutate headers/cookies/proxy —
# they're already set and the session may be shared across tracks.
if session is None:
session = Session()
if headers:
headers = {k: v for k, v in headers.items() if k.lower() != "accept-encoding"}
session.headers.update(headers)
if cookies:
session.cookies.update(cookies)
if proxy:
session.proxies.update({"all": proxy})
# Mount HTTPAdapter with connection pooling sized to worker count.
# Safe to do on any requests.Session — improves connection reuse for parallel downloads.
if _is_requests_session(session):
adapter = HTTPAdapter(pool_connections=max_workers, pool_maxsize=max_workers, pool_block=True)
session.mount("https://", adapter)
session.mount("http://", adapter)
if debug_logger:
first_url = urls[0].get("url", "") if urls else ""
url_display = first_url[:200] + "..." if len(first_url) > 200 else first_url
debug_logger.log(
level="DEBUG",
operation="downloader_start",
message="Starting download",
context={
"url_count": len(urls),
"first_url": url_display,
"output_dir": str(output_dir),
"filename": filename,
"max_workers": max_workers,
"has_proxy": bool(proxy),
"session_type": type(session).__name__,
},
)
segmented_batch = len(urls) > 1
# Fast path: single URL — no thread pool overhead
if len(urls) == 1:
try:
yield from download(
session=session,
segmented=segmented_batch,
**urls[0],
)
except KeyboardInterrupt:
DOWNLOAD_CANCELLED.set()
yield dict(downloaded="[yellow]CANCELLED")
raise
else:
# Segmented download with thread pool
# Speed is tracked here on the main thread, not in workers
total_bytes = 0
start_time = time.time()
last_speed_report = start_time
pool = ThreadPoolExecutor(max_workers=max_workers)
event_queue: Queue[dict[str, Any]] = Queue()
def _download_worker(url_item: dict[str, Any]) -> None:
for event in download(
session=session,
segmented=segmented_batch,
**url_item,
):
event_queue.put(event)
futures = [pool.submit(_download_worker, url) for url in urls]
pending = set(futures)
pending_advance = 0
try:
while pending:
# Drain queued events — batch advances, track bytes for speed
while True:
try:
event = event_queue.get_nowait()
except Empty:
break
# Accumulate advance events for batched yield
advance = event.get("advance")
if advance:
pending_advance += advance
continue
# Track bytes from completed segments for speed calculation
written = event.get("written")
if written:
total_bytes += written
# Pass through other events (file_downloaded, total, etc.)
yield event
# Yield batched advances every drain cycle for responsive progress bar
if pending_advance > 0:
yield dict(advance=pending_advance)
pending_advance = 0
# Yield speed every 0.5s (throttled to avoid spamming Rich)
now = time.time()
if now - last_speed_report > 0.5 and total_bytes > 0:
elapsed = now - start_time
if elapsed > 0:
download_speed = math.ceil(total_bytes / elapsed)
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
last_speed_report = now
# Wait efficiently for next future completion (OS condition variable)
completed, pending = wait(pending, timeout=0.1, return_when=FIRST_COMPLETED)
for future in completed:
exc = future.exception()
if isinstance(exc, KeyboardInterrupt):
raise KeyboardInterrupt()
elif exc:
DOWNLOAD_CANCELLED.set()
yield dict(downloaded="[red]FAILING")
pool.shutdown(wait=False, cancel_futures=True)
yield dict(downloaded="[red]FAILED")
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_failed",
message=f"Download failed: {exc}",
error=exc,
context={
"url_count": len(urls),
"output_dir": str(output_dir),
},
)
raise exc
except KeyboardInterrupt:
DOWNLOAD_CANCELLED.set()
yield dict(downloaded="[yellow]CANCELLING")
pool.shutdown(wait=False, cancel_futures=True)
yield dict(downloaded="[yellow]CANCELLED")
raise
finally:
pool.shutdown(wait=False, cancel_futures=True)
# Drain remaining events
while True:
try:
event = event_queue.get_nowait()
except Empty:
break
advance = event.get("advance")
if advance:
pending_advance += advance
continue
written = event.get("written")
if written:
total_bytes += written
yield event
# Flush remaining advances and final speed
if pending_advance > 0:
yield dict(advance=pending_advance)
elapsed = time.time() - start_time
if elapsed > 0 and total_bytes > 0:
download_speed = math.ceil(total_bytes / elapsed)
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
if debug_logger:
debug_logger.log(
level="DEBUG",
operation="downloader_complete",
message="Download completed successfully",
context={
"url_count": len(urls),
"output_dir": str(output_dir),
"filename": filename,
},
)
__all__ = ("requests",)