Skip to content

Commit

Permalink
Revert #114
Browse files Browse the repository at this point in the history
  • Loading branch information
cutoffthetop committed Mar 12, 2024
1 parent 1e3f5fa commit 5e7ffe9
Show file tree
Hide file tree
Showing 38 changed files with 197 additions and 227 deletions.
5 changes: 2 additions & 3 deletions mex/common/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import pdb # noqa: T100
import sys
from bdb import BdbQuit
from collections.abc import Callable
from functools import partial
from textwrap import dedent
from traceback import format_exc
from typing import Any
from typing import Any, Callable

import click
from click import Command, Option
Expand Down Expand Up @@ -47,7 +46,7 @@ def _field_to_parameters(name: str, field: FieldInfo) -> list[str]:
names = [name] + ([field.alias] if field.alias else [])
names = [n.replace("_", "-") for n in names]
dashes = ["--" if len(n) > 1 else "-" for n in names]
return [f"{d}{n}" for d, n in zip(dashes, names, strict=False)]
return [f"{d}{n}" for d, n in zip(dashes, names)]


def _field_to_option(name: str, settings_cls: type[SettingsType]) -> Option:
Expand Down
8 changes: 4 additions & 4 deletions mex/common/connector/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from contextlib import ExitStack
from types import TracebackType
from typing import TypeVar, cast, final
from typing import Optional, TypeVar, cast, final

from mex.common.context import ContextStore

Expand Down Expand Up @@ -46,9 +46,9 @@ def __enter__(self: ConnectorType) -> ConnectorType:
@final
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit connector by calling `close` method and removing it from context."""
self.close()
Expand Down
5 changes: 2 additions & 3 deletions mex/common/extract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import defaultdict
from collections.abc import Generator
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar, Union
from typing import TYPE_CHECKING, Any, Generator, TypeVar, Union

import numpy as np
import pandas as pd
Expand All @@ -26,7 +25,7 @@ def get_dtypes_for_model(model: type["BaseModel"]) -> dict[str, "Dtype"]:
"""Get the basic dtypes per field for a model from the `PANDAS_DTYPE_MAP`.
Args:
model: Model class for which to get pandas data types per field alias
model: Model class for which to get pandas dtypes per field alias
Returns:
Mapping from field alias to dtype strings
Expand Down
3 changes: 1 addition & 2 deletions mex/common/identity/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Hashable
from typing import Final
from typing import Final, Hashable

from mex.common.identity.base import BaseProvider
from mex.common.identity.memory import MemoryIdentityProvider
Expand Down
3 changes: 1 addition & 2 deletions mex/common/ldap/connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Generator
from functools import cache
from typing import TypeVar
from typing import Generator, TypeVar
from urllib.parse import urlsplit

from ldap3 import AUTO_BIND_NO_TLS, Connection, Server
Expand Down
2 changes: 1 addition & 1 deletion mex/common/ldap/extract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from collections.abc import Iterable
from typing import Iterable

from mex.common.identity import get_provider
from mex.common.ldap.models.person import LDAPPerson, LDAPPersonWithQuery
Expand Down
2 changes: 1 addition & 1 deletion mex/common/ldap/transform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from collections.abc import Generator, Iterable
from dataclasses import dataclass
from functools import cache
from typing import Generator, Iterable

from mex.common.exceptions import MExError
from mex.common.ldap.models.actor import LDAPActor
Expand Down
7 changes: 3 additions & 4 deletions mex/common/logging.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
import logging.config
from collections.abc import Callable, Generator
from datetime import datetime
from functools import wraps
from typing import Any, TypeVar
from typing import Any, Callable, Generator, Optional, TypeVar, Union

import click

Expand Down Expand Up @@ -60,12 +59,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Generator[YieldT, None, None]:
return wrapper


def get_ts(ts: datetime | None = None) -> str:
def get_ts(ts: Optional[datetime] = None) -> str:
"""Get a styled timestamp tag for prefixing log messages."""
return click.style(f"[{ts or datetime.now()}]", fg="bright_yellow")


