#!/usr/bin/env python3 import subprocess import threading import json import logging import os import signal import sys import time from datetime import datetime, timezone from pathlib import Path from scapy.all import IP, UDP, sniff # ============================================ # Config # ============================================ WATCHLIST_FILE = Path("/etc/wireguard/.wgctl/daemon/watchlist.json") EVENTS_LOG = Path("/etc/wireguard/.wgctl/daemon/events.log") WG_INTERFACE = os.environ.get("WG_INTERFACE", "eth0") WG_PORT = int(os.environ.get("WG_PORT", "51820")) LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO") WG_HANDSHAKE_CHECK_SEC = int(os.environ.get("WG_HANDSHAKE_CHECK_TIME_SEC", "300")) WG_WG_INTERFACE = os.environ.get("WG_WG_INTERFACE", "wg0") # WireGuard interface, not capture interface HS_CACHE_FILE = Path("/etc/wireguard/.wgctl/daemon/hs_cache.json") ENDPOINT_CACHE_FILE = Path("/etc/wireguard/.wgctl/daemon/endpoint_cache.json") # ============================================ # Logging # ============================================ logging.basicConfig( level=getattr(logging, LOG_LEVEL), format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)] ) log = logging.getLogger("wgctl-monitor") # ============================================ # Watchlist # ============================================ _watchlist: dict[str, str] = {} _watchlist_mtime: float = 0.0 def load_watchlist() -> dict[str, str]: global _watchlist, _watchlist_mtime try: mtime = WATCHLIST_FILE.stat().st_mtime if mtime == _watchlist_mtime: return _watchlist with WATCHLIST_FILE.open() as f: _watchlist = json.load(f) _watchlist_mtime = mtime log.debug(f"Watchlist reloaded: {len(_watchlist)} entries") except Exception as e: log.error(f"Failed to load watchlist: {e}") return _watchlist def is_watched(ip: str) -> str | None: watchlist = load_watchlist() return watchlist.get(ip) # ============================================ # Endpoint Resolution # ============================================ def get_endpoint(public_key: str) -> str | None: try: import subprocess result = subprocess.run( ["wg", "show", WG_INTERFACE, "endpoints"], capture_output=True, text=True ) for line in result.stdout.splitlines(): parts = line.split() if len(parts) == 2 and parts[0] == public_key: # Return just the IP without port return parts[1].rsplit(":", 1)[0] except Exception as e: log.debug(f"Failed to get endpoint: {e}") return None def get_client_public_key(client_name: str) -> str | None: key_file = Path(f"/etc/wireguard/clients/{client_name}_public.key") try: return key_file.read_text().strip() except Exception: return None # ============================================ # Event Logging # ============================================ def log_event(ip: str, client: str, event: str, endpoint: str | None = None): entry = { "timestamp": datetime.now(timezone.utc).isoformat(), "ip": ip, "client": client, "event": event, } # Update endpoint cache when we see a packet cache_file = ENDPOINT_CACHE_FILE try: with open(cache_file) as f: cache = json.load(f) except: cache = {} cache[client] = ip with open(cache_file, 'w') as f: json.dump(cache, f, indent=2) if endpoint: entry["endpoint"] = endpoint try: with EVENTS_LOG.open("a") as f: f.write(json.dumps(entry) + "\n") log.debug(f"Event logged: {entry}") except Exception as e: log.error(f"Failed to write event: {e}") # ============================================ # Handshake Poller # ============================================ # Tracks last logged handshake ts per pubkey _hs_last_logged: dict[str, int] = {} def load_hs_cache(): try: with HS_CACHE_FILE.open() as f: return {k: int(v) for k, v in json.load(f).items()} except Exception: return {} def save_hs_cache(cache): try: with HS_CACHE_FILE.open('w') as f: json.dump(cache, f) except Exception: pass def build_pubkey_to_name() -> dict[str, str]: """Build pubkey -> client name map from public key files.""" mapping = {} clients_dir = Path("/etc/wireguard/clients") for kf in clients_dir.glob("*_public.key"): name = kf.stem.replace("_public", "") try: mapping[kf.read_text().strip()] = name except Exception: pass return mapping def poll_handshakes(): """ Poll wg show latest-handshakes periodically. Log a handshake event only when gap > WG_HANDSHAKE_CHECK_SEC (new session). """ global _hs_last_logged _hs_last_logged = load_hs_cache() pubkey_to_name = build_pubkey_to_name() log.info(f"Handshake poller started — {len(pubkey_to_name)} peers, " f"session threshold {WG_HANDSHAKE_CHECK_SEC}s") while True: try: result = subprocess.run( ["wg", "show", WG_WG_INTERFACE, "latest-handshakes"], capture_output=True, text=True ) for line in result.stdout.strip().splitlines(): parts = line.split() if len(parts) != 2: continue pubkey, ts_str = parts try: ts = int(ts_str) except ValueError: continue if ts == 0: continue client = pubkey_to_name.get(pubkey) if not client: continue last = _hs_last_logged.get(pubkey, 0) gap = ts - last # Always update last seen _hs_last_logged[pubkey] = ts if gap < WG_HANDSHAKE_CHECK_SEC: continue # keepalive, skip # Get endpoint endpoint = get_endpoint(pubkey) or '' if not endpoint: try: cache = json.loads(ENDPOINT_CACHE_FILE.read_text()) endpoint = cache.get(client, '') except Exception: pass # New session, log it entry = { "timestamp": datetime.fromtimestamp(ts, tz=timezone.utc).isoformat(), "ip": "", "client": client, "event": "handshake", "endpoint": endpoint, } try: with EVENTS_LOG.open("a") as f: f.write(json.dumps(entry) + "\n") log.info(f"New session: {client} from {endpoint}") except Exception as e: log.error(f"Failed to write handshake event: {e}") log.debug(f"Gap for {client}: {gap}s (threshold: {WG_HANDSHAKE_CHECK_SEC}s)") save_hs_cache(_hs_last_logged) except Exception as e: log.error(f"Handshake poll error: {e}") time.sleep(WG_HANDSHAKE_CHECK_SEC // 2) # poll at half the threshold # ============================================ # Packet Handler # ============================================ def handle_packet(pkt): if not (IP in pkt and UDP in pkt): return # Only care about packets targeting WireGuard port if pkt[UDP].dport != WG_PORT: return src_ip = pkt[IP].src client = is_watched(src_ip) if not client: return # Resolve real endpoint IP public_key = get_client_public_key(client) endpoint = None if public_key: endpoint = get_endpoint(public_key) # If no endpoint from wg show, use packet source IP if not endpoint: endpoint = src_ip log_event(src_ip, client, "attempt", endpoint) log.info(f"Blocked attempt: {client} ({src_ip}) from endpoint {endpoint}") # ============================================ # Signal Handling # ============================================ def handle_signal(signum, frame): log.info("Shutting down wgctl-monitor") sys.exit(0) signal.signal(signal.SIGTERM, handle_signal) signal.signal(signal.SIGINT, handle_signal) # ============================================ # Main # ============================================ def main(): log.info(f"wgctl-monitor starting on interface {WG_INTERFACE} port {WG_PORT}") if not WATCHLIST_FILE.exists(): log.error(f"Watchlist not found: {WATCHLIST_FILE}") sys.exit(1) load_watchlist() log.info("Watchlist loaded, starting packet capture...") # Start handshake poller in background thread hs_thread = threading.Thread(target=poll_handshakes, daemon=True) hs_thread.start() sniff( iface=WG_INTERFACE, filter=f"udp port {WG_PORT}", prn=handle_packet, store=0 ) if __name__ == "__main__": main()