diff --git a/src/kili/adapters/kili_api_gateway/project/operations.py b/src/kili/adapters/kili_api_gateway/project/operations.py index 22df2f22a..73e715573 100644 --- a/src/kili/adapters/kili_api_gateway/project/operations.py +++ b/src/kili/adapters/kili_api_gateway/project/operations.py @@ -36,3 +36,9 @@ def get_update_properties_in_project_mutation(fragment: str) -> str: } } """ + +GQL_COPY_PROJECT = """ +mutation CopyProject($data: CopyProjectData!) { + data: copyProject(data: $data) +} +""" diff --git a/src/kili/adapters/kili_api_gateway/project/operations_mixin.py b/src/kili/adapters/kili_api_gateway/project/operations_mixin.py index 7abf76b4f..af48555de 100644 --- a/src/kili/adapters/kili_api_gateway/project/operations_mixin.py +++ b/src/kili/adapters/kili_api_gateway/project/operations_mixin.py @@ -21,11 +21,12 @@ from .common import get_project from .mappers import project_data_mapper, project_where_mapper from .operations import ( + GQL_COPY_PROJECT, GQL_COUNT_PROJECTS, GQL_CREATE_PROJECT, get_update_properties_in_project_mutation, ) -from .types import ProjectDataKiliAPIGatewayInput +from .types import CopyProjectInput, ProjectDataKiliAPIGatewayInput class ProjectOperationMixin(BaseOperationMixin): @@ -104,3 +105,20 @@ def update_properties_in_project( variables = {"data": data, "where": {"id": project_id}} result = self.graphql_client.execute(mutation, variables) return load_project_json_fields(result["data"], fields) + + def copy_project( + self, + project_id: ProjectId, + project_data: CopyProjectInput, + ) -> ProjectId: + """Copy a project.""" + variables = { + "data": { + "projectId": project_id, + "shouldCopyAssets": project_data.should_copy_assets, + "shouldCopyUsers": project_data.should_copy_members, + } + } + + result = self.graphql_client.execute(GQL_COPY_PROJECT, variables) + return ProjectId(result.get("data", "")) diff --git a/src/kili/adapters/kili_api_gateway/project/types.py b/src/kili/adapters/kili_api_gateway/project/types.py index a5ead2a82..83e870f46 100644 --- a/src/kili/adapters/kili_api_gateway/project/types.py +++ b/src/kili/adapters/kili_api_gateway/project/types.py @@ -36,3 +36,11 @@ class ProjectDataKiliAPIGatewayInput: should_relaunch_kpi_computation: Optional[bool] title: Optional[str] use_honeypot: Optional[bool] + + +@dataclass +class CopyProjectInput: + """Copy project input data for Kili API Gateway.""" + + should_copy_members: Optional[bool] + should_copy_assets: Optional[bool] diff --git a/src/kili/entrypoints/mutations/project/__init__.py b/src/kili/entrypoints/mutations/project/__init__.py index 798980082..db5c98deb 100644 --- a/src/kili/entrypoints/mutations/project/__init__.py +++ b/src/kili/entrypoints/mutations/project/__init__.py @@ -183,8 +183,8 @@ def copy_project( # pylint: disable=too-many-arguments title if `None` is provided. description: Description for the new project. Defaults to empty string if `None` is provided. - copy_json_interface: Include json interface in the copy. - copy_quality_settings: Include quality settings in the copy. + copy_json_interface: Deprecated. Always include json interface in the copy. + copy_quality_settings: Deprecated. Always include quality settings in the copy. copy_members: Include members in the copy. copy_assets: Include assets in the copy. copy_labels: Include labels in the copy. @@ -196,12 +196,15 @@ def copy_project( # pylint: disable=too-many-arguments Examples: >>> kili.copy_project(from_project_id="clbqn56b331234567890l41c0") """ + if (not copy_json_interface) or (not copy_quality_settings): + raise ValueError( + "The 'copy_json_interface' and 'copy_quality_settings' arguments are deprecated." + ) + return ProjectCopier(self).copy_project( # pyright: ignore[reportGeneralTypeIssues] from_project_id, title, description, - copy_json_interface, - copy_quality_settings, copy_members, copy_assets, copy_labels, diff --git a/src/kili/services/copy_project/__init__.py b/src/kili/services/copy_project/__init__.py index 212ee9cf8..c831c167a 100644 --- a/src/kili/services/copy_project/__init__.py +++ b/src/kili/services/copy_project/__init__.py @@ -1,21 +1,14 @@ """Copy project implementation.""" import itertools -import json import logging -from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Optional from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions -from kili.core.constants import QUERY_BATCH_SIZE -from kili.core.utils.pagination import batcher +from kili.adapters.kili_api_gateway.project.types import CopyProjectInput from kili.domain.asset import AssetFilters from kili.domain.label import LabelFilters from kili.domain.project import ProjectId -from kili.domain.types import ListOrTuple -from kili.use_cases.asset.media_downloader import get_download_assets_function -from kili.utils.tempfile import TemporaryDirectory -from kili.utils.tqdm import tqdm if TYPE_CHECKING: from kili.client import Kili @@ -26,32 +19,20 @@ class ProjectCopier: # pylint: disable=too-few-public-methods FIELDS_PROJECT = ( "title", - "inputType", "description", "id", "dataConnections.id", ) - FIELDS_JSON_INTERFACE = ("jsonInterface",) - FIELDS_QUALITY_SETTINGS = ( - "canSkipAsset", - "consensusTotCoverage", - "minConsensusSize", - "reviewCoverage", - "secondsToLabelBeforeAutoAssign", - "useHoneyPot", - ) def __init__(self, kili: "Kili") -> None: self.disable_tqdm = False self.kili = kili - def copy_project( # pylint: disable=too-many-arguments,too-many-locals + def copy_project( # pylint: disable=too-many-arguments self, from_project_id: str, title: Optional[str], description: Optional[str], - copy_json_interface: bool, - copy_quality_settings: bool, copy_members: bool, copy_assets: bool, copy_labels: bool, @@ -64,17 +45,7 @@ def copy_project( # pylint: disable=too-many-arguments,too-many-locals logger = logging.getLogger("kili.services.copy_project") logger.setLevel(logging.INFO) - if not any( - (copy_json_interface, copy_quality_settings, copy_members, copy_assets, copy_labels) - ): - raise ValueError("At least one element has to be copied.") - if copy_labels: - if not copy_json_interface: - raise ValueError( - "`copy_json_interface` must be set to `True` for copying the source project" - " labels." - ) if not copy_assets: raise ValueError( "`copy_assets` must be set to `True` for copying the source project labels." @@ -85,41 +56,31 @@ def copy_project( # pylint: disable=too-many-arguments,too-many-locals ) fields = self.FIELDS_PROJECT - if copy_json_interface: - fields += self.FIELDS_JSON_INTERFACE - if copy_quality_settings: - fields += self.FIELDS_QUALITY_SETTINGS src_project = self.kili.kili_api_gateway.get_project(ProjectId(from_project_id), fields) if src_project["dataConnections"] and copy_assets: raise NotImplementedError("Copying projects with cloud storage is not supported.") - new_project_title = title or self._generate_project_title(src_title=src_project["title"]) - - new_project_description = description or "" + logger.info("Copying new project...") - json_interface = src_project["jsonInterface"] if copy_json_interface else {"jobs": {}} - - new_project_id = self.kili.create_project( - input_type=src_project["inputType"], - json_interface=json_interface, - title=new_project_title, - description=new_project_description, - )["id"] - logger.info(f"Creating new project with id: '{new_project_id}'") + new_project_id = self.kili.kili_api_gateway.copy_project( + ProjectId(from_project_id), + CopyProjectInput( + should_copy_members=copy_members, + should_copy_assets=copy_assets, + ), + ) - if copy_members: - logger.info("Copying members...") - self._copy_members(from_project_id, new_project_id) + logger.info(f"Created new project {new_project_id}") - if copy_quality_settings: - logger.info("Copying quality settings...") - self._copy_quality_settings(new_project_id, src_project) + self.kili.update_properties_in_project( + project_id=new_project_id, + title=title or self._generate_project_title(src_project["title"]), + description=description, + ) - if copy_assets: - logger.info("Copying assets...") - self._copy_assets(from_project_id, new_project_id) + logger.info("Updated title/description") if copy_labels: logger.info("Copying labels...") @@ -139,118 +100,6 @@ def _generate_project_title(self, src_title: str) -> str: i += 1 return new_title - def _copy_members(self, from_project_id: str, new_project_id: str) -> None: - members = self.kili.project_users( - project_id=from_project_id, - fields=["activated", "role", "user.email", "status", "id"], - disable_tqdm=True, - ) - - members = [memb for memb in members if memb["status"] == "ACTIVATED" and memb["activated"]] - - for member in tqdm(members, disable=self.disable_tqdm): - self.kili.append_to_roles( - project_id=new_project_id, - user_email=member["user"]["email"], - role=member["role"], - ) - - def _copy_quality_settings(self, new_project_id: str, src_project: Dict) -> None: - self.kili.update_properties_in_project( - project_id=new_project_id, - can_skip_asset=src_project["canSkipAsset"], - consensus_tot_coverage=src_project["consensusTotCoverage"], - min_consensus_size=src_project["minConsensusSize"], - use_honeypot=src_project["useHoneyPot"], - review_coverage=src_project["reviewCoverage"], - seconds_to_label_before_auto_assign=src_project["secondsToLabelBeforeAutoAssign"], - ) - - def _copy_assets(self, from_project_id: str, new_project_id: str) -> None: - """Copy assets from a project to another. - - Fetches assets by batch since `content` urls expire. - """ - filters = AssetFilters(project_id=ProjectId(from_project_id)) - options = QueryOptions(disable_tqdm=False) - fields = ( - "content", - "ocrMetadata", - "externalId", - "isHoneypot", - "jsonContent", - "jsonMetadata", - ) - - assets_gen = self.kili.kili_api_gateway.list_assets(filters, fields, options) - - with TemporaryDirectory() as tmp_dir: - # TODO: modify download_media function so it can take a generator of assets - for assets_batch in batcher(assets_gen, QUERY_BATCH_SIZE): - downloaded_assets = self._download_assets( - from_project_id, fields, tmp_dir, assets_batch - ) - self._upload_assets(new_project_id, downloaded_assets) - - def _download_assets( - self, from_project_id: str, fields: ListOrTuple[str], tmp_dir: Path, assets: List[Dict] - ) -> List[Dict]: - download_function, _ = get_download_assets_function( - self.kili.kili_api_gateway, - download_media=True, - fields=fields, - project_id=ProjectId(from_project_id), - local_media_dir=str(tmp_dir.resolve()), - ) - assert download_function - return download_function(assets) - - def _upload_assets(self, new_project_id: str, assets: List[Dict]) -> List[Dict]: - # ocrMetadata field of assets need to be merged with jsonMetadata field - for asset in assets: - if isinstance(asset["jsonMetadata"], str): - try: - asset["jsonMetadata"] = json.loads(asset["jsonMetadata"]) - except json.JSONDecodeError: - asset["jsonMetadata"] = {} - if asset["ocrMetadata"]: - asset["jsonMetadata"] = {**asset["jsonMetadata"], **asset["ocrMetadata"]} - - # we cannot send None values in the content_array or json_content_array fields of - # kili.append_many_to_assets. So we need to sort and group the assets by the presence of - # content and jsonContent. - assets = sorted( - assets, - key=lambda asset: (bool(asset["content"]), bool(asset["jsonContent"])), - ) - assets_iterator = itertools.groupby( - assets, - key=lambda asset: (bool(asset["content"]), bool(asset["jsonContent"])), - ) - - for key, group in assets_iterator: - has_content, has_jsoncontent = key - group = list(group) - - content_array = [asset["content"] for asset in group] if has_content else None - external_id_array = [asset["externalId"] for asset in group] - is_honeypot_array = [asset["isHoneypot"] for asset in group] - json_content_array = ( - [asset["jsonContent"] for asset in group] if has_jsoncontent else None - ) - json_metadata_array = [asset["jsonMetadata"] for asset in group] - - self.kili.append_many_to_dataset( - project_id=new_project_id, - content_array=content_array, - external_id_array=external_id_array, - is_honeypot_array=is_honeypot_array, - json_content_array=json_content_array, - json_metadata_array=json_metadata_array, - disable_tqdm=True, - ) - return assets - # pylint: disable=too-many-locals def _copy_labels(self, from_project_id: str, new_project_id: str) -> None: assets_new_project = self.kili.kili_api_gateway.list_assets(