forked from kenzuya/unshackle
fix(monalisa): harden wasm calls and license handling
- Validate _monalisa_context_alloc return and cleanup on init failure - Derive deterministic KID when DCID missing to avoid collisions - Ensure stackRestore always runs via try/finally in _ccall - Log base64 decode failures without leaking license contents - Add bounds/alignment checks for i32 memory writes
This commit is contained in:
@@ -7,8 +7,11 @@ a WebAssembly module that runs locally via wasmtime.
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import ctypes
|
import ctypes
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
@@ -17,6 +20,8 @@ import wasmtime
|
|||||||
|
|
||||||
from unshackle.core import binaries
|
from unshackle.core import binaries
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MonaLisaCDM:
|
class MonaLisaCDM:
|
||||||
"""
|
"""
|
||||||
@@ -128,10 +133,27 @@ class MonaLisaCDM:
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.exports["___wasm_call_ctors"](self.store)
|
self.exports["___wasm_call_ctors"](self.store)
|
||||||
self.ctx = self.exports["_monalisa_context_alloc"](self.store)
|
ctx = self.exports["_monalisa_context_alloc"](self.store)
|
||||||
|
self.ctx = ctx
|
||||||
|
|
||||||
|
# _monalisa_context_alloc is expected to return a positive pointer/handle.
|
||||||
|
# Treat 0/negative/non-int-like values as allocation failure.
|
||||||
|
try:
|
||||||
|
ctx_int = int(ctx)
|
||||||
|
except Exception:
|
||||||
|
ctx_int = None
|
||||||
|
|
||||||
|
if ctx_int is None or ctx_int <= 0:
|
||||||
|
# Ensure we don't leave a partially-initialized instance around.
|
||||||
|
self.close()
|
||||||
|
raise RuntimeError(f"Failed to allocate MonaLisa context (ctx={ctx!r})")
|
||||||
return 1
|
return 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to initialize session: {e}")
|
# Clean up partial state (e.g., store/memory/instance) before propagating failure.
|
||||||
|
self.close()
|
||||||
|
if isinstance(e, RuntimeError):
|
||||||
|
raise
|
||||||
|
raise RuntimeError(f"Failed to initialize session: {e}") from e
|
||||||
|
|
||||||
def close(self, session_id: int = 1) -> None:
|
def close(self, session_id: int = 1) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -188,7 +210,9 @@ class MonaLisaCDM:
|
|||||||
# Extract DCID from license to generate KID
|
# Extract DCID from license to generate KID
|
||||||
try:
|
try:
|
||||||
decoded = base64.b64decode(license_b64).decode("ascii", errors="ignore")
|
decoded = base64.b64decode(license_b64).decode("ascii", errors="ignore")
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
# Avoid logging raw license content; log only safe metadata.
|
||||||
|
logger.exception("Failed to base64-decode MonaLisa license (len=%s): %s", len(license_b64), e)
|
||||||
decoded = ""
|
decoded = ""
|
||||||
|
|
||||||
m = re.search(
|
m = re.search(
|
||||||
@@ -198,7 +222,14 @@ class MonaLisaCDM:
|
|||||||
if m:
|
if m:
|
||||||
kid_bytes = uuid.uuid5(uuid.NAMESPACE_DNS, m.group()).bytes
|
kid_bytes = uuid.uuid5(uuid.NAMESPACE_DNS, m.group()).bytes
|
||||||
else:
|
else:
|
||||||
kid_bytes = uuid.UUID(int=0).bytes
|
# No DCID in the license: derive a deterministic per-license KID to avoid collisions.
|
||||||
|
try:
|
||||||
|
license_raw = base64.b64decode(license_b64)
|
||||||
|
except Exception:
|
||||||
|
license_raw = license_b64.encode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
license_hash = hashlib.sha256(license_raw).hexdigest()
|
||||||
|
kid_bytes = uuid.uuid5(uuid.NAMESPACE_DNS, f"monalisa:license:{license_hash}").bytes
|
||||||
|
|
||||||
return {"kid": kid_bytes.hex(), "key": key_bytes.hex(), "type": "CONTENT"}
|
return {"kid": kid_bytes.hex(), "key": key_bytes.hex(), "type": "CONTENT"}
|
||||||
|
|
||||||
@@ -221,21 +252,29 @@ class MonaLisaCDM:
|
|||||||
stack = 0
|
stack = 0
|
||||||
converted_args = []
|
converted_args = []
|
||||||
|
|
||||||
for arg in args:
|
try:
|
||||||
if isinstance(arg, str):
|
for arg in args:
|
||||||
if stack == 0:
|
if isinstance(arg, str):
|
||||||
stack = self.exports["stackSave"](self.store)
|
if stack == 0:
|
||||||
max_length = (len(arg) << 2) + 1
|
stack = self.exports["stackSave"](self.store)
|
||||||
ptr = self.exports["stackAlloc"](self.store, max_length)
|
max_length = (len(arg) << 2) + 1
|
||||||
self._string_to_utf8(arg, ptr, max_length)
|
ptr = self.exports["stackAlloc"](self.store, max_length)
|
||||||
converted_args.append(ptr)
|
self._string_to_utf8(arg, ptr, max_length)
|
||||||
else:
|
converted_args.append(ptr)
|
||||||
converted_args.append(arg)
|
else:
|
||||||
|
converted_args.append(arg)
|
||||||
|
|
||||||
result = self.exports[func_name](self.store, *converted_args)
|
result = self.exports[func_name](self.store, *converted_args)
|
||||||
|
finally:
|
||||||
if stack != 0:
|
# stackAlloc pointers live on the WASM stack; always restore even if the call throws.
|
||||||
self.exports["stackRestore"](self.store, stack)
|
if stack != 0:
|
||||||
|
exc = sys.exc_info()[1]
|
||||||
|
try:
|
||||||
|
self.exports["stackRestore"](self.store, stack)
|
||||||
|
except Exception:
|
||||||
|
# If we're already failing, don't mask the original exception.
|
||||||
|
if exc is None:
|
||||||
|
raise
|
||||||
|
|
||||||
if return_type is bool:
|
if return_type is bool:
|
||||||
return bool(result)
|
return bool(result)
|
||||||
@@ -243,6 +282,13 @@ class MonaLisaCDM:
|
|||||||
|
|
||||||
def _write_i32(self, addr: int, value: int) -> None:
|
def _write_i32(self, addr: int, value: int) -> None:
|
||||||
"""Write a 32-bit integer to WASM memory."""
|
"""Write a 32-bit integer to WASM memory."""
|
||||||
|
if addr % 4 != 0:
|
||||||
|
raise ValueError(f"Unaligned i32 write: addr={addr} (must be 4-byte aligned)")
|
||||||
|
|
||||||
|
data_len = self.memory.data_len(self.store)
|
||||||
|
if addr < 0 or addr + 4 > data_len:
|
||||||
|
raise IndexError(f"i32 write out of bounds: addr={addr}, mem_len={data_len}")
|
||||||
|
|
||||||
data = self.memory.data_ptr(self.store)
|
data = self.memory.data_ptr(self.store)
|
||||||
mem_ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_int32))
|
mem_ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_int32))
|
||||||
mem_ptr[addr >> 2] = value
|
mem_ptr[addr >> 2] = value
|
||||||
|
|||||||
Reference in New Issue
Block a user