diff --git a/unshackle/vaults/API.py b/unshackle/vaults/API.py index d627ecc..0cc52fe 100644 --- a/unshackle/vaults/API.py +++ b/unshackle/vaults/API.py @@ -114,49 +114,32 @@ class API(Vault): return added or updated def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int: - # Normalize keys - normalized_keys = {str(kid).replace("-", ""): key for kid, key in kid_keys.items()} - kid_list = list(normalized_keys.keys()) + data = self.session.post( + url=f"{self.uri}/{service.lower()}", + json={"content_keys": {str(kid).replace("-", ""): key for kid, key in kid_keys.items()}}, + headers={"Accept": "application/json"}, + ).json() - if not kid_list: - return 0 + code = int(data.get("code", 0)) + message = data.get("message") + error = { + 0: None, + 1: Exceptions.AuthRejected, + 2: Exceptions.TooManyRequests, + 3: Exceptions.ServiceTagInvalid, + 4: Exceptions.KeyIdInvalid, + 5: Exceptions.ContentKeyInvalid, + }.get(code, ValueError) - # Batch requests to avoid server limits - batch_size = 500 - total_added = 0 + if error: + raise error(f"{message} ({code})") - for i in range(0, len(kid_list), batch_size): - batch_kids = kid_list[i : i + batch_size] - batch_keys = {kid: normalized_keys[kid] for kid in batch_kids} + # each kid:key that was new to the vault (optional) + added = int(data.get("added")) + # each key for a kid that was changed/updated (optional) + updated = int(data.get("updated")) - data = self.session.post( - url=f"{self.uri}/{service.lower()}", - json={"content_keys": batch_keys}, - headers={"Accept": "application/json"}, - ).json() - - code = int(data.get("code", 0)) - message = data.get("message") - error = { - 0: None, - 1: Exceptions.AuthRejected, - 2: Exceptions.TooManyRequests, - 3: Exceptions.ServiceTagInvalid, - 4: Exceptions.KeyIdInvalid, - 5: Exceptions.ContentKeyInvalid, - }.get(code, ValueError) - - if error: - raise error(f"{message} ({code})") - - # each kid:key that was new to the vault (optional) - added = int(data.get("added", 0)) - # each key for a kid that was changed/updated (optional) - updated = int(data.get("updated", 0)) - - total_added += added + updated - - return total_added + return added + updated def get_services(self) -> Iterator[str]: data = self.session.post(url=self.uri, headers={"Accept": "application/json"}).json() diff --git a/unshackle/vaults/SQLite.py b/unshackle/vaults/SQLite.py index a3f6447..f1922d7 100644 --- a/unshackle/vaults/SQLite.py +++ b/unshackle/vaults/SQLite.py @@ -119,25 +119,9 @@ class SQLite(Vault): cursor = conn.cursor() try: - # Query existing KIDs in batches to avoid SQLite variable limit - # Try larger batch first (newer SQLite supports 32766), fall back to 500 if needed - existing_kids: set[str] = set() - kid_list = list(kid_keys.keys()) - batch_size = 32000 - - i = 0 - while i < len(kid_list): - batch = kid_list[i : i + batch_size] - placeholders = ",".join(["?"] * len(batch)) - try: - cursor.execute(f"SELECT kid FROM `{service}` WHERE kid IN ({placeholders})", batch) - existing_kids.update(row[0] for row in cursor.fetchall()) - i += batch_size - except sqlite3.OperationalError as e: - if "too many SQL variables" in str(e) and batch_size > 500: - batch_size = 500 - continue - raise + placeholders = ",".join(["?"] * len(kid_keys)) + cursor.execute(f"SELECT kid FROM `{service}` WHERE kid IN ({placeholders})", list(kid_keys.keys())) + existing_kids = {row[0] for row in cursor.fetchall()} new_keys = {kid: key for kid, key in kid_keys.items() if kid not in existing_kids}