Skip to content

Commit

Permalink
fix: Apply image status-based filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Feb 11, 2025
1 parent 574e5d7 commit f1da0e0
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 38 deletions.
25 changes: 15 additions & 10 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from ai.backend.common.docker import ImageRef
from ai.backend.manager.models.container_registry import ContainerRegistryRow
from ai.backend.manager.models.group import GroupRow
from ai.backend.manager.models.image import ImageIdentifier, rescan_images
from ai.backend.manager.models.image import ImageIdentifier, ImageStatus, rescan_images

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
Expand Down Expand Up @@ -1259,6 +1259,7 @@ async def _commit_and_upload(reporter: ProgressReporter) -> None:
== f"{params.image_visibility.value}:{image_owner_id}"
)
)
.where(ImageRow.status == ImageStatus.ALIVE)
)
existing_image_count = await sess.scalar(query)

Expand All @@ -1275,16 +1276,20 @@ async def _commit_and_upload(reporter: ProgressReporter) -> None:
)

# check if image with same name exists and reuse ID it if is
query = sa.select(ImageRow).where(
ImageRow.name.like(f"{new_canonical}%")
& (
ImageRow.labels["ai.backend.customized-image.owner"].as_string()
== f"{params.image_visibility.value}:{image_owner_id}"
)
& (
ImageRow.labels["ai.backend.customized-image.name"].as_string()
== params.image_name
query = (
sa.select(ImageRow)
.where(
ImageRow.name.like(f"{new_canonical}%")
& (
ImageRow.labels["ai.backend.customized-image.owner"].as_string()
== f"{params.image_visibility.value}:{image_owner_id}"
)
& (
ImageRow.labels["ai.backend.customized-image.name"].as_string()
== params.image_name
)
)
.where(ImageRow.status == ImageStatus.ALIVE)
)
existing_row = await sess.scalar(query)

Expand Down
23 changes: 15 additions & 8 deletions src/ai/backend/manager/cli/image_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ai.backend.common.types import ImageAlias
from ai.backend.logging import BraceStyleAdapter

from ..models.image import ImageAliasRow, ImageIdentifier, ImageRow
from ..models.image import ImageAliasRow, ImageIdentifier, ImageRow, ImageStatus
from ..models.image import rescan_images as rescan_images_func
from ..models.utils import connect_database
from .context import CLIContext, redis_ctx
Expand All @@ -33,6 +33,8 @@ async def list_images(cli_ctx, short, installed_only):
):
displayed_items = []
try:
# TODO QUESTION: Should we display deleted image here?
# Idea: Add `deleted` option to include deleted images.
items = await ImageRow.list(session)
# NOTE: installed/installed_agents fields are no longer provided in CLI,
# until we finish the epic refactoring of image metadata db.
Expand Down Expand Up @@ -228,20 +230,25 @@ async def validate_image_canonical(
if current or architecture is not None:
if current:
architecture = architecture or CURRENT_ARCH
image_row = await session.scalar(
sa.select(ImageRow).where(
(ImageRow.name == canonical) & (ImageRow.architecture == architecture)
)

# TODO QUESTION: Should we use deleted image here?
assert architecture is not None
image_row = await ImageRow.resolve(
session, [ImageIdentifier(canonical, architecture)]
)
if image_row is None:
raise UnknownImageReference(f"{canonical}/{architecture}")

for key, value in validate_image_labels(image_row.labels).items():
print(f"{key:<40}: ", end="")
if isinstance(value, list):
value = f"{', '.join(value)}"
print(value)
else:
rows = await session.scalars(sa.select(ImageRow).where(ImageRow.name == canonical))
# TODO QUESTION: Should we use deleted image here?
rows = await session.scalars(
sa.select(ImageRow)
.where(ImageRow.name == canonical)
.where(ImageRow.status == ImageStatus.ALIVE)
)
image_rows = rows.fetchall()
if not image_rows:
raise UnknownImageReference(f"{canonical}")
Expand Down
10 changes: 7 additions & 3 deletions src/ai/backend/manager/container_registry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ai.backend.manager.models.container_registry import ContainerRegistryRow

from ..defs import INTRINSIC_SLOTS_MIN
from ..models.image import ImageIdentifier, ImageRow, ImageType
from ..models.image import ImageIdentifier, ImageRow, ImageStatus, ImageType
from ..models.utils import ExtendedAsyncSAEngine

log = BraceStyleAdapter(logging.getLogger(__spec__.name))
Expand Down Expand Up @@ -128,10 +128,13 @@ async def commit_rescan_result(self) -> None:
else:
image_identifiers = [(k.canonical, k.architecture) for k in _all_updates.keys()]
async with self.db.begin_session() as session:
# TODO QUESTION: Should we filter out deleted image here?
existing_images = await session.scalars(
sa.select(ImageRow).where(
sa.select(ImageRow)
.where(
sa.func.ROW(ImageRow.name, ImageRow.architecture).in_(image_identifiers),
),
)
.where(ImageRow.status == ImageStatus.ALIVE),
)
is_local = self.registry_name == "local"

Expand Down Expand Up @@ -178,6 +181,7 @@ async def commit_rescan_result(self) -> None:
accelerators=update.get("accels"),
labels=update["labels"],
resources=update["resources"],
status=ImageStatus.ALIVE,
)
)
progress_msg = f"Updated image - {parsed_img.canonical}/{image_identifier.architecture} ({update['config_digest']})"
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/manager/container_registry/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ async def _read_image_info(
already_exists = 0
config_digest = data["Id"]
async with self.db.begin_readonly_session() as db_session:
# TODO QUESTION: Should we use deleted image here?
already_exists = await db_session.scalar(
sa.select([sa.func.count(ImageRow.id)]).where(
ImageRow.config_digest == config_digest,
Expand Down
28 changes: 18 additions & 10 deletions src/ai/backend/manager/models/gql_models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ async def batch_load_by_name_and_arch(
graph_ctx: GraphQueryContext,
name_and_arch: Sequence[tuple[str, str]],
) -> Sequence[Sequence[ImageNode]]:
# TODO QUESTION: Should we filter out deleted image here?
query = (
sa.select(ImageRow)
.where(sa.tuple_(ImageRow.name, ImageRow.architecture).in_(name_and_arch))
Expand All @@ -377,6 +378,7 @@ async def batch_load_by_image_identifier(
graph_ctx: GraphQueryContext,
image_ids: Sequence[ImageIdentifier],
) -> Sequence[Sequence[ImageNode]]:
# TODO QUESTION: Should we filter out deleted image here?
name_and_arch_tuples = [(img.canonical, img.architecture) for img in image_ids]
return await cls.batch_load_by_name_and_arch(graph_ctx, name_and_arch_tuples)

Expand Down Expand Up @@ -421,6 +423,7 @@ def from_row(cls, row: ImageRow | None) -> ImageNode | None:
],
supported_accelerators=(row.accelerators or "").split(","),
aliases=[alias_row.alias for alias_row in row.aliases],
status=row.status,
)

@classmethod
Expand All @@ -445,13 +448,15 @@ def from_legacy_image(cls, row: Image) -> ImageNode:
resource_limits=row.resource_limits,
supported_accelerators=row.supported_accelerators,
aliases=row.aliases,
status=row.status,
)

@classmethod
async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ImageNode:
graph_ctx: GraphQueryContext = info.context

_, image_id = AsyncNode.resolve_global_id(info, id)
# TODO QUESTION: Should we filter out deleted image here?
query = (
sa.select(ImageRow)
.where(ImageRow.id == image_id)
Expand Down Expand Up @@ -500,7 +505,9 @@ async def mutate(
client_role = ctx.user["role"]

async with ctx.db.begin_session() as session:
image_row = await ImageRow.get(session, _image_id, load_aliases=True)
image_row = await ImageRow.get(
session, _image_id, load_only_active=True, load_aliases=True
)
if not image_row:
raise ObjectNotFound("image")
if client_role != UserRole.SUPERADMIN:
Expand Down Expand Up @@ -650,7 +657,9 @@ async def mutate(
client_role = ctx.user["role"]

async with ctx.db.begin_session() as session:
image_row = await ImageRow.get(session, _image_id, load_aliases=True)
image_row = await ImageRow.get(
session, _image_id, load_only_active=True, load_aliases=True
)
if not image_row:
raise ObjectNotFound("image")
if client_role != UserRole.SUPERADMIN:
Expand Down Expand Up @@ -704,7 +713,9 @@ async def mutate(
client_role = ctx.user["role"]

async with ctx.db.begin_readonly_session() as session:
image_row = await ImageRow.get(session, _image_id, load_aliases=True)
image_row = await ImageRow.get(
session, _image_id, load_only_active=True, load_aliases=True
)
if not image_row:
raise ImageNotFound
if client_role != UserRole.SUPERADMIN:
Expand Down Expand Up @@ -889,15 +900,12 @@ async def mutate(
ctx: GraphQueryContext = info.context
try:
async with ctx.db.begin_session() as session:
result = await session.execute(
sa.select(ImageRow).where(ImageRow.registry == registry)
)
image_ids = [x.id for x in result.scalars().all()]

await session.execute(
sa.delete(ImageAliasRow).where(ImageAliasRow.image_id.in_(image_ids))
sa.update(ImageRow)
.where(ImageRow.registry == registry)
.where(ImageRow.status != ImageStatus.DELETED)
.values(status=ImageStatus.DELETED)
)
await session.execute(sa.delete(ImageRow).where(ImageRow.registry == registry))
except ValueError as e:
return ClearImages(ok=False, msg=str(e))
return ClearImages(ok=True, msg="")
Expand Down
53 changes: 46 additions & 7 deletions src/ai/backend/manager/models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def __init__(
accelerators=None,
labels=None,
resources=None,
status=ImageStatus.ALIVE,
) -> None:
self.name = name
self.project = project
Expand All @@ -394,6 +395,7 @@ def __init__(
self.accelerators = accelerators
self.labels = labels
self.resources = resources
self.status = status

@property
def trimmed_digest(self) -> str:
Expand All @@ -420,6 +422,7 @@ async def from_alias(
session: AsyncSession,
alias: str,
load_aliases: bool = False,
load_only_active: bool = True,
*,
loading_options: Iterable[RelationLoadingOption] = tuple(),
) -> ImageRow:
Expand All @@ -430,6 +433,8 @@ async def from_alias(
)
if load_aliases:
query = query.options(selectinload(ImageRow.aliases))
if load_only_active:
query = query.where(ImageRow.status == ImageStatus.ALIVE)
query = _apply_loading_option(query, loading_options)
result = await session.scalar(query)
if result is not None:
Expand All @@ -443,6 +448,7 @@ async def from_image_identifier(
session: AsyncSession,
identifier: ImageIdentifier,
load_aliases: bool = True,
load_only_active: bool = True,
*,
loading_options: Iterable[RelationLoadingOption] = tuple(),
) -> ImageRow:
Expand All @@ -453,6 +459,8 @@ async def from_image_identifier(

if load_aliases:
query = query.options(selectinload(ImageRow.aliases))
if load_only_active:
query = query.where(ImageRow.status == ImageStatus.ALIVE)
query = _apply_loading_option(query, loading_options)

result = await session.execute(query)
Expand All @@ -471,6 +479,7 @@ async def from_image_ref(
*,
strict_arch: bool = False,
load_aliases: bool = False,
load_only_active: bool = True,
loading_options: Iterable[RelationLoadingOption] = tuple(),
) -> ImageRow:
"""
Expand All @@ -483,6 +492,9 @@ async def from_image_ref(
query = sa.select(ImageRow).where(ImageRow.name == ref.canonical)
if load_aliases:
query = query.options(selectinload(ImageRow.aliases))
if load_only_active:
query = query.where(ImageRow.status == ImageStatus.ALIVE)

query = _apply_loading_option(query, loading_options)

result = await session.execute(query)
Expand All @@ -504,6 +516,7 @@ async def resolve(
reference_candidates: list[ImageAlias | ImageRef | ImageIdentifier],
*,
strict_arch: bool = False,
load_only_active: bool = True,
load_aliases: bool = True,
loading_options: Iterable[RelationLoadingOption] = tuple(),
) -> ImageRow:
Expand Down Expand Up @@ -554,7 +567,11 @@ async def resolve(
searched_refs.append(f"identifier:{reference!r}")
try:
if row := await resolver_func(
session, reference, load_aliases=load_aliases, loading_options=loading_options
session,
reference,
load_aliases=load_aliases,
load_only_active=load_only_active,
loading_options=loading_options,
):
return row
except UnknownImageReference:
Expand All @@ -563,19 +580,31 @@ async def resolve(

@classmethod
async def get(
cls, session: AsyncSession, image_id: UUID, load_aliases=False
cls,
session: AsyncSession,
image_id: UUID,
load_only_active: bool = True,
load_aliases: bool = False,
) -> ImageRow | None:
query = sa.select(ImageRow).where(ImageRow.id == image_id)
if load_aliases:
query = query.options(selectinload(ImageRow.aliases))
if load_only_active:
query = query.where(ImageRow.status == ImageStatus.ALIVE)

result = await session.execute(query)
return result.scalar()

@classmethod
async def list(cls, session: AsyncSession, load_aliases=False) -> List[ImageRow]:
async def list(
cls, session: AsyncSession, load_only_active: bool = True, load_aliases: bool = False
) -> List[ImageRow]:
query = sa.select(ImageRow)
if load_aliases:
query = query.options(selectinload(ImageRow.aliases))
if load_only_active:
query = query.where(ImageRow.status == ImageStatus.ALIVE)

result = await session.execute(query)
return result.scalars().all()

Expand Down Expand Up @@ -873,7 +902,10 @@ async def build_ctx_in_system_scope(
permissions = await self.calculate_permission(ctx, SystemScope())
image_id_permission_map: dict[UUID, frozenset[ImagePermission]] = {}

for image_row in await self.db_session.scalars(sa.select(ImageRow)):
# TODO QUESTION: Should we filter out deleted image here?
for image_row in await self.db_session.scalars(
sa.select(ImageRow).where(ImageRow.status == ImageStatus.ALIVE)
):
image_id_permission_map[image_row.id] = permissions
perm_ctx = ImagePermissionContext(
object_id_to_additional_permission_map=image_id_permission_map
Expand Down Expand Up @@ -909,7 +941,11 @@ async def _in_domain_scope(
raise InvalidScope(f"Domain not found (n:{scope.domain_name})")

allowed_registries: set[str] = set(domain_row.allowed_docker_registries)
_img_query_stmt = sa.select(ImageRow).options(load_only(ImageRow.id, ImageRow.registry))
_img_query_stmt = (
sa.select(ImageRow)
.where(ImageRow.status == ImageStatus.ALIVE)
.options(load_only(ImageRow.id, ImageRow.registry))
)
for row in await self.db_session.scalars(_img_query_stmt):
_row = cast(ImageRow, row)
if _row.registry in allowed_registries:
Expand Down Expand Up @@ -952,8 +988,11 @@ async def _in_user_scope(
permissions = await self.calculate_permission(ctx, scope)
image_id_permission_map: dict[UUID, frozenset[ImagePermission]] = {}
allowed_registries: set[str] = set(user_row.domain.allowed_docker_registries)
_img_query_stmt = sa.select(ImageRow).options(
load_only(ImageRow.id, ImageRow.labels, ImageRow.registry)
# TODO QUESTION: Should we filter out deleted image here?
_img_query_stmt = (
sa.select(ImageRow)
.where(ImageRow.status == ImageStatus.ALIVE)
.options(load_only(ImageRow.id, ImageRow.labels, ImageRow.registry))
)
for row in await self.db_session.scalars(_img_query_stmt):
_row = cast(ImageRow, row)
Expand Down

0 comments on commit f1da0e0

Please sign in to comment.