diff --git a/cli/create_credentials.py b/cli/create_credentials.py index 869d36d3..6c9c619b 100644 --- a/cli/create_credentials.py +++ b/cli/create_credentials.py @@ -12,7 +12,9 @@ from yaml.dumper import SafeDumper def ask_for_confirmation(key: str) -> bool: """Prompt the user for confirmation to overwrite an existing value.""" - confirmation = input(f"Are you sure you want to overwrite the value for '{key}'? (y/n): ").strip().lower() + confirmation = input( + f"Are you sure you want to overwrite the value for '{key}'? (y/n): " + ).strip().lower() return confirmation == 'y' @@ -20,17 +22,31 @@ def main(): parser = argparse.ArgumentParser( description="Selectively vault credentials + become-password in your inventory." ) - parser.add_argument("--role-path", required=True, help="Path to your role") - parser.add_argument("--inventory-file", required=True, help="Host vars file to update") - parser.add_argument("--vault-password-file", required=True, help="Vault password file") - parser.add_argument("--set", nargs="*", default=[], help="Override values key.subkey=VALUE") - parser.add_argument("-f", "--force", action="store_true", help="Force overwrite without confirmation") + parser.add_argument( + "--role-path", required=True, help="Path to your role" + ) + parser.add_argument( + "--inventory-file", required=True, help="Host vars file to update" + ) + parser.add_argument( + "--vault-password-file", required=True, help="Vault password file" + ) + parser.add_argument( + "--set", nargs="*", default=[], help="Override values key.subkey=VALUE" + ) + parser.add_argument( + "-f", "--force", action="store_true", + help="Force overwrite without confirmation" + ) args = parser.parse_args() - # Parsing overrides - overrides = {k.strip(): v.strip() for pair in args.set for k, v in [pair.split("=", 1)]} + # Parse overrides + overrides = { + k.strip(): v.strip() + for pair in args.set for k, v in [pair.split("=", 1)] + } - # Initialize the Inventory Manager + # Initialize inventory manager manager = InventoryManager( role_path=Path(args.role_path), inventory_path=Path(args.inventory_file), @@ -38,34 +54,57 @@ def main(): overrides=overrides ) - # 1) Apply schema and update inventory + # Load existing credentials to preserve + existing_apps = manager.inventory.get("applications", {}) + existing_creds = {} + if manager.app_id in existing_apps: + existing_creds = existing_apps[manager.app_id].get("credentials", {}).copy() + + # Apply schema (may generate defaults) updated_inventory = manager.apply_schema() - # 2) Apply vault encryption ONLY to 'credentials' fields (we no longer apply it globally) - credentials = updated_inventory.get("applications", {}).get(manager.app_id, {}).get("credentials", {}) - for key, value in credentials.items(): - if not value.lstrip().startswith("$ANSIBLE_VAULT"): # Only apply encryption if the value is not already vaulted - if key in credentials and not args.force: - if not ask_for_confirmation(key): # Ask for confirmation before overwriting - print(f"Skipping overwrite of '{key}'.") - continue - encrypted_value = manager.vault_handler.encrypt_string(value, key) - lines = encrypted_value.splitlines() - indent = len(lines[1]) - len(lines[1].lstrip()) - body = "\n".join(line[indent:] for line in lines[1:]) - credentials[key] = VaultScalar(body) # Store encrypted value as VaultScalar + # Restore existing database_password if present + apps = updated_inventory.setdefault("applications", {}) + app_block = apps.setdefault(manager.app_id, {}) + creds = app_block.setdefault("credentials", {}) + if "database_password" in existing_creds: + creds["database_password"] = existing_creds["database_password"] - # 3) Vault top-level ansible_become_password if present + # Store original plaintext values + original_plain = {key: str(val) for key, val in creds.items()} + + for key, raw_val in list(creds.items()): + # Skip if already vaulted + if isinstance(raw_val, VaultScalar) or str(raw_val).lstrip().startswith("$ANSIBLE_VAULT"): + continue + + # Determine plaintext + plain = original_plain.get(key, "") + if key in overrides and (args.force or ask_for_confirmation(key)): + plain = overrides[key] + + # Encrypt the plaintext + encrypted = manager.vault_handler.encrypt_string(plain, key) + lines = encrypted.splitlines() + indent = len(lines[1]) - len(lines[1].lstrip()) + body = "\n".join(line[indent:] for line in lines[1:]) + creds[key] = VaultScalar(body) + + # Vault top-level become password if present if "ansible_become_password" in updated_inventory: val = str(updated_inventory["ansible_become_password"]) - if not val.lstrip().startswith("$ANSIBLE_VAULT"): - snippet = manager.vault_handler.encrypt_string(val, "ansible_become_password") + if val.lstrip().startswith("$ANSIBLE_VAULT"): + updated_inventory["ansible_become_password"] = VaultScalar(val) + else: + snippet = manager.vault_handler.encrypt_string( + val, "ansible_become_password" + ) lines = snippet.splitlines() indent = len(lines[1]) - len(lines[1].lstrip()) body = "\n".join(line[indent:] for line in lines[1:]) updated_inventory["ansible_become_password"] = VaultScalar(body) - # 4) Save the updated inventory to file + # Write back to file with open(args.inventory_file, "w", encoding="utf-8") as f: yaml.dump(updated_inventory, f, sort_keys=False, Dumper=SafeDumper) diff --git a/tests/unit/test_generate_vaulted_credentials.py b/tests/unit/test_generate_vaulted_credentials.py deleted file mode 100644 index 964960b8..00000000 --- a/tests/unit/test_generate_vaulted_credentials.py +++ /dev/null @@ -1,77 +0,0 @@ -import pytest -import sys, os -from pathlib import Path - -sys.path.insert( - 0, - os.path.abspath( - os.path.join(os.path.dirname(__file__), "../../cli") - ), -) - -# 2) Import from the cli package -import cli.create_credentials as gvc - -class DummyProc: - def __init__(self, returncode, stdout, stderr=''): - self.returncode = returncode - self.stdout = stdout - self.stderr = stderr - -# Monkeypatch subprocess.run for encrypt_with_vault -@pytest.fixture(autouse=True) -def mock_subprocess_run(monkeypatch): - def fake_run(cmd, capture_output, text): - name = None - # find --name= in args - for arg in cmd: - if arg.startswith("--name="): - name = arg.split("=",1)[1] - val = cmd[ cmd.index(name) - 1 ] if name else "key" - # simulate Ansible output - snippet = f"{name or 'key'}: !vault |\n encrypted_{val}" - return DummyProc(0, snippet) - monkeypatch.setattr(gvc.subprocess, 'run', fake_run) - -def test_wrap_existing_vaults(): - data = { - 'a': '$ANSIBLE_VAULT;1.1;AES256...blob', - 'b': {'c': 'normal', 'd': '$ANSIBLE_VAULT;1.1;AES256...other'}, - 'e': ['x', '$ANSIBLE_VAULT;1.1;AES256...list'] - } - wrapped = gvc.wrap_existing_vaults(data) - assert isinstance(wrapped['a'], gvc.VaultScalar) - assert isinstance(wrapped['b']['d'], gvc.VaultScalar) - assert isinstance(wrapped['e'][1], gvc.VaultScalar) - assert wrapped['b']['c'] == 'normal' - assert wrapped['e'][0] == 'x' - -@pytest.mark.parametrize("pairs,expected", [ - (['k=v'], {'k': 'v'}), - (['a.b=1', 'c=two'], {'a.b': '1', 'c': 'two'}), - (['noeq'], {}), -]) -def test_parse_overrides(pairs, expected): - assert gvc.parse_overrides(pairs) == expected - -def test_apply_schema_and_vault(tmp_path): - schema = { - 'cred': {'description':'d','algorithm':'plain','validation':{}}, - 'nested': {'inner': {'description':'d2','algorithm':'plain','validation':{}}} - } - inv = {} - updated = gvc.apply_schema(schema, inv, 'app', {}, 'pwfile') - apps = updated['applications']['app'] - assert isinstance(apps['cred'], gvc.VaultScalar) - assert isinstance(apps['nested']['inner'], gvc.VaultScalar) - -def test_encrypt_leaves_and_credentials(): - branch = {'p':'v','nested':{'q':'u'}} - gvc.encrypt_leaves(branch, 'pwfile') - assert isinstance(branch['p'], gvc.VaultScalar) - assert isinstance(branch['nested']['q'], gvc.VaultScalar) - - inv = {'credentials':{'a':'b'}, 'x':{'credentials':{'c':'d'}}} - gvc.encrypt_credentials_branch(inv, 'pwfile') - assert isinstance(inv['credentials']['a'], gvc.VaultScalar) - assert isinstance(inv['x']['credentials']['c'], gvc.VaultScalar)