Skip to content

Commit

Permalink
use Never type instead of Unknown for unbound public types
Browse files Browse the repository at this point in the history
Uses `Type::Never` instead of `Type::Unknown` for the case where
a publicly available variable is unbound. In the case where it is
unbound, we want a union of its actual type with `Never` instead of
`Unbound`, because the `Unbound` case will cause runtime error.
  • Loading branch information
pilleye committed Oct 15, 2024
1 parent b16f665 commit a18ae67
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
# Unbound

## Maybe unbound
## Maybe never

```py
```py path=maybe_never/maybe_never.py
if flag:
y = 3
x = y
reveal_type(x) # revealed: Unbound | Literal[3]
reveal_type(x) # revealed: Never | Literal[3]
```

```py path=maybe_never/public.py
from .maybe_never import x
reveal_type(x) # revealed: Literal[3]
```

## Maybe never annotated

```py path=maybe_never_annotated/maybe_never_annotated.py
if flag:
y: int = 3
x = y
reveal_type(x) # revealed: Never | int
```

```py path=maybe_never_annotated/public.py
from .maybe_never import x
reveal_type(x) # revealed: int
```

## Unbound

```py
```py path=unbound/
x = foo; foo = 1
reveal_type(x) # revealed: Unbound
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ reveal_type(x) # revealed: Literal[2, 3]

## Simple if-elif-else

```py
```py path=simple_if_elif_else/simple_if_elif_else.py
y = 1
y = 2
if flag:
Expand All @@ -28,11 +28,19 @@ else:
y = 5
s = y
x = y

reveal_type(x) # revealed: Literal[3, 4, 5]
reveal_type(r) # revealed: Unbound | Literal[2]
reveal_type(s) # revealed: Unbound | Literal[5]
```

```py path=simple_if_elif_else/public.py
from .simple_if_elif_else import r, s

reveal_type(r) # revealed: Literal[2]
reveal_type(s) # revealed: Literal[5]
```

## Single symbol across if-elif-else

```py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ reveal_type(y) # revealed: Literal[2, 3]

## Without wildcard

```py
```py path=without_wildcard/without_wildcard.py
match 0:
case 1:
y = 2
Expand All @@ -24,6 +24,12 @@ match 0:
reveal_type(y) # revealed: Unbound | Literal[2, 3]
```

```py path=without_wildcard/public.py
from .without_wildcard import y

reveal_type(y) # revealed: Literal[2, 3]
```

## Basic match

```py
Expand Down
32 changes: 28 additions & 4 deletions crates/red_knot_python_semantic/resources/mdtest/loops/for_loop.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Basic `for` loop

```py
```py path=basic_for_loop/basic_for_loop.py
class IntIterator:
def __next__(self) -> int:
return 42
Expand All @@ -17,6 +17,12 @@ for x in IntIterable():
reveal_type(x) # revealed: Unbound | int
```

```py path=basic_for_loop/public.py
from .basic_for_loop import x

reveal_type(x) # revealed: int
```

## With previous definition

```py
Expand Down Expand Up @@ -77,7 +83,7 @@ reveal_type(x) # revealed: int | Literal["foo"]

## With old-style iteration protocol

```py
```py path=without_oldstyle_iteration_protocol/without_oldstyle_iteration_protocol.py
class OldStyleIterable:
def __getitem__(self, key: int) -> int:
return 42
Expand All @@ -88,18 +94,30 @@ for x in OldStyleIterable():
reveal_type(x) # revealed: Unbound | int
```

```py path=without_oldstyle_iteration_protocol/public.py
from .without_oldstyle_iteration_protocol import x

reveal_type(x) # revealed: int
```

## With heterogeneous tuple

```py
```py path=with_heterogeneous_tuple/with_heterogeneous_tuple.py
for x in (1, 'a', b'foo'):
pass

reveal_type(x) # revealed: Unbound | Literal[1] | Literal["a"] | Literal[b"foo"]
```

```py path=with_heterogeneous_tuple/public.py
from .with_heterogeneous_tuple import x

