refactor(ip_info): simplify lookup and trim cache

Cache only country/country_code (drop full IP/org/asn), bump CACHE_KEY and auto-purge stale cache versions. Dedup the three provider parsers into one normalize(). Use a plain retry-free requests session for lookups instead of the TLS-fingerprinted session, carrying only the proxy over, so a 429 returns directly and hands over to the next provider reliably.
This commit is contained in:
imSp4rky
2026-05-27 22:48:39 -06:00
parent 40104be738
commit fb8dc0bd9d

View File

@@ -8,101 +8,113 @@ import requests
from unshackle.core.cacher import Cacher from unshackle.core.cacher import Cacher
CACHE_KEY = "ip_info_v2" CACHE_KEY = "ip_info_v3"
CACHE_TTL = 86400 # 24 hours CACHE_TTL = 86400 # 24 hours
PROVIDER_STATE_KEY = "ip_provider_state" PROVIDER_STATE_KEY = "ip_provider_state"
RATE_LIMIT_COOLDOWN = 300 # 5 minutes RATE_LIMIT_COOLDOWN = 300 # 5 minutes
REQUEST_TIMEOUT = 10 REQUEST_TIMEOUT = 10
# Only these keys are persisted to the global cache.
GEO_CACHE_KEYS = ("country", "country_code")
Fetcher = Callable[[requests.Session], Optional[dict]] Fetcher = Callable[[requests.Session], Optional[dict]]
log = logging.getLogger("ip_info")
class _RateLimited(Exception):
class RateLimited(Exception):
"""Raised by a provider fetcher when the upstream returns 429.""" """Raised by a provider fetcher when the upstream returns 429."""
def _empty() -> dict: def normalize(
*,
country_code: str,
ip: str = "",
region: str = "",
city: str = "",
org: str = "",
asn: str = "",
as_name: str = "",
continent_code: str = "",
) -> Optional[dict]:
"""Build the canonical IP-info dict, or None if no country code is present."""
code = country_code.strip()
if not code:
return None
return { return {
"ip": "", "ip": ip,
"country": "", "country": code.lower(),
"country_code": "", "country_code": code.upper(),
"region": "", "region": region,
"city": "", "city": city,
"org": "", "org": org,
"asn": "", "asn": asn,
"as_name": "", "as_name": as_name,
"continent_code": "", "continent_code": continent_code.upper(),
} }
def _parse_ipinfo_lite(data: dict) -> Optional[dict]: def parse_ipinfo_lite(data: dict) -> Optional[dict]:
code = (data.get("country_code") or "").strip()
if not code:
return None
asn = (data.get("asn") or "").strip() asn = (data.get("asn") or "").strip()
as_name = (data.get("as_name") or "").strip() as_name = (data.get("as_name") or "").strip()
org = f"{asn} {as_name}".strip() if (asn or as_name) else "" return normalize(
out = _empty() country_code=data.get("country_code") or "",
out.update( ip=data.get("ip") or "",
{ org=f"{asn} {as_name}".strip(),
"ip": data.get("ip") or "", asn=asn,
"country": code.lower(), as_name=as_name,
"country_code": code.upper(), continent_code=data.get("continent_code") or "",
"org": org,
"asn": asn,
"as_name": as_name,
"continent_code": (data.get("continent_code") or "").upper(),
}
) )
return out
def _parse_ipinfo(data: dict) -> Optional[dict]: def parse_ipinfo(data: dict) -> Optional[dict]:
code = (data.get("country") or "").strip() return normalize(
if not code: country_code=data.get("country") or "",
return None ip=data.get("ip") or "",
out = _empty() region=data.get("region") or "",
out.update( city=data.get("city") or "",
{ org=data.get("org") or "",
"ip": data.get("ip") or "",
"country": code.lower(),
"country_code": code.upper(),
"region": data.get("region") or "",
"city": data.get("city") or "",
"org": data.get("org") or "",
}
) )
return out
def _parse_ip_api_in(data: dict) -> Optional[dict]: def parse_ip_api_in(data: dict) -> Optional[dict]:
code = (data.get("country_code") or "").strip()
if not code:
return None
asn = (data.get("asn") or "").strip() asn = (data.get("asn") or "").strip()
org_name = (data.get("organization") or "").strip() org_name = (data.get("organization") or "").strip()
org = f"{asn} {org_name}".strip() if (asn or org_name) else "" return normalize(
out = _empty() country_code=data.get("country_code") or "",
out.update( ip=data.get("ip") or "",
{ region=data.get("region") or "",
"ip": data.get("ip") or "", city=data.get("city") or "",
"country": code.lower(), org=f"{asn} {org_name}".strip(),
"country_code": code.upper(), asn=asn,
"region": data.get("region") or "", as_name=org_name,
"city": data.get("city") or "", continent_code=data.get("continent_code") or "",
"org": org,
"asn": asn,
"as_name": org_name,
"continent_code": (data.get("continent_code") or "").upper(),
}
) )
return out
def _check(response: requests.Response) -> Optional[dict]: def lookup_session(source: Optional[requests.Session]) -> requests.Session:
"""Raise _RateLimited on 429, return parsed JSON on 200, else None.""" """
Build a plain, retry-free requests session for IP geolocation.
Geolocation needs no TLS fingerprinting, so we skip the impersonated rnet
session and the base session's urllib3 retry loop — both retry 429 internally,
which hides the response and defeats fast provider handover. With a bare session
a 429 comes straight back so we can move to the next provider immediately. Only
the proxy is carried over so proxied lookups still report the proxy's exit IP.
"""
sess = requests.Session()
proxies = getattr(source, "proxies", None)
if proxies:
proxy = proxies.get("all") or proxies.get("https") or proxies.get("http")
if proxy:
sess.proxies.update({"http": proxy, "https": proxy})
return sess
def json_or_raise(response: requests.Response) -> Optional[dict]:
"""Raise RateLimited on 429, return parsed JSON on 200, else None."""
if response.status_code == 429: if response.status_code == 429:
raise _RateLimited() raise RateLimited()
if response.status_code != 200: if response.status_code != 200:
return None return None
try: try:
@@ -111,50 +123,62 @@ def _check(response: requests.Response) -> Optional[dict]:
return None return None
def _fetch_ipinfo_lite(token: str) -> Fetcher: def fetch_ipinfo_lite(token: str) -> Fetcher:
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
def fetch(session: requests.Session) -> Optional[dict]: def fetch(session: requests.Session) -> Optional[dict]:
payload = _check(session.get("https://api.ipinfo.io/lite/me", headers=headers, timeout=REQUEST_TIMEOUT)) payload = json_or_raise(session.get("https://api.ipinfo.io/lite/me", headers=headers, timeout=REQUEST_TIMEOUT))
return _parse_ipinfo_lite(payload) if payload else None return parse_ipinfo_lite(payload) if payload else None
return fetch return fetch
def _fetch_ipinfo(session: requests.Session) -> Optional[dict]: def fetch_ipinfo(session: requests.Session) -> Optional[dict]:
payload = _check(session.get("https://ipinfo.io/json", timeout=REQUEST_TIMEOUT)) payload = json_or_raise(session.get("https://ipinfo.io/json", timeout=REQUEST_TIMEOUT))
return _parse_ipinfo(payload) if payload else None return parse_ipinfo(payload) if payload else None
def _fetch_ip_api_in(session: requests.Session) -> Optional[dict]: def fetch_ip_api_in(session: requests.Session) -> Optional[dict]:
"""ip-api.in has no /me endpoint — resolve IP via ipify first, then look it up.""" """ip-api.in has no /me endpoint — resolve IP via ipify first, then look it up."""
ip_resp = session.get("https://api.ipify.org", timeout=REQUEST_TIMEOUT) ip_resp = session.get("https://api.ipify.org", timeout=REQUEST_TIMEOUT)
if ip_resp.status_code == 429: if ip_resp.status_code == 429:
raise _RateLimited() raise RateLimited()
if ip_resp.status_code != 200: ip = (ip_resp.text or "").strip() if ip_resp.status_code == 200 else ""
return None
ip = (ip_resp.text or "").strip()
if not ip: if not ip:
return None return None
payload = _check(session.get(f"https://ip-api.in/api/v1/ip/{ip}", timeout=REQUEST_TIMEOUT)) payload = json_or_raise(session.get(f"https://ip-api.in/api/v1/ip/{ip}", timeout=REQUEST_TIMEOUT))
if not payload or not payload.get("success"): if not payload or not payload.get("success"):
return None return None
return _parse_ip_api_in(payload.get("data") or {}) return parse_ip_api_in(payload.get("data") or {})
def _build_providers() -> list[tuple[str, Fetcher]]: def build_providers() -> list[tuple[str, Fetcher]]:
"""Return ordered (name, fetcher) pairs. Token read at call time.""" """Return ordered (name, fetcher) pairs. Token is read at call time."""
from unshackle.core.config import config from unshackle.core.config import config
providers: list[tuple[str, Fetcher]] = [] providers: list[tuple[str, Fetcher]] = []
token = (getattr(config, "ipinfo_api_key", "") or "").strip() token = (getattr(config, "ipinfo_api_key", "") or "").strip()
if token: if token:
providers.append(("ipinfo_lite", _fetch_ipinfo_lite(token))) providers.append(("ipinfo_lite", fetch_ipinfo_lite(token)))
providers.append(("ipinfo", _fetch_ipinfo)) providers.append(("ipinfo", fetch_ipinfo))
providers.append(("ip_api_in", _fetch_ip_api_in)) providers.append(("ip_api_in", fetch_ip_api_in))
return providers return providers
def purge_stale_cache() -> None:
"""Delete superseded ip_info cache files (older CACHE_KEY versions)."""
from unshackle.core.config import config
global_dir = config.directories.cache / "global"
for stale in global_dir.glob("ip_info_v*.json"):
if stale.stem != CACHE_KEY:
stale.unlink(missing_ok=True)
def load_provider_state(cacher: Cacher) -> dict[str, Any]:
return cacher.data if cacher and not cacher.expired and isinstance(cacher.data, dict) else {}
def get_ip_info( def get_ip_info(
session: Optional[requests.Session] = None, session: Optional[requests.Session] = None,
*, *,
@@ -164,9 +188,10 @@ def get_ip_info(
Look up IP/geolocation info via ipinfo.io (Lite when `ipinfo_api_key` configured) Look up IP/geolocation info via ipinfo.io (Lite when `ipinfo_api_key` configured)
with fallback to ip-api.in. with fallback to ip-api.in.
Returns a normalized dict with keys: `ip`, `country` (lowercase ISO2), Live lookups return a dict with `ip`, `country` (lowercase ISO2), `country_code`
`country_code` (uppercase ISO2), `region`, `city`, `org`, `asn`, `as_name`, (uppercase ISO2), `region`, `city`, `org`, `asn`, `as_name`, `continent_code` and
`continent_code`, and `_provider`. Returns None if every provider fails. `_provider`. Cached lookups return only `country`/`country_code` (see GEO_CACHE_KEYS).
Returns None if every provider fails.
Args: Args:
session: Optional requests session. If a proxied session is passed, the session: Optional requests session. If a proxied session is passed, the
@@ -175,36 +200,29 @@ def get_ip_info(
cached: When True, read/write a 24h Cacher-backed entry. Use only for cached: When True, read/write a 24h Cacher-backed entry. Use only for
local IP lookups — never with a proxied session. local IP lookups — never with a proxied session.
""" """
log = logging.getLogger("ip_info") cache = None
if cached: if cached:
purge_stale_cache()
cache = Cacher("global").get(CACHE_KEY) cache = Cacher("global").get(CACHE_KEY)
if cache and not cache.expired and cache.data: if cache and not cache.expired and cache.data:
return cache.data return cache.data
else:
cache = None
state_cache = Cacher("global").get(PROVIDER_STATE_KEY) state_cache = Cacher("global").get(PROVIDER_STATE_KEY)
state: dict[str, Any] = ( state = load_provider_state(state_cache)
state_cache.data if state_cache and not state_cache.expired and isinstance(state_cache.data, dict) else {}
)
providers = _build_providers()
now = time.time() now = time.time()
def _cooldown_key(item: tuple[str, Fetcher]) -> int: def on_cooldown(item: tuple[str, Fetcher]) -> int:
info = state.get(item[0]) or {} rate_limited_at = (state.get(item[0]) or {}).get("rate_limited_at", 0)
return 1 if (now - info.get("rate_limited_at", 0)) < RATE_LIMIT_COOLDOWN else 0 return 1 if (now - rate_limited_at) < RATE_LIMIT_COOLDOWN else 0
providers.sort(key=_cooldown_key) providers = sorted(build_providers(), key=on_cooldown)
sess = lookup_session(session)
sess = session or requests.Session()
for name, fetcher in providers: for name, fetcher in providers:
log.debug(f"Trying IP provider: {name}") log.debug(f"Trying IP provider: {name}")
try: try:
normalized = fetcher(sess) normalized = fetcher(sess)
except _RateLimited: except RateLimited:
log.warning(f"Provider {name} returned 429 (rate limited), trying next provider") log.warning(f"Provider {name} returned 429 (rate limited), trying next provider")
entry = state.setdefault(name, {}) entry = state.setdefault(name, {})
entry["rate_limited_at"] = now entry["rate_limited_at"] = now
@@ -225,8 +243,8 @@ def get_ip_info(
if name in state and state[name].pop("rate_limited_at", None) is not None: if name in state and state[name].pop("rate_limited_at", None) is not None:
state_cache.set(state, expiration=RATE_LIMIT_COOLDOWN) state_cache.set(state, expiration=RATE_LIMIT_COOLDOWN)
if cached and cache is not None: if cache is not None:
cache.set(normalized, expiration=CACHE_TTL) cache.set({k: normalized.get(k, "") for k in GEO_CACHE_KEYS}, expiration=CACHE_TTL)
return normalized return normalized