Skip to content

Commit

Permalink
feat: REST API v1 Artifacts and Models service layer
Browse files Browse the repository at this point in the history
Co-authored-by: James K. Glasbrenner <[email protected]>
Co-authored-by: Paul Scemama <[email protected]>
  • Loading branch information
3 people committed Jul 2, 2024
1 parent b8d013d commit 987818b
Show file tree
Hide file tree
Showing 21 changed files with 3,456 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,368 @@
"""Add ml_model_versions_resource
Revision ID: fd786b5377d6
Revises: d2bae5f6d991
Create Date: 2024-06-28 17:13:00.008695
"""

from typing import Annotated

import sqlalchemy as sa
from alembic import op
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
MappedAsDataclass,
mapped_column,
sessionmaker,
)

# revision identifiers, used by Alembic.
revision = "fd786b5377d6"
down_revision = "d2bae5f6d991"
branch_labels = None
depends_on = None


# Migration data models
intpk = Annotated[
int,
mapped_column(sa.BigInteger().with_variant(sa.Integer, "sqlite"), primary_key=True),
]
bigint = Annotated[
int, mapped_column(sa.BigInteger().with_variant(sa.Integer, "sqlite"))
]
text_ = Annotated[str, mapped_column(sa.Text())]


class UpgradeBase(DeclarativeBase, MappedAsDataclass):
pass


class DowngradeBase(DeclarativeBase, MappedAsDataclass):
pass


class ResourceDependencyTypeUpgrade(UpgradeBase):
__tablename__ = "resource_dependency_types"

parent_resource_type: Mapped[text_] = mapped_column(
sa.ForeignKey("resource_types.resource_type"), primary_key=True
)
child_resource_type: Mapped[text_] = mapped_column(
sa.ForeignKey("resource_types.resource_type"), primary_key=True
)


class ResourceTypeUpgrade(UpgradeBase):
__tablename__ = "resource_types"

resource_type: Mapped[text_] = mapped_column(primary_key=True)


class ResourceDependencyTypeDowngrade(DowngradeBase):
__tablename__ = "resource_dependency_types"

parent_resource_type: Mapped[text_] = mapped_column(
sa.ForeignKey("resource_types.resource_type"), primary_key=True
)
child_resource_type: Mapped[text_] = mapped_column(
sa.ForeignKey("resource_types.resource_type"), primary_key=True
)


class ResourceTypeDowngrade(DowngradeBase):
__tablename__ = "resource_types"

resource_type: Mapped[text_] = mapped_column(primary_key=True)


class ResourceDependencyDowngrade(DowngradeBase):
__tablename__ = "resource_dependencies"

parent_resource_id: Mapped[intpk]
child_resource_id: Mapped[intpk]
parent_resource_type: Mapped[text_] = mapped_column(nullable=False)
child_resource_type: Mapped[text_] = mapped_column(nullable=False)


class DraftResourceDowngrade(DowngradeBase):
__tablename__ = "draft_resources"

# Database fields
draft_resource_id: Mapped[intpk]
resource_type: Mapped[text_] = mapped_column(
sa.ForeignKey("resource_types.resource_type"), nullable=False
)


class SharedResourceDowngrade(DowngradeBase):
__tablename__ = "shared_resources"

# Database fields
shared_resource_id: Mapped[intpk]
resource_id: Mapped[bigint] = mapped_column(
sa.ForeignKey("resources.resource_id"), nullable=False
)


class ResourceDowngrade(DowngradeBase):
__tablename__ = "resources"

# Database fields
resource_id: Mapped[intpk]
resource_type: Mapped[text_] = mapped_column(
sa.ForeignKey("resource_types.resource_type"), nullable=False
)


class ResourceSnapshotDowngrade(DowngradeBase):
__tablename__ = "resource_snapshots"

# Database fields
resource_snapshot_id: Mapped[intpk]
resource_id: Mapped[bigint] = mapped_column(nullable=False)
resource_type: Mapped[text_] = mapped_column(
sa.ForeignKey("resource_types.resource_type"), nullable=False
)


class MlModelDowngrade(DowngradeBase):
__tablename__ = "ml_models"

# Database fields
resource_snapshot_id: Mapped[intpk]
resource_id: Mapped[bigint] = mapped_column(nullable=False)


