Skip to content

Commit

Permalink
feat(cli): add bind-key option to CLI (#339)
Browse files Browse the repository at this point in the history
Adds a new `bind-key` option to the CLI for specifying which engine configuration to use for migrations.
  • Loading branch information
cofin authored Jan 13, 2025
1 parent 2caeeda commit ed9f5cd
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 89 deletions.
152 changes: 83 additions & 69 deletions advanced_alchemy/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Sequence, cast
from typing import TYPE_CHECKING, Sequence, Union, cast

if TYPE_CHECKING:
from click import Group
Expand Down Expand Up @@ -43,7 +43,7 @@ def alchemy_group(ctx: click.Context, config: str) -> None:
ctx.ensure_object(dict)
try:
config_instance = module_loader.import_string(config)
if isinstance(config_instance, (list, tuple)):
if isinstance(config_instance, Sequence):
ctx.obj["configs"] = config_instance
else:
ctx.obj["configs"] = [config_instance]
Expand Down Expand Up @@ -72,138 +72,163 @@ def add_migration_commands(database_group: Group | None = None) -> Group: # noq
if database_group is None:
database_group = get_alchemy_group()

bind_key_option = click.option(
"--bind-key",
help="Specify which SQLAlchemy config to use by bind key",
type=str,
default=None,
)
verbose_option = click.option(
"--verbose",
help="Enable verbose output.",
type=bool,
default=False,
is_flag=True,
)
no_prompt_option = click.option(
"--no-prompt",
help="Do not prompt for confirmation before executing the command.",
type=bool,
default=False,
required=False,
show_default=True,
is_flag=True,
)

def get_config_by_bind_key(
ctx: click.Context, bind_key: str | None
) -> SQLAlchemyAsyncConfig | SQLAlchemySyncConfig:
"""Get the SQLAlchemy config for the specified bind key."""
configs = ctx.obj["configs"]
if bind_key is None:
return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", configs[0])

for config in configs:
if config.bind_key == bind_key:
return cast("Union[SQLAlchemyAsyncConfig, SQLAlchemySyncConfig]", config)

console.print(f"[red]No config found for bind key: {bind_key}[/]")
ctx.exit(1) # noqa: RET503

@database_group.command(
name="show-current-revision",
help="Shows the current revision for the database.",
)
@click.option("--verbose", type=bool, help="Enable verbose output.", default=False, is_flag=True)
@click.pass_context
def show_database_revision(ctx: click.Context, verbose: bool) -> None: # pyright: ignore[reportUnusedFunction]
@bind_key_option
@verbose_option
def show_database_revision(bind_key: str | None, verbose: bool) -> None: # pyright: ignore[reportUnusedFunction]
"""Show current database revision."""
from advanced_alchemy.alembic.commands import AlembicCommands

ctx = click.get_current_context()
console.rule("[yellow]Listing current revision[/]", align="left")
sqlalchemy_config = ctx.obj["configs"][0]
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.current(verbose=verbose)

@database_group.command(
name="downgrade",
help="Downgrade database to a specific revision.",
)
@bind_key_option
@click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True)
@click.option(
"--tag",
help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.",
type=str,
default=None,
)
@click.option(
"--no-prompt",
help="Do not prompt for confirmation before downgrading.",
type=bool,
default=False,
required=False,
show_default=True,
is_flag=True,
)
@no_prompt_option
@click.argument(
"revision",
type=str,
default="-1",
)
@click.pass_context
def downgrade_database(ctx: click.Context, revision: str, sql: bool, tag: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
def downgrade_database( # pyright: ignore[reportUnusedFunction]
bind_key: str | None, revision: str, sql: bool, tag: str | None, no_prompt: bool
) -> None:
"""Downgrade the database to the latest revision."""
from rich.prompt import Confirm

from advanced_alchemy.alembic.commands import AlembicCommands

ctx = click.get_current_context()
console.rule("[yellow]Starting database downgrade process[/]", align="left")
input_confirmed = (
True
if no_prompt
else Confirm.ask(f"Are you sure you want to downgrade the database to the `{revision}` revision?")
)
if input_confirmed:
sqlalchemy_config = ctx.obj["configs"][0]
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.downgrade(revision=revision, sql=sql, tag=tag)

@database_group.command(
name="upgrade",
help="Upgrade database to a specific revision.",
)
@bind_key_option
@click.option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True)
@click.option(
"--tag",
help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.",
type=str,
default=None,
)
@click.option(
"--no-prompt",
help="Do not prompt for confirmation before upgrading.",
type=bool,
default=False,
required=False,
show_default=True,
is_flag=True,
)
@no_prompt_option
@click.argument(
"revision",
type=str,
default="head",
)
@click.pass_context
def upgrade_database(ctx: click.Context, revision: str, sql: bool, tag: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
def upgrade_database( # pyright: ignore[reportUnusedFunction]
bind_key: str | None, revision: str, sql: bool, tag: str | None, no_prompt: bool
) -> None:
"""Upgrade the database to the latest revision."""
from rich.prompt import Confirm

from advanced_alchemy.alembic.commands import AlembicCommands

ctx = click.get_current_context()
console.rule("[yellow]Starting database upgrade process[/]", align="left")
input_confirmed = (
True
if no_prompt
else Confirm.ask(f"[bold]Are you sure you want migrate the database to the `{revision}` revision?[/]")
)
if input_confirmed:
sqlalchemy_config = ctx.obj["configs"][0]
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.upgrade(revision=revision, sql=sql, tag=tag)

@database_group.command(
name="init",
help="Initialize migrations for the project.",
)
@bind_key_option
@click.argument("directory", default=None)
@click.option("--multidb", is_flag=True, default=False, help="Support multiple databases")
@click.option("--package", is_flag=True, default=True, help="Create `__init__.py` for created folder")
@click.option(
"--no-prompt",
help="Do not prompt for confirmation before initializing.",
type=bool,
default=False,
required=False,
show_default=True,
is_flag=True,
)
@click.pass_context
def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, package: bool, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
@no_prompt_option
def init_alembic( # pyright: ignore[reportUnusedFunction]
bind_key: str | None, directory: str | None, multidb: bool, package: bool, no_prompt: bool
) -> None:
"""Initialize the database migrations."""
from rich.prompt import Confirm

from advanced_alchemy.alembic.commands import AlembicCommands

ctx = click.get_current_context()
console.rule("[yellow]Initializing database migrations.", align="left")
input_confirmed = (
True
if no_prompt
else Confirm.ask(f"[bold]Are you sure you want initialize the project in `{directory}`?[/]")
)
if input_confirmed:
for config in ctx.obj["configs"]:
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
for config in configs:
directory = config.alembic_config.script_location if directory is None else directory
alembic_commands = AlembicCommands(sqlalchemy_config=config)
alembic_commands.init(directory=cast("str", directory), multidb=multidb, package=package)
Expand All @@ -212,6 +237,7 @@ def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, packa
name="make-migrations",
help="Create a new migration revision.",
)
@bind_key_option
@click.option("-m", "--message", default=None, help="Revision message")
@click.option(
"--autogenerate/--no-autogenerate", default=True, help="Automatically populate revision with detected changes"
Expand All @@ -224,18 +250,9 @@ def init_alembic(ctx: click.Context, directory: str | None, multidb: bool, packa
@click.option("--branch-label", default=None, help="Specify a branch label to apply to the new revision")
@click.option("--version-path", default=None, help="Specify specific path from config for version file")
@click.option("--rev-id", default=None, help="Specify a ID to use for revision.")
@click.option(
"--no-prompt",
help="Do not prompt for a migration message.",
type=bool,
default=False,
required=False,
show_default=True,
is_flag=True,
)
@click.pass_context
@no_prompt_option
def create_revision( # pyright: ignore[reportUnusedFunction]
ctx: click.Context,
bind_key: str | None,
message: str | None,
autogenerate: bool,
sql: bool,
Expand Down Expand Up @@ -275,11 +292,12 @@ def process_revision_directives(
)
directives.clear()

ctx = click.get_current_context()
console.rule("[yellow]Starting database upgrade process[/]", align="left")
if message is None:
message = "autogenerated" if no_prompt else Prompt.ask("Please enter a message describing this revision")

sqlalchemy_config = ctx.obj["configs"][0]
sqlalchemy_config = get_config_by_bind_key(ctx, bind_key)
alembic_commands = AlembicCommands(sqlalchemy_config=sqlalchemy_config)
alembic_commands.revision(
message=message,
Expand All @@ -294,24 +312,17 @@ def process_revision_directives(
)

@database_group.command(name="drop-all", help="Drop all tables from the database.")
@click.option(
"--no-prompt",
help="Do not prompt for confirmation before upgrading.",
type=bool,
default=False,
required=False,
show_default=True,
is_flag=True,
)
@click.pass_context
def drop_all(ctx: click.Context, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
@bind_key_option
@no_prompt_option
def drop_all(bind_key: str | None, no_prompt: bool) -> None: # pyright: ignore[reportUnusedFunction]
"""Drop all tables from the database."""
from anyio import run
from rich.prompt import Confirm

from advanced_alchemy.alembic.utils import drop_all
from advanced_alchemy.base import metadata_registry

ctx = click.get_current_context()
console.rule("[yellow]Dropping all tables from the database[/]", align="left")
input_confirmed = no_prompt or Confirm.ask(
"[bold red]Are you sure you want to drop all tables from the database?"
Expand All @@ -325,9 +336,11 @@ async def _drop_all(
await drop_all(engine, config.alembic_config.version_table_name, metadata_registry.get(config.bind_key))

if input_confirmed:
run(_drop_all, ctx.obj["configs"])
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
run(_drop_all, configs)

@database_group.command(name="dump-data", help="Dump specified tables from the database to JSON files.")
@bind_key_option
@click.option(
"--table",
"table_names",
Expand All @@ -344,15 +357,15 @@ async def _drop_all(
default=Path.cwd() / "fixtures",
required=False,
)
@click.pass_context
def dump_table_data(ctx: click.Context, table_names: tuple[str, ...], dump_dir: Path) -> None: # pyright: ignore[reportUnusedFunction]
def dump_table_data(bind_key: str | None, table_names: tuple[str, ...], dump_dir: Path) -> None: # pyright: ignore[reportUnusedFunction]
"""Dump table data to JSON files."""
from anyio import run
from rich.prompt import Confirm

from advanced_alchemy.alembic.utils import dump_tables
from advanced_alchemy.base import metadata_registry, orm_registry

ctx = click.get_current_context()
all_tables = "*" in table_names

if all_tables and not Confirm.ask(
Expand All @@ -361,7 +374,8 @@ def dump_table_data(ctx: click.Context, table_names: tuple[str, ...], dump_dir:
return console.rule("[red bold]No data was dumped.", style="red", align="left")

async def _dump_tables() -> None:
for config in ctx.obj["configs"]:
configs = [get_config_by_bind_key(ctx, bind_key)] if bind_key is not None else ctx.obj["configs"]
for config in configs:
target_tables = set(metadata_registry.get(config.bind_key).tables)

if not all_tables:
Expand Down
16 changes: 6 additions & 10 deletions advanced_alchemy/extensions/litestar/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from contextlib import suppress
from typing import TYPE_CHECKING

from litestar.cli._utils import LitestarGroup

from advanced_alchemy.cli import add_migration_commands

try:
import rich_click as click
except ImportError:
import click # type: ignore[no-redef]
from litestar.cli._utils import LitestarGroup

from advanced_alchemy.cli import add_migration_commands

if TYPE_CHECKING:
from litestar import Litestar
Expand All @@ -18,11 +19,7 @@


def get_database_migration_plugin(app: Litestar) -> SQLAlchemyInitPlugin:
"""Retrieve a database migration plugin from the Litestar application's plugins.
This function attempts to find and return either the SQLAlchemyPlugin or SQLAlchemyInitPlugin.
If neither plugin is found, it raises an ImproperlyConfiguredException.
"""
"""Retrieve a database migration plugin from the Litestar application's plugins."""
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyInitPlugin

Expand All @@ -33,10 +30,9 @@ def get_database_migration_plugin(app: Litestar) -> SQLAlchemyInitPlugin:


@click.group(cls=LitestarGroup, name="database")
@click.pass_context
def database_group(ctx: click.Context) -> None:
"""Manage SQLAlchemy database components."""
ctx.obj = get_database_migration_plugin(ctx.obj.app).config
ctx.obj = {"app": ctx.obj, "configs": get_database_migration_plugin(ctx.obj.app).config}


add_migration_commands(database_group)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ sqlite = ["aiosqlite>=0.20.0"]
test = [
"pydantic-extra-types < 2.9.0; python_version < \"3.9\"",
"pydantic-extra-types; python_version >= \"3.9\"",
"rich-click",
"coverage>=7.6.1",
"pytest>=7.4.4",
"pytest-asyncio>=0.23.8",
Expand Down Expand Up @@ -468,6 +469,7 @@ exclude = [
include = ["advanced_alchemy"]
pythonVersion = "3.8"
reportUnnecessaryTypeIgnoreComments = true
reportUnusedFunction = false
strict = ["advanced_alchemy/**/*"]
venv = ".venv"
venvPath = "."
Expand Down
Loading

0 comments on commit ed9f5cd

Please sign in to comment.