Skip to content

Commit

Permalink
fix: Union type on components (#4137)
Browse files Browse the repository at this point in the history
* 🐛 (type_extraction.py): fix condition to correctly handle UnionType objects in type extraction process

* ✨ (test_schema.py): add support for additional data types and nested structures in post_process_type function to improve type handling and flexibility

* ✅ (test_schema.py): add additional test cases for post_process_type function to cover various Union types and combinations for better test coverage and accuracy
  • Loading branch information
Cristhianzl authored Oct 14, 2024
1 parent e043964 commit db2c121
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/backend/base/langflow/type_extraction/type_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def post_process_type(_type):

# If the return type is not a Union, then we just return it as a list
inner_type = _type[0] if isinstance(_type, list) else _type
if not hasattr(inner_type, "__origin__") or inner_type.__origin__ != Union:
if (not hasattr(inner_type, "__origin__") or inner_type.__origin__ != Union) and (
not hasattr(inner_type, "__class__") or inner_type.__class__.__name__ != "UnionType"
):
return _type if isinstance(_type, list) else [_type]
# If the return type is a Union, then we need to parse it
_type = extract_union_types_from_generic_alias(_type)
Expand Down
59 changes: 56 additions & 3 deletions src/backend/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from collections.abc import Sequence
from types import NoneType
from typing import Union

from langflow.schema.data import Data
import pytest
from pydantic import ValidationError

from langflow.template import Input, Output
from langflow.template.field.base import UNDEFINED
from langflow.type_extraction.type_extraction import post_process_type
from collections.abc import Sequence as SequenceABC


@pytest.fixture(name="client", autouse=True)
Expand Down Expand Up @@ -40,11 +42,62 @@ def test_validate_type_class(self):
assert input_obj.field_type == "int"

def test_post_process_type_function(self):
# Basic types
assert set(post_process_type(int)) == {int}
assert set(post_process_type(float)) == {float}

# List and Sequence types
assert set(post_process_type(list[int])) == {int}
assert set(post_process_type(SequenceABC[float])) == {float}

# Union types
assert set(post_process_type(Union[int, str])) == {int, str}
assert set(post_process_type(Union[int, Sequence[str]])) == {int, str}
assert set(post_process_type(Union[int, Sequence[int]])) == {int}
assert set(post_process_type(Union[int, SequenceABC[str]])) == {int, str}
assert set(post_process_type(Union[int, SequenceABC[int]])) == {int}

# Nested Union with lists
assert set(post_process_type(Union[list[int], list[str]])) == {int, str}
assert set(post_process_type(Union[int, list[str], list[float]])) == {int, str, float}

# Custom data types
assert set(post_process_type(Data)) == {Data}
assert set(post_process_type(list[Data])) == {Data}

# Union with custom types
assert set(post_process_type(Union[Data, str])) == {Data, str}
assert set(post_process_type(Union[Data, int, list[str]])) == {Data, int, str}

# Empty lists and edge cases
assert set(post_process_type(list)) == {list}
assert set(post_process_type(Union[int, None])) == {int, NoneType}
assert set(post_process_type(Union[None, list[None]])) == {None, NoneType}

# Handling complex nested structures
assert set(post_process_type(Union[SequenceABC[Union[int, str]], list[float]])) == {int, str, float}
assert set(post_process_type(Union[Union[Union[int, list[str]], list[float]], str])) == {int, str, float}

# Non-generic types should return as is
assert set(post_process_type(dict)) == {dict}
assert set(post_process_type(tuple)) == {tuple}

# Union with custom types
assert set(post_process_type(Union[Data, str])) == {Data, str}
assert set(post_process_type(Data | str)) == {Data, str}
assert set(post_process_type(Data | int | list[str])) == {Data, int, str}

# More complex combinations with Data
assert set(post_process_type(Data | list[float])) == {Data, float}
assert set(post_process_type(Data | Union[int, str])) == {Data, int, str}
assert set(post_process_type(Data | list[int] | None)) == {Data, int, type(None)}
assert set(post_process_type(Data | Union[float, None])) == {Data, float, type(None)}

# Multiple Data types combined
assert set(post_process_type(Union[Data, Union[str, float]])) == {Data, str, float}
assert set(post_process_type(Union[Data | float | str, int])) == {Data, int, float, str}

# Testing with nested unions and lists
assert set(post_process_type(Union[list[Data], list[Union[int, str]]])) == {Data, int, str}
assert set(post_process_type(Data | list[Union[float, str]])) == {Data, float, str}

def test_input_to_dict(self):
input_obj = Input(field_type="str")
Expand Down

0 comments on commit db2c121

Please sign in to comment.