From 987818be935a6220510d7d383f6e9705324e81bd Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Tue, 2 Jul 2024 10:18:05 -0400 Subject: [PATCH] feat: REST API v1 Artifacts and Models service layer Co-authored-by: James K. Glasbrenner Co-authored-by: Paul Scemama --- ...6b5377d6_add_ml_model_versions_resource.py | 368 ++++++++ src/dioptra/restapi/db/models/__init__.py | 3 +- src/dioptra/restapi/db/models/ml_models.py | 72 +- src/dioptra/restapi/errors.py | 2 + src/dioptra/restapi/v1/artifacts/__init__.py | 3 + .../restapi/v1/artifacts/controller.py | 126 ++- src/dioptra/restapi/v1/artifacts/errors.py | 44 + src/dioptra/restapi/v1/artifacts/schema.py | 13 +- src/dioptra/restapi/v1/artifacts/service.py | 381 ++++++++ src/dioptra/restapi/v1/entrypoints/schema.py | 2 +- src/dioptra/restapi/v1/models/__init__.py | 3 + src/dioptra/restapi/v1/models/controller.py | 284 +++++- src/dioptra/restapi/v1/models/errors.py | 54 ++ src/dioptra/restapi/v1/models/schema.py | 103 ++- src/dioptra/restapi/v1/models/service.py | 822 ++++++++++++++++++ src/dioptra/restapi/v1/utils.py | 144 ++- tests/unit/restapi/lib/actions.py | 95 +- tests/unit/restapi/lib/db/setup.py | 2 + tests/unit/restapi/v1/conftest.py | 105 +++ tests/unit/restapi/v1/test_artifact.py | 321 +++++++ tests/unit/restapi/v1/test_model.py | 548 ++++++++++++ 21 files changed, 3456 insertions(+), 39 deletions(-) create mode 100644 src/dioptra/restapi/db/alembic/versions/fd786b5377d6_add_ml_model_versions_resource.py create mode 100644 src/dioptra/restapi/v1/artifacts/errors.py create mode 100644 src/dioptra/restapi/v1/artifacts/service.py create mode 100644 src/dioptra/restapi/v1/models/errors.py create mode 100644 src/dioptra/restapi/v1/models/service.py create mode 100644 tests/unit/restapi/v1/test_artifact.py create mode 100644 tests/unit/restapi/v1/test_model.py diff --git a/src/dioptra/restapi/db/alembic/versions/fd786b5377d6_add_ml_model_versions_resource.py b/src/dioptra/restapi/db/alembic/versions/fd786b5377d6_add_ml_model_versions_resource.py new file mode 100644 index 000000000..d4eec3770 --- /dev/null +++ b/src/dioptra/restapi/db/alembic/versions/fd786b5377d6_add_ml_model_versions_resource.py @@ -0,0 +1,368 @@ +"""Add ml_model_versions_resource + +Revision ID: fd786b5377d6 +Revises: d2bae5f6d991 +Create Date: 2024-06-28 17:13:00.008695 + +""" + +from typing import Annotated + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + MappedAsDataclass, + mapped_column, + sessionmaker, +) + +# revision identifiers, used by Alembic. +revision = "fd786b5377d6" +down_revision = "d2bae5f6d991" +branch_labels = None +depends_on = None + + +# Migration data models +intpk = Annotated[ + int, + mapped_column(sa.BigInteger().with_variant(sa.Integer, "sqlite"), primary_key=True), +] +bigint = Annotated[ + int, mapped_column(sa.BigInteger().with_variant(sa.Integer, "sqlite")) +] +text_ = Annotated[str, mapped_column(sa.Text())] + + +class UpgradeBase(DeclarativeBase, MappedAsDataclass): + pass + + +class DowngradeBase(DeclarativeBase, MappedAsDataclass): + pass + + +class ResourceDependencyTypeUpgrade(UpgradeBase): + __tablename__ = "resource_dependency_types" + + parent_resource_type: Mapped[text_] = mapped_column( + sa.ForeignKey("resource_types.resource_type"), primary_key=True + ) + child_resource_type: Mapped[text_] = mapped_column( + sa.ForeignKey("resource_types.resource_type"), primary_key=True + ) + + +class ResourceTypeUpgrade(UpgradeBase): + __tablename__ = "resource_types" + + resource_type: Mapped[text_] = mapped_column(primary_key=True) + + +class ResourceDependencyTypeDowngrade(DowngradeBase): + __tablename__ = "resource_dependency_types" + + parent_resource_type: Mapped[text_] = mapped_column( + sa.ForeignKey("resource_types.resource_type"), primary_key=True + ) + child_resource_type: Mapped[text_] = mapped_column( + sa.ForeignKey("resource_types.resource_type"), primary_key=True + ) + + +class ResourceTypeDowngrade(DowngradeBase): + __tablename__ = "resource_types" + + resource_type: Mapped[text_] = mapped_column(primary_key=True) + + +class ResourceDependencyDowngrade(DowngradeBase): + __tablename__ = "resource_dependencies" + + parent_resource_id: Mapped[intpk] + child_resource_id: Mapped[intpk] + parent_resource_type: Mapped[text_] = mapped_column(nullable=False) + child_resource_type: Mapped[text_] = mapped_column(nullable=False) + + +class DraftResourceDowngrade(DowngradeBase): + __tablename__ = "draft_resources" + + # Database fields + draft_resource_id: Mapped[intpk] + resource_type: Mapped[text_] = mapped_column( + sa.ForeignKey("resource_types.resource_type"), nullable=False + ) + + +class SharedResourceDowngrade(DowngradeBase): + __tablename__ = "shared_resources" + + # Database fields + shared_resource_id: Mapped[intpk] + resource_id: Mapped[bigint] = mapped_column( + sa.ForeignKey("resources.resource_id"), nullable=False + ) + + +class ResourceDowngrade(DowngradeBase): + __tablename__ = "resources" + + # Database fields + resource_id: Mapped[intpk] + resource_type: Mapped[text_] = mapped_column( + sa.ForeignKey("resource_types.resource_type"), nullable=False + ) + + +class ResourceSnapshotDowngrade(DowngradeBase): + __tablename__ = "resource_snapshots" + + # Database fields + resource_snapshot_id: Mapped[intpk] + resource_id: Mapped[bigint] = mapped_column(nullable=False) + resource_type: Mapped[text_] = mapped_column( + sa.ForeignKey("resource_types.resource_type"), nullable=False + ) + + +class MlModelDowngrade(DowngradeBase): + __tablename__ = "ml_models" + + # Database fields + resource_snapshot_id: Mapped[intpk] + resource_id: Mapped[bigint] = mapped_column(nullable=False) + + +class MlModelVersionDowngrade(DowngradeBase): + __tablename__ = "ml_model_versions" + + # Database fields + resource_snapshot_id: Mapped[intpk] + resource_id: Mapped[bigint] = mapped_column(nullable=False) + + +class ResourceTagDowngrade(DowngradeBase): + __tablename__ = "resource_tags" + + # Database fields + resource_id: Mapped[intpk] = mapped_column(sa.ForeignKey("resources.resource_id")) + + +def upgrade(): + bind = op.get_bind() + Session = sessionmaker(bind=bind) + + # Update the list of allowed resource types and resource dependency types + with Session() as session: + session.add(ResourceTypeUpgrade(resource_type="ml_model_version")) + session.flush() + session.add( + ResourceDependencyTypeUpgrade( + parent_resource_type="ml_model", child_resource_type="ml_model_version" + ) + ) + session.commit() + + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "ml_model_versions", + sa.Column( + "resource_snapshot_id", + sa.BigInteger().with_variant(sa.Integer(), "sqlite"), + nullable=False, + ), + sa.Column( + "resource_id", + sa.BigInteger().with_variant(sa.Integer(), "sqlite"), + nullable=False, + ), + sa.Column( + "artifact_resource_snapshot_id", + sa.BigInteger().with_variant(sa.Integer(), "sqlite"), + nullable=False, + ), + sa.Column( + "version_number", + sa.BigInteger().with_variant(sa.Integer(), "sqlite"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["artifact_resource_snapshot_id"], + ["artifacts.resource_snapshot_id"], + name=op.f("fk_ml_model_versions_artifact_resource_snapshot_id_artifacts"), + ), + sa.ForeignKeyConstraint( + ["resource_snapshot_id", "resource_id"], + [ + "resource_snapshots.resource_snapshot_id", + "resource_snapshots.resource_id", + ], + name=op.f("fk_ml_model_versions_resource_snapshot_id_resource_snapshots"), + ), + sa.PrimaryKeyConstraint( + "resource_snapshot_id", name=op.f("pk_ml_model_versions") + ), + sa.UniqueConstraint( + "resource_snapshot_id", + "resource_id", + "artifact_resource_snapshot_id", + "version_number", + name=op.f( + "uq_ml_model_versions_resource_snapshot_id_artifact_resource_snapshot" + "_id_version_number" + ), + ), + sa.UniqueConstraint( + "resource_snapshot_id", + "resource_id", + name=op.f("uq_ml_model_versions_resource_snapshot_id"), + ), + ) + with op.batch_alter_table("ml_model_versions", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("ix_ml_model_versions_resource_id"), + ["resource_id"], + unique=False, + ) + batch_op.create_index( + batch_op.f("ix_ml_model_versions_resource_snapshot_id"), + ["resource_snapshot_id", "resource_id", "artifact_resource_snapshot_id"], + unique=True, + ) + + with op.batch_alter_table("ml_models", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_ml_models_resource_snapshot_id")) + batch_op.create_index( + batch_op.f("ix_ml_models_resource_snapshot_id"), + ["resource_snapshot_id", "resource_id"], + unique=True, + ) + batch_op.drop_constraint( + "fk_ml_models_artifact_resource_snapshot_id_artifacts", type_="foreignkey" + ) + batch_op.drop_column("artifact_resource_snapshot_id") + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + bind = op.get_bind() + Session = sessionmaker(bind=bind) + + # Remove all traces of MlModel and MlModelVersion resources + with Session() as session: + ml_models_stmt = sa.select(MlModelDowngrade) + ml_model_versions_stmt = sa.select(MlModelVersionDowngrade) + drafts_stmt = sa.select(DraftResourceDowngrade).where( + DraftResourceDowngrade.resource_type.in_(["ml_model", "ml_model_version"]) + ) + resources_stmt = sa.select(ResourceDowngrade).where( + ResourceDowngrade.resource_type.in_(["ml_model", "ml_model_version"]) + ) + resource_snapshots_stmt = sa.select(ResourceSnapshotDowngrade).where( + ResourceSnapshotDowngrade.resource_type.in_( + ["ml_model", "ml_model_version"] + ) + ) + resource_dependencies_stmt = sa.select(ResourceDependencyDowngrade).where( + ResourceDependencyDowngrade.parent_resource_type == "ml_model", + ResourceDependencyDowngrade.child_resource_type == "ml_model_version", + ) + cte_resource_ids = ( + sa.select(ResourceDowngrade.resource_id) + .where( + ResourceDowngrade.resource_type.in_(["ml_model", "ml_model_version"]) + ) + .cte() + ) + shared_resources_stmt = sa.select(SharedResourceDowngrade).where( + SharedResourceDowngrade.resource_id.in_(sa.select(cte_resource_ids)) + ) + resource_tags_stmt = sa.select(ResourceTagDowngrade).where( + ResourceTagDowngrade.resource_id.in_(sa.select(cte_resource_ids)) + ) + + for ml_model in session.scalars(ml_models_stmt): + session.delete(ml_model) + + for ml_model_version in session.scalars(ml_model_versions_stmt): + session.delete(ml_model_version) + + for draft in session.scalars(drafts_stmt): + session.delete(draft) + + for resource in session.scalars(resources_stmt): + session.delete(resource) + + for resource_snapshot in session.scalars(resource_snapshots_stmt): + session.delete(resource_snapshot) + + for resource_dependency in session.scalars(resource_dependencies_stmt): + session.delete(resource_dependency) + + for shared_resource in session.scalars(shared_resources_stmt): + session.delete(shared_resource) + + for resource_tag in session.scalars(resource_tags_stmt): + session.delete(resource_tag) + + session.commit() + + with op.batch_alter_table("ml_models", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "artifact_resource_snapshot_id", + sa.BigInteger().with_variant(sa.Integer(), "sqlite"), + nullable=True, + ) + ) + batch_op.create_foreign_key( + batch_op.f("fk_ml_models_artifact_resource_snapshot_id_artifacts"), + "artifacts", + ["artifact_resource_snapshot_id"], + ["resource_snapshot_id"], + ) + batch_op.drop_index(batch_op.f("ix_ml_models_resource_snapshot_id")) + batch_op.create_index( + batch_op.f("ix_ml_models_resource_snapshot_id"), + ["resource_snapshot_id", "resource_id", "artifact_resource_snapshot_id"], + unique=True, + ) + + # Workaround to ensure the migration won't fail (table should be empty) + with op.batch_alter_table("ml_models", schema=None) as batch_op: + batch_op.alter_column("artifact_resource_snapshot_id", nullable=False) + + with op.batch_alter_table("ml_model_versions", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_ml_model_versions_resource_snapshot_id")) + batch_op.drop_index(batch_op.f("ix_ml_model_versions_resource_id")) + + op.drop_table("ml_model_versions") + + # Downgrade the list of allowed resource types and resource dependency types + with Session() as session: + ml_model_version_type_stmt = sa.select(ResourceTypeDowngrade).where( + ResourceTypeDowngrade.resource_type == "ml_model_version" + ) + resource_dependency_types_stmt = sa.select( + ResourceDependencyTypeDowngrade + ).where( + ResourceDependencyTypeDowngrade.parent_resource_type == "ml_model", + ResourceDependencyTypeDowngrade.child_resource_type == "ml_model_version", + ) + ml_model_version_type = session.scalar(ml_model_version_type_stmt) + resource_dependency_type = session.scalar(resource_dependency_types_stmt) + + if ml_model_version_type is not None: + session.delete(ml_model_version_type) + + if resource_dependency_type is not None: + session.delete(resource_dependency_type) + + session.commit() + + # ### end Alembic commands ### diff --git a/src/dioptra/restapi/db/models/__init__.py b/src/dioptra/restapi/db/models/__init__.py index b383c02c8..794066517 100644 --- a/src/dioptra/restapi/db/models/__init__.py +++ b/src/dioptra/restapi/db/models/__init__.py @@ -40,7 +40,7 @@ resource_lock_types_table, user_lock_types_table, ) -from .ml_models import MlModel +from .ml_models import MlModel, MlModelVersion from .plugins import ( Plugin, PluginFile, @@ -81,6 +81,7 @@ "Job", "JobMlflowRun", "MlModel", + "MlModelVersion", "Plugin", "PluginFile", "PluginTask", diff --git a/src/dioptra/restapi/db/models/ml_models.py b/src/dioptra/restapi/db/models/ml_models.py index 9ba489564..7c424457b 100644 --- a/src/dioptra/restapi/db/models/ml_models.py +++ b/src/dioptra/restapi/db/models/ml_models.py @@ -14,13 +14,20 @@ # # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode -from sqlalchemy import ForeignKey, ForeignKeyConstraint, Index, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy import ( + ForeignKey, + ForeignKeyConstraint, + Index, + UniqueConstraint, + and_, + select, +) +from sqlalchemy.orm import Mapped, column_property, mapped_column, relationship from dioptra.restapi.db.db import bigint, intpk, text_ from .artifacts import Artifact -from .resources import ResourceSnapshot +from .resources import Resource, ResourceSnapshot, resource_dependencies_table # -- ORM Classes ----------------------------------------------------------------------- @@ -28,21 +35,68 @@ class MlModel(ResourceSnapshot): __tablename__ = "ml_models" + # Database fields + resource_snapshot_id: Mapped[intpk] = mapped_column(init=False) + resource_id: Mapped[bigint] = mapped_column(init=False, nullable=False, index=True) + name: Mapped[text_] = mapped_column(nullable=False) + + # Additional settings + __table_args__ = ( # type: ignore[assignment] + Index( + None, + "resource_snapshot_id", + "resource_id", + unique=True, + ), + ForeignKeyConstraint( + ["resource_snapshot_id", "resource_id"], + [ + "resource_snapshots.resource_snapshot_id", + "resource_snapshots.resource_id", + ], + ), + UniqueConstraint("resource_snapshot_id", "resource_id"), + ) + __mapper_args__ = { + "polymorphic_identity": "ml_model", + } + + +class MlModelVersion(ResourceSnapshot): + __tablename__ = "ml_model_versions" + # Database fields resource_snapshot_id: Mapped[intpk] = mapped_column(init=False) resource_id: Mapped[bigint] = mapped_column(init=False, nullable=False, index=True) artifact_resource_snapshot_id: Mapped[bigint] = mapped_column( ForeignKey("artifacts.resource_snapshot_id"), init=False, nullable=False ) - name: Mapped[text_] = mapped_column(nullable=False) + version_number: Mapped[bigint] = mapped_column(nullable=False) # Relationships artifact: Mapped[Artifact] = relationship( primaryjoin=( - "Artifact.resource_snapshot_id == MlModel.artifact_resource_snapshot_id" + "Artifact.resource_snapshot_id " + "== MlModelVersion.artifact_resource_snapshot_id" ), ) + # Derived fields (read-only) + model_id: Mapped[bigint] = column_property( + select(Resource.resource_id) + .join( + resource_dependencies_table, + and_( + Resource.resource_id + == resource_dependencies_table.c.parent_resource_id, + resource_dependencies_table.c.child_resource_id == resource_id, + ), + ) + .limit(1) + .correlate_except(Resource) + .scalar_subquery() + ) + # Additional settings __table_args__ = ( # type: ignore[assignment] Index( @@ -59,8 +113,14 @@ class MlModel(ResourceSnapshot): "resource_snapshots.resource_id", ], ), + UniqueConstraint( + "resource_snapshot_id", + "resource_id", + "artifact_resource_snapshot_id", + "version_number", + ), UniqueConstraint("resource_snapshot_id", "resource_id"), ) __mapper_args__ = { - "polymorphic_identity": "ml_model", + "polymorphic_identity": "ml_model_version", } diff --git a/src/dioptra/restapi/errors.py b/src/dioptra/restapi/errors.py index 4f4393dab..22f633101 100644 --- a/src/dioptra/restapi/errors.py +++ b/src/dioptra/restapi/errors.py @@ -130,10 +130,12 @@ def register_v1_error_handlers(api: Api) -> None: from dioptra.restapi import v1 register_base_v1_error_handlers(api) + v1.artifacts.errors.register_error_handlers(api) v1.entrypoints.errors.register_error_handlers(api) v1.experiments.errors.register_error_handlers(api) v1.groups.errors.register_error_handlers(api) v1.jobs.errors.register_error_handlers(api) + v1.models.errors.register_error_handlers(api) v1.plugin_parameter_types.errors.register_error_handlers(api) v1.plugins.errors.register_error_handlers(api) v1.queues.errors.register_error_handlers(api) diff --git a/src/dioptra/restapi/v1/artifacts/__init__.py b/src/dioptra/restapi/v1/artifacts/__init__.py index ab0a41a34..11ce655e6 100644 --- a/src/dioptra/restapi/v1/artifacts/__init__.py +++ b/src/dioptra/restapi/v1/artifacts/__init__.py @@ -14,3 +14,6 @@ # # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode +from . import errors + +__all__ = ["errors"] diff --git a/src/dioptra/restapi/v1/artifacts/controller.py b/src/dioptra/restapi/v1/artifacts/controller.py index 1ea5fd85f..917cbd325 100644 --- a/src/dioptra/restapi/v1/artifacts/controller.py +++ b/src/dioptra/restapi/v1/artifacts/controller.py @@ -18,15 +18,37 @@ from __future__ import annotations import uuid +from typing import cast +from urllib.parse import unquote import structlog from flask import request from flask_accepts import accepts, responds from flask_login import login_required from flask_restx import Namespace, Resource +from injector import inject from structlog.stdlib import BoundLogger -from .schema import ArtifactGetQueryParameters, ArtifactPageSchema, ArtifactSchema +from dioptra.restapi.db import models +from dioptra.restapi.routes import V1_ARTIFACTS_ROUTE +from dioptra.restapi.v1 import utils +from dioptra.restapi.v1.shared.snapshots.controller import ( + generate_resource_snapshots_endpoint, + generate_resource_snapshots_id_endpoint, +) + +from .schema import ( + ArtifactGetQueryParameters, + ArtifactMutableFieldsSchema, + ArtifactPageSchema, + ArtifactSchema, +) +from .service import ( + RESOURCE_TYPE, + SEARCHABLE_FIELDS, + ArtifactIdService, + ArtifactService, +) LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -35,6 +57,18 @@ @api.route("/") class ArtifactEndpoint(Resource): + @inject + def __init__(self, artifact_service: ArtifactService, *args, **kwargs) -> None: + """Initialize the artifact resource. + + All arguments are provided via dependency injection. + + Args: + artifact_service: A ArtifactService object. + """ + self._artifact_service = artifact_service + super().__init__(*args, **kwargs) + @login_required @accepts(query_params_schema=ArtifactGetQueryParameters, api=api) @responds(schema=ArtifactPageSchema, api=api) @@ -43,9 +77,32 @@ def get(self): log = LOGGER.new( request_id=str(uuid.uuid4()), resource="Artifact", request_type="GET" ) - log.debug("Request received") parsed_query_params = request.parsed_query_params # noqa: F841 + group_id = parsed_query_params["group_id"] + search_string = unquote(parsed_query_params["search"]) + page_index = parsed_query_params["index"] + page_length = parsed_query_params["page_length"] + + artifacts, total_num_artifacts = self._artifact_service.get( + group_id=group_id, + search_string=search_string, + page_index=page_index, + page_length=page_length, + log=log, + ) + return utils.build_paging_envelope( + V1_ARTIFACTS_ROUTE, + build_fn=utils.build_artifact, + data=artifacts, + group_id=group_id, + query=search_string, + draft_type=None, + index=page_index, + length=page_length, + total_num_elements=total_num_artifacts, + ) + @login_required @accepts(schema=ArtifactSchema, api=api) @responds(schema=ArtifactSchema, api=api) @@ -57,10 +114,31 @@ def post(self): log.debug("Request received") parsed_obj = request.parsed_obj # noqa: F841 + artifact = self._artifact_service.create( + uri=parsed_obj["uri"], + description=parsed_obj["description"], + group_id=parsed_obj["group_id"], + job_id=parsed_obj["job_id"], + log=log, + ) + return utils.build_artifact(artifact) + @api.route("/") @api.param("id", "ID for the Artifact resource.") class ArtifactIdEndpoint(Resource): + @inject + def __init__(self, artifact_id_service: ArtifactIdService, *args, **kwargs) -> None: + """Initialize the artifact_id resource. + + All arguments are provided via dependency injection. + + Args: + artifact_id_service: A ArtifactIdService object. + """ + self._artifact_id_service = artifact_id_service + super().__init__(*args, **kwargs) + @login_required @responds(schema=ArtifactSchema, api=api) def get(self, id: int): @@ -68,4 +146,46 @@ def get(self, id: int): log = LOGGER.new( request_id=str(uuid.uuid4()), resource="Artifact", request_type="GET", id=id ) - log.debug("Request received") + artifact = cast( + models.Artifact, + self._artifact_id_service.get(id, error_if_not_found=True, log=log), + ) + return utils.build_artifact(artifact) + + @login_required + @accepts(schema=ArtifactMutableFieldsSchema, api=api) + @responds(schema=ArtifactSchema, api=api) + def put(self, id: int): + """Modifies an Artifact resource.""" + log = LOGGER.new( + request_id=str(uuid.uuid4()), resource="Artifact", request_type="PUT", id=id + ) + parsed_obj = request.parsed_obj # type: ignore + artifact = cast( + models.Artifact, + self._artifact_id_service.modify( + id, + description=parsed_obj["description"], + error_if_not_found=True, + log=log, + ), + ) + return utils.build_artifact(artifact) + + +ArtifactSnapshotsResource = generate_resource_snapshots_endpoint( + api=api, + resource_model=models.Artifact, + resource_name=RESOURCE_TYPE, + route_prefix=V1_ARTIFACTS_ROUTE, + searchable_fields=SEARCHABLE_FIELDS, + page_schema=ArtifactPageSchema, + build_fn=utils.build_artifact, +) +ArtifactSnapshotsIdResource = generate_resource_snapshots_id_endpoint( + api=api, + resource_model=models.Artifact, + resource_name=RESOURCE_TYPE, + response_schema=ArtifactSchema, + build_fn=utils.build_artifact, +) diff --git a/src/dioptra/restapi/v1/artifacts/errors.py b/src/dioptra/restapi/v1/artifacts/errors.py new file mode 100644 index 000000000..827dbc1ec --- /dev/null +++ b/src/dioptra/restapi/v1/artifacts/errors.py @@ -0,0 +1,44 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Error handlers for the artifact endpoints.""" +from __future__ import annotations + +from flask_restx import Api + + +class ArtifactAlreadyExistsError(Exception): + """The queue name already exists.""" + + +class ArtifactDoesNotExistError(Exception): + """The requested artifact does not exist.""" + + +def register_error_handlers(api: Api) -> None: + @api.errorhandler(ArtifactDoesNotExistError) + def handle_artifact_does_not_exist_error(error): + return {"message": "Not Found - The requested artifact does not exist"}, 404 + + @api.errorhandler(ArtifactAlreadyExistsError) + def handle_artifact_already_exists_error(error): + return ( + { + "message": "Bad Request - The artifact uri on the registration form " + "already exists. Please select another and resubmit." + }, + 400, + ) diff --git a/src/dioptra/restapi/v1/artifacts/schema.py b/src/dioptra/restapi/v1/artifacts/schema.py index e28dd4bc0..093465b44 100644 --- a/src/dioptra/restapi/v1/artifacts/schema.py +++ b/src/dioptra/restapi/v1/artifacts/schema.py @@ -19,6 +19,7 @@ from dioptra.restapi.v1.schemas import ( BasePageSchema, + GroupIdQueryParametersSchema, PagingQueryParametersSchema, SearchQueryParametersSchema, generate_base_resource_ref_schema, @@ -62,16 +63,9 @@ class ArtifactSchema(ArtifactMutableFieldsSchema, ArtifactBaseSchema): # type: metadata=dict(description="id of the job that produced this Artifact"), required=True, ) - mlflowRunId = fields.Int( - attribute="mlflow_run_id", - data_key="mlflowRun", - metadata=dict(description="id of the tracking MLflow run"), - required=True, - ) - artifactUri = fields.URL( - attribute="artifact_uri", + uri = fields.String( + attribute="uri", metadata=dict(description="URL pointing to the location of the Artifact."), - relative=True, required=True, ) @@ -88,6 +82,7 @@ class ArtifactPageSchema(BasePageSchema): class ArtifactGetQueryParameters( PagingQueryParametersSchema, + GroupIdQueryParametersSchema, SearchQueryParametersSchema, ): """The query parameters for the GET method of the /artifacts endpoint.""" diff --git a/src/dioptra/restapi/v1/artifacts/service.py b/src/dioptra/restapi/v1/artifacts/service.py new file mode 100644 index 000000000..a97cd1139 --- /dev/null +++ b/src/dioptra/restapi/v1/artifacts/service.py @@ -0,0 +1,381 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The server-side functions that perform artifact endpoint operations.""" +from __future__ import annotations + +from typing import Any, Final + +import structlog +from flask_login import current_user +from injector import inject +from sqlalchemy import Integer, func, select +from structlog.stdlib import BoundLogger + +from dioptra.restapi.db import db, models +from dioptra.restapi.errors import BackendDatabaseError +from dioptra.restapi.v1 import utils +from dioptra.restapi.v1.groups.service import GroupIdService +from dioptra.restapi.v1.shared.search_parser import construct_sql_query_filters + +from .errors import ArtifactAlreadyExistsError, ArtifactDoesNotExistError + +# from dioptra.restapi.v1.jobs.service import JobIdService + + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +RESOURCE_TYPE: Final[str] = "artifact" +SEARCHABLE_FIELDS: Final[dict[str, Any]] = { + "uri": lambda x: models.Artifact.uri.like(x, escape="/"), + "description": lambda x: models.Artifact.description.like(x, escape="/"), +} + + +class ArtifactService(object): + """The service methods for registering and managing artifacts by their unique id.""" + + @inject + def __init__( + self, + artifact_uri_service: ArtifactUriService, + # job_id_service: JobIdService, + group_id_service: GroupIdService, + ) -> None: + """Initialize the artifact service. + + All arguments are provided via dependency injection. + + Args: + artifact_uri_service: An ArtifactUriService object. + job_id_service: A JobIdService object. + group_id_service: A GroupIdService object. + """ + self._artifact_uri_service = artifact_uri_service + # self._job_id_service = job_id_service + self._group_id_service = group_id_service + + def create( + self, + uri: str, + description: str, + group_id: int, + job_id: int, + commit: bool = True, + **kwargs, + ) -> utils.ArtifactDict: + """Create a new artifact. + + Args: + uri: The uri of the artifact. This must be globally unique. + description: The description of the artifact. + group_id: The group that will own the artifact. + commit: If True, commit the transaction. Defaults to True. + + Returns: + The newly created artifact object. + + Raises: + ArtifactAlreadyExistsError: If the artifact already exists. + + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + if self._artifact_uri_service.get(uri, log=log) is not None: + log.debug("Artifact uri already exists", uri=uri) + raise ArtifactAlreadyExistsError + + # job = self._job_id_service.get(job_id, error_if_not_found=True) + group = self._group_id_service.get(group_id, error_if_not_found=True) + + resource = models.Resource(resource_type="artifact", owner=group) + new_artifact = models.Artifact( + uri=uri, + description=description, + resource=resource, + creator=current_user, + ) + # job.children.append(new_artifact.resource) + db.session.add(new_artifact) + + if commit: + db.session.commit() + log.debug( + "Artifact registration successful", artifact_id=new_artifact.resource_id + ) + + return utils.ArtifactDict(artifact=new_artifact, has_draft=False) + + def get( + self, + group_id: int | None, + search_string: str, + page_index: int, + page_length: int, + **kwargs, + ) -> Any: + """Fetch a list of artifacts, optionally filtering by search string and paging + parameters. + + Args: + group_id: A group ID used to filter results. + search_string: A search string used to filter results. + page_index: The index of the first group to be returned. + page_length: The maximum number of artifacts to be returned. + + Returns: + A tuple containing a list of artifacts and the total number of artifacts + matching the query. + + Raises: + SearchNotImplementedError: If a search string is provided. + BackendDatabaseError: If the database query returns a None when counting + the number of artifacts. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Get full list of artifacts") + + filters = list() + + if group_id is not None: + filters.append(models.Resource.group_id == group_id) + + if search_string: + filters.append( + construct_sql_query_filters(search_string, SEARCHABLE_FIELDS) + ) + + stmt = ( + select(func.count(models.Artifact.resource_id)) + .join(models.Resource) + .where( + *filters, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.Artifact.resource_snapshot_id, + ) + ) + total_num_artifacts = db.session.scalars(stmt).first() + + if total_num_artifacts is None: + log.error( + "The database query returned a None when counting the number of " + "groups when it should return a number.", + sql=str(stmt), + ) + raise BackendDatabaseError + + if total_num_artifacts == 0: + return [], total_num_artifacts + + # get latest artifact snapshots + lastest_artifacts_stmt = ( + select(models.Artifact) + .join(models.Resource) + .where( + *filters, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.Artifact.resource_snapshot_id, + ) + .offset(page_index) + .limit(page_length) + ) + artifacts = db.session.scalars(lastest_artifacts_stmt).all() + + drafts_stmt = select( + models.DraftResource.payload["resource_id"].as_string().cast(Integer) + ).where( + models.DraftResource.payload["resource_id"] + .as_string() + .cast(Integer) + .in_(tuple(artifact.resource_id for artifact in artifacts)), + models.DraftResource.user_id == current_user.user_id, + ) + artifacts_dict: dict[int, utils.ArtifactDict] = { + artifact.resource_id: utils.ArtifactDict(artifact=artifact, has_draft=False) + for artifact in artifacts + } + for resource_id in db.session.scalars(drafts_stmt): + artifacts_dict[resource_id]["has_draft"] = True + + return list(artifacts_dict.values()), total_num_artifacts + + +class ArtifactUriService(object): + """The service methods for managing artifacts by their uri.""" + + def get( + self, + artifact_uri: str, + error_if_not_found: bool = False, + **kwargs, + ) -> utils.ArtifactDict | None: + """Fetch an artifact by its unique uri. + + Args: + artifact_uri: the unique uri of the artifact. + error_if_not_found: If True, raise an error if the artifact is not found. + Defaults to False. + + + Returns: + The artifact object if found, otherwise None. + + Raises: + ArtifactDoesNotExistError: If the artifact is not found and + `error_if_not_found` is True. + + """ + + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Get artifact by id", artifact_uri=artifact_uri) + + stmt = ( + select(models.Artifact) + .join(models.Resource) + .where( + models.Artifact.uri == artifact_uri, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.Artifact.resource_snapshot_id, + ) + ) + artifact = db.session.scalars(stmt).first() + + if artifact is None: + if error_if_not_found: + log.debug("Artifact not found", artifact_uri=artifact_uri) + raise ArtifactDoesNotExistError + + return None + + return artifact + + +class ArtifactIdService(object): + """The service methods for retrieving artifacts by their unique id.""" + + def get( + self, + artifact_id: int, + error_if_not_found: bool = False, + **kwargs, + ) -> utils.ArtifactDict | None: + """Fetch a artifact by its unique id. + + Args: + artifact_id: The unique id of the artifact. + error_if_not_found: If True, raise an error if the artifact is not found. + Defaults to False. + + Returns: + The artifact object if found, otherwise None. + + Raises: + ArtifactDoesNotExistError: If the artifact is not found and + `error_if_not_found` is True. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Get artifact by id", artifact_id=artifact_id) + + stmt = ( + select(models.Artifact) + .join(models.Resource) + .where( + models.Artifact.resource_id == artifact_id, + models.Artifact.resource_snapshot_id + == models.Resource.latest_snapshot_id, + models.Resource.is_deleted == False, # noqa: E712 + ) + ) + artifact = db.session.scalars(stmt).first() + + if artifact is None: + if error_if_not_found: + log.debug("Artifact not found", artifact_id=artifact_id) + raise ArtifactDoesNotExistError + + return None + + drafts_stmt = ( + select(models.DraftResource.draft_resource_id) + .where( + models.DraftResource.payload["resource_id"].as_string().cast(Integer) + == artifact.resource_id, + models.DraftResource.user_id == current_user.user_id, + ) + .exists() + .select() + ) + has_draft = db.session.scalar(drafts_stmt) + + return utils.ArtifactDict(artifact=artifact, has_draft=has_draft) + + def modify( + self, + artifact_id: int, + description: str, + error_if_not_found: bool = False, + commit: bool = True, + **kwargs, + ) -> utils.ArtifactDict | None: + """Modify a artifact. + + Args: + artifact_id: The unique id of the artifact. + description: The new description of the artifact. + error_if_not_found: If True, raise an error if the group is not found. + Defaults to False. + commit: If True, commit the transaction. Defaults to True. + + Returns: + The updated artifact object. + + Raises: + ArtifactDoesNotExistError: If the artifact is not found and + `error_if_not_found` is True. + ArtifactAlreadyExistsError: If the artifact name already exists. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + artifact_dict = self.get( + artifact_id, error_if_not_found=error_if_not_found, log=log + ) + + if artifact_dict is None: + return None + + artifact = artifact_dict["artifact"] + has_draft = artifact_dict["has_draft"] + + new_artifact = models.Artifact( + uri=artifact.uri, + description=description, + resource=artifact.resource, + creator=current_user, + ) + db.session.add(new_artifact) + + if commit: + db.session.commit() + log.debug( + "Artifact modification successful", + artifact_id=artifact_id, + description=description, + ) + + return utils.ArtifactDict(artifact=new_artifact, has_draft=has_draft) diff --git a/src/dioptra/restapi/v1/entrypoints/schema.py b/src/dioptra/restapi/v1/entrypoints/schema.py index f9c85bffb..07573aed8 100644 --- a/src/dioptra/restapi/v1/entrypoints/schema.py +++ b/src/dioptra/restapi/v1/entrypoints/schema.py @@ -180,7 +180,7 @@ class EntrypointSchema( EntrypointPluginSchema, attribute="plugins", many=True, - metadata=dict(description="List of plugin files for the entrypoint."), + metadata=dict(description="List of plugins for the entrypoint."), dump_only=True, ) queues = fields.Nested( diff --git a/src/dioptra/restapi/v1/models/__init__.py b/src/dioptra/restapi/v1/models/__init__.py index ab0a41a34..11ce655e6 100644 --- a/src/dioptra/restapi/v1/models/__init__.py +++ b/src/dioptra/restapi/v1/models/__init__.py @@ -14,3 +14,6 @@ # # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode +from . import errors + +__all__ = ["errors"] diff --git a/src/dioptra/restapi/v1/models/controller.py b/src/dioptra/restapi/v1/models/controller.py index a1bc6a24e..16d53e24b 100644 --- a/src/dioptra/restapi/v1/models/controller.py +++ b/src/dioptra/restapi/v1/models/controller.py @@ -18,21 +18,52 @@ from __future__ import annotations import uuid +from typing import cast +from urllib.parse import unquote import structlog from flask import request from flask_accepts import accepts, responds from flask_login import login_required from flask_restx import Namespace, Resource +from injector import inject from structlog.stdlib import BoundLogger +from dioptra.restapi.db import models +from dioptra.restapi.routes import V1_MODELS_ROUTE +from dioptra.restapi.v1 import utils from dioptra.restapi.v1.schemas import IdStatusResponseSchema +from dioptra.restapi.v1.shared.drafts.controller import ( + generate_resource_drafts_endpoint, + generate_resource_drafts_id_endpoint, + generate_resource_id_draft_endpoint, +) +from dioptra.restapi.v1.shared.snapshots.controller import ( + generate_resource_snapshots_endpoint, + generate_resource_snapshots_id_endpoint, +) +from dioptra.restapi.v1.shared.tags.controller import ( + generate_resource_tags_endpoint, + generate_resource_tags_id_endpoint, +) from .schema import ( ModelGetQueryParameters, ModelMutableFieldsSchema, ModelPageSchema, ModelSchema, + ModelVersionGetQueryParameters, + ModelVersionMutableFieldsSchema, + ModelVersionPageSchema, + ModelVersionSchema, +) +from .service import ( + MODEL_RESOURCE_TYPE, + MODEL_SEARCHABLE_FIELDS, + ModelIdService, + ModelIdVersionsNumberService, + ModelIdVersionsService, + ModelService, ) LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -42,16 +73,51 @@ @api.route("/") class ModelEndpoint(Resource): + @inject + def __init__(self, model_service: ModelService, *args, **kwargs) -> None: + """Initialize the model resource. + + All arguments are provided via dependency injection. + + Args: + model_service: A ModelService object. + """ + self._model_service = model_service + super().__init__(*args, **kwargs) + @login_required @accepts(query_params_schema=ModelGetQueryParameters, api=api) @responds(schema=ModelPageSchema, api=api) def get(self): """Gets a list of all Model resources.""" log = LOGGER.new( - request_id=str(uuid.uuid4()), resource="Model", request_type="GET" + request_id=str(uuid.uuid4()), resource="Models", request_type="GET" ) - log.debug("Request received") + parsed_query_params = request.parsed_query_params # noqa: F841 + group_id = parsed_query_params["group_id"] + search_string = unquote(parsed_query_params["search"]) + page_index = parsed_query_params["index"] + page_length = parsed_query_params["page_length"] + + models, total_num_models = self._model_service.get( + group_id=group_id, + search_string=search_string, + page_index=page_index, + page_length=page_length, + log=log, + ) + return utils.build_paging_envelope( + "models", + build_fn=utils.build_model, + data=models, + group_id=group_id, + query=search_string, + draft_type=None, + index=page_index, + length=page_length, + total_num_elements=total_num_models, + ) @login_required @accepts(schema=ModelSchema, api=api) @@ -61,13 +127,31 @@ def post(self): log = LOGGER.new( request_id=str(uuid.uuid4()), resource="Model", request_type="POST" ) - log.debug("Request received") parsed_obj = request.parsed_obj # noqa: F841 + model = self._model_service.create( + name=parsed_obj["name"], + description=parsed_obj["description"], + group_id=parsed_obj["group_id"], + log=log, + ) + return utils.build_model(model) @api.route("/") @api.param("id", "ID for the Model resource.") class ModelIdEndpoint(Resource): + @inject + def __init__(self, model_id_service: ModelIdService, *args, **kwargs) -> None: + """Initialize the model resource. + + All arguments are provided via dependency injection. + + Args: + model_id_service: A ModelIdService object. + """ + self._model_id_service = model_id_service + super().__init__(*args, **kwargs) + @login_required @responds(schema=ModelSchema, api=api) def get(self, id: int): @@ -75,7 +159,11 @@ def get(self, id: int): log = LOGGER.new( request_id=str(uuid.uuid4()), resource="Model", request_type="GET", id=id ) - log.debug("Request received") + model = cast( + models.MlModel, + self._model_id_service.get(id, error_if_not_found=True, log=log), + ) + return utils.build_model(model) @login_required @responds(schema=IdStatusResponseSchema, api=api) @@ -84,7 +172,7 @@ def delete(self, id: int): log = LOGGER.new( request_id=str(uuid.uuid4()), resource="Model", request_type="DELETE", id=id ) - log.debug("Request received") + return self._model_id_service.delete(model_id=id, log=log) @login_required @accepts(schema=ModelMutableFieldsSchema, api=api) @@ -94,5 +182,189 @@ def put(self, id: int): log = LOGGER.new( request_id=str(uuid.uuid4()), resource="Model", request_type="PUT", id=id ) - log.debug("Request received") parsed_obj = request.parsed_obj # type: ignore # noqa: F841 + model = cast( + models.MlModel, + self._model_id_service.modify( + id, + name=parsed_obj["name"], + description=parsed_obj["description"], + error_if_not_found=True, + log=log, + ), + ) + return utils.build_model(model) + + +@api.route("//versions") +@api.param("id", "ID for the Models resource.") +class ModelIdVersionsEndpoint(Resource): + @inject + def __init__( + self, model_id_versions_service: ModelIdVersionsService, *args, **kwargs + ) -> None: + """Initialize the model resource. + + All arguments are provided via dependency injection. + + Args: + model_id_versions_service: A ModelIdVersionsService object. + """ + self._model_id_versions_service = model_id_versions_service + super().__init__(*args, **kwargs) + + @login_required + @accepts(query_params_schema=ModelVersionGetQueryParameters, api=api) + @responds(schema=ModelVersionPageSchema, api=api) + def get(self, id: int): + """Gets a list of versions of this Model resource.""" + log = LOGGER.new( + request_id=str(uuid.uuid4()), resource="Model", request_type="GET" + ) + + parsed_query_params = request.parsed_query_params # type: ignore + search_string = unquote(parsed_query_params["search"]) + page_index = parsed_query_params["index"] + page_length = parsed_query_params["page_length"] + + versions, total_num_versions = cast( + tuple[list[utils.ModelWithVersionDict], int], + self._model_id_versions_service.get( + model_id=id, + search_string=search_string, + page_index=page_index, + page_length=page_length, + error_if_not_found=True, + log=log, + ), + ) + return utils.build_paging_envelope( + f"models/{id}/versions", + build_fn=utils.build_model_version, + data=versions, + group_id=None, + query=search_string, + draft_type=None, + index=page_index, + length=page_length, + total_num_elements=total_num_versions, + ) + + @login_required + @accepts(schema=ModelVersionSchema, api=api) + @responds(schema=ModelVersionSchema, api=api) + def post(self, id: int): + """Creates a new version of the model resource.""" + log = LOGGER.new( + request_id=str(uuid.uuid4()), resource="Model", request_type="POST" + ) + parsed_obj = request.parsed_obj # type: ignore + model = self._model_id_versions_service.create( + id, + description=parsed_obj["description"], + artifact_id=parsed_obj["artifact_id"], + log=log, + ) + return utils.build_model_version(model) + + +@api.route("//versions/") +@api.param("id", "ID for the Model resource.") +@api.param("versionNumber", "Version number for the Model resource.") +class ModelIdVersionsNumberEndpoint(Resource): + @inject + def __init__( + self, + model_id_versions_number_service: ModelIdVersionsNumberService, + *args, + **kwargs, + ) -> None: + """Initialize the model resource. + + All arguments are provided via dependency injection. + + Args: + model_id_versions_number_service: A ModelIdVersionsNumberService object. + """ + self._model_id_versions_number_service = model_id_versions_number_service + super().__init__(*args, **kwargs) + + @login_required + @responds(schema=ModelVersionSchema, api=api) + def get(self, id: int, versionNumber: int): + """Gets a specific version of a model by version number.""" + log = LOGGER.new( + request_id=str(uuid.uuid4()), resource="Models", request_type="GET" + ) + model = cast( + utils.ModelWithVersionDict, + self._model_id_versions_number_service.get( + id, version_number=versionNumber, error_if_not_found=True, log=log + ), + ) + return utils.build_model_version(model) + + @login_required + @accepts(schema=ModelVersionMutableFieldsSchema, api=api) + @responds(schema=ModelVersionSchema, api=api) + def put(self, id: int, versionNumber: int): + """Modifies specific version of the model resource.""" + log = LOGGER.new( + request_id=str(uuid.uuid4()), resource="Model", request_type="POST" + ) + parsed_obj = request.parsed_obj # type: ignore + model = cast( + utils.ModelWithVersionDict, + self._model_id_versions_number_service.modify( + id, + versionNumber, + description=parsed_obj["description"], + error_if_not_found=True, + log=log, + ), + ) + return utils.build_model_version(model) + + +ModelDraftResource = generate_resource_drafts_endpoint( + api, + resource_name=MODEL_RESOURCE_TYPE, + route_prefix=V1_MODELS_ROUTE, + request_schema=ModelSchema, +) +ModelDraftIdResource = generate_resource_drafts_id_endpoint( + api, + resource_name=MODEL_RESOURCE_TYPE, + request_schema=ModelSchema(exclude=["groupId"]), +) +ModelIdDraftResource = generate_resource_id_draft_endpoint( + api, + resource_name=MODEL_RESOURCE_TYPE, + request_schema=ModelSchema(exclude=["groupId"]), +) + +ModelSnapshotsResource = generate_resource_snapshots_endpoint( + api=api, + resource_model=models.MlModel, + resource_name=MODEL_RESOURCE_TYPE, + route_prefix=V1_MODELS_ROUTE, + searchable_fields=MODEL_SEARCHABLE_FIELDS, + page_schema=ModelPageSchema, + build_fn=utils.build_entrypoint, +) +ModelSnapshotsIdResource = generate_resource_snapshots_id_endpoint( + api=api, + resource_model=models.MlModel, + resource_name=MODEL_RESOURCE_TYPE, + response_schema=ModelSchema, + build_fn=utils.build_entrypoint, +) + +ModelTagsResource = generate_resource_tags_endpoint( + api=api, + resource_name=MODEL_RESOURCE_TYPE, +) +ModelTagsIdResource = generate_resource_tags_id_endpoint( + api=api, + resource_name=MODEL_RESOURCE_TYPE, +) diff --git a/src/dioptra/restapi/v1/models/errors.py b/src/dioptra/restapi/v1/models/errors.py new file mode 100644 index 000000000..fb87d885c --- /dev/null +++ b/src/dioptra/restapi/v1/models/errors.py @@ -0,0 +1,54 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Error handlers for the model endpoints.""" +from __future__ import annotations + +from flask_restx import Api + + +class ModelAlreadyExistsError(Exception): + """The model name already exists.""" + + +class ModelDoesNotExistError(Exception): + """The requested model does not exist.""" + + +class ModelVersionDoesNotExistError(Exception): + """The requested version of the model does not exist.""" + + +def register_error_handlers(api: Api) -> None: + @api.errorhandler(ModelDoesNotExistError) + def handle_model_does_not_exist_error(error): + return {"message": "Not Found - The requested model does not exist"}, 404 + + @api.errorhandler(ModelVersionDoesNotExistError) + def handle_model_version_does_not_exist_error(error): + return { + "message": "Not Found - The requested model version does not exist" + }, 404 + + @api.errorhandler(ModelAlreadyExistsError) + def handle_model_already_exists_error(error): + return ( + { + "message": "Bad Request - The model name on the registration form " + "already exists. Please select another and resubmit." + }, + 400, + ) diff --git a/src/dioptra/restapi/v1/models/schema.py b/src/dioptra/restapi/v1/models/schema.py index d321a5cbe..ae23cd7b7 100644 --- a/src/dioptra/restapi/v1/models/schema.py +++ b/src/dioptra/restapi/v1/models/schema.py @@ -20,6 +20,7 @@ from dioptra.restapi.v1.artifacts.schema import ArtifactRefSchema from dioptra.restapi.v1.schemas import ( BasePageSchema, + GroupIdQueryParametersSchema, PagingQueryParametersSchema, SearchQueryParametersSchema, generate_base_resource_ref_schema, @@ -29,26 +30,86 @@ ModelBaseRefSchema = generate_base_resource_ref_schema("Model") -class ModelMutableFieldsSchema(Schema): - """The fields schema for the mutable data in a Model resource.""" +class ModelRefSchema(ModelBaseRefSchema): # type: ignore + """The reference schema for the data stored in a Model resource.""" name = fields.String( attribute="name", metadata=dict(description="Name of the Model resource."), required=True, ) + + +class ModelVersionRefSchema(Schema): + """The reference schema for the data stored in a Model Version.""" + + versionNumber = fields.Integer( + attribute="version_number", + metadata=dict(description="The version number of the Model."), + dump_only=True, + ) + url = fields.Url( + attribute="url", + metadata=dict(description="URL for accessing the full Model Version."), + relative=True, + ) + + +class ModelVersionMutableFieldsSchema(Schema): description = fields.String( attribute="description", - metadata=dict(description="Description of the Model resource."), + metadata=dict(description="Description of the Model Version."), load_default=None, ) + + +class ModelVersionSchema(ModelVersionMutableFieldsSchema): + """The schema for the data stored in a ModelVersion resource.""" + + model = fields.Nested( + ModelRefSchema, + attribute="model", + metadata=dict(description="The Model resource."), + dump_only=True, + ) artifactId = fields.Integer( attribute="artifact_id", data_key="artifact", - metadata=dict(description="The artifact representing the Model."), + metadata=dict(description="The artifact representing the Model Version."), load_only=True, required=True, ) + artifact = fields.Nested( + ArtifactRefSchema, + attribute="artifact", + metadata=dict(description="The artifact representing the Model Version."), + dump_only=True, + ) + versionNumber = fields.Integer( + attribute="version_number", + metadata=dict(description="The version number of the Model."), + dump_only=True, + ) + createdOn = fields.DateTime( + attribute="created_on", + metadata=dict(description="Timestamp when the Model Version was created."), + dump_only=True, + ) + + +class ModelMutableFieldsSchema(Schema): + """The fields schema for the mutable data in a Model resource.""" + + name = fields.String( + attribute="name", + metadata=dict(description="Name of the Model resource."), + required=True, + ) + description = fields.String( + attribute="description", + metadata=dict(description="Description of the Model resource."), + load_default=None, + ) ModelBaseSchema = generate_base_resource_schema("Model", snapshot=True) @@ -57,15 +118,17 @@ class ModelMutableFieldsSchema(Schema): class ModelSchema(ModelMutableFieldsSchema, ModelBaseSchema): # type: ignore """The schema for the data stored in a Model resource.""" - artifact = fields.Nested( - ArtifactRefSchema, - attribute="artifact", - metadata=dict(description="The artifact representing the Model."), + versions = fields.Nested( + ModelVersionRefSchema, + many=True, + attribute="versions", + metadata=dict(description="The details of this model version."), dump_only=True, ) - versionNumber = fields.Integer( - attribute="version_number", - metadata=dict(description="The version number of the Model."), + latestVersion = fields.Nested( + ModelVersionSchema, + attribute="latest_version", + metadata=dict(description="The details of latest version of this model."), dump_only=True, ) @@ -80,8 +143,26 @@ class ModelPageSchema(BasePageSchema): ) +class ModelVersionPageSchema(BasePageSchema): + """The paged schema for the data stored in a Model resource.""" + + data = fields.Nested( + ModelVersionSchema, + many=True, + metadata=dict(description="List of Model resources in the current page."), + ) + + class ModelGetQueryParameters( PagingQueryParametersSchema, + GroupIdQueryParametersSchema, SearchQueryParametersSchema, ): """The query parameters for the GET method of the /models endpoint.""" + + +class ModelVersionGetQueryParameters( + PagingQueryParametersSchema, + SearchQueryParametersSchema, +): + """The query parameters for the GET method of the /models/{id}/versions endpoint.""" diff --git a/src/dioptra/restapi/v1/models/service.py b/src/dioptra/restapi/v1/models/service.py new file mode 100644 index 000000000..0404ae22b --- /dev/null +++ b/src/dioptra/restapi/v1/models/service.py @@ -0,0 +1,822 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The server-side functions that perform model endpoint operations.""" +from __future__ import annotations + +from typing import Any, Final, cast + +import structlog +from flask_login import current_user +from injector import inject +from sqlalchemy import Integer, func, select +from sqlalchemy.orm import aliased +from structlog.stdlib import BoundLogger + +from dioptra.restapi.db import db, models +from dioptra.restapi.errors import BackendDatabaseError +from dioptra.restapi.v1 import utils +from dioptra.restapi.v1.artifacts.service import ArtifactIdService +from dioptra.restapi.v1.groups.service import GroupIdService +from dioptra.restapi.v1.shared.search_parser import construct_sql_query_filters + +from .errors import ( + ModelAlreadyExistsError, + ModelDoesNotExistError, + ModelVersionDoesNotExistError, +) + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +MODEL_RESOURCE_TYPE: Final[str] = "ml_model" +MODEL_VERSION_RESOURCE_TYPE: Final[str] = "ml_model_version" +MODEL_SEARCHABLE_FIELDS: Final[dict[str, Any]] = { + "name": lambda x: models.MlModel.name.like(x, escape="/"), + "description": lambda x: models.MlModel.description.like(x, escape="/"), +} +MODEL_VERSION_SEARCHABLE_FIELDS: Final[dict[str, Any]] = { + "description": lambda x: models.MlModelVersion.description.like(x, escape="/"), +} + + +class ModelService(object): + """The service methods for registering and managing models by their unique id.""" + + @inject + def __init__( + self, + model_name_service: ModelNameService, + group_id_service: GroupIdService, + ) -> None: + """Initialize the model service. + + All arguments are provided via dependency injection. + + Args: + model_name_service: A ModelNameService object. + group_id_service: A GroupIdService object. + """ + self._model_name_service = model_name_service + self._group_id_service = group_id_service + + def create( + self, + name: str, + description: str, + group_id: int, + commit: bool = True, + **kwargs, + ) -> utils.ModelWithVersionDict: + """Create a new model. + + Args: + name: The name of the model. The combination of name and group_id must be + unique. + description: The description of the model. + group_id: The group that will own the model. + commit: If True, commit the transaction. Defaults to True. + + Returns: + The newly created model object. + + Raises: + ModelAlreadyExistsError: If a model with the given name already exists. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + if self._model_name_service.get(name, group_id=group_id, log=log) is not None: + log.debug("Model name already exists", name=name, group_id=group_id) + raise ModelAlreadyExistsError + + group = self._group_id_service.get(group_id, error_if_not_found=True) + + resource = models.Resource(resource_type=MODEL_RESOURCE_TYPE, owner=group) + + ml_model = models.MlModel( + name=name, + description=description, + resource=resource, + creator=current_user, + ) + + db.session.add(ml_model) + + if commit: + db.session.commit() + log.debug( + "Model registration successful", + model_id=ml_model.resource_id, + name=ml_model.name, + ) + + return utils.ModelWithVersionDict(model=ml_model, version=None, has_draft=False) + + def get( + self, + group_id: int | None, + search_string: str, + page_index: int, + page_length: int, + **kwargs, + ) -> Any: + """Fetch a list of models, optionally filtering by search string and paging + parameters. + + Args: + group_id: A group ID used to filter results. + search_string: A search string used to filter results. + page_index: The index of the first group to be returned. + page_length: The maximum number of models to be returned. + + Returns: + A tuple containing a list of models and the total number of models matching + the query. + + Raises: + SearchNotImplementedError: If a search string is provided. + BackendDatabaseError: If the database query returns a None when counting + the number of models. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Get full list of models") + + filters = list() + + if group_id is not None: + filters.append(models.Resource.group_id == group_id) + + if search_string: + filters.append( + construct_sql_query_filters(search_string, MODEL_SEARCHABLE_FIELDS) + ) + + stmt = ( + select(func.count(models.MlModel.resource_id)) + .join(models.Resource) + .where( + *filters, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.MlModel.resource_snapshot_id, + ) + ) + total_num_ml_models = db.session.scalars(stmt).first() + + if total_num_ml_models is None: + log.error( + "The database query returned a None when counting the number of " + "groups when it should return a number.", + sql=str(stmt), + ) + raise BackendDatabaseError + + if total_num_ml_models == 0: + return [], total_num_ml_models + + latest_ml_models_stmt = ( + select(models.MlModel) + .join(models.Resource) + .where( + *filters, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.MlModel.resource_snapshot_id, + ) + .offset(page_index) + .limit(page_length) + ) + ml_models = db.session.scalars(latest_ml_models_stmt).all() + + # extract list of model ids + model_ids = [ml_model.resource_id for ml_model in ml_models] + + # Build CTE that retrieves all snapshot ids for the list of ml model versions + # associated with retrieved ml models + parent_model = aliased(models.MlModel) + model_version_snapshot_ids_cte = ( + select(models.MlModelVersion.resource_snapshot_id) + .join( + models.resource_dependencies_table, + models.MlModelVersion.resource_id + == models.resource_dependencies_table.c.child_resource_id, + ) + .join( + parent_model, + parent_model.resource_id + == models.resource_dependencies_table.c.parent_resource_id, + ) + .where(parent_model.resource_id.in_(model_ids)) + .cte() + ) + + # get the latest model version snapshots associated with the retrieved models + latest_model_versions_stmt = ( + select(models.MlModelVersion) + .join(models.Resource) + .where( + models.MlModelVersion.resource_snapshot_id.in_( + select(model_version_snapshot_ids_cte) + ), + models.Resource.latest_snapshot_id + == models.MlModelVersion.resource_snapshot_id, + models.Resource.is_deleted == False, # noqa: E712 + ) + .order_by(models.MlModelVersion.created_on) + ) + model_versions = db.session.scalars(latest_model_versions_stmt).unique().all() + + # build a dictionary structure to re-associate models with model versions + models_dict: dict[int, utils.ModelWithVersionDict] = { + model.resource_id: utils.ModelWithVersionDict( + model=model, version=None, has_draft=False + ) + for model in ml_models + } + for model_version in model_versions: + models_dict[model_version.model_id]["version"] = model_version + + drafts_stmt = select( + models.DraftResource.payload["resource_id"].as_string().cast(Integer) + ).where( + models.DraftResource.payload["resource_id"] + .as_string() + .cast(Integer) + .in_(tuple(model["model"].resource_id for model in models_dict.values())), + models.DraftResource.user_id == current_user.user_id, + ) + for resource_id in db.session.scalars(drafts_stmt): + models_dict[resource_id]["has_draft"] = True + + return list(models_dict.values()), total_num_ml_models + + +class ModelIdService(object): + """The service methods for registering and managing models by their unique id.""" + + @inject + def __init__( + self, + model_name_service: ModelNameService, + ) -> None: + """Initialize the model service. + + All arguments are provided via dependency injection. + + Args: + model_name_service: A ModelNameService object. + """ + self._model_name_service = model_name_service + + def get( + self, + model_id: int, + error_if_not_found: bool = False, + **kwargs, + ) -> utils.ModelWithVersionDict | None: + """Fetch a model by its unique id. + + Args: + model_id: The unique id of the model. + error_if_not_found: If True, raise an error if the model is not found. + Defaults to False. + + Returns: + The model object if found, otherwise None. + + Raises: + ModelDoesNotExistError: If the model is not found and `error_if_not_found` + is True. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Get model by id", model_id=model_id) + + stmt = ( + select(models.MlModel) + .join(models.Resource) + .where( + models.MlModel.resource_id == model_id, + models.MlModel.resource_snapshot_id + == models.Resource.latest_snapshot_id, + models.Resource.is_deleted == False, # noqa: E712 + ) + ) + ml_model = db.session.scalars(stmt).first() + + if ml_model is None: + if error_if_not_found: + log.debug("Model not found", model_id=model_id) + raise ModelDoesNotExistError + + return None + + latest_ml_model_version_stmt = ( + select(models.MlModelVersion) + .join(models.Resource) + .where( + models.MlModelVersion.model_id == model_id, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.MlModelVersion.resource_snapshot_id, + ) + .order_by(models.MlModelVersion.created_on.desc()) + .limit(1) + ) + latest_version = db.session.scalar(latest_ml_model_version_stmt) + + drafts_stmt = ( + select(models.DraftResource.draft_resource_id) + .where( + models.DraftResource.payload["resource_id"].as_string().cast(Integer) + == ml_model.resource_id, + models.DraftResource.user_id == current_user.user_id, + ) + .exists() + .select() + ) + has_draft = db.session.scalar(drafts_stmt) + + return utils.ModelWithVersionDict( + model=ml_model, version=latest_version, has_draft=has_draft + ) + + def modify( + self, + model_id: int, + name: str, + description: str, + error_if_not_found: bool = False, + commit: bool = True, + **kwargs, + ) -> utils.ModelWithVersionDict | None: + """Modify a model. + + Args: + model_id: The unique id of the model. + name: The new name of the model. + description: The new description of the model. + error_if_not_found: If True, raise an error if the group is not found. + Defaults to False. + commit: If True, commit the transaction. Defaults to True. + + Returns: + The updated model object. + + Raises: + ModelDoesNotExistError: If the model is not found and `error_if_not_found` + is True. + ModelAlreadyExistsError: If the model name already exists. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + ml_model_dict = self.get( + model_id, error_if_not_found=error_if_not_found, log=log + ) + + if ml_model_dict is None: + if error_if_not_found: + raise ModelDoesNotExistError + + return None + + ml_model = ml_model_dict["model"] + version = ml_model_dict["version"] + has_draft = ml_model_dict["has_draft"] + + group_id = ml_model.resource.group_id + if ( + name != ml_model.name + and self._model_name_service.get(name, group_id=group_id, log=log) + is not None + ): + log.debug("Model name already exists", name=name, group_id=group_id) + raise ModelAlreadyExistsError + + new_ml_model = models.MlModel( + name=name, + description=description, + resource=ml_model.resource, + creator=current_user, + ) + db.session.add(new_ml_model) + + if commit: + db.session.commit() + log.debug( + "Model modification successful", + model_id=model_id, + name=name, + description=description, + ) + + return utils.ModelWithVersionDict( + model=new_ml_model, + version=version, + has_draft=has_draft, + ) + + def delete(self, model_id: int, **kwargs) -> dict[str, Any]: + """Delete a model. + + Args: + model_id: The unique id of the model. + + Returns: + A dictionary reporting the status of the request. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + stmt = select(models.Resource).filter_by( + resource_id=model_id, resource_type=MODEL_RESOURCE_TYPE, is_deleted=False + ) + model_resource = db.session.scalars(stmt).first() + + if model_resource is None: + raise ModelDoesNotExistError + + deleted_resource_lock = models.ResourceLock( + resource_lock_type="delete", + resource=model_resource, + ) + db.session.add(deleted_resource_lock) + db.session.commit() + log.debug("Model deleted", model_id=model_id) + + return {"status": "Success", "model_id": model_id} + + +class ModelIdVersionsService(object): + @inject + def __init__( + self, + artifact_id_service: ArtifactIdService, + model_id_service: ModelIdService, + ) -> None: + """Initialize the model service. + + All arguments are provided via dependency injection. + + Args: + artifact_id_service: A ArtifactIdService object. + model_id_service: A ModelIdService object. + """ + self._artifact_id_service = artifact_id_service + self._model_id_service = model_id_service + + def create( + self, + model_id: int, + description: str, + artifact_id: int, + commit: bool = True, + **kwargs, + ) -> utils.ModelWithVersionDict: + """Create a new model version. + + Args: + model_id: The unique id of the model. + description: The description of the model version. + artifact_id: The artifact for the model version. + commit: If True, commit the transaction. Defaults to True. + + Returns: + The newly created model object. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + ml_model_dict = cast( + utils.ModelWithVersionDict, + self._model_id_service.get(model_id, error_if_not_found=True, log=log), + ) + ml_model = ml_model_dict["model"] + has_draft = ml_model_dict["has_draft"] + version = ml_model_dict["version"] + next_version_number = version.version_number + 1 if version is not None else 1 + group = ml_model.resource.owner + artifact_dict = cast( + utils.ArtifactDict, + self._artifact_id_service.get( + artifact_id, error_if_not_found=True, log=log + ), + ) + artifact = artifact_dict["artifact"] + + resource = models.Resource( + resource_type=MODEL_VERSION_RESOURCE_TYPE, owner=group + ) + new_version = models.MlModelVersion( + description=description, + artifact=artifact, + version_number=next_version_number, + resource=resource, + creator=current_user, + ) + + ml_model.resource.children.append(new_version.resource) + db.session.add(new_version) + + if commit: + db.session.commit() + log.debug( + "Model registration successful", + model_id=ml_model.resource_id, + name=ml_model.name, + ) + + return utils.ModelWithVersionDict( + model=ml_model, + version=new_version, + has_draft=has_draft, + ) + + def get( + self, + model_id: int, + search_string: str, + page_index: int, + page_length: int, + error_if_not_found: bool = False, + **kwargs, + ) -> tuple[list[utils.ModelWithVersionDict], int] | None: + """Fetch a list of versions of a model. + + Args: + model_id: The unique id of the resource. + search_string: A search string used to filter results. + page_index: The index of the first snapshot to be returned. + page_length: The maximum number of versions to be returned. + error_if_not_found: If True, raise an error if the resource is not found. + Defaults to False. + + Returns: + The list of resource versions of the resource object if found, otherwise + None. + + Raises: + ResourceDoesNotExistError: If the resource is not found and + `error_if_not_found` is True. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Get model versions by id", model_id=model_id) + + ml_model_dict = cast( + utils.ModelWithVersionDict, + self._model_id_service.get(model_id, error_if_not_found=True, log=log), + ) + ml_model = ml_model_dict["model"] + model_id = ml_model.resource_id + + filters = construct_sql_query_filters( + search_string, MODEL_VERSION_SEARCHABLE_FIELDS + ) + + latest_model_versions_count_stmt = ( + select(func.count(models.MlModelVersion.resource_id)) + .join(models.Resource) + .where( + filters, + models.MlModelVersion.model_id == model_id, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.MlModelVersion.resource_snapshot_id, + ) + ) + total_num_model_versions = db.session.scalar(latest_model_versions_count_stmt) + + if total_num_model_versions is None: + log.error( + "The database query returned a None when counting the number of " + "groups when it should return a number.", + sql=str(total_num_model_versions), + ) + raise BackendDatabaseError + + if total_num_model_versions == 0: + return [], total_num_model_versions + + latest_model_versions_stmt = ( + select(models.MlModelVersion) + .join(models.Resource) + .where( + filters, + models.MlModelVersion.model_id == model_id, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.MlModelVersion.resource_snapshot_id, + ) + .offset(page_index) + .limit(page_length) + ) + + drafts_stmt = select( + models.DraftResource.payload["resource_id"].as_string().cast(Integer) + ).where( + models.DraftResource.payload["resource_id"].as_string().cast(Integer) + == model_id, + models.DraftResource.user_id == current_user.user_id, + ) + has_draft = db.session.scalar(drafts_stmt) + + return [ + utils.ModelWithVersionDict( + model=ml_model, version=version, has_draft=has_draft + ) + for version in db.session.scalars(latest_model_versions_stmt).unique().all() + ], total_num_model_versions + + +class ModelIdVersionsNumberService(object): + @inject + def __init__( + self, + model_id_service: ModelIdService, + ) -> None: + """Initialize the model service. + + All arguments are provided via dependency injection. + + Args: + model_id_service: A ModelIdService object. + """ + self._model_id_service = model_id_service + + def get( + self, + model_id: int, + version_number: int, + error_if_not_found: bool = False, + **kwargs, + ) -> utils.ModelWithVersionDict | None: + """Fetch a specific version of a Model resource. + + Args: + model_id: The unique id of the Model resource. + version_number: The version number of the Model + error_if_not_found: If True, raise an error if the resource is not found. + Defaults to False. + + Returns: + The requested version the resource object if found, otherwise None. + + Raises: + ResourceDoesNotExistError: If the resource is not found and + `error_if_not_found` is True. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Get resource snaphot by id", model_id=model_id) + + ml_model_dict = self._model_id_service.get( + model_id, error_if_not_found=error_if_not_found, log=log + ) + + if ml_model_dict is None: + return None + + ml_model = ml_model_dict["model"] + + ml_model_version_stmt = ( + select(models.MlModelVersion) + .join(models.Resource) + .where( + models.MlModelVersion.model_id == model_id, + models.MlModelVersion.version_number == version_number, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.MlModelVersion.resource_snapshot_id, + ) + ) + latest_version = db.session.scalar(ml_model_version_stmt) + + if latest_version is None: + if error_if_not_found: + log.debug("Model version not found", version_number=version_number) + raise ModelVersionDoesNotExistError + + return None + + return utils.ModelWithVersionDict( + model=ml_model, version=latest_version, has_draft=False + ) + + def modify( + self, + model_id: int, + version_number: int, + description: str, + commit: bool = True, + **kwargs, + ) -> utils.ModelWithVersionDict | None: + """Modify a model version. + + Args: + model_id: The unique id of the model. + version_number: The version number of the model. + description: The new description of the model version. + error_if_not_found: If True, raise an error if the group is not found. + Defaults to False. + commit: If True, commit the transaction. Defaults to True. + + Returns: + The updated model object. + + Raises: + ModelDoesNotExistError: If the model is not found and `error_if_not_found` + is True. + ModelAlreadyExistsError: If the model name already exists. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + ml_model_dict = cast( + utils.ModelWithVersionDict, + self.get(model_id, version_number, error_if_not_found=True, log=log), + ) + + ml_model = ml_model_dict["model"] + version = cast(models.MlModelVersion, ml_model_dict["version"]) + + new_version = models.MlModelVersion( + description=description, + artifact=version.artifact, + version_number=version.version_number, + resource=version.resource, + creator=current_user, + ) + db.session.add(new_version) + + if commit: + db.session.commit() + log.debug( + "Model Version modification successful", + model_id=model_id, + version_number=version_number, + description=description, + ) + + return utils.ModelWithVersionDict( + model=ml_model, + version=new_version, + has_draft=False, + ) + + +class ModelNameService(object): + """The service methods for managing models by their name.""" + + def get( + self, + name: str, + group_id: int, + error_if_not_found: bool = False, + **kwargs, + ) -> models.MlModel | None: + """Fetch a model by its name. + + Args: + name: The name of the model. + group_id: The the group id of the model. + error_if_not_found: If True, raise an error if the model is not found. + Defaults to False. + + Returns: + The model object if found, otherwise None. + + Raises: + ModelDoesNotExistError: If the model is not found and `error_if_not_found` + is True. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Get model by name", model_name=name, group_id=group_id) + + stmt = ( + select(models.MlModel) + .join(models.Resource) + .where( + models.MlModel.name == name, + models.Resource.group_id == group_id, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.MlModel.resource_snapshot_id, + ) + ) + ml_model = db.session.scalars(stmt).first() + + if ml_model is None: + if error_if_not_found: + log.debug("Model not found", name=name) + raise ModelDoesNotExistError + + return None + + return ml_model diff --git a/src/dioptra/restapi/v1/utils.py b/src/dioptra/restapi/v1/utils.py index b51a5ff2e..859681075 100644 --- a/src/dioptra/restapi/v1/utils.py +++ b/src/dioptra/restapi/v1/utils.py @@ -15,7 +15,7 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """Utility functions to help in building responses from ORM models""" -from typing import Any, Callable, Final, TypedDict +from typing import Any, Callable, Final, TypedDict, cast from urllib.parse import urlencode, urlunparse from marshmallow import Schema @@ -23,9 +23,11 @@ from dioptra.restapi.db import models from dioptra.restapi.routes import V1_ROOT +ARTIFACTS: Final[str] = "artifacts" ENTRYPOINTS: Final[str] = "entrypoints" EXPERIMENTS: Final[str] = "experiments" GROUPS: Final[str] = "groups" +MODELS: Final[str] = "models" PLUGINS: Final[str] = "plugins" PLUGIN_FILES: Final[str] = "files" PLUGIN_PARAMETER_TYPES: Final[str] = "pluginParameterTypes" @@ -90,6 +92,11 @@ class PluginParameterTypeDict(TypedDict): has_draft: bool | None +class ArtifactDict(TypedDict): + artifact: models.Artifact + has_draft: bool | None + + class QueueDict(TypedDict): queue: models.Queue has_draft: bool | None @@ -107,6 +114,12 @@ class JobDict(TypedDict): has_draft: bool | None +class ModelWithVersionDict(TypedDict): + model: models.MlModel + version: models.MlModelVersion | None + has_draft: bool | None + + # -- Ref Types ----------------------------------------------------------------- @@ -338,6 +351,23 @@ def build_entrypoint_snapshot_ref(entrypoint: models.EntryPoint) -> dict[str, An } +def build_model_ref(model: models.MlModel) -> dict[str, Any]: + """Build a MlModel dictionary. + + Args: + model: The Model object to convert into a MlModel dictionary. + + Returns: + The MlModel dictionary. + """ + return { + "id": model.resource_id, + "name": model.name, + "group": build_group_ref(model.resource.owner), + "url": f"{MODELS}/{model.resource_id}", + } + + def build_queue_ref(queue: models.Queue) -> dict[str, Any]: """Build a QueueRef dictionary. @@ -394,6 +424,23 @@ def build_plugin_parameter_type_ref( } +def build_artifact_ref(artifact: models.Artifact) -> dict[str, Any]: + """Build a ArtifactRef dictionary. + + Args: + artifact: The Artifact object to convert into a ArtifactRef dictionary. + + Returns: + The ArtifactRef dictionary. + """ + return { + "id": artifact.resource_id, + "artifact_uri": artifact.uri, + "group": build_group_ref(artifact.resource.owner), + "url": f"/{ARTIFACTS}/{artifact.resource_id}", + } + + # -- Full Types ---------------------------------------------------------------- @@ -624,6 +671,101 @@ def build_job(job_dict: JobDict) -> dict[str, Any]: return data +def build_model(model_dict: ModelWithVersionDict) -> dict[str, Any]: + """Build a Model response dictionary for the latest version. + + Args: + model: The Model object to convert into a Model response dictionary. + + Returns: + The Model response dictionary. + """ + model = model_dict["model"] + version = model_dict.get("version", None) + has_draft = model_dict.get("has_draft", None) + + latest_version = build_model_version(model_dict) if version is not None else None + latest_version_number = version.version_number if version is not None else 0 + + # WARNING: this assumes versions cannot be deleted and that no details of + # the version # are needed in constructing the ref type. If either of these + # assumptions change, the service and controller layers will need to be + # updated to return the MlModelVersion # ORM objects. + versions = [ + { + "version_number": version_number, + "url": build_url(f"{MODELS}/{model.resource_id}/versions/{version_number}"), + } + for version_number in range(1, latest_version_number + 1) + ] + + data = { + "id": model.resource_id, + "snapshot_id": model.resource_snapshot_id, + "name": model.name, + "description": model.description, + "user": build_user_ref(model.creator), + "group": build_group_ref(model.resource.owner), + "created_on": model.created_on, + "last_modified_on": model.resource.last_modified_on, + "latest_snapshot": model.resource.latest_snapshot_id + == model.resource_snapshot_id, + "tags": [build_tag_ref(tag) for tag in model.tags], + "latest_version": latest_version, + "versions": versions, + } + + if has_draft is not None: + data["has_draft"] = has_draft + + return data + + +def build_model_version(model_dict: ModelWithVersionDict) -> dict[str, Any]: + """Build a ModelVersion response dictionary. + + Args: + model: The ModelVersion object to convert into a ModelVersion response + dictionary. + + Returns: + The ModelVersion response dictionary. + """ + model = model_dict["model"] + version = cast(models.MlModelVersion, model_dict["version"]) + return { + "model": build_model_ref(model), + "description": version.description, + "version_number": version.version_number, + "artifact": build_artifact_ref(version.artifact), + "created_on": version.created_on, + } + + +def build_artifact(artifact_dict: ArtifactDict) -> dict[str, Any]: + artifact = artifact_dict["artifact"] + has_draft = artifact_dict.get("has_draft") + + data = { + "id": artifact.resource_id, + "snapshot_id": artifact.resource_snapshot_id, + "uri": artifact.uri, + "description": artifact.description, + "user": build_user_ref(artifact.creator), + "group": build_group_ref(artifact.resource.owner), + "created_on": artifact.created_on, + "last_modified_on": artifact.resource.last_modified_on, + "latest_snapshot": artifact.resource.latest_snapshot_id + == artifact.resource_snapshot_id, + "tags": [build_tag_ref(tag) for tag in artifact.tags], + } + + if has_draft is not None: + data["has_draft"] = has_draft + + return data + + def build_queue(queue_dict: QueueDict) -> dict[str, Any]: """Build a Queue response dictionary. diff --git a/tests/unit/restapi/lib/actions.py b/tests/unit/restapi/lib/actions.py index 447863c41..d9914b13b 100644 --- a/tests/unit/restapi/lib/actions.py +++ b/tests/unit/restapi/lib/actions.py @@ -19,17 +19,19 @@ This module contains shared actions used across test suites for each of the REST API endpoints. """ -from typing import Any, List +from typing import Any from flask.testing import FlaskClient from werkzeug.test import TestResponse from dioptra.restapi.routes import ( + V1_ARTIFACTS_ROUTE, V1_AUTH_ROUTE, V1_ENTRYPOINTS_ROUTE, V1_EXPERIMENTS_ROUTE, V1_GROUPS_ROUTE, V1_JOBS_ROUTE, + V1_MODELS_ROUTE, V1_PLUGIN_PARAMETER_TYPES_ROUTE, V1_PLUGINS_ROUTE, V1_QUEUES_ROUTE, @@ -223,6 +225,89 @@ def register_tag( ) +def register_artifact( + client: FlaskClient, + uri: int, + job_id: int, + group_id: int, + description: str | None = None, +) -> TestResponse: + """Register an artifact using the API. + + Args: + client: The Flask test client. + uri: The URI of the artifact + job_id: The job to create the new artifact. + group_id: The group to create the new artifact in. + description: The description of the new artifact. + + Returns: + The response from the API. + """ + payload = {"uri": uri, "job": job_id, "group": group_id} + + if description is not None: + payload["description"] = description + + return client.post( + f"/{V1_ROOT}/{V1_ARTIFACTS_ROUTE}/", + json=payload, + follow_redirects=True, + ) + + +def register_model( + client: FlaskClient, + name: str, + group_id: int, + description: str | None = None, +) -> TestResponse: + """Register a model using the API. + + Args: + client: The Flask test client. + name: The name to assign to the new model. + group_id: The group to create the new model in. + description: The description of the new model. + + Returns: + The response from the API. + """ + payload = {"name": name, "group": group_id, "description": description} + + return client.post( + f"/{V1_ROOT}/{V1_MODELS_ROUTE}/", + json=payload, + follow_redirects=True, + ) + + +def register_model_version( + client: FlaskClient, + model_id: int, + artifact_id: int, + description: str | None = None, +) -> TestResponse: + """Register a model version using the API. + + Args: + client: The Flask test client. + model_id: The id of the model to create a new version of. + artifact_id: The id of artifact representing the new model version. + description: The description of the new model version. + + Returns: + The response from the API. + """ + payload = {"artifact": artifact_id, "description": description} + + return client.post( + f"/{V1_ROOT}/{V1_MODELS_ROUTE}/{model_id}/versions", + json=payload, + follow_redirects=True, + ) + + def register_plugin( client: FlaskClient, name: str, @@ -378,6 +463,14 @@ def get_public_group(client: FlaskClient) -> TestResponse: return client.get(f"/{V1_ROOT}/{V1_GROUPS_ROUTE}/1", follow_redirects=True) +def get_model(client: FlaskClient, model_id: int) -> TestResponse: + response = client.get( + f"/{V1_ROOT}/{V1_MODELS_ROUTE}/{model_id}", + follow_redirects=True, + ) + return response + + def get_plugin_parameter_types(client: FlaskClient) -> TestResponse: response = client.get( f"/{V1_ROOT}/{V1_PLUGIN_PARAMETER_TYPES_ROUTE}", diff --git a/tests/unit/restapi/lib/db/setup.py b/tests/unit/restapi/lib/db/setup.py index 69937acac..61ca93ad5 100644 --- a/tests/unit/restapi/lib/db/setup.py +++ b/tests/unit/restapi/lib/db/setup.py @@ -55,6 +55,7 @@ {"resource_type": "plugin_task_parameter_type"}, {"resource_type": "queue"}, {"resource_type": "resource_snapshot"}, + {"resource_type": "ml_model_version"}, ] RESOURCE_DEPENDENCY_TYPES: Final[list[dict[str, str]]] = [ {"parent_resource_type": "experiment", "child_resource_type": "entry_point"}, @@ -63,6 +64,7 @@ {"parent_resource_type": "plugin", "child_resource_type": "plugin_file"}, {"parent_resource_type": "job", "child_resource_type": "artifact"}, {"parent_resource_type": "job", "child_resource_type": "job"}, + {"parent_resource_type": "ml_model", "child_resource_type": "ml_model_version"}, ] LEGACY_JOB_STATUS_TYPES: Final[list[dict[str, str]]] = [ {"status": "queued"}, diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index 1b4aad0a0..6d582296a 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -105,6 +105,111 @@ def registered_tags( } +@pytest.fixture +def registered_artifacts( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + # registered_jobs: dict[str, Any], +) -> dict[str, Any]: + # TODO: get the job ID from a registered jobs once implemented + job_id = 0 # registered_jobs["job1"]["id"] + group_id = auth_account["groups"][0]["id"] + artifact1_response = actions.register_artifact( + client, + uri="s3://bucket/model_v1.artifact", + description="Model artifact.", + job_id=job_id, + group_id=group_id, + ).get_json() + artifact2_response = actions.register_artifact( + client, + uri="s3://bucket/cnn.artifact", + description="Trained conv net model artifact.", + job_id=job_id, + group_id=group_id, + ).get_json() + artifact3_response = actions.register_artifact( + client, + uri="s3://bucket/model.artifact", + description="Another model", + job_id=job_id, + group_id=group_id, + ).get_json() + artifact4_response = actions.register_artifact( + client, + uri="s3://bucket/model_v2.artifact", + description="Fine-tuned model.", + job_id=job_id, + group_id=group_id, + ).get_json() + + return { + "artifact1": artifact1_response, + "artifact2": artifact2_response, + "artifact3": artifact3_response, + "artifact4": artifact4_response, + } + + +@pytest.fixture +def registered_models( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_artifacts: dict[str, Any], +) -> dict[str, Any]: + group_id = auth_account["groups"][0]["id"] + model1_response = actions.register_model( + client, + name="my_tensorflow_model", + description="Trained model", + group_id=group_id, + ).get_json() + model2_response = actions.register_model( + client, + name="model2", + description="Trained model", + group_id=group_id, + ).get_json() + model3_response = actions.register_model( + client, + name="model3", + description="", + group_id=group_id, + ).get_json() + + actions.register_model_version( + client, + model_id=model1_response["id"], + artifact_id=registered_artifacts["artifact1"]["id"], + description="initial version", + ).get_json() + actions.register_model_version( + client, + model_id=model2_response["id"], + artifact_id=registered_artifacts["artifact2"]["id"], + description="initial version", + ).get_json() + actions.register_model_version( + client, + model_id=model3_response["id"], + artifact_id=registered_artifacts["artifact3"]["id"], + ).get_json() + actions.register_model_version( + client, + model_id=model1_response["id"], + artifact_id=registered_artifacts["artifact4"]["id"], + description="new version", + ).get_json() + + return { + "model1": actions.get_model(client, model1_response["id"]).get_json(), + "model2": actions.get_model(client, model2_response["id"]).get_json(), + "model3": actions.get_model(client, model3_response["id"]).get_json(), + } + + @pytest.fixture def registered_plugins( client: FlaskClient, db: SQLAlchemy, auth_account: dict[str, Any] diff --git a/tests/unit/restapi/v1/test_artifact.py b/tests/unit/restapi/v1/test_artifact.py new file mode 100644 index 000000000..bc1e6793d --- /dev/null +++ b/tests/unit/restapi/v1/test_artifact.py @@ -0,0 +1,321 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Test suite for model operations. + +This module contains a set of tests that validate the CRUD operations and additional +functionalities for the model entity. The tests ensure that the models can be +registered, renamed, deleted, and locked/unlocked as expected through the REST API. +""" + +from typing import Any + +from flask.testing import FlaskClient +from flask_sqlalchemy import SQLAlchemy +from werkzeug.test import TestResponse + +from dioptra.restapi.routes import V1_ARTIFACTS_ROUTE, V1_ROOT + +from ..lib import actions, helpers + +# -- Assertions --------------------------------------------------------------------------- + + +def assert_artifact_response_contents_matches_expectations( + response: dict[str, Any], + expected_contents: dict[str, Any], +) -> None: + """Assert that artifact response contents is valid. + + Args: + response: The actual response from the API. + expected_contents: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response or if the response contents is not + valid. + """ + expected_keys = { + "id", + "snapshot", + "group", + "user", + "createdOn", + "lastModifiedOn", + "latestSnapshot", + "hasDraft", + "uri", + "description", + "tags", + } + assert set(response.keys()) == expected_keys + + # Validate the non-Ref fields + assert isinstance(response["id"], int) + assert isinstance(response["snapshot"], int) + assert isinstance(response["uri"], str) + assert isinstance(response["description"], str) + assert isinstance(response["createdOn"], str) + assert isinstance(response["lastModifiedOn"], str) + assert isinstance(response["latestSnapshot"], bool) + assert isinstance(response["hasDraft"], bool) + + assert response["uri"] == expected_contents["uri"] + assert response["description"] == expected_contents["description"] + + assert helpers.is_iso_format(response["createdOn"]) + assert helpers.is_iso_format(response["lastModifiedOn"]) + + # Validate the UserRef structure + assert isinstance(response["user"]["id"], int) + assert isinstance(response["user"]["username"], str) + assert isinstance(response["user"]["url"], str) + assert response["user"]["id"] == expected_contents["user_id"] + + # Validate the GroupRef structure + assert isinstance(response["group"]["id"], int) + assert isinstance(response["group"]["name"], str) + assert isinstance(response["group"]["url"], str) + assert response["group"]["id"] == expected_contents["group_id"] + + # Validate the TagRef structure + for tag in response["tags"]: + assert isinstance(tag["id"], int) + assert isinstance(tag["name"], str) + assert isinstance(tag["url"], str) + + +def assert_retrieving_artifact_by_id_works( + client: FlaskClient, artifact_id: int, expected: dict[str, Any] +) -> None: + """Assert that retrieving a artifact by id works. + + Args: + client: The Flask test client. + artifact_id: The id of the artifact to retrieve. + expected: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + response = client.get( + f"/{V1_ROOT}/{V1_ARTIFACTS_ROUTE}/{artifact_id}", follow_redirects=True + ) + assert response.status_code == 200 and response.get_json() == expected + + +def assert_retrieving_artifacts_works( + client: FlaskClient, + expected: list[dict[str, Any]], + group_id: int | None = None, + search: str | None = None, + paging_info: dict[str, Any] | None = None, +) -> None: + """Assert that retrieving all artifacts works. + + Args: + client: The Flask test client. + expected: The expected response from the API. + group_id: The group ID used in query parameters. + search: The search string used in query parameters. + paging_info: The paging information used in query parameters. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + query_string: dict[str, Any] = {} + + if group_id is not None: + query_string["groupId"] = group_id + + if search is not None: + query_string["search"] = search + + if paging_info is not None: + query_string["index"] = paging_info["index"] + query_string["pageLength"] = paging_info["page_length"] + + response = client.get( + f"/{V1_ROOT}/{V1_ARTIFACTS_ROUTE}", + query_string=query_string, + follow_redirects=True, + ) + assert response.status_code == 200 and response.get_json()["data"] == expected + + +def assert_registering_existing_artifact_uri_fails( + client: FlaskClient, + uri: str, + group_id: int, + job_id: int, +) -> None: + """Assert that registering an artifact with an existing uri fails. + + Args: + client: The Flask test client. + uri: The uri to assign to the new artifact. + + Raises: + AssertionError: If the response status code is not 400. + """ + response = actions.register_artifact( + client, + uri=uri, + description="", + group_id=group_id, + job_id=job_id, + ) + assert response.status_code == 400 + + +# -- Tests ----------------------------------------------------------------------------- + + +def test_create_artifact( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], +) -> None: + """Test that artifacts can be correctly registered and retrieved using the API. + + Given an authenticated user, this test validates the following sequence of actions: + + - The user registers an artifact with uri "s3://bucket/model_v1.artifact". + - The response is valid matches the expected values given the registration request. + - The user is able to retrieve information about the artifact using the artifact id. + """ + uri = "s3://bucket/model_v1.artifact" + description = "The first artifact." + job_id = 0 # TODO: fill in once jobs are good. + user_id = auth_account["id"] + group_id = auth_account["groups"][0]["id"] + artifact_response = actions.register_artifact( + client, + uri=uri, + job_id=job_id, + group_id=group_id, + description=description, + ) + + artifact_expected = artifact_response.get_json() + + assert_artifact_response_contents_matches_expectations( + response=artifact_expected, + expected_contents={ + "uri": uri, + "description": description, + "user_id": user_id, + "group_id": group_id, + }, + ) + assert_retrieving_artifact_by_id_works( + client, artifact_id=artifact_expected["id"], expected=artifact_expected + ) + + +def test_artifacts_get_all( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_artifacts: dict[str, Any], +) -> None: + """Test that all artifacts can be retrieved. + + Given an authenticated user and registered artifacts, this test validates the following + sequence of actions: + + - A user registers three artifacts with uris + - "s3://bucket/model_v1.artifact" + - "s3://bucket/cnn.artifact" + - "s3://bucket/model.artifact" + - "s3://bucket/model_v2.artifact" + - The user is able to retrieve a list of all registered artifacts. + - The returned list of artifacts matches the full list of registered artifacts. + """ + artifacts_expected_list = list(registered_artifacts.values()) + assert_retrieving_artifacts_works(client, expected=artifacts_expected_list) + + +def test_artifact_search_query( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_artifacts: dict[str, Any], +) -> None: + """Test that artifacts can be queried with a search term. + + Given an authenticated user and registered artifacts, this test validates the following + sequence of actions: + + - The user is able to retrieve a list of all registered artifacts with 'artifact' in + their description. + - The returned list of artifacts matches the expected matches from the query. + """ + artifacts_expected_list = list(registered_artifacts.values())[:2] + assert_retrieving_artifacts_works( + client, + expected=artifacts_expected_list, + search="description:*artifact*", + ) + + +def test_artifact_group_query( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_artifacts: dict[str, Any], +) -> None: + """Test that artifacts can retrieved using a group filter. + + Given an authenticated user and registered artifacts, this test validates the following + sequence of actions: + + - The user is able to retrieve a list of all registered artifacts that are owned by the + default group. + - The returned list of artifacts matches the expected list owned by the default group. + """ + artifacts_expected_list = list(registered_artifacts.values()) + assert_retrieving_artifacts_works( + client, + expected=artifacts_expected_list, + group_id=auth_account["groups"][0]["id"], + ) + + +def test_cannot_register_existing_artifact_uri( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_artifacts: dict[str, Any], +) -> None: + """Test that registering a artifact with an existing uri fails. + + Given an authenticated user and registered artifacts, this test validates the following + sequence of actions: + + - The user attempts to register a second artifact with the same uri. + - The request fails with an appropriate error message and response code. + """ + existing_artifact = registered_artifacts["artifact1"] + assert_registering_existing_artifact_uri_fails( + client, + uri=existing_artifact["uri"], + group_id=existing_artifact["group"]["id"], + job_id=0, # TODO: fill in once job stuff is done. + ) diff --git a/tests/unit/restapi/v1/test_model.py b/tests/unit/restapi/v1/test_model.py new file mode 100644 index 000000000..368f792cf --- /dev/null +++ b/tests/unit/restapi/v1/test_model.py @@ -0,0 +1,548 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Test suite for model operations. + +This module contains a set of tests that validate the CRUD operations and additional +functionalities for the model entity. The tests ensure that the models can be +registered, renamed, deleted, and locked/unlocked as expected through the REST API. +""" + +from typing import Any + +from flask.testing import FlaskClient +from flask_sqlalchemy import SQLAlchemy +from werkzeug.test import TestResponse + +from dioptra.restapi.routes import V1_MODELS_ROUTE, V1_ROOT + +from ..lib import actions, helpers + +# -- Actions --------------------------------------------------------------------------- + + +def modify_model( + client: FlaskClient, + model_id: int, + new_name: str, + new_description: str | None, +) -> TestResponse: + """Rename a model using the API. + + Args: + client: The Flask test client. + model_id: The id of the model to rename. + new_name: The new name to assign to the model. + new_description: The new description to assign to the model. + new_artifact_id: The new artifact to assign to the model. + + Returns: + The response from the API. + """ + payload = { + "name": new_name, + "description": new_description, + } + + return client.put( + f"/{V1_ROOT}/{V1_MODELS_ROUTE}/{model_id}", + json=payload, + follow_redirects=True, + ) + + +def delete_model_with_id( + client: FlaskClient, + model_id: int, +) -> TestResponse: + """Delete a model using the API. + + Args: + client: The Flask test client. + model_id: The id of the model to delete. + + Returns: + The response from the API. + """ + + return client.delete( + f"/{V1_ROOT}/{V1_MODELS_ROUTE}/{model_id}", + follow_redirects=True, + ) + + +# -- Assertions ------------------------------------------------------------------------ + + +def assert_model_response_contents_matches_expectations( + response: dict[str, Any], expected_contents: dict[str, Any] +) -> None: + """Assert that model response contents is valid. + + Args: + response: The actual response from the API. + expected_contents: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response or if the response contents is not + valid. + """ + expected_keys = { + "id", + "snapshot", + "group", + "user", + "versions", + "createdOn", + "lastModifiedOn", + "latestSnapshot", + "hasDraft", + "name", + "description", + "latestVersion", + "tags", + } + assert set(response.keys()) == expected_keys + + # Validate the non-Ref fields + assert isinstance(response["id"], int) + assert isinstance(response["snapshot"], int) + assert isinstance(response["name"], str) + assert isinstance(response["description"], str) + assert isinstance(response["createdOn"], str) + assert isinstance(response["lastModifiedOn"], str) + assert isinstance(response["latestSnapshot"], bool) + assert isinstance(response["hasDraft"], bool) + + assert response["name"] == expected_contents["name"] + assert response["description"] == expected_contents["description"] + assert response["versions"] == expected_contents["versions"] + + assert helpers.is_iso_format(response["createdOn"]) + assert helpers.is_iso_format(response["lastModifiedOn"]) + + # Validate the UserRef structure + assert isinstance(response["user"]["id"], int) + assert isinstance(response["user"]["username"], str) + assert isinstance(response["user"]["url"], str) + assert response["user"]["id"] == expected_contents["user_id"] + + # Validate the GroupRef structure + assert isinstance(response["group"]["id"], int) + assert isinstance(response["group"]["name"], str) + assert isinstance(response["group"]["url"], str) + assert response["group"]["id"] == expected_contents["group_id"] + + # Validate the versions structure + assert response["latestVersion"] == expected_contents["latest_version"] + + # Validate the TagRef structure + for tag in response["tags"]: + assert isinstance(tag["id"], int) + assert isinstance(tag["name"], str) + assert isinstance(tag["url"], str) + + +def assert_model_version_response_contents_matches_expectations( + response: dict[str, Any], expected_contents: dict[str, Any] +) -> None: + """Assert that model version response contents is valid. + + Args: + response: The actual response from the API. + expected_contents: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response or if the response contents is not + valid. + """ + expected_keys = { + "model", + "artifact", + "versionNumber", + "description", + "createdOn", + } + assert set(response.keys()) == expected_keys + + # Validate the non-Ref fields + assert isinstance(response["description"], str) + assert isinstance(response["createdOn"], str) + assert isinstance(response["versionNumber"], int) + + assert response["versionNumber"] == expected_contents["version_number"] + assert response["description"] == expected_contents["description"] + assert response["model"]["id"] == expected_contents["model_id"] + assert response["artifact"]["id"] == expected_contents["artifact_id"] + + assert helpers.is_iso_format(response["createdOn"]) + + +def assert_retrieving_model_by_id_works( + client: FlaskClient, + model_id: int, + expected: dict[str, Any], +) -> None: + """Assert that retrieving a model by id works. + + Args: + client: The Flask test client. + model_id: The id of the model to retrieve. + expected: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + response = client.get( + f"/{V1_ROOT}/{V1_MODELS_ROUTE}/{model_id}", follow_redirects=True + ) + assert response.status_code == 200 and response.get_json() == expected + + +def assert_retrieving_models_works( + client: FlaskClient, + expected: list[dict[str, Any]], + group_id: int | None = None, + search: str | None = None, + paging_info: dict[str, Any] | None = None, +) -> None: + """Assert that retrieving all models works. + + Args: + client: The Flask test client. + expected: The expected response from the API. + group_id: The group ID used in query parameters. + search: The search string used in query parameters. + paging_info: The paging information used in query parameters. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + + query_string: dict[str, Any] = {} + + if group_id is not None: + query_string["groupId"] = group_id + + if search is not None: + query_string["search"] = search + + if paging_info is not None: + query_string["index"] = paging_info["index"] + query_string["pageLength"] = paging_info["page_length"] + + response = client.get( + f"/{V1_ROOT}/{V1_MODELS_ROUTE}", + query_string=query_string, + follow_redirects=True, + ) + assert response.status_code == 200 and response.get_json()["data"] == expected + + +def assert_registering_existing_model_name_fails( + client: FlaskClient, name: str, group_id: int +) -> None: + """Assert that registering a model with an existing name fails. + + Args: + client: The Flask test client. + name: The name to assign to the new model. + + Raises: + AssertionError: If the response status code is not 400. + """ + response = actions.register_model( + client, name=name, description="", group_id=group_id + ) + assert response.status_code == 400 + + +def assert_model_name_matches_expected_name( + client: FlaskClient, model_id: int, expected_name: str +) -> None: + """Assert that the name of a model matches the expected name. + + Args: + client: The Flask test client. + model_id: The id of the model to retrieve. + expected_name: The expected name of the model. + + Raises: + AssertionError: If the response status code is not 200 or if the name of the + model does not match the expected name. + """ + response = client.get( + f"/{V1_ROOT}/{V1_MODELS_ROUTE}/{model_id}", + follow_redirects=True, + ) + assert response.status_code == 200 and response.get_json()["name"] == expected_name + + +def assert_model_is_not_found( + client: FlaskClient, + model_id: int, +) -> None: + """Assert that a model is not found. + + Args: + client: The Flask test client. + model_id: The id of the model to retrieve. + + Raises: + AssertionError: If the response status code is not 404. + """ + response = client.get( + f"/{V1_ROOT}/{V1_MODELS_ROUTE}/{model_id}", + follow_redirects=True, + ) + assert response.status_code == 404 + + +def assert_cannot_rename_model_with_existing_name( + client: FlaskClient, + model_id: int, + existing_name: str, + existing_description: str, +) -> None: + """Assert that renaming a model with an existing name fails. + + Args: + client: The Flask test client. + model_id: The id of the model to rename. + name: The name of an existing model. + + Raises: + AssertionError: If the response status code is not 400. + """ + response = modify_model( + client=client, + model_id=model_id, + new_name=existing_name, + new_description=existing_description, + ) + assert response.status_code == 400 + + +# -- Model Tests --------------------------------------------------------------- + + +def test_create_model( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], +) -> None: + name = "my_model" + description = "The first model." + user_id = auth_account["id"] + group_id = auth_account["groups"][0]["id"] + model_response = actions.register_model( + client, name=name, group_id=group_id, description=description + ) + + model_expected = model_response.get_json() + assert_model_response_contents_matches_expectations( + response=model_expected, + expected_contents={ + "name": name, + "description": description, + "user_id": user_id, + "group_id": group_id, + "versions": [], + "latest_version": None, + }, + ) + + assert_retrieving_model_by_id_works( + client, model_id=model_expected["id"], expected=model_expected + ) + + +def test_model_get_all( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_models: dict[str, Any], +) -> None: + """Test that all models can be retrieved. + + Given an authenticated user and registered models, this test validates the following + sequence of actions: + + - A user registers three models, "tensorflow_cpu", "tensorflow_gpu", "pytorch_cpu". + - The user is able to retrieve a list of all registered models. + - The returned list of models matches the full list of registered models. + """ + model_expected_list = list(registered_models.values()) + assert_retrieving_models_works(client, expected=model_expected_list) + + +def test_model_search_query( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_models: dict[str, Any], +) -> None: + """Test that models can be queried with a search term. + + Given an authenticated user and registered models, this test validates the following + sequence of actions: + + - The user is able to retrieve a list of all registered models with various queries. + - The returned list of models matches the expected matches from the query. + """ + model_expected_list = list(registered_models.values())[:2] + assert_retrieving_models_works( + client, expected=model_expected_list, search="description:*model*" + ) + model_expected_list = list(registered_models.values())[:1] + assert_retrieving_models_works( + client, expected=model_expected_list, search="*model*, name:*tensorflow*" + ) + model_expected_list = list(registered_models.values()) + assert_retrieving_models_works(client, expected=model_expected_list, search="*") + + +def test_model_group_query( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_models: dict[str, Any], +) -> None: + """Test that models can retrieved using a group filter. + + Given an authenticated user and registered models, this test validates the following + sequence of actions: + + - The user is able to retrieve a list of all registered models that are owned by the + default group. + - The returned list of models matches the expected list owned by the default group. + """ + model_expected_list = list(registered_models.values()) + assert_retrieving_models_works( + client, + expected=model_expected_list, + group_id=auth_account["groups"][0]["id"], + ) + + +def test_cannot_register_existing_model_name( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_models: dict[str, Any], +) -> None: + """Test that registering a model with an existing name fails. + + Given an authenticated user and registered models, this test validates the following + sequence of actions: + + - The user attempts to register a second model with the same name. + - The request fails with an appropriate error message and response code. + """ + existing_model = registered_models["model1"] + + assert_registering_existing_model_name_fails( + client, + name=existing_model["name"], + group_id=existing_model["group"]["id"], + ) + + +def test_rename_model( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_models: dict[str, Any], +) -> None: + """Test that a model can be renamed. + + Given an authenticated user and registered models, this test validates the following + sequence of actions: + + - The user issues a request to change the name of a model. + - The user retrieves information about the same model and it reflects the name + change. + - The user issues a request to change the name of the model to the existing name. + - The user retrieves information about the same model and verifies the name remains + unchanged. + - The user issues a request to change the name of a model to an existing model's + name. + - The request fails with an appropriate error message and response code. + """ + updated_model_name = "tensorflow_tpu" + model_to_rename = registered_models["model1"] + existing_model = registered_models["model2"] + + modified_model = modify_model( + client, + model_id=model_to_rename["id"], + new_name=updated_model_name, + new_description=model_to_rename["description"], + ).get_json() + assert_model_name_matches_expected_name( + client, model_id=model_to_rename["id"], expected_name=updated_model_name + ) + model_expected_list = [ + modified_model, + registered_models["model2"], + registered_models["model3"], + ] + assert_retrieving_models_works(client, expected=model_expected_list) + + modified_model = modify_model( + client, + model_id=model_to_rename["id"], + new_name=updated_model_name, + new_description=model_to_rename["description"], + ).get_json() + assert_model_name_matches_expected_name( + client, model_id=model_to_rename["id"], expected_name=updated_model_name + ) + + assert_cannot_rename_model_with_existing_name( + client, + model_id=model_to_rename["id"], + existing_name=existing_model["name"], + existing_description=model_to_rename["description"], + ) + + +def test_delete_model_by_id( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_models: dict[str, Any], +) -> None: + """Test that a model can be deleted by referencing its id. + + Given an authenticated user and registered models, this test validates the following + sequence of actions: + + - The user deletes a model by referencing its id. + - The user attempts to retrieve information about the deleted model. + - The request fails with an appropriate error message and response code. + """ + model_to_delete = registered_models["model1"] + + delete_model_with_id(client, model_id=model_to_delete["id"]) + assert_model_is_not_found(client, model_id=model_to_delete["id"]) + + +# -- Tests Model Versions ----------------------------------------------------------------