mirror of
				https://github.com/kevinveenbirkenbach/computer-playbook.git
				synced 2025-10-31 10:19:09 +00:00 
			
		
		
		
	Optimized cert speed, testing etc.
This commit is contained in:
		| @@ -1,36 +1,7 @@ | ||||
| #!/usr/bin/python | ||||
|  | ||||
| from __future__ import absolute_import, division, print_function | ||||
| __metaclass__ = type | ||||
|  | ||||
| import os | ||||
| from ansible.module_utils.basic import AnsibleModule | ||||
| from ansible.module_utils.cert_utils import CertUtils | ||||
|  | ||||
| def cert_exists(domain, cert_files, debug=False): | ||||
|     for cert_path in cert_files: | ||||
|         cert_text = CertUtils.run_openssl(cert_path) | ||||
|         if not cert_text: | ||||
|             continue | ||||
|         sans = CertUtils.extract_sans(cert_text) | ||||
|         if debug: | ||||
|             print(f"Checking {cert_path}: {sans}") | ||||
|         for entry in sans: | ||||
|             if CertUtils.matches(domain, entry): | ||||
|                 return True | ||||
|     return False | ||||
|  | ||||
| def cert_check_exists(module): | ||||
|     domain = module.params['domain'] | ||||
|     cert_base_path = module.params['cert_base_path'] | ||||
|     debug = module.params['debug'] | ||||
|  | ||||
|     cert_files = CertUtils.list_cert_files(cert_base_path) | ||||
|  | ||||
|     exists = cert_exists(domain, cert_files, debug) | ||||
|  | ||||
|     module.exit_json(exists=exists) | ||||
|  | ||||
| def main(): | ||||
|     module_args = dict( | ||||
|         domain=dict(type='str', required=True), | ||||
| @@ -39,11 +10,17 @@ def main(): | ||||
|     ) | ||||
|  | ||||
|     module = AnsibleModule( | ||||
|         argument_spec=module_args, | ||||
|         supports_check_mode=True | ||||
|         argument_spec=module_args | ||||
|     ) | ||||
|  | ||||
|     cert_check_exists(module) | ||||
|     domain = module.params['domain'] | ||||
|     cert_base_path = module.params['cert_base_path'] | ||||
|     debug = module.params['debug'] | ||||
|  | ||||
|     folder = CertUtils.find_cert_for_domain(domain, cert_base_path, debug) | ||||
|     exists = folder is not None | ||||
|  | ||||
|     module.exit_json(exists=exists) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
|     main() | ||||
| @@ -1,48 +1,6 @@ | ||||
| #!/usr/bin/python | ||||
|  | ||||
| from __future__ import absolute_import, division, print_function | ||||
| __metaclass__ = type | ||||
|  | ||||
| import os | ||||
| from ansible.module_utils.basic import AnsibleModule | ||||
| from ansible.module_utils.cert_utils import CertUtils | ||||
|  | ||||
| def cert_folder_find(module): | ||||
|     domain = module.params['domain'] | ||||
|     cert_base_path = module.params['cert_base_path'] | ||||
|     debug = module.params['debug'] | ||||
|  | ||||
|     cert_files = CertUtils.list_cert_files(cert_base_path) | ||||
|  | ||||
|     if debug: | ||||
|         print(f"Found {len(cert_files)} cert.pem files under {cert_base_path}") | ||||
|  | ||||
|     matching_folders = [] | ||||
|  | ||||
|     for cert_path in cert_files: | ||||
|         cert_text = CertUtils.run_openssl(cert_path) | ||||
|         if not cert_text: | ||||
|             continue | ||||
|         sans = CertUtils.extract_sans(cert_text) | ||||
|         if debug: | ||||
|             print(f"Checking {cert_path}: {sans}") | ||||
|         for entry in sans: | ||||
|             if CertUtils.matches(domain, entry): | ||||
|                 folder = os.path.basename(os.path.dirname(cert_path)) | ||||
|                 matching_folders.append(folder) | ||||
|                 if debug: | ||||
|                     print(f"Match found in folder: {folder}") | ||||
|                 break  # No need to check further SANs for this cert | ||||
|  | ||||
|     if not matching_folders: | ||||
|         # No matching cert found | ||||
|         module.exit_json(folder=None) | ||||
|  | ||||
|     # Prefer shortest and least-dashed folder name (SAN bundles often have more dashes) | ||||
|     matching_folders = sorted(matching_folders, key=lambda f: (f.count('-'), len(f))) | ||||
|  | ||||
|     module.exit_json(folder=matching_folders[0]) | ||||
|  | ||||
| def main(): | ||||
|     module_args = dict( | ||||
|         domain=dict(type='str', required=True), | ||||
| @@ -51,11 +9,19 @@ def main(): | ||||
|     ) | ||||
|  | ||||
|     module = AnsibleModule( | ||||
|         argument_spec=module_args, | ||||
|         supports_check_mode=True | ||||
|         argument_spec=module_args | ||||
|     ) | ||||
|  | ||||
|     cert_folder_find(module) | ||||
|     domain = module.params['domain'] | ||||
|     cert_base_path = module.params['cert_base_path'] | ||||
|     debug = module.params['debug'] | ||||
|  | ||||
|     folder = CertUtils.find_cert_for_domain(domain, cert_base_path, debug) | ||||
|  | ||||
|     if folder is None: | ||||
|         module.fail_json(msg=f"No certificate covering domain {domain} found.") | ||||
|     else: | ||||
|         module.exit_json(folder=folder) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
|     main() | ||||
|   | ||||
							
								
								
									
										22
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								main.py
									
									
									
									
									
								
							| @@ -9,7 +9,7 @@ def run_ansible_vault(action, filename, password_file): | ||||
