Skip to content

Commit

Permalink
[red-knot] Implement Type::Tuple Comparisons (#13712)
Browse files Browse the repository at this point in the history
## Summary

This PR implements comparisons for (tuple, tuple).

It will close #13688 and complete an item in #13618 once merged.

## Test Plan

Basic tests are included for (tuple, tuple) comparisons.

---------

Co-authored-by: Carl Meyer <[email protected]>
  • Loading branch information
cake-monotone and carljm authored Oct 16, 2024
1 parent 8f5b2aa commit 2ffc3fa
Show file tree
Hide file tree
Showing 2 changed files with 338 additions and 0 deletions.
205 changes: 205 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Comparison - Tuples

## Heterogeneous

For tuples like `tuple[int, str, Literal[1]]`

### Value Comparisons

"Value Comparisons" refers to the operators: `==`, `!=`, `<`, `<=`, `>`, `>=`

#### Results without Ambiguity

Cases where the result can be definitively inferred as a `BooleanLiteral`.

```py
a = (1, "test", (3, 13), True)
b = (1, "test", (3, 14), False)

reveal_type(a == a) # revealed: Literal[True]
reveal_type(a != a) # revealed: Literal[False]
reveal_type(a < a) # revealed: Literal[False]
reveal_type(a <= a) # revealed: Literal[True]
reveal_type(a > a) # revealed: Literal[False]
reveal_type(a >= a) # revealed: Literal[True]

reveal_type(a == b) # revealed: Literal[False]
reveal_type(a != b) # revealed: Literal[True]
reveal_type(a < b) # revealed: Literal[True]
reveal_type(a <= b) # revealed: Literal[True]
reveal_type(a > b) # revealed: Literal[False]
reveal_type(a >= b) # revealed: Literal[False]
```

Even when tuples have different lengths, comparisons should be handled appropriately.

```py path=different_length.py
a = (1, 2, 3)
b = (1, 2, 3, 4)

reveal_type(a == b) # revealed: Literal[False]
reveal_type(a != b) # revealed: Literal[True]
reveal_type(a < b) # revealed: Literal[True]
reveal_type(a <= b) # revealed: Literal[True]
reveal_type(a > b) # revealed: Literal[False]
reveal_type(a >= b) # revealed: Literal[False]

c = ("a", "b", "c", "d")
d = ("a", "b", "c")

reveal_type(c == d) # revealed: Literal[False]
reveal_type(c != d) # revealed: Literal[True]
reveal_type(c < d) # revealed: Literal[False]
reveal_type(c <= d) # revealed: Literal[False]
reveal_type(c > d) # revealed: Literal[True]
reveal_type(c >= d) # revealed: Literal[True]
```

#### Results with Ambiguity

```py
def bool_instance() -> bool: ...
def int_instance() -> int: ...

a = (bool_instance(),)
b = (int_instance(),)

# TODO: All @Todo should be `bool`
reveal_type(a == a) # revealed: @Todo
reveal_type(a != a) # revealed: @Todo
reveal_type(a < a) # revealed: @Todo
reveal_type(a <= a) # revealed: @Todo
reveal_type(a > a) # revealed: @Todo
reveal_type(a >= a) # revealed: @Todo

reveal_type(a == b) # revealed: @Todo
reveal_type(a != b) # revealed: @Todo
reveal_type(a < b) # revealed: @Todo
reveal_type(a <= b) # revealed: @Todo
reveal_type(a > b) # revealed: @Todo
reveal_type(a >= b) # revealed: @Todo
```

#### Comparison Unsupported

If two tuples contain types that do not support comparison, the result may be `Unknown`.
However, `==` and `!=` are exceptions and can still provide definite results.

```py
a = (1, 2)
b = (1, "hello")

# TODO: should be Literal[False]
reveal_type(a == b) # revealed: @Todo

# TODO: should be Literal[True]
reveal_type(a != b) # revealed: @Todo

# TODO: should be Unknown and add more informative diagnostics
reveal_type(a < b) # revealed: @Todo
reveal_type(a <= b) # revealed: @Todo
reveal_type(a > b) # revealed: @Todo
reveal_type(a >= b) # revealed: @Todo
```

However, if the lexicographic comparison completes without reaching a point where str and int are compared,
Python will still produce a result based on the prior elements.

```py path=short_circuit.py
a = (1, 2)
b = (999999, "hello")

reveal_type(a == b) # revealed: Literal[False]
reveal_type(a != b) # revealed: Literal[True]
reveal_type(a < b) # revealed: Literal[True]
reveal_type(a <= b) # revealed: Literal[True]
reveal_type(a > b) # revealed: Literal[False]
reveal_type(a >= b) # revealed: Literal[False]
```

#### Matryoshka Tuples

```py
a = (1, True, "Hello")
b = (a, a, a)
c = (b, b, b)

reveal_type(c == c) # revealed: Literal[True]
reveal_type(c != c) # revealed: Literal[False]
reveal_type(c < c) # revealed: Literal[False]
reveal_type(c <= c) # revealed: Literal[True]
reveal_type(c > c) # revealed: Literal[False]
reveal_type(c >= c) # revealed: Literal[True]
```

#### Non Boolean Rich Comparisons

```py
class A():
def __eq__(self, o) -> str: ...
def __ne__(self, o) -> int: ...
def __lt__(self, o) -> float: ...
def __le__(self, o) -> object: ...
def __gt__(self, o) -> tuple: ...
def __ge__(self, o) -> list: ...

a = (A(), A())

# TODO: All @Todo should be bool
reveal_type(a == a) # revealed: @Todo
reveal_type(a != a) # revealed: @Todo
reveal_type(a < a) # revealed: @Todo
reveal_type(a <= a) # revealed: @Todo
reveal_type(a > a) # revealed: @Todo
reveal_type(a >= a) # revealed: @Todo
```

### Membership Test Comparisons

"Membership Test Comparisons" refers to the operators `in` and `not in`.

```py
def int_instance() -> int: ...

a = (1, 2)
b = ((3, 4), (1, 2))
c = ((1, 2, 3), (4, 5, 6))
d = ((int_instance(), int_instance()), (int_instance(), int_instance()))

reveal_type(a in b) # revealed: Literal[True]
reveal_type(a not in b) # revealed: Literal[False]

reveal_type(a in c) # revealed: Literal[False]
reveal_type(a not in c) # revealed: Literal[True]

# TODO: All @Todo should be bool
reveal_type(a in d) # revealed: @Todo
reveal_type(a not in d) # revealed: @Todo
```

### Identity Comparisons

"Identity Comparisons" refers to `is` and `is not`.

```py
a = (1, 2)
b = ("a", "b")
c = (1, 2, 3)

reveal_type(a is (1, 2)) # revealed: bool
reveal_type(a is not (1, 2)) # revealed: bool

# TODO: Update to Literal[False] once str == int comparison is implemented
reveal_type(a is b) # revealed: @Todo
# TODO: Update to Literal[True] once str == int comparison is implemented
reveal_type(a is not b) # revealed: @Todo

reveal_type(a is c) # revealed: Literal[False]
reveal_type(a is not c) # revealed: Literal[True]
```

## Homogeneous

For tuples like `tuple[int, ...]`, `tuple[Any, ...]`

// TODO
133 changes: 133 additions & 0 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2831,7 +2831,68 @@ impl<'db> TypeInferenceBuilder<'db> {
(_, Type::BytesLiteral(_)) => {
self.infer_binary_type_comparison(left, op, KnownClass::Bytes.to_instance(self.db))
}
(Type::Tuple(lhs), Type::Tuple(rhs)) => {
// Note: This only works on heterogeneous tuple types.
let lhs_elements = lhs.elements(self.db).as_ref();
let rhs_elements = rhs.elements(self.db).as_ref();

let mut lexicographic_type_comparison =
|op| self.infer_lexicographic_type_comparison(lhs_elements, op, rhs_elements);

match op {
ast::CmpOp::Eq => lexicographic_type_comparison(RichCompareOperator::Eq),
ast::CmpOp::NotEq => lexicographic_type_comparison(RichCompareOperator::Ne),
ast::CmpOp::Lt => lexicographic_type_comparison(RichCompareOperator::Lt),
ast::CmpOp::LtE => lexicographic_type_comparison(RichCompareOperator::Le),
ast::CmpOp::Gt => lexicographic_type_comparison(RichCompareOperator::Gt),
ast::CmpOp::GtE => lexicographic_type_comparison(RichCompareOperator::Ge),
ast::CmpOp::In | ast::CmpOp::NotIn => {
let mut eq_count = 0usize;
let mut not_eq_count = 0usize;

for ty in rhs_elements {
let eq_result = self.infer_binary_type_comparison(
Type::Tuple(lhs),
ast::CmpOp::Eq,
*ty,
).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`");

match eq_result {
Type::Todo => return Some(Type::Todo),
ty => match ty.bool(self.db) {
Truthiness::AlwaysTrue => eq_count += 1,
Truthiness::AlwaysFalse => not_eq_count += 1,
Truthiness::Ambiguous => (),
},
}
}

if eq_count >= 1 {
Some(Type::BooleanLiteral(op.is_in()))
} else if not_eq_count == rhs_elements.len() {
Some(Type::BooleanLiteral(op.is_not_in()))
} else {
Some(KnownClass::Bool.to_instance(self.db))
}
}
ast::CmpOp::Is | ast::CmpOp::IsNot => {
// - `[ast::CmpOp::Is]`: returns `false` if the elements are definitely unequal, otherwise `bool`
// - `[ast::CmpOp::IsNot]`: returns `true` if the elements are definitely unequal, otherwise `bool`
let eq_result = lexicographic_type_comparison(RichCompareOperator::Eq)
.expect(
"infer_binary_type_comparison should never return None for `CmpOp::Eq`",
);

Some(match eq_result {
Type::Todo => Type::Todo,
ty => match ty.bool(self.db) {
Truthiness::AlwaysFalse => Type::BooleanLiteral(op.is_is_not()),
_ => KnownClass::Bool.to_instance(self.db),
},
})
}
}
}
// Lookup the rich comparison `__dunder__` methods on instances
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
ast::CmpOp::Lt => {
Expand All @@ -2845,6 +2906,55 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}

/// Performs lexicographic comparison between two slices of types.
///
/// For lexicographic comparison, elements from both slices are compared pairwise using
/// `infer_binary_type_comparison`. If a conclusive result cannot be determined as a `BoolLiteral`,
/// it returns `bool`. Returns `None` if the comparison is not supported.
fn infer_lexicographic_type_comparison(
&mut self,
left: &[Type<'db>],
op: RichCompareOperator,
right: &[Type<'db>],
) -> Option<Type<'db>> {
// Compare paired elements from left and right slices
for (l_ty, r_ty) in left.iter().copied().zip(right.iter().copied()) {
let eq_result = self
.infer_binary_type_comparison(l_ty, ast::CmpOp::Eq, r_ty)
.expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`");

match eq_result {
// If propagation is required, return the result as is
Type::Todo => return Some(Type::Todo),
ty => match ty.bool(self.db) {
// Types are equal, continue to the next pair
Truthiness::AlwaysTrue => continue,
// Types are not equal, perform the specified comparison and return the result
Truthiness::AlwaysFalse => {
return self.infer_binary_type_comparison(l_ty, op.into(), r_ty)
}
// If the intermediate result is ambiguous, we cannot determine the final result as BooleanLiteral.
// In this case, we simply return a bool instance.
Truthiness::Ambiguous => return Some(KnownClass::Bool.to_instance(self.db)),
},
}
}

// At this point, the lengths of the two slices may be different, but the prefix of
// left and right slices is entirely identical.
// We return a comparison of the slice lengths based on the operator.
let (left_len, right_len) = (left.len(), right.len());

Some(Type::BooleanLiteral(match op {
RichCompareOperator::Eq => left_len == right_len,
RichCompareOperator::Ne => left_len != right_len,
RichCompareOperator::Lt => left_len < right_len,
RichCompareOperator::Le => left_len <= right_len,
RichCompareOperator::Gt => left_len > right_len,
RichCompareOperator::Ge => left_len >= right_len,
}))
}

fn infer_subscript_expression(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> {
let ast::ExprSubscript {
range: _,
Expand Down Expand Up @@ -3286,6 +3396,29 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum RichCompareOperator {
Eq,
Ne,
Gt,
Ge,
Lt,
Le,
}

impl From<RichCompareOperator> for ast::CmpOp {
fn from(value: RichCompareOperator) -> Self {
match value {
RichCompareOperator::Eq => ast::CmpOp::Eq,
RichCompareOperator::Ne => ast::CmpOp::NotEq,
RichCompareOperator::Lt => ast::CmpOp::Lt,
RichCompareOperator::Le => ast::CmpOp::LtE,
RichCompareOperator::Gt => ast::CmpOp::Gt,
RichCompareOperator::Ge => ast::CmpOp::GtE,
}
}
}

fn format_import_from_module(level: u32, module: Option<&str>) -> String {
format!(
"{}{}",
Expand Down

0 comments on commit 2ffc3fa

Please sign in to comment.