From db2c12183cbb7eb06a720c5e7804b89306b07cbe Mon Sep 17 00:00:00 2001 From: Cristhian Zanforlin Lousa Date: Mon, 14 Oct 2024 18:02:15 -0300 Subject: [PATCH] fix: Union type on components (#4137) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🐛 (type_extraction.py): fix condition to correctly handle UnionType objects in type extraction process * ✨ (test_schema.py): add support for additional data types and nested structures in post_process_type function to improve type handling and flexibility * ✅ (test_schema.py): add additional test cases for post_process_type function to cover various Union types and combinations for better test coverage and accuracy --- .../type_extraction/type_extraction.py | 4 +- src/backend/tests/unit/test_schema.py | 59 ++++++++++++++++++- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/backend/base/langflow/type_extraction/type_extraction.py b/src/backend/base/langflow/type_extraction/type_extraction.py index 6a060f116ef..9a2725c03af 100644 --- a/src/backend/base/langflow/type_extraction/type_extraction.py +++ b/src/backend/base/langflow/type_extraction/type_extraction.py @@ -55,7 +55,9 @@ def post_process_type(_type): # If the return type is not a Union, then we just return it as a list inner_type = _type[0] if isinstance(_type, list) else _type - if not hasattr(inner_type, "__origin__") or inner_type.__origin__ != Union: + if (not hasattr(inner_type, "__origin__") or inner_type.__origin__ != Union) and ( + not hasattr(inner_type, "__class__") or inner_type.__class__.__name__ != "UnionType" + ): return _type if isinstance(_type, list) else [_type] # If the return type is a Union, then we need to parse it _type = extract_union_types_from_generic_alias(_type) diff --git a/src/backend/tests/unit/test_schema.py b/src/backend/tests/unit/test_schema.py index b101ce60813..68128ef28e3 100644 --- a/src/backend/tests/unit/test_schema.py +++ b/src/backend/tests/unit/test_schema.py @@ -1,12 +1,14 @@ -from collections.abc import Sequence +from types import NoneType from typing import Union +from langflow.schema.data import Data import pytest from pydantic import ValidationError from langflow.template import Input, Output from langflow.template.field.base import UNDEFINED from langflow.type_extraction.type_extraction import post_process_type +from collections.abc import Sequence as SequenceABC @pytest.fixture(name="client", autouse=True) @@ -40,11 +42,62 @@ def test_validate_type_class(self): assert input_obj.field_type == "int" def test_post_process_type_function(self): + # Basic types assert set(post_process_type(int)) == {int} + assert set(post_process_type(float)) == {float} + + # List and Sequence types assert set(post_process_type(list[int])) == {int} + assert set(post_process_type(SequenceABC[float])) == {float} + + # Union types assert set(post_process_type(Union[int, str])) == {int, str} - assert set(post_process_type(Union[int, Sequence[str]])) == {int, str} - assert set(post_process_type(Union[int, Sequence[int]])) == {int} + assert set(post_process_type(Union[int, SequenceABC[str]])) == {int, str} + assert set(post_process_type(Union[int, SequenceABC[int]])) == {int} + + # Nested Union with lists + assert set(post_process_type(Union[list[int], list[str]])) == {int, str} + assert set(post_process_type(Union[int, list[str], list[float]])) == {int, str, float} + + # Custom data types + assert set(post_process_type(Data)) == {Data} + assert set(post_process_type(list[Data])) == {Data} + + # Union with custom types + assert set(post_process_type(Union[Data, str])) == {Data, str} + assert set(post_process_type(Union[Data, int, list[str]])) == {Data, int, str} + + # Empty lists and edge cases + assert set(post_process_type(list)) == {list} + assert set(post_process_type(Union[int, None])) == {int, NoneType} + assert set(post_process_type(Union[None, list[None]])) == {None, NoneType} + + # Handling complex nested structures + assert set(post_process_type(Union[SequenceABC[Union[int, str]], list[float]])) == {int, str, float} + assert set(post_process_type(Union[Union[Union[int, list[str]], list[float]], str])) == {int, str, float} + + # Non-generic types should return as is + assert set(post_process_type(dict)) == {dict} + assert set(post_process_type(tuple)) == {tuple} + + # Union with custom types + assert set(post_process_type(Union[Data, str])) == {Data, str} + assert set(post_process_type(Data | str)) == {Data, str} + assert set(post_process_type(Data | int | list[str])) == {Data, int, str} + + # More complex combinations with Data + assert set(post_process_type(Data | list[float])) == {Data, float} + assert set(post_process_type(Data | Union[int, str])) == {Data, int, str} + assert set(post_process_type(Data | list[int] | None)) == {Data, int, type(None)} + assert set(post_process_type(Data | Union[float, None])) == {Data, float, type(None)} + + # Multiple Data types combined + assert set(post_process_type(Union[Data, Union[str, float]])) == {Data, str, float} + assert set(post_process_type(Union[Data | float | str, int])) == {Data, int, float, str} + + # Testing with nested unions and lists + assert set(post_process_type(Union[list[Data], list[Union[int, str]]])) == {Data, int, str} + assert set(post_process_type(Data | list[Union[float, str]])) == {Data, float, str} def test_input_to_dict(self): input_obj = Input(field_type="str")