From 2ffc3fad47556960064f1d150f38ad5330fa0e4d Mon Sep 17 00:00:00 2001 From: cake-monotone Date: Wed, 16 Oct 2024 20:39:55 +0900 Subject: [PATCH] [red-knot] Implement `Type::Tuple` Comparisons (#13712) ## Summary This PR implements comparisons for (tuple, tuple). It will close #13688 and complete an item in #13618 once merged. ## Test Plan Basic tests are included for (tuple, tuple) comparisons. --------- Co-authored-by: Carl Meyer --- .../resources/mdtest/comparison/tuples.md | 205 ++++++++++++++++++ .../src/types/infer.rs | 133 ++++++++++++ 2 files changed, 338 insertions(+) create mode 100644 crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md new file mode 100644 index 0000000000000..7497e2300075a --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md @@ -0,0 +1,205 @@ +# Comparison - Tuples + +## Heterogeneous + +For tuples like `tuple[int, str, Literal[1]]` + +### Value Comparisons + +"Value Comparisons" refers to the operators: `==`, `!=`, `<`, `<=`, `>`, `>=` + +#### Results without Ambiguity + +Cases where the result can be definitively inferred as a `BooleanLiteral`. + +```py +a = (1, "test", (3, 13), True) +b = (1, "test", (3, 14), False) + +reveal_type(a == a) # revealed: Literal[True] +reveal_type(a != a) # revealed: Literal[False] +reveal_type(a < a) # revealed: Literal[False] +reveal_type(a <= a) # revealed: Literal[True] +reveal_type(a > a) # revealed: Literal[False] +reveal_type(a >= a) # revealed: Literal[True] + +reveal_type(a == b) # revealed: Literal[False] +reveal_type(a != b) # revealed: Literal[True] +reveal_type(a < b) # revealed: Literal[True] +reveal_type(a <= b) # revealed: Literal[True] +reveal_type(a > b) # revealed: Literal[False] +reveal_type(a >= b) # revealed: Literal[False] +``` + +Even when tuples have different lengths, comparisons should be handled appropriately. + +```py path=different_length.py +a = (1, 2, 3) +b = (1, 2, 3, 4) + +reveal_type(a == b) # revealed: Literal[False] +reveal_type(a != b) # revealed: Literal[True] +reveal_type(a < b) # revealed: Literal[True] +reveal_type(a <= b) # revealed: Literal[True] +reveal_type(a > b) # revealed: Literal[False] +reveal_type(a >= b) # revealed: Literal[False] + +c = ("a", "b", "c", "d") +d = ("a", "b", "c") + +reveal_type(c == d) # revealed: Literal[False] +reveal_type(c != d) # revealed: Literal[True] +reveal_type(c < d) # revealed: Literal[False] +reveal_type(c <= d) # revealed: Literal[False] +reveal_type(c > d) # revealed: Literal[True] +reveal_type(c >= d) # revealed: Literal[True] +``` + +#### Results with Ambiguity + +```py +def bool_instance() -> bool: ... +def int_instance() -> int: ... + +a = (bool_instance(),) +b = (int_instance(),) + +# TODO: All @Todo should be `bool` +reveal_type(a == a) # revealed: @Todo +reveal_type(a != a) # revealed: @Todo +reveal_type(a < a) # revealed: @Todo +reveal_type(a <= a) # revealed: @Todo +reveal_type(a > a) # revealed: @Todo +reveal_type(a >= a) # revealed: @Todo + +reveal_type(a == b) # revealed: @Todo +reveal_type(a != b) # revealed: @Todo +reveal_type(a < b) # revealed: @Todo +reveal_type(a <= b) # revealed: @Todo +reveal_type(a > b) # revealed: @Todo +reveal_type(a >= b) # revealed: @Todo +``` + +#### Comparison Unsupported + +If two tuples contain types that do not support comparison, the result may be `Unknown`. +However, `==` and `!=` are exceptions and can still provide definite results. + +```py +a = (1, 2) +b = (1, "hello") + +# TODO: should be Literal[False] +reveal_type(a == b) # revealed: @Todo + +# TODO: should be Literal[True] +reveal_type(a != b) # revealed: @Todo + +# TODO: should be Unknown and add more informative diagnostics +reveal_type(a < b) # revealed: @Todo +reveal_type(a <= b) # revealed: @Todo +reveal_type(a > b) # revealed: @Todo +reveal_type(a >= b) # revealed: @Todo +``` + +However, if the lexicographic comparison completes without reaching a point where str and int are compared, +Python will still produce a result based on the prior elements. + +```py path=short_circuit.py +a = (1, 2) +b = (999999, "hello") + +reveal_type(a == b) # revealed: Literal[False] +reveal_type(a != b) # revealed: Literal[True] +reveal_type(a < b) # revealed: Literal[True] +reveal_type(a <= b) # revealed: Literal[True] +reveal_type(a > b) # revealed: Literal[False] +reveal_type(a >= b) # revealed: Literal[False] +``` + +#### Matryoshka Tuples + +```py +a = (1, True, "Hello") +b = (a, a, a) +c = (b, b, b) + +reveal_type(c == c) # revealed: Literal[True] +reveal_type(c != c) # revealed: Literal[False] +reveal_type(c < c) # revealed: Literal[False] +reveal_type(c <= c) # revealed: Literal[True] +reveal_type(c > c) # revealed: Literal[False] +reveal_type(c >= c) # revealed: Literal[True] +``` + +#### Non Boolean Rich Comparisons + +```py +class A(): + def __eq__(self, o) -> str: ... + def __ne__(self, o) -> int: ... + def __lt__(self, o) -> float: ... + def __le__(self, o) -> object: ... + def __gt__(self, o) -> tuple: ... + def __ge__(self, o) -> list: ... + +a = (A(), A()) + +# TODO: All @Todo should be bool +reveal_type(a == a) # revealed: @Todo +reveal_type(a != a) # revealed: @Todo +reveal_type(a < a) # revealed: @Todo +reveal_type(a <= a) # revealed: @Todo +reveal_type(a > a) # revealed: @Todo +reveal_type(a >= a) # revealed: @Todo +``` + +### Membership Test Comparisons + +"Membership Test Comparisons" refers to the operators `in` and `not in`. + +```py +def int_instance() -> int: ... + +a = (1, 2) +b = ((3, 4), (1, 2)) +c = ((1, 2, 3), (4, 5, 6)) +d = ((int_instance(), int_instance()), (int_instance(), int_instance())) + +reveal_type(a in b) # revealed: Literal[True] +reveal_type(a not in b) # revealed: Literal[False] + +reveal_type(a in c) # revealed: Literal[False] +reveal_type(a not in c) # revealed: Literal[True] + +# TODO: All @Todo should be bool +reveal_type(a in d) # revealed: @Todo +reveal_type(a not in d) # revealed: @Todo +``` + +### Identity Comparisons + +"Identity Comparisons" refers to `is` and `is not`. + +```py +a = (1, 2) +b = ("a", "b") +c = (1, 2, 3) + +reveal_type(a is (1, 2)) # revealed: bool +reveal_type(a is not (1, 2)) # revealed: bool + +# TODO: Update to Literal[False] once str == int comparison is implemented +reveal_type(a is b) # revealed: @Todo +# TODO: Update to Literal[True] once str == int comparison is implemented +reveal_type(a is not b) # revealed: @Todo + +reveal_type(a is c) # revealed: Literal[False] +reveal_type(a is not c) # revealed: Literal[True] +``` + +## Homogeneous + +For tuples like `tuple[int, ...]`, `tuple[Any, ...]` + +// TODO diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 06b24d7ae0838..1622f96178361 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2831,7 +2831,68 @@ impl<'db> TypeInferenceBuilder<'db> { (_, Type::BytesLiteral(_)) => { self.infer_binary_type_comparison(left, op, KnownClass::Bytes.to_instance(self.db)) } + (Type::Tuple(lhs), Type::Tuple(rhs)) => { + // Note: This only works on heterogeneous tuple types. + let lhs_elements = lhs.elements(self.db).as_ref(); + let rhs_elements = rhs.elements(self.db).as_ref(); + let mut lexicographic_type_comparison = + |op| self.infer_lexicographic_type_comparison(lhs_elements, op, rhs_elements); + + match op { + ast::CmpOp::Eq => lexicographic_type_comparison(RichCompareOperator::Eq), + ast::CmpOp::NotEq => lexicographic_type_comparison(RichCompareOperator::Ne), + ast::CmpOp::Lt => lexicographic_type_comparison(RichCompareOperator::Lt), + ast::CmpOp::LtE => lexicographic_type_comparison(RichCompareOperator::Le), + ast::CmpOp::Gt => lexicographic_type_comparison(RichCompareOperator::Gt), + ast::CmpOp::GtE => lexicographic_type_comparison(RichCompareOperator::Ge), + ast::CmpOp::In | ast::CmpOp::NotIn => { + let mut eq_count = 0usize; + let mut not_eq_count = 0usize; + + for ty in rhs_elements { + let eq_result = self.infer_binary_type_comparison( + Type::Tuple(lhs), + ast::CmpOp::Eq, + *ty, + ).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); + + match eq_result { + Type::Todo => return Some(Type::Todo), + ty => match ty.bool(self.db) { + Truthiness::AlwaysTrue => eq_count += 1, + Truthiness::AlwaysFalse => not_eq_count += 1, + Truthiness::Ambiguous => (), + }, + } + } + + if eq_count >= 1 { + Some(Type::BooleanLiteral(op.is_in())) + } else if not_eq_count == rhs_elements.len() { + Some(Type::BooleanLiteral(op.is_not_in())) + } else { + Some(KnownClass::Bool.to_instance(self.db)) + } + } + ast::CmpOp::Is | ast::CmpOp::IsNot => { + // - `[ast::CmpOp::Is]`: returns `false` if the elements are definitely unequal, otherwise `bool` + // - `[ast::CmpOp::IsNot]`: returns `true` if the elements are definitely unequal, otherwise `bool` + let eq_result = lexicographic_type_comparison(RichCompareOperator::Eq) + .expect( + "infer_binary_type_comparison should never return None for `CmpOp::Eq`", + ); + + Some(match eq_result { + Type::Todo => Type::Todo, + ty => match ty.bool(self.db) { + Truthiness::AlwaysFalse => Type::BooleanLiteral(op.is_is_not()), + _ => KnownClass::Bool.to_instance(self.db), + }, + }) + } + } + } // Lookup the rich comparison `__dunder__` methods on instances (Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op { ast::CmpOp::Lt => { @@ -2845,6 +2906,55 @@ impl<'db> TypeInferenceBuilder<'db> { } } + /// Performs lexicographic comparison between two slices of types. + /// + /// For lexicographic comparison, elements from both slices are compared pairwise using + /// `infer_binary_type_comparison`. If a conclusive result cannot be determined as a `BoolLiteral`, + /// it returns `bool`. Returns `None` if the comparison is not supported. + fn infer_lexicographic_type_comparison( + &mut self, + left: &[Type<'db>], + op: RichCompareOperator, + right: &[Type<'db>], + ) -> Option> { + // Compare paired elements from left and right slices + for (l_ty, r_ty) in left.iter().copied().zip(right.iter().copied()) { + let eq_result = self + .infer_binary_type_comparison(l_ty, ast::CmpOp::Eq, r_ty) + .expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); + + match eq_result { + // If propagation is required, return the result as is + Type::Todo => return Some(Type::Todo), + ty => match ty.bool(self.db) { + // Types are equal, continue to the next pair + Truthiness::AlwaysTrue => continue, + // Types are not equal, perform the specified comparison and return the result + Truthiness::AlwaysFalse => { + return self.infer_binary_type_comparison(l_ty, op.into(), r_ty) + } + // If the intermediate result is ambiguous, we cannot determine the final result as BooleanLiteral. + // In this case, we simply return a bool instance. + Truthiness::Ambiguous => return Some(KnownClass::Bool.to_instance(self.db)), + }, + } + } + + // At this point, the lengths of the two slices may be different, but the prefix of + // left and right slices is entirely identical. + // We return a comparison of the slice lengths based on the operator. + let (left_len, right_len) = (left.len(), right.len()); + + Some(Type::BooleanLiteral(match op { + RichCompareOperator::Eq => left_len == right_len, + RichCompareOperator::Ne => left_len != right_len, + RichCompareOperator::Lt => left_len < right_len, + RichCompareOperator::Le => left_len <= right_len, + RichCompareOperator::Gt => left_len > right_len, + RichCompareOperator::Ge => left_len >= right_len, + })) + } + fn infer_subscript_expression(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> { let ast::ExprSubscript { range: _, @@ -3286,6 +3396,29 @@ impl<'db> TypeInferenceBuilder<'db> { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum RichCompareOperator { + Eq, + Ne, + Gt, + Ge, + Lt, + Le, +} + +impl From for ast::CmpOp { + fn from(value: RichCompareOperator) -> Self { + match value { + RichCompareOperator::Eq => ast::CmpOp::Eq, + RichCompareOperator::Ne => ast::CmpOp::NotEq, + RichCompareOperator::Lt => ast::CmpOp::Lt, + RichCompareOperator::Le => ast::CmpOp::LtE, + RichCompareOperator::Gt => ast::CmpOp::Gt, + RichCompareOperator::Ge => ast::CmpOp::GtE, + } + } +} + fn format_import_from_module(level: u32, module: Option<&str>) -> String { format!( "{}{}",