diff --git a/tests/integration/test_when_condition_duplication.py b/tests/integration/test_when_condition_duplication.py new file mode 100644 index 00000000..3514d2b9 --- /dev/null +++ b/tests/integration/test_when_condition_duplication.py @@ -0,0 +1,202 @@ +# tests/integration/test_when_condition_duplication.py +import os +import unittest +from pathlib import Path +from typing import Any, Dict, Iterable, List, Tuple + +try: + import yaml # PyYAML +except ImportError as e: + raise SystemExit( + "PyYAML is required to run this test. Install with: pip install pyyaml" + ) from e + + +THRESHOLD = 3 # fail if the same when-condition occurs on more than this many tasks + + +def _find_repo_root_containing(marker_names: Iterable[str], max_up: int = 8) -> Path: + """ + Walk upwards from this file to find the repo root. We assume the project root + contains at least one of `marker_names` (e.g., 'roles', '.git', 'playbooks'). + """ + here = Path(__file__).resolve().parent + cur = here + for _ in range(max_up): + for marker in marker_names: + if (cur / marker).exists(): + return cur + if cur.parent == cur: + break + cur = cur.parent + # Fallback: repo root assumed 3 levels up from tests/integration + return Path(__file__).resolve().parents[3] + + +def _normalize_when(value: Any) -> str: + """ + Normalize Ansible 'when' to a comparable string: + - If it's a list, join with ' && ' preserving order (order matters in Ansible). + - If it's a scalar, strip leading/trailing whitespace. + - Represent everything as a single-line string for stable comparison. + """ + if value is None: + return "" + if isinstance(value, list): + parts = [] + for v in value: + s = "" if v is None else str(v).strip() + # collapse internal whitespace runs to a single space for stability + s = " ".join(s.split()) + parts.append(s) + return " && ".join(parts) + # scalar (str, int, bool, jinja template, etc.) + s = str(value).strip() + return " ".join(s.split()) + + +def _iter_tasks(node: Any) -> Iterable[Dict[str, Any]]: + """ + Yield task-like dicts (those which may contain 'when') from arbitrary YAML structures. + Handles: + - Top-level lists of tasks + - Dicts that contain keys like 'block', 'rescue', 'always' (Ansible blocks) + - Nested lists/dicts recursively + We only yield a dict once as a "task" (the one that has a 'when' or looks like a task). + """ + if isinstance(node, list): + for item in node: + yield from _iter_tasks(item) + elif isinstance(node, dict): + # If this dict itself looks like a task (has module keys or 'when'/'name'), + # yield it, but also traverse nested blocks. + is_task_like = any( + k in node + for k in ( + "when", + "name", + "block", + "rescue", + "always", + "include_tasks", + "import_tasks", + "ansible.builtin.include_tasks", + "ansible.builtin.import_tasks", + ) + ) + if is_task_like: + yield node + + # Recurse into Ansible block sections if present + for section in ("block", "rescue", "always"): + if section in node and isinstance(node[section], list): + for item in node[section]: + yield from _iter_tasks(item) + # Also traverse other nested structures conservatively + for k, v in node.items(): + if k not in ("block", "rescue", "always"): + if isinstance(v, (list, dict)): + yield from _iter_tasks(v) + + +def _load_yaml_documents(path: Path) -> List[Any]: + """ + Load all YAML documents from a file. Best-effort parsing: + - If YAML fails due to Jinja syntax, we still raise, because a broken file + should be fixed in the repo. + """ + text = path.read_text(encoding="utf-8") + return list(yaml.safe_load_all(text)) # may return [None] if empty + + +def _collect_when_counts(yaml_docs: List[Any]) -> Dict[str, List[Tuple[str, str]]]: + """ + Return a mapping: normalized_when -> list of (task_name, hint_location) + where each entry corresponds to a task that uses that 'when'. + """ + counts: Dict[str, List[Tuple[str, str]]] = {} + for doc in yaml_docs: + for task in _iter_tasks(doc): + if "when" not in task: + continue + normalized = _normalize_when(task.get("when")) + if not normalized: + continue + task_name = str(task.get("name") or "") + # Provide a minimal hint for where this came from (e.g., module/inclusion used) + hint = None + for key in ( + "include_tasks", + "import_tasks", + "ansible.builtin.include_tasks", + "ansible.builtin.import_tasks", + ): + if key in task: + hint = f"{key}: {task[key]}" + break + hint_loc = hint or "task" + counts.setdefault(normalized, []).append((task_name, hint_loc)) + return counts + + +class WhenConditionDuplicationTest(unittest.TestCase): + """ + Integration test that ensures we don't repeat the same 'when' condition + on too many tasks in a single tasks file. + + Rationale: + Repeating identical 'when' across many tasks forces Ansible to evaluate + the same condition over and over, which is bad for performance. Prefer + factoring those tasks into a dedicated file and use `include_tasks` + (or a block with a single 'when') to evaluate once. + """ + + def test_excessive_repeated_when_in_tasks_files(self): + repo_root = _find_repo_root_containing(marker_names=(".git", "roles", "playbooks")) + tasks_globs = [ + "**/tasks/**/*.yml", + "**/tasks/**/*.yaml", + ] + + violations: List[str] = [] + + for pattern in tasks_globs: + for path in repo_root.glob(pattern): + # Only scan files that are inside the project workspace + if not path.is_file(): + continue + + try: + docs = _load_yaml_documents(path) + except Exception as exc: + self.fail(f"Failed to parse YAML file: {path}\n{exc}") + + counts = _collect_when_counts(docs) + for normalized_when, occurrences in counts.items(): + if len(occurrences) > THRESHOLD: + # Build a helpful error message showing a few sample tasks with this condition + sample = "\n".join( + f" - {tname} ({hint})" + for tname, hint in occurrences[:5] + ) + violations.append( + ( + f"{path} uses the same 'when' condition more than {THRESHOLD} times " + f"({len(occurrences)} occurrences):\n" + f" WHEN: {normalized_when}\n" + f" Sample tasks:\n{sample}\n" + f"Suggestion: Group these tasks into a separate file and call it with " + f"`include_tasks`, or use a single `block` guarded by this 'when' to avoid " + f"re-evaluating the condition repeatedly." + ) + ) + + if violations: + self.fail( + "Excessive duplicate 'when' conditions detected (hurts performance):\n\n" + + "\n\n".join(violations) + ) + + +if __name__ == "__main__": + unittest.main()