Skip to content

Commit

Permalink
fixup! [red-knot] feat: implement and use repr/str for f-strings
Browse files Browse the repository at this point in the history
  • Loading branch information
Slyces committed Sep 27, 2024
1 parent 04ed1b5 commit 36a2fa4
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 85 deletions.
89 changes: 44 additions & 45 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ impl<'db> Type<'db> {
}
}

pub fn builtin_str(db: &'db dyn Db) -> Self {
builtins_symbol_ty(db, "str")
}

pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
match self {
Type::Class(class) => class.is_stdlib_symbol(db, module_name, name),
Expand Down Expand Up @@ -723,41 +727,40 @@ impl<'db> Type<'db> {
}

/// Return the string representation of this type when converted to string as it would be
/// provided by the `__str__` method. If that can't be determined, return `None`.
/// provided by the `__str__` method.
///
/// When not available, this should fall back to the value of `[Type::repr]`.
/// Note: this method is used in the builtins `format`, `print`, `str.format` and `f-strings`.
pub fn str(&self, db: &'db dyn Db) -> Option<Type<'db>> {
let str_result = match self {
Type::IntLiteral(_) => None,
Type::BooleanLiteral(_) => None,
Type::StringLiteral(_) => Some(*self),
#[must_use]
pub fn str(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::IntLiteral(_) | Type::BooleanLiteral(_) => self.repr(db),
Type::StringLiteral(_) | Type::LiteralString => *self,
// TODO: handle more complex types
_ => None,
};
str_result.or_else(|| self.repr(db))
_ => builtins_symbol_ty(db, "str").to_instance(db),
}
}

/// Return the string representation of this type as it would be provided by the `__repr__`
/// method at runtime. If that can't be determined, return `None`.
pub fn repr(&self, db: &'db dyn Db) -> Option<Type<'db>> {
/// method at runtime.
#[must_use]
pub fn repr(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::IntLiteral(number) => Some(Type::StringLiteral(StringLiteralType::new(db, {
Type::IntLiteral(number) => Type::StringLiteral(StringLiteralType::new(db, {
number.to_string().into_boxed_str()
}))),
Type::BooleanLiteral(true) => Some(Type::StringLiteral(StringLiteralType::new(db, {
"True".into()
}))),
Type::BooleanLiteral(false) => Some(Type::StringLiteral(StringLiteralType::new(db, {
"False".into()
}))),
Type::StringLiteral(literal) => {
Some(Type::StringLiteral(StringLiteralType::new(db, {
format!("'{}'", literal.value(db)).into()
})))
})),
Type::BooleanLiteral(true) => {
Type::StringLiteral(StringLiteralType::new(db, "True".into()))
}
Type::BooleanLiteral(false) => {
Type::StringLiteral(StringLiteralType::new(db, "False".into()))
}
Type::StringLiteral(literal) => Type::StringLiteral(StringLiteralType::new(db, {
format!("'{}'", literal.value(db).escape_default()).into()
})),
Type::LiteralString => Type::LiteralString,
// TODO: handle more complex types
_ => None,
_ => builtins_symbol_ty(db, "str").to_instance(db),
}
}
}
Expand Down Expand Up @@ -1237,7 +1240,7 @@ mod tests {

/// A test representation of a type that can be transformed unambiguously into a real Type,
/// given a db.
#[derive(Debug)]
#[derive(Debug, Clone)]
enum Ty {
Never,
Unknown,
Expand Down Expand Up @@ -1373,31 +1376,27 @@ mod tests {
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::Ambiguous);
}

#[test_case(Ty::IntLiteral(1), Some("1"))]
#[test_case(Ty::BoolLiteral(true), Some("True"))]
#[test_case(Ty::BoolLiteral(false), Some("False"))]
#[test_case(Ty::StringLiteral("hello"), Some("hello"))] // no quotes
#[test_case(Ty::LiteralString, None)]
#[test_case(Ty::BuiltinInstance("int"), None)]
fn has_correct_str(ty: Ty, expected: Option<&str>) {
#[test_case(Ty::IntLiteral(1), Ty::StringLiteral("1"))]
#[test_case(Ty::BoolLiteral(true), Ty::StringLiteral("True"))]
#[test_case(Ty::BoolLiteral(false), Ty::StringLiteral("False"))]
#[test_case(Ty::StringLiteral("ab'cd"), Ty::StringLiteral("ab'cd"))] // no quotes
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))]
fn has_correct_str(ty: Ty, expected: Ty) {
let db = setup_db();

let expected = expected.map(|s| Type::StringLiteral(StringLiteralType::new(&db, s.into())));

assert_eq!(ty.into_type(&db).str(&db), expected);
assert_eq!(ty.into_type(&db).str(&db), expected.into_type(&db));
}

#[test_case(Ty::IntLiteral(1), Some("1"))]
#[test_case(Ty::BoolLiteral(true), Some("True"))]
#[test_case(Ty::BoolLiteral(false), Some("False"))]
#[test_case(Ty::StringLiteral("hello"), Some("'hello'"))] // single quotes
#[test_case(Ty::LiteralString, None)]
#[test_case(Ty::BuiltinInstance("int"), None)]
fn has_correct_repr(ty: Ty, expected: Option<&str>) {
#[test_case(Ty::IntLiteral(1), Ty::StringLiteral("1"))]
#[test_case(Ty::BoolLiteral(true), Ty::StringLiteral("True"))]
#[test_case(Ty::BoolLiteral(false), Ty::StringLiteral("False"))]
#[test_case(Ty::StringLiteral("ab'cd"), Ty::StringLiteral("'ab\\'cd'"))] // single quotes
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))]
fn has_correct_repr(ty: Ty, expected: Ty) {
let db = setup_db();

let expected = expected.map(|s| Type::StringLiteral(StringLiteralType::new(&db, s.into())));

assert_eq!(ty.into_type(&db).repr(&db), expected);
assert_eq!(ty.into_type(&db).repr(&db), expected.into_type(&db));
}
}
104 changes: 64 additions & 40 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1653,13 +1653,13 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_fstring_expression(&mut self, fstring: &ast::ExprFString) -> Type<'db> {
let ast::ExprFString { range: _, value } = fstring;

