diff --git a/module_utils/cert_utils.py b/module_utils/cert_utils.py index 22e7970c..2ab7d032 100644 --- a/module_utils/cert_utils.py +++ b/module_utils/cert_utils.py @@ -6,6 +6,7 @@ __metaclass__ = type import os import subprocess import time +from datetime import datetime class CertUtils: _domain_cert_mapping = None @@ -22,6 +23,30 @@ class CertUtils: except subprocess.CalledProcessError: return "" + @staticmethod + def run_openssl_dates(cert_path): + """ + Returns (not_before_ts, not_after_ts) as POSIX timestamps or (None, None) on failure. + """ + try: + output = subprocess.check_output( + ['openssl', 'x509', '-in', cert_path, '-noout', '-startdate', '-enddate'], + universal_newlines=True + ) + nb, na = None, None + for line in output.splitlines(): + line = line.strip() + if line.startswith('notBefore='): + nb = line.split('=', 1)[1].strip() + elif line.startswith('notAfter='): + na = line.split('=', 1)[1].strip() + def _parse(openssl_dt): + # OpenSSL format example: "Oct 10 12:34:56 2025 GMT" + return int(datetime.strptime(openssl_dt, "%b %d %H:%M:%S %Y %Z").timestamp()) + return (_parse(nb) if nb else None, _parse(na) if na else None) + except Exception: + return (None, None) + @staticmethod def extract_sans(cert_text): dns_entries = [] @@ -59,7 +84,6 @@ class CertUtils: else: return domain == san - @classmethod def build_snapshot(cls, cert_base_path): snapshot = [] @@ -82,6 +106,17 @@ class CertUtils: @classmethod def refresh_cert_mapping(cls, cert_base_path, debug=False): + """ + Build mapping: SAN -> list of entries + entry = { + 'folder': str, + 'cert_path': str, + 'mtime': float, + 'not_before': int|None, + 'not_after': int|None, + 'is_wildcard': bool + } + """ cert_files = cls.list_cert_files(cert_base_path) mapping = {} for cert_path in cert_files: @@ -90,46 +125,82 @@ class CertUtils: continue sans = cls.extract_sans(cert_text) folder = os.path.basename(os.path.dirname(cert_path)) + try: + mtime = os.stat(cert_path).st_mtime + except FileNotFoundError: + mtime = 0.0 + nb, na = cls.run_openssl_dates(cert_path) + for san in sans: - if san not in mapping: - mapping[san] = folder + entry = { + 'folder': folder, + 'cert_path': cert_path, + 'mtime': mtime, + 'not_before': nb, + 'not_after': na, + 'is_wildcard': san.startswith('*.'), + } + mapping.setdefault(san, []).append(entry) + cls._domain_cert_mapping = mapping if debug: - print(f"[DEBUG] Refreshed domain-to-cert mapping: {mapping}") + print(f"[DEBUG] Refreshed domain-to-cert mapping (counts): " + f"{ {k: len(v) for k, v in mapping.items()} }") @classmethod def ensure_cert_mapping(cls, cert_base_path, debug=False): if cls._domain_cert_mapping is None or cls.snapshot_changed(cert_base_path): cls.refresh_cert_mapping(cert_base_path, debug) + @staticmethod + def _score_entry(entry): + """ + Return tuple used for sorting newest-first: + (not_before or -inf, mtime) + """ + nb = entry.get('not_before') + mtime = entry.get('mtime', 0.0) + return (nb if nb is not None else -1, mtime) + @classmethod def find_cert_for_domain(cls, domain, cert_base_path, debug=False): cls.ensure_cert_mapping(cert_base_path, debug) - exact_match = None - wildcard_match = None + candidates_exact = [] + candidates_wild = [] - for san, folder in cls._domain_cert_mapping.items(): + for san, entries in cls._domain_cert_mapping.items(): if san == domain: - exact_match = folder - break - if san.startswith('*.'): + candidates_exact.extend(entries) + elif san.startswith('*.'): base = san[2:] if domain.count('.') == base.count('.') + 1 and domain.endswith('.' + base): - wildcard_match = folder + candidates_wild.extend(entries) - if exact_match: - if debug: - print(f"[DEBUG] Exact match for {domain} found in {exact_match}") - return exact_match + def _pick_newest(entries): + if not entries: + return None + # newest by (not_before, mtime) + best = max(entries, key=cls._score_entry) + return best - if wildcard_match: - if debug: - print(f"[DEBUG] Wildcard match for {domain} found in {wildcard_match}") - return wildcard_match + best_exact = _pick_newest(candidates_exact) + best_wild = _pick_newest(candidates_wild) + + if best_exact and debug: + print(f"[DEBUG] Best exact match for {domain}: {best_exact['folder']} " + f"(not_before={best_exact['not_before']}, mtime={best_exact['mtime']})") + if best_wild and debug: + print(f"[DEBUG] Best wildcard match for {domain}: {best_wild['folder']} " + f"(not_before={best_wild['not_before']}, mtime={best_wild['mtime']})") + + # Prefer exact if it exists; otherwise wildcard + chosen = best_exact or best_wild + + if chosen: + return chosen['folder'] if debug: print(f"[DEBUG] No certificate folder found for {domain}") return None - diff --git a/tests/unit/module_utils/test_cert_utils_newest.py b/tests/unit/module_utils/test_cert_utils_newest.py new file mode 100644 index 00000000..aca77d13 --- /dev/null +++ b/tests/unit/module_utils/test_cert_utils_newest.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +import os +import sys +import unittest +from types import SimpleNamespace +from unittest.mock import patch + +# Add the project root/module_utils to the import path +CURRENT_DIR = os.path.dirname(__file__) +PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "../../..")) +sys.path.insert(0, PROJECT_ROOT) + +from module_utils.cert_utils import CertUtils + + +def _san_block(*entries): + """ + Helper: builds a minimal OpenSSL text snippet that contains SAN entries. + Example: _san_block('example.com', '*.example.com') + """ + sans = ", ".join(f"DNS:{e}" for e in entries) + return f""" +Certificate: + Data: + Version: 3 (0x2) + ... + X509v3 extensions: + X509v3 Subject Alternative Name: + {sans} + """ + + +class TestCertUtilsFindNewest(unittest.TestCase): + def setUp(self): + # Reset internal caches before each test + CertUtils._domain_cert_mapping = None + CertUtils._cert_snapshot = None + + def _mock_stat_map(self, mtime_map, size_map=None): + size_map = size_map or {} + def _stat_side_effect(path): + return SimpleNamespace( + st_mtime=mtime_map.get(path, 0.0), + st_size=size_map.get(path, 1234), + ) + return _stat_side_effect + + def test_prefers_newest_by_not_before(self): + """ + Two certs with the same SAN 'www.example.com': + - a/cert.pem: older notBefore + - b/cert.pem: newer notBefore -> should be selected + """ + files = [ + "/etc/letsencrypt/live/a/cert.pem", + "/etc/letsencrypt/live/b/cert.pem", + ] + san_text = _san_block("www.example.com") + + with patch.object(CertUtils, "list_cert_files", return_value=files), \ + patch.object(CertUtils, "run_openssl", return_value=san_text), \ + patch.object(CertUtils, "run_openssl_dates") as mock_dates, \ + patch("os.stat", side_effect=self._mock_stat_map({ + files[0]: 1000, + files[1]: 1001, + })): + + mock_dates.side_effect = [(10, 100000), (20, 100000)] # older/newer + + folder = CertUtils.find_cert_for_domain("www.example.com", "/etc/letsencrypt/live", debug=False) + self.assertEqual(folder, "b", "Should return the folder with the newest notBefore date.") + + def test_fallback_to_mtime_when_not_before_missing(self): + """ + When not_before is missing, mtime should be used as a fallback. + """ + files = [ + "/etc/letsencrypt/live/a/cert.pem", + "/etc/letsencrypt/live/b/cert.pem", + ] + san_text = _san_block("www.example.com") + + with patch.object(CertUtils, "list_cert_files", return_value=files), \ + patch.object(CertUtils, "run_openssl", return_value=san_text), \ + patch.object(CertUtils, "run_openssl_dates", return_value=(None, None)), \ + patch("os.stat", side_effect=self._mock_stat_map({ + files[0]: 1000, + files[1]: 2000, + })): + + folder = CertUtils.find_cert_for_domain("www.example.com", "/etc/letsencrypt/live", debug=False) + self.assertEqual(folder, "b", "Should fall back to mtime and select the newest file.") + + def test_exact_beats_wildcard_even_if_wildcard_newer(self): + """ + Exact matches must take precedence over wildcard matches, + even if the wildcard certificate is newer. + """ + files = [ + "/etc/letsencrypt/live/exact/cert.pem", + "/etc/letsencrypt/live/wild/cert.pem", + ] + text_exact = _san_block("api.example.com") + text_wild = _san_block("*.example.com") + + with patch.object(CertUtils, "list_cert_files", return_value=files), \ + patch.object(CertUtils, "run_openssl") as mock_text, \ + patch.object(CertUtils, "run_openssl_dates") as mock_dates, \ + patch("os.stat", side_effect=self._mock_stat_map({ + files[0]: 1000, # exact is older + files[1]: 5000, # wildcard is much newer + })): + + mock_text.side_effect = [text_exact, text_wild] + mock_dates.side_effect = [(10, 100000), (99, 100000)] + + folder = CertUtils.find_cert_for_domain("api.example.com", "/etc/letsencrypt/live", debug=False) + self.assertEqual( + folder, "exact", + "Exact match must win even if the wildcard certificate is newer." + ) + + def test_wildcard_one_label_only(self): + """ + Wildcards (*.example.com) must only match one additional label. + """ + files = ["/etc/letsencrypt/live/wild/cert.pem"] + text_wild = _san_block("*.example.com") + + with patch.object(CertUtils, "list_cert_files", return_value=files), \ + patch.object(CertUtils, "run_openssl", return_value=text_wild), \ + patch.object(CertUtils, "run_openssl_dates", return_value=(50, 100000)), \ + patch("os.stat", side_effect=self._mock_stat_map({files[0]: 1000})): + + # should match + self.assertEqual( + CertUtils.find_cert_for_domain("api.example.com", "/etc/letsencrypt/live"), + "wild" + ) + # too deep -> should not match + self.assertIsNone( + CertUtils.find_cert_for_domain("deep.api.example.com", "/etc/letsencrypt/live"), + "Wildcard must not match multiple labels." + ) + # base domain not covered + self.assertIsNone( + CertUtils.find_cert_for_domain("example.com", "/etc/letsencrypt/live"), + "Base domain is not covered by *.example.com." + ) + + def test_snapshot_refresh_rebuilds_mapping(self): + """ + ensure_cert_mapping() should rebuild mapping when snapshot changes. + """ + CertUtils._domain_cert_mapping = {"www.example.com": [{"folder": "old", "mtime": 1, "not_before": 1}]} + + with patch.object(CertUtils, "snapshot_changed", return_value=True), \ + patch.object(CertUtils, "refresh_cert_mapping") as mock_refresh: + + def _set_new_mapping(cert_base_path, debug=False): + CertUtils._domain_cert_mapping = { + "www.example.com": [{"folder": "new", "mtime": 999, "not_before": 999}] + } + + mock_refresh.side_effect = _set_new_mapping + + folder = CertUtils.find_cert_for_domain("www.example.com", "/etc/letsencrypt/live", debug=False) + self.assertEqual(folder, "new", "Mapping must be refreshed when snapshot changes.") + + +if __name__ == "__main__": + unittest.main()