- daemon: update_peer_history() tracks all endpoints per peer - daemon: endpoint_index.json for O(1) IP -> peer name lookup - daemon: poll_handshakes updates history on every cycle - json_helper: peer_history_lookup() uses index, falls back to scan - resolve::endpoint_parts: step 3 checks peer history index - json.sh: json::peer_history_lookup wrapper - resolve: mobile peer IPs now resolve to peer name via history
385 lines
12 KiB
Python
Executable file
385 lines
12 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
|
|
import subprocess
|
|
import threading
|
|
import json
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from scapy.all import IP, UDP, sniff
|
|
|
|
# ============================================
|
|
# Config
|
|
# ============================================
|
|
|
|
WATCHLIST_FILE = Path("/etc/wireguard/.wgctl/daemon/watchlist.json")
|
|
EVENTS_LOG = Path("/etc/wireguard/.wgctl/daemon/events.log")
|
|
WG_INTERFACE = os.environ.get("WG_INTERFACE", "eth0")
|
|
WG_PORT = int(os.environ.get("WG_PORT", "51820"))
|
|
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO")
|
|
WG_HANDSHAKE_CHECK_SEC = int(os.environ.get("WG_HANDSHAKE_CHECK_TIME_SEC", "300"))
|
|
WG_WG_INTERFACE = os.environ.get("WG_WG_INTERFACE", "wg0") # WireGuard interface, not capture interface
|
|
HS_CACHE_FILE = Path("/etc/wireguard/.wgctl/daemon/hs_cache.json")
|
|
ENDPOINT_CACHE_FILE = Path("/etc/wireguard/.wgctl/daemon/endpoint_cache.json")
|
|
PEER_HISTORY_DIR = Path("/etc/wireguard/.wgctl/peer-history")
|
|
ENDPOINT_INDEX_FILE = PEER_HISTORY_DIR / "endpoint_index.json"
|
|
|
|
# ============================================
|
|
# Logging
|
|
# ============================================
|
|
|
|
logging.basicConfig(
|
|
level=getattr(logging, LOG_LEVEL),
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
handlers=[logging.StreamHandler(sys.stdout)]
|
|
)
|
|
log = logging.getLogger("wgctl-monitor")
|
|
|
|
# ============================================
|
|
# Watchlist
|
|
# ============================================
|
|
|
|
_watchlist: dict[str, str] = {}
|
|
_watchlist_mtime: float = 0.0
|
|
|
|
def load_watchlist() -> dict[str, str]:
|
|
global _watchlist, _watchlist_mtime
|
|
|
|
try:
|
|
mtime = WATCHLIST_FILE.stat().st_mtime
|
|
if mtime == _watchlist_mtime:
|
|
return _watchlist
|
|
|
|
with WATCHLIST_FILE.open() as f:
|
|
_watchlist = json.load(f)
|
|
_watchlist_mtime = mtime
|
|
log.debug(f"Watchlist reloaded: {len(_watchlist)} entries")
|
|
|
|
except Exception as e:
|
|
log.error(f"Failed to load watchlist: {e}")
|
|
|
|
return _watchlist
|
|
|
|
def is_watched(ip: str) -> str | None:
|
|
watchlist = load_watchlist()
|
|
return watchlist.get(ip)
|
|
|
|
# ============================================
|
|
# Endpoint Resolution
|
|
# ============================================
|
|
|
|
def get_endpoint(public_key: str) -> str | None:
|
|
try:
|
|
import subprocess
|
|
result = subprocess.run(
|
|
["wg", "show", WG_INTERFACE, "endpoints"],
|
|
capture_output=True, text=True
|
|
)
|
|
for line in result.stdout.splitlines():
|
|
parts = line.split()
|
|
if len(parts) == 2 and parts[0] == public_key:
|
|
# Return just the IP without port
|
|
return parts[1].rsplit(":", 1)[0]
|
|
except Exception as e:
|
|
log.debug(f"Failed to get endpoint: {e}")
|
|
return None
|
|
|
|
def get_client_public_key(client_name: str) -> str | None:
|
|
key_file = Path(f"/etc/wireguard/clients/{client_name}_public.key")
|
|
try:
|
|
return key_file.read_text().strip()
|
|
except Exception:
|
|
return None
|
|
|
|
# ============================================
|
|
# Event Logging
|
|
# ============================================
|
|
|
|
def log_event(ip: str, client: str, event: str, endpoint: str | None = None):
|
|
entry = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"ip": ip,
|
|
"client": client,
|
|
"event": event,
|
|
}
|
|
|
|
# Update endpoint cache when we see a packet
|
|
|
|
cache_file = ENDPOINT_CACHE_FILE
|
|
try:
|
|
with open(cache_file) as f:
|
|
cache = json.load(f)
|
|
except:
|
|
cache = {}
|
|
cache[client] = ip
|
|
with open(cache_file, 'w') as f:
|
|
json.dump(cache, f, indent=2)
|
|
|
|
if endpoint:
|
|
entry["endpoint"] = endpoint
|
|
|
|
try:
|
|
with EVENTS_LOG.open("a") as f:
|
|
f.write(json.dumps(entry) + "\n")
|
|
log.debug(f"Event logged: {entry}")
|
|
except Exception as e:
|
|
log.error(f"Failed to write event: {e}")
|
|
|
|
# ============================================
|
|
# Handshake Poller
|
|
# ============================================
|
|
|
|
# Tracks last logged handshake ts per pubkey
|
|
_hs_last_logged: dict[str, int] = {}
|
|
|
|
def load_hs_cache():
|
|
try:
|
|
with HS_CACHE_FILE.open() as f:
|
|
return {k: int(v) for k, v in json.load(f).items()}
|
|
except Exception:
|
|
return {}
|
|
|
|
def save_hs_cache(cache):
|
|
try:
|
|
with HS_CACHE_FILE.open('w') as f:
|
|
json.dump(cache, f)
|
|
except Exception:
|
|
pass
|
|
|
|
def build_pubkey_to_name() -> dict[str, str]:
|
|
"""Build pubkey -> client name map from public key files."""
|
|
mapping = {}
|
|
clients_dir = Path("/etc/wireguard/clients")
|
|
for kf in clients_dir.glob("*_public.key"):
|
|
name = kf.stem.replace("_public", "")
|
|
try:
|
|
mapping[kf.read_text().strip()] = name
|
|
except Exception:
|
|
pass
|
|
return mapping
|
|
|
|
|
|
def poll_handshakes():
|
|
"""
|
|
Poll wg show latest-handshakes periodically.
|
|
Log a handshake event only when gap > WG_HANDSHAKE_CHECK_SEC (new session).
|
|
"""
|
|
global _hs_last_logged, _endpoint_index
|
|
|
|
_hs_last_logged = load_hs_cache()
|
|
_endpoint_index = load_endpoint_index()
|
|
|
|
pubkey_to_name = build_pubkey_to_name()
|
|
log.info(f"Handshake poller started — {len(pubkey_to_name)} peers, "
|
|
f"session threshold {WG_HANDSHAKE_CHECK_SEC}s")
|
|
log.info(f"Endpoint index loaded — {len(_endpoint_index)} known endpoints")
|
|
|
|
while True:
|
|
try:
|
|
result = subprocess.run(
|
|
["wg", "show", WG_WG_INTERFACE, "latest-handshakes"],
|
|
capture_output=True, text=True
|
|
)
|
|
for line in result.stdout.strip().splitlines():
|
|
parts = line.split()
|
|
if len(parts) != 2:
|
|
continue
|
|
pubkey, ts_str = parts
|
|
try:
|
|
ts = int(ts_str)
|
|
except ValueError:
|
|
continue
|
|
if ts == 0:
|
|
continue
|
|
|
|
client = pubkey_to_name.get(pubkey)
|
|
if not client:
|
|
continue
|
|
|
|
last = _hs_last_logged.get(pubkey, 0)
|
|
gap = ts - last
|
|
|
|
# Always update last seen
|
|
_hs_last_logged[pubkey] = ts
|
|
|
|
# Get endpoint
|
|
endpoint = get_endpoint(pubkey) or ''
|
|
if not endpoint:
|
|
try:
|
|
cache = json.loads(ENDPOINT_CACHE_FILE.read_text())
|
|
endpoint = cache.get(client, '')
|
|
except Exception:
|
|
pass
|
|
|
|
# Always update peer history + index
|
|
if endpoint:
|
|
update_peer_history(client, endpoint, ts)
|
|
|
|
if gap < WG_HANDSHAKE_CHECK_SEC:
|
|
continue # keepalive, skip
|
|
|
|
# New session — log it
|
|
entry = {
|
|
"timestamp": datetime.fromtimestamp(ts, tz=timezone.utc).isoformat(),
|
|
"ip": "",
|
|
"client": client,
|
|
"event": "handshake",
|
|
"endpoint": endpoint,
|
|
}
|
|
try:
|
|
with EVENTS_LOG.open("a") as f:
|
|
f.write(json.dumps(entry) + "\n")
|
|
log.info(f"New session: {client} from {endpoint}")
|
|
except Exception as e:
|
|
log.error(f"Failed to write handshake event: {e}")
|
|
|
|
log.debug(f"Gap for {client}: {gap}s (threshold: {WG_HANDSHAKE_CHECK_SEC}s)")
|
|
save_hs_cache(_hs_last_logged)
|
|
|
|
except Exception as e:
|
|
log.error(f"Handshake poll error: {e}")
|
|
|
|
time.sleep(WG_HANDSHAKE_CHECK_SEC // 2)
|
|
|
|
|
|
# ============================================
|
|
# Peer History
|
|
# ============================================
|
|
|
|
def load_endpoint_index() -> dict:
|
|
"""Load endpoint -> peer name index."""
|
|
try:
|
|
if ENDPOINT_INDEX_FILE.exists():
|
|
return json.loads(ENDPOINT_INDEX_FILE.read_text())
|
|
except Exception:
|
|
pass
|
|
return {}
|
|
|
|
def save_endpoint_index(index: dict):
|
|
"""Save endpoint -> peer name index."""
|
|
try:
|
|
PEER_HISTORY_DIR.mkdir(parents=True, exist_ok=True)
|
|
ENDPOINT_INDEX_FILE.write_text(json.dumps(index, indent=2))
|
|
except Exception as e:
|
|
log.error(f"Failed to save endpoint index: {e}")
|
|
|
|
# In-memory index — loaded once, updated on each new endpoint
|
|
_endpoint_index: dict = {}
|
|
|
|
def update_peer_history(client: str, endpoint: str, ts: int):
|
|
"""
|
|
Update peer endpoint history and endpoint index.
|
|
Called on every poll cycle to keep last_seen current.
|
|
"""
|
|
global _endpoint_index
|
|
if not endpoint:
|
|
return
|
|
try:
|
|
PEER_HISTORY_DIR.mkdir(parents=True, exist_ok=True)
|
|
history_file = PEER_HISTORY_DIR / f"{client}.json"
|
|
|
|
if history_file.exists():
|
|
try:
|
|
data = json.loads(history_file.read_text())
|
|
except Exception:
|
|
data = {"peer": client, "endpoints": {}}
|
|
else:
|
|
data = {"peer": client, "endpoints": {}}
|
|
|
|
ts_iso = datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
|
|
eps = data.setdefault("endpoints", {})
|
|
is_new = endpoint not in eps
|
|
|
|
if is_new:
|
|
eps[endpoint] = {
|
|
"first_seen": ts_iso,
|
|
"last_seen": ts_iso,
|
|
"count": 1
|
|
}
|
|
log.debug(f"New endpoint for {client}: {endpoint}")
|
|
# Update in-memory index and persist
|
|
_endpoint_index[endpoint] = client
|
|
save_endpoint_index(_endpoint_index)
|
|
else:
|
|
eps[endpoint]["last_seen"] = ts_iso
|
|
eps[endpoint]["count"] += 1
|
|
|
|
history_file.write_text(json.dumps(data, indent=2))
|
|
except Exception as e:
|
|
log.error(f"Failed to update peer history for {client}: {e}")
|
|
|
|
|
|
# ============================================
|
|
# Packet Handler
|
|
# ============================================
|
|
|
|
def handle_packet(pkt):
|
|
if not (IP in pkt and UDP in pkt):
|
|
return
|
|
|
|
# Only care about packets targeting WireGuard port
|
|
if pkt[UDP].dport != WG_PORT:
|
|
return
|
|
|
|
src_ip = pkt[IP].src
|
|
client = is_watched(src_ip)
|
|
|
|
if not client:
|
|
return
|
|
|
|
# Resolve real endpoint IP
|
|
public_key = get_client_public_key(client)
|
|
endpoint = None
|
|
if public_key:
|
|
endpoint = get_endpoint(public_key)
|
|
|
|
# If no endpoint from wg show, use packet source IP
|
|
if not endpoint:
|
|
endpoint = src_ip
|
|
|
|
log_event(src_ip, client, "attempt", endpoint)
|
|
log.info(f"Blocked attempt: {client} ({src_ip}) from endpoint {endpoint}")
|
|
|
|
# ============================================
|
|
# Signal Handling
|
|
# ============================================
|
|
|
|
def handle_signal(signum, frame):
|
|
log.info("Shutting down wgctl-monitor")
|
|
sys.exit(0)
|
|
|
|
signal.signal(signal.SIGTERM, handle_signal)
|
|
signal.signal(signal.SIGINT, handle_signal)
|
|
|
|
# ============================================
|
|
# Main
|
|
# ============================================
|
|
|
|
def main():
|
|
log.info(f"wgctl-monitor starting on interface {WG_INTERFACE} port {WG_PORT}")
|
|
|
|
if not WATCHLIST_FILE.exists():
|
|
log.error(f"Watchlist not found: {WATCHLIST_FILE}")
|
|
sys.exit(1)
|
|
|
|
load_watchlist()
|
|
log.info("Watchlist loaded, starting packet capture...")
|
|
|
|
# Start handshake poller in background thread
|
|
hs_thread = threading.Thread(target=poll_handshakes, daemon=True)
|
|
hs_thread.start()
|
|
|
|
sniff(
|
|
iface=WG_INTERFACE,
|
|
filter=f"udp port {WG_PORT}",
|
|
prn=handle_packet,
|
|
store=0
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|