#!/usr/bin/env python3 import json import logging import os import signal import sys import time from datetime import datetime, timezone from pathlib import Path from scapy.all import IP, UDP, sniff # ============================================ # Config # ============================================ WATCHLIST_FILE = Path("/etc/wireguard/wgctl/daemon/watchlist.json") EVENTS_LOG = Path("/etc/wireguard/wgctl/daemon/events.log") WG_INTERFACE = os.environ.get("WG_INTERFACE", "eth0") WG_PORT = int(os.environ.get("WG_PORT", "51820")) LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO") # ============================================ # Logging # ============================================ logging.basicConfig( level=getattr(logging, LOG_LEVEL), format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)] ) log = logging.getLogger("wgctl-monitor") # ============================================ # Watchlist # ============================================ _watchlist: dict[str, str] = {} _watchlist_mtime: float = 0.0 def load_watchlist() -> dict[str, str]: global _watchlist, _watchlist_mtime try: mtime = WATCHLIST_FILE.stat().st_mtime if mtime == _watchlist_mtime: return _watchlist with WATCHLIST_FILE.open() as f: _watchlist = json.load(f) _watchlist_mtime = mtime log.debug(f"Watchlist reloaded: {len(_watchlist)} entries") except Exception as e: log.error(f"Failed to load watchlist: {e}") return _watchlist def is_watched(ip: str) -> str | None: watchlist = load_watchlist() return watchlist.get(ip) # ============================================ # Endpoint Resolution # ============================================ def get_endpoint(public_key: str) -> str | None: try: import subprocess result = subprocess.run( ["wg", "show", WG_INTERFACE, "endpoints"], capture_output=True, text=True ) for line in result.stdout.splitlines(): parts = line.split() if len(parts) == 2 and parts[0] == public_key: # Return just the IP without port return parts[1].rsplit(":", 1)[0] except Exception as e: log.debug(f"Failed to get endpoint: {e}") return None def get_client_public_key(client_name: str) -> str | None: key_file = Path(f"/etc/wireguard/clients/{client_name}_public.key") try: return key_file.read_text().strip() except Exception: return None # ============================================ # Event Logging # ============================================ def log_event(ip: str, client: str, event: str, endpoint: str | None = None): entry = { "timestamp": datetime.now(timezone.utc).isoformat(), "ip": ip, "client": client, "event": event, } # Update endpoint cache when we see a packet cache_file = os.path.join(os.path.dirname(WATCHLIST_FILE), 'endpoint_cache.json') try: with open(cache_file) as f: cache = json.load(f) except: cache = {} cache[client] = ip with open(cache_file, 'w') as f: json.dump(cache, f, indent=2) if endpoint: entry["endpoint"] = endpoint try: with EVENTS_LOG.open("a") as f: f.write(json.dumps(entry) + "\n") log.debug(f"Event logged: {entry}") except Exception as e: log.error(f"Failed to write event: {e}") # ============================================ # Packet Handler # ============================================ def handle_packet(pkt): if not (IP in pkt and UDP in pkt): return # Only care about packets targeting WireGuard port if pkt[UDP].dport != WG_PORT: return src_ip = pkt[IP].src client = is_watched(src_ip) if not client: return # Resolve real endpoint IP public_key = get_client_public_key(client) endpoint = None if public_key: endpoint = get_endpoint(public_key) # If no endpoint from wg show, use packet source IP if not endpoint: endpoint = src_ip log_event(src_ip, client, "attempt", endpoint) log.info(f"Blocked attempt: {client} ({src_ip}) from endpoint {endpoint}") # ============================================ # Signal Handling # ============================================ def handle_signal(signum, frame): log.info("Shutting down wgctl-monitor") sys.exit(0) signal.signal(signal.SIGTERM, handle_signal) signal.signal(signal.SIGINT, handle_signal) # ============================================ # Main # ============================================ def main(): log.info(f"wgctl-monitor starting on interface {WG_INTERFACE} port {WG_PORT}") if not WATCHLIST_FILE.exists(): log.error(f"Watchlist not found: {WATCHLIST_FILE}") sys.exit(1) load_watchlist() log.info("Watchlist loaded, starting packet capture...") sniff( iface=WG_INTERFACE, filter=f"udp port {WG_PORT}", prn=handle_packet, store=0 ) if __name__ == "__main__": main()