Skip to content

Commit

Permalink
Merge pull request #131 from SciCatProject/fix-default-field-values
Browse files Browse the repository at this point in the history
Dont init wrong fields in download
  • Loading branch information
jl-wynen authored Aug 16, 2023
2 parents 1cd7b27 + bd5b819 commit 7bcff94
Show file tree
Hide file tree
Showing 11 changed files with 483 additions and 36 deletions.
5 changes: 3 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@

intersphinx_mapping = {
"fabric": ("https://docs.fabfile.org/en/latest", None),
"hypothesis": ("https://hypothesis.readthedocs.io/en/latest/", None),
"python": ("https://docs.python.org/3", None),
"hypothesis": ("https://hypothesis.readthedocs.io/en/latest", None),
"paramiko": ("https://docs.paramiko.org/en/latest", None),
"pydantic": ("https://docs.pydantic.dev/latest", None),
"python": ("https://docs.python.org/3", None),
}

# autodocs includes everything, even irrelevant API internals. autosummary
Expand Down
2 changes: 2 additions & 0 deletions docs/release-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ Breaking changes
Bugfixes
~~~~~~~~

* Fields of derived datasets are no longer initialized when downloading raw datasets and vice versa.

Documentation
~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ warn_unreachable = true
[[tool.mypy.overrides]]
module = "tests.*"
disallow_untyped_defs = false
disallow_untyped_calls = false

[tool.ruff]
line-length = 88
Expand Down
44 changes: 43 additions & 1 deletion src/scitacean/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,31 @@ def _delete_ignored_args(self, args: Dict[str, Any]) -> None:
for key in self._masked_fields:
args.pop(key, None)

@classmethod
def user_model_type(cls) -> Optional[type]:
"""Return the user model type for this model.
Returns ``None`` if there is no user model, e.g., for ``Dataset``
where there is a custom class instead of a plain model.
"""
return None

@classmethod
def upload_model_type(cls) -> Optional[type]:
"""Return the upload model type for this model.
Returns ``None`` if the model cannot be uploaded or this is an upload model.
"""
return None

@classmethod
def download_model_type(cls) -> Optional[type]:
"""Return the download model type for this model.
Returns ``None`` if this is a download model.
"""
return None

if is_pydantic_v1():

@classmethod
Expand Down Expand Up @@ -143,7 +168,23 @@ def _upload_model_dict(self) -> Dict[str, Any]:

@classmethod
def from_download_model(cls, download_model: Any) -> BaseUserModel:
raise NotImplementedError()
raise NotImplementedError("Function does not exist for BaseUserModel")

@classmethod
def upload_model_type(cls) -> Optional[type]:
"""Return the upload model type for this user model.
Returns ``None`` if the model cannot be uploaded.
"""
return None

@classmethod
def download_model_type(cls) -> type:
"""Return the download model type for this user model."""
# There is no sensible default value here as there always exists a download
# model.
# All child classes must implement this function.
raise NotImplementedError("Function does not exist for BaseUserModel")


def construct(
Expand All @@ -159,6 +200,7 @@ def construct(
-------
If the model is created without validation, no fields will be converted
to their proper type but will simply be whatever arguments are passed.
See ``model_construct`` or :class:`pydantic.BaseModel` for more information.
A warning will be emitted in this case.
Expand Down
48 changes: 42 additions & 6 deletions src/scitacean/_dataset_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from __future__ import annotations

from datetime import datetime, timezone
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union

import dateutil.parser

Expand All @@ -21,7 +21,11 @@
from .datablock import OrigDatablock
from .filesystem import RemotePath
from .model import (
construct,
BaseModel,
BaseUserModel,
DownloadDataset,
DownloadLifecycle,
Lifecycle,
Relationship,
Technique,
Expand Down Expand Up @@ -1062,12 +1066,44 @@ def _prepare_fields_from_download(
init_args[field.name] = getattr(download_model, field.scicat_name)

init_args["meta"] = download_model.scientificMetadata
DatasetBase._convert_readonly_fields_in_place(read_only)
_convert_download_fields_in_place(init_args, read_only)

return init_args, read_only

@staticmethod
def _convert_readonly_fields_in_place(read_only: Dict[str, Any]) -> Dict[str, Any]:
if "_pid" in read_only:
read_only["_pid"] = _parse_pid(read_only["_pid"])
return read_only
def _convert_readonly_fields_in_place(read_only: Dict[str, Any]) -> None:
if (pid := read_only.get("_pid")) is not None:
read_only["_pid"] = _parse_pid(pid)


def _convert_download_fields_in_place(
init_args: Dict[str, Any], read_only: Dict[str, Any]
) -> None:
for mod, key in ((Technique, "techniques"), (Relationship, "relationships")):
init_args[key] = _list_field_from_download(mod, init_args.get(key))

DatasetBase._convert_readonly_fields_in_place(read_only)
if (lifecycle := read_only.get("_lifecycle")) is not None:
read_only["_lifecycle"] = Lifecycle.from_download_model(
_as_model(DownloadLifecycle, lifecycle)
)


def _list_field_from_download(
mod: Type[BaseUserModel], value: Optional[List[Any]]
) -> Optional[List[BaseUserModel]]:
if value is None:
return None
return [
mod.from_download_model(_as_model(mod.download_model_type(), item))
for item in value
]


# If validation fails, sub models are not converted automatically by Pydantic.
def _as_model(
mod: Type[BaseModel], value: Union[BaseModel, Dict[str, Any]]
) -> BaseModel:
if isinstance(value, dict):
return construct(mod, **value, _strict_validation=False)
return value
13 changes: 0 additions & 13 deletions src/scitacean/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,9 @@
from .datablock import OrigDatablock
from .file import File
from .model import (
BaseUserModel,
DatasetType,
DownloadDataset,
DownloadOrigDatablock,
Relationship,
Technique,
UploadDerivedDataset,
UploadOrigDatablock,
UploadRawDataset,
Expand Down Expand Up @@ -60,8 +57,6 @@ def from_download_models(
A new Dataset instance.
"""
init_args, read_only = DatasetBase._prepare_fields_from_download(dataset_model)
for mod, key in ((Technique, "techniques"), (Relationship, "relationships")):
init_args[key] = _list_field_from_download(mod, init_args[key])
dset = cls(**init_args)
for key, val in read_only.items():
setattr(dset, key, val)
Expand Down Expand Up @@ -459,11 +454,3 @@ def _list_field_for_upload(value: Optional[List[Any]]) -> Optional[List[Any]]:
if value is None:
return None
return [item.make_upload_model() for item in value]


def _list_field_from_download(
mod: Type[BaseUserModel], value: Optional[List[Any]]
) -> Optional[List[Any]]:
if value is None:
return None
return [mod.from_download_model(item) for item in value]
Loading

0 comments on commit 7bcff94

Please sign in to comment.