- wgctl-monitor: update _hs_last_logged on ALL handshakes not just new sessions
- wgctl-monitor: fix endpoint_cache.json absolute path
- wgctl-monitor: move script to wgctl/daemon/ (correct location)
- watch: _poll_handshakes sorts by ts descending, endpoint cache fallback
- watch: empty endpoint uses - not em dash (alignment fix)
- logs: newline between fw and wg sections
- monitor::live extracted, cmd::logs::follow no longer calls cmd:⌚:run
- ui.sh: UTF-8 extra byte constants
303 lines
8.9 KiB
Python
Executable file
303 lines
8.9 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
|
|
import subprocess
|
|
import threading
|
|
import json
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from scapy.all import IP, UDP, sniff
|
|
|
|
# ============================================
|
|
# Config
|
|
# ============================================
|
|
|
|
WATCHLIST_FILE = Path("/etc/wireguard/.wgctl/daemon/watchlist.json")
|
|
EVENTS_LOG = Path("/etc/wireguard/.wgctl/daemon/events.log")
|
|
WG_INTERFACE = os.environ.get("WG_INTERFACE", "eth0")
|
|
WG_PORT = int(os.environ.get("WG_PORT", "51820"))
|
|
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO")
|
|
WG_HANDSHAKE_CHECK_SEC = int(os.environ.get("WG_HANDSHAKE_CHECK_TIME_SEC", "300"))
|
|
WG_WG_INTERFACE = os.environ.get("WG_WG_INTERFACE", "wg0") # WireGuard interface, not capture interface
|
|
HS_CACHE_FILE = Path("/etc/wireguard/.wgctl/daemon/hs_cache.json")
|
|
ENDPOINT_CACHE_FILE = Path("/etc/wireguard/.wgctl/daemon/endpoint_cache.json")
|
|
|
|
# ============================================
|
|
# Logging
|
|
# ============================================
|
|
|
|
logging.basicConfig(
|
|
level=getattr(logging, LOG_LEVEL),
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
handlers=[logging.StreamHandler(sys.stdout)]
|
|
)
|
|
log = logging.getLogger("wgctl-monitor")
|
|
|
|
# ============================================
|
|
# Watchlist
|
|
# ============================================
|
|
|
|
_watchlist: dict[str, str] = {}
|
|
_watchlist_mtime: float = 0.0
|
|
|
|
def load_watchlist() -> dict[str, str]:
|
|
global _watchlist, _watchlist_mtime
|
|
|
|
try:
|
|
mtime = WATCHLIST_FILE.stat().st_mtime
|
|
if mtime == _watchlist_mtime:
|
|
return _watchlist
|
|
|
|
with WATCHLIST_FILE.open() as f:
|
|
_watchlist = json.load(f)
|
|
_watchlist_mtime = mtime
|
|
log.debug(f"Watchlist reloaded: {len(_watchlist)} entries")
|
|
|
|
except Exception as e:
|
|
log.error(f"Failed to load watchlist: {e}")
|
|
|
|
return _watchlist
|
|
|
|
def is_watched(ip: str) -> str | None:
|
|
watchlist = load_watchlist()
|
|
return watchlist.get(ip)
|
|
|
|
# ============================================
|
|
# Endpoint Resolution
|
|
# ============================================
|
|
|
|
def get_endpoint(public_key: str) -> str | None:
|
|
try:
|
|
import subprocess
|
|
result = subprocess.run(
|
|
["wg", "show", WG_INTERFACE, "endpoints"],
|
|
capture_output=True, text=True
|
|
)
|
|
for line in result.stdout.splitlines():
|
|
parts = line.split()
|
|
if len(parts) == 2 and parts[0] == public_key:
|
|
# Return just the IP without port
|
|
return parts[1].rsplit(":", 1)[0]
|
|
except Exception as e:
|
|
log.debug(f"Failed to get endpoint: {e}")
|
|
return None
|
|
|
|
def get_client_public_key(client_name: str) -> str | None:
|
|
key_file = Path(f"/etc/wireguard/clients/{client_name}_public.key")
|
|
try:
|
|
return key_file.read_text().strip()
|
|
except Exception:
|
|
return None
|
|
|
|
# ============================================
|
|
# Event Logging
|
|
# ============================================
|
|
|
|
def log_event(ip: str, client: str, event: str, endpoint: str | None = None):
|
|
entry = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"ip": ip,
|
|
"client": client,
|
|
"event": event,
|
|
}
|
|
|
|
# Update endpoint cache when we see a packet
|
|
|
|
cache_file = ENDPOINT_CACHE_FILE
|
|
try:
|
|
with open(cache_file) as f:
|
|
cache = json.load(f)
|
|
except:
|
|
cache = {}
|
|
cache[client] = ip
|
|
with open(cache_file, 'w') as f:
|
|
json.dump(cache, f, indent=2)
|
|
|
|
if endpoint:
|
|
entry["endpoint"] = endpoint
|
|
|
|
try:
|
|
with EVENTS_LOG.open("a") as f:
|
|
f.write(json.dumps(entry) + "\n")
|
|
log.debug(f"Event logged: {entry}")
|
|
except Exception as e:
|
|
log.error(f"Failed to write event: {e}")
|
|
|
|
# ============================================
|
|
# Handshake Poller
|
|
# ============================================
|
|
|
|
# Tracks last logged handshake ts per pubkey
|
|
_hs_last_logged: dict[str, int] = {}
|
|
|
|
def load_hs_cache():
|
|
try:
|
|
with HS_CACHE_FILE.open() as f:
|
|
return {k: int(v) for k, v in json.load(f).items()}
|
|
except Exception:
|
|
return {}
|
|
|
|
def save_hs_cache(cache):
|
|
try:
|
|
with HS_CACHE_FILE.open('w') as f:
|
|
json.dump(cache, f)
|
|
except Exception:
|
|
pass
|
|
|
|
def build_pubkey_to_name() -> dict[str, str]:
|
|
"""Build pubkey -> client name map from public key files."""
|
|
mapping = {}
|
|
clients_dir = Path("/etc/wireguard/clients")
|
|
for kf in clients_dir.glob("*_public.key"):
|
|
name = kf.stem.replace("_public", "")
|
|
try:
|
|
mapping[kf.read_text().strip()] = name
|
|
except Exception:
|
|
pass
|
|
return mapping
|
|
|
|
|
|
def poll_handshakes():
|
|
"""
|
|
Poll wg show latest-handshakes periodically.
|
|
Log a handshake event only when gap > WG_HANDSHAKE_CHECK_SEC (new session).
|
|
"""
|
|
global _hs_last_logged
|
|
|
|
_hs_last_logged = load_hs_cache()
|
|
|
|
pubkey_to_name = build_pubkey_to_name()
|
|
log.info(f"Handshake poller started — {len(pubkey_to_name)} peers, "
|
|
f"session threshold {WG_HANDSHAKE_CHECK_SEC}s")
|
|
|
|
while True:
|
|
try:
|
|
result = subprocess.run(
|
|
["wg", "show", WG_WG_INTERFACE, "latest-handshakes"],
|
|
capture_output=True, text=True
|
|
)
|
|
for line in result.stdout.strip().splitlines():
|
|
parts = line.split()
|
|
if len(parts) != 2:
|
|
continue
|
|
pubkey, ts_str = parts
|
|
try:
|
|
ts = int(ts_str)
|
|
except ValueError:
|
|
continue
|
|
if ts == 0:
|
|
continue
|
|
|
|
client = pubkey_to_name.get(pubkey)
|
|
if not client:
|
|
continue
|
|
|
|
last = _hs_last_logged.get(pubkey, 0)
|
|
gap = ts - last
|
|
|
|
# Always update last seen
|
|
_hs_last_logged[pubkey] = ts
|
|
|
|
if gap < WG_HANDSHAKE_CHECK_SEC:
|
|
continue # keepalive, skip
|
|
|
|
# Get endpoint
|
|
endpoint = get_endpoint(pubkey) or ''
|
|
|
|
# New session, log it
|
|
entry = {
|
|
"timestamp": datetime.fromtimestamp(ts, tz=timezone.utc).isoformat(),
|
|
"ip": "",
|
|
"client": client,
|
|
"event": "handshake",
|
|
"endpoint": endpoint,
|
|
}
|
|
try:
|
|
with EVENTS_LOG.open("a") as f:
|
|
f.write(json.dumps(entry) + "\n")
|
|
log.info(f"New session: {client} from {endpoint}")
|
|
except Exception as e:
|
|
log.error(f"Failed to write handshake event: {e}")
|
|
|
|
log.debug(f"Gap for {client}: {gap}s (threshold: {WG_HANDSHAKE_CHECK_SEC}s)")
|
|
save_hs_cache(_hs_last_logged)
|
|
|
|
except Exception as e:
|
|
log.error(f"Handshake poll error: {e}")
|
|
|
|
time.sleep(WG_HANDSHAKE_CHECK_SEC // 2) # poll at half the threshold
|
|
|
|
# ============================================
|
|
# Packet Handler
|
|
# ============================================
|
|
|
|
def handle_packet(pkt):
|
|
if not (IP in pkt and UDP in pkt):
|
|
return
|
|
|
|
# Only care about packets targeting WireGuard port
|
|
if pkt[UDP].dport != WG_PORT:
|
|
return
|
|
|
|
src_ip = pkt[IP].src
|
|
client = is_watched(src_ip)
|
|
|
|
if not client:
|
|
return
|
|
|
|
# Resolve real endpoint IP
|
|
public_key = get_client_public_key(client)
|
|
endpoint = None
|
|
if public_key:
|
|
endpoint = get_endpoint(public_key)
|
|
|
|
# If no endpoint from wg show, use packet source IP
|
|
if not endpoint:
|
|
endpoint = src_ip
|
|
|
|
log_event(src_ip, client, "attempt", endpoint)
|
|
log.info(f"Blocked attempt: {client} ({src_ip}) from endpoint {endpoint}")
|
|
|
|
# ============================================
|
|
# Signal Handling
|
|
# ============================================
|
|
|
|
def handle_signal(signum, frame):
|
|
log.info("Shutting down wgctl-monitor")
|
|
sys.exit(0)
|
|
|
|
signal.signal(signal.SIGTERM, handle_signal)
|
|
signal.signal(signal.SIGINT, handle_signal)
|
|
|
|
# ============================================
|
|
# Main
|
|
# ============================================
|
|
|
|
def main():
|
|
log.info(f"wgctl-monitor starting on interface {WG_INTERFACE} port {WG_PORT}")
|
|
|
|
if not WATCHLIST_FILE.exists():
|
|
log.error(f"Watchlist not found: {WATCHLIST_FILE}")
|
|
sys.exit(1)
|
|
|
|
load_watchlist()
|
|
log.info("Watchlist loaded, starting packet capture...")
|
|
|
|
# Start handshake poller in background thread
|
|
hs_thread = threading.Thread(target=poll_handshakes, daemon=True)
|
|
hs_thread.start()
|
|
|
|
sniff(
|
|
iface=WG_INTERFACE,
|
|
filter=f"udp port {WG_PORT}",
|
|
prn=handle_packet,
|
|
store=0
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|