#!/usr/bin/env python3 import subprocess import threading import json import logging import os import signal import sys import time from datetime import datetime, timezone from pathlib import Path from scapy.all import IP, UDP, sniff # ============================================ # Config # ============================================ WATCHLIST_FILE = Path("/etc/wireguard/.wgctl/daemon/watchlist.json") EVENTS_LOG = Path("/etc/wireguard/.wgctl/daemon/events.log") WG_INTERFACE = os.environ.get("WG_INTERFACE", "eth0") WG_PORT = int(os.environ.get("WG_PORT", "51820")) LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO") WG_HANDSHAKE_CHECK_SEC = int(os.environ.get("WG_HANDSHAKE_CHECK_TIME_SEC", "300")) WG_WG_INTERFACE = os.environ.get("WG_WG_INTERFACE", "wg0") # WireGuard interface, not capture interface HS_CACHE_FILE = Path("/etc/wireguard/.wgctl/daemon/hs_cache.json") ENDPOINT_CACHE_FILE = Path("/etc/wireguard/.wgctl/daemon/endpoint_cache.json") PEER_HISTORY_DIR = Path("/etc/wireguard/.wgctl/peer-history") ENDPOINT_INDEX_FILE = PEER_HISTORY_DIR / "endpoint_index.json" # ============================================ # Logging # ============================================ logging.basicConfig( level=getattr(logging, LOG_LEVEL), format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)] ) log = logging.getLogger("wgctl-monitor") # ============================================ # Watchlist # ============================================ _watchlist: dict[str, str] = {} _watchlist_mtime: float = 0.0 def load_watchlist() -> dict[str, str]: global _watchlist, _watchlist_mtime try: mtime = WATCHLIST_FILE.stat().st_mtime if mtime == _watchlist_mtime: return _watchlist with WATCHLIST_FILE.open() as f: _watchlist = json.load(f) _watchlist_mtime = mtime log.debug(f"Watchlist reloaded: {len(_watchlist)} entries") except Exception as e: log.error(f"Failed to load watchlist: {e}") return _watchlist def is_watched(ip: str) -> str | None: watchlist = load_watchlist() return watchlist.get(ip) # ============================================ # Endpoint Resolution # ============================================ def get_endpoint(public_key: str) -> str | None: try: import subprocess result = subprocess.run( ["wg", "show", WG_INTERFACE, "endpoints"], capture_output=True, text=True ) for line in result.stdout.splitlines(): parts = line.split() if len(parts) == 2 and parts[0] == public_key: # Return just the IP without port return parts[1].rsplit(":", 1)[0] except Exception as e: log.debug(f"Failed to get endpoint: {e}") return None def get_client_public_key(client_name: str) -> str | None: key_file = Path(f"/etc/wireguard/clients/{client_name}_public.key") try: return key_file.read_text().strip() except Exception: return None # ============================================ # Event Logging # ============================================ def log_event(ip: str, client: str, event: str, endpoint: str | None = None): entry = { "timestamp": datetime.now(timezone.utc).isoformat(), "ip": ip, "client": client, "event": event, } # Update endpoint cache when we see a packet cache_file = ENDPOINT_CACHE_FILE try: with open(cache_file) as f: cache = json.load(f) except: cache = {} cache[client] = ip with open(cache_file, 'w') as f: json.dump(cache, f, indent=2) if endpoint: entry["endpoint"] = endpoint try: with EVENTS_LOG.open("a") as f: f.write(json.dumps(entry) + "\n") log.debug(f"Event logged: {entry}") except Exception as e: log.error(f"Failed to write event: {e}") # ============================================ # Handshake Poller # ============================================ # Tracks last logged handshake ts per pubkey _hs_last_logged: dict[str, int] = {} def load_hs_cache(): try: with HS_CACHE_FILE.open() as f: return {k: int(v) for k, v in json.load(f).items()} except Exception: return {} def save_hs_cache(cache): try: with HS_CACHE_FILE.open('w') as f: json.dump(cache, f) except Exception: pass def build_pubkey_to_name() -> dict[str, str]: """Build pubkey -> client name map from public key files.""" mapping = {} clients_dir = Path("/etc/wireguard/clients") for kf in clients_dir.glob("*_public.key"): name = kf.stem.replace("_public", "") try: mapping[kf.read_text().strip()] = name except Exception: pass return mapping def poll_handshakes(): """ Poll wg show latest-handshakes periodically. Log a handshake event only when gap > WG_HANDSHAKE_CHECK_SEC (new session). """ global _hs_last_logged, _endpoint_index _hs_last_logged = load_hs_cache() _endpoint_index = load_endpoint_index() pubkey_to_name = build_pubkey_to_name() log.info(f"Handshake poller started — {len(pubkey_to_name)} peers, " f"session threshold {WG_HANDSHAKE_CHECK_SEC}s") log.info(f"Endpoint index loaded — {len(_endpoint_index)} known endpoints") while True: try: result = subprocess.run( ["wg", "show", WG_WG_INTERFACE, "latest-handshakes"], capture_output=True, text=True ) for line in result.stdout.strip().splitlines(): parts = line.split() if len(parts) != 2: continue pubkey, ts_str = parts try: ts = int(ts_str) except ValueError: continue if ts == 0: continue client = pubkey_to_name.get(pubkey) if not client: continue last = _hs_last_logged.get(pubkey, 0) gap = ts - last # Always update last seen _hs_last_logged[pubkey] = ts # Get endpoint endpoint = get_endpoint(pubkey) or '' if not endpoint: try: cache = json.loads(ENDPOINT_CACHE_FILE.read_text()) endpoint = cache.get(client, '') except Exception: pass # Always update peer history + index if endpoint: update_peer_history(client, endpoint, ts) if gap < WG_HANDSHAKE_CHECK_SEC: continue # keepalive, skip # New session — log it entry = { "timestamp": datetime.fromtimestamp(ts, tz=timezone.utc).isoformat(), "ip": "", "client": client, "event": "handshake", "endpoint": endpoint, } try: with EVENTS_LOG.open("a") as f: f.write(json.dumps(entry) + "\n") log.info(f"New session: {client} from {endpoint}") except Exception as e: log.error(f"Failed to write handshake event: {e}") log.debug(f"Gap for {client}: {gap}s (threshold: {WG_HANDSHAKE_CHECK_SEC}s)") save_hs_cache(_hs_last_logged) except Exception as e: log.error(f"Handshake poll error: {e}") time.sleep(WG_HANDSHAKE_CHECK_SEC // 2) # ============================================ # Peer History # ============================================ def load_endpoint_index() -> dict: """Load endpoint -> peer name index.""" try: if ENDPOINT_INDEX_FILE.exists(): return json.loads(ENDPOINT_INDEX_FILE.read_text()) except Exception: pass return {} def save_endpoint_index(index: dict): """Save endpoint -> peer name index.""" try: PEER_HISTORY_DIR.mkdir(parents=True, exist_ok=True) ENDPOINT_INDEX_FILE.write_text(json.dumps(index, indent=2)) except Exception as e: log.error(f"Failed to save endpoint index: {e}") # In-memory index — loaded once, updated on each new endpoint _endpoint_index: dict = {} def update_peer_history(client: str, endpoint: str, ts: int): """ Update peer endpoint history and endpoint index. Called on every poll cycle to keep last_seen current. """ global _endpoint_index if not endpoint: return try: PEER_HISTORY_DIR.mkdir(parents=True, exist_ok=True) history_file = PEER_HISTORY_DIR / f"{client}.json" if history_file.exists(): try: data = json.loads(history_file.read_text()) except Exception: data = {"peer": client, "endpoints": {}} else: data = {"peer": client, "endpoints": {}} ts_iso = datetime.fromtimestamp(ts, tz=timezone.utc).isoformat() eps = data.setdefault("endpoints", {}) is_new = endpoint not in eps if is_new: eps[endpoint] = { "first_seen": ts_iso, "last_seen": ts_iso, "count": 1 } log.debug(f"New endpoint for {client}: {endpoint}") # Update in-memory index and persist _endpoint_index[endpoint] = client save_endpoint_index(_endpoint_index) else: eps[endpoint]["last_seen"] = ts_iso eps[endpoint]["count"] += 1 history_file.write_text(json.dumps(data, indent=2)) except Exception as e: log.error(f"Failed to update peer history for {client}: {e}") # ============================================ # Packet Handler # ============================================ def handle_packet(pkt): if not (IP in pkt and UDP in pkt): return # Only care about packets targeting WireGuard port if pkt[UDP].dport != WG_PORT: return src_ip = pkt[IP].src client = is_watched(src_ip) if not client: return # Resolve real endpoint IP public_key = get_client_public_key(client) endpoint = None if public_key: endpoint = get_endpoint(public_key) # If no endpoint from wg show, use packet source IP if not endpoint: endpoint = src_ip log_event(src_ip, client, "attempt", endpoint) log.info(f"Blocked attempt: {client} ({src_ip}) from endpoint {endpoint}") # ============================================ # Signal Handling # ============================================ def handle_signal(signum, frame): log.info("Shutting down wgctl-monitor") sys.exit(0) signal.signal(signal.SIGTERM, handle_signal) signal.signal(signal.SIGINT, handle_signal) # ============================================ # Main # ============================================ def main(): log.info(f"wgctl-monitor starting on interface {WG_INTERFACE} port {WG_PORT}") if not WATCHLIST_FILE.exists(): log.error(f"Watchlist not found: {WATCHLIST_FILE}") sys.exit(1) load_watchlist() log.info("Watchlist loaded, starting packet capture...") # Start handshake poller in background thread hs_thread = threading.Thread(target=poll_handshakes, daemon=True) hs_thread.start() sniff( iface=WG_INTERFACE, filter=f"udp port {WG_PORT}", prn=handle_packet, store=0 ) if __name__ == "__main__": main()