Refactored

This commit is contained in:
2025-05-20 12:53:10 +02:00
parent 969a176be1
commit dfb67918c8
13 changed files with 210 additions and 190 deletions

0
cli/utils/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,50 @@
import subprocess
from typing import Any, Dict
from yaml.loader import SafeLoader
from yaml.dumper import SafeDumper
class VaultScalar(str):
"""A subclass of str to represent vault-encrypted strings."""
pass
def _vault_constructor(loader, node):
"""Custom constructor to handle !vault tag as plain text."""
return node.value
def _vault_representer(dumper, data):
"""Custom representer to dump VaultScalar as literal blocks."""
return dumper.represent_scalar('!vault', data, style='|')
SafeLoader.add_constructor('!vault', _vault_constructor)
SafeDumper.add_representer(VaultScalar, _vault_representer)
class VaultHandler:
def __init__(self, vault_password_file: str):
self.vault_password_file = vault_password_file
def encrypt_string(self, value: str, name: str) -> str:
"""Encrypt a string using ansible-vault."""
cmd = [
"ansible-vault", "encrypt_string",
value, f"--name={name}",
"--vault-password-file", self.vault_password_file
]
proc = subprocess.run(cmd, capture_output=True, text=True)
if proc.returncode != 0:
raise RuntimeError(f"ansible-vault encrypt_string failed:\n{proc.stderr}")
return proc.stdout
def encrypt_leaves(self, branch: Dict[str, Any], vault_pw: str):
"""Recursively encrypt all leaves (plain text values) under the credentials section."""
for key, value in branch.items():
if isinstance(value, dict):
self.encrypt_leaves(value, vault_pw) # Recurse into nested dictionaries
else:
# Skip if already vaulted (i.e., starts with $ANSIBLE_VAULT)
if isinstance(value, str) and not value.lstrip().startswith("$ANSIBLE_VAULT"):
snippet = self.encrypt_string(value, key)
lines = snippet.splitlines()
indent = len(lines[1]) - len(lines[1].lstrip())
body = "\n".join(line[indent:] for line in lines[1:])
branch[key] = VaultScalar(body) # Store encrypted value as VaultScalar

23
cli/utils/handler/yaml.py Normal file
View File

@@ -0,0 +1,23 @@
import yaml
from yaml.loader import SafeLoader
from typing import Any, Dict
from utils.handler.vault import VaultScalar
class YamlHandler:
@staticmethod
def load_yaml(path) -> Dict:
"""Load the YAML file and wrap existing !vault entries."""
text = path.read_text()
data = yaml.load(text, Loader=SafeLoader) or {}
return YamlHandler.wrap_existing_vaults(data)
@staticmethod
def wrap_existing_vaults(node: Any) -> Any:
"""Recursively wrap any str that begins with '$ANSIBLE_VAULT' in a VaultScalar so it dumps as a literal block."""
if isinstance(node, dict):
return {k: YamlHandler.wrap_existing_vaults(v) for k, v in node.items()}
if isinstance(node, list):
return [YamlHandler.wrap_existing_vaults(v) for v in node]
if isinstance(node, str) and node.lstrip().startswith("$ANSIBLE_VAULT"):
return VaultScalar(node)
return node

View File

@@ -0,0 +1,99 @@
import secrets
import hashlib
import bcrypt
from pathlib import Path
from typing import Dict
from utils.handler.yaml import YamlHandler
from utils.handler.vault import VaultHandler, VaultScalar
class InventoryManager:
def __init__(self, role_path: Path, inventory_path: Path, vault_pw: str, overrides: Dict[str, str]):
"""Initialize the Inventory Manager."""
self.role_path = role_path
self.inventory_path = inventory_path
self.vault_pw = vault_pw
self.overrides = overrides
self.inventory = YamlHandler.load_yaml(inventory_path)
self.schema = YamlHandler.load_yaml(role_path / "meta" / "schema.yml")
self.app_id = self.load_application_id(role_path)
self.vault_handler = VaultHandler(vault_pw)
def load_application_id(self, role_path: Path) -> str:
"""Load the application ID from the role's vars/main.yml file."""
vars_file = role_path / "vars" / "main.yml"
data = YamlHandler.load_yaml(vars_file)
app_id = data.get("application_id")
if not app_id:
print(f"ERROR: 'application_id' missing in {vars_file}", file=sys.stderr)
sys.exit(1)
return app_id
def apply_schema(self) -> Dict:
"""Apply the schema and return the updated inventory."""
apps = self.inventory.setdefault("applications", {})
target = apps.setdefault(self.app_id, {})
# Load the data from vars/main.yml
vars_file = self.role_path / "vars" / "main.yml"
data = YamlHandler.load_yaml(vars_file)
# Check if 'central-database' is enabled in the features section of data
if "features" in data and \
"central-database" in data["features"] and \
data["features"]["central_database"]:
# Add 'database_password' to credentials if 'central-database' is True
target.setdefault("credentials", {})["database_password"] = {
"value": self.generate_value("alphanumeric") # Generate the password value
}
self.recurse(self.schema, target)
return self.inventory
def recurse(self, branch: dict, dest: dict, prefix: str = ""):
"""Recursively process the schema and generate values."""
for key, meta in branch.items():
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(meta, dict) and all(k in meta for k in ("description", "algorithm", "validation")):
alg = meta["algorithm"]
if alg == "plain":
# Must be supplied via --set
if full_key not in self.overrides:
print(f"ERROR: Plain algorithm for '{full_key}' requires override via --set {full_key}=<value>", file=sys.stderr)
sys.exit(1)
plain = self.overrides[full_key]
else:
plain = self.overrides.get(full_key, self.generate_value(alg))
snippet = self.vault_handler.encrypt_string(plain, key)
lines = snippet.splitlines()
indent = len(lines[1]) - len(lines[1].lstrip())
body = "\n".join(line[indent:] for line in lines[1:])
dest[key] = VaultScalar(body)
elif isinstance(meta, dict):
sub = dest.setdefault(key, {})
self.recurse(meta, sub, full_key)
else:
dest[key] = meta
def generate_secure_alphanumeric(length: int) -> str:
"""Generate a cryptographically secure random alphanumeric string of the given length."""
characters = string.ascii_letters + string.digits # a-zA-Z0-9
return ''.join(secrets.choice(characters) for _ in range(length))
def generate_value(self, algorithm: str) -> str:
"""Generate a value based on the provided algorithm."""
if algorithm == "random_hex":
return secrets.token_hex(64)
if algorithm == "sha256":
return hashlib.sha256(secrets.token_bytes(32)).hexdigest()
if algorithm == "sha1":
return hashlib.sha1(secrets.token_bytes(20)).hexdigest()
if algorithm == "bcrypt":
pw = secrets.token_urlsafe(16).encode()
return bcrypt.hashpw(pw, bcrypt.gensalt()).decode()
if algorithm == "alphanumeric":
return generate_secure_alphanumeric(64)
return "undefined"