Skip to content

Commit

Permalink
Merge pull request #125 from SciCatProject/fix-type-hints
Browse files Browse the repository at this point in the history
Fix type hints
  • Loading branch information
jl-wynen authored Aug 10, 2023
2 parents df63497 + 113c753 commit c7d6272
Show file tree
Hide file tree
Showing 41 changed files with 372 additions and 179 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ filterwarnings = [
[tool.mypy]
plugins = "pydantic.mypy"
mypy_path = "src"
exclude = ["venv"]
exclude = ["docs/conf.py", "tools/model-generation", "venv"]
ignore_missing_imports = true
enable_error_code = [
"ignore-without-code",
Expand All @@ -78,6 +78,10 @@ strict = true
show_error_codes = true
warn_unreachable = true

[[tool.mypy.overrides]]
module = "tests.*"
disallow_untyped_defs = false

[tool.ruff]
line-length = 88
select = ["B", "D", "E", "F", "I", "S", "T20", "PGH", "FBT003", "RUF100"]
Expand Down
61 changes: 37 additions & 24 deletions src/scitacean/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,11 @@
# Copyright (c) 2023 SciCat Project (https://github.com/SciCatProject/scitacean)

"""Types and functions to implement models for communication with SciCat."""

try:
# Python 3.11+
from enum import StrEnum as _StrEnum

_DatasetTypeBases = (_StrEnum,)
except ImportError:
from enum import Enum as _Enum

_DatasetTypeBases = (
str,
_Enum,
)
from __future__ import annotations

import dataclasses
from datetime import datetime
from typing import Any, Dict, Iterable, Optional, Type, TypeVar, Union
from typing import Any, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union

import pydantic
from dateutil.parser import parse as parse_datetime
Expand All @@ -31,12 +19,28 @@

ModelType = TypeVar("ModelType", bound=pydantic.BaseModel)

try:
# Python 3.11+
from enum import StrEnum

class DatasetType(StrEnum):
"""Type of Dataset."""

RAW = "raw"
DERIVED = "derived"

del StrEnum

except ImportError:
from enum import Enum

class DatasetType(*_DatasetTypeBases):
"""Type of Dataset."""
class DatasetType(str, Enum): # type: ignore[no-redef]
"""Type of Dataset."""

RAW = "raw"
DERIVED = "derived"
RAW = "raw"
DERIVED = "derived"

del Enum


class BaseModel(pydantic.BaseModel):
Expand All @@ -56,6 +60,8 @@ class Config:
extra="forbid",
)

_masked_fields: Tuple[str, ...]

# Some schemas contain fields that we don't want to use in Scitacean.
# Normally, omitting them from the model would result in an error when
# building a model from the JSON returned by SciCat.
Expand Down Expand Up @@ -83,21 +89,23 @@ def _delete_ignored_args(self, args: Dict[str, Any]) -> None:
if is_pydantic_v1():

@classmethod
def get_model_fields(cls) -> Dict[str, pydantic.fields.ModelField]:
return cls.__fields__
def get_model_fields(cls) -> Dict[str, Any]:
return cls.__fields__ # type: ignore[return-value]

def model_dump(self, *args, **kwargs) -> Dict[str, Any]:
def model_dump(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
return self.dict(*args, **kwargs)

def model_dump_json(self, *args, **kwargs) -> str:
def model_dump_json(self, *args: Any, **kwargs: Any) -> str:
return self.json(*args, **kwargs)

@classmethod
def model_construct(cls: Type[ModelType], *args, **kwargs) -> ModelType:
def model_construct(
cls: Type[ModelType], *args: Any, **kwargs: Any
) -> ModelType:
return cls.construct(*args, **kwargs)

@classmethod
def model_rebuild(cls, *args, **kwargs) -> Optional[bool]:
def model_rebuild(cls, *args: Any, **kwargs: Any) -> Optional[bool]:
return cls.update_forward_refs(*args, **kwargs)

else:
Expand All @@ -107,6 +115,7 @@ def get_model_fields(cls) -> Dict[str, pydantic.fields.FieldInfo]:
return cls.model_fields


@dataclasses.dataclass
class BaseUserModel:
"""Base class for user models.
Expand All @@ -132,6 +141,10 @@ def _upload_model_dict(self) -> Dict[str, Any]:
if not field.name.startswith("_")
}

@classmethod
def from_download_model(cls, download_model: Any) -> BaseUserModel:
raise NotImplementedError()


def construct(
model: Type[ModelType],
Expand Down
47 changes: 23 additions & 24 deletions src/scitacean/_dataset_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@

import dateutil.parser

from ._base_model import DatasetType
from ._internal.dataclass_wrapper import dataclass_optional_args
from .datablock import OrigDatablock
from .filesystem import RemotePath
from .model import (
DatasetType,
DownloadDataset,
Lifecycle,
History,
Relationship,
Technique,
)
Expand Down Expand Up @@ -80,6 +79,16 @@ def used_by(self, dataset_type: DatasetType) -> bool:
)

_FIELD_SPEC = [
Field(
name="type",
description="Characterize type of dataset, either 'raw' or 'derived'. Autofilled when choosing the proper inherited models.",
read_only=False,
required=True,
scicat_name="type",
type=DatasetType,
used_by_derived=True,
used_by_raw=True,
),
Field(
name="access_groups",
description="Optional additional groups which have read access to the data. Users which are members in one of the groups listed here are allowed to access this data. The special group 'public' makes data available to all users.",
Expand Down Expand Up @@ -216,7 +225,7 @@ def used_by(self, dataset_type: DatasetType) -> bool:
read_only=True,
required=False,
scicat_name="history",
type=None,
type=type(None),
used_by_derived=True,
used_by_raw=True,
),
Expand Down Expand Up @@ -460,16 +469,6 @@ def used_by(self, dataset_type: DatasetType) -> bool:
used_by_derived=True,
used_by_raw=True,
),
Field(
name="type",
description="Characterize type of dataset, either 'raw' or 'derived'. Autofilled when choosing the proper inherited models.",
read_only=False,
required=True,
scicat_name="type",
type=DatasetType,
used_by_derived=True,
used_by_raw=True,
),
Field(
name="updated_at",
description="Date and time when this record was updated last. This property is added and maintained by mongoose.",
Expand Down Expand Up @@ -551,12 +550,12 @@ def used_by(self, dataset_type: DatasetType) -> bool:
"_source_folder",
"_source_folder_host",
"_techniques",
"_type",
"_updated_at",
"_updated_by",
"_used_software",
"_validation_status",
"_meta",
"_type",
"_default_checksum_algorithm",
"_orig_datablocks",
)
Expand Down Expand Up @@ -1000,16 +999,6 @@ def techniques(self, techniques: Optional[List[Technique]]) -> None:
"""Stores the metadata information for techniques."""
self._techniques = techniques

@property
def type(self) -> Optional[DatasetType]:
"""Characterize type of dataset, either 'raw' or 'derived'. Autofilled when choosing the proper inherited models."""
return self._type

@type.setter
def type(self, type: Optional[DatasetType]) -> None:
"""Characterize type of dataset, either 'raw' or 'derived'. Autofilled when choosing the proper inherited models."""
self._type = type

@property
def updated_at(self) -> Optional[datetime]:
"""Date and time when this record was updated last. This property is added and maintained by mongoose."""
Expand Down Expand Up @@ -1050,6 +1039,16 @@ def meta(self, meta: Dict[str, Any]) -> None:
"""Dict of scientific metadata."""
self._meta = meta

@property
def type(self) -> DatasetType:
"""Characterize type of dataset, either 'raw' or 'derived'. Autofilled when choosing the proper inherited models."""
return self._type

@type.setter
def type(self, type: Union[DatasetType, Literal["raw", "derived"]]) -> None:
"""Characterize type of dataset, either 'raw' or 'derived'. Autofilled when choosing the proper inherited models."""
self._type = DatasetType(type)

@staticmethod
def _prepare_fields_from_download(
download_model: DownloadDataset,
Expand Down
10 changes: 5 additions & 5 deletions src/scitacean/_html_repr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import html
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterable, List, Optional

import pydantic

Expand Down Expand Up @@ -81,7 +81,7 @@ def _format_metadata_value(value: Any) -> str:
_VALUE_UNIT_KEYS = {"value", "unit", "valueSI", "unitSI"}


def _has_value_unit_encoding(meta_value: Any):
def _has_value_unit_encoding(meta_value: Any) -> bool:
if (
isinstance(meta_value, dict)
and "value" in meta_value
Expand All @@ -104,7 +104,7 @@ class Field:


def _format_field(field: Field) -> str:
def format_value(val) -> str:
def format_value(val: Any) -> str:
if isinstance(val, datetime):
return val.strftime("%Y-%m-%d %H:%M:%S%z")
return html.escape(str(val))
Expand Down Expand Up @@ -175,15 +175,15 @@ def _get_fields(dset: Dataset) -> List[Field]:
)


def _check_error(field: Field, validation: Dict[str, str]) -> Optional[str]:
def _check_error(field: Dataset.Field, validation: Dict[str, str]) -> Optional[str]:
if field.name in validation:
# TODO validation uses model names (camelCase)
return validation[field.name]
return None


def _validate(dset: Dataset) -> Dict[str, str]:
def single_elem(xs):
def single_elem(xs: Iterable[Any]) -> Any:
(x,) = xs
return x

Expand Down
27 changes: 25 additions & 2 deletions src/scitacean/_internal/dataclass_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,36 @@
"""Python version-independent dataclasses."""

import dataclasses
from typing import Callable, Type, TypeVar
from typing import Any, Callable, Type, TypeVar

T = TypeVar("T")


try:
from typing import dataclass_transform
except ImportError:
from typing import Tuple, Union

F = TypeVar("F")

def dataclass_transform(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
frozen_default: bool = False,
field_specifiers: Tuple[Union[Type[Any], Callable[..., Any]], ...] = (),
**kwargs: Any,
) -> Callable[[T], T]:
def impl(f: F) -> F:
return f

return impl


@dataclass_transform()
def dataclass_optional_args(
kw_only: bool = False, slots: bool = False, **kwargs
kw_only: bool = False, slots: bool = False, **kwargs: Any
) -> Callable[[Type[T]], Type[T]]:
"""Create a dataclass with modern arguments."""
try:
Expand Down
4 changes: 2 additions & 2 deletions src/scitacean/_internal/pydantic_compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 SciCat Project (https://github.com/SciCatProject/scitacean)
from typing import Any, Literal
from typing import Any, Callable, Literal

import pydantic

Expand All @@ -13,7 +13,7 @@ def field_validator(
*args: Any,
mode: Literal["before", "after", "wrap", "plain"] = "after",
**kwargs: Any,
) -> Any:
) -> Callable[[Any], Any]:
if is_pydantic_v1():
return pydantic.validator(*args, pre=(mode == "before"), **kwargs)
return pydantic.field_validator(*args, mode=mode, **kwargs)
16 changes: 8 additions & 8 deletions src/scitacean/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def get_dataset(
dataset_model=self.scicat.get_dataset_model(
pid, strict_validation=strict_validation
),
orig_datablock_models=orig_datablocks,
orig_datablock_models=orig_datablocks or [],
)

def upload_new_dataset_now(self, dataset: Dataset) -> Dataset:
Expand Down Expand Up @@ -236,7 +236,7 @@ def upload_new_dataset_now(self, dataset: Dataset) -> Dataset:

return Dataset.from_download_models(
dataset_model=finalized_model,
orig_datablock_models=finalized_orig_datablocks,
orig_datablock_models=finalized_orig_datablocks or [],
)

def _upload_orig_datablocks(
Expand Down Expand Up @@ -387,7 +387,7 @@ def download_files(
for f in to_download
if (p := f.remote_access_path(dataset.source_folder)) is not None
],
local=[f.local_path for f in to_download],
local=[f.local_path for f in to_download], # type: ignore[misc]
)
for f in to_download:
f.validate_after_download()
Expand Down Expand Up @@ -687,11 +687,11 @@ def _url_concat(a: str, b: str) -> str:


def _strip_token(error: Any, token: str) -> str:
error = str(error)
error = re.sub(r"token=[\w\-./]+", "token=<HIDDEN>", error)
err = str(error)
err = re.sub(r"token=[\w\-./]+", "token=<HIDDEN>", err)
if token: # token can be ""
error = error.replace(token, "<HIDDEN>")
return error
err = err.replace(token, "<HIDDEN>")
return err


def _make_orig_datablock(
Expand Down Expand Up @@ -811,7 +811,7 @@ def _remove_up_to_date_local_files(
file
for file in files
if not (
file.local_path.exists()
file.local_path.exists() # type: ignore[union-attr]
and dataclasses.replace(
file, checksum_algorithm=checksum_algorithm
).local_is_up_to_date()
Expand Down
Loading

0 comments on commit c7d6272

Please sign in to comment.