let mut done = false;
let mut has_expression = false;
let mut concatenated = String::new();
let mut collector = StringPartsCollector::new();
for part in value {
// Make sure we iter through every parts to infer all sub-expressions. The `collector`
// struct ensures we don't allocate unnecessary strings.
match part {
ast::FStringPart::Literal(literal) => {
concatenated.push_str(&literal.value);
collector.push_str(&literal.value);
}
ast::FStringPart::FString(fstring) => {
for element in &fstring.elements {
Expand All @@ -1672,54 +1672,31 @@ impl<'db> TypeInferenceBuilder<'db> {
conversion,
format_spec,
} = expression;
// Always infer sub-expressions, even if we've figured out the type
let ty = self.infer_expression(expression);
if !done {
// TODO: handle format specifiers by calling a method
// (`Type::format`?) that handles the `__format__` method.
// Conversion flags should be handled before calling
// `__format__`.
// https://docs.python.org/3/library/string.html#format-string-syntax
if !conversion.is_none() || format_spec.is_some() {
has_expression = true;
done = true;

// TODO: handle format specifiers by calling a method
// (`Type::format`?) that handles the `__format__` method.
// Conversion flags should be handled before calling `__format__`.
// https://docs.python.org/3/library/string.html#format-string-syntax
if !conversion.is_none() || format_spec.is_some() {
collector.add_expression();
} else {
if let Type::StringLiteral(literal) = ty.str(self.db) {
collector.push_str(literal.value(self.db));
} else {
if let Some(Type::StringLiteral(literal)) = ty.str(self.db)
{
concatenated.push_str(literal.value(self.db));
} else {
has_expression = true;
done = true;
}
collector.add_expression();
}
}
}
ast::FStringElement::Literal(literal) => {
if !done {
concatenated.push_str(&literal.value);
}
collector.push_str(&literal.value);
}
}
}
}
}
if concatenated.len() > Self::MAX_STRING_LITERAL_SIZE {
done = true;
}
}

if has_expression {
builtins_symbol_ty(self.db, "str").to_instance(self.db)
} else {
if concatenated.len() <= Self::MAX_STRING_LITERAL_SIZE {
Type::StringLiteral(StringLiteralType::new(
self.db,
concatenated.into_boxed_str(),
))
} else {
Type::LiteralString
}
}
collector.ty(self.db)
}

fn infer_ellipsis_literal_expression(
Expand Down Expand Up @@ -2682,6 +2659,53 @@ enum ModuleNameResolutionError {
TooManyDots,
}

/// Struct collecting string parts when inferring a formatted string. Infers a string literal if the
/// concatenated string is small enough, otherwise infers a literal string.
///
/// If the formatted string contains an expression (with a representation unknown at compile time),
/// infers an instance of `builtins.str`.
struct StringPartsCollector {
concatenated: Option<String>,
expression: bool,
}

impl StringPartsCollector {
fn new() -> Self {
Self {
concatenated: Some(String::new()),
expression: false,
}
}

fn push_str(&mut self, literal: &str) {
if let Some(mut concatenated) = self.concatenated.take() {
if concatenated.len().saturating_add(literal.len())
<= TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE
{
concatenated.push_str(literal);
self.concatenated = Some(concatenated);
} else {
self.concatenated = None;
}
}
}

fn add_expression(&mut self) {
self.concatenated = None;
self.expression = true;
}

fn ty(self, db: &dyn Db) -> Type {
if self.expression {
Type::builtin_str(db).to_instance(db)
} else if let Some(concatenated) = self.concatenated {
Type::StringLiteral(StringLiteralType::new(db, concatenated.into_boxed_str()))
} else {
Type::LiteralString
}
}
}

#[cfg(test)]
mod tests {

Expand Down

0 comments on commit 36a2fa4

Please sign in to comment.