feat(wg-mtu-auto): add --prefer-wg-egress, --auto-pmtu-from-wg, and --set-wg-mtu; refine egress detection & PMTU logic
Refactor helpers; allow preferring wg* as egress when default route uses WireGuard; auto-discover peer endpoints from `wg show`/showconf as PMTU targets; add explicit `--set-wg-mtu` override with clamping; improve default-route parsing and dedup of targets. Update unit tests to cover prefer-wg egress selection, auto-pmtu-from-wg, median/min policies, all-fail fallback, and explicit override behavior. Conversation context: https://chatgpt.com/share/68efc179-1a10-800f-9656-1e8731b40546
This commit is contained in:
Binary file not shown.
Binary file not shown.
225
main.py
225
main.py
@@ -4,72 +4,119 @@ wg_mtu_auto.py — Auto-detect egress IF, optionally probe Path MTU to one or mo
|
|||||||
compute the correct WireGuard MTU, and apply it.
|
compute the correct WireGuard MTU, and apply it.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
sudo ./wg_mtu_auto.py
|
sudo ./main.py
|
||||||
sudo ./wg_mtu_auto.py --force-egress-mtu 1452
|
sudo ./main.py --force-egress-mtu 1452
|
||||||
sudo ./wg_mtu_auto.py --pmtu-target 46.4.224.77 --pmtu-target 2a01:4f8:2201:4695::2
|
sudo ./main.py --pmtu-target 46.4.224.77 --pmtu-target 2a01:4f8:2201:4695::2
|
||||||
sudo ./wg_mtu_auto.py --pmtu-target 46.4.224.77,2a01:4f8:2201:4695::2 --pmtu-policy min
|
sudo ./main.py --pmtu-target 46.4.224.77,2a01:4f8:2201:4695::2 --pmtu-policy min
|
||||||
./wg_mtu_auto.py --dry-run
|
sudo ./main.py --prefer-wg-egress --auto-pmtu-from-wg
|
||||||
|
./main.py --dry-run
|
||||||
"""
|
"""
|
||||||
import argparse, os, re, subprocess, sys, pathlib, ipaddress, statistics
|
import argparse
|
||||||
|
import ipaddress
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
import statistics
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------- helpers -----------------
|
||||||
|
|
||||||
def run(cmd): # -> str
|
def run(cmd): # -> str
|
||||||
return subprocess.run(cmd, check=False, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True).stdout.strip()
|
return subprocess.run(
|
||||||
|
cmd, check=False, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True
|
||||||
|
).stdout.strip()
|
||||||
|
|
||||||
|
|
||||||
def rc(cmd): # -> int
|
def rc(cmd): # -> int
|
||||||
return subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode
|
return subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode
|
||||||
|
|
||||||
|
|
||||||
def exists_iface(iface): # -> bool
|
def exists_iface(iface): # -> bool
|
||||||
return pathlib.Path(f"/sys/class/net/{iface}").exists()
|
return pathlib.Path(f"/sys/class/net/{iface}").exists()
|
||||||
|
|
||||||
def get_default_ifaces(): # -> list[str]
|
|
||||||
devs = []
|
|
||||||
for cmd in (["ip","-4","route","show","default"], ["ip","-6","route","show","default"]):
|
|
||||||
out = run(cmd)
|
|
||||||
for line in out.splitlines():
|
|
||||||
m = re.search(r"\bdev\s+(\S+)", line)
|
|
||||||
if m: devs.append(m.group(1))
|
|
||||||
if not devs:
|
|
||||||
for cmd in (["ip","route","get","1.1.1.1"], ["ip","-6","route","get","2606:4700:4700::1111"]):
|
|
||||||
out = run(cmd)
|
|
||||||
m = re.search(r"\bdev\s+(\S+)", out)
|
|
||||||
if m: devs.append(m.group(1))
|
|
||||||
uniq = []
|
|
||||||
for d in devs:
|
|
||||||
if not d or d == "lo" or re.match(r"^(wg|tun)\d*$", d) or not exists_iface(d): continue
|
|
||||||
if d not in uniq: uniq.append(d)
|
|
||||||
return uniq
|
|
||||||
|
|
||||||
def read_mtu(iface): # -> int
|
def read_mtu(iface): # -> int
|
||||||
with open(f"/sys/class/net/{iface}/mtu","r") as f:
|
with open(f"/sys/class/net/{iface}/mtu", "r") as f:
|
||||||
return int(f.read().strip())
|
return int(f.read().strip())
|
||||||
|
|
||||||
|
|
||||||
def set_mtu(iface, mtu, dry):
|
def set_mtu(iface, mtu, dry):
|
||||||
if dry:
|
if dry:
|
||||||
print(f"[wg-mtu] DRY-RUN: ip link set mtu {mtu} dev {iface}")
|
print(f"[wg-mtu] DRY-RUN: ip link set mtu {mtu} dev {iface}")
|
||||||
else:
|
else:
|
||||||
subprocess.run(["ip","link","set","mtu",str(mtu),"dev",iface], check=True)
|
subprocess.run(["ip", "link", "set", "mtu", str(mtu), "dev", iface], check=True)
|
||||||
|
|
||||||
|
|
||||||
def require_root(dry):
|
def require_root(dry):
|
||||||
if not dry and os.geteuid() != 0:
|
if not dry and os.geteuid() != 0:
|
||||||
print("[wg-mtu][ERROR] Please run as root (sudo) or use --dry-run.", file=sys.stderr)
|
print("[wg-mtu][ERROR] Please run as root (sudo) or use --dry-run.", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def is_ipv6(addr): # -> bool
|
def is_ipv6(addr): # -> bool
|
||||||
try:
|
try:
|
||||||
return isinstance(ipaddress.ip_address(addr), ipaddress.IPv6Address)
|
return isinstance(ipaddress.ip_address(addr), ipaddress.IPv6Address)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
# best-effort for hostnames (contains ':')
|
||||||
return ":" in addr
|
return ":" in addr
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------- route & iface selection -----------------
|
||||||
|
|
||||||
|
def default_route_lines():
|
||||||
|
lines = []
|
||||||
|
for cmd in (["ip", "-4", "route", "show", "default"], ["ip", "-6", "route", "show", "default"]):
|
||||||
|
out = run(cmd)
|
||||||
|
if out:
|
||||||
|
lines.extend(out.splitlines())
|
||||||
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_ifaces(ignore_vpn=True): # -> list[str]
|
||||||
|
devs = []
|
||||||
|
for line in default_route_lines():
|
||||||
|
m = re.search(r"\bdev\s+(\S+)", line)
|
||||||
|
if m:
|
||||||
|
devs.append(m.group(1))
|
||||||
|
# fallback via route get
|
||||||
|
if not devs:
|
||||||
|
for cmd in (["ip", "route", "get", "1.1.1.1"], ["ip", "-6", "route", "get", "2606:4700:4700::1111"]):
|
||||||
|
out = run(cmd)
|
||||||
|
m = re.search(r"\bdev\s+(\S+)", out)
|
||||||
|
if m:
|
||||||
|
devs.append(m.group(1))
|
||||||
|
uniq = []
|
||||||
|
for d in devs:
|
||||||
|
if not d or d == "lo" or not exists_iface(d):
|
||||||
|
continue
|
||||||
|
if ignore_vpn and re.match(r"^(wg|tun)\d*$", d):
|
||||||
|
continue
|
||||||
|
if d not in uniq:
|
||||||
|
uniq.append(d)
|
||||||
|
return uniq
|
||||||
|
|
||||||
|
|
||||||
|
def wg_default_is_active(wg_if: str) -> bool:
|
||||||
|
# check if any default route is via wg_if
|
||||||
|
return any(re.search(rf"\bdev\s+{re.escape(wg_if)}\b", line) for line in default_route_lines())
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------- PMTU probing -----------------
|
||||||
|
|
||||||
def ping_ok(payload, target, timeout_s): # -> bool
|
def ping_ok(payload, target, timeout_s): # -> bool
|
||||||
base = ["ping","-M","do","-c","1","-s",str(payload),"-W",str(max(1, int(round(timeout_s))))]
|
base = ["ping", "-M", "do", "-c", "1", "-s", str(payload), "-W", str(max(1, int(round(timeout_s))))]
|
||||||
if is_ipv6(target):
|
if is_ipv6(target):
|
||||||
base.insert(1, "-6")
|
base.insert(1, "-6")
|
||||||
return rc(base + [target]) == 0
|
return rc(base + [target]) == 0
|
||||||
|
|
||||||
|
|
||||||
def probe_pmtu(target, lo_payload=1200, hi_payload=1472, timeout=1.0): # -> int|None
|
def probe_pmtu(target, lo_payload=1200, hi_payload=1472, timeout=1.0): # -> int|None
|
||||||
"""Binary-search the largest payload that passes with DF. Return Path-MTU (payload + hdr) or None."""
|
"""Binary-search the largest payload that passes with DF. Return Path MTU (payload + hdr) or None.
|
||||||
|
Header: +28 (IPv4), +48 (IPv6)."""
|
||||||
hdr = 48 if is_ipv6(target) else 28
|
hdr = 48 if is_ipv6(target) else 28
|
||||||
# ensure the lower bound works; if not, try slightly smaller floors
|
# ensure lower bound works; if not, try smaller floors
|
||||||
if not ping_ok(lo_payload, target, timeout):
|
if not ping_ok(lo_payload, target, timeout):
|
||||||
for p in (1180, 1160, 1140):
|
for p in (1180, 1160, 1140):
|
||||||
if ping_ok(p, target, timeout):
|
if ping_ok(p, target, timeout):
|
||||||
@@ -87,8 +134,8 @@ def probe_pmtu(target, lo_payload=1200, hi_payload=1472, timeout=1.0): # -> int
|
|||||||
hi = mid - 1
|
hi = mid - 1
|
||||||
return (best + hdr) if best is not None else None
|
return (best + hdr) if best is not None else None
|
||||||
|
|
||||||
|
|
||||||
def choose_effective(pmtus, policy="min"): # -> int
|
def choose_effective(pmtus, policy="min"): # -> int
|
||||||
"""Pick an effective PMTU from a list of successful PMTUs."""
|
|
||||||
if not pmtus:
|
if not pmtus:
|
||||||
raise ValueError("no PMTUs to choose from")
|
raise ValueError("no PMTUs to choose from")
|
||||||
if policy == "min":
|
if policy == "min":
|
||||||
@@ -99,35 +146,103 @@ def choose_effective(pmtus, policy="min"): # -> int
|
|||||||
return int(statistics.median(sorted(pmtus)))
|
return int(statistics.median(sorted(pmtus)))
|
||||||
raise ValueError(f"unknown policy {policy}")
|
raise ValueError(f"unknown policy {policy}")
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------- WireGuard helpers (opt-in) -----------------
|
||||||
|
|
||||||
|
def wg_is_active(wg_if: str) -> bool:
|
||||||
|
if not exists_iface(wg_if):
|
||||||
|
return False
|
||||||
|
return rc(["wg", "show", wg_if]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def wg_peer_endpoints(wg_if: str) -> list[str]:
|
||||||
|
"""Return list of peer endpoints (hostnames/IPs) – port stripped."""
|
||||||
|
targets = []
|
||||||
|
|
||||||
|
# 1) Try: wg show <if> endpoints
|
||||||
|
out = run(["wg", "show", wg_if, "endpoints"])
|
||||||
|
for line in out.splitlines():
|
||||||
|
# format: <peer_public_key>\t<endpoint or (none)>
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) >= 2 and parts[-1] != "(none)":
|
||||||
|
ep = parts[-1] # host:port
|
||||||
|
host = ep.rsplit(":", 1)[0]
|
||||||
|
# IPv6 endpoint may be like [2001:db8::1]:51820
|
||||||
|
host = host.strip("[]")
|
||||||
|
targets.append(host)
|
||||||
|
|
||||||
|
# 2) Fallback: wg showconf (root may be required)
|
||||||
|
if not targets:
|
||||||
|
conf = run(["wg", "showconf", wg_if])
|
||||||
|
if conf:
|
||||||
|
for m in re.finditer(r"^Endpoint\s*=\s*(.+)$", conf, flags=re.MULTILINE):
|
||||||
|
ep = m.group(1).strip()
|
||||||
|
host = ep.rsplit(":", 1)[0].strip("[]")
|
||||||
|
targets.append(host)
|
||||||
|
|
||||||
|
# dedupe & sanity
|
||||||
|
cleaned = []
|
||||||
|
for t in targets:
|
||||||
|
if t and t not in cleaned:
|
||||||
|
cleaned.append(t)
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------- main -----------------
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
ap = argparse.ArgumentParser(description="Compute/apply WireGuard MTU based on egress MTU and optional multi-target PMTU probing.")
|
ap = argparse.ArgumentParser(description="Compute/apply WireGuard MTU based on egress MTU and optional multi-target PMTU probing.")
|
||||||
ap.add_argument("--egress-if", help="Explicit egress interface (auto-detected if omitted).")
|
ap.add_argument("--egress-if", help="Explicit egress interface (auto-detected if omitted).")
|
||||||
ap.add_argument("--force-egress-mtu", type=int, help="Force this MTU on the egress interface before computing wg MTU.")
|
ap.add_argument("--prefer-wg-egress", action="store_true",
|
||||||
ap.add_argument("--wg-if", default=os.environ.get("WG_IF","wg0"), help="WireGuard interface name (default: wg0).")
|
help="Allow/consider wg* as egress and prefer it if default route uses wg (default: disabled).")
|
||||||
ap.add_argument("--wg-overhead", type=int, default=int(os.environ.get("WG_OVERHEAD","80")), help="Bytes of WG overhead to subtract (default: 80).")
|
ap.add_argument("--auto-pmtu-from-wg", action="store_true",
|
||||||
ap.add_argument("--wg-min", type=int, default=int(os.environ.get("WG_MIN","1280")), help="Minimum allowed WG MTU (default: 1280).")
|
help="Automatically add WireGuard peer endpoints as PMTU targets (default: disabled).")
|
||||||
|
ap.add_argument("--wg-if", default=os.environ.get("WG_IF", "wg0"), help="WireGuard interface name (default: wg0).")
|
||||||
|
ap.add_argument("--wg-overhead", type=int, default=int(os.environ.get("WG_OVERHEAD", "80")), help="Bytes of WG overhead to subtract (default: 80).")
|
||||||
|
ap.add_argument("--wg-min", type=int, default=int(os.environ.get("WG_MIN", "1280")), help="Minimum allowed WG MTU (default: 1280).")
|
||||||
# PMTU (multi-target)
|
# PMTU (multi-target)
|
||||||
ap.add_argument("--pmtu-target", action="append", help="Target hostname/IP to probe PMTU. Can be given multiple times OR comma-separated.")
|
ap.add_argument("--pmtu-target", action="append", help="Target hostname/IP to probe PMTU. Can be given multiple times OR comma-separated.")
|
||||||
ap.add_argument("--pmtu-timeout", type=float, default=1.0, help="Timeout (seconds) per ping probe (default: 1.0).")
|
ap.add_argument("--pmtu-timeout", type=float, default=1.0, help="Timeout (seconds) per ping probe (default: 1.0).")
|
||||||
ap.add_argument("--pmtu-min-payload", type=int, default=1200, help="Lower bound payload for PMTU search (default: 1200).")
|
ap.add_argument("--pmtu-min-payload", type=int, default=1200, help="Lower bound payload for PMTU search (default: 1200).")
|
||||||
ap.add_argument("--pmtu-max-payload", type=int, default=1472, help="Upper bound payload for PMTU search (default: 1472 ~ 1500-28).")
|
ap.add_argument("--pmtu-max-payload", type=int, default=1472, help="Upper bound payload for PMTU search (default: 1472 ~ 1500-28).")
|
||||||
ap.add_argument("--pmtu-policy", choices=["min","median","max"], default="min",
|
ap.add_argument("--pmtu-policy", choices=["min", "median", "max"], default="min",
|
||||||
help="How to choose effective PMTU across multiple targets (default: min).")
|
help="How to choose effective PMTU across multiple targets (default: min).")
|
||||||
ap.add_argument("--dry-run", action="store_true", help="Show actions without applying changes.")
|
ap.add_argument("--dry-run", action="store_true", help="Show actions without applying changes.")
|
||||||
|
# NEW: force a specific WireGuard MTU (overrides computed value)
|
||||||
|
ap.add_argument("--set-wg-mtu", type=int, help="Force a specific MTU to apply on the WireGuard interface (overrides computed value).")
|
||||||
|
# (legacy / optional) force egress MTU
|
||||||
|
ap.add_argument("--force-egress-mtu", type=int, help="Force this MTU on the egress interface before computing wg MTU.")
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
require_root(args.dry_run)
|
require_root(args.dry_run)
|
||||||
|
|
||||||
# Detect egress
|
# Egress detection (wg ignored by default)
|
||||||
egress = args.egress_if or (get_default_ifaces()[0] if get_default_ifaces() else None)
|
if args.egress_if:
|
||||||
|
egress = args.egress_if
|
||||||
|
else:
|
||||||
|
ignore_vpn = not args.prefer_wg_egress
|
||||||
|
cands = get_default_ifaces(ignore_vpn=ignore_vpn)
|
||||||
|
# If we allow wg and default route is via wg-if, prefer it first
|
||||||
|
if args.prefer_wg_egress and wg_is_active(args.wg_if) and wg_default_is_active(args.wg_if):
|
||||||
|
if args.wg_if in cands:
|
||||||
|
cands.remove(args.wg_if)
|
||||||
|
cands.insert(0, args.wg_if)
|
||||||
|
egress = cands[0] if cands else None
|
||||||
|
|
||||||
if not egress:
|
if not egress:
|
||||||
print("[wg-mtu][ERROR] Could not detect egress interface (use --egress-if).", file=sys.stderr)
|
print("[wg-mtu][ERROR] Could not detect egress interface (use --egress-if).", file=sys.stderr)
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
if not exists_iface(egress):
|
if not exists_iface(egress):
|
||||||
print(f"[wg-mtu][ERROR] Interface {egress} does not exist.", file=sys.stderr); sys.exit(3)
|
print(f"[wg-mtu][ERROR] Interface {egress} does not exist.", file=sys.stderr)
|
||||||
|
sys.exit(3)
|
||||||
print(f"[wg-mtu] Detected egress interface: {egress}")
|
print(f"[wg-mtu] Detected egress interface: {egress}")
|
||||||
|
|
||||||
# Egress MTU
|
# Egress MTU
|
||||||
|
if args.prefer_wg_egress and egress == args.wg_if:
|
||||||
|
print(f"[wg-mtu] Using WireGuard interface {args.wg_if} as egress basis.")
|
||||||
|
if args.wg_if == egress and not wg_is_active(args.wg_if):
|
||||||
|
print(f"[wg-mtu][WARN] {args.wg_if} selected as egress but WireGuard is not active.", file=sys.stderr)
|
||||||
|
|
||||||
if args.force_egress_mtu:
|
if args.force_egress_mtu:
|
||||||
print(f"[wg-mtu] Forcing egress MTU {args.force_egress_mtu} on {egress}")
|
print(f"[wg-mtu] Forcing egress MTU {args.force_egress_mtu} on {egress}")
|
||||||
set_mtu(egress, args.force_egress_mtu, args.dry_run)
|
set_mtu(egress, args.force_egress_mtu, args.dry_run)
|
||||||
@@ -136,24 +251,37 @@ def main():
|
|||||||
base_mtu = read_mtu(egress)
|
base_mtu = read_mtu(egress)
|
||||||
print(f"[wg-mtu] Egress base MTU: {base_mtu}")
|
print(f"[wg-mtu] Egress base MTU: {base_mtu}")
|
||||||
|
|
||||||
# PMTU over multiple targets
|
# Build PMTU target list
|
||||||
effective_mtu = base_mtu
|
|
||||||
pmtu_targets = []
|
pmtu_targets = []
|
||||||
if args.pmtu_target:
|
if args.pmtu_target:
|
||||||
# flatten comma-separated + repeated flags
|
|
||||||
for item in args.pmtu_target:
|
for item in args.pmtu_target:
|
||||||
pmtu_targets.extend([x.strip() for x in item.split(",") if x.strip()])
|
pmtu_targets.extend([x.strip() for x in item.split(",") if x.strip()])
|
||||||
|
|
||||||
|
if args.auto_pmtu_from_wg:
|
||||||
|
if wg_is_active(args.wg_if):
|
||||||
|
wg_targets = wg_peer_endpoints(args.wg_if)
|
||||||
|
if wg_targets:
|
||||||
|
print(f"[wg-mtu] Auto-added WG peer endpoints as PMTU targets: {', '.join(wg_targets)}")
|
||||||
|
pmtu_targets.extend(wg_targets)
|
||||||
|
else:
|
||||||
|
print("[wg-mtu] INFO: No WG peer endpoints discovered (wg show/showconf).")
|
||||||
|
else:
|
||||||
|
print(f"[wg-mtu] INFO: {args.wg_if} is not active; skipping auto PMTU targets from WG.")
|
||||||
|
|
||||||
|
# Deduplicate PMTU targets
|
||||||
|
if pmtu_targets:
|
||||||
|
pmtu_targets = list(dict.fromkeys(pmtu_targets))
|
||||||
|
|
||||||
|
# PMTU probing
|
||||||
|
effective_mtu = base_mtu
|
||||||
if pmtu_targets:
|
if pmtu_targets:
|
||||||
results = {}
|
|
||||||
good = []
|
good = []
|
||||||
print(f"[wg-mtu] Probing Path MTU for: {', '.join(pmtu_targets)} (policy={args.pmtu_policy})")
|
print(f"[wg-mtu] Probing Path MTU for: {', '.join(pmtu_targets)} (policy={args.pmtu_policy})")
|
||||||
for t in pmtu_targets:
|
for t in pmtu_targets:
|
||||||
p = probe_pmtu(t, args.pmtu_min_payload, args.pmtu_max_payload, args.pmtu_timeout)
|
p = probe_pmtu(t, args.pmtu_min_payload, args.pmtu_max_payload, args.pmtu_timeout)
|
||||||
results[t] = p
|
print(f"[wg-mtu] - {t}: {p if p else 'probe failed'}")
|
||||||
if p:
|
if p:
|
||||||
good.append(p)
|
good.append(p)
|
||||||
print(f"[wg-mtu] - {t}: {'%s' % p if p else 'probe failed'}")
|
|
||||||
if good:
|
if good:
|
||||||
chosen = choose_effective(good, args.pmtu_policy)
|
chosen = choose_effective(good, args.pmtu_policy)
|
||||||
print(f"[wg-mtu] Selected Path MTU (policy={args.pmtu_policy}): {chosen}")
|
print(f"[wg-mtu] Selected Path MTU (policy={args.pmtu_policy}): {chosen}")
|
||||||
@@ -165,6 +293,14 @@ def main():
|
|||||||
wg_mtu = max(args.wg_min, effective_mtu - args.wg_overhead)
|
wg_mtu = max(args.wg_min, effective_mtu - args.wg_overhead)
|
||||||
print(f"[wg-mtu] Computed {args.wg_if} MTU: {wg_mtu} (overhead={args.wg_overhead}, min={args.wg_min})")
|
print(f"[wg-mtu] Computed {args.wg_if} MTU: {wg_mtu} (overhead={args.wg_overhead}, min={args.wg_min})")
|
||||||
|
|
||||||
|
# --- NEW: override with --set-wg-mtu if provided
|
||||||
|
if args.set_wg_mtu is not None:
|
||||||
|
if args.set_wg_mtu < args.wg_min:
|
||||||
|
print(f"[wg-mtu][WARN] --set-wg-mtu {args.set_wg_mtu} is below wg-min {args.wg_min}; clamping to {args.wg_min}.")
|
||||||
|
args.set_wg_mtu = args.wg_min
|
||||||
|
wg_mtu = args.set_wg_mtu
|
||||||
|
print(f"[wg-mtu] Forcing WireGuard MTU (override): {wg_mtu}")
|
||||||
|
|
||||||
# Apply
|
# Apply
|
||||||
if exists_iface(args.wg_if):
|
if exists_iface(args.wg_if):
|
||||||
set_mtu(args.wg_if, wg_mtu, args.dry_run)
|
set_mtu(args.wg_if, wg_mtu, args.dry_run)
|
||||||
@@ -174,5 +310,6 @@ def main():
|
|||||||
|
|
||||||
print(f"[wg-mtu] Done. Summary: egress={egress} mtu={base_mtu}, effective_mtu={effective_mtu}, {args.wg_if}_mtu={wg_mtu}")
|
print(f"[wg-mtu] Done. Summary: egress={egress} mtu={base_mtu}, effective_mtu={effective_mtu}, {args.wg_if}_mtu={wg_mtu}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
183
test.py
183
test.py
@@ -1,14 +1,15 @@
|
|||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch, call
|
from unittest.mock import patch
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
# Import the script as a module
|
# Import the script as a module
|
||||||
import main as automtu
|
import main as automtu
|
||||||
|
|
||||||
|
|
||||||
class TestWgMtuAuto(unittest.TestCase):
|
class TestWgMtuAutoExtended(unittest.TestCase):
|
||||||
|
# ---------- Baseline behavior (unchanged) ----------
|
||||||
|
|
||||||
@patch("main.set_mtu")
|
@patch("main.set_mtu")
|
||||||
@patch("main.read_mtu", return_value=1500)
|
@patch("main.read_mtu", return_value=1500)
|
||||||
@@ -16,12 +17,8 @@ class TestWgMtuAuto(unittest.TestCase):
|
|||||||
@patch("main.get_default_ifaces", return_value=["eth0"])
|
@patch("main.get_default_ifaces", return_value=["eth0"])
|
||||||
@patch("main.require_root", return_value=None)
|
@patch("main.require_root", return_value=None)
|
||||||
def test_no_pmtu_uses_egress_minus_overhead(
|
def test_no_pmtu_uses_egress_minus_overhead(
|
||||||
self, _req_root, _get_def, _exists, _read_mtu, mock_set_mtu
|
self, _req_root, mock_get_def, _exists, _read_mtu, mock_set_mtu
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Without PMTU probing, wg MTU should be base_mtu - overhead (clamped by min).
|
|
||||||
With base=1500, overhead=80 ⇒ wg_mtu=1420.
|
|
||||||
"""
|
|
||||||
argv = ["main.py", "--dry-run"]
|
argv = ["main.py", "--dry-run"]
|
||||||
with patch.object(sys, "argv", argv):
|
with patch.object(sys, "argv", argv):
|
||||||
buf = io.StringIO()
|
buf = io.StringIO()
|
||||||
@@ -32,74 +29,100 @@ class TestWgMtuAuto(unittest.TestCase):
|
|||||||
self.assertIn("Detected egress interface: eth0", out)
|
self.assertIn("Detected egress interface: eth0", out)
|
||||||
self.assertIn("Egress base MTU: 1500", out)
|
self.assertIn("Egress base MTU: 1500", out)
|
||||||
self.assertIn("Computed wg0 MTU: 1420", out)
|
self.assertIn("Computed wg0 MTU: 1420", out)
|
||||||
|
|
||||||
# dry-run still calls set_mtu (but prints DRY-RUN); ensure it targeted wg0 with 1420
|
|
||||||
mock_set_mtu.assert_any_call("wg0", 1420, True)
|
mock_set_mtu.assert_any_call("wg0", 1420, True)
|
||||||
|
# get_default_ifaces should be called with ignore_vpn=True by default
|
||||||
|
mock_get_def.assert_called_with(ignore_vpn=True)
|
||||||
|
|
||||||
|
# ---------- prefer-wg-egress selection ----------
|
||||||
|
|
||||||
|
@patch("main.wg_default_is_active", return_value=True)
|
||||||
|
@patch("main.wg_is_active", return_value=True)
|
||||||
@patch("main.set_mtu")
|
@patch("main.set_mtu")
|
||||||
|
@patch("main.read_mtu", return_value=1420)
|
||||||
@patch("main.exists_iface", return_value=True)
|
@patch("main.exists_iface", return_value=True)
|
||||||
@patch("main.get_default_ifaces", return_value=["eth0"])
|
@patch("main.get_default_ifaces", return_value=["eth0", "wg0"])
|
||||||
@patch("main.require_root", return_value=None)
|
@patch("main.require_root", return_value=None)
|
||||||
def test_force_egress_mtu_and_pmtu_multiple_targets_min_policy(
|
def test_prefer_wg_egress_picks_wg0_when_default_route_via_wg(
|
||||||
self, _req_root, _get_def, _exists, mock_set_mtu
|
self, _req_root, mock_get_def, _exists, _read_mtu, _set_mtu, _wg_is_active, _wg_def_active
|
||||||
):
|
):
|
||||||
"""
|
argv = ["main.py", "--dry-run", "--prefer-wg-egress", "--wg-if", "wg0"]
|
||||||
base_mtu forced=1452; PMTU results: 1452, 1420 -> policy=min => 1420 chosen.
|
with patch.object(sys, "argv", argv):
|
||||||
effective=min(1452,1420)=1420; wg_mtu=1420-80=1340
|
buf = io.StringIO()
|
||||||
"""
|
with redirect_stdout(buf):
|
||||||
with patch("main.read_mtu", return_value=9999): # should be ignored because we force
|
automtu.main()
|
||||||
with patch("main.probe_pmtu", side_effect=[1452, 1420]):
|
|
||||||
argv = [
|
|
||||||
"main.py",
|
|
||||||
"--dry-run",
|
|
||||||
"--force-egress-mtu", "1452",
|
|
||||||
"--pmtu-target", "t1",
|
|
||||||
"--pmtu-target", "t2",
|
|
||||||
"--pmtu-policy", "min",
|
|
||||||
]
|
|
||||||
with patch.object(sys, "argv", argv):
|
|
||||||
buf = io.StringIO()
|
|
||||||
with redirect_stdout(buf):
|
|
||||||
automtu.main()
|
|
||||||
|
|
||||||
out = buf.getvalue()
|
out = buf.getvalue()
|
||||||
self.assertIn("Forcing egress MTU 1452 on eth0", out)
|
# When prefer-wg is set AND wg default route is active, wg0 should be chosen as egress
|
||||||
self.assertIn("Probing Path MTU for: t1, t2 (policy=min)", out)
|
self.assertIn("Detected egress interface: wg0", out)
|
||||||
self.assertIn("Selected Path MTU (policy=min): 1420", out)
|
self.assertIn("Using WireGuard interface wg0 as egress basis.", out)
|
||||||
|
# Computed MTU: base 1420 - 80 = 1340 (clamped by min=1280)
|
||||||
self.assertIn("Computed wg0 MTU: 1340", out)
|
self.assertIn("Computed wg0 MTU: 1340", out)
|
||||||
mock_set_mtu.assert_any_call("wg0", 1340, True)
|
# get_default_ifaces should be called with ignore_vpn=False (because prefer-wg)
|
||||||
|
mock_get_def.assert_called_with(ignore_vpn=False)
|
||||||
|
|
||||||
|
# ---------- auto-pmtu-from-wg adds peer endpoints ----------
|
||||||
|
|
||||||
|
@patch("main.wg_peer_endpoints", return_value=["46.4.224.77", "2a01:db8::1"])
|
||||||
|
@patch("main.wg_is_active", return_value=True)
|
||||||
|
@patch("main.probe_pmtu", side_effect=[1452, 1420]) # results for two peers
|
||||||
@patch("main.set_mtu")
|
@patch("main.set_mtu")
|
||||||
@patch("main.read_mtu", return_value=1500)
|
@patch("main.read_mtu", return_value=1500)
|
||||||
@patch("main.exists_iface", return_value=True)
|
@patch("main.exists_iface", return_value=True)
|
||||||
@patch("main.get_default_ifaces", return_value=["eth0"])
|
@patch("main.get_default_ifaces", return_value=["eth0"])
|
||||||
@patch("main.require_root", return_value=None)
|
@patch("main.require_root", return_value=None)
|
||||||
def test_pmtu_policy_median(
|
def test_auto_pmtu_from_wg_adds_targets_and_uses_min_policy(
|
||||||
self, _req_root, _get_def, _exists, _read_mtu, mock_set_mtu
|
self, _req_root, _get_def, _exists, _read_mtu, _set_mtu, _probe_pmtu, _wg_active, _wg_peers
|
||||||
):
|
):
|
||||||
"""
|
argv = ["main.py", "--dry-run", "--auto-pmtu-from-wg", "--wg-if", "wg0"]
|
||||||
base=1500; PMTUs: 1500, 1452, 1472 -> median=1472.
|
with patch.object(sys, "argv", argv):
|
||||||
effective=min(1500,1472)=1472; wg_mtu=1472-80=1392
|
buf = io.StringIO()
|
||||||
"""
|
with redirect_stdout(buf):
|
||||||
with patch("main.probe_pmtu", side_effect=[1500, 1452, 1472]):
|
automtu.main()
|
||||||
argv = [
|
|
||||||
"main.py",
|
|
||||||
"--dry-run",
|
|
||||||
"--pmtu-target", "a",
|
|
||||||
"--pmtu-target", "b",
|
|
||||||
"--pmtu-target", "c",
|
|
||||||
"--pmtu-policy", "median",
|
|
||||||
]
|
|
||||||
with patch.object(sys, "argv", argv):
|
|
||||||
buf = io.StringIO()
|
|
||||||
with redirect_stdout(buf):
|
|
||||||
automtu.main()
|
|
||||||
|
|
||||||
out = buf.getvalue()
|
out = buf.getvalue()
|
||||||
self.assertIn("Probing Path MTU for: a, b, c (policy=median)", out)
|
# Confirm WG peers were added
|
||||||
|
self.assertIn("Auto-added WG peer endpoints as PMTU targets: 46.4.224.77, 2a01:db8::1", out)
|
||||||
|
# The policy default is 'min', so chosen PMTU should be 1420
|
||||||
|
self.assertIn("Selected Path MTU (policy=min): 1420", out)
|
||||||
|
# Computed wg0 MTU: 1420 - 80 = 1340
|
||||||
|
self.assertIn("Computed wg0 MTU: 1340", out)
|
||||||
|
# Ensure probe was called twice (for both peers)
|
||||||
|
self.assertEqual(_probe_pmtu.call_count, 2)
|
||||||
|
|
||||||
|
# ---------- manual PMTU still works with prefer-wg-egress ----------
|
||||||
|
|
||||||
|
@patch("main.wg_default_is_active", return_value=True)
|
||||||
|
@patch("main.wg_is_active", return_value=True)
|
||||||
|
@patch("main.probe_pmtu", side_effect=[1472, 1452, 1500])
|
||||||
|
@patch("main.set_mtu")
|
||||||
|
@patch("main.read_mtu", return_value=1500)
|
||||||
|
@patch("main.exists_iface", return_value=True)
|
||||||
|
@patch("main.get_default_ifaces", return_value=["eth0"])
|
||||||
|
@patch("main.require_root", return_value=None)
|
||||||
|
def test_prefer_wg_egress_with_manual_targets_and_median_policy(
|
||||||
|
self, _req_root, _get_def, _exists, _read_mtu, _set_mtu, _probe_pmtu, _wg_is_active, _wg_def_active
|
||||||
|
):
|
||||||
|
argv = [
|
||||||
|
"main.py", "--dry-run",
|
||||||
|
"--prefer-wg-egress", "--wg-if", "wg0",
|
||||||
|
"--pmtu-target", "a", "--pmtu-target", "b", "--pmtu-target", "c",
|
||||||
|
"--pmtu-policy", "median"
|
||||||
|
]
|
||||||
|
with patch.object(sys, "argv", argv):
|
||||||
|
buf = io.StringIO()
|
||||||
|
with redirect_stdout(buf):
|
||||||
|
automtu.main()
|
||||||
|
|
||||||
|
out = buf.getvalue()
|
||||||
|
# As default route via wg is active, wg0 should be used
|
||||||
|
self.assertIn("Detected egress interface: wg0", out)
|
||||||
|
# PMTU values: 1472, 1452, 1500 -> median = 1472
|
||||||
self.assertIn("Selected Path MTU (policy=median): 1472", out)
|
self.assertIn("Selected Path MTU (policy=median): 1472", out)
|
||||||
|
# Computed WG MTU: 1472 - 80 = 1392
|
||||||
self.assertIn("Computed wg0 MTU: 1392", out)
|
self.assertIn("Computed wg0 MTU: 1392", out)
|
||||||
mock_set_mtu.assert_any_call("wg0", 1392, True)
|
self.assertEqual(_probe_pmtu.call_count, 3)
|
||||||
|
|
||||||
|
# ---------- PMTU all fail fallback ----------
|
||||||
|
|
||||||
@patch("main.set_mtu")
|
@patch("main.set_mtu")
|
||||||
@patch("main.read_mtu", return_value=1500)
|
@patch("main.read_mtu", return_value=1500)
|
||||||
@@ -109,16 +132,8 @@ class TestWgMtuAuto(unittest.TestCase):
|
|||||||
def test_pmtu_all_fail_falls_back_to_base(
|
def test_pmtu_all_fail_falls_back_to_base(
|
||||||
self, _req_root, _get_def, _exists, _read_mtu, mock_set_mtu
|
self, _req_root, _get_def, _exists, _read_mtu, mock_set_mtu
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
If all PMTU probes fail, fall back to base MTU (1500) => wg_mtu=1420.
|
|
||||||
"""
|
|
||||||
with patch("main.probe_pmtu", side_effect=[None, None]):
|
with patch("main.probe_pmtu", side_effect=[None, None]):
|
||||||
argv = [
|
argv = ["main.py", "--dry-run", "--pmtu-target", "bad1", "--pmtu-target", "bad2"]
|
||||||
"main.py",
|
|
||||||
"--dry-run",
|
|
||||||
"--pmtu-target", "bad1",
|
|
||||||
"--pmtu-target", "bad2",
|
|
||||||
]
|
|
||||||
with patch.object(sys, "argv", argv):
|
with patch.object(sys, "argv", argv):
|
||||||
buf = io.StringIO()
|
buf = io.StringIO()
|
||||||
with redirect_stdout(buf):
|
with redirect_stdout(buf):
|
||||||
@@ -126,9 +141,47 @@ class TestWgMtuAuto(unittest.TestCase):
|
|||||||
|
|
||||||
out = buf.getvalue()
|
out = buf.getvalue()
|
||||||
self.assertIn("WARNING: All PMTU probes failed. Falling back to egress MTU.", out)
|
self.assertIn("WARNING: All PMTU probes failed. Falling back to egress MTU.", out)
|
||||||
self.assertIn("Computed wg0 MTU: 1420", out)
|
self.assertIn("Computed wg0 MTU: 1420", out) # 1500 - 80
|
||||||
mock_set_mtu.assert_any_call("wg0", 1420, True)
|
mock_set_mtu.assert_any_call("wg0", 1420, True)
|
||||||
|
|
||||||
|
# ---------- NEW: --set-wg-mtu overrides computed ----------
|
||||||
|
|
||||||
|
@patch("main.set_mtu")
|
||||||
|
@patch("main.read_mtu", return_value=1500)
|
||||||
|
@patch("main.exists_iface", return_value=True)
|
||||||
|
@patch("main.get_default_ifaces", return_value=["eth0"])
|
||||||
|
@patch("main.require_root", return_value=None)
|
||||||
|
def test_force_set_wg_mtu_overrides_computed(
|
||||||
|
self, _req_root, _get_def, _exists, _read_mtu, mock_set_mtu
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
--set-wg-mtu must override the computed value.
|
||||||
|
Base=1500 -> computed 1420 (1500-80), but we force 1300.
|
||||||
|
"""
|
||||||
|
argv = ["main.py", "--dry-run", "--set-wg-mtu", "1300"]
|
||||||
|
with patch.object(sys, "argv", argv):
|
||||||
|
buf = io.StringIO()
|
||||||
|
with redirect_stdout(buf):
|
||||||
|
automtu.main()
|
||||||
|
|
||||||
|
out = buf.getvalue()
|
||||||
|
# Computation is printed first
|
||||||
|
self.assertIn("Computed wg0 MTU: 1420", out)
|
||||||
|
# Then override message appears and applied value is 1300
|
||||||
|
self.assertIn("Forcing WireGuard MTU (override): 1300", out)
|
||||||
|
mock_set_mtu.assert_any_call("wg0", 1300, True)
|
||||||
|
|
||||||
|
# also test clamping below wg-min
|
||||||
|
argv2 = ["main.py", "--dry-run", "--set-wg-mtu", "1200"] # below default wg_min=1280
|
||||||
|
with patch.object(sys, "argv", argv2):
|
||||||
|
out2 = io.StringIO()
|
||||||
|
with redirect_stdout(out2):
|
||||||
|
automtu.main()
|
||||||
|
s = out2.getvalue()
|
||||||
|
self.assertIn("[wg-mtu][WARN] --set-wg-mtu 1200 is below wg-min 1280; clamping to 1280.", s)
|
||||||
|
self.assertIn("Forcing WireGuard MTU (override): 1280", s)
|
||||||
|
mock_set_mtu.assert_any_call("wg0", 1280, True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|||||||
Reference in New Issue
Block a user