188 lines
5.1 KiB
Python
Executable file
188 lines
5.1 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from scapy.all import IP, UDP, sniff
|
|
|
|
# ============================================
|
|
# Config
|
|
# ============================================
|
|
|
|
WATCHLIST_FILE = Path("/etc/wireguard/wgctl/daemon/watchlist.json")
|
|
EVENTS_LOG = Path("/etc/wireguard/wgctl/daemon/events.log")
|
|
WG_INTERFACE = os.environ.get("WG_INTERFACE", "eth0")
|
|
WG_PORT = int(os.environ.get("WG_PORT", "51820"))
|
|
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO")
|
|
|
|
# ============================================
|
|
# Logging
|
|
# ============================================
|
|
|
|
logging.basicConfig(
|
|
level=getattr(logging, LOG_LEVEL),
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
handlers=[logging.StreamHandler(sys.stdout)]
|
|
)
|
|
log = logging.getLogger("wgctl-monitor")
|
|
|
|
# ============================================
|
|
# Watchlist
|
|
# ============================================
|
|
|
|
_watchlist: dict[str, str] = {}
|
|
_watchlist_mtime: float = 0.0
|
|
|
|
def load_watchlist() -> dict[str, str]:
|
|
global _watchlist, _watchlist_mtime
|
|
|
|
try:
|
|
mtime = WATCHLIST_FILE.stat().st_mtime
|
|
if mtime == _watchlist_mtime:
|
|
return _watchlist
|
|
|
|
with WATCHLIST_FILE.open() as f:
|
|
_watchlist = json.load(f)
|
|
_watchlist_mtime = mtime
|
|
log.debug(f"Watchlist reloaded: {len(_watchlist)} entries")
|
|
|
|
except Exception as e:
|
|
log.error(f"Failed to load watchlist: {e}")
|
|
|
|
return _watchlist
|
|
|
|
def is_watched(ip: str) -> str | None:
|
|
watchlist = load_watchlist()
|
|
return watchlist.get(ip)
|
|
|
|
# ============================================
|
|
# Endpoint Resolution
|
|
# ============================================
|
|
|
|
def get_endpoint(public_key: str) -> str | None:
|
|
try:
|
|
import subprocess
|
|
result = subprocess.run(
|
|
["wg", "show", WG_INTERFACE, "endpoints"],
|
|
capture_output=True, text=True
|
|
)
|
|
for line in result.stdout.splitlines():
|
|
parts = line.split()
|
|
if len(parts) == 2 and parts[0] == public_key:
|
|
# Return just the IP without port
|
|
return parts[1].rsplit(":", 1)[0]
|
|
except Exception as e:
|
|
log.debug(f"Failed to get endpoint: {e}")
|
|
return None
|
|
|
|
def get_client_public_key(client_name: str) -> str | None:
|
|
key_file = Path(f"/etc/wireguard/clients/{client_name}_public.key")
|
|
try:
|
|
return key_file.read_text().strip()
|
|
except Exception:
|
|
return None
|
|
|
|
# ============================================
|
|
# Event Logging
|
|
# ============================================
|
|
|
|
def log_event(ip: str, client: str, event: str, endpoint: str | None = None):
|
|
entry = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"ip": ip,
|
|
"client": client,
|
|
"event": event,
|
|
}
|
|
|
|
# Update endpoint cache when we see a packet
|
|
cache_file = os.path.join(os.path.dirname(WATCHLIST_FILE), 'endpoint_cache.json')
|
|
try:
|
|
with open(cache_file) as f:
|
|
cache = json.load(f)
|
|
except:
|
|
cache = {}
|
|
cache[client] = ip
|
|
with open(cache_file, 'w') as f:
|
|
json.dump(cache, f, indent=2)
|
|
|
|
if endpoint:
|
|
entry["endpoint"] = endpoint
|
|
|
|
try:
|
|
with EVENTS_LOG.open("a") as f:
|
|
f.write(json.dumps(entry) + "\n")
|
|
log.debug(f"Event logged: {entry}")
|
|
except Exception as e:
|
|
log.error(f"Failed to write event: {e}")
|
|
|
|
# ============================================
|
|
# Packet Handler
|
|
# ============================================
|
|
|
|
def handle_packet(pkt):
|
|
if not (IP in pkt and UDP in pkt):
|
|
return
|
|
|
|
# Only care about packets targeting WireGuard port
|
|
if pkt[UDP].dport != WG_PORT:
|
|
return
|
|
|
|
src_ip = pkt[IP].src
|
|
client = is_watched(src_ip)
|
|
|
|
if not client:
|
|
return
|
|
|
|
# Resolve real endpoint IP
|
|
public_key = get_client_public_key(client)
|
|
endpoint = None
|
|
if public_key:
|
|
endpoint = get_endpoint(public_key)
|
|
|
|
# If no endpoint from wg show, use packet source IP
|
|
if not endpoint:
|
|
endpoint = src_ip
|
|
|
|
log_event(src_ip, client, "attempt", endpoint)
|
|
log.info(f"Blocked attempt: {client} ({src_ip}) from endpoint {endpoint}")
|
|
|
|
# ============================================
|
|
# Signal Handling
|
|
# ============================================
|
|
|
|
def handle_signal(signum, frame):
|
|
log.info("Shutting down wgctl-monitor")
|
|
sys.exit(0)
|
|
|
|
signal.signal(signal.SIGTERM, handle_signal)
|
|
signal.signal(signal.SIGINT, handle_signal)
|
|
|
|
# ============================================
|
|
# Main
|
|
# ============================================
|
|
|
|
def main():
|
|
log.info(f"wgctl-monitor starting on interface {WG_INTERFACE} port {WG_PORT}")
|
|
|
|
if not WATCHLIST_FILE.exists():
|
|
log.error(f"Watchlist not found: {WATCHLIST_FILE}")
|
|
sys.exit(1)
|
|
|
|
load_watchlist()
|
|
log.info("Watchlist loaded, starting packet capture...")
|
|
|
|
sniff(
|
|
iface=WG_INTERFACE,
|
|
filter=f"udp port {WG_PORT}",
|
|
prn=handle_packet,
|
|
store=0
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|