Skip to content

Commit

Permalink
Check union args to don't consider Optional fields as complex Union
Browse files Browse the repository at this point in the history
  • Loading branch information
hramezani committed Aug 7, 2023
1 parent 04ec4c7 commit 8d2b3d9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
13 changes: 11 additions & 2 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pydantic._internal._typing_extra import origin_is_union
from pydantic._internal._utils import deep_update, lenient_issubclass
from pydantic.fields import FieldInfo
from typing_extensions import get_origin
from typing_extensions import get_args, get_origin

from pydantic_settings.utils import path_type_label

Expand Down Expand Up @@ -441,13 +441,22 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val
# simplest case, field is not complex, we only need to add the value if it was found
return value

def _union_is_complex(self, annotation: type[Any] | None, metadata: list[Any]) -> bool:
for arg in get_args(annotation):
if arg is type(None): # Optional
continue
elif _annotation_is_complex(arg, metadata):
return True

return False

def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
"""
Find out if a field is complex, and if so whether JSON errors should be ignored
"""
if self.field_is_complex(field):
allow_parse_failure = False
elif origin_is_union(get_origin(field.annotation)):
elif origin_is_union(get_origin(field.annotation)) and self._union_is_complex(field.annotation, field.metadata):
allow_parse_failure = True
else:
return False, False
Expand Down
25 changes: 25 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,3 +1791,28 @@ class Settings(BaseSettings):
env.set('z', '[{"x": 1, "y": {"foo": 1}}, {"x": 2, "y": {"foo": 2}}]')
s = Settings()
assert s.model_dump() == {'z': [{'x': 1, 'y': {'foo': 1}}, {'x': 2, 'y': {'foo': 2}}]}


def test_optional_field_from_env(env):
class Settings(BaseSettings):
x: Optional[str] = None

env.set('x', '123')

s = Settings()
assert s.x == '123'


@pytest.mark.skipif(not dotenv, reason='python-dotenv not installed')
def test_dotenv_optional_json_field(tmp_path):
p = tmp_path / '.env'
p.write_text("""DATA='{"foo":"bar"}'""")

class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=p)

data: Optional[Json[dict[str, str]]] = Field(default=None)

s = Settings()
print(s.data)
assert s.data == {'foo': 'bar'}

0 comments on commit 8d2b3d9

Please sign in to comment.