From 415544775b6cb437347f70f7fc60abcc76b60a31 Mon Sep 17 00:00:00 2001 From: Andy Date: Wed, 14 Jan 2026 23:04:54 +0000 Subject: [PATCH] fix(vaults): adaptive batch sizing for bulk key operations --- unshackle/vaults/API.py | 83 ++++++++++++++++++++++++++++---------- unshackle/vaults/SQLite.py | 22 ++++++++-- 2 files changed, 80 insertions(+), 25 deletions(-) diff --git a/unshackle/vaults/API.py b/unshackle/vaults/API.py index 0cc52fe..dad9607 100644 --- a/unshackle/vaults/API.py +++ b/unshackle/vaults/API.py @@ -114,32 +114,71 @@ class API(Vault): return added or updated def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int: - data = self.session.post( - url=f"{self.uri}/{service.lower()}", - json={"content_keys": {str(kid).replace("-", ""): key for kid, key in kid_keys.items()}}, - headers={"Accept": "application/json"}, - ).json() + # Normalize keys + normalized_keys = {str(kid).replace("-", ""): key for kid, key in kid_keys.items()} + kid_list = list(normalized_keys.keys()) - code = int(data.get("code", 0)) - message = data.get("message") - error = { - 0: None, - 1: Exceptions.AuthRejected, - 2: Exceptions.TooManyRequests, - 3: Exceptions.ServiceTagInvalid, - 4: Exceptions.KeyIdInvalid, - 5: Exceptions.ContentKeyInvalid, - }.get(code, ValueError) + if not kid_list: + return 0 - if error: - raise error(f"{message} ({code})") + # Try batches starting at 500, stepping down by 100 on failure, fallback to 1 + batch_size = 500 + total_added = 0 + i = 0 - # each kid:key that was new to the vault (optional) - added = int(data.get("added")) - # each key for a kid that was changed/updated (optional) - updated = int(data.get("updated")) + while i < len(kid_list): + batch_kids = kid_list[i : i + batch_size] + batch_keys = {kid: normalized_keys[kid] for kid in batch_kids} - return added + updated + try: + response = self.session.post( + url=f"{self.uri}/{service.lower()}", + json={"content_keys": batch_keys}, + headers={"Accept": "application/json"}, + ) + + # Check for HTTP errors that suggest batch is too large + if response.status_code in (413, 414, 400) and batch_size > 1: + if batch_size > 100: + batch_size -= 100 + else: + batch_size = 1 + continue + + data = response.json() + except Exception: + # JSON decode error or connection issue - try smaller batch + if batch_size > 1: + if batch_size > 100: + batch_size -= 100 + else: + batch_size = 1 + continue + raise + + code = int(data.get("code", 0)) + message = data.get("message") + error = { + 0: None, + 1: Exceptions.AuthRejected, + 2: Exceptions.TooManyRequests, + 3: Exceptions.ServiceTagInvalid, + 4: Exceptions.KeyIdInvalid, + 5: Exceptions.ContentKeyInvalid, + }.get(code, ValueError) + + if error: + raise error(f"{message} ({code})") + + # each kid:key that was new to the vault (optional) + added = int(data.get("added", 0)) + # each key for a kid that was changed/updated (optional) + updated = int(data.get("updated", 0)) + + total_added += added + updated + i += batch_size + + return total_added def get_services(self) -> Iterator[str]: data = self.session.post(url=self.uri, headers={"Accept": "application/json"}).json() diff --git a/unshackle/vaults/SQLite.py b/unshackle/vaults/SQLite.py index f1922d7..a3f6447 100644 --- a/unshackle/vaults/SQLite.py +++ b/unshackle/vaults/SQLite.py @@ -119,9 +119,25 @@ class SQLite(Vault): cursor = conn.cursor() try: - placeholders = ",".join(["?"] * len(kid_keys)) - cursor.execute(f"SELECT kid FROM `{service}` WHERE kid IN ({placeholders})", list(kid_keys.keys())) - existing_kids = {row[0] for row in cursor.fetchall()} + # Query existing KIDs in batches to avoid SQLite variable limit + # Try larger batch first (newer SQLite supports 32766), fall back to 500 if needed + existing_kids: set[str] = set() + kid_list = list(kid_keys.keys()) + batch_size = 32000 + + i = 0 + while i < len(kid_list): + batch = kid_list[i : i + batch_size] + placeholders = ",".join(["?"] * len(batch)) + try: + cursor.execute(f"SELECT kid FROM `{service}` WHERE kid IN ({placeholders})", batch) + existing_kids.update(row[0] for row in cursor.fetchall()) + i += batch_size + except sqlite3.OperationalError as e: + if "too many SQL variables" in str(e) and batch_size > 500: + batch_size = 500 + continue + raise new_keys = {kid: key for kid, key in kid_keys.items() if kid not in existing_kids}