From 1fcf0722579a36e8de66036f825c2bf4bcde0164 Mon Sep 17 00:00:00 2001 From: Kevin Veen-Birkenbach Date: Mon, 11 Aug 2025 10:50:42 +0200 Subject: [PATCH] Added performance violation test for blocks --- tests/integration/test_block_when_size.py | 169 ++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 tests/integration/test_block_when_size.py diff --git a/tests/integration/test_block_when_size.py b/tests/integration/test_block_when_size.py new file mode 100644 index 00000000..4bd5b95a --- /dev/null +++ b/tests/integration/test_block_when_size.py @@ -0,0 +1,169 @@ +# tests/integration/test_block_when_size.py +import unittest +from pathlib import Path +import yaml +from typing import Any, Dict, List, Tuple, Union + +# -------- YAML loader that's tolerant of Ansible-specific tags (e.g. !vault) ----- +class AnsibleTolerantLoader(yaml.SafeLoader): + pass + +def _ansible_tag_passthrough(loader: yaml.Loader, tag_prefix: str, node: yaml.Node): + # Treat unknown/Ansible custom tags as plain YAML + if isinstance(node, yaml.ScalarNode): + return loader.construct_scalar(node) + if isinstance(node, yaml.SequenceNode): + return loader.construct_sequence(node) + if isinstance(node, yaml.MappingNode): + return loader.construct_mapping(node) + return None + +yaml.add_multi_constructor('!', _ansible_tag_passthrough, Loader=AnsibleTolerantLoader) + +# ------------------------------------------------------------------------------- + +def _repo_root() -> Path: + """Find a plausible project root by walking upward from this file.""" + here = Path(__file__).resolve() + cur = here.parent + markers = {'.git', 'ansible.cfg', 'roles', 'playbook.yml'} + for _ in range(7): + if any((cur / m).exists() for m in markers): + return cur + parent = cur.parent + if parent == cur: + break + cur = parent + # Fallback: project/tests/integration/ -> pick parent of 'tests' + return here.parents[2] if len(here.parents) >= 3 else here.parent + +Yaml = Union[Dict[str, Any], List[Any], Any] + +def _iter_yaml_files(root: Path) -> List[Path]: + """Return all *.yml files in the repository (excluding common junk dirs).""" + ignore_dirs = {'.git', '.venv', 'venv', '.tox', '.idea', '.pytest_cache', '__pycache__'} + files: List[Path] = [] + for p in root.rglob('*.yml'): + if any(part in ignore_dirs for part in p.parts): + continue + files.append(p) + return files + +def _safe_load_all(path: Path) -> List[Yaml]: + """Load all YAML documents from a file, tolerating Ansible tags; return list of docs.""" + try: + with path.open('r', encoding='utf-8') as fh: + return list(yaml.load_all(fh, Loader=AnsibleTolerantLoader)) + except Exception: + # If a file is completely unparsable, treat as empty (so test won't crash). + return [] + +def _find_blocks_with_when(node: Yaml, path: str = "") -> List[Tuple[str, Dict[str, Any]]]: + """ + Recursively find mappings that represent an Ansible block with a block-level `when`. + Returns list of (yaml_path, block_mapping). + """ + found: List[Tuple[str, Dict[str, Any]]] = [] + if isinstance(node, dict): + if 'block' in node and 'when' in node and isinstance(node['block'], list): + found.append((path or '/', node)) + # Recurse into all values to catch nested blocks + for k, v in node.items(): + child_path = f"{path}/{k}" if path else f"/{k}" + found.extend(_find_blocks_with_when(v, child_path)) + elif isinstance(node, list): + for i, item in enumerate(node): + child_path = f"{path}[{i}]" + found.extend(_find_blocks_with_when(item, child_path)) + return found + +def _len_if_list(x: Any) -> int: + return len(x) if isinstance(x, list) else 0 + +class BlockWhenSizeTest(unittest.TestCase): + MAX_TASKS = 3 # performance threshold + + def test_blocks_with_when_and_sections_have_max_three_tasks(self): + root = _repo_root() + violations: List[str] = [] + + for yml in _iter_yaml_files(root): + docs = _safe_load_all(yml) + if not docs: + continue + + for doc_idx, doc in enumerate(docs): + blocks = _find_blocks_with_when(doc, path=f"{yml}:{doc_idx}") + for yaml_path, mapping in blocks: + name = mapping.get('name') or '' + when_expr = mapping.get('when') + + # Check main block size + block_tasks = mapping.get('block', []) + block_count = _len_if_list(block_tasks) + if block_count > self.MAX_TASKS: + violations.append( + f"[PERFORMANCE VIOLATION] {yml} :: {name} :: section=block " + f":: tasks={block_count} (> {self.MAX_TASKS}) " + f":: when={when_expr!r} :: at {yaml_path}" + ) + + # Check rescue size (if present) + rescue_tasks = mapping.get('rescue', []) + rescue_count = _len_if_list(rescue_tasks) + if rescue_count > self.MAX_TASKS: + violations.append( + f"[PERFORMANCE VIOLATION] {yml} :: {name} :: section=rescue " + f":: tasks={rescue_count} (> {self.MAX_TASKS}) " + f":: parent-when={when_expr!r} :: at {yaml_path}/rescue" + ) + + # Check always size (if present) + always_tasks = mapping.get('always', []) + always_count = _len_if_list(always_tasks) + if always_count > self.MAX_TASKS: + violations.append( + f"[PERFORMANCE VIOLATION] {yml} :: {name} :: section=always " + f":: tasks={always_count} (> {self.MAX_TASKS}) " + f":: parent-when={when_expr!r} :: at {yaml_path}/always" + ) + + if violations: + self.fail( + "Blocks with a block-level 'when' must contain at most 3 tasks per section " + "('block', 'rescue', 'always') for performance reasons.\n" + "Rationale:\n" + " - A block-level 'when' does NOT prevent parsing of tasks inside the block; " + " all sections ('block', 'rescue', 'always') are parsed and then may be skipped at runtime, " + " causing parse time and module redirect overhead.\n" + "Recommendation:\n" + " - Keep the block structure (for rescue/always, become, vars, etc.), but place a single " + " `include_tasks` INSIDE the respective section and put the condition on that include " + " (where feasible). When the condition is false, the included file is not parsed at all.\n" + "Example:\n" + " - Bad (heavy tasks directly in sections):\n" + " - block:\n" + " - community.general.pacman: name=foo state=present\n" + " - template: src=a.j2 dest=/a\n" + " rescue:\n" + " - debug: msg='rollback-1'\n" + " - debug: msg='rollback-2'\n" + " always:\n" + " - debug: msg='cleanup-1'\n" + " when: feature_enabled\n\n" + " - Good (preserve semantics, avoid parsing when skipped):\n" + " - block:\n" + " - include_tasks: heavy_setup.yml\n" + " when: feature_enabled | bool\n" + " rescue:\n" + " - include_tasks: rollback.yml\n" + " when: feature_enabled | bool # add a guard if rollback is heavy/rare\n" + " always:\n" + " - include_tasks: cleanup.yml\n" + " when: feature_enabled | bool # add a guard if cleanup is heavy/rare\n" + " when: feature_enabled\n\n" + "Violations:\n" + "\n".join(violations) + ) + +if __name__ == "__main__": + unittest.main()