|     cmd = ["ansible-vault", action, filename, "--vault-password-file", password_file] | ||||
|     subprocess.run(cmd, check=True)  | ||||
|  | ||||
| def run_ansible_playbook(inventory: str, playbook: str, modes: dict, limit: str = None, password_file: str = None, verbose: int = 0): | ||||
| def run_ansible_playbook(inventory: str, playbook: str, modes: dict, limit: str = None, password_file: str = None, verbose: int = 0, skip_tests: bool = False): | ||||
|     """Execute an ansible-playbook command with optional parameters.""" | ||||
|     cmd = ["ansible-playbook", "-i", inventory, playbook] | ||||
|      | ||||
| @@ -18,7 +18,6 @@ def run_ansible_playbook(inventory: str, playbook: str, modes: dict, limit: str | ||||
|      | ||||
|     if modes: | ||||
|         for key, value in modes.items(): | ||||
|             # Convert boolean values to lowercase strings | ||||
|             arg_value = f"{str(value).lower()}" if isinstance(value, bool) else f"{value}" | ||||
|             cmd.extend(["-e", f"{key}={arg_value}"]) | ||||
|      | ||||
| @@ -28,9 +27,12 @@ def run_ansible_playbook(inventory: str, playbook: str, modes: dict, limit: str | ||||
|         cmd.extend(["--ask-vault-pass"]) | ||||
|      | ||||
|     if verbose: | ||||
|         # Append a single flag with multiple "v"s (e.g. -vvv) | ||||
|         cmd.append("-" + "v" * verbose) | ||||
|     subprocess.run(['make','build'], check=True) | ||||
|  | ||||
|     if not skip_tests: | ||||
|         subprocess.run(["make", "test"], check=True) | ||||
|      | ||||
|     subprocess.run(["make", "build"], check=True) | ||||
|     subprocess.run(cmd, check=True) | ||||
|  | ||||
| def main(): | ||||
| @@ -60,6 +62,7 @@ def main(): | ||||
|     playbook_parser.add_argument("--cleanup", action="store_true", help="Enable cleanup mode") | ||||
|     playbook_parser.add_argument("--debug", action="store_true", help="Enable debugging output") | ||||
|     playbook_parser.add_argument("--password-file", help="Path to the Vault password file") | ||||
|     playbook_parser.add_argument("--skip-tests", action="store_true", help="Skip running make test before executing the playbook") | ||||
|     playbook_parser.add_argument("-v", "--verbose", action="count", default=0, | ||||
|                                  help=("Increase verbosity. This option can be specified multiple times " | ||||
|                                        "to increase the verbosity level (e.g., -vvv for more detailed debug output).")) | ||||
| @@ -79,8 +82,15 @@ def main(): | ||||
|             "host_type": args.host_type | ||||
|         } | ||||
|  | ||||
|         # Use a fixed playbook file "playbook.yml" | ||||
|         run_ansible_playbook(args.inventory, f"{script_dir}/playbook.yml", modes, args.limit, args.password_file, args.verbose) | ||||
|         run_ansible_playbook( | ||||
|             inventory=args.inventory, | ||||
|             playbook=f"{script_dir}/playbook.yml", | ||||
|             modes=modes, | ||||
|             limit=args.limit, | ||||
|             password_file=args.password_file, | ||||
|             verbose=args.verbose, | ||||
|             skip_tests=args.skip_tests | ||||
|         ) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|   | ||||
							
								
								
									
										0
									
								
								module_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								module_utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -1,9 +1,16 @@ | ||||
| #!/usr/bin/python | ||||
|  | ||||
| from __future__ import absolute_import, division, print_function | ||||
| __metaclass__ = type | ||||
|  | ||||
| import os | ||||
| import subprocess | ||||
| import time | ||||
|  | ||||
| class CertUtils: | ||||
|     _domain_cert_mapping = None | ||||
|     _cert_snapshot = None | ||||
|  | ||||
|     @staticmethod | ||||
|     def run_openssl(cert_path): | ||||
|         try: | ||||
| @@ -40,12 +47,93 @@ class CertUtils: | ||||
|  | ||||
|     @staticmethod | ||||
|     def matches(domain, san): | ||||
|         """Check if the SAN entry matches the domain according to wildcard rules.""" | ||||
|         if san.startswith('*.'): | ||||
|             base = san[2:] | ||||
|             # Check if domain is direct subdomain (one label only) | ||||
|             if domain.count('.') == base.count('.') + 1 and domain.endswith('.' + base): | ||||
|                 return True | ||||
|             # Wildcard does NOT cover the base domain itself | ||||
|             if domain == base: | ||||
|                 return False | ||||
|             if domain.endswith('.' + base): | ||||
|                 # Check if the domain has exactly one label more than the base | ||||
|                 domain_labels = domain.split('.') | ||||
|                 base_labels = base.split('.') | ||||
|                 if len(domain_labels) == len(base_labels) + 1: | ||||
|                     return True | ||||
|             return False | ||||
|         else: | ||||
|             # Exact match required for non-wildcard SAN entries | ||||
|             return domain == san | ||||
|  | ||||
|  | ||||
|     @classmethod | ||||
|     def build_snapshot(cls, cert_base_path): | ||||
|         snapshot = [] | ||||
|         for cert_file in cls.list_cert_files(cert_base_path): | ||||
|             try: | ||||
|                 stat = os.stat(cert_file) | ||||
|                 snapshot.append((cert_file, stat.st_mtime, stat.st_size)) | ||||
|             except FileNotFoundError: | ||||
|                 continue | ||||
|         snapshot.sort() | ||||
|         return snapshot | ||||
|  | ||||
|     @classmethod | ||||
|     def snapshot_changed(cls, cert_base_path): | ||||
|         current_snapshot = cls.build_snapshot(cert_base_path) | ||||
|         if cls._cert_snapshot != current_snapshot: | ||||
|             cls._cert_snapshot = current_snapshot | ||||
|             return True | ||||
|         return False | ||||
|  | ||||
|     @classmethod | ||||
|     def refresh_cert_mapping(cls, cert_base_path, debug=False): | ||||
|         cert_files = cls.list_cert_files(cert_base_path) | ||||
|         mapping = {} | ||||
|         for cert_path in cert_files: | ||||
|             cert_text = cls.run_openssl(cert_path) | ||||
|             if not cert_text: | ||||
|                 continue | ||||
|             sans = cls.extract_sans(cert_text) | ||||
|             folder = os.path.basename(os.path.dirname(cert_path)) | ||||
|             for san in sans: | ||||
|                 if san not in mapping: | ||||
|                     mapping[san] = folder | ||||
|         cls._domain_cert_mapping = mapping | ||||
|         if debug: | ||||
|             print(f"[DEBUG] Refreshed domain-to-cert mapping: {mapping}") | ||||
|  | ||||
|     @classmethod | ||||
|     def ensure_cert_mapping(cls, cert_base_path, debug=False): | ||||
|         if cls._domain_cert_mapping is None or cls.snapshot_changed(cert_base_path): | ||||
|             cls.refresh_cert_mapping(cert_base_path, debug) | ||||
|  | ||||
|     @classmethod | ||||
|     def find_cert_for_domain(cls, domain, cert_base_path, debug=False): | ||||
|         cls.ensure_cert_mapping(cert_base_path, debug) | ||||
|  | ||||
|         exact_match = None | ||||
|         wildcard_match = None | ||||
|  | ||||
|         for san, folder in cls._domain_cert_mapping.items(): | ||||
|             if san == domain: | ||||
|                 exact_match = folder | ||||
|                 break | ||||
|             if san.startswith('*.'): | ||||
|                 base = san[2:] | ||||
|                 if domain.count('.') == base.count('.') + 1 and domain.endswith('.' + base): | ||||
|                     wildcard_match = folder | ||||
|  | ||||
|         if exact_match: | ||||
|             if debug: | ||||
|                 print(f"[DEBUG] Exact match for {domain} found in {exact_match}") | ||||
|             return exact_match | ||||
|  | ||||
|         if wildcard_match: | ||||
|             if debug: | ||||
|                 print(f"[DEBUG] Wildcard match for {domain} found in {wildcard_match}") | ||||
|             return wildcard_match | ||||
|  | ||||
|         if debug: | ||||
|             print(f"[DEBUG] No certificate folder found for {domain}") | ||||
|  | ||||
|         return None | ||||
|  | ||||
|   | ||||
| @@ -20,6 +20,7 @@ | ||||
|     {% endif %} | ||||
|     {{ '--mode-test' if mode_test | bool else '' }} | ||||
|   register: certbundle_result | ||||
|   changed_when: "'Certificate not yet due for renewal' not in certbundle_result.stdout" | ||||
|   when: run_once_san_certs is not defined | ||||
|  | ||||
| - name: run the san tasks once | ||||
|   | ||||
| @@ -27,4 +27,9 @@ | ||||
|  | ||||
| - name: Set fact | ||||
|   set_fact: | ||||
|     ssl_cert_folder: "{{ cert_folder_result.folder }}" | ||||
|     ssl_cert_folder: "{{ cert_folder_result.folder }}" | ||||
|  | ||||
| - name: Ensure ssl_cert_folder is set | ||||
|   fail: | ||||
|     msg: "No certificate folder found for domain {{ domain }}" | ||||
|   when: ssl_cert_folder is undefined or ssl_cert_folder is none | ||||
							
								
								
									
										50
									
								
								tests/unit/test_cert_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								tests/unit/test_cert_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| # Add module_utils/ to the import path | ||||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..", "module_utils"))) | ||||
|  | ||||
| from module_utils.cert_utils import CertUtils | ||||
|  | ||||
| def test_matches(): | ||||
|     tests = [ | ||||
|         # Exact matches | ||||
|         ("example.com", "example.com", True), | ||||
|         ("www.example.com", "www.example.com", True), | ||||
|         ("api.example.com", "api.example.com", True), | ||||
|  | ||||
|         # Wildcard matches | ||||
|         ("sub.example.com", "*.example.com", True), | ||||
|         ("www.example.com", "*.example.com", True), | ||||
|  | ||||
|         # Wildcard non-matches | ||||
|         ("example.com", "*.example.com", False),  # base domain is not covered | ||||
|         ("deep.sub.example.com", "*.example.com", False),  # too deep | ||||
|         ("sub.deep.example.com", "*.deep.example.com", True),  # correct: one level below | ||||
|  | ||||
|         # Special cases | ||||
|         ("deep.api.example.com", "*.api.example.com", True), | ||||
|         ("api.example.com", "*.api.example.com", False),  # base not covered by wildcard | ||||
|  | ||||
|         # Completely different domains | ||||
|         ("test.other.com", "*.example.com", False), | ||||
|     ] | ||||
|  | ||||
|     passed = 0 | ||||
|     failed = 0 | ||||
|  | ||||
|     for domain, san, expected in tests: | ||||
|         result = CertUtils.matches(domain, san) | ||||
|         if result == expected: | ||||
|             print(f"✅ PASS: {domain} vs {san} -> {result}") | ||||
|             passed += 1 | ||||
|         else: | ||||
|             print(f"❌ FAIL: {domain} vs {san} -> {result} (expected {expected})") | ||||
|             failed += 1 | ||||
|  | ||||
|     print(f"\nSummary: {passed} passed, {failed} failed") | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     test_matches() | ||||
		Reference in New Issue
	
	Block a user