diff --git a/unshackle/vaults/API.py b/unshackle/vaults/API.py index 0cc52fe..d627ecc 100644 --- a/unshackle/vaults/API.py +++ b/unshackle/vaults/API.py @@ -114,32 +114,49 @@ class API(Vault): return added or updated def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int: - data = self.session.post( - url=f"{self.uri}/{service.lower()}", - json={"content_keys": {str(kid).replace("-", ""): key for kid, key in kid_keys.items()}}, - headers={"Accept": "application/json"}, - ).json() + # Normalize keys + normalized_keys = {str(kid).replace("-", ""): key for kid, key in kid_keys.items()} + kid_list = list(normalized_keys.keys()) - code = int(data.get("code", 0)) - message = data.get("message") - error = { - 0: None, - 1: Exceptions.AuthRejected, - 2: Exceptions.TooManyRequests, - 3: Exceptions.ServiceTagInvalid, - 4: Exceptions.KeyIdInvalid, - 5: Exceptions.ContentKeyInvalid, - }.get(code, ValueError) + if not kid_list: + return 0 - if error: - raise error(f"{message} ({code})") + # Batch requests to avoid server limits + batch_size = 500 + total_added = 0 - # each kid:key that was new to the vault (optional) - added = int(data.get("added")) - # each key for a kid that was changed/updated (optional) - updated = int(data.get("updated")) + for i in range(0, len(kid_list), batch_size): + batch_kids = kid_list[i : i + batch_size] + batch_keys = {kid: normalized_keys[kid] for kid in batch_kids} - return added + updated + data = self.session.post( + url=f"{self.uri}/{service.lower()}", + json={"content_keys": batch_keys}, + headers={"Accept": "application/json"}, + ).json() + + code = int(data.get("code", 0)) + message = data.get("message") + error = { + 0: None, + 1: Exceptions.AuthRejected, + 2: Exceptions.TooManyRequests, + 3: Exceptions.ServiceTagInvalid, + 4: Exceptions.KeyIdInvalid, + 5: Exceptions.ContentKeyInvalid, + }.get(code, ValueError) + + if error: + raise error(f"{message} ({code})") + + # each kid:key that was new to the vault (optional) + added = int(data.get("added", 0)) + # each key for a kid that was changed/updated (optional) + updated = int(data.get("updated", 0)) + + total_added += added + updated + + return total_added def get_services(self) -> Iterator[str]: data = self.session.post(url=self.uri, headers={"Accept": "application/json"}).json() diff --git a/unshackle/vaults/SQLite.py b/unshackle/vaults/SQLite.py index f1922d7..a3f6447 100644 --- a/unshackle/vaults/SQLite.py +++ b/unshackle/vaults/SQLite.py @@ -119,9 +119,25 @@ class SQLite(Vault): cursor = conn.cursor() try: - placeholders = ",".join(["?"] * len(kid_keys)) - cursor.execute(f"SELECT kid FROM `{service}` WHERE kid IN ({placeholders})", list(kid_keys.keys())) - existing_kids = {row[0] for row in cursor.fetchall()} + # Query existing KIDs in batches to avoid SQLite variable limit + # Try larger batch first (newer SQLite supports 32766), fall back to 500 if needed + existing_kids: set[str] = set() + kid_list = list(kid_keys.keys()) + batch_size = 32000 + + i = 0 + while i < len(kid_list): + batch = kid_list[i : i + batch_size] + placeholders = ",".join(["?"] * len(batch)) + try: + cursor.execute(f"SELECT kid FROM `{service}` WHERE kid IN ({placeholders})", batch) + existing_kids.update(row[0] for row in cursor.fetchall()) + i += batch_size + except sqlite3.OperationalError as e: + if "too many SQL variables" in str(e) and batch_size > 500: + batch_size = 500 + continue + raise new_keys = {kid: key for kid, key in kid_keys.items() if kid not in existing_kids}