Improved performance by avoiding the load of roles which are anyhow just protected by one condition

This commit is contained in:
Kevin Veen-Birkenbach 2025-08-11 13:52:24 +02:00
parent c7badc608a
commit 5709935c92
No known key found for this signature in database
GPG Key ID: 44D8F11FD62F878E
3 changed files with 327 additions and 1 deletions

View File

@ -7,6 +7,7 @@
public: false
# Explicit set to guaranty that application_id will not be overwritten.
# Should be anyhow the default case
when: run_once_svc_prx_openresty is not defined
- name: "reset (if enabled)"
include_tasks: 02_reset.yml

View File

@ -17,7 +17,9 @@
- name: "Update Docker Images"
include_role:
name: update-docker
when: docker_compose_directory_stat.stat.exists
when:
- docker_compose_directory_stat.stat.exists
- run_once_update_docker is not defined
- name: "Check if yay is installed"
command: which yay

View File

@ -0,0 +1,323 @@
import os
import glob
import unittest
from typing import Any, Dict, List, Tuple, Optional
try:
import yaml
except ImportError: # pragma: no cover
raise SystemExit("Please `pip install pyyaml` to run this test.")
# ---------- Helpers: repo + YAML parsing ----------
def _find_repo_root_containing(relative: str, max_depth: int = 8) -> str:
"""Walk upwards from this file to find the repo root that contains `relative`."""
here = os.path.abspath(os.path.dirname(__file__))
cur = here
for _ in range(max_depth):
candidate = os.path.join(cur, relative)
if os.path.exists(candidate):
return cur
parent = os.path.dirname(cur)
if parent == cur:
break
cur = parent
raise FileNotFoundError(f"Could not find {relative!r} upwards from {here}")
def _load_yaml_file(path: str) -> List[Dict[str, Any]]:
"""
Load a tasks YAML file.
Returns a list of top-level task dicts. If the file is empty, returns [].
Supports multi-doc YAML.
"""
with open(path, "r", encoding="utf-8") as f:
content = f.read()
docs = list(yaml.safe_load_all(content)) or []
tasks: List[Dict[str, Any]] = []
for doc in docs:
if doc is None:
continue
if isinstance(doc, list):
tasks.extend([t for t in doc if isinstance(t, dict)])
elif isinstance(doc, dict):
if "tasks" in doc and isinstance(doc["tasks"], list):
tasks.extend([t for t in doc["tasks"] if isinstance(t, dict)])
else:
tasks.append(doc)
return tasks
# ---------- Helpers: when / structure checks ----------
def _normalize_when(value: Any) -> List[str]:
"""
Normalize a 'when' value (string | list | bool | None) to a list of strings.
Non-string entries are ignored.
"""
if value is None or isinstance(value, bool):
return []
if isinstance(value, str):
v = value.strip()
return [v] if v else []
if isinstance(value, list):
out: List[str] = []
for item in value:
if isinstance(item, str):
s = item.strip()
if s:
out.append(s)
return out
return []
def _task_has_block_with_when(task: Dict[str, Any]) -> bool:
return "block" in task and bool(_normalize_when(task.get("when")))
def _is_pure_guarded_tasks_file(tasks: List[Dict[str, Any]]) -> Tuple[List[str], bool]:
"""
A "pure guarded" tasks file has EXACTLY ONE top-level task,
that task contains a 'block', and that task has a 'when' condition.
Returns (guard_conditions, is_pure_guarded).
"""
if len(tasks) != 1:
return [], False
only_task = tasks[0]
if not _task_has_block_with_when(only_task):
return [], False
return _normalize_when(only_task.get("when")), True
# ---------- Helpers: discovery ----------
def _iter_all_tasks_files(repo_root: str) -> List[str]:
"""
Return all tasks/*.yml|*.yaml files in the project (recursively).
"""
patterns = [
os.path.join(repo_root, "**", "tasks", "*.yml"),
os.path.join(repo_root, "**", "tasks", "*.yaml"),
]
files: List[str] = []
for pat in patterns:
files.extend(glob.glob(pat, recursive=True))
# Deduplicate while keeping order
seen = set()
ordered: List[str] = []
for p in files:
if p not in seen:
ordered.append(p)
seen.add(p)
return ordered
def _get_include_role_name(task: Dict[str, Any]) -> Optional[str]:
"""
If task is an include_role task, return the role 'name'.
Supports 'include_role' and 'ansible.builtin.include_role'.
"""
for key in ("include_role", "ansible.builtin.include_role"):
if key in task and isinstance(task[key], dict):
role_name = task[key].get("name")
if isinstance(role_name, str) and role_name.strip():
return role_name.strip()
return None
def _get_include_tasks_target(task: Dict[str, Any]) -> Optional[str]:
"""
If task is an include_tasks, return the path string as-is (could be relative).
Supports 'include_tasks' and 'ansible.builtin.include_tasks'.
Returns None if not found or not a string.
"""
for key in ("include_tasks", "ansible.builtin.include_tasks"):
if key in task:
val = task[key]
if isinstance(val, str):
return val.strip()
return None
def _contains_jinja(s: str) -> bool:
return "{{" in s or "{%" in s or "}}" in s or "%}" in s
def _resolve_include_tasks_path(include_value: str, including_file: str) -> Optional[str]:
"""
Resolve an include_tasks path relative to the including file.
If it contains Jinja or does not resolve to an existing file, return None.
Tries exact path, then adds .yml / .yaml if extension missing.
"""
if _contains_jinja(include_value):
return None
# Absolute path?
candidates: List[str] = []
if os.path.isabs(include_value):
candidates.append(include_value)
else:
base = os.path.dirname(including_file)
candidates.append(os.path.join(base, include_value))
final_candidates: List[str] = []
for c in candidates:
final_candidates.append(c)
root, ext = os.path.splitext(c)
if ext == "":
final_candidates.append(root + ".yml")
final_candidates.append(root + ".yaml")
for c in final_candidates:
if os.path.isfile(c):
return c
return None
class PureGuardedIncludeTest(unittest.TestCase):
"""
Enforce short-circuit includes ONLY for "pure guarded" targets:
- Exactly one top-level task
- That task is a 'block'
- That task has a 'when'
Apply to both:
- include_role (roles/<role>/tasks/main.yml must be pure guarded)
- include_tasks (target tasks file must be pure guarded)
"""
@classmethod
def setUpClass(cls):
cls.repo_root = _find_repo_root_containing("roles")
# Map pure guarded roles: role_name -> (guards, main_path)
cls.pure_guarded_roles: Dict[str, Tuple[List[str], str]] = {}
role_main_glob = os.path.join(cls.repo_root, "roles", "*", "tasks", "main.yml")
for main_path in glob.glob(role_main_glob):
role_name = os.path.basename(os.path.dirname(os.path.dirname(main_path))) # roles/<role>/tasks/main.yml
try:
tasks = _load_yaml_file(main_path)
guards, pure = _is_pure_guarded_tasks_file(tasks)
if pure and guards:
cls.pure_guarded_roles[role_name] = (guards, main_path)
except Exception:
# If parsing fails, ignore here; will be caught when scanning all files if relevant
pass
# Cache of parsed tasks files for include_tasks: path -> (guards, pure)
cls.tasks_file_cache: Dict[str, Tuple[List[str], bool]] = {}
# All tasks files across repo
cls.all_tasks_files = _iter_all_tasks_files(cls.repo_root)
# ---------- Tests ----------
def test_include_role_short_circuits_when_target_is_pure_guarded(self):
failures: List[str] = []
if not self.pure_guarded_roles:
self.skipTest("No pure guarded roles found; nothing to validate for include_role.")
for path in self.all_tasks_files:
try:
tasks = _load_yaml_file(path)
except Exception as e:
failures.append(f"[PARSE ERROR] {path}: {e}")
continue
for idx, task in enumerate(tasks):
role_name = _get_include_role_name(task)
if not role_name:
continue
# Only enforce when the included role is pure guarded
role_entry = self.pure_guarded_roles.get(role_name)
if not role_entry:
continue
guards, main_path = role_entry
include_when = _normalize_when(task.get("when"))
if not include_when:
failures.append(
f"{path} (task #{idx+1}) includes role '{role_name}' "
f"but lacks a 'when'. The role is pure guarded by {guards} in {main_path}. "
f"Add at least one of those guard expressions to the include to avoid loading the role unnecessarily."
)
continue
if not any(req in include_when for req in guards):
failures.append(
f"{path} (task #{idx+1}) includes role '{role_name}' but its 'when' "
f"does not contain the role's guard.\n"
f"Role guard(s) from {main_path}: {guards}\n"
f"Include 'when': {include_when}\n"
"Add the role's guard condition to short-circuit when false."
)
if failures:
self.fail("Some include_role calls are missing pure-guard short-circuiting:\n\n" + "\n\n".join(failures))
def test_include_tasks_short_circuits_when_target_is_pure_guarded(self):
failures: List[str] = []
for including_path in self.all_tasks_files:
try:
including_tasks = _load_yaml_file(including_path)
except Exception as e:
failures.append(f"[PARSE ERROR] {including_path}: {e}")
continue
for idx, task in enumerate(including_tasks):
include_value = _get_include_tasks_target(task)
if not include_value:
continue
resolved = _resolve_include_tasks_path(include_value, including_path)
if not resolved:
# Could not resolve (Jinja path or file not found). Skip enforcing.
continue
# Load/inspect included tasks file (with cache)
if resolved not in self.tasks_file_cache:
try:
target_tasks = _load_yaml_file(resolved)
guards, pure = _is_pure_guarded_tasks_file(target_tasks)
self.tasks_file_cache[resolved] = (guards, pure)
except Exception as e:
failures.append(f"[PARSE ERROR] included by {including_path} (task #{idx+1}): {resolved}: {e}")
# mark as non-pure to avoid repeated parsing attempts
self.tasks_file_cache[resolved] = ([], False)
guards, pure = self.tasks_file_cache.get(resolved, ([], False))
if not (pure and guards):
# Only enforce for pure guarded task files
continue
include_when = _normalize_when(task.get("when"))
if not include_when:
failures.append(
f"{including_path} (task #{idx+1}) includes tasks '{include_value}' "
f"-> {resolved}, which is PURE GUARDED by {guards}. "
f"Add at least one of those guard expressions to the include to avoid loading the file unnecessarily."
)
continue
if not any(req in include_when for req in guards):
failures.append(
f"{including_path} (task #{idx+1}) includes tasks '{include_value}' "
f"-> {resolved}, but its 'when' does not contain the target's guard.\n"
f"Target guard(s): {guards}\n"
f"Include 'when': {include_when}\n"
"Add the guard condition to short-circuit when false."
)
if failures:
self.fail("Some include_tasks calls are missing pure-guard short-circuiting:\n\n" + "\n\n".join(failures))
if __name__ == "__main__":
unittest.main()