class MlModelVersionDowngrade(DowngradeBase):
__tablename__ = "ml_model_versions"

# Database fields
resource_snapshot_id: Mapped[intpk]
resource_id: Mapped[bigint] = mapped_column(nullable=False)


class ResourceTagDowngrade(DowngradeBase):
__tablename__ = "resource_tags"

# Database fields
resource_id: Mapped[intpk] = mapped_column(sa.ForeignKey("resources.resource_id"))


def upgrade():
bind = op.get_bind()
Session = sessionmaker(bind=bind)

# Update the list of allowed resource types and resource dependency types
with Session() as session:
session.add(ResourceTypeUpgrade(resource_type="ml_model_version"))
session.flush()
session.add(
ResourceDependencyTypeUpgrade(
parent_resource_type="ml_model", child_resource_type="ml_model_version"
)
)
session.commit()

# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"ml_model_versions",
sa.Column(
"resource_snapshot_id",
sa.BigInteger().with_variant(sa.Integer(), "sqlite"),
nullable=False,
),
sa.Column(
"resource_id",
sa.BigInteger().with_variant(sa.Integer(), "sqlite"),
nullable=False,
),
sa.Column(
"artifact_resource_snapshot_id",
sa.BigInteger().with_variant(sa.Integer(), "sqlite"),
nullable=False,
),
sa.Column(
"version_number",
sa.BigInteger().with_variant(sa.Integer(), "sqlite"),
nullable=False,
),
sa.ForeignKeyConstraint(
["artifact_resource_snapshot_id"],
["artifacts.resource_snapshot_id"],
name=op.f("fk_ml_model_versions_artifact_resource_snapshot_id_artifacts"),
),
sa.ForeignKeyConstraint(
["resource_snapshot_id", "resource_id"],
[
"resource_snapshots.resource_snapshot_id",
"resource_snapshots.resource_id",
],
name=op.f("fk_ml_model_versions_resource_snapshot_id_resource_snapshots"),
),
sa.PrimaryKeyConstraint(
"resource_snapshot_id", name=op.f("pk_ml_model_versions")
),
sa.UniqueConstraint(
"resource_snapshot_id",
"resource_id",
"artifact_resource_snapshot_id",
"version_number",
name=op.f(
"uq_ml_model_versions_resource_snapshot_id_artifact_resource_snapshot"
"_id_version_number"
),
),
sa.UniqueConstraint(
"resource_snapshot_id",
"resource_id",
name=op.f("uq_ml_model_versions_resource_snapshot_id"),
),
)
with op.batch_alter_table("ml_model_versions", schema=None) as batch_op:
batch_op.create_index(
batch_op.f("ix_ml_model_versions_resource_id"),
["resource_id"],
unique=False,
)
batch_op.create_index(
batch_op.f("ix_ml_model_versions_resource_snapshot_id"),
["resource_snapshot_id", "resource_id", "artifact_resource_snapshot_id"],
unique=True,
)

with op.batch_alter_table("ml_models", schema=None) as batch_op:
batch_op.drop_index(batch_op.f("ix_ml_models_resource_snapshot_id"))
batch_op.create_index(
batch_op.f("ix_ml_models_resource_snapshot_id"),
["resource_snapshot_id", "resource_id"],
unique=True,
)
batch_op.drop_constraint(
"fk_ml_models_artifact_resource_snapshot_id_artifacts", type_="foreignkey"
)
batch_op.drop_column("artifact_resource_snapshot_id")

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
bind = op.get_bind()
Session = sessionmaker(bind=bind)

