mirror of
https://github.com/unshackle-dl/unshackle.git
synced 2026-03-12 01:19:02 +00:00
Initial Commit
This commit is contained in:
184
unshackle/vaults/API.py
Normal file
184
unshackle/vaults/API.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from typing import Iterator, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from requests import Session
|
||||
|
||||
from unshackle.core import __version__
|
||||
from unshackle.core.vault import Vault
|
||||
|
||||
|
||||
class API(Vault):
|
||||
"""Key Vault using a simple RESTful HTTP API call."""
|
||||
|
||||
def __init__(self, name: str, uri: str, token: str):
|
||||
super().__init__(name)
|
||||
self.uri = uri.rstrip("/")
|
||||
self.session = Session()
|
||||
self.session.headers.update({"User-Agent": f"unshackle v{__version__}"})
|
||||
self.session.headers.update({"Authorization": f"Bearer {token}"})
|
||||
|
||||
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
|
||||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
data = self.session.get(
|
||||
url=f"{self.uri}/{service.lower()}/{kid}", headers={"Accept": "application/json"}
|
||||
).json()
|
||||
|
||||
code = int(data.get("code", 0))
|
||||
message = data.get("message")
|
||||
error = {
|
||||
0: None,
|
||||
1: Exceptions.AuthRejected,
|
||||
2: Exceptions.TooManyRequests,
|
||||
3: Exceptions.ServiceTagInvalid,
|
||||
4: Exceptions.KeyIdInvalid,
|
||||
}.get(code, ValueError)
|
||||
|
||||
if error:
|
||||
raise error(f"{message} ({code})")
|
||||
|
||||
content_key = data.get("content_key")
|
||||
if not content_key:
|
||||
return None
|
||||
|
||||
if not isinstance(content_key, str):
|
||||
raise ValueError(f"Expected {content_key} to be {str}, was {type(content_key)}")
|
||||
|
||||
return content_key
|
||||
|
||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||
page = 1
|
||||
|
||||
while True:
|
||||
data = self.session.get(
|
||||
url=f"{self.uri}/{service.lower()}",
|
||||
params={"page": page, "total": 10},
|
||||
headers={"Accept": "application/json"},
|
||||
).json()
|
||||
|
||||
code = int(data.get("code", 0))
|
||||
message = data.get("message")
|
||||
error = {
|
||||
0: None,
|
||||
1: Exceptions.AuthRejected,
|
||||
2: Exceptions.TooManyRequests,
|
||||
3: Exceptions.PageInvalid,
|
||||
4: Exceptions.ServiceTagInvalid,
|
||||
}.get(code, ValueError)
|
||||
|
||||
if error:
|
||||
raise error(f"{message} ({code})")
|
||||
|
||||
content_keys = data.get("content_keys")
|
||||
if content_keys:
|
||||
if not isinstance(content_keys, dict):
|
||||
raise ValueError(f"Expected {content_keys} to be {dict}, was {type(content_keys)}")
|
||||
|
||||
for key_id, key in content_keys.items():
|
||||
yield key_id, key
|
||||
|
||||
pages = int(data["pages"])
|
||||
if pages <= page:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
data = self.session.post(
|
||||
url=f"{self.uri}/{service.lower()}/{kid}", json={"content_key": key}, headers={"Accept": "application/json"}
|
||||
).json()
|
||||
|
||||
code = int(data.get("code", 0))
|
||||
message = data.get("message")
|
||||
error = {
|
||||
0: None,
|
||||
1: Exceptions.AuthRejected,
|
||||
2: Exceptions.TooManyRequests,
|
||||
3: Exceptions.ServiceTagInvalid,
|
||||
4: Exceptions.KeyIdInvalid,
|
||||
5: Exceptions.ContentKeyInvalid,
|
||||
}.get(code, ValueError)
|
||||
|
||||
if error:
|
||||
raise error(f"{message} ({code})")
|
||||
|
||||
# the kid:key was new to the vault (optional)
|
||||
added = bool(data.get("added"))
|
||||
# the key for kid was changed/updated (optional)
|
||||
updated = bool(data.get("updated"))
|
||||
|
||||
return added or updated
|
||||
|
||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
|
||||
data = self.session.post(
|
||||
url=f"{self.uri}/{service.lower()}",
|
||||
json={"content_keys": {str(kid).replace("-", ""): key for kid, key in kid_keys.items()}},
|
||||
headers={"Accept": "application/json"},
|
||||
).json()
|
||||
|
||||
code = int(data.get("code", 0))
|
||||
message = data.get("message")
|
||||
error = {
|
||||
0: None,
|
||||
1: Exceptions.AuthRejected,
|
||||
2: Exceptions.TooManyRequests,
|
||||
3: Exceptions.ServiceTagInvalid,
|
||||
4: Exceptions.KeyIdInvalid,
|
||||
5: Exceptions.ContentKeyInvalid,
|
||||
}.get(code, ValueError)
|
||||
|
||||
if error:
|
||||
raise error(f"{message} ({code})")
|
||||
|
||||
# each kid:key that was new to the vault (optional)
|
||||
added = int(data.get("added"))
|
||||
# each key for a kid that was changed/updated (optional)
|
||||
updated = int(data.get("updated"))
|
||||
|
||||
return added + updated
|
||||
|
||||
def get_services(self) -> Iterator[str]:
|
||||
data = self.session.post(url=self.uri, headers={"Accept": "application/json"}).json()
|
||||
|
||||
code = int(data.get("code", 0))
|
||||
message = data.get("message")
|
||||
error = {
|
||||
0: None,
|
||||
1: Exceptions.AuthRejected,
|
||||
2: Exceptions.TooManyRequests,
|
||||
}.get(code, ValueError)
|
||||
|
||||
if error:
|
||||
raise error(f"{message} ({code})")
|
||||
|
||||
service_list = data.get("service_list", [])
|
||||
|
||||
if not isinstance(service_list, list):
|
||||
raise ValueError(f"Expected {service_list} to be {list}, was {type(service_list)}")
|
||||
|
||||
for service in service_list:
|
||||
yield service
|
||||
|
||||
|
||||
class Exceptions:
|
||||
class AuthRejected(Exception):
|
||||
"""Authentication Error Occurred, is your token valid? Do you have permission to make this call?"""
|
||||
|
||||
class TooManyRequests(Exception):
|
||||
"""Rate Limited; Sent too many requests in a given amount of time."""
|
||||
|
||||
class PageInvalid(Exception):
|
||||
"""Requested page does not exist."""
|
||||
|
||||
class ServiceTagInvalid(Exception):
|
||||
"""The Service Tag is invalid."""
|
||||
|
||||
class KeyIdInvalid(Exception):
|
||||
"""The Key ID is invalid."""
|
||||
|
||||
class ContentKeyInvalid(Exception):
|
||||
"""The Content Key is invalid."""
|
||||
326
unshackle/vaults/HTTP.py
Normal file
326
unshackle/vaults/HTTP.py
Normal file
@@ -0,0 +1,326 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Iterator, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from requests import Session
|
||||
|
||||
from unshackle.core import __version__
|
||||
from unshackle.core.vault import Vault
|
||||
|
||||
|
||||
class InsertResult(Enum):
|
||||
FAILURE = 0
|
||||
SUCCESS = 1
|
||||
ALREADY_EXISTS = 2
|
||||
|
||||
|
||||
class HTTP(Vault):
|
||||
"""Key Vault using HTTP API with support for both query parameters and JSON payloads."""
|
||||
|
||||
def __init__(self, name: str, host: str, password: str, username: Optional[str] = None, api_mode: str = "query"):
|
||||
"""
|
||||
Initialize HTTP Vault.
|
||||
|
||||
Args:
|
||||
name: Vault name
|
||||
host: Host URL
|
||||
password: Password for query mode or API token for json mode
|
||||
username: Username (required for query mode, ignored for json mode)
|
||||
api_mode: "query" for query parameters or "json" for JSON API
|
||||
"""
|
||||
super().__init__(name)
|
||||
self.url = host.rstrip("/")
|
||||
self.password = password
|
||||
self.username = username
|
||||
self.api_mode = api_mode.lower()
|
||||
self.current_title = None
|
||||
self.session = Session()
|
||||
self.session.headers.update({"User-Agent": f"unshackle v{__version__}"})
|
||||
self.api_session_id = None
|
||||
|
||||
# Validate configuration based on mode
|
||||
if self.api_mode == "query" and not self.username:
|
||||
raise ValueError("Username is required for query mode")
|
||||
elif self.api_mode not in ["query", "json"]:
|
||||
raise ValueError("api_mode must be either 'query' or 'json'")
|
||||
|
||||
def request(self, method: str, params: dict = None) -> dict:
|
||||
"""Make a request to the JSON API vault."""
|
||||
if self.api_mode != "json":
|
||||
raise ValueError("request method is only available in json mode")
|
||||
|
||||
request_payload = {
|
||||
"method": method,
|
||||
"params": {
|
||||
**(params or {}),
|
||||
"session_id": self.api_session_id,
|
||||
},
|
||||
"token": self.password,
|
||||
}
|
||||
|
||||
r = self.session.post(self.url, json=request_payload)
|
||||
|
||||
if r.status_code == 404:
|
||||
return {"status": "not_found"}
|
||||
|
||||
if not r.ok:
|
||||
raise ValueError(f"API returned HTTP Error {r.status_code}: {r.reason.title()}")
|
||||
|
||||
try:
|
||||
res = r.json()
|
||||
except json.JSONDecodeError:
|
||||
if r.status_code == 404:
|
||||
return {"status": "not_found"}
|
||||
raise ValueError(f"API returned an invalid response: {r.text}")
|
||||
|
||||
if res.get("status_code") != 200:
|
||||
raise ValueError(f"API returned an error: {res['status_code']} - {res['message']}")
|
||||
|
||||
if session_id := res.get("message", {}).get("session_id"):
|
||||
self.api_session_id = session_id
|
||||
|
||||
return res.get("message", res)
|
||||
|
||||
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
|
||||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
if self.api_mode == "json":
|
||||
try:
|
||||
title = getattr(self, "current_title", None)
|
||||
response = self.request(
|
||||
"GetKey",
|
||||
{
|
||||
"kid": kid,
|
||||
"service": service.lower(),
|
||||
"title": title,
|
||||
},
|
||||
)
|
||||
if response.get("status") == "not_found":
|
||||
return None
|
||||
keys = response.get("keys", [])
|
||||
for key_entry in keys:
|
||||
if key_entry["kid"] == kid:
|
||||
return key_entry["key"]
|
||||
except Exception as e:
|
||||
print(f"Failed to get key ({e.__class__.__name__}: {e})")
|
||||
return None
|
||||
return None
|
||||
else: # query mode
|
||||
response = self.session.get(
|
||||
self.url,
|
||||
params={"service": service.lower(), "username": self.username, "password": self.password, "kid": kid},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
if data.get("status_code") != 200 or not data.get("keys"):
|
||||
return None
|
||||
|
||||
return data["keys"][0]["key"]
|
||||
|
||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||
if self.api_mode == "json":
|
||||
# JSON API doesn't support getting all keys, so return empty iterator
|
||||
# This will cause the copy command to rely on the API's internal duplicate handling
|
||||
return iter([])
|
||||
else: # query mode
|
||||
response = self.session.get(
|
||||
self.url, params={"service": service.lower(), "username": self.username, "password": self.password}
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
if data.get("status_code") != 200 or not data.get("keys"):
|
||||
return
|
||||
|
||||
for key_entry in data["keys"]:
|
||||
yield key_entry["kid"], key_entry["key"]
|
||||
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||
if not key or key.count("0") == len(key):
|
||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||
|
||||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
title = getattr(self, "current_title", None)
|
||||
|
||||
if self.api_mode == "json":
|
||||
try:
|
||||
response = self.request(
|
||||
"InsertKey",
|
||||
{
|
||||
"kid": kid,
|
||||
"key": key,
|
||||
"service": service.lower(),
|
||||
"title": title,
|
||||
},
|
||||
)
|
||||
if response.get("status") == "not_found":
|
||||
return False
|
||||
return response.get("inserted", False)
|
||||
except Exception:
|
||||
return False
|
||||
else: # query mode
|
||||
response = self.session.get(
|
||||
self.url,
|
||||
params={
|
||||
"service": service.lower(),
|
||||
"username": self.username,
|
||||
"password": self.password,
|
||||
"kid": kid,
|
||||
"key": key,
|
||||
"title": title,
|
||||
},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
return data.get("status_code") == 200
|
||||
|
||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
|
||||
for kid, key in kid_keys.items():
|
||||
if not key or key.count("0") == len(key):
|
||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||
|
||||
processed_kid_keys = {
|
||||
str(kid).replace("-", "") if isinstance(kid, UUID) else kid: key for kid, key in kid_keys.items()
|
||||
}
|
||||
|
||||
inserted_count = 0
|
||||
title = getattr(self, "current_title", None)
|
||||
|
||||
if self.api_mode == "json":
|
||||
for kid, key in processed_kid_keys.items():
|
||||
try:
|
||||
response = self.request(
|
||||
"InsertKey",
|
||||
{
|
||||
"kid": kid,
|
||||
"key": key,
|
||||
"service": service.lower(),
|
||||
"title": title,
|
||||
},
|
||||
)
|
||||
if response.get("status") == "not_found":
|
||||
continue
|
||||
if response.get("inserted", False):
|
||||
inserted_count += 1
|
||||
except Exception:
|
||||
continue
|
||||
else: # query mode
|
||||
for kid, key in processed_kid_keys.items():
|
||||
response = self.session.get(
|
||||
self.url,
|
||||
params={
|
||||
"service": service.lower(),
|
||||
"username": self.username,
|
||||
"password": self.password,
|
||||
"kid": kid,
|
||||
"key": key,
|
||||
"title": title,
|
||||
},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
if data.get("status_code") == 200 and data.get("inserted", True):
|
||||
inserted_count += 1
|
||||
|
||||
return inserted_count
|
||||
|
||||
def get_services(self) -> Iterator[str]:
|
||||
if self.api_mode == "json":
|
||||
try:
|
||||
response = self.request("GetServices")
|
||||
services = response.get("services", [])
|
||||
for service in services:
|
||||
yield service
|
||||
except Exception:
|
||||
return iter([])
|
||||
else: # query mode
|
||||
response = self.session.get(
|
||||
self.url, params={"username": self.username, "password": self.password, "list_services": True}
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
if data.get("status_code") != 200:
|
||||
return
|
||||
|
||||
services = data.get("services", [])
|
||||
for service in services:
|
||||
yield service
|
||||
|
||||
def set_title(self, title: str):
|
||||
"""
|
||||
Set a title to be used for the next key insertions.
|
||||
This is optional and will be sent with add_key requests if available.
|
||||
"""
|
||||
self.current_title = title
|
||||
|
||||
def insert_key_with_result(
|
||||
self, service: str, kid: Union[UUID, str], key: str, title: Optional[str] = None
|
||||
) -> InsertResult:
|
||||
"""
|
||||
Insert a key and return detailed result information.
|
||||
This method provides more granular feedback than the standard add_key method.
|
||||
Available in both API modes.
|
||||
"""
|
||||
if not key or key.count("0") == len(key):
|
||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||
|
||||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
if title is None:
|
||||
title = getattr(self, "current_title", None)
|
||||
|
||||
if self.api_mode == "json":
|
||||
try:
|
||||
response = self.request(
|
||||
"InsertKey",
|
||||
{
|
||||
"kid": kid,
|
||||
"key": key,
|
||||
"service": service.lower(),
|
||||
"title": title,
|
||||
},
|
||||
)
|
||||
|
||||
if response.get("status") == "not_found":
|
||||
return InsertResult.FAILURE
|
||||
|
||||
if response.get("inserted", False):
|
||||
return InsertResult.SUCCESS
|
||||
else:
|
||||
return InsertResult.ALREADY_EXISTS
|
||||
|
||||
except Exception:
|
||||
return InsertResult.FAILURE
|
||||
else: # query mode
|
||||
response = self.session.get(
|
||||
self.url,
|
||||
params={
|
||||
"service": service.lower(),
|
||||
"username": self.username,
|
||||
"password": self.password,
|
||||
"kid": kid,
|
||||
"key": key,
|
||||
"title": title,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
if data.get("status_code") == 200:
|
||||
if data.get("inserted", True):
|
||||
return InsertResult.SUCCESS
|
||||
else:
|
||||
return InsertResult.ALREADY_EXISTS
|
||||
else:
|
||||
return InsertResult.FAILURE
|
||||
except Exception:
|
||||
return InsertResult.FAILURE
|
||||
244
unshackle/vaults/MySQL.py
Normal file
244
unshackle/vaults/MySQL.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import threading
|
||||
from typing import Iterator, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import pymysql
|
||||
from pymysql.cursors import DictCursor
|
||||
|
||||
from unshackle.core.services import Services
|
||||
from unshackle.core.vault import Vault
|
||||
|
||||
|
||||
class MySQL(Vault):
|
||||
"""Key Vault using a remotely-accessed mysql database connection."""
|
||||
|
||||
def __init__(self, name: str, host: str, database: str, username: str, **kwargs):
|
||||
"""
|
||||
All extra arguments provided via **kwargs will be sent to pymysql.connect.
|
||||
This can be used to provide more specific connection information.
|
||||
"""
|
||||
super().__init__(name)
|
||||
self.slug = f"{host}:{database}:{username}"
|
||||
self.conn_factory = ConnectionFactory(
|
||||
dict(host=host, db=database, user=username, cursorclass=DictCursor, **kwargs)
|
||||
)
|
||||
|
||||
self.permissions = self.get_permissions()
|
||||
if not self.has_permission("SELECT"):
|
||||
raise PermissionError(f"MySQL vault {self.slug} has no SELECT permission.")
|
||||
|
||||
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
|
||||
if not self.has_table(service):
|
||||
# no table, no key, simple
|
||||
return None
|
||||
|
||||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=%s AND `key_`!=%s",
|
||||
(kid, "0" * 32),
|
||||
)
|
||||
cek = cursor.fetchone()
|
||||
if not cek:
|
||||
return None
|
||||
return cek["key_"]
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||
if not self.has_table(service):
|
||||
# no table, no keys, simple
|
||||
return None
|
||||
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=%s",
|
||||
("0" * 32,),
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
yield row["kid"], row["key_"]
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||
if not key or key.count("0") == len(key):
|
||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||
|
||||
if not self.has_permission("INSERT", table=service):
|
||||
raise PermissionError(f"MySQL vault {self.slug} has no INSERT permission.")
|
||||
|
||||
if not self.has_table(service):
|
||||
try:
|
||||
self.create_table(service)
|
||||
except PermissionError:
|
||||
return False
|
||||
|
||||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"SELECT `id` FROM `{service}` WHERE `kid`=%s AND `key_`=%s",
|
||||
(kid, key),
|
||||
)
|
||||
if cursor.fetchone():
|
||||
# table already has this exact KID:KEY stored
|
||||
return True
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"INSERT INTO `{service}` (kid, key_) VALUES (%s, %s)",
|
||||
(kid, key),
|
||||
)
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
return True
|
||||
|
||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
|
||||
for kid, key in kid_keys.items():
|
||||
if not key or key.count("0") == len(key):
|
||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||
|
||||
if not self.has_permission("INSERT", table=service):
|
||||
raise PermissionError(f"MySQL vault {self.slug} has no INSERT permission.")
|
||||
|
||||
if not self.has_table(service):
|
||||
try:
|
||||
self.create_table(service)
|
||||
except PermissionError:
|
||||
return 0
|
||||
|
||||
if not isinstance(kid_keys, dict):
|
||||
raise ValueError(f"The kid_keys provided is not a dictionary, {kid_keys!r}")
|
||||
if not all(isinstance(kid, (str, UUID)) and isinstance(key_, str) for kid, key_ in kid_keys.items()):
|
||||
raise ValueError("Expecting dict with Key of str/UUID and value of str.")
|
||||
|
||||
if any(isinstance(kid, UUID) for kid, key_ in kid_keys.items()):
|
||||
kid_keys = {kid.hex if isinstance(kid, UUID) else kid: key_ for kid, key_ in kid_keys.items()}
|
||||
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.executemany(
|
||||
# TODO: SQL injection risk
|
||||
f"INSERT IGNORE INTO `{service}` (kid, key_) VALUES (%s, %s)",
|
||||
kid_keys.items(),
|
||||
)
|
||||
return cursor.rowcount
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
def get_services(self) -> Iterator[str]:
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("SHOW TABLES")
|
||||
for table in cursor.fetchall():
|
||||
# each entry has a key named `Tables_in_<db name>`
|
||||
yield Services.get_tag(list(table.values())[0])
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def has_table(self, name: str) -> bool:
|
||||
"""Check if the Vault has a Table with the specified name."""
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
"SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
|
||||
(conn.db, name),
|
||||
)
|
||||
return list(cursor.fetchone().values())[0] == 1
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def create_table(self, name: str):
|
||||
"""Create a Table with the specified name if not yet created."""
|
||||
if self.has_table(name):
|
||||
return
|
||||
|
||||
if not self.has_permission("CREATE"):
|
||||
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
|
||||
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {name} (
|
||||
id int AUTO_INCREMENT PRIMARY KEY,
|
||||
kid VARCHAR(64) NOT NULL,
|
||||
key_ VARCHAR(64) NOT NULL,
|
||||
UNIQUE(kid, key_)
|
||||
);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
def get_permissions(self) -> list:
|
||||
"""Get and parse Grants to a more easily usable list tuple array."""
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
grants = [next(iter(x.values())) for x in grants]
|
||||
grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)]
|
||||
grants = [
|
||||
(
|
||||
list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))),
|
||||
location.replace("`", "").split("."),
|
||||
)
|
||||
for perms, location in grants
|
||||
]
|
||||
return grants
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
|
||||
"""Check if the current connection has a specific permission."""
|
||||
grants = [x for x in self.permissions if x[0] == ["*"] or operation.upper() in x[0]]
|
||||
if grants and database:
|
||||
grants = [x for x in grants if x[1][0] in (database, "*")]
|
||||
if grants and table:
|
||||
grants = [x for x in grants if x[1][1] in (table, "*")]
|
||||
return bool(grants)
|
||||
|
||||
|
||||
class ConnectionFactory:
|
||||
def __init__(self, con: dict):
|
||||
self._con = con
|
||||
self._store = threading.local()
|
||||
|
||||
def _create_connection(self) -> pymysql.Connection:
|
||||
return pymysql.connect(**self._con)
|
||||
|
||||
def get(self) -> pymysql.Connection:
|
||||
if not hasattr(self._store, "conn"):
|
||||
self._store.conn = self._create_connection()
|
||||
return self._store.conn
|
||||
179
unshackle/vaults/SQLite.py
Normal file
179
unshackle/vaults/SQLite.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from sqlite3 import Connection
|
||||
from typing import Iterator, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from unshackle.core.services import Services
|
||||
from unshackle.core.vault import Vault
|
||||
|
||||
|
||||
class SQLite(Vault):
|
||||
"""Key Vault using a locally-accessed sqlite DB file."""
|
||||
|
||||
def __init__(self, name: str, path: Union[str, Path]):
|
||||
super().__init__(name)
|
||||
self.path = Path(path).expanduser()
|
||||
# TODO: Use a DictCursor or such to get fetches as dict?
|
||||
self.conn_factory = ConnectionFactory(self.path)
|
||||
|
||||
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
|
||||
if not self.has_table(service):
|
||||
# no table, no key, simple
|
||||
return None
|
||||
|
||||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?", (kid, "0" * 32))
|
||||
cek = cursor.fetchone()
|
||||
if not cek:
|
||||
return None
|
||||
return cek[1]
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||
if not self.has_table(service):
|
||||
# no table, no keys, simple
|
||||
return None
|
||||
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?", ("0" * 32,))
|
||||
for kid, key_ in cursor.fetchall():
|
||||
yield kid, key_
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||
if not key or key.count("0") == len(key):
|
||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||
|
||||
if not self.has_table(service):
|
||||
self.create_table(service)
|
||||
|
||||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"SELECT `id` FROM `{service}` WHERE `kid`=? AND `key_`=?",
|
||||
(kid, key),
|
||||
)
|
||||
if cursor.fetchone():
|
||||
# table already has this exact KID:KEY stored
|
||||
return True
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"INSERT INTO `{service}` (kid, key_) VALUES (?, ?)",
|
||||
(kid, key),
|
||||
)
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
return True
|
||||
|
||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
|
||||
for kid, key in kid_keys.items():
|
||||
if not key or key.count("0") == len(key):
|
||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||
|
||||
if not self.has_table(service):
|
||||
self.create_table(service)
|
||||
|
||||
if not isinstance(kid_keys, dict):
|
||||
raise ValueError(f"The kid_keys provided is not a dictionary, {kid_keys!r}")
|
||||
if not all(isinstance(kid, (str, UUID)) and isinstance(key_, str) for kid, key_ in kid_keys.items()):
|
||||
raise ValueError("Expecting dict with Key of str/UUID and value of str.")
|
||||
|
||||
if any(isinstance(kid, UUID) for kid, key_ in kid_keys.items()):
|
||||
kid_keys = {kid.hex if isinstance(kid, UUID) else kid: key_ for kid, key_ in kid_keys.items()}
|
||||
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.executemany(
|
||||
# TODO: SQL injection risk
|
||||
f"INSERT OR IGNORE INTO `{service}` (kid, key_) VALUES (?, ?)",
|
||||
kid_keys.items(),
|
||||
)
|
||||
return cursor.rowcount
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
def get_services(self) -> Iterator[str]:
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
for (name,) in cursor.fetchall():
|
||||
if name != "sqlite_sequence":
|
||||
yield Services.get_tag(name)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def has_table(self, name: str) -> bool:
|
||||
"""Check if the Vault has a Table with the specified name."""
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?", (name,))
|
||||
return cursor.fetchone()[0] == 1
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def create_table(self, name: str):
|
||||
"""Create a Table with the specified name if not yet created."""
|
||||
if self.has_table(name):
|
||||
return
|
||||
|
||||
conn = self.conn_factory.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {name} (
|
||||
"id" INTEGER NOT NULL UNIQUE,
|
||||
"kid" TEXT NOT NULL COLLATE NOCASE,
|
||||
"key_" TEXT NOT NULL COLLATE NOCASE,
|
||||
PRIMARY KEY("id" AUTOINCREMENT),
|
||||
UNIQUE("kid", "key_")
|
||||
);
|
||||
"""
|
||||
)
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
|
||||
class ConnectionFactory:
|
||||
def __init__(self, path: Union[str, Path]):
|
||||
self._path = path
|
||||
self._store = threading.local()
|
||||
|
||||
def _create_connection(self) -> Connection:
|
||||
return sqlite3.connect(self._path)
|
||||
|
||||
def get(self) -> Connection:
|
||||
if not hasattr(self._store, "conn"):
|
||||
self._store.conn = self._create_connection()
|
||||
return self._store.conn
|
||||
0
unshackle/vaults/__init__.py
Normal file
0
unshackle/vaults/__init__.py
Normal file
Reference in New Issue
Block a user