diff --git a/mex/common/cli.py b/mex/common/cli.py index e2b6cb00..a74f34d4 100644 --- a/mex/common/cli.py +++ b/mex/common/cli.py @@ -2,11 +2,10 @@ import pdb # noqa: T100 import sys from bdb import BdbQuit -from collections.abc import Callable from functools import partial from textwrap import dedent from traceback import format_exc -from typing import Any +from typing import Any, Callable import click from click import Command, Option @@ -47,7 +46,7 @@ def _field_to_parameters(name: str, field: FieldInfo) -> list[str]: names = [name] + ([field.alias] if field.alias else []) names = [n.replace("_", "-") for n in names] dashes = ["--" if len(n) > 1 else "-" for n in names] - return [f"{d}{n}" for d, n in zip(dashes, names, strict=False)] + return [f"{d}{n}" for d, n in zip(dashes, names)] def _field_to_option(name: str, settings_cls: type[SettingsType]) -> Option: diff --git a/mex/common/connector/base.py b/mex/common/connector/base.py index ad953c07..e667cd55 100644 --- a/mex/common/connector/base.py +++ b/mex/common/connector/base.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from contextlib import ExitStack from types import TracebackType -from typing import TypeVar, cast, final +from typing import Optional, TypeVar, cast, final from mex.common.context import ContextStore @@ -46,9 +46,9 @@ def __enter__(self: ConnectorType) -> ConnectorType: @final def __exit__( self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], ) -> None: """Exit connector by calling `close` method and removing it from context.""" self.close() diff --git a/mex/common/extract.py b/mex/common/extract.py index bbdbb081..a6015fd0 100644 --- a/mex/common/extract.py +++ b/mex/common/extract.py @@ -1,7 +1,6 @@ from collections import defaultdict -from collections.abc import Generator from pathlib import Path -from typing import TYPE_CHECKING, Any, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generator, TypeVar, Union import numpy as np import pandas as pd @@ -26,7 +25,7 @@ def get_dtypes_for_model(model: type["BaseModel"]) -> dict[str, "Dtype"]: """Get the basic dtypes per field for a model from the `PANDAS_DTYPE_MAP`. Args: - model: Model class for which to get pandas data types per field alias + model: Model class for which to get pandas dtypes per field alias Returns: Mapping from field alias to dtype strings diff --git a/mex/common/identity/registry.py b/mex/common/identity/registry.py index b133ea6e..39c514dc 100644 --- a/mex/common/identity/registry.py +++ b/mex/common/identity/registry.py @@ -1,5 +1,4 @@ -from collections.abc import Hashable -from typing import Final +from typing import Final, Hashable from mex.common.identity.base import BaseProvider from mex.common.identity.memory import MemoryIdentityProvider diff --git a/mex/common/ldap/connector.py b/mex/common/ldap/connector.py index e1188b68..bb168180 100644 --- a/mex/common/ldap/connector.py +++ b/mex/common/ldap/connector.py @@ -1,6 +1,5 @@ -from collections.abc import Generator from functools import cache -from typing import TypeVar +from typing import Generator, TypeVar from urllib.parse import urlsplit from ldap3 import AUTO_BIND_NO_TLS, Connection, Server diff --git a/mex/common/ldap/extract.py b/mex/common/ldap/extract.py index 69eafab7..be6070f7 100644 --- a/mex/common/ldap/extract.py +++ b/mex/common/ldap/extract.py @@ -1,5 +1,5 @@ from collections import defaultdict -from collections.abc import Iterable +from typing import Iterable from mex.common.identity import get_provider from mex.common.ldap.models.person import LDAPPerson, LDAPPersonWithQuery diff --git a/mex/common/ldap/transform.py b/mex/common/ldap/transform.py index 01253f31..532fa452 100644 --- a/mex/common/ldap/transform.py +++ b/mex/common/ldap/transform.py @@ -1,7 +1,7 @@ import re -from collections.abc import Generator, Iterable from dataclasses import dataclass from functools import cache +from typing import Generator, Iterable from mex.common.exceptions import MExError from mex.common.ldap.models.actor import LDAPActor diff --git a/mex/common/logging.py b/mex/common/logging.py index 4f61a7fc..809b23d8 100644 --- a/mex/common/logging.py +++ b/mex/common/logging.py @@ -1,9 +1,8 @@ import logging import logging.config -from collections.abc import Callable, Generator from datetime import datetime from functools import wraps -from typing import Any, TypeVar +from typing import Any, Callable, Generator, Optional, TypeVar, Union import click @@ -60,12 +59,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Generator[YieldT, None, None]: return wrapper -def get_ts(ts: datetime | None = None) -> str: +def get_ts(ts: Optional[datetime] = None) -> str: """Get a styled timestamp tag for prefixing log messages.""" return click.style(f"[{ts or datetime.now()}]", fg="bright_yellow") -def echo(text: str | bytes, ts: datetime | None = None, **styles: Any) -> None: +def echo(text: Union[str, bytes], ts: Optional[datetime] = None, **styles: Any) -> None: """Echo the given text with the given styles and the current timestamp prefix. Args: diff --git a/mex/common/models/__init__.py b/mex/common/models/__init__.py index be82521a..0edccea7 100644 --- a/mex/common/models/__init__.py +++ b/mex/common/models/__init__.py @@ -1,4 +1,4 @@ -from typing import Final, get_args +from typing import Final, Union, get_args from mex.common.models.access_platform import ( BaseAccessPlatform, @@ -99,38 +99,37 @@ "MEX_PRIMARY_SOURCE_STABLE_TARGET_ID", ) -AnyBaseModel = ( - BaseAccessPlatform - | BaseActivity - | BaseContactPoint - | BaseDistribution - | BaseOrganization - | BaseOrganizationalUnit - | BasePerson - | BasePrimarySource - | BaseResource - | BaseVariable - | BaseVariableGroup -) - +AnyBaseModel = Union[ + BaseAccessPlatform, + BaseActivity, + BaseContactPoint, + BaseDistribution, + BaseOrganization, + BaseOrganizationalUnit, + BasePerson, + BasePrimarySource, + BaseResource, + BaseVariable, + BaseVariableGroup, +] BASE_MODEL_CLASSES: Final[list[type[AnyBaseModel]]] = list(get_args(AnyBaseModel)) BASE_MODEL_CLASSES_BY_NAME: Final[dict[str, type[AnyBaseModel]]] = { cls.__name__: cls for cls in BASE_MODEL_CLASSES } -AnyExtractedModel = ( - ExtractedAccessPlatform - | ExtractedActivity - | ExtractedContactPoint - | ExtractedDistribution - | ExtractedOrganization - | ExtractedOrganizationalUnit - | ExtractedPerson - | ExtractedPrimarySource - | ExtractedResource - | ExtractedVariable - | ExtractedVariableGroup -) +AnyExtractedModel = Union[ + ExtractedAccessPlatform, + ExtractedActivity, + ExtractedContactPoint, + ExtractedDistribution, + ExtractedOrganization, + ExtractedOrganizationalUnit, + ExtractedPerson, + ExtractedPrimarySource, + ExtractedResource, + ExtractedVariable, + ExtractedVariableGroup, +] EXTRACTED_MODEL_CLASSES: Final[list[type[AnyExtractedModel]]] = list( get_args(AnyExtractedModel) ) @@ -138,19 +137,19 @@ cls.__name__: cls for cls in EXTRACTED_MODEL_CLASSES } -AnyMergedModel = ( - MergedAccessPlatform - | MergedActivity - | MergedContactPoint - | MergedDistribution - | MergedOrganization - | MergedOrganizationalUnit - | MergedPerson - | MergedPrimarySource - | MergedResource - | MergedVariable - | MergedVariableGroup -) +AnyMergedModel = Union[ + MergedAccessPlatform, + MergedActivity, + MergedContactPoint, + MergedDistribution, + MergedOrganization, + MergedOrganizationalUnit, + MergedPerson, + MergedPrimarySource, + MergedResource, + MergedVariable, + MergedVariableGroup, +] MERGED_MODEL_CLASSES: Final[list[type[AnyMergedModel]]] = list(get_args(AnyMergedModel)) MERGED_MODEL_CLASSES_BY_NAME: Final[dict[str, type[AnyMergedModel]]] = { cls.__name__: cls for cls in MERGED_MODEL_CLASSES diff --git a/mex/common/models/base.py b/mex/common/models/base.py index eba44854..56785992 100644 --- a/mex/common/models/base.py +++ b/mex/common/models/base.py @@ -2,11 +2,12 @@ import pickle # nosec from collections.abc import MutableMapping from functools import cache -from types import UnionType from typing import ( Any, TypeVar, Union, + get_args, + get_origin, ) from pydantic import BaseModel as PydanticBaseModel @@ -16,11 +17,11 @@ ValidationError, model_validator, ) +from pydantic.fields import FieldInfo from pydantic.json_schema import DEFAULT_REF_TEMPLATE, JsonSchemaMode from pydantic.json_schema import GenerateJsonSchema as PydanticJsonSchemaGenerator from mex.common.models.schema import JsonSchemaGenerator -from mex.common.utils import get_inner_types RawModelDataT = TypeVar("RawModelDataT") @@ -69,46 +70,52 @@ def model_json_schema( @cache def _get_alias_lookup(cls) -> dict[str, str]: """Build a cached mapping from field alias to field names.""" - return { - field_info.alias or field_name: field_name - for field_name, field_info in cls.model_fields.items() - } + return {field.alias or name: name for name, field in cls.model_fields.items()} @classmethod @cache def _get_list_field_names(cls) -> list[str]: """Build a cached list of fields that look like lists.""" - field_names = [] - for field_name, field_info in cls.model_fields.items(): - field_types = get_inner_types( - field_info.annotation, unpack=(Union, UnionType) - ) - if any( - isinstance(field_type, type) and issubclass(field_type, list) - for field_type in field_types - ): - field_names.append(field_name) - return field_names + + def is_object_subclass_of_list(obj: Any) -> bool: + try: + return issubclass(obj, list) + except TypeError: + return False + + list_fields = [] + for name, field in cls.model_fields.items(): + origin = get_origin(field.annotation) + if is_object_subclass_of_list(origin): + list_fields.append(name) + elif origin is Union: + for arg in get_args(field.annotation): + if is_object_subclass_of_list(get_origin(arg)): + list_fields.append(name) + break + return list_fields @classmethod @cache def _get_field_names_allowing_none(cls) -> list[str]: """Build a cached list of fields can be set to None.""" - field_names: list[str] = [] - for field_name, field_info in cls.model_fields.items(): + fields: list[str] = [] + for name, field_info in cls.model_fields.items(): validator = TypeAdapter(field_info.annotation) try: validator.validate_python(None) except ValidationError: continue - field_names.append(field_name) - return field_names + fields.append(name) + return fields @classmethod - def _convert_non_list_to_list(cls, field_name: str, value: Any) -> list[Any] | None: + def _convert_non_list_to_list( + cls, name: str, field: FieldInfo, value: Any + ) -> list[Any] | None: """Convert a non-list value to a list value by wrapping it in a list.""" if value is None: - if field_name in cls._get_field_names_allowing_none(): + if name in cls._get_field_names_allowing_none(): return None # if a list is required, we interpret None as an empty list return [] @@ -116,7 +123,7 @@ def _convert_non_list_to_list(cls, field_name: str, value: Any) -> list[Any] | N return [value] @classmethod - def _convert_list_to_non_list(cls, field_name: str, value: list[Any]) -> Any: + def _convert_list_to_non_list(cls, name: str, value: list[Any]) -> Any: """Convert a list value to a non-list value by unpacking it if possible.""" length = len(value) if length == 0: @@ -126,17 +133,19 @@ def _convert_list_to_non_list(cls, field_name: str, value: list[Any]) -> Any: # if we have just one entry, we can safely unpack it return value[0] # we cannot unambiguously unpack more than one value - raise ValueError(f"got multiple values for {field_name}") + raise ValueError(f"got multiple values for {name}") @classmethod - def _fix_value_listyness_for_field(cls, field_name: str, value: Any) -> Any: + def _fix_value_listyness_for_field( + cls, name: str, field: FieldInfo, value: Any + ) -> Any: """Check actual and desired shape of a value and fix it if necessary.""" - should_be_list = field_name in cls._get_list_field_names() + should_be_list = name in cls._get_list_field_names() is_list = isinstance(value, list) if not is_list and should_be_list: - return cls._convert_non_list_to_list(field_name, value) + return cls._convert_non_list_to_list(name, field, value) if is_list and not should_be_list: - return cls._convert_list_to_non_list(field_name, value) + return cls._convert_list_to_non_list(name, value) # already desired shape return value @@ -164,8 +173,10 @@ def fix_listyness(cls, data: RawModelDataT) -> RawModelDataT: if isinstance(data, MutableMapping): for name, value in data.items(): field_name = cls._get_alias_lookup().get(name, name) - if field_name in cls.model_fields: - data[name] = cls._fix_value_listyness_for_field(field_name, value) + if field := cls.model_fields.get(field_name): + data[name] = cls._fix_value_listyness_for_field( + field_name, field, value + ) return data def checksum(self) -> str: diff --git a/mex/common/models/filter.py b/mex/common/models/filter.py index e10f3489..3343b64f 100644 --- a/mex/common/models/filter.py +++ b/mex/common/models/filter.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any +from typing import Annotated, Any, Optional from pydantic import BaseModel, Field, create_model @@ -8,18 +8,18 @@ class EntityFilterRule(BaseModel, extra="forbid"): """Entity filter rule model.""" - forValues: list[str] | None = None - rule: str | None = None + forValues: Optional[list[str]] = None + rule: Optional[str] = None class EntityFilter(BaseModel, extra="forbid"): """Entity filter model.""" fieldInPrimarySource: str - locationInPrimarySource: str | None = None - examplesInPrimarySource: list[str] | None = None + locationInPrimarySource: Optional[str] = None + examplesInPrimarySource: Optional[list[str]] = None mappingRules: Annotated[list[EntityFilterRule], Field(min_length=1)] - comment: str | None = None + comment: Optional[str] = None def generate_entity_filter_schema( diff --git a/mex/common/models/mapping.py b/mex/common/models/mapping.py index b4d9f2d5..00b81223 100644 --- a/mex/common/models/mapping.py +++ b/mex/common/models/mapping.py @@ -8,19 +8,19 @@ class GenericRule(BaseModel, extra="forbid"): """Generic mapping rule model.""" - forValues: list[str] | None = None - setValues: list[Any] | None = None - rule: str | None = None + forValues: Optional[list[str]] = None + setValues: Optional[list[Any]] = None + rule: Optional[str] = None class GenericField(BaseModel, extra="forbid"): """Generic Field model.""" fieldInPrimarySource: str - locationInPrimarySource: str | None = None - examplesInPrimarySource: list[str] | None = None + locationInPrimarySource: Optional[str] = None + examplesInPrimarySource: Optional[list[str]] = None mappingRules: Annotated[list[GenericRule], Field(min_length=1)] - comment: str | None = None + comment: Optional[str] = None def generate_mapping_schema_for_mex_class( @@ -44,15 +44,15 @@ def generate_mapping_schema_for_mex_class( continue # first create dynamic rule model if get_origin(field_info.annotation) is list: - rule_type: Any = field_info.annotation + rule_type: object = Optional[field_info.annotation] else: - rule_type = list[field_info.annotation] # type: ignore[name-defined] + rule_type = Optional[list[field_info.annotation]] # type: ignore[name-defined] rule_model: type[GenericRule] = create_model( f"{field_name.capitalize()}MappingRule", __base__=(GenericRule,), setValues=( - Optional[rule_type], # noqa: UP007 + rule_type, None, ), ) diff --git a/mex/common/organigram/extract.py b/mex/common/organigram/extract.py index 173d8728..17ec89fa 100644 --- a/mex/common/organigram/extract.py +++ b/mex/common/organigram/extract.py @@ -1,5 +1,5 @@ import json -from collections.abc import Generator, Iterable +from typing import Generator, Iterable from mex.common.logging import watch from mex.common.models import ExtractedOrganizationalUnit diff --git a/mex/common/organigram/transform.py b/mex/common/organigram/transform.py index eb675a79..1c9e14d4 100644 --- a/mex/common/organigram/transform.py +++ b/mex/common/organigram/transform.py @@ -1,4 +1,4 @@ -from collections.abc import Generator, Iterable +from typing import Generator, Iterable from mex.common.logging import watch from mex.common.models import ExtractedOrganizationalUnit, ExtractedPrimarySource diff --git a/mex/common/primary_source/extract.py b/mex/common/primary_source/extract.py index 8ca95473..d38f2a18 100644 --- a/mex/common/primary_source/extract.py +++ b/mex/common/primary_source/extract.py @@ -1,5 +1,5 @@ import json -from collections.abc import Generator +from typing import Generator from mex.common.logging import watch from mex.common.primary_source.models import SeedPrimarySource diff --git a/mex/common/primary_source/transform.py b/mex/common/primary_source/transform.py index 085a1f66..01c5ccba 100644 --- a/mex/common/primary_source/transform.py +++ b/mex/common/primary_source/transform.py @@ -1,4 +1,4 @@ -from collections.abc import Generator, Iterable +from typing import Generator, Iterable from mex.common.logging import watch from mex.common.models import ( @@ -32,7 +32,7 @@ def transform_seed_primary_sources_to_extracted_primary_sources( identifierInPrimarySource=primary_source.identifier, title=primary_source.title, hadPrimarySource=MEX_PRIMARY_SOURCE_STABLE_TARGET_ID, - **set_stable_target_id, + **set_stable_target_id ) diff --git a/mex/common/settings.py b/mex/common/settings.py index c6f68f5f..7ff1b356 100644 --- a/mex/common/settings.py +++ b/mex/common/settings.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Optional, TypeVar +from typing import Any, Optional, TypeVar, Union from pydantic import AnyUrl, Field, SecretStr, model_validator from pydantic_core import Url @@ -142,7 +142,7 @@ def get(cls: type[SettingsType]) -> SettingsType: description="Backend API key with write access to call POST/PUT endpoints", validation_alias="MEX_BACKEND_API_KEY", ) - verify_session: bool | AssetsPath = Field( + verify_session: Union[bool, AssetsPath] = Field( True, description=( "Either a boolean that controls whether we verify the server's TLS " diff --git a/mex/common/sinks/backend_api.py b/mex/common/sinks/backend_api.py index e4a0906f..76dbe49d 100644 --- a/mex/common/sinks/backend_api.py +++ b/mex/common/sinks/backend_api.py @@ -1,4 +1,4 @@ -from collections.abc import Generator, Iterable +from typing import Generator, Iterable from mex.common.backend_api.connector import BackendApiConnector from mex.common.logging import watch diff --git a/mex/common/sinks/ndjson.py b/mex/common/sinks/ndjson.py index c3d05866..27676892 100644 --- a/mex/common/sinks/ndjson.py +++ b/mex/common/sinks/ndjson.py @@ -1,8 +1,7 @@ import json -from collections.abc import Generator, Iterable from contextlib import ExitStack from pathlib import Path -from typing import IO, Any +from typing import IO, Any, Generator, Iterable from mex.common.logging import echo, watch from mex.common.models import AnyExtractedModel diff --git a/mex/common/testing/plugin.py b/mex/common/testing/plugin.py index f32aadab..e0958d30 100644 --- a/mex/common/testing/plugin.py +++ b/mex/common/testing/plugin.py @@ -5,10 +5,9 @@ """ import os -from collections.abc import Generator from enum import Enum from pathlib import Path -from typing import Any +from typing import Any, Generator from unittest.mock import MagicMock from langdetect import DetectorFactory diff --git a/mex/common/transform.py b/mex/common/transform.py index 8f686beb..f0ef7cca 100644 --- a/mex/common/transform.py +++ b/mex/common/transform.py @@ -1,10 +1,9 @@ import json import re -from collections.abc import Iterable from enum import Enum from functools import cache from pathlib import PurePath -from typing import Any, cast +from typing import Any, Iterable, cast from uuid import UUID from pydantic import AnyUrl, SecretStr diff --git a/mex/common/types/__init__.py b/mex/common/types/__init__.py index 5885a9d9..1f758389 100644 --- a/mex/common/types/__init__.py +++ b/mex/common/types/__init__.py @@ -1,4 +1,4 @@ -from typing import Final, get_args +from typing import Final, Union, get_args from mex.common.types.email import Email from mex.common.types.identifier import ( @@ -119,25 +119,28 @@ "WorkPath", ) -AnyNestedModel = Link | Text +AnyNestedModel = Union[ + Link, + Text, +] NESTED_MODEL_CLASSES: Final[list[type[AnyNestedModel]]] = list(get_args(AnyNestedModel)) NESTED_MODEL_CLASSES_BY_NAME: Final[dict[str, type[AnyNestedModel]]] = { cls.__name__: cls for cls in NESTED_MODEL_CLASSES } -AnyMergedIdentifier = ( - MergedAccessPlatformIdentifier - | MergedActivityIdentifier - | MergedContactPointIdentifier - | MergedDistributionIdentifier - | MergedOrganizationalUnitIdentifier - | MergedOrganizationIdentifier - | MergedPersonIdentifier - | MergedPrimarySourceIdentifier - | MergedResourceIdentifier - | MergedVariableGroupIdentifier - | MergedVariableIdentifier -) +AnyMergedIdentifier = Union[ + MergedAccessPlatformIdentifier, + MergedActivityIdentifier, + MergedContactPointIdentifier, + MergedDistributionIdentifier, + MergedOrganizationalUnitIdentifier, + MergedOrganizationIdentifier, + MergedPersonIdentifier, + MergedPrimarySourceIdentifier, + MergedResourceIdentifier, + MergedVariableGroupIdentifier, + MergedVariableIdentifier, +] MERGED_IDENTIFIER_CLASSES: Final[list[type[AnyMergedIdentifier]]] = list( get_args(AnyMergedIdentifier) ) @@ -145,19 +148,19 @@ cls.__name__: cls for cls in MERGED_IDENTIFIER_CLASSES } -AnyExtractedIdentifier = ( - ExtractedAccessPlatformIdentifier - | ExtractedActivityIdentifier - | ExtractedContactPointIdentifier - | ExtractedDistributionIdentifier - | ExtractedOrganizationalUnitIdentifier - | ExtractedOrganizationIdentifier - | ExtractedPersonIdentifier - | ExtractedPrimarySourceIdentifier - | ExtractedResourceIdentifier - | ExtractedVariableGroupIdentifier - | ExtractedVariableIdentifier -) +AnyExtractedIdentifier = Union[ + ExtractedAccessPlatformIdentifier, + ExtractedActivityIdentifier, + ExtractedContactPointIdentifier, + ExtractedDistributionIdentifier, + ExtractedOrganizationalUnitIdentifier, + ExtractedOrganizationIdentifier, + ExtractedPersonIdentifier, + ExtractedPrimarySourceIdentifier, + ExtractedResourceIdentifier, + ExtractedVariableGroupIdentifier, + ExtractedVariableIdentifier, +] EXTRACTED_IDENTIFIER_CLASSES: Final[list[type[AnyExtractedIdentifier]]] = list( get_args(AnyExtractedIdentifier) ) diff --git a/mex/common/types/email.py b/mex/common/types/email.py index 51fac461..fd4416bd 100644 --- a/mex/common/types/email.py +++ b/mex/common/types/email.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Type from pydantic import GetJsonSchemaHandler from pydantic.json_schema import JsonSchemaValue @@ -11,7 +11,7 @@ class Email(str): """Email address of a person, organization or other entity.""" @classmethod - def __get_pydantic_core_schema__(cls, _source: type[Any]) -> core_schema.CoreSchema: + def __get_pydantic_core_schema__(cls, _source: Type[Any]) -> core_schema.CoreSchema: """Get pydantic core schema.""" return core_schema.str_schema(pattern=EMAIL_PATTERN) diff --git a/mex/common/types/identifier.py b/mex/common/types/identifier.py index c6732324..d5c42b6a 100644 --- a/mex/common/types/identifier.py +++ b/mex/common/types/identifier.py @@ -1,6 +1,6 @@ import re import string -from typing import Any, TypeVar +from typing import Any, Type, TypeVar from uuid import UUID, uuid4 from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler @@ -35,7 +35,7 @@ def generate(cls: type[IdentifierT], seed: int | None = None) -> IdentifierT: @classmethod def validate(cls: type[IdentifierT], value: Any) -> IdentifierT: """Validate a string, UUID or Identifier.""" - if isinstance(value, str | UUID | Identifier): + if isinstance(value, (str, UUID, Identifier)): value = str(value) if re.match(MEX_ID_PATTERN, value): return cls(value) @@ -46,7 +46,7 @@ def validate(cls: type[IdentifierT], value: Any) -> IdentifierT: @classmethod def __get_pydantic_core_schema__( - cls, source: type[Any], handler: GetCoreSchemaHandler + cls, source: Type[Any], handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: """Modify the schema to add the ID regex.""" identifier_schema = { diff --git a/mex/common/types/path.py b/mex/common/types/path.py index a543e6a2..b85e6d56 100644 --- a/mex/common/types/path.py +++ b/mex/common/types/path.py @@ -1,6 +1,6 @@ from os import PathLike from pathlib import Path -from typing import Any, TypeVar, Union +from typing import Any, Type, TypeVar, Union from pydantic_core import core_schema @@ -51,7 +51,7 @@ def is_relative(self) -> bool: return not self._path.is_absolute() @classmethod - def __get_pydantic_core_schema__(cls, _source: type[Any]) -> core_schema.CoreSchema: + def __get_pydantic_core_schema__(cls, _source: Type[Any]) -> core_schema.CoreSchema: """Set schema to str schema.""" from_str_schema = core_schema.chain_schema( [ @@ -75,7 +75,7 @@ def __get_pydantic_core_schema__(cls, _source: type[Any]) -> core_schema.CoreSch @classmethod def validate(cls: type[PathWrapperT], value: Any) -> PathWrapperT: """Convert a string value to a Text instance.""" - if isinstance(value, str | Path | PathWrapper): + if isinstance(value, (str, Path, PathWrapper)): return cls(value) raise ValueError(f"Cannot parse {type(value)} as {cls.__name__}") diff --git a/mex/common/types/timestamp.py b/mex/common/types/timestamp.py index 13005260..300f87e4 100644 --- a/mex/common/types/timestamp.py +++ b/mex/common/types/timestamp.py @@ -3,7 +3,7 @@ from enum import Enum from functools import total_ordering from itertools import zip_longest -from typing import Any, Literal, Union, cast, overload +from typing import Any, Literal, Optional, Type, Union, cast, overload from pandas._libs.tslibs import parsing from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler @@ -77,13 +77,13 @@ def __init__( def __init__( self, *args: int, - tzinfo: tzinfo | None = None, + tzinfo: Optional[tzinfo] = None, ) -> None: ... # pragma: no cover def __init__( self, *args: Union[int, str, date, datetime, "Timestamp"], - tzinfo: tzinfo | None = None, + tzinfo: Optional[tzinfo] = None, ) -> None: """Create a new timestamp instance. @@ -107,7 +107,7 @@ def __init__( if len(args) > 7: raise TypeError(f"Timestamp takes at most 7 arguments ({len(args)} given)") - if len(args) == 1 and isinstance(args[0], str | date | datetime | Timestamp): + if len(args) == 1 and isinstance(args[0], (str, date, datetime, Timestamp)): if tzinfo: raise TypeError("Timestamp does not accept tzinfo in parsing mode") if isinstance(args[0], Timestamp): @@ -135,7 +135,7 @@ def __init__( @classmethod def __get_pydantic_core_schema__( - cls, _source: type[Any], _handler: GetCoreSchemaHandler + cls, _source: Type[Any], _handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: """Mutate the field schema for timestamps.""" from_str_schema = core_schema.chain_schema( @@ -176,13 +176,13 @@ def __get_pydantic_json_schema__( @classmethod def validate(cls, value: Any) -> "Timestamp": """Parse any value and try to convert it into a timestamp.""" - if isinstance(value, cls | date | str): + if isinstance(value, (cls, date, str)): return cls(value) raise TypeError(f"Cannot parse {type(value)} as {cls.__name__}") @staticmethod def _parse_args( - *args: int, tzinfo: tzinfo | None = None + *args: int, tzinfo: Optional[tzinfo] = None ) -> tuple[datetime, TimestampPrecision]: """Parse 0-7 integer arguments into a timestamp and deduct the precision.""" if tzinfo is None: diff --git a/mex/common/types/vocabulary.py b/mex/common/types/vocabulary.py index ab9ce769..f519e2f7 100644 --- a/mex/common/types/vocabulary.py +++ b/mex/common/types/vocabulary.py @@ -31,8 +31,8 @@ class Concept(BaseModel): identifier: AnyUrl inScheme: AnyUrl prefLabel: BilingualText - altLabel: BilingualText | None = None - definition: BilingualText | None = None + altLabel: Optional[BilingualText] = None + definition: Optional[BilingualText] = None @cache @@ -93,7 +93,7 @@ def find(cls, search_term: Union[str, "Text"]) -> Optional["VocabularyEnum"]: Enum instance for the found concept or None """ language = getattr(search_term, "language", None) - search_term = normalize(str(search_term)) + search_term = normalize(search_term) for concept in cls.__concepts__: searchable_labels = [] for label in (concept.prefLabel, concept.altLabel): diff --git a/mex/common/utils.py b/mex/common/utils.py index f527bda5..facaf78a 100644 --- a/mex/common/utils.py +++ b/mex/common/utils.py @@ -1,18 +1,9 @@ import re -from collections.abc import Container, Generator, Iterable, Iterator from functools import cache from itertools import zip_longest from random import random from time import sleep -from types import UnionType -from typing import ( - Annotated, - Any, - TypeVar, - Union, - get_args, - get_origin, -) +from typing import Container, Iterable, Iterator, Optional, TypeVar T = TypeVar("T") @@ -25,7 +16,9 @@ def contains_any(base: Container[T], tokens: Iterable[T]) -> bool: return False -def any_contains_any(bases: Iterable[Container[T] | None], tokens: Iterable[T]) -> bool: +def any_contains_any( + bases: Iterable[Optional[Container[T]]], tokens: Iterable[T] +) -> bool: """Check if any of the given bases contains any of the given tokens.""" for base in bases: if base is None: @@ -36,31 +29,13 @@ def any_contains_any(bases: Iterable[Container[T] | None], tokens: Iterable[T]) return False -def get_inner_types( - annotation: Any, unpack: Iterable[Any] = (Union, UnionType, list) -) -> Generator[type, None, None]: - """Yield all inner types from annotations and the types in `unpack`.""" - origin = get_origin(annotation) - if origin == Annotated: - yield from get_inner_types(get_args(annotation)[0], unpack) - elif origin in unpack: - for arg in get_args(annotation): - yield from get_inner_types(arg, unpack) - elif origin is not None: - yield origin - elif annotation is None: - yield type(None) - else: - yield annotation - - @cache def normalize(string: str) -> str: """Normalize the given string to lowercase, numerals and single spaces.""" return " ".join(re.sub(r"[^a-z0-9]", " ", string.lower()).split()) -def grouper(chunk_size: int, iterable: Iterable[T]) -> Iterator[Iterable[T | None]]: +def grouper(chunk_size: int, iterable: Iterable[T]) -> Iterator[Iterable[Optional[T]]]: """Collect data into fixed-length chunks or blocks.""" # https://docs.python.org/3.9/library/itertools.html#itertools-recipes args = [iter(iterable)] * chunk_size diff --git a/mex/common/wikidata/extract.py b/mex/common/wikidata/extract.py index 7c5f33ed..3a15b7e3 100644 --- a/mex/common/wikidata/extract.py +++ b/mex/common/wikidata/extract.py @@ -1,4 +1,4 @@ -from collections.abc import Generator +from typing import Generator import requests diff --git a/mex/common/wikidata/models/organization.py b/mex/common/wikidata/models/organization.py index e7ee72ea..6ffea244 100644 --- a/mex/common/wikidata/models/organization.py +++ b/mex/common/wikidata/models/organization.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, Optional, Union from pydantic import ConfigDict, Field, model_validator @@ -20,8 +20,8 @@ class DataValue(BaseModel): @model_validator(mode="before") @classmethod def transform_strings_to_dict( - cls, values: dict[str, str | dict[str, str]] - ) -> dict[str, dict[str, str | None]] | dict[str, str | dict[str, str]]: + cls, values: dict[str, Union[str, dict[str, str]]] + ) -> Union[dict[str, dict[str, str | None]], dict[str, Union[str, dict[str, str]]]]: """Transform string and null value to a dict for parsing. Args: @@ -72,8 +72,8 @@ class Label(BaseModel): class Labels(BaseModel): """Model class for Labels.""" - de: Label | None = None - en: Label | None = None + de: Optional[Label] = None + en: Optional[Label] = None class Alias(BaseModel): diff --git a/mex/common/wikidata/transform.py b/mex/common/wikidata/transform.py index 7c6230df..501ce616 100644 --- a/mex/common/wikidata/transform.py +++ b/mex/common/wikidata/transform.py @@ -1,4 +1,4 @@ -from collections.abc import Generator, Iterable +from typing import Generator, Iterable from mex.common.models import ExtractedOrganization, ExtractedPrimarySource from mex.common.types import Text, TextLanguage diff --git a/tests/connector/test_http.py b/tests/connector/test_http.py index 3cd17df3..5716c469 100644 --- a/tests/connector/test_http.py +++ b/tests/connector/test_http.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from unittest.mock import MagicMock, Mock, call import pytest @@ -115,7 +115,7 @@ def test_connector_reset_context() -> None: ) def test_request_success( monkeypatch: MonkeyPatch, - sent_payload: dict[str, Any] | None, + sent_payload: Optional[dict[str, Any]], mocked_response: Response, expected_response: dict[str, Any], expected_kwargs: dict[str, Any], diff --git a/tests/identity/conftest.py b/tests/identity/conftest.py index 38594f59..a5e391a3 100644 --- a/tests/identity/conftest.py +++ b/tests/identity/conftest.py @@ -1,4 +1,4 @@ -from collections.abc import Generator +from typing import Generator import pytest diff --git a/tests/ldap/conftest.py b/tests/ldap/conftest.py index d137e03c..3ec43570 100644 --- a/tests/ldap/conftest.py +++ b/tests/ldap/conftest.py @@ -1,5 +1,4 @@ -from collections.abc import Callable -from typing import Any +from typing import Any, Callable from unittest.mock import MagicMock, Mock import pytest diff --git a/tests/models/test_base.py b/tests/models/test_base.py index 3f6a12f9..88e4416c 100644 --- a/tests/models/test_base.py +++ b/tests/models/test_base.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Annotated, Any, Literal +from typing import Annotated, Any, Literal, Optional, Union import pytest from pydantic import Field, ValidationError @@ -11,9 +11,9 @@ class ComplexDummyModel(BaseModel): """Dummy Model with multiple attributes.""" - optional_str: str | None = None + optional_str: Optional[str] = None required_str: str = "default" - optional_list: list[str] | None = None + optional_list: Optional[list[str]] = None required_list: list[str] = [] @@ -24,13 +24,6 @@ def test_get_field_names_allowing_none() -> None: ] -def test_get_list_field_names() -> None: - assert ComplexDummyModel._get_list_field_names() == [ - "optional_list", - "required_list", - ] - - class Animal(Enum): """Dummy enum to use in tests.""" @@ -66,7 +59,7 @@ class Animal(Enum): ], ) def test_base_model_listyness_fix( - data: dict[str, Any], expected: str | dict[str, Any] + data: dict[str, Any], expected: Union[str, dict[str, Any]] ) -> None: try: model = ComplexDummyModel.model_validate(data) @@ -91,7 +84,7 @@ class Shelter(Pet): class DummyBaseModel(BaseModel): - foo: str | None = None + foo: Optional[str] = None def test_base_model_checksum() -> None: diff --git a/tests/models/test_model_schemas.py b/tests/models/test_model_schemas.py index 7bf4fb42..5c98b974 100644 --- a/tests/models/test_model_schemas.py +++ b/tests/models/test_model_schemas.py @@ -1,10 +1,9 @@ import json import re -from collections.abc import Callable from copy import deepcopy from importlib.resources import files from itertools import zip_longest -from typing import Any +from typing import Any, Callable import pytest diff --git a/tests/test_cli.py b/tests/test_cli.py index 1440af8f..57fa53d9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,7 +1,7 @@ import logging import re from enum import Enum -from typing import Any +from typing import Any, Union import pytest from click.testing import CliRunner @@ -168,7 +168,7 @@ class MyEnum(Enum): "UnionFieldSettings", __base__=BaseSettings, union_field=( - bool | str, + Union[bool, str], Field(True, description="String or boolean"), ), ), diff --git a/tests/test_utils.py b/tests/test_utils.py index 12d64329..820cc680 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,6 @@ import json import time -from collections.abc import Iterable -from typing import Any +from typing import Any, Iterable import pytest