reveal_type(x) # revealed: Literal[1] | Literal["a"] | Literal[b"foo"]
```

## With non-callable iterator

```py
```py path=with_noncallable_iterator/with_noncallable_iterator.py
class NotIterable:
if flag:
__iter__ = 1
Expand All @@ -112,6 +130,12 @@ for x in NotIterable(): # error: "Object of type `NotIterable` is not iterable"
reveal_type(x) # revealed: Unbound | Unknown
```

```py path=with_noncallable_iterator/with_noncallable_iterator.py
from .with_noncallable_iterator import x

reveal_type(x) # revealed: Unknown | int
```

## Invalid iterable

```py
Expand Down
45 changes: 27 additions & 18 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ fn symbol_ty_by_id<'db>(db: &'db dyn Db, scope: ScopeId<'db>, symbol: ScopedSymb
let _span = tracing::trace_span!("symbol_ty_by_id", ?symbol).entered();

let use_def = use_def_map(db, scope);
let unbound_ty = || use_def.public_may_be_unbound(symbol).then_some(Type::Never);

// If the symbol is declared, the public type is based on declarations; otherwise, it's based
// on inference from bindings.
Expand All @@ -58,9 +59,7 @@ fn symbol_ty_by_id<'db>(db: &'db dyn Db, scope: ScopeId<'db>, symbol: ScopedSymb
Some(bindings_ty(
db,
use_def.public_bindings(symbol),
use_def
.public_may_be_unbound(symbol)
.then_some(Type::Unknown),
unbound_ty(),
))
} else {
None
Expand All @@ -69,17 +68,11 @@ fn symbol_ty_by_id<'db>(db: &'db dyn Db, scope: ScopeId<'db>, symbol: ScopedSymb
// problem of the module we are importing from.
declarations_ty(db, declarations, undeclared_ty).unwrap_or_else(|(ty, _)| ty)
} else {
bindings_ty(
db,
use_def.public_bindings(symbol),
use_def
.public_may_be_unbound(symbol)
.then_some(Type::Unbound),
)
bindings_ty(db, use_def.public_bindings(symbol), unbound_ty())
}
}

/// Shorthand for `symbol_ty` that takes a symbol name instead of an ID.
/// Shorthand for `symbol_ty_by_id` that takes a symbol name instead of an ID.
fn symbol_ty<'db>(db: &'db dyn Db, scope: ScopeId<'db>, name: &str) -> Type<'db> {
let table = symbol_table(db, scope);
table
Expand Down Expand Up @@ -374,17 +367,33 @@ impl<'db> Type<'db> {
}
}

#[must_use]
pub fn replace_unbound_with(&self, db: &'db dyn Db, replacement: Type<'db>) -> Type<'db> {
fn replace_type_with(
&self,
db: &'db dyn Db,
target: Type<'db>,
replacement: Type<'db>,
) -> Type<'db> {
if self.is_equivalent_to(db, target) {
return replacement;
}

match self {
Type::Unbound => replacement,
Type::Union(union) => {
union.map(db, |element| element.replace_unbound_with(db, replacement))
}
ty => *ty,
Type::Union(union) => union.map(db, |element| {
element.replace_type_with(db, target, replacement)
}),
_ => *self,
}
}

#[must_use]
pub fn replace_unbound_with(&self, db: &'db dyn Db, replacement: Type<'db>) -> Type<'db> {
self.replace_type_with(db, Type::Unbound, replacement)
}

fn replace_never_with(&self, db: &'db dyn Db, replacement: Type<'db>) -> Type<'db> {
self.replace_type_with(db, Type::Never, replacement)
}

pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
match self {
Type::Class(class) => class.is_stdlib_symbol(db, module_name, name),
Expand Down
54 changes: 41 additions & 13 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1725,8 +1725,8 @@ impl<'db> TypeInferenceBuilder<'db> {

let member_ty = module_ty.member(self.db, &ast::name::Name::new(&name.id));

// TODO: What if it's a union where one of the elements is `Unbound`?
if member_ty.is_unbound() {
// TODO: What if it's a union where one of the elements is `Unbound` or `Never`?
if member_ty.is_unbound() || member_ty.is_never() {
self.add_diagnostic(
AnyNodeRef::Alias(alias),
"unresolved-import",
Expand All @@ -1743,7 +1743,9 @@ impl<'db> TypeInferenceBuilder<'db> {
// the runtime error will occur immediately (rather than when the symbol is *used*,
// as would be the case for a symbol with type `Unbound`), so it's appropriate to
// think of the type of the imported symbol as `Unknown` rather than `Unbound`
let ty = member_ty.replace_unbound_with(self.db, Type::Unknown);
let ty = member_ty
.replace_never_with(self.db, Type::Unknown)
.replace_unbound_with(self.db, Type::Unknown);

self.add_declaration_with_binding(alias.into(), definition, ty, ty);
}
Expand Down Expand Up @@ -2353,13 +2355,15 @@ impl<'db> TypeInferenceBuilder<'db> {
return symbol_ty(self.db, enclosing_scope_id, name);
}
}

// No nonlocal binding, check module globals. Avoid infinite recursion if `self.scope`
// already is module globals.
let ty = if file_scope_id.is_global() {
Type::Unbound
} else {
global_symbol_ty(self.db, self.file, name)
};

// Fallback to builtins (without infinite recursion if we're already in builtins.)
if ty.may_be_unbound(self.db) && Some(self.scope) != builtins_module_scope(self.db) {
let mut builtin_ty = builtins_symbol_ty(self.db, name);
Expand Down Expand Up @@ -3424,6 +3428,7 @@ mod tests {
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
use ruff_db::testing::assert_function_query_was_not_run;
use ruff_python_ast::name::Name;
use test_case::test_case;

use super::TypeInferenceBuilder;

Expand Down Expand Up @@ -3523,6 +3528,24 @@ mod tests {
assert_diagnostic_messages(&diagnostics, expected);
}

#[test]
fn imported_unbound_symbol_is_unknown() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_files([
("src/package/__init__.py", ""),
("src/package/foo.py", "x"),
("src/package/bar.py", "from package.foo import x"),
])?;

// the type as seen from external modules (`Unknown`)
// is different from the type inside the module itself (`Never`):
assert_public_ty(&db, "src/package/foo.py", "x", "Never");
assert_public_ty(&db, "src/package/bar.py", "x", "Unknown");

Ok(())
}

#[test]
fn from_import_with_no_module_name() -> anyhow::Result<()> {
// This test checks that invalid syntax in a `StmtImportFrom` node
Expand Down Expand Up @@ -3745,7 +3768,7 @@ mod tests {
)?;

// TODO: sys.version_info, and need to understand @final and @type_check_only
assert_public_ty(&db, "src/a.py", "x", "Unknown | EllipsisType");
assert_public_ty(&db, "src/a.py", "x", "EllipsisType | Unknown");

Ok(())
}
Expand Down Expand Up @@ -3856,24 +3879,29 @@ mod tests {
let y_ty = symbol_ty(&db, function_scope, "y");
let x_ty = symbol_ty(&db, function_scope, "x");

assert_eq!(x_ty.display(&db).to_string(), "Unbound");
assert_eq!(x_ty.display(&db).to_string(), "Never");
assert_eq!(y_ty.display(&db).to_string(), "Literal[1]");

Ok(())
}

#[test]
fn conditionally_global_or_builtin() -> anyhow::Result<()> {
#[test_case(""; "unannotated")]
// Tests that we only use the definition of a symbol instead of its declaration when we are
// checking module globals without a nonlocal binding.
#[test_case(": int"; "annotated")]
fn conditionally_global_or_builtin(annotation: &'static str) -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
if flag:
copyright = 1
def f():
y = copyright
&format!(
"
if flag:
copyright{annotation} = 1
def f():
y = copyright
",
),
)?;

let file = system_path_to_file(&db, "src/a.py").expect("file to exist");
Expand Down Expand Up @@ -4404,7 +4432,7 @@ mod tests {
",
)?;

assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "z", "Unbound");
assert_scope_ty(&db, "src/a.py", &["foo", "<listcomp>"], "z", "Never");

// (There is a diagnostic for invalid syntax that's emitted, but it's not listed by `assert_file_diagnostics`)
assert_file_diagnostics(&db, "src/a.py", &[]);
Expand Down

0 comments on commit a18ae67

Please sign in to comment.