Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Materialization of the "main" keypair concept #1761

Merged
merged 13 commits into from
Dec 7, 2023
1 change: 1 addition & 0 deletions changes/1761.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement the concept of the "main" keypair to make it clear which keypair to use by default and which one holds the user-level resource limits
10 changes: 10 additions & 0 deletions src/ai/backend/client/cli/admin/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def info(ctx: CLIContext, email: str) -> None:
user_fields["groups"],
user_fields["allowed_client_ip"],
user_fields["sudo_session_enabled"],
user_fields["main_access_key"],
]
with Session() as session:
try:
Expand Down Expand Up @@ -148,6 +149,7 @@ def list(ctx: CLIContext, status, group, filter_, order, offset, limit) -> None:
user_fields["groups"],
user_fields["allowed_client_ip"],
user_fields["sudo_session_enabled"],
user_fields["main_access_key"],
]
try:
with Session() as session:
Expand Down Expand Up @@ -348,6 +350,12 @@ def add(
"Note that this feature does not automatically install sudo for the session."
),
)
@click.option(
"--main-access-key",
type=OptionalType(str),
default=undefined,
help="Set main access key which works as default.",
)
def update(
ctx: CLIContext,
email: str,
Expand All @@ -361,6 +369,7 @@ def update(
allowed_ip: Sequence[str] | Undefined,
description: str | Undefined,
sudo_session_enabled: bool | Undefined,
main_access_key: str | Undefined,
):
"""
Update an existing user.
Expand All @@ -382,6 +391,7 @@ def update(
allowed_client_ip=allowed_ip,
description=description,
sudo_session_enabled=sudo_session_enabled,
main_access_key=main_access_key,
)
except Exception as e:
ctx.output.print_mutation_error(
Expand Down
4 changes: 4 additions & 0 deletions src/ai/backend/client/func/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
user_fields["allowed_client_ip"],
user_fields["totp_activated"],
user_fields["sudo_session_enabled"],
user_fields["main_access_key"],
)

_default_detail_fields = (
Expand All @@ -52,6 +53,7 @@
user_fields["allowed_client_ip"],
user_fields["totp_activated"],
user_fields["sudo_session_enabled"],
user_fields["main_access_key"],
)


Expand Down Expand Up @@ -322,6 +324,7 @@ async def update(
totp_activated: bool | Undefined = undefined,
group_ids: Iterable[str] | Undefined = undefined,
sudo_session_enabled: bool | Undefined = undefined,
main_access_key: str | Undefined = undefined,
fields: Iterable[FieldSpec | str] | None = None,
) -> dict:
"""
Expand All @@ -348,6 +351,7 @@ async def update(
set_if_set(inputs, "totp_activated", totp_activated)
set_if_set(inputs, "group_ids", group_ids)
set_if_set(inputs, "sudo_session_enabled", sudo_session_enabled)
set_if_set(inputs, "main_access_key", main_access_key)
variables = {
"email": email,
"input": inputs,
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/client/output/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@
FieldSpec("allowed_client_ip"),
FieldSpec("totp_activated"),
FieldSpec("sudo_session_enabled"),
FieldSpec("main_access_key"),
]
)

Expand Down
3 changes: 3 additions & 0 deletions src/ai/backend/manager/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,6 @@ class LockID(enum.IntEnum):


SERVICE_MAX_RETRIES = 5 # FIXME: make configurable

DEFAULT_KEYPAIR_RESOURCE_POLICY_NAME: Final = "default"
DEFAULT_KEYPAIR_RATE_LIMIT: Final = 10000
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""user_main_keypair

Revision ID: d3f8c74bf148
Revises: 308bcecec5c2
Create Date: 2023-12-06 12:20:11.537908

"""
import enum
from typing import Any

import sqlalchemy as sa
from alembic import op
from sqlalchemy.orm import registry, relationship, selectinload, sessionmaker

from ai.backend.manager.defs import DEFAULT_KEYPAIR_RATE_LIMIT, DEFAULT_KEYPAIR_RESOURCE_POLICY_NAME
from ai.backend.manager.models.base import GUID, EnumValueType, convention
from ai.backend.manager.models.keypair import generate_keypair, generate_ssh_keypair

# revision identifiers, used by Alembic.
revision = "d3f8c74bf148"
down_revision = "308bcecec5c2"
branch_labels = None
depends_on = None


metadata = sa.MetaData(naming_convention=convention)
mapper_registry = registry(metadata=metadata)
Base = mapper_registry.generate_base()

PAGE_SIZE = 100


class UserRole(str, enum.Enum):
"""
User's role.
"""

SUPERADMIN = "superadmin"
ADMIN = "admin"
USER = "user"
MONITOR = "monitor"


MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB


def upgrade():
op.add_column("users", sa.Column("main_access_key", sa.String(length=20), nullable=True))
op.create_foreign_key(
op.f("fk_users_main_access_key_keypairs"),
"users",
"keypairs",
["main_access_key"],
["access_key"],
)

# Update all user's main_access_key
# The oldest keypair of a user will be main_access_key
class KeyPairRow(Base): # type: ignore[valid-type, misc]
__tablename__ = "keypairs"
__table_args__ = {"extend_existing": True}

user_id = sa.Column("user_id", sa.String(length=256))
access_key = sa.Column("access_key", sa.String(length=20), primary_key=True)
secret_key = sa.Column("secret_key", sa.String(length=40))
is_active = sa.Column("is_active", sa.Boolean, index=True)
is_admin = sa.Column(
"is_admin", sa.Boolean, index=True, default=False, server_default=sa.false()
)
created_at = sa.Column("created_at", sa.DateTime(timezone=True))
rate_limit = sa.Column("rate_limit", sa.Integer)
num_queries = sa.Column("num_queries", sa.Integer, server_default="0")
ssh_public_key = sa.Column("ssh_public_key", sa.Text, nullable=True)
ssh_private_key = sa.Column("ssh_private_key", sa.Text, nullable=True)
resource_policy = sa.Column(
"resource_policy",
sa.String(length=256),
sa.ForeignKey("keypair_resource_policies.name"),
nullable=False,
)
dotfiles = sa.Column(
"dotfiles", sa.LargeBinary(length=MAXIMUM_DOTFILE_SIZE), nullable=False, default=b"\x90"
)
bootstrap_script = sa.Column(
"bootstrap_script", sa.String(length=MAXIMUM_DOTFILE_SIZE), nullable=False, default=""
)
user = sa.Column("user", GUID, sa.ForeignKey("users.uuid"))
user_row = relationship("UserRow", back_populates="keypairs", foreign_keys=user)

class UserRow(Base): # type: ignore[valid-type, misc]
__tablename__ = "users"
__table_args__ = {"extend_existing": True}

uuid = sa.Column("uuid", GUID, primary_key=True)
role = sa.Column("role", EnumValueType(UserRole), default=UserRole.USER)
email = sa.Column("email", sa.String(length=64))
main_access_key = sa.Column(
"main_access_key",
sa.String(length=20),
sa.ForeignKey("keypairs.access_key", ondelete="RESTRICT"),
nullable=True,
)
keypairs = relationship(
"KeyPairRow", back_populates="user_row", foreign_keys=KeyPairRow.user
)
main_keypair = relationship("KeyPairRow", foreign_keys=main_access_key)

def pick_main_keypair(keypair_list: list[KeyPairRow]) -> KeyPairRow | None:
try:
return sorted(keypair_list, key=lambda k: k.created_at)[0]
except IndexError:
return None

def prepare_keypair(
user_email,
user_id,
user_role,
) -> dict[str, Any]:
ak, sk = generate_keypair()
pubkey, privkey = generate_ssh_keypair()
return {
"user_id": user_email,
"user": user_id,
"access_key": ak,
"secret_key": sk,
"is_active": True,
"is_admin": user_role in (UserRole.SUPERADMIN, UserRole.ADMIN),
"resource_policy": DEFAULT_KEYPAIR_RESOURCE_POLICY_NAME,
"rate_limit": DEFAULT_KEYPAIR_RATE_LIMIT,
"num_queries": 0,
"ssh_public_key": pubkey,
"ssh_private_key": privkey,
}

connection = op.get_bind()
sess_factory = sessionmaker(connection)
db_session = sess_factory()
while True:
user_id_pk_maps = []
user_query = (
sa.select(UserRow)
.where(UserRow.main_access_key.is_(sa.null()))
.limit(PAGE_SIZE)
.options(selectinload(UserRow.keypairs))
)
user_rows: list[UserRow] = db_session.scalars(user_query).all()

if not user_rows:
break

for row in user_rows:
primary = pick_main_keypair(row.keypairs)
if primary is None:
# Create new keypair when the user has no keypair
kp_data = prepare_keypair(row.email, row.uuid, row.role)
db_session.execute(sa.insert(KeyPairRow).values(**kp_data))
user_id_pk_maps.append(
{"user_id": row.uuid, "main_access_key": kp_data["access_key"]}
)
else:
user_id_pk_maps.append({"user_id": row.uuid, "main_access_key": primary.access_key})
achimnol marked this conversation as resolved.
Show resolved Hide resolved

update_query = (
sa.update(UserRow)
.where(UserRow.uuid == sa.bindparam("user_id"))
.values(main_access_key=sa.bindparam("main_access_key"))
)
db_session.execute(update_query, user_id_pk_maps)


def downgrade():
op.drop_constraint(op.f("fk_users_main_access_key_keypairs"), "users", type_="foreignkey")
op.drop_column("users", "main_access_key")
14 changes: 9 additions & 5 deletions src/ai/backend/manager/models/keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class KeyPairRow(Base):
back_populates="keypairs",
)

user_row = relationship("UserRow", back_populates="keypairs", foreign_keys=keypairs.c.user)


class UserInfo(graphene.ObjectType):
email = graphene.String()
Expand Down Expand Up @@ -630,11 +632,13 @@ async def mutate(
) -> DeleteKeyPair:
ctx: GraphQueryContext = info.context
delete_query = sa.delete(keypairs).where(keypairs.c.access_key == access_key)
await redis_helper.execute(
ctx.redis_stat,
lambda r: r.delete(f"keypair.concurrency_used.{access_key}"),
)
return await simple_db_mutate(cls, ctx, delete_query)
result = await simple_db_mutate(cls, ctx, delete_query)
if result.ok:
await redis_helper.execute(
ctx.redis_stat,
lambda r: r.delete(f"keypair.concurrency_used.{access_key}"),
)
return result


class Dotfile(TypedDict):
Expand Down
Loading
Loading