def echo(text: str | bytes, ts: datetime | None = None, **styles: Any) -> None:
def echo(text: Union[str, bytes], ts: Optional[datetime] = None, **styles: Any) -> None:
"""Echo the given text with the given styles and the current timestamp prefix.
Args:
Expand Down
81 changes: 40 additions & 41 deletions mex/common/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Final, get_args
from typing import Final, Union, get_args

from mex.common.models.access_platform import (
BaseAccessPlatform,
Expand Down Expand Up @@ -99,58 +99,57 @@
"MEX_PRIMARY_SOURCE_STABLE_TARGET_ID",
)

AnyBaseModel = (
BaseAccessPlatform
| BaseActivity
| BaseContactPoint
| BaseDistribution
| BaseOrganization
| BaseOrganizationalUnit
| BasePerson
| BasePrimarySource
| BaseResource
| BaseVariable
| BaseVariableGroup
)

AnyBaseModel = Union[
BaseAccessPlatform,
BaseActivity,
BaseContactPoint,
BaseDistribution,
BaseOrganization,
BaseOrganizationalUnit,
BasePerson,
BasePrimarySource,
BaseResource,
BaseVariable,
BaseVariableGroup,
]
BASE_MODEL_CLASSES: Final[list[type[AnyBaseModel]]] = list(get_args(AnyBaseModel))
BASE_MODEL_CLASSES_BY_NAME: Final[dict[str, type[AnyBaseModel]]] = {
cls.__name__: cls for cls in BASE_MODEL_CLASSES
}

AnyExtractedModel = (
ExtractedAccessPlatform
| ExtractedActivity
| ExtractedContactPoint
| ExtractedDistribution
| ExtractedOrganization
| ExtractedOrganizationalUnit
| ExtractedPerson
| ExtractedPrimarySource
| ExtractedResource
| ExtractedVariable
| ExtractedVariableGroup
)
AnyExtractedModel = Union[
ExtractedAccessPlatform,
ExtractedActivity,
ExtractedContactPoint,
ExtractedDistribution,
ExtractedOrganization,
ExtractedOrganizationalUnit,
ExtractedPerson,
ExtractedPrimarySource,
ExtractedResource,
ExtractedVariable,
ExtractedVariableGroup,
]
EXTRACTED_MODEL_CLASSES: Final[list[type[AnyExtractedModel]]] = list(
get_args(AnyExtractedModel)
)
EXTRACTED_MODEL_CLASSES_BY_NAME: Final[dict[str, type[AnyExtractedModel]]] = {
cls.__name__: cls for cls in EXTRACTED_MODEL_CLASSES
}

AnyMergedModel = (
MergedAccessPlatform
| MergedActivity
| MergedContactPoint
| MergedDistribution
| MergedOrganization
| MergedOrganizationalUnit
| MergedPerson
| MergedPrimarySource
| MergedResource
| MergedVariable
| MergedVariableGroup
)
AnyMergedModel = Union[
MergedAccessPlatform,
MergedActivity,
MergedContactPoint,
MergedDistribution,
MergedOrganization,
MergedOrganizationalUnit,
MergedPerson,
MergedPrimarySource,
MergedResource,
MergedVariable,
MergedVariableGroup,
]
MERGED_MODEL_CLASSES: Final[list[type[AnyMergedModel]]] = list(get_args(AnyMergedModel))
MERGED_MODEL_CLASSES_BY_NAME: Final[dict[str, type[AnyMergedModel]]] = {
cls.__name__: cls for cls in MERGED_MODEL_CLASSES
Expand Down
73 changes: 42 additions & 31 deletions mex/common/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import pickle # nosec
from collections.abc import MutableMapping
from functools import cache
from types import UnionType
from typing import (
Any,
TypeVar,
Union,
get_args,
get_origin,
)

from pydantic import BaseModel as PydanticBaseModel
Expand All @@ -16,11 +17,11 @@
ValidationError,
model_validator,
)
from pydantic.fields import FieldInfo
from pydantic.json_schema import DEFAULT_REF_TEMPLATE, JsonSchemaMode
from pydantic.json_schema import GenerateJsonSchema as PydanticJsonSchemaGenerator

from mex.common.models.schema import JsonSchemaGenerator
from mex.common.utils import get_inner_types

RawModelDataT = TypeVar("RawModelDataT")

Expand Down Expand Up @@ -69,54 +70,60 @@ def model_json_schema(
@cache
def _get_alias_lookup(cls) -> dict[str, str]:
"""Build a cached mapping from field alias to field names."""
return {
field_info.alias or field_name: field_name
for field_name, field_info in cls.model_fields.items()
}
return {field.alias or name: name for name, field in cls.model_fields.items()}

@classmethod
@cache
def _get_list_field_names(cls) -> list[str]:
"""Build a cached list of fields that look like lists."""
field_names = []
for field_name, field_info in cls.model_fields.items():
field_types = get_inner_types(
field_info.annotation, unpack=(Union, UnionType)
)
if any(
isinstance(field_type, type) and issubclass(field_type, list)
for field_type in field_types
):
field_names.append(field_name)
return field_names

def is_object_subclass_of_list(obj: Any) -> bool:
try:
return issubclass(obj, list)
except TypeError:
return False

list_fields = []
for name, field in cls.model_fields.items():
origin = get_origin(field.annotation)
if is_object_subclass_of_list(origin):
list_fields.append(name)
elif origin is Union:
for arg in get_args(field.annotation):
if is_object_subclass_of_list(get_origin(arg)):
list_fields.append(name)
break
return list_fields

@classmethod
@cache
def _get_field_names_allowing_none(cls) -> list[str]:
"""Build a cached list of fields can be set to None."""
field_names: list[str] = []
for field_name, field_info in cls.model_fields.items():
fields: list[str] = []
for name, field_info in cls.model_fields.items():
validator = TypeAdapter(field_info.annotation)
try:
validator.validate_python(None)
except ValidationError:
continue
field_names.append(field_name)
return field_names
fields.append(name)
return fields

@classmethod
def _convert_non_list_to_list(cls, field_name: str, value: Any) -> list[Any] | None:
def _convert_non_list_to_list(
cls, name: str, field: FieldInfo, value: Any
) -> list[Any] | None:
"""Convert a non-list value to a list value by wrapping it in a list."""
if value is None:
if field_name in cls._get_field_names_allowing_none():
if name in cls._get_field_names_allowing_none():
return None
# if a list is required, we interpret None as an empty list
return []
# if the value is non-None, wrap it in a list
return [value]

@classmethod
def _convert_list_to_non_list(cls, field_name: str, value: list[Any]) -> Any:
def _convert_list_to_non_list(cls, name: str, value: list[Any]) -> Any:
"""Convert a list value to a non-list value by unpacking it if possible."""
length = len(value)
if length == 0:
Expand All @@ -126,17 +133,19 @@ def _convert_list_to_non_list(cls, field_name: str, value: list[Any]) -> Any:
# if we have just one entry, we can safely unpack it
return value[0]
# we cannot unambiguously unpack more than one value
raise ValueError(f"got multiple values for {field_name}")
raise ValueError(f"got multiple values for {name}")

@classmethod
def _fix_value_listyness_for_field(cls, field_name: str, value: Any) -> Any:
def _fix_value_listyness_for_field(
cls, name: str, field: FieldInfo, value: Any
) -> Any:
"""Check actual and desired shape of a value and fix it if necessary."""
should_be_list = field_name in cls._get_list_field_names()
should_be_list = name in cls._get_list_field_names()
is_list = isinstance(value, list)
if not is_list and should_be_list:
return cls._convert_non_list_to_list(field_name, value)
return cls._convert_non_list_to_list(name, field, value)
if is_list and not should_be_list:
return cls._convert_list_to_non_list(field_name, value)
return cls._convert_list_to_non_list(name, value)
# already desired shape
return value

Expand Down Expand Up @@ -164,8 +173,10 @@ def fix_listyness(cls, data: RawModelDataT) -> RawModelDataT:
if isinstance(data, MutableMapping):
for name, value in data.items():
field_name = cls._get_alias_lookup().get(name, name)
if field_name in cls.model_fields:
data[name] = cls._fix_value_listyness_for_field(field_name, value)
if field := cls.model_fields.get(field_name):
data[name] = cls._fix_value_listyness_for_field(
field_name, field, value
)
return data

def checksum(self) -> str:
Expand Down
Loading

0 comments on commit 5e7ffe9

Please sign in to comment.