diff --git a/src/p2pkg/__main__.py b/src/p2pkg/__main__.py index 26ab636..ca519be 100644 --- a/src/p2pkg/__main__.py +++ b/src/p2pkg/__main__.py @@ -3,6 +3,7 @@ from __future__ import annotations import argparse import pathlib import subprocess +from dataclasses import dataclass INIT_TEMPLATE = """\ @@ -24,6 +25,13 @@ __all__ = getattr(_main, "__all__", [n for n in dir(_main) if not n.startswith(" """ +@dataclass(frozen=True) +class MigrationPlan: + source: pathlib.Path + target_main: pathlib.Path + target_init: pathlib.Path + + def _run(cmd: list[str], cwd: pathlib.Path | None = None) -> None: subprocess.run(cmd, check=True, cwd=str(cwd) if cwd else None) @@ -52,31 +60,119 @@ def _fs_mv(src: pathlib.Path, dst: pathlib.Path) -> None: src.rename(dst) -def migrate_one(py_file: pathlib.Path, use_git: bool, repo_root: pathlib.Path) -> None: - """Migrate a single flat module file (foo.py) into a package (foo/__main__.py).""" - py_file = py_file.resolve() - if py_file.suffix != ".py": - raise ValueError(f"Not a .py file: {py_file}") +def _is_candidate_py_file(path: pathlib.Path) -> bool: + """ + Decide if a file should be migrated. + We skip: + - non-.py files + - __init__.py and __main__.py + - files that are inside a package which already has a __main__.py + (i.e. that package is already runnable via `python -m package`) + """ + if path.suffix != ".py": + return False + if path.name in {"__init__.py", "__main__.py"}: + return False + + # Find the nearest (innermost) package dir this file belongs to, if any. + pkg_dir: pathlib.Path | None = None + for parent in [path.parent, *path.parents]: + if (parent / "__init__.py").exists(): + pkg_dir = parent + break + + # If file is inside a package and that package already has __main__.py -> skip + if pkg_dir is not None and (pkg_dir / "__main__.py").exists(): + return False + + return True + +def _build_plan(py_file: pathlib.Path) -> MigrationPlan: + py_file = py_file.resolve() name = py_file.stem pkg_dir = py_file.parent / name - - if pkg_dir.exists() and not pkg_dir.is_dir(): - raise RuntimeError(f"Target exists but is not a directory: {pkg_dir}") - target_main = pkg_dir / "__main__.py" target_init = pkg_dir / "__init__.py" + return MigrationPlan(source=py_file, target_main=target_main, target_init=target_init) - if target_main.exists(): - raise RuntimeError(f"Refusing to overwrite existing: {target_main}") + +def _discover_py_files_recursive(root: pathlib.Path) -> list[pathlib.Path]: + root = root.resolve() + files: list[pathlib.Path] = [] + for p in root.rglob("*.py"): + if p.is_file() and _is_candidate_py_file(p): + files.append(p.resolve()) + return sorted(set(files)) + + +def _rel(path: pathlib.Path, repo_root: pathlib.Path) -> str: + try: + return str(path.relative_to(repo_root)) + except Exception: + return str(path) + + +def _print_plan(plans: list[MigrationPlan], repo_root: pathlib.Path) -> None: + if not plans: + print("No candidates found.") + return + + print("Planned migrations:") + for plan in plans: + print(f" - {_rel(plan.source, repo_root)} -> {_rel(plan.target_main, repo_root)}") + + +def _confirm_apply() -> bool: + """ + Ask user for confirmation. Default is NO. + Accepts: y, yes (case-insensitive). + """ + try: + ans = input("Apply these changes? [y/N] ").strip().lower() + except EOFError: + return False + return ans in {"y", "yes"} + + +def migrate_one(py_file: pathlib.Path, use_git: bool, repo_root: pathlib.Path) -> None: + """Migrate a single flat module file (foo.py) into a package (foo/__main__.py).""" + plan = _build_plan(py_file) + + if plan.source.suffix != ".py": + raise ValueError(f"Not a .py file: {plan.source}") + + if plan.target_main.exists(): + raise RuntimeError(f"Refusing to overwrite existing: {plan.target_main}") + + if plan.target_main.parent.exists() and not plan.target_main.parent.is_dir(): + raise RuntimeError(f"Target exists but is not a directory: {plan.target_main.parent}") if use_git: - _git_mv(py_file, target_main, cwd=repo_root) + _git_mv(plan.source, plan.target_main, cwd=repo_root) else: - _fs_mv(py_file, target_main) + _fs_mv(plan.source, plan.target_main) - if not target_init.exists(): - target_init.write_text(INIT_TEMPLATE.format(name=name), encoding="utf-8") + if not plan.target_init.exists(): + plan.target_init.write_text(INIT_TEMPLATE.format(name=plan.source.stem), encoding="utf-8") + + +def _apply_plans(plans: list[MigrationPlan], use_git: bool, repo_root: pathlib.Path) -> None: + # Preflight checks: fail fast before doing partial work + for plan in plans: + if plan.target_main.exists(): + raise RuntimeError(f"Refusing to overwrite existing: {plan.target_main}") + if plan.target_main.parent.exists() and not plan.target_main.parent.is_dir(): + raise RuntimeError(f"Target exists but is not a directory: {plan.target_main.parent}") + + for plan in plans: + if use_git: + _git_mv(plan.source, plan.target_main, cwd=repo_root) + else: + _fs_mv(plan.source, plan.target_main) + + if not plan.target_init.exists(): + plan.target_init.write_text(INIT_TEMPLATE.format(name=plan.source.stem), encoding="utf-8") def main(argv: list[str] | None = None) -> int: @@ -84,10 +180,29 @@ def main(argv: list[str] | None = None) -> int: prog="p2pkg", description="Migrate foo.py -> foo/__main__.py and generate foo/__init__.py that re-exports public API.", ) + parser.add_argument( - "files", + "paths", nargs="+", - help="Python module files to migrate (e.g. roles_list.py other.py).", + help="Python module files OR directories (with -R) to migrate (e.g. roles_list.py src/).", + ) + parser.add_argument( + "-R", + "--recursive", + action="store_true", + help="Treat given paths as directories and discover *.py recursively.", + ) + parser.add_argument( + "-p", + "--preview", + action="store_true", + help="Preview only (do not change anything).", + ) + parser.add_argument( + "-f", + "--force", + action="store_true", + help="Apply without asking for confirmation.", ) parser.add_argument( "--no-git", @@ -104,15 +219,71 @@ def main(argv: list[str] | None = None) -> int: repo_root = pathlib.Path(ns.repo_root).resolve() use_git = (not ns.no_git) and _have_git_repo(repo_root) - for f in ns.files: - path = pathlib.Path(f) + if ns.recursive: + dirs: list[pathlib.Path] = [] + for raw in ns.paths: + p = pathlib.Path(raw) + if not p.is_absolute(): + p = (repo_root / p).resolve() + dirs.append(p) + + for d in dirs: + if not d.exists(): + raise FileNotFoundError(str(d)) + if not d.is_dir(): + raise NotADirectoryError(str(d)) + + candidates: list[pathlib.Path] = [] + for d in dirs: + candidates.extend(_discover_py_files_recursive(d)) + + candidates = sorted(set(candidates)) + plans = [_build_plan(p) for p in candidates] + + _print_plan(plans, repo_root=repo_root) + + if ns.preview: + return 0 + + if not plans: + return 0 + + if ns.force or _confirm_apply(): + _apply_plans(plans, use_git=use_git, repo_root=repo_root) + return 0 + + print("Aborted.") + return 1 + + # Non-recursive: treat paths as files (existing behavior) + files: list[pathlib.Path] = [] + for raw in ns.paths: + path = pathlib.Path(raw) if not path.is_absolute(): path = (repo_root / path).resolve() if not path.exists(): raise FileNotFoundError(str(path)) - migrate_one(path, use_git=use_git, repo_root=repo_root) + if path.is_dir(): + raise IsADirectoryError( + f"{path} is a directory. Use -R/--recursive to migrate directories recursively." + ) + files.append(path) - return 0 + plans = [_build_plan(p) for p in files] + _print_plan(plans, repo_root=repo_root) + + if ns.preview: + return 0 + + if not plans: + return 0 + + if ns.force or _confirm_apply(): + _apply_plans(plans, use_git=use_git, repo_root=repo_root) + return 0 + + print("Aborted.") + return 1 if __name__ == "__main__": diff --git a/tests/test_migration.py b/tests/test_migration.py index a3d3fd6..cf38045 100644 --- a/tests/test_migration.py +++ b/tests/test_migration.py @@ -1,34 +1,42 @@ from __future__ import annotations +import io import sys import tempfile import textwrap import unittest +from contextlib import redirect_stdout from pathlib import Path +from unittest.mock import patch -from p2pkg.__main__ import migrate_one + +# Ensure we import THIS repo's src/ implementation, not an installed site-packages one. +REPO_ROOT = Path(__file__).resolve().parents[1] +SRC_DIR = REPO_ROOT / "src" +sys.path.insert(0, str(SRC_DIR)) + +from p2pkg.__main__ import main, migrate_one # noqa: E402 class TestMigration(unittest.TestCase): - def test_migrate_creates_package_and_exports_public_api(self) -> None: + def test_migrate_one_creates_package_and_exports_public_api(self) -> None: with tempfile.TemporaryDirectory() as td: root = Path(td) mod = root / "roles_list.py" mod.write_text( - textwrap.dedent("""\ - __all__ = ["add", "PUBLIC_CONST"] - PUBLIC_CONST = 123 - _PRIVATE_CONST = 999 + textwrap.dedent( + """\ + __all__ = ["add", "PUBLIC_CONST"] + PUBLIC_CONST = 123 + _PRIVATE_CONST = 999 - def add(a: int, b: int) -> int: - return a + b + def add(a: int, b: int) -> int: + return a + b - def _hidden() -> int: - return 1 - - if __name__ == "__main__": - print("running as script") - """), + def _hidden() -> int: + return 1 + """ + ), encoding="utf-8", ) @@ -39,7 +47,6 @@ class TestMigration(unittest.TestCase): self.assertTrue((pkg / "__init__.py").exists()) self.assertFalse(mod.exists()) - # Import the package and ensure re-exports work sys.path.insert(0, str(root)) try: import roles_list # type: ignore @@ -49,45 +56,102 @@ class TestMigration(unittest.TestCase): self.assertTrue(hasattr(roles_list, "PUBLIC_CONST")) self.assertEqual(roles_list.PUBLIC_CONST, 123) - # __all__ should be preserved self.assertEqual(set(roles_list.__all__), {"add", "PUBLIC_CONST"}) self.assertFalse(hasattr(roles_list, "_hidden")) finally: sys.path.remove(str(root)) - if "roles_list" in sys.modules: - del sys.modules["roles_list"] + sys.modules.pop("roles_list", None) - def test_migrate_without_explicit_all_exports_public_names(self) -> None: + def test_main_non_recursive_prompts_and_applies_on_yes(self) -> None: with tempfile.TemporaryDirectory() as td: root = Path(td) mod = root / "foo.py" - mod.write_text( - textwrap.dedent("""\ - VALUE = "ok" + mod.write_text("X = 1\n", encoding="utf-8") - def hello() -> str: - return "hi" + buf = io.StringIO() + with redirect_stdout(buf), patch("builtins.input", return_value="y"): + rc = main(["--no-git", "--repo-root", str(root), str(mod)]) - def _private() -> str: - return "no" - """), - encoding="utf-8", - ) - migrate_one(mod, use_git=False, repo_root=root) + self.assertEqual(rc, 0) + self.assertFalse(mod.exists()) + self.assertTrue((root / "foo" / "__main__.py").exists()) + self.assertTrue((root / "foo" / "__init__.py").exists()) - sys.path.insert(0, str(root)) - try: - import foo # type: ignore - self.assertEqual(foo.hello(), "hi") - self.assertEqual(foo.VALUE, "ok") - self.assertIn("hello", foo.__all__) - self.assertIn("VALUE", foo.__all__) - self.assertNotIn("_private", foo.__all__) - finally: - sys.path.remove(str(root)) - if "foo" in sys.modules: - del sys.modules["foo"] + def test_main_recursive_preview_does_not_apply_and_does_not_prompt(self) -> None: + with tempfile.TemporaryDirectory() as td: + root = Path(td) + target = root / "tree" + target.mkdir() + + (target / "x.py").write_text("X = 1\n", encoding="utf-8") + (target / "y.py").write_text("Y = 2\n", encoding="utf-8") + + buf = io.StringIO() + with redirect_stdout(buf), patch("builtins.input") as mocked_input: + rc = main(["-R", "-p", "--no-git", "--repo-root", str(root), str(target)]) + + self.assertEqual(rc, 0) + mocked_input.assert_not_called() + self.assertTrue((target / "x.py").exists()) + self.assertTrue((target / "y.py").exists()) + self.assertFalse((target / "x").exists()) + self.assertFalse((target / "y").exists()) + + def test_main_recursive_force_applies_without_prompt(self) -> None: + with tempfile.TemporaryDirectory() as td: + root = Path(td) + target = root / "tree" + target.mkdir() + + (target / "x.py").write_text("X = 1\n", encoding="utf-8") + (target / "y.py").write_text("Y = 2\n", encoding="utf-8") + + buf = io.StringIO() + with redirect_stdout(buf), patch("builtins.input") as mocked_input: + rc = main(["-R", "-f", "--no-git", "--repo-root", str(root), str(target)]) + + self.assertEqual(rc, 0) + mocked_input.assert_not_called() + self.assertFalse((target / "x.py").exists()) + self.assertFalse((target / "y.py").exists()) + self.assertTrue((target / "x" / "__main__.py").exists()) + self.assertTrue((target / "y" / "__main__.py").exists()) + + def test_main_recursive_prompts_and_aborts_on_no(self) -> None: + with tempfile.TemporaryDirectory() as td: + root = Path(td) + target = root / "tree" + target.mkdir() + + (target / "x.py").write_text("X = 1\n", encoding="utf-8") + + buf = io.StringIO() + with redirect_stdout(buf), patch("builtins.input", return_value="n"): + rc = main(["-R", "--no-git", "--repo-root", str(root), str(target)]) + + self.assertEqual(rc, 1) + self.assertTrue((target / "x.py").exists()) + self.assertFalse((target / "x").exists()) + + def test_main_recursive_prompts_and_applies_on_yes(self) -> None: + with tempfile.TemporaryDirectory() as td: + root = Path(td) + target = root / "tree" + target.mkdir() + + (target / "x.py").write_text("X = 1\n", encoding="utf-8") + (target / "y.py").write_text("Y = 2\n", encoding="utf-8") + + buf = io.StringIO() + with redirect_stdout(buf), patch("builtins.input", return_value="y"): + rc = main(["-R", "--no-git", "--repo-root", str(root), str(target)]) + + self.assertEqual(rc, 0) + self.assertFalse((target / "x.py").exists()) + self.assertFalse((target / "y.py").exists()) + self.assertTrue((target / "x" / "__main__.py").exists()) + self.assertTrue((target / "y" / "__main__.py").exists()) if __name__ == "__main__":