diff --git a/changes/1761.feature.md b/changes/1761.feature.md new file mode 100644 index 00000000000..62188a87a30 --- /dev/null +++ b/changes/1761.feature.md @@ -0,0 +1 @@ +Implement the concept of the "main" keypair to make it clear which keypair to use by default and which one holds the user-level resource limits diff --git a/src/ai/backend/client/cli/admin/user.py b/src/ai/backend/client/cli/admin/user.py index 802c39b09cf..2063aecea8e 100644 --- a/src/ai/backend/client/cli/admin/user.py +++ b/src/ai/backend/client/cli/admin/user.py @@ -47,6 +47,7 @@ def info(ctx: CLIContext, email: str) -> None: user_fields["groups"], user_fields["allowed_client_ip"], user_fields["sudo_session_enabled"], + user_fields["main_access_key"], ] with Session() as session: try: @@ -148,6 +149,7 @@ def list(ctx: CLIContext, status, group, filter_, order, offset, limit) -> None: user_fields["groups"], user_fields["allowed_client_ip"], user_fields["sudo_session_enabled"], + user_fields["main_access_key"], ] try: with Session() as session: @@ -348,6 +350,12 @@ def add( "Note that this feature does not automatically install sudo for the session." ), ) +@click.option( + "--main-access-key", + type=OptionalType(str), + default=undefined, + help="Set main access key which works as default.", +) def update( ctx: CLIContext, email: str, @@ -361,6 +369,7 @@ def update( allowed_ip: Sequence[str] | Undefined, description: str | Undefined, sudo_session_enabled: bool | Undefined, + main_access_key: str | Undefined, ): """ Update an existing user. @@ -382,6 +391,7 @@ def update( allowed_client_ip=allowed_ip, description=description, sudo_session_enabled=sudo_session_enabled, + main_access_key=main_access_key, ) except Exception as e: ctx.output.print_mutation_error( diff --git a/src/ai/backend/client/func/user.py b/src/ai/backend/client/func/user.py index 942bba75e02..bd4eb1a5a1b 100644 --- a/src/ai/backend/client/func/user.py +++ b/src/ai/backend/client/func/user.py @@ -36,6 +36,7 @@ user_fields["allowed_client_ip"], user_fields["totp_activated"], user_fields["sudo_session_enabled"], + user_fields["main_access_key"], ) _default_detail_fields = ( @@ -52,6 +53,7 @@ user_fields["allowed_client_ip"], user_fields["totp_activated"], user_fields["sudo_session_enabled"], + user_fields["main_access_key"], ) @@ -322,6 +324,7 @@ async def update( totp_activated: bool | Undefined = undefined, group_ids: Iterable[str] | Undefined = undefined, sudo_session_enabled: bool | Undefined = undefined, + main_access_key: str | Undefined = undefined, fields: Iterable[FieldSpec | str] | None = None, ) -> dict: """ @@ -348,6 +351,7 @@ async def update( set_if_set(inputs, "totp_activated", totp_activated) set_if_set(inputs, "group_ids", group_ids) set_if_set(inputs, "sudo_session_enabled", sudo_session_enabled) + set_if_set(inputs, "main_access_key", main_access_key) variables = { "email": email, "input": inputs, diff --git a/src/ai/backend/client/output/fields.py b/src/ai/backend/client/output/fields.py index 95f96e43cf3..4b5219fdc69 100644 --- a/src/ai/backend/client/output/fields.py +++ b/src/ai/backend/client/output/fields.py @@ -278,6 +278,7 @@ FieldSpec("allowed_client_ip"), FieldSpec("totp_activated"), FieldSpec("sudo_session_enabled"), + FieldSpec("main_access_key"), ] ) diff --git a/src/ai/backend/manager/defs.py b/src/ai/backend/manager/defs.py index f1de1723e38..c5558c9d890 100644 --- a/src/ai/backend/manager/defs.py +++ b/src/ai/backend/manager/defs.py @@ -77,3 +77,6 @@ class LockID(enum.IntEnum): SERVICE_MAX_RETRIES = 5 # FIXME: make configurable + +DEFAULT_KEYPAIR_RESOURCE_POLICY_NAME: Final = "default" +DEFAULT_KEYPAIR_RATE_LIMIT: Final = 10000 diff --git a/src/ai/backend/manager/models/alembic/versions/d3f8c74bf148_user_main_keypair.py b/src/ai/backend/manager/models/alembic/versions/d3f8c74bf148_user_main_keypair.py new file mode 100644 index 00000000000..70d1fd4e569 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/d3f8c74bf148_user_main_keypair.py @@ -0,0 +1,173 @@ +"""user_main_keypair + +Revision ID: d3f8c74bf148 +Revises: 308bcecec5c2 +Create Date: 2023-12-06 12:20:11.537908 + +""" +import enum +from typing import Any + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.orm import registry, relationship, selectinload, sessionmaker + +from ai.backend.manager.defs import DEFAULT_KEYPAIR_RATE_LIMIT, DEFAULT_KEYPAIR_RESOURCE_POLICY_NAME +from ai.backend.manager.models.base import GUID, EnumValueType, convention +from ai.backend.manager.models.keypair import generate_keypair, generate_ssh_keypair + +# revision identifiers, used by Alembic. +revision = "d3f8c74bf148" +down_revision = "308bcecec5c2" +branch_labels = None +depends_on = None + + +metadata = sa.MetaData(naming_convention=convention) +mapper_registry = registry(metadata=metadata) +Base = mapper_registry.generate_base() + +PAGE_SIZE = 100 + + +class UserRole(str, enum.Enum): + """ + User's role. + """ + + SUPERADMIN = "superadmin" + ADMIN = "admin" + USER = "user" + MONITOR = "monitor" + + +MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB + + +def upgrade(): + op.add_column("users", sa.Column("main_access_key", sa.String(length=20), nullable=True)) + op.create_foreign_key( + op.f("fk_users_main_access_key_keypairs"), + "users", + "keypairs", + ["main_access_key"], + ["access_key"], + ) + + # Update all user's main_access_key + # The oldest keypair of a user will be main_access_key + class KeyPairRow(Base): # type: ignore[valid-type, misc] + __tablename__ = "keypairs" + __table_args__ = {"extend_existing": True} + + user_id = sa.Column("user_id", sa.String(length=256)) + access_key = sa.Column("access_key", sa.String(length=20), primary_key=True) + secret_key = sa.Column("secret_key", sa.String(length=40)) + is_active = sa.Column("is_active", sa.Boolean, index=True) + is_admin = sa.Column( + "is_admin", sa.Boolean, index=True, default=False, server_default=sa.false() + ) + created_at = sa.Column("created_at", sa.DateTime(timezone=True)) + rate_limit = sa.Column("rate_limit", sa.Integer) + num_queries = sa.Column("num_queries", sa.Integer, server_default="0") + ssh_public_key = sa.Column("ssh_public_key", sa.Text, nullable=True) + ssh_private_key = sa.Column("ssh_private_key", sa.Text, nullable=True) + resource_policy = sa.Column( + "resource_policy", + sa.String(length=256), + sa.ForeignKey("keypair_resource_policies.name"), + nullable=False, + ) + dotfiles = sa.Column( + "dotfiles", sa.LargeBinary(length=MAXIMUM_DOTFILE_SIZE), nullable=False, default=b"\x90" + ) + bootstrap_script = sa.Column( + "bootstrap_script", sa.String(length=MAXIMUM_DOTFILE_SIZE), nullable=False, default="" + ) + user = sa.Column("user", GUID, sa.ForeignKey("users.uuid")) + user_row = relationship("UserRow", back_populates="keypairs", foreign_keys=user) + + class UserRow(Base): # type: ignore[valid-type, misc] + __tablename__ = "users" + __table_args__ = {"extend_existing": True} + + uuid = sa.Column("uuid", GUID, primary_key=True) + role = sa.Column("role", EnumValueType(UserRole), default=UserRole.USER) + email = sa.Column("email", sa.String(length=64)) + main_access_key = sa.Column( + "main_access_key", + sa.String(length=20), + sa.ForeignKey("keypairs.access_key", ondelete="RESTRICT"), + nullable=True, + ) + keypairs = relationship( + "KeyPairRow", back_populates="user_row", foreign_keys=KeyPairRow.user + ) + main_keypair = relationship("KeyPairRow", foreign_keys=main_access_key) + + def pick_main_keypair(keypair_list: list[KeyPairRow]) -> KeyPairRow | None: + try: + return sorted(keypair_list, key=lambda k: k.created_at)[0] + except IndexError: + return None + + def prepare_keypair( + user_email, + user_id, + user_role, + ) -> dict[str, Any]: + ak, sk = generate_keypair() + pubkey, privkey = generate_ssh_keypair() + return { + "user_id": user_email, + "user": user_id, + "access_key": ak, + "secret_key": sk, + "is_active": True, + "is_admin": user_role in (UserRole.SUPERADMIN, UserRole.ADMIN), + "resource_policy": DEFAULT_KEYPAIR_RESOURCE_POLICY_NAME, + "rate_limit": DEFAULT_KEYPAIR_RATE_LIMIT, + "num_queries": 0, + "ssh_public_key": pubkey, + "ssh_private_key": privkey, + } + + connection = op.get_bind() + sess_factory = sessionmaker(connection) + db_session = sess_factory() + while True: + user_id_kp_maps = [] + user_query = ( + sa.select(UserRow) + .where(UserRow.main_access_key.is_(sa.null())) + .limit(PAGE_SIZE) + .options(selectinload(UserRow.keypairs)) + ) + user_rows: list[UserRow] = db_session.scalars(user_query).all() + + if not user_rows: + break + + for row in user_rows: + main_kp = pick_main_keypair(row.keypairs) + if main_kp is None: + # Create new keypair when the user has no keypair + kp_data = prepare_keypair(row.email, row.uuid, row.role) + db_session.execute(sa.insert(KeyPairRow).values(**kp_data)) + user_id_kp_maps.append( + {"user_id": row.uuid, "main_access_key": kp_data["access_key"]} + ) + else: + user_id_kp_maps.append({"user_id": row.uuid, "main_access_key": main_kp.access_key}) + + update_query = ( + sa.update(UserRow) + .where(UserRow.uuid == sa.bindparam("user_id")) + .values(main_access_key=sa.bindparam("main_access_key")) + ) + db_session.execute(update_query, user_id_kp_maps) + + +def downgrade(): + op.drop_constraint(op.f("fk_users_main_access_key_keypairs"), "users", type_="foreignkey") + op.drop_column("users", "main_access_key") diff --git a/src/ai/backend/manager/models/keypair.py b/src/ai/backend/manager/models/keypair.py index a5f27aa85b2..f5ba6fdfbc0 100644 --- a/src/ai/backend/manager/models/keypair.py +++ b/src/ai/backend/manager/models/keypair.py @@ -111,6 +111,8 @@ class KeyPairRow(Base): back_populates="keypairs", ) + user_row = relationship("UserRow", back_populates="keypairs", foreign_keys=keypairs.c.user) + class UserInfo(graphene.ObjectType): email = graphene.String() @@ -630,11 +632,13 @@ async def mutate( ) -> DeleteKeyPair: ctx: GraphQueryContext = info.context delete_query = sa.delete(keypairs).where(keypairs.c.access_key == access_key) - await redis_helper.execute( - ctx.redis_stat, - lambda r: r.delete(f"keypair.concurrency_used.{access_key}"), - ) - return await simple_db_mutate(cls, ctx, delete_query) + result = await simple_db_mutate(cls, ctx, delete_query) + if result.ok: + await redis_helper.execute( + ctx.redis_stat, + lambda r: r.delete(f"keypair.concurrency_used.{access_key}"), + ) + return result class Dotfile(TypedDict): diff --git a/src/ai/backend/manager/models/user.py b/src/ai/backend/manager/models/user.py index 643a20fc579..8615b64dcd8 100644 --- a/src/ai/backend/manager/models/user.py +++ b/src/ai/backend/manager/models/user.py @@ -17,7 +17,8 @@ from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection from sqlalchemy.ext.asyncio import AsyncEngine as SAEngine -from sqlalchemy.orm import relationship +from sqlalchemy.ext.asyncio import AsyncSession as SASession +from sqlalchemy.orm import joinedload, load_only, noload, relationship from sqlalchemy.sql.expression import bindparam from sqlalchemy.types import VARCHAR, TypeDecorator @@ -26,6 +27,7 @@ from ai.backend.common.types import RedisConnectionInfo, VFolderID from ..api.exceptions import VFolderOperationFailed +from ..defs import DEFAULT_KEYPAIR_RATE_LIMIT, DEFAULT_KEYPAIR_RESOURCE_POLICY_NAME from .base import ( Base, EnumValueType, @@ -153,15 +155,26 @@ class UserStatus(str, enum.Enum): default=False, nullable=False, ), + sa.Column( + "main_access_key", + sa.String(length=20), + sa.ForeignKey("keypairs.access_key", ondelete="RESTRICT"), + nullable=True, # keypairs.user is non-nullable + ), ) class UserRow(Base): __table__ = users + # from .keypair import KeyPairRow + sessions = relationship("SessionRow", back_populates="user") domain = relationship("DomainRow", back_populates="users") groups = relationship("AssocGroupUserRow", back_populates="user") resource_policy_row = relationship("UserResourcePolicyRow", back_populates="users") + keypairs = relationship("KeyPairRow", back_populates="user_row", foreign_keys="KeyPairRow.user") + + main_keypair = relationship("KeyPairRow", foreign_keys=users.c.main_access_key) class UserGroup(graphene.ObjectType): @@ -221,6 +234,13 @@ class Meta: totp_activated = graphene.Boolean() totp_activated_at = GQLDateTime() sudo_session_enabled = graphene.Boolean() + main_access_key = graphene.String( + description=( + "Added in 24.03.0. Used as the default authentication credential for password-based" + " logins and sets the user's total resource usage limit. User's main_access_key cannot" + " be deleted, and only super-admin can replace main_access_key." + ) + ) groups = graphene.List(lambda: UserGroup) @@ -259,6 +279,7 @@ def from_row( totp_activated=row["totp_activated"], totp_activated_at=row["totp_activated_at"], sudo_session_enabled=row["sudo_session_enabled"], + main_access_key=row["main_access_key"], ) @classmethod @@ -315,6 +336,7 @@ async def load_all( "totp_activated": ("totp_activated", None), "totp_activated_at": ("totp_activated_at", dtparse), "sudo_session_enabled": ("sudo_session_enabled", None), + "main_access_key": ("main_access_key", None), } _queryorder_colmap: Mapping[str, OrderSpecItem] = { @@ -334,6 +356,7 @@ async def load_all( "totp_activated": ("totp_activated", None), "totp_activated_at": ("totp_activated_at", None), "sudo_session_enabled": ("sudo_session_enabled", None), + "main_access_key": ("main_access_key", None), } @classmethod @@ -544,6 +567,7 @@ class ModifyUserInput(graphene.InputObjectType): totp_activated = graphene.Boolean(required=False, default=False) resource_policy = graphene.String(required=False) sudo_session_enabled = graphene.Boolean(required=False, default=False) + main_access_key = graphene.String(required=False) class PurgeUserInput(graphene.InputObjectType): @@ -610,8 +634,8 @@ async def _post_func(conn: SAConnection, result: Result) -> Row: { "is_active": _status == UserStatus.ACTIVE, "is_admin": user_data["role"] in [UserRole.SUPERADMIN, UserRole.ADMIN], - "resource_policy": "default", - "rate_limit": 10000, + "resource_policy": DEFAULT_KEYPAIR_RESOURCE_POLICY_NAME, + "rate_limit": DEFAULT_KEYPAIR_RATE_LIMIT, }, ) kp_insert_query = sa.insert(keypairs).values( @@ -619,6 +643,16 @@ async def _post_func(conn: SAConnection, result: Result) -> Row: user=created_user.uuid, ) await conn.execute(kp_insert_query) + + # Update user main_keypair + main_ak = kp_data["access_key"] + update_query = ( + sa.update(users) + .where(users.c.uuid == created_user.uuid) + .values(main_access_key=main_ak) + ) + await conn.execute(update_query) + model_store_query = sa.select([groups.c.id]).where( groups.c.type == ProjectType.MODEL_STORE ) @@ -671,6 +705,8 @@ async def mutate( email: str, props: ModifyUserInput, ) -> ModifyUser: + from .keypair import KeyPairRow + graph_ctx: GraphQueryContext = info.context data: Dict[str, Any] = {} set_if_set(props, data, "username") @@ -685,6 +721,7 @@ async def mutate( set_if_set(props, data, "totp_activated") set_if_set(props, data, "resource_policy") set_if_set(props, data, "sudo_session_enabled") + set_if_set(props, data, "main_access_key") if not data and not props.group_ids: return cls(ok=False, msg="nothing to update", user=None) if data.get("status") is None and props.is_active is not None: @@ -693,12 +730,13 @@ async def mutate( if data.get("password") is not None: data["password_changed_at"] = sa.func.now() + main_access_key: str | None = data.get("main_access_key") user_update_data: Dict[str, Any] = {} prev_domain_name: str prev_role: UserRole async def _pre_func(conn: SAConnection) -> None: - nonlocal user_update_data, prev_domain_name, prev_role + nonlocal user_update_data, prev_domain_name, prev_role, main_access_key result = await conn.execute( sa.select([users.c.domain_name, users.c.role, users.c.status]) .select_from(users) @@ -712,6 +750,28 @@ async def _pre_func(conn: SAConnection) -> None: user_update_data["status_info"] = ( "admin-requested" # user mutation is only for admin ) + if main_access_key is not None: + db_session = SASession(conn) + keypair_query = ( + sa.select(KeyPairRow) + .where(KeyPairRow.access_key == main_access_key) + .options( + noload("*"), + joinedload(KeyPairRow.user_row).options(load_only(UserRow.email)), + ) + ) + keypair_row: KeyPairRow | None = (await db_session.scalars(keypair_query)).first() + if keypair_row is None: + raise RuntimeError("Cannot set non-existing access key as the main access key.") + if keypair_row.user_row.email != email: + raise RuntimeError( + "Cannot set another user's access key as the main access key." + ) + await conn.execute( + sa.update(users) + .where(users.c.email == email) + .values(main_access_key=main_access_key) + ) update_query = lambda: ( # uses lambda because user_update_data is modified in _pre_func() sa.update(users).values(user_update_data).where(users.c.email == email) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 8b82c605174..60fcb73dc10 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -1762,6 +1762,28 @@ async def create_cluster_ssh_keypair(self) -> ClusterSSHKeyPair: "public_key": public_key.decode("utf-8"), } + async def get_user_occupancy(self, user_id, *, db_sess=None): + known_slot_types = await self.shared_config.get_resource_slots() + + async def _query() -> ResourceSlot: + async with reenter_txn_session(self.db, db_sess) as _sess: + query = sa.select(KernelRow.occupied_slots).where( + (KernelRow.user_uuid == user_id) + & (KernelRow.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) + & (KernelRow.role.not_in(PRIVATE_KERNEL_ROLES)), + ) + zero = ResourceSlot() + user_occupied = sum( + [row.occupied_slots async for row in (await _sess.stream(query))], zero + ) + # drop no-longer used slot types + user_occupied = ResourceSlot( + {key: val for key, val in user_occupied.items() if key in known_slot_types} + ) + return user_occupied + + return await execute_with_retry(_query) + async def get_keypair_occupancy(self, access_key, *, db_sess=None): known_slot_types = await self.shared_config.get_resource_slots() diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index ada81692dbf..adb9cbff7ce 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -94,6 +94,7 @@ check_group_resource_limit, check_keypair_resource_limit, check_reserved_batch_session, + check_user_resource_limit, ) from .types import ( AbstractScheduler, @@ -480,6 +481,10 @@ async def _check_predicates() -> List[Tuple[str, Union[Exception, PredicateResul "keypair_resource_limit", check_keypair_resource_limit(db_sess, sched_ctx, sess_ctx), ), + ( + "user_resource_limit", + check_user_resource_limit(db_sess, sched_ctx, sess_ctx), + ), ( "user_group_resource_limit", check_group_resource_limit(db_sess, sched_ctx, sess_ctx), diff --git a/src/ai/backend/manager/scheduler/predicates.py b/src/ai/backend/manager/scheduler/predicates.py index 81c2e033aae..e8a822add5f 100644 --- a/src/ai/backend/manager/scheduler/predicates.py +++ b/src/ai/backend/manager/scheduler/predicates.py @@ -17,6 +17,7 @@ KeyPairRow, SessionDependencyRow, SessionRow, + UserRow, ) from ..models.utils import execute_with_retry from .types import PredicateResult, SchedulingContext @@ -181,6 +182,47 @@ async def check_keypair_resource_limit( return PredicateResult(True) +async def check_user_resource_limit( + db_sess: SASession, + sched_ctx: SchedulingContext, + sess_ctx: SessionRow, +) -> PredicateResult: + main_ak = ( + sa.select(UserRow.main_access_key) + .where(UserRow.uuid == sess_ctx.user_uuid) + .scalar_subquery() + ) + resouce_policy_q = sa.select(KeyPairRow.resource_policy).where(KeyPairRow.access_key == main_ak) + select_query = sa.select(KeyPairResourcePolicyRow).where( + KeyPairResourcePolicyRow.name == resouce_policy_q.scalar_subquery() + ) + resource_policy: KeyPairResourcePolicyRow = (await db_sess.scalars(select_query)).first() + + resource_policy_map = { + "total_resource_slots": resource_policy.total_resource_slots, + "default_for_unspecified": resource_policy.default_for_unspecified, + } + total_main_keypair_allowed = ResourceSlot.from_policy( + resource_policy_map, sched_ctx.known_slot_types + ) + user_occupied = await sched_ctx.registry.get_user_occupancy(sess_ctx.user_uuid, db_sess=db_sess) + log.debug("user:{} current-occupancy: {}", sess_ctx.user_uuid, user_occupied) + log.debug("user:{} total-allowed: {}", sess_ctx.user_uuid, total_main_keypair_allowed) + if not (user_occupied + sess_ctx.requested_slots <= total_main_keypair_allowed): + return PredicateResult( + False, + "Your main-keypair resource quota is exceeded. ({})".format( + " ".join( + f"{k}={v}" + for k, v in total_main_keypair_allowed.to_humanized( + sched_ctx.known_slot_types + ).items() + ) + ), + ) + return PredicateResult(True) + + async def check_group_resource_limit( db_sess: SASession, sched_ctx: SchedulingContext,