mirror of
https://github.com/kevinveenbirkenbach/computer-playbook.git
synced 2025-11-05 20:58:21 +00:00
Enhance CertUtils to return the newest matching certificate and add comprehensive unit tests
- Added run_openssl_dates() to extract notBefore/notAfter timestamps. - Modified mapping logic to store multiple cert entries per SAN with metadata. - find_cert_for_domain() now selects the newest certificate based on notBefore and mtime. - Exact SAN matches take precedence over wildcard matches. - Added new unit tests (test_cert_utils_newest.py) verifying freshness logic, fallback handling, and wildcard behavior. Reference: https://chatgpt.com/share/68ef4b4c-41d4-800f-9e50-5da4b6be1105
This commit is contained in:
@@ -6,6 +6,7 @@ __metaclass__ = type
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
class CertUtils:
|
class CertUtils:
|
||||||
_domain_cert_mapping = None
|
_domain_cert_mapping = None
|
||||||
@@ -22,6 +23,30 @@ class CertUtils:
|
|||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run_openssl_dates(cert_path):
|
||||||
|
"""
|
||||||
|
Returns (not_before_ts, not_after_ts) as POSIX timestamps or (None, None) on failure.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(
|
||||||
|
['openssl', 'x509', '-in', cert_path, '-noout', '-startdate', '-enddate'],
|
||||||
|
universal_newlines=True
|
||||||
|
)
|
||||||
|
nb, na = None, None
|
||||||
|
for line in output.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith('notBefore='):
|
||||||
|
nb = line.split('=', 1)[1].strip()
|
||||||
|
elif line.startswith('notAfter='):
|
||||||
|
na = line.split('=', 1)[1].strip()
|
||||||
|
def _parse(openssl_dt):
|
||||||
|
# OpenSSL format example: "Oct 10 12:34:56 2025 GMT"
|
||||||
|
return int(datetime.strptime(openssl_dt, "%b %d %H:%M:%S %Y %Z").timestamp())
|
||||||
|
return (_parse(nb) if nb else None, _parse(na) if na else None)
|
||||||
|
except Exception:
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_sans(cert_text):
|
def extract_sans(cert_text):
|
||||||
dns_entries = []
|
dns_entries = []
|
||||||
@@ -59,7 +84,6 @@ class CertUtils:
|
|||||||
else:
|
else:
|
||||||
return domain == san
|
return domain == san
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_snapshot(cls, cert_base_path):
|
def build_snapshot(cls, cert_base_path):
|
||||||
snapshot = []
|
snapshot = []
|
||||||
@@ -82,6 +106,17 @@ class CertUtils:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def refresh_cert_mapping(cls, cert_base_path, debug=False):
|
def refresh_cert_mapping(cls, cert_base_path, debug=False):
|
||||||
|
"""
|
||||||
|
Build mapping: SAN -> list of entries
|
||||||
|
entry = {
|
||||||
|
'folder': str,
|
||||||
|
'cert_path': str,
|
||||||
|
'mtime': float,
|
||||||
|
'not_before': int|None,
|
||||||
|
'not_after': int|None,
|
||||||
|
'is_wildcard': bool
|
||||||
|
}
|
||||||
|
"""
|
||||||
cert_files = cls.list_cert_files(cert_base_path)
|
cert_files = cls.list_cert_files(cert_base_path)
|
||||||
mapping = {}
|
mapping = {}
|
||||||
for cert_path in cert_files:
|
for cert_path in cert_files:
|
||||||
@@ -90,46 +125,82 @@ class CertUtils:
|
|||||||
continue
|
continue
|
||||||
sans = cls.extract_sans(cert_text)
|
sans = cls.extract_sans(cert_text)
|
||||||
folder = os.path.basename(os.path.dirname(cert_path))
|
folder = os.path.basename(os.path.dirname(cert_path))
|
||||||
|
try:
|
||||||
|
mtime = os.stat(cert_path).st_mtime
|
||||||
|
except FileNotFoundError:
|
||||||
|
mtime = 0.0
|
||||||
|
nb, na = cls.run_openssl_dates(cert_path)
|
||||||
|
|
||||||
for san in sans:
|
for san in sans:
|
||||||
if san not in mapping:
|
entry = {
|
||||||
mapping[san] = folder
|
'folder': folder,
|
||||||
|
'cert_path': cert_path,
|
||||||
|
'mtime': mtime,
|
||||||
|
'not_before': nb,
|
||||||
|
'not_after': na,
|
||||||
|
'is_wildcard': san.startswith('*.'),
|
||||||
|
}
|
||||||
|
mapping.setdefault(san, []).append(entry)
|
||||||
|
|
||||||
cls._domain_cert_mapping = mapping
|
cls._domain_cert_mapping = mapping
|
||||||
if debug:
|
if debug:
|
||||||
print(f"[DEBUG] Refreshed domain-to-cert mapping: {mapping}")
|
print(f"[DEBUG] Refreshed domain-to-cert mapping (counts): "
|
||||||
|
f"{ {k: len(v) for k, v in mapping.items()} }")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def ensure_cert_mapping(cls, cert_base_path, debug=False):
|
def ensure_cert_mapping(cls, cert_base_path, debug=False):
|
||||||
if cls._domain_cert_mapping is None or cls.snapshot_changed(cert_base_path):
|
if cls._domain_cert_mapping is None or cls.snapshot_changed(cert_base_path):
|
||||||
cls.refresh_cert_mapping(cert_base_path, debug)
|
cls.refresh_cert_mapping(cert_base_path, debug)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _score_entry(entry):
|
||||||
|
"""
|
||||||
|
Return tuple used for sorting newest-first:
|
||||||
|
(not_before or -inf, mtime)
|
||||||
|
"""
|
||||||
|
nb = entry.get('not_before')
|
||||||
|
mtime = entry.get('mtime', 0.0)
|
||||||
|
return (nb if nb is not None else -1, mtime)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_cert_for_domain(cls, domain, cert_base_path, debug=False):
|
def find_cert_for_domain(cls, domain, cert_base_path, debug=False):
|
||||||
cls.ensure_cert_mapping(cert_base_path, debug)
|
cls.ensure_cert_mapping(cert_base_path, debug)
|
||||||
|
|
||||||
exact_match = None
|
candidates_exact = []
|
||||||
wildcard_match = None
|
candidates_wild = []
|
||||||
|
|
||||||
for san, folder in cls._domain_cert_mapping.items():
|
for san, entries in cls._domain_cert_mapping.items():
|
||||||
if san == domain:
|
if san == domain:
|
||||||
exact_match = folder
|
candidates_exact.extend(entries)
|
||||||
break
|
elif san.startswith('*.'):
|
||||||
if san.startswith('*.'):
|
|
||||||
base = san[2:]
|
base = san[2:]
|
||||||
if domain.count('.') == base.count('.') + 1 and domain.endswith('.' + base):
|
if domain.count('.') == base.count('.') + 1 and domain.endswith('.' + base):
|
||||||
wildcard_match = folder
|
candidates_wild.extend(entries)
|
||||||
|
|
||||||
if exact_match:
|
def _pick_newest(entries):
|
||||||
if debug:
|
if not entries:
|
||||||
print(f"[DEBUG] Exact match for {domain} found in {exact_match}")
|
return None
|
||||||
return exact_match
|
# newest by (not_before, mtime)
|
||||||
|
best = max(entries, key=cls._score_entry)
|
||||||
|
return best
|
||||||
|
|
||||||
if wildcard_match:
|
best_exact = _pick_newest(candidates_exact)
|
||||||
if debug:
|
best_wild = _pick_newest(candidates_wild)
|
||||||
print(f"[DEBUG] Wildcard match for {domain} found in {wildcard_match}")
|
|
||||||
return wildcard_match
|
if best_exact and debug:
|
||||||
|
print(f"[DEBUG] Best exact match for {domain}: {best_exact['folder']} "
|
||||||
|
f"(not_before={best_exact['not_before']}, mtime={best_exact['mtime']})")
|
||||||
|
if best_wild and debug:
|
||||||
|
print(f"[DEBUG] Best wildcard match for {domain}: {best_wild['folder']} "
|
||||||
|
f"(not_before={best_wild['not_before']}, mtime={best_wild['mtime']})")
|
||||||
|
|
||||||
|
# Prefer exact if it exists; otherwise wildcard
|
||||||
|
chosen = best_exact or best_wild
|
||||||
|
|
||||||
|
if chosen:
|
||||||
|
return chosen['folder']
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(f"[DEBUG] No certificate folder found for {domain}")
|
print(f"[DEBUG] No certificate folder found for {domain}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
172
tests/unit/module_utils/test_cert_utils_newest.py
Normal file
172
tests/unit/module_utils/test_cert_utils_newest.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
# Add the project root/module_utils to the import path
|
||||||
|
CURRENT_DIR = os.path.dirname(__file__)
|
||||||
|
PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "../../.."))
|
||||||
|
sys.path.insert(0, PROJECT_ROOT)
|
||||||
|
|
||||||
|
from module_utils.cert_utils import CertUtils
|
||||||
|
|
||||||
|
|
||||||
|
def _san_block(*entries):
|
||||||
|
"""
|
||||||
|
Helper: builds a minimal OpenSSL text snippet that contains SAN entries.
|
||||||
|
Example: _san_block('example.com', '*.example.com')
|
||||||
|
"""
|
||||||
|
sans = ", ".join(f"DNS:{e}" for e in entries)
|
||||||
|
return f"""
|
||||||
|
Certificate:
|
||||||
|
Data:
|
||||||
|
Version: 3 (0x2)
|
||||||
|
...
|
||||||
|
X509v3 extensions:
|
||||||
|
X509v3 Subject Alternative Name:
|
||||||
|
{sans}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TestCertUtilsFindNewest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Reset internal caches before each test
|
||||||
|
CertUtils._domain_cert_mapping = None
|
||||||
|
CertUtils._cert_snapshot = None
|
||||||
|
|
||||||
|
def _mock_stat_map(self, mtime_map, size_map=None):
|
||||||
|
size_map = size_map or {}
|
||||||
|
def _stat_side_effect(path):
|
||||||
|
return SimpleNamespace(
|
||||||
|
st_mtime=mtime_map.get(path, 0.0),
|
||||||
|
st_size=size_map.get(path, 1234),
|
||||||
|
)
|
||||||
|
return _stat_side_effect
|
||||||
|
|
||||||
|
def test_prefers_newest_by_not_before(self):
|
||||||
|
"""
|
||||||
|
Two certs with the same SAN 'www.example.com':
|
||||||
|
- a/cert.pem: older notBefore
|
||||||
|
- b/cert.pem: newer notBefore -> should be selected
|
||||||
|
"""
|
||||||
|
files = [
|
||||||
|
"/etc/letsencrypt/live/a/cert.pem",
|
||||||
|
"/etc/letsencrypt/live/b/cert.pem",
|
||||||
|
]
|
||||||
|
san_text = _san_block("www.example.com")
|
||||||
|
|
||||||
|
with patch.object(CertUtils, "list_cert_files", return_value=files), \
|
||||||
|
patch.object(CertUtils, "run_openssl", return_value=san_text), \
|
||||||
|
patch.object(CertUtils, "run_openssl_dates") as mock_dates, \
|
||||||
|
patch("os.stat", side_effect=self._mock_stat_map({
|
||||||
|
files[0]: 1000,
|
||||||
|
files[1]: 1001,
|
||||||
|
})):
|
||||||
|
|
||||||
|
mock_dates.side_effect = [(10, 100000), (20, 100000)] # older/newer
|
||||||
|
|
||||||
|
folder = CertUtils.find_cert_for_domain("www.example.com", "/etc/letsencrypt/live", debug=False)
|
||||||
|
self.assertEqual(folder, "b", "Should return the folder with the newest notBefore date.")
|
||||||
|
|
||||||
|
def test_fallback_to_mtime_when_not_before_missing(self):
|
||||||
|
"""
|
||||||
|
When not_before is missing, mtime should be used as a fallback.
|
||||||
|
"""
|
||||||
|
files = [
|
||||||
|
"/etc/letsencrypt/live/a/cert.pem",
|
||||||
|
"/etc/letsencrypt/live/b/cert.pem",
|
||||||
|
]
|
||||||
|
san_text = _san_block("www.example.com")
|
||||||
|
|
||||||
|
with patch.object(CertUtils, "list_cert_files", return_value=files), \
|
||||||
|
patch.object(CertUtils, "run_openssl", return_value=san_text), \
|
||||||
|
patch.object(CertUtils, "run_openssl_dates", return_value=(None, None)), \
|
||||||
|
patch("os.stat", side_effect=self._mock_stat_map({
|
||||||
|
files[0]: 1000,
|
||||||
|
files[1]: 2000,
|
||||||
|
})):
|
||||||
|
|
||||||
|
folder = CertUtils.find_cert_for_domain("www.example.com", "/etc/letsencrypt/live", debug=False)
|
||||||
|
self.assertEqual(folder, "b", "Should fall back to mtime and select the newest file.")
|
||||||
|
|
||||||
|
def test_exact_beats_wildcard_even_if_wildcard_newer(self):
|
||||||
|
"""
|
||||||
|
Exact matches must take precedence over wildcard matches,
|
||||||
|
even if the wildcard certificate is newer.
|
||||||
|
"""
|
||||||
|
files = [
|
||||||
|
"/etc/letsencrypt/live/exact/cert.pem",
|
||||||
|
"/etc/letsencrypt/live/wild/cert.pem",
|
||||||
|
]
|
||||||
|
text_exact = _san_block("api.example.com")
|
||||||
|
text_wild = _san_block("*.example.com")
|
||||||
|
|
||||||
|
with patch.object(CertUtils, "list_cert_files", return_value=files), \
|
||||||
|
patch.object(CertUtils, "run_openssl") as mock_text, \
|
||||||
|
patch.object(CertUtils, "run_openssl_dates") as mock_dates, \
|
||||||
|
patch("os.stat", side_effect=self._mock_stat_map({
|
||||||
|
files[0]: 1000, # exact is older
|
||||||
|
files[1]: 5000, # wildcard is much newer
|
||||||
|
})):
|
||||||
|
|
||||||
|
mock_text.side_effect = [text_exact, text_wild]
|
||||||
|
mock_dates.side_effect = [(10, 100000), (99, 100000)]
|
||||||
|
|
||||||
|
folder = CertUtils.find_cert_for_domain("api.example.com", "/etc/letsencrypt/live", debug=False)
|
||||||
|
self.assertEqual(
|
||||||
|
folder, "exact",
|
||||||
|
"Exact match must win even if the wildcard certificate is newer."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_wildcard_one_label_only(self):
|
||||||
|
"""
|
||||||
|
Wildcards (*.example.com) must only match one additional label.
|
||||||
|
"""
|
||||||
|
files = ["/etc/letsencrypt/live/wild/cert.pem"]
|
||||||
|
text_wild = _san_block("*.example.com")
|
||||||
|
|
||||||
|
with patch.object(CertUtils, "list_cert_files", return_value=files), \
|
||||||
|
patch.object(CertUtils, "run_openssl", return_value=text_wild), \
|
||||||
|
patch.object(CertUtils, "run_openssl_dates", return_value=(50, 100000)), \
|
||||||
|
patch("os.stat", side_effect=self._mock_stat_map({files[0]: 1000})):
|
||||||
|
|
||||||
|
# should match
|
||||||
|
self.assertEqual(
|
||||||
|
CertUtils.find_cert_for_domain("api.example.com", "/etc/letsencrypt/live"),
|
||||||
|
"wild"
|
||||||
|
)
|
||||||
|
# too deep -> should not match
|
||||||
|
self.assertIsNone(
|
||||||
|
CertUtils.find_cert_for_domain("deep.api.example.com", "/etc/letsencrypt/live"),
|
||||||
|
"Wildcard must not match multiple labels."
|
||||||
|
)
|
||||||
|
# base domain not covered
|
||||||
|
self.assertIsNone(
|
||||||
|
CertUtils.find_cert_for_domain("example.com", "/etc/letsencrypt/live"),
|
||||||
|
"Base domain is not covered by *.example.com."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_snapshot_refresh_rebuilds_mapping(self):
|
||||||
|
"""
|
||||||
|
ensure_cert_mapping() should rebuild mapping when snapshot changes.
|
||||||
|
"""
|
||||||
|
CertUtils._domain_cert_mapping = {"www.example.com": [{"folder": "old", "mtime": 1, "not_before": 1}]}
|
||||||
|
|
||||||
|
with patch.object(CertUtils, "snapshot_changed", return_value=True), \
|
||||||
|
patch.object(CertUtils, "refresh_cert_mapping") as mock_refresh:
|
||||||
|
|
||||||
|
def _set_new_mapping(cert_base_path, debug=False):
|
||||||
|
CertUtils._domain_cert_mapping = {
|
||||||
|
"www.example.com": [{"folder": "new", "mtime": 999, "not_before": 999}]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_refresh.side_effect = _set_new_mapping
|
||||||
|
|
||||||
|
folder = CertUtils.find_cert_for_domain("www.example.com", "/etc/letsencrypt/live", debug=False)
|
||||||
|
self.assertEqual(folder, "new", "Mapping must be refreshed when snapshot changes.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user