diff --git a/advanced_alchemy/cli.py b/advanced_alchemy/cli.py index 45ae135e..e40c4206 100644 --- a/advanced_alchemy/cli.py +++ b/advanced_alchemy/cli.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Sequence, cast +from typing import TYPE_CHECKING, Sequence, Union, cast if TYPE_CHECKING: from click import Group @@ -43,7 +43,7 @@ def alchemy_group(ctx: click.Context, config: str) -> None: ctx.ensure_object(dict) try: config_instance = module_loader.import_string(config) - if isinstance(config_instance, (list, tuple)): + if isinstance(config_instance, Sequence): ctx.obj["configs"] = config_instance else: ctx.obj["configs"] = [config_instance] @@ -72,18 +72,57 @@ def add_migration_commands(database_group: Group | None = None) -> Group: # noq if database_group is None: database_group = get_alchemy_group() + bind_key_option = click.option( + "--bind-key", + help="Specify which SQLAlchemy config to use by bind key", + type=str, + default=None, + ) + verbose_option = click.option( + "--verbose", + help="Enable verbose output.", + type=bool, + default=False, + is_flag=True, + ) + no_prompt_option = click.option( + "--no-prompt", + help="Do not prompt for confirmation before executing the command.", + type=bool, + default=False, + required=False, + show_default=True, + is_flag=True, + ) + + def get_config_by_bind_key( + ctx: click.Context, bind_key: str | None + ) -> SQLAlchemyAsyncConfig | SQLAlchemySyncConfig: + """Get the SQLAlchemy config for the specified bind key.""" + configs = ctx.obj["configs"] + if bind_key is None: + return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", configs[0]) + + for config in configs: + if config.bind_key == bind_key: + return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", config) + + console.print(f"[red]No config found for bind key: {bind_key}[/]") + ctx.exit(1) # noqa: RET503 + @database_group.command( name="show-current-revision", help="Shows the current revision for the database.", ) - @click.option("--verbose", type=bool, help="Enable verbose output.", default=False, is_flag=True) - @click.pass_context - def show_database_revision(ctx: click.Context, verbose: bool) -> None: # pyright: ignore[reportUnusedFunction] + @bind_key_option + @verbose_option + def show_database_revision(bind_key: str | None, verbose: bool) -> None: # pyright: ignore[reportUnusedFunction] """Show current database revision.""" from advanced_alchemy.alembic.commands import AlembicCommands + ctx = click.get_current_context() console.rule("[yellow]Listing current revision[/]", align="left") - sqlalchemy_config = ctx.obj["configs"][0] + sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.current(verbose=verbose) @@ -91,6 +130,7 @@ def show_database_revision(ctx: click.Context, verbose: bool) -> None: # pyrigh name="downgrade", help="Downgrade database to a specific revision.", ) + @bind_key_option @click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True) @click.option( "--tag", @@ -98,27 +138,21 @@ def show_database_revision(ctx: click.Context, verbose: bool) -> None: # pyrigh type=str, default=None, ) - @click.option( - "--no-prompt", - help="Do not prompt for confirmation before downgrading.", - type=bool, - default=False, - required=False, - show_default=True, - is_flag=True, - ) + @no_prompt_option @click.argument( "revision", type=str, default="-1", ) - @click.pass_context - def downgrade_database(ctx: click.Context, revision: str, sql: bool, tag: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction] + def downgrade_database( # pyright: ignore[reportUnusedFunction] + bind_key: str | None, revision: str, sql: bool, tag: str | None, no_prompt: bool + ) -> None: """Downgrade the database to the latest revision.""" from rich.prompt import Confirm from advanced_alchemy.alembic.commands import AlembicCommands + ctx = click.get_current_context() console.rule("[yellow]Starting database downgrade process[/]", align="left") input_confirmed = ( True @@ -126,7 +160,7 @@ def downgrade_database(ctx: click.Context, revision: str, sql: bool, tag: str | else Confirm.ask(f"Are you sure you want to downgrade the database to the `{revision}` revision?") ) if input_confirmed: - sqlalchemy_config = ctx.obj["configs"][0] + sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.downgrade(revision=revision, sql=sql, tag=tag) @@ -134,6 +168,7 @@ def downgrade_database(ctx: click.Context, revision: str, sql: bool, tag: str | name="upgrade", help="Upgrade database to a specific revision.", ) + @bind_key_option @click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True) @click.option( "--tag", @@ -141,27 +176,21 @@ def downgrade_database(ctx: click.Context, revision: str, sql: bool, tag: str | type=str, default=None, ) - @click.option( - "--no-prompt", - help="Do not prompt for confirmation before upgrading.", - type=bool, - default=False, - required=False, - show_default=True, - is_flag=True, - ) + @no_prompt_option @click.argument( "revision", type=str, default="head", ) - @click.pass_context - def upgrade_database(ctx: click.Context, revision: str, sql: bool, tag: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction] + def upgrade_database( # pyright: ignore[reportUnusedFunction] + bind_key: str | None, revision: str, sql: bool, tag: str | None, no_prompt: bool + ) -> None: """Upgrade the database to the latest revision.""" from rich.prompt import Confirm from advanced_alchemy.alembic.commands import AlembicCommands + ctx = click.get_current_context() console.rule("[yellow]Starting database upgrade process[/]", align="left") input_confirmed = ( True @@ -169,7 +198,7 @@ def upgrade_database(ctx: click.Context, revision: str, sql: bool, tag: str | No else Confirm.ask(f"[bold]Are you sure you want migrate the database to the `{revision}` revision?[/]") ) if input_confirmed: - sqlalchemy_config = ctx.obj["configs"][0] + sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.upgrade(revision=revision, sql=sql, tag=tag) @@ -177,25 +206,20 @@ def upgrade_database(ctx: click.Context, revision: str, sql: bool, tag: str | No name="init", help="Initialize migrations for the project.", ) + @bind_key_option @click.argument("directory", default=None) @click.option("--multidb", is_flag=True, default=False, help="Support multiple databases") @click.option("--package", is_flag=True, default=True, help="Create `__init__.py` for created folder") - @click.option( - "--no-prompt", - help="Do not prompt for confirmation before initializing.", - type=bool, - default=False, - required=False, - show_default=True, - is_flag=True, - ) - @click.pass_context - def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, package: bool, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction] + @no_prompt_option + def init_alembic( # pyright: ignore[reportUnusedFunction] + bind_key: str | None, directory: str | None, multidb: bool, package: bool, no_prompt: bool + ) -> None: """Initialize the database migrations.""" from rich.prompt import Confirm from advanced_alchemy.alembic.commands import AlembicCommands + ctx = click.get_current_context() console.rule("[yellow]Initializing database migrations.", align="left") input_confirmed = ( True @@ -203,7 +227,8 @@ def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, packa else Confirm.ask(f"[bold]Are you sure you want initialize the project in `{directory}`?[/]") ) if input_confirmed: - for config in ctx.obj["configs"]: + configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"] + for config in configs: directory = config.alembic_config.script_location if directory is None else directory alembic_commands = AlembicCommands(sqlalchemy_config=config) alembic_commands.init(directory=cast("str", directory), multidb=multidb, package=package) @@ -212,6 +237,7 @@ def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, packa name="make-migrations", help="Create a new migration revision.", ) + @bind_key_option @click.option("-m", "--message", default=None, help="Revision message") @click.option( "--autogenerate/--no-autogenerate", default=True, help="Automatically populate revision with detected changes" @@ -224,18 +250,9 @@ def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, packa @click.option("--branch-label", default=None, help="Specify a branch label to apply to the new revision") @click.option("--version-path", default=None, help="Specify specific path from config for version file") @click.option("--rev-id", default=None, help="Specify a ID to use for revision.") - @click.option( - "--no-prompt", - help="Do not prompt for a migration message.", - type=bool, - default=False, - required=False, - show_default=True, - is_flag=True, - ) - @click.pass_context + @no_prompt_option def create_revision( # pyright: ignore[reportUnusedFunction] - ctx: click.Context, + bind_key: str | None, message: str | None, autogenerate: bool, sql: bool, @@ -275,11 +292,12 @@ def process_revision_directives( ) directives.clear() + ctx = click.get_current_context() console.rule("[yellow]Starting database upgrade process[/]", align="left") if message is None: message = "autogenerated" if no_prompt else Prompt.ask("Please enter a message describing this revision") - sqlalchemy_config = ctx.obj["configs"][0] + sqlalchemy_config = get_config_by_bind_key(ctx, bind_key) alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config) alembic_commands.revision( message=message, @@ -294,17 +312,9 @@ def process_revision_directives( ) @database_group.command(name="drop-all", help="Drop all tables from the database.") - @click.option( - "--no-prompt", - help="Do not prompt for confirmation before upgrading.", - type=bool, - default=False, - required=False, - show_default=True, - is_flag=True, - ) - @click.pass_context - def drop_all(ctx: click.Context, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction] + @bind_key_option + @no_prompt_option + def drop_all(bind_key: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction] """Drop all tables from the database.""" from anyio import run from rich.prompt import Confirm @@ -312,6 +322,7 @@ def drop_all(ctx: click.Context, no_prompt: bool) -> None: # pyright: ignore[re from advanced_alchemy.alembic.utils import drop_all from advanced_alchemy.base import metadata_registry + ctx = click.get_current_context() console.rule("[yellow]Dropping all tables from the database[/]", align="left") input_confirmed = no_prompt or Confirm.ask( "[bold red]Are you sure you want to drop all tables from the database?" @@ -325,9 +336,11 @@ async def _drop_all( await drop_all(engine, config.alembic_config.version_table_name, metadata_registry.get(config.bind_key)) if input_confirmed: - run(_drop_all, ctx.obj["configs"]) + configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"] + run(_drop_all, configs) @database_group.command(name="dump-data", help="Dump specified tables from the database to JSON files.") + @bind_key_option @click.option( "--table", "table_names", @@ -344,8 +357,7 @@ async def _drop_all( default=Path.cwd() / "fixtures", required=False, ) - @click.pass_context - def dump_table_data(ctx: click.Context, table_names: tuple[str, ...], dump_dir: Path) -> None: # pyright: ignore[reportUnusedFunction] + def dump_table_data(bind_key: str | None, table_names: tuple[str, ...], dump_dir: Path) -> None: # pyright: ignore[reportUnusedFunction] """Dump table data to JSON files.""" from anyio import run from rich.prompt import Confirm @@ -353,6 +365,7 @@ def dump_table_data(ctx: click.Context, table_names: tuple[str, ...], dump_dir: from advanced_alchemy.alembic.utils import dump_tables from advanced_alchemy.base import metadata_registry, orm_registry + ctx = click.get_current_context() all_tables = "*" in table_names if all_tables and not Confirm.ask( @@ -361,7 +374,8 @@ def dump_table_data(ctx: click.Context, table_names: tuple[str, ...], dump_dir: return console.rule("[red bold]No data was dumped.", style="red", align="left") async def _dump_tables() -> None: - for config in ctx.obj["configs"]: + configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"] + for config in configs: target_tables = set(metadata_registry.get(config.bind_key).tables) if not all_tables: diff --git a/advanced_alchemy/extensions/litestar/cli.py b/advanced_alchemy/extensions/litestar/cli.py index 1aefe3cd..a72b56c9 100644 --- a/advanced_alchemy/extensions/litestar/cli.py +++ b/advanced_alchemy/extensions/litestar/cli.py @@ -3,13 +3,14 @@ from contextlib import suppress from typing import TYPE_CHECKING +from litestar.cli._utils import LitestarGroup + +from advanced_alchemy.cli import add_migration_commands + try: import rich_click as click except ImportError: import click # type: ignore[no-redef] -from litestar.cli._utils import LitestarGroup - -from advanced_alchemy.cli import add_migration_commands if TYPE_CHECKING: from litestar import Litestar @@ -18,11 +19,7 @@ def get_database_migration_plugin(app: Litestar) -> SQLAlchemyInitPlugin: - """Retrieve a database migration plugin from the Litestar application's plugins. - - This function attempts to find and return either the SQLAlchemyPlugin or SQLAlchemyInitPlugin. - If neither plugin is found, it raises an ImproperlyConfiguredException. - """ + """Retrieve a database migration plugin from the Litestar application's plugins.""" from advanced_alchemy.exceptions import ImproperConfigurationError from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyInitPlugin @@ -33,10 +30,9 @@ def get_database_migration_plugin(app: Litestar) -> SQLAlchemyInitPlugin: @click.group(cls=LitestarGroup, name="database") -@click.pass_context def database_group(ctx: click.Context) -> None: """Manage SQLAlchemy database components.""" - ctx.obj = get_database_migration_plugin(ctx.obj.app).config + ctx.obj = {"app": ctx.obj, "configs": get_database_migration_plugin(ctx.obj.app).config} add_migration_commands(database_group) diff --git a/pyproject.toml b/pyproject.toml index dedbcaa4..b4149923 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,6 +147,7 @@ sqlite = ["aiosqlite>=0.20.0"] test = [ "pydantic-extra-types < 2.9.0; python_version < \"3.9\"", "pydantic-extra-types; python_version >= \"3.9\"", + "rich-click", "coverage>=7.6.1", "pytest>=7.4.4", "pytest-asyncio>=0.23.8", @@ -468,6 +469,7 @@ exclude = [ include = ["advanced_alchemy"] pythonVersion = "3.8" reportUnnecessaryTypeIgnoreComments = true +reportUnusedFunction = false strict = ["advanced_alchemy/**/*"] venv = ".venv" venvPath = "." diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 8f1bde42..c34ff0de 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -9,7 +9,7 @@ from click.testing import CliRunner from sqlalchemy.ext.asyncio import AsyncEngine -from advanced_alchemy.cli import add_migration_commands +from advanced_alchemy.cli import add_migration_commands, get_alchemy_group if TYPE_CHECKING: from click import Group @@ -42,8 +42,9 @@ def mock_context(mock_config: MagicMock) -> Generator[MagicMock, None, None]: @pytest.fixture def database_cli(mock_context: MagicMock) -> Generator[Group, None, None]: """Create the database CLI group.""" + cli_group = get_alchemy_group() cli_group = add_migration_commands() - cli_group.context = mock_context # pyright: ignore[reportAttributeAccessIssue] + cli_group.ctx = mock_context # pyright: ignore[reportAttributeAccessIssue] yield cli_group @@ -69,7 +70,7 @@ def test_downgrade_database( if no_prompt: args.append("--no-prompt") - result = cli_runner.invoke(database_cli, args, obj=mock_context.obj) + result = cli_runner.invoke(database_cli, args) if no_prompt: assert result.exit_code == 0 @@ -90,7 +91,7 @@ def test_upgrade_database(cli_runner: CliRunner, database_cli: Group, mock_conte if no_prompt: args.append("--no-prompt") - result = cli_runner.invoke(database_cli, args, obj=mock_context.obj) + result = cli_runner.invoke(database_cli, args) if no_prompt: assert result.exit_code == 0 @@ -108,7 +109,6 @@ def test_init_alembic(cli_runner: CliRunner, database_cli: Group, mock_context: result = cli_runner.invoke( database_cli, ["--config", "tests.unit.fixtures.configs", "init", "--no-prompt", "migrations"], - obj=mock_context.obj, ) assert result.exit_code == 0 mock_alembic.assert_called_once() @@ -121,7 +121,6 @@ def test_make_migrations(cli_runner: CliRunner, database_cli: Group, mock_contex result = cli_runner.invoke( database_cli, ["--config", "tests.unit.fixtures.configs", "make-migrations", "--no-prompt", "-m", "test migration"], - obj=mock_context.obj, ) assert result.exit_code == 0 mock_alembic.assert_called_once() @@ -131,9 +130,7 @@ def test_make_migrations(cli_runner: CliRunner, database_cli: Group, mock_contex def test_drop_all(cli_runner: CliRunner, database_cli: Group, mock_context: MagicMock) -> None: """Test the drop-all command.""" - result = cli_runner.invoke( - database_cli, ["--config", "tests.unit.fixtures.configs", "drop-all", "--no-prompt"], obj=mock_context.obj - ) + result = cli_runner.invoke(database_cli, ["--config", "tests.unit.fixtures.configs", "drop-all", "--no-prompt"]) assert result.exit_code == 0 @@ -143,7 +140,6 @@ def test_dump_data(cli_runner: CliRunner, database_cli: Group, mock_context: Mag result = cli_runner.invoke( database_cli, ["--config", "tests.unit.fixtures.configs", "dump-data", "--table", "test_table", "--dir", str(tmp_path)], - obj=mock_context.obj, ) assert result.exit_code == 0 diff --git a/uv.lock b/uv.lock index 3cc19edf..0f0fbc58 100644 --- a/uv.lock +++ b/uv.lock @@ -85,6 +85,7 @@ dev = [ { name = "pytest-sugar" }, { name = "pytest-xdist" }, { name = "pytz" }, + { name = "rich-click" }, { name = "ruff" }, { name = "sanic", version = "24.6.0", source = { registry = "https://pypi.org/simple" }, extra = ["ext"], marker = "python_full_version < '3.9'" }, { name = "sanic", version = "24.12.0", source = { registry = "https://pypi.org/simple" }, extra = ["ext"], marker = "python_full_version >= '3.9'" }, @@ -228,6 +229,7 @@ test = [ { name = "pytest-mock" }, { name = "pytest-sugar" }, { name = "pytest-xdist" }, + { name = "rich-click" }, { name = "time-machine", version = "2.15.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "time-machine", version = "2.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, ] @@ -289,6 +291,7 @@ dev = [ { name = "pytest-sugar", specifier = ">=1.0.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "pytz", specifier = ">=2024.2" }, + { name = "rich-click" }, { name = "ruff", specifier = ">=0.7.1" }, { name = "sanic", marker = "python_full_version < '3.9'", specifier = "<24.12" }, { name = "sanic", marker = "python_full_version >= '3.9'" }, @@ -406,6 +409,7 @@ test = [ { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "pytest-sugar", specifier = ">=1.0.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, + { name = "rich-click" }, { name = "time-machine", specifier = ">=2.15.0" }, ]