diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 8df0d39be82e5..e24d2bdc96558 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -380,6 +380,10 @@ impl<'db> Type<'db> { } } + pub fn builtin_str(db: &'db dyn Db) -> Self { + builtins_symbol_ty(db, "str") + } + pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool { match self { Type::Class(class) => class.is_stdlib_symbol(db, module_name, name), @@ -723,41 +727,40 @@ impl<'db> Type<'db> { } /// Return the string representation of this type when converted to string as it would be - /// provided by the `__str__` method. If that can't be determined, return `None`. + /// provided by the `__str__` method. /// /// When not available, this should fall back to the value of `[Type::repr]`. /// Note: this method is used in the builtins `format`, `print`, `str.format` and `f-strings`. - pub fn str(&self, db: &'db dyn Db) -> Option> { - let str_result = match self { - Type::IntLiteral(_) => None, - Type::BooleanLiteral(_) => None, - Type::StringLiteral(_) => Some(*self), + #[must_use] + pub fn str(&self, db: &'db dyn Db) -> Type<'db> { + match self { + Type::IntLiteral(_) | Type::BooleanLiteral(_) => self.repr(db), + Type::StringLiteral(_) | Type::LiteralString => *self, // TODO: handle more complex types - _ => None, - }; - str_result.or_else(|| self.repr(db)) + _ => builtins_symbol_ty(db, "str").to_instance(db), + } } /// Return the string representation of this type as it would be provided by the `__repr__` - /// method at runtime. If that can't be determined, return `None`. - pub fn repr(&self, db: &'db dyn Db) -> Option> { + /// method at runtime. + #[must_use] + pub fn repr(&self, db: &'db dyn Db) -> Type<'db> { match self { - Type::IntLiteral(number) => Some(Type::StringLiteral(StringLiteralType::new(db, { + Type::IntLiteral(number) => Type::StringLiteral(StringLiteralType::new(db, { number.to_string().into_boxed_str() - }))), - Type::BooleanLiteral(true) => Some(Type::StringLiteral(StringLiteralType::new(db, { - "True".into() - }))), - Type::BooleanLiteral(false) => Some(Type::StringLiteral(StringLiteralType::new(db, { - "False".into() - }))), - Type::StringLiteral(literal) => { - Some(Type::StringLiteral(StringLiteralType::new(db, { - format!("'{}'", literal.value(db)).into() - }))) + })), + Type::BooleanLiteral(true) => { + Type::StringLiteral(StringLiteralType::new(db, "True".into())) + } + Type::BooleanLiteral(false) => { + Type::StringLiteral(StringLiteralType::new(db, "False".into())) } + Type::StringLiteral(literal) => Type::StringLiteral(StringLiteralType::new(db, { + format!("'{}'", literal.value(db).escape_default()).into() + })), + Type::LiteralString => Type::LiteralString, // TODO: handle more complex types - _ => None, + _ => builtins_symbol_ty(db, "str").to_instance(db), } } } @@ -1237,7 +1240,7 @@ mod tests { /// A test representation of a type that can be transformed unambiguously into a real Type, /// given a db. - #[derive(Debug)] + #[derive(Debug, Clone)] enum Ty { Never, Unknown, @@ -1373,31 +1376,27 @@ mod tests { assert_eq!(ty.into_type(&db).bool(&db), Truthiness::Ambiguous); } - #[test_case(Ty::IntLiteral(1), Some("1"))] - #[test_case(Ty::BoolLiteral(true), Some("True"))] - #[test_case(Ty::BoolLiteral(false), Some("False"))] - #[test_case(Ty::StringLiteral("hello"), Some("hello"))] // no quotes - #[test_case(Ty::LiteralString, None)] - #[test_case(Ty::BuiltinInstance("int"), None)] - fn has_correct_str(ty: Ty, expected: Option<&str>) { + #[test_case(Ty::IntLiteral(1), Ty::StringLiteral("1"))] + #[test_case(Ty::BoolLiteral(true), Ty::StringLiteral("True"))] + #[test_case(Ty::BoolLiteral(false), Ty::StringLiteral("False"))] + #[test_case(Ty::StringLiteral("ab'cd"), Ty::StringLiteral("ab'cd"))] // no quotes + #[test_case(Ty::LiteralString, Ty::LiteralString)] + #[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))] + fn has_correct_str(ty: Ty, expected: Ty) { let db = setup_db(); - let expected = expected.map(|s| Type::StringLiteral(StringLiteralType::new(&db, s.into()))); - - assert_eq!(ty.into_type(&db).str(&db), expected); + assert_eq!(ty.into_type(&db).str(&db), expected.into_type(&db)); } - #[test_case(Ty::IntLiteral(1), Some("1"))] - #[test_case(Ty::BoolLiteral(true), Some("True"))] - #[test_case(Ty::BoolLiteral(false), Some("False"))] - #[test_case(Ty::StringLiteral("hello"), Some("'hello'"))] // single quotes - #[test_case(Ty::LiteralString, None)] - #[test_case(Ty::BuiltinInstance("int"), None)] - fn has_correct_repr(ty: Ty, expected: Option<&str>) { + #[test_case(Ty::IntLiteral(1), Ty::StringLiteral("1"))] + #[test_case(Ty::BoolLiteral(true), Ty::StringLiteral("True"))] + #[test_case(Ty::BoolLiteral(false), Ty::StringLiteral("False"))] + #[test_case(Ty::StringLiteral("ab'cd"), Ty::StringLiteral("'ab\\'cd'"))] // single quotes + #[test_case(Ty::LiteralString, Ty::LiteralString)] + #[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))] + fn has_correct_repr(ty: Ty, expected: Ty) { let db = setup_db(); - let expected = expected.map(|s| Type::StringLiteral(StringLiteralType::new(&db, s.into()))); - - assert_eq!(ty.into_type(&db).repr(&db), expected); + assert_eq!(ty.into_type(&db).repr(&db), expected.into_type(&db)); } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index f5d5897be834c..32484fe2bf604 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1653,13 +1653,13 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_fstring_expression(&mut self, fstring: &ast::ExprFString) -> Type<'db> { let ast::ExprFString { range: _, value } = fstring; - let mut done = false; - let mut has_expression = false; - let mut concatenated = String::new(); + let mut collector = StringPartsCollector::new(); for part in value { + // Make sure we iter through every parts to infer all sub-expressions. The `collector` + // struct ensures we don't allocate unnecessary strings. match part { ast::FStringPart::Literal(literal) => { - concatenated.push_str(&literal.value); + collector.push_str(&literal.value); } ast::FStringPart::FString(fstring) => { for element in &fstring.elements { @@ -1672,54 +1672,31 @@ impl<'db> TypeInferenceBuilder<'db> { conversion, format_spec, } = expression; - // Always infer sub-expressions, even if we've figured out the type let ty = self.infer_expression(expression); - if !done { - // TODO: handle format specifiers by calling a method - // (`Type::format`?) that handles the `__format__` method. - // Conversion flags should be handled before calling - // `__format__`. - // https://docs.python.org/3/library/string.html#format-string-syntax - if !conversion.is_none() || format_spec.is_some() { - has_expression = true; - done = true; + + // TODO: handle format specifiers by calling a method + // (`Type::format`?) that handles the `__format__` method. + // Conversion flags should be handled before calling `__format__`. + // https://docs.python.org/3/library/string.html#format-string-syntax + if !conversion.is_none() || format_spec.is_some() { + collector.add_expression(); + } else { + if let Type::StringLiteral(literal) = ty.str(self.db) { + collector.push_str(literal.value(self.db)); } else { - if let Some(Type::StringLiteral(literal)) = ty.str(self.db) - { - concatenated.push_str(literal.value(self.db)); - } else { - has_expression = true; - done = true; - } + collector.add_expression(); } } } ast::FStringElement::Literal(literal) => { - if !done { - concatenated.push_str(&literal.value); - } + collector.push_str(&literal.value); } } } } } - if concatenated.len() > Self::MAX_STRING_LITERAL_SIZE { - done = true; - } - } - - if has_expression { - builtins_symbol_ty(self.db, "str").to_instance(self.db) - } else { - if concatenated.len() <= Self::MAX_STRING_LITERAL_SIZE { - Type::StringLiteral(StringLiteralType::new( - self.db, - concatenated.into_boxed_str(), - )) - } else { - Type::LiteralString - } } + collector.ty(self.db) } fn infer_ellipsis_literal_expression( @@ -2682,6 +2659,53 @@ enum ModuleNameResolutionError { TooManyDots, } +/// Struct collecting string parts when inferring a formatted string. Infers a string literal if the +/// concatenated string is small enough, otherwise infers a literal string. +/// +/// If the formatted string contains an expression (with a representation unknown at compile time), +/// infers an instance of `builtins.str`. +struct StringPartsCollector { + concatenated: Option, + expression: bool, +} + +impl StringPartsCollector { + fn new() -> Self { + Self { + concatenated: Some(String::new()), + expression: false, + } + } + + fn push_str(&mut self, literal: &str) { + if let Some(mut concatenated) = self.concatenated.take() { + if concatenated.len().saturating_add(literal.len()) + <= TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + { + concatenated.push_str(literal); + self.concatenated = Some(concatenated); + } else { + self.concatenated = None; + } + } + } + + fn add_expression(&mut self) { + self.concatenated = None; + self.expression = true; + } + + fn ty(self, db: &dyn Db) -> Type { + if self.expression { + Type::builtin_str(db).to_instance(db) + } else if let Some(concatenated) = self.concatenated { + Type::StringLiteral(StringLiteralType::new(db, concatenated.into_boxed_str())) + } else { + Type::LiteralString + } + } +} + #[cfg(test)] mod tests {