diff --git a/library/cert_check_exists.py b/library/cert_check_exists.py index e6d34ff0..56237f14 100644 --- a/library/cert_check_exists.py +++ b/library/cert_check_exists.py @@ -1,36 +1,7 @@ -#!/usr/bin/python -from __future__ import absolute_import, division, print_function -__metaclass__ = type - -import os from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.cert_utils import CertUtils -def cert_exists(domain, cert_files, debug=False): - for cert_path in cert_files: - cert_text = CertUtils.run_openssl(cert_path) - if not cert_text: - continue - sans = CertUtils.extract_sans(cert_text) - if debug: - print(f"Checking {cert_path}: {sans}") - for entry in sans: - if CertUtils.matches(domain, entry): - return True - return False - -def cert_check_exists(module): - domain = module.params['domain'] - cert_base_path = module.params['cert_base_path'] - debug = module.params['debug'] - - cert_files = CertUtils.list_cert_files(cert_base_path) - - exists = cert_exists(domain, cert_files, debug) - - module.exit_json(exists=exists) - def main(): module_args = dict( domain=dict(type='str', required=True), @@ -39,11 +10,17 @@ def main(): ) module = AnsibleModule( - argument_spec=module_args, - supports_check_mode=True + argument_spec=module_args ) - cert_check_exists(module) + domain = module.params['domain'] + cert_base_path = module.params['cert_base_path'] + debug = module.params['debug'] + + folder = CertUtils.find_cert_for_domain(domain, cert_base_path, debug) + exists = folder is not None + + module.exit_json(exists=exists) if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git a/library/cert_folder_find.py b/library/cert_folder_find.py index ff0bd633..3d15edc8 100644 --- a/library/cert_folder_find.py +++ b/library/cert_folder_find.py @@ -1,48 +1,6 @@ -#!/usr/bin/python - -from __future__ import absolute_import, division, print_function -__metaclass__ = type - -import os from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.cert_utils import CertUtils -def cert_folder_find(module): - domain = module.params['domain'] - cert_base_path = module.params['cert_base_path'] - debug = module.params['debug'] - - cert_files = CertUtils.list_cert_files(cert_base_path) - - if debug: - print(f"Found {len(cert_files)} cert.pem files under {cert_base_path}") - - matching_folders = [] - - for cert_path in cert_files: - cert_text = CertUtils.run_openssl(cert_path) - if not cert_text: - continue - sans = CertUtils.extract_sans(cert_text) - if debug: - print(f"Checking {cert_path}: {sans}") - for entry in sans: - if CertUtils.matches(domain, entry): - folder = os.path.basename(os.path.dirname(cert_path)) - matching_folders.append(folder) - if debug: - print(f"Match found in folder: {folder}") - break # No need to check further SANs for this cert - - if not matching_folders: - # No matching cert found - module.exit_json(folder=None) - - # Prefer shortest and least-dashed folder name (SAN bundles often have more dashes) - matching_folders = sorted(matching_folders, key=lambda f: (f.count('-'), len(f))) - - module.exit_json(folder=matching_folders[0]) - def main(): module_args = dict( domain=dict(type='str', required=True), @@ -51,11 +9,19 @@ def main(): ) module = AnsibleModule( - argument_spec=module_args, - supports_check_mode=True + argument_spec=module_args ) - cert_folder_find(module) + domain = module.params['domain'] + cert_base_path = module.params['cert_base_path'] + debug = module.params['debug'] + + folder = CertUtils.find_cert_for_domain(domain, cert_base_path, debug) + + if folder is None: + module.fail_json(msg=f"No certificate covering domain {domain} found.") + else: + module.exit_json(folder=folder) if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/main.py b/main.py index b4d06cfd..73f1e01c 100755 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ def run_ansible_vault(action, filename, password_file): cmd = ["ansible-vault", action, filename, "--vault-password-file", password_file] subprocess.run(cmd, check=True) -def run_ansible_playbook(inventory: str, playbook: str, modes: dict, limit: str = None, password_file: str = None, verbose: int = 0): +def run_ansible_playbook(inventory: str, playbook: str, modes: dict, limit: str = None, password_file: str = None, verbose: int = 0, skip_tests: bool = False): """Execute an ansible-playbook command with optional parameters.""" cmd = ["ansible-playbook", "-i", inventory, playbook] @@ -18,7 +18,6 @@ def run_ansible_playbook(inventory: str, playbook: str, modes: dict, limit: str if modes: for key, value in modes.items(): - # Convert boolean values to lowercase strings arg_value = f"{str(value).lower()}" if isinstance(value, bool) else f"{value}" cmd.extend(["-e", f"{key}={arg_value}"]) @@ -28,9 +27,12 @@ def run_ansible_playbook(inventory: str, playbook: str, modes: dict, limit: str cmd.extend(["--ask-vault-pass"]) if verbose: - # Append a single flag with multiple "v"s (e.g. -vvv) cmd.append("-" + "v" * verbose) - subprocess.run(['make','build'], check=True) + + if not skip_tests: + subprocess.run(["make", "test"], check=True) + + subprocess.run(["make", "build"], check=True) subprocess.run(cmd, check=True) def main(): @@ -60,6 +62,7 @@ def main(): playbook_parser.add_argument("--cleanup", action="store_true", help="Enable cleanup mode") playbook_parser.add_argument("--debug", action="store_true", help="Enable debugging output") playbook_parser.add_argument("--password-file", help="Path to the Vault password file") + playbook_parser.add_argument("--skip-tests", action="store_true", help="Skip running make test before executing the playbook") playbook_parser.add_argument("-v", "--verbose", action="count", default=0, help=("Increase verbosity. This option can be specified multiple times " "to increase the verbosity level (e.g., -vvv for more detailed debug output).")) @@ -79,8 +82,15 @@ def main(): "host_type": args.host_type } - # Use a fixed playbook file "playbook.yml" - run_ansible_playbook(args.inventory, f"{script_dir}/playbook.yml", modes, args.limit, args.password_file, args.verbose) + run_ansible_playbook( + inventory=args.inventory, + playbook=f"{script_dir}/playbook.yml", + modes=modes, + limit=args.limit, + password_file=args.password_file, + verbose=args.verbose, + skip_tests=args.skip_tests + ) if __name__ == "__main__": main() diff --git a/module_utils/__init__.py b/module_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/module_utils/cert_utils.py b/module_utils/cert_utils.py index c2d711e8..9d9a9748 100644 --- a/module_utils/cert_utils.py +++ b/module_utils/cert_utils.py @@ -1,9 +1,16 @@ #!/usr/bin/python +from __future__ import absolute_import, division, print_function +__metaclass__ = type + import os import subprocess +import time class CertUtils: + _domain_cert_mapping = None + _cert_snapshot = None + @staticmethod def run_openssl(cert_path): try: @@ -40,12 +47,93 @@ class CertUtils: @staticmethod def matches(domain, san): - """Check if the SAN entry matches the domain according to wildcard rules.""" if san.startswith('*.'): base = san[2:] - # Check if domain is direct subdomain (one label only) - if domain.count('.') == base.count('.') + 1 and domain.endswith('.' + base): - return True + # Wildcard does NOT cover the base domain itself + if domain == base: + return False + if domain.endswith('.' + base): + # Check if the domain has exactly one label more than the base + domain_labels = domain.split('.') + base_labels = base.split('.') + if len(domain_labels) == len(base_labels) + 1: + return True return False else: + # Exact match required for non-wildcard SAN entries return domain == san + + + @classmethod + def build_snapshot(cls, cert_base_path): + snapshot = [] + for cert_file in cls.list_cert_files(cert_base_path): + try: + stat = os.stat(cert_file) + snapshot.append((cert_file, stat.st_mtime, stat.st_size)) + except FileNotFoundError: + continue + snapshot.sort() + return snapshot + + @classmethod + def snapshot_changed(cls, cert_base_path): + current_snapshot = cls.build_snapshot(cert_base_path) + if cls._cert_snapshot != current_snapshot: + cls._cert_snapshot = current_snapshot + return True + return False + + @classmethod + def refresh_cert_mapping(cls, cert_base_path, debug=False): + cert_files = cls.list_cert_files(cert_base_path) + mapping = {} + for cert_path in cert_files: + cert_text = cls.run_openssl(cert_path) + if not cert_text: + continue + sans = cls.extract_sans(cert_text) + folder = os.path.basename(os.path.dirname(cert_path)) + for san in sans: + if san not in mapping: + mapping[san] = folder + cls._domain_cert_mapping = mapping + if debug: + print(f"[DEBUG] Refreshed domain-to-cert mapping: {mapping}") + + @classmethod + def ensure_cert_mapping(cls, cert_base_path, debug=False): + if cls._domain_cert_mapping is None or cls.snapshot_changed(cert_base_path): + cls.refresh_cert_mapping(cert_base_path, debug) + + @classmethod + def find_cert_for_domain(cls, domain, cert_base_path, debug=False): + cls.ensure_cert_mapping(cert_base_path, debug) + + exact_match = None + wildcard_match = None + + for san, folder in cls._domain_cert_mapping.items(): + if san == domain: + exact_match = folder + break + if san.startswith('*.'): + base = san[2:] + if domain.count('.') == base.count('.') + 1 and domain.endswith('.' + base): + wildcard_match = folder + + if exact_match: + if debug: + print(f"[DEBUG] Exact match for {domain} found in {exact_match}") + return exact_match + + if wildcard_match: + if debug: + print(f"[DEBUG] Wildcard match for {domain} found in {wildcard_match}") + return wildcard_match + + if debug: + print(f"[DEBUG] No certificate folder found for {domain}") + + return None + diff --git a/roles/nginx-https-get-cert/tasks/flavors/san.yml b/roles/nginx-https-get-cert/tasks/flavors/san.yml index 685ad684..1fda3626 100644 --- a/roles/nginx-https-get-cert/tasks/flavors/san.yml +++ b/roles/nginx-https-get-cert/tasks/flavors/san.yml @@ -20,6 +20,7 @@ {% endif %} {{ '--mode-test' if mode_test | bool else '' }} register: certbundle_result + changed_when: "'Certificate not yet due for renewal' not in certbundle_result.stdout" when: run_once_san_certs is not defined - name: run the san tasks once diff --git a/roles/nginx-https-get-cert/tasks/main.yml b/roles/nginx-https-get-cert/tasks/main.yml index 5cdcb153..99cd234a 100644 --- a/roles/nginx-https-get-cert/tasks/main.yml +++ b/roles/nginx-https-get-cert/tasks/main.yml @@ -27,4 +27,9 @@ - name: Set fact set_fact: - ssl_cert_folder: "{{ cert_folder_result.folder }}" \ No newline at end of file + ssl_cert_folder: "{{ cert_folder_result.folder }}" + +- name: Ensure ssl_cert_folder is set + fail: + msg: "No certificate folder found for domain {{ domain }}" + when: ssl_cert_folder is undefined or ssl_cert_folder is none \ No newline at end of file diff --git a/tests/unit/test_cert_utils.py b/tests/unit/test_cert_utils.py new file mode 100644 index 00000000..3c117592 --- /dev/null +++ b/tests/unit/test_cert_utils.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 + +import os +import sys + +# Add module_utils/ to the import path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..", "module_utils"))) + +from module_utils.cert_utils import CertUtils + +def test_matches(): + tests = [ + # Exact matches + ("example.com", "example.com", True), + ("www.example.com", "www.example.com", True), + ("api.example.com", "api.example.com", True), + + # Wildcard matches + ("sub.example.com", "*.example.com", True), + ("www.example.com", "*.example.com", True), + + # Wildcard non-matches + ("example.com", "*.example.com", False), # base domain is not covered + ("deep.sub.example.com", "*.example.com", False), # too deep + ("sub.deep.example.com", "*.deep.example.com", True), # correct: one level below + + # Special cases + ("deep.api.example.com", "*.api.example.com", True), + ("api.example.com", "*.api.example.com", False), # base not covered by wildcard + + # Completely different domains + ("test.other.com", "*.example.com", False), + ] + + passed = 0 + failed = 0 + + for domain, san, expected in tests: + result = CertUtils.matches(domain, san) + if result == expected: + print(f"✅ PASS: {domain} vs {san} -> {result}") + passed += 1 + else: + print(f"❌ FAIL: {domain} vs {san} -> {result} (expected {expected})") + failed += 1 + + print(f"\nSummary: {passed} passed, {failed} failed") + +if __name__ == "__main__": + test_matches()