diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index de3384e4e54cd3..336830aa267c99 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1654,48 +1654,47 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_fstring_expression(&mut self, fstring: &ast::ExprFString) -> Type<'db> { let ast::ExprFString { range: _, value } = fstring; + // When we infer an fstring, there are only 2 outcomes: + // - The fstring contains *any* expression, and we infer `builtins.str` + // - The fstring contains *only* literals, and we use the same logic as + // `infer_string_literal_expression` + let mut has_expression = false; + let mut literals = Vec::new(); for part in value { match part { - ast::FStringPart::Literal(_) => { - // TODO string literal type + ast::FStringPart::Literal(literal) => { + literals.push(&literal.value); } ast::FStringPart::FString(fstring) => { - let ast::FString { - range: _, - elements, - flags: _, - } = fstring; - for element in elements { - self.infer_fstring_element(element); + for element in fstring.elements.into_iter() { + match element { + ast::FStringElement::Expression(_) => { + // We can short-circuit on any found expression + has_expression = true; + break; + } + ast::FStringElement::Literal(literal) => { + literals.push(&literal.value); + } + } } } } } - // TODO str type - Type::Unknown - } - - fn infer_fstring_element(&mut self, element: &ast::FStringElement) { - match element { - ast::FStringElement::Literal(_) => { - // TODO string literal type - } - ast::FStringElement::Expression(expr_element) => { - let ast::FStringExpressionElement { - range: _, - expression, - debug_text: _, - conversion: _, - format_spec, - } = expr_element; - self.infer_expression(expression); - - if let Some(format_spec) = format_spec { - for spec_element in &format_spec.elements { - self.infer_fstring_element(spec_element); - } - } + if has_expression { + builtins_symbol_ty(self.db, "str").to_instance(self.db) + } else { + if literals.iter().fold(0, |acc, box_str| acc + box_str.len()) + <= Self::MAX_STRING_LITERAL_SIZE + { + let concatenated: String = literals.into_iter().map(Box::as_ref).collect(); + Type::StringLiteral(StringLiteralType::new( + self.db, + concatenated.into_boxed_str(), + )) + } else { + Type::LiteralString } } } @@ -2574,10 +2573,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_call_expression(call_expr); Type::Unknown } - ast::Expr::FString(fstring) => { - self.infer_fstring_expression(fstring); - Type::Unknown - } + ast::Expr::FString(_) => Type::Unknown, // ast::Expr::Attribute(attribute) => { self.infer_attribute_expression(attribute); @@ -3362,6 +3358,29 @@ mod tests { Ok(()) } + #[test] + fn fstring_expression() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + x = 0 + a = f'hello' + b = f'hello {x}' + c = 'one ' f'single ' f'literal' + d = 'first ' f'second({x})' f'third' + ", + )?; + + assert_public_ty(&db, "src/a.py", "a", "Literal[\"hello\"]"); + assert_public_ty(&db, "src/a.py", "b", "str"); + assert_public_ty(&db, "src/a.py", "c", "Literal[\"one single literal\"]"); + assert_public_ty(&db, "src/a.py", "d", "str"); + + Ok(()) + } + #[test] fn basic_call_expression() -> anyhow::Result<()> { let mut db = setup_db();