diff --git a/filter_plugins/timeout_start_sec_for_domains.py b/filter_plugins/timeout_start_sec_for_domains.py index fba98dc2..144dedfa 100644 --- a/filter_plugins/timeout_start_sec_for_domains.py +++ b/filter_plugins/timeout_start_sec_for_domains.py @@ -1,12 +1,7 @@ +# filter_plugins/timeout_start_sec_for_domains.py (nur Kern geƤndert) from ansible.errors import AnsibleFilterError class FilterModule(object): - """ - Compute a max TimeoutStartSec for systemd services that iterate over many domains. - The timeout scales with the number of unique domains (optionally including www.* clones) - and is clamped between configurable min/max bounds. - """ - def filters(self): return { "timeout_start_sec_for_domains": self.timeout_start_sec_for_domains, @@ -23,28 +18,17 @@ class FilterModule(object): ): """ Args: - domains_dict (dict): Same structure you pass to generate_all_domains - (values can be str | list[str] | dict[str,str]). - include_www (bool): If true, also count "www." variants. - per_domain_seconds (int): Budget per domain (default 25s). - overhead_seconds (int): Fixed overhead on top (default 30s). - min_seconds (int): Lower clamp (default 120s). - max_seconds (int): Upper clamp (default 3600s). - - Returns: - int: TimeoutStartSec in seconds (integer). - - Raises: - AnsibleFilterError: On invalid input types or unexpected failures. + domains_dict (dict | list[str] | str): Either the domain mapping dict + (values can be str | list[str] | dict[str,str]) or an already + flattened list of domains, or a single domain string. + include_www (bool): If true, add 'www.' for non-www entries. + ... """ try: - if not isinstance(domains_dict, dict): - raise AnsibleFilterError("Expected 'domains_dict' to be a dict.") - - # Local flatten similar to your generate_all_domains - def _flatten(domains): + # Local flattener for dict inputs (like your generate_all_domains source) + def _flatten_from_dict(domains_map): flat = [] - for v in (domains or {}).values(): + for v in (domains_map or {}).values(): if isinstance(v, str): flat.append(v) elif isinstance(v, list): @@ -53,18 +37,26 @@ class FilterModule(object): flat.extend(v.values()) return flat - flat = _flatten(domains_dict) + # Accept dict | list | str + if isinstance(domains_dict, dict): + flat = _flatten_from_dict(domains_dict) + elif isinstance(domains_dict, list): + flat = list(domains_dict) + elif isinstance(domains_dict, str): + flat = [domains_dict] + else: + raise AnsibleFilterError( + "Expected 'domains_dict' to be dict | list | str." + ) if include_www: - # dedupe first so we don't generate duplicate www-variants base_unique = sorted(set(flat)) - www_variants = [f"www.{d}" for d in base_unique if not str(d).startswith("www.")] + www_variants = [f"www.{d}" for d in base_unique if not str(d).lower().startswith("www.")] flat.extend(www_variants) unique_domains = sorted(set(flat)) count = len(unique_domains) - # Compute and clamp raw = overhead_seconds + per_domain_seconds * count clamped = max(min_seconds, min(max_seconds, int(raw))) return clamped diff --git a/tests/unit/filter_plugins/test_timeout_start_sec_for_domains.py b/tests/unit/filter_plugins/test_timeout_start_sec_for_domains.py index d9a1c6ad..c75fa794 100644 --- a/tests/unit/filter_plugins/test_timeout_start_sec_for_domains.py +++ b/tests/unit/filter_plugins/test_timeout_start_sec_for_domains.py @@ -78,10 +78,28 @@ class TestTimeoutStartSecForDomains(unittest.TestCase): # raw = 30 + 25*4 = 130 self.assertEqual(result, 130) - def test_raises_on_non_dict_input(self): + def test_raises_on_invalid_type_int(self): with self.assertRaises(AnsibleFilterError): - _f()(["not-a-dict"]) + _f()(123) + def test_raises_on_invalid_type_none(self): + with self.assertRaises(AnsibleFilterError): + _f()(None) + + def test_accepts_list_input(self): + domains_list = ["a.com", "www.a.com", "b.com"] + result = _f()(domains_list, include_www=True, + per_domain_seconds=25, overhead_seconds=30, + min_seconds=1, max_seconds=10000) + # unique + www for b.com -> {"a.com","www.a.com","b.com","www.b.com"} = 4 + self.assertEqual(result, 30 + 25*4) + + def test_accepts_str_input(self): + result = _f()("a.com", include_www=True, + per_domain_seconds=25, overhead_seconds=30, + min_seconds=1, max_seconds=10000) + # {"a.com","www.a.com"} = 2 + self.assertEqual(result, 30 + 25*2) if __name__ == "__main__": unittest.main()