# Remove all traces of MlModel and MlModelVersion resources
with Session() as session:
ml_models_stmt = sa.select(MlModelDowngrade)
ml_model_versions_stmt = sa.select(MlModelVersionDowngrade)
drafts_stmt = sa.select(DraftResourceDowngrade).where(
DraftResourceDowngrade.resource_type.in_(["ml_model", "ml_model_version"])
)
resources_stmt = sa.select(ResourceDowngrade).where(
ResourceDowngrade.resource_type.in_(["ml_model", "ml_model_version"])
)
resource_snapshots_stmt = sa.select(ResourceSnapshotDowngrade).where(
ResourceSnapshotDowngrade.resource_type.in_(
["ml_model", "ml_model_version"]
)
)
resource_dependencies_stmt = sa.select(ResourceDependencyDowngrade).where(
ResourceDependencyDowngrade.parent_resource_type == "ml_model",
ResourceDependencyDowngrade.child_resource_type == "ml_model_version",
)
cte_resource_ids = (
sa.select(ResourceDowngrade.resource_id)
.where(
ResourceDowngrade.resource_type.in_(["ml_model", "ml_model_version"])
)
.cte()
)
shared_resources_stmt = sa.select(SharedResourceDowngrade).where(
SharedResourceDowngrade.resource_id.in_(sa.select(cte_resource_ids))
)
resource_tags_stmt = sa.select(ResourceTagDowngrade).where(
ResourceTagDowngrade.resource_id.in_(sa.select(cte_resource_ids))
)

for ml_model in session.scalars(ml_models_stmt):
session.delete(ml_model)

for ml_model_version in session.scalars(ml_model_versions_stmt):
session.delete(ml_model_version)

for draft in session.scalars(drafts_stmt):
session.delete(draft)

for resource in session.scalars(resources_stmt):
session.delete(resource)

for resource_snapshot in session.scalars(resource_snapshots_stmt):
session.delete(resource_snapshot)

for resource_dependency in session.scalars(resource_dependencies_stmt):
session.delete(resource_dependency)

for shared_resource in session.scalars(shared_resources_stmt):
session.delete(shared_resource)

for resource_tag in session.scalars(resource_tags_stmt):
session.delete(resource_tag)

session.commit()

with op.batch_alter_table("ml_models", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"artifact_resource_snapshot_id",
sa.BigInteger().with_variant(sa.Integer(), "sqlite"),
nullable=True,
)
)
batch_op.create_foreign_key(
batch_op.f("fk_ml_models_artifact_resource_snapshot_id_artifacts"),
"artifacts",
["artifact_resource_snapshot_id"],
["resource_snapshot_id"],
)
batch_op.drop_index(batch_op.f("ix_ml_models_resource_snapshot_id"))
batch_op.create_index(
batch_op.f("ix_ml_models_resource_snapshot_id"),
["resource_snapshot_id", "resource_id", "artifact_resource_snapshot_id"],
unique=True,
)

# Workaround to ensure the migration won't fail (table should be empty)
with op.batch_alter_table("ml_models", schema=None) as batch_op:
batch_op.alter_column("artifact_resource_snapshot_id", nullable=False)

with op.batch_alter_table("ml_model_versions", schema=None) as batch_op:
batch_op.drop_index(batch_op.f("ix_ml_model_versions_resource_snapshot_id"))
batch_op.drop_index(batch_op.f("ix_ml_model_versions_resource_id"))

op.drop_table("ml_model_versions")

# Downgrade the list of allowed resource types and resource dependency types
with Session() as session:
ml_model_version_type_stmt = sa.select(ResourceTypeDowngrade).where(
ResourceTypeDowngrade.resource_type == "ml_model_version"
)
resource_dependency_types_stmt = sa.select(
ResourceDependencyTypeDowngrade
).where(
ResourceDependencyTypeDowngrade.parent_resource_type == "ml_model",
ResourceDependencyTypeDowngrade.child_resource_type == "ml_model_version",
)
ml_model_version_type = session.scalar(ml_model_version_type_stmt)
resource_dependency_type = session.scalar(resource_dependency_types_stmt)

if ml_model_version_type is not None:
session.delete(ml_model_version_type)

if resource_dependency_type is not None:
session.delete(resource_dependency_type)

session.commit()

# ### end Alembic commands ###
3 changes: 2 additions & 1 deletion src/dioptra/restapi/db/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
resource_lock_types_table,
user_lock_types_table,
)
from .ml_models import MlModel
from .ml_models import MlModel, MlModelVersion
from .plugins import (
Plugin,
PluginFile,
Expand Down Expand Up @@ -81,6 +81,7 @@
"Job",
"JobMlflowRun",
"MlModel",
"MlModelVersion",
"Plugin",
"PluginFile",
"PluginTask",
Expand Down
Loading

0 comments on commit 987818b

Please sign in to comment.