From 29177339bc8115d1ffba348f2b7c5be0d5a8df8a Mon Sep 17 00:00:00 2001 From: Erin Moore Date: Wed, 29 Jan 2025 10:45:07 -0800 Subject: [PATCH] Add Lambda nodes to the ast and start parsing. github.com/google/xls/issues/1671 PiperOrigin-RevId: 721032937 --- xls/dslx/bytecode/bytecode_emitter.cc | 4 ++ xls/dslx/bytecode/bytecode_emitter.h | 1 + xls/dslx/constexpr_evaluator.cc | 4 ++ xls/dslx/constexpr_evaluator.h | 1 + xls/dslx/frontend/ast.cc | 40 +++++++++++++++++ xls/dslx/frontend/ast.h | 43 +++++++++++++++++++ xls/dslx/frontend/ast_cloner.cc | 4 ++ xls/dslx/frontend/ast_node.h | 1 + xls/dslx/frontend/parser.cc | 30 +++++++++++++ xls/dslx/frontend/parser.h | 2 + xls/dslx/frontend/parser_test.cc | 42 ++++++++++++++++++ .../ir_convert/extract_conversion_order.cc | 4 ++ xls/dslx/ir_convert/function_converter.cc | 9 +++- xls/dslx/ir_convert/function_converter.h | 1 + xls/dslx/type_system/deduce.cc | 6 +++ xls/dslx/type_system/type_info.proto | 1 + xls/dslx/type_system/type_info_to_proto.cc | 4 ++ 17 files changed, 195 insertions(+), 2 deletions(-) diff --git a/xls/dslx/bytecode/bytecode_emitter.cc b/xls/dslx/bytecode/bytecode_emitter.cc index 5ff54b15df..850b3f8472 100644 --- a/xls/dslx/bytecode/bytecode_emitter.cc +++ b/xls/dslx/bytecode/bytecode_emitter.cc @@ -1307,6 +1307,10 @@ absl::Status BytecodeEmitter::DestructureLet( return absl::OkStatus(); } +absl::Status BytecodeEmitter::HandleLambda(const Lambda* node) { + return absl::UnimplementedError("lambdas not yet supported"); +} + absl::Status BytecodeEmitter::HandleLet(const Let* node) { XLS_RETURN_IF_ERROR(node->rhs()->AcceptExpr(this)); std::optional type = type_info_->GetItem(node->rhs()); diff --git a/xls/dslx/bytecode/bytecode_emitter.h b/xls/dslx/bytecode/bytecode_emitter.h index 0edad791c8..dee4680e58 100644 --- a/xls/dslx/bytecode/bytecode_emitter.h +++ b/xls/dslx/bytecode/bytecode_emitter.h @@ -119,6 +119,7 @@ class BytecodeEmitter : public ExprVisitor { absl::Status HandleWidthSlice(const Index* node, WidthSlice* width_slice); absl::Status HandleInvocation(const Invocation* node) override; + absl::Status HandleLambda(const Lambda* node) override; absl::Status HandleLet(const Let* node) override; absl::Status HandleMatch(const Match* node) override; absl::Status HandleNameRef(const NameRef* node) override; diff --git a/xls/dslx/constexpr_evaluator.cc b/xls/dslx/constexpr_evaluator.cc index 325c67be07..1e2c12b80e 100644 --- a/xls/dslx/constexpr_evaluator.cc +++ b/xls/dslx/constexpr_evaluator.cc @@ -432,6 +432,10 @@ absl::Status ConstexprEvaluator::HandleInvocation(const Invocation* expr) { return absl::OkStatus(); } +absl::Status ConstexprEvaluator::HandleLambda(const Lambda* expr) { + return absl::OkStatus(); +} + absl::Status ConstexprEvaluator::HandleMatch(const Match* expr) { EVAL_AS_CONSTEXPR_OR_RETURN(expr->matched()); diff --git a/xls/dslx/constexpr_evaluator.h b/xls/dslx/constexpr_evaluator.h index 37bcdc0d52..a2400ae1a6 100644 --- a/xls/dslx/constexpr_evaluator.h +++ b/xls/dslx/constexpr_evaluator.h @@ -83,6 +83,7 @@ class ConstexprEvaluator : public xls::dslx::ExprVisitor { absl::Status HandleAllOnesMacro(const AllOnesMacro* expr) override; absl::Status HandleIndex(const Index* expr) override; absl::Status HandleInvocation(const Invocation* expr) override; + absl::Status HandleLambda(const Lambda* expr) override; absl::Status HandleLet(const Let* expr) override { return absl::OkStatus(); } absl::Status HandleMatch(const Match* expr) override; absl::Status HandleNameRef(const NameRef* expr) override; diff --git a/xls/dslx/frontend/ast.cc b/xls/dslx/frontend/ast.cc index 3b62233778..f994644b5f 100644 --- a/xls/dslx/frontend/ast.cc +++ b/xls/dslx/frontend/ast.cc @@ -306,6 +306,8 @@ std::string_view AstNodeKindToString(AstNodeKind kind) { return "cast"; case AstNodeKind::kConstantDef: return "constant definition"; + case AstNodeKind::kLambda: + return "lambda"; case AstNodeKind::kLet: return "let"; case AstNodeKind::kChannelDecl: @@ -2088,6 +2090,44 @@ std::vector Function::GetFreeParametricKeys() const { TestFunction::~TestFunction() = default; +// -- class Lambda + +Lambda::Lambda(Module* owner, Span span, std::vector params, + TypeAnnotation* return_type, StatementBlock* body) + : Expr(owner, std::move(span)), + params_(std::move(params)), + return_type_(return_type), + body_(body) {} + +Lambda::~Lambda() = default; + +std::vector Lambda::GetChildren(bool want_types) const { + std::vector results; + for (Param* p : params()) { + results.push_back(p); + } + if (return_type_ != nullptr && want_types) { + results.push_back(return_type_); + } + results.push_back(body_); + return results; +} + +std::string Lambda::ToStringInternal() const { + std::string params_str = + absl::StrJoin(params(), ", ", [](std::string* out, Param* param) { + absl::StrAppend(out, param->ToString()); + }); + + std::string return_str = return_type_ != nullptr + ? absl::StrCat(" -> ", return_type_->ToString()) + : ""; + std::string body_str = + body_->size() > 1 ? body_->ToString() : body_->ToInlineString(); + + return absl::StrFormat("|%s|%s %s", params_str, return_str, body_str); +} + // -- class MatchArm MatchArm::MatchArm(Module* owner, Span span, std::vector patterns, diff --git a/xls/dslx/frontend/ast.h b/xls/dslx/frontend/ast.h index 0367acc2fe..e7cc49ab24 100644 --- a/xls/dslx/frontend/ast.h +++ b/xls/dslx/frontend/ast.h @@ -59,6 +59,7 @@ X(FunctionRef) \ X(Index) \ X(Invocation) \ + X(Lambda) \ X(Let) \ X(Match) \ X(NameRef) \ @@ -2084,6 +2085,48 @@ class Function : public AstNode { bool disable_format_ = false; }; +// A lambda expression. +// Syntax: `|[: ] { }` +// +// Parameter types and return type are optional. +// +// Example: `let squares = map(range(u32:0, u32:5), |x| { x * x });` +// +// Attributes: +// * params: The explicit parameters of the lambda. +// * return_type: The return type of the lambda. +// * body: The body of the lambda. +class Lambda : public Expr { + public: + Lambda(Module* owner, Span span, std::vector params, + TypeAnnotation* return_type, StatementBlock* body); + + ~Lambda() override; + + AstNodeKind kind() const override { return AstNodeKind::kLambda; } + absl::Status Accept(AstNodeVisitor* v) const override { + return v->HandleLambda(this); + } + absl::Status AcceptExpr(ExprVisitor* v) const override { + return v->HandleLambda(this); + } + std::string_view GetNodeTypeName() const override { return "Lambda"; } + std::vector GetChildren(bool want_types) const override; + + const std::vector& params() const { return params_; } + + private: + std::vector params_; + TypeAnnotation* return_type_; // May be null. + StatementBlock* body_; + + std::string ToStringInternal() const final; + + Precedence GetPrecedenceWithoutParens() const final { + return Precedence::kStrongest; + } +}; + // Represents a single arm in a match expression. // // Attributes: diff --git a/xls/dslx/frontend/ast_cloner.cc b/xls/dslx/frontend/ast_cloner.cc index a2bf0c2555..372bba41d9 100644 --- a/xls/dslx/frontend/ast_cloner.cc +++ b/xls/dslx/frontend/ast_cloner.cc @@ -471,6 +471,10 @@ class AstCloner : public AstNodeVisitor { return absl::OkStatus(); } + absl::Status HandleLambda(const Lambda* n) override { + return absl::UnimplementedError("lambdas not yet supported"); + } + absl::Status HandleLet(const Let* n) override { XLS_RETURN_IF_ERROR(VisitChildren(n)); diff --git a/xls/dslx/frontend/ast_node.h b/xls/dslx/frontend/ast_node.h index c291a51a93..4c8d9989a8 100644 --- a/xls/dslx/frontend/ast_node.h +++ b/xls/dslx/frontend/ast_node.h @@ -55,6 +55,7 @@ enum class AstNodeKind : uint8_t { kInstantiation, kInvocation, kJoin, + kLambda, kLet, kMatch, kMatchArm, diff --git a/xls/dslx/frontend/parser.cc b/xls/dslx/frontend/parser.cc index bcc23260b2..2b085f3cf6 100644 --- a/xls/dslx/frontend/parser.cc +++ b/xls/dslx/frontend/parser.cc @@ -238,6 +238,32 @@ absl::StatusOr Parser::ParseFunction( return f; } +// Lambda syntax: | [: ], ... | [-> ] { } +absl::StatusOr Parser::ParseLambda(Bindings& bindings) { + Pos start_pos = GetPos(); + VLOG(5) << "ParseLambda @ " << start_pos; + XLS_ASSIGN_OR_RETURN(const Token* peek, PeekToken()); + std::vector params; + if (peek->kind() == TokenKind::kBar) { + XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kBar)); + auto parse_param = [&] { return ParseParam(bindings); }; + XLS_ASSIGN_OR_RETURN(params, + ParseCommaSeq(parse_param, TokenKind::kBar)); + } else { + XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kDoubleBar)); + } + + XLS_ASSIGN_OR_RETURN(bool dropped_arrow, TryDropToken(TokenKind::kArrow)); + TypeAnnotation* return_type = nullptr; + if (dropped_arrow) { + XLS_ASSIGN_OR_RETURN(return_type, ParseTypeAnnotation(bindings)); + } + + XLS_ASSIGN_OR_RETURN(StatementBlock * body, ParseBlockExpression(bindings)); + return module_->Make(Span(start_pos, GetPos()), params, return_type, + body); +} + absl::Status Parser::ParseModuleAttribute() { XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kOBrack)); Span attribute_span; @@ -666,6 +692,10 @@ absl::StatusOr Parser::ParseExpression(Bindings& bindings, if (peek->kind() == TokenKind::kOBrace) { return ParseBlockExpression(bindings); } + if (peek->kind() == TokenKind::kBar || + peek->kind() == TokenKind::kDoubleBar) { + return ParseLambda(bindings); + } return ParseConditionalExpression(bindings, restrictions); } diff --git a/xls/dslx/frontend/parser.h b/xls/dslx/frontend/parser.h index 9d9f3b8e06..6f9b5a19e8 100644 --- a/xls/dslx/frontend/parser.h +++ b/xls/dslx/frontend/parser.h @@ -120,6 +120,8 @@ class Parser : public TokenParser { const Pos& start_pos, bool is_public, Bindings& bindings, absl::flat_hash_map* name_to_fn = nullptr); + absl::StatusOr ParseLambda(Bindings& bindings); + absl::StatusOr ParseProc(const Pos& start_pos, bool is_public, Bindings& bindings); diff --git a/xls/dslx/frontend/parser_test.cc b/xls/dslx/frontend/parser_test.cc index 71007eba44..7a7fd270f3 100644 --- a/xls/dslx/frontend/parser_test.cc +++ b/xls/dslx/frontend/parser_test.cc @@ -3354,6 +3354,48 @@ TEST_F(ParserTest, ParseParametricProcWithConstAssert) { EXPECT_EQ(p->ToString(), text); } +TEST_F(ParserTest, ParseMapWithLambdaParamAnnotation) { + RoundTrip(R"(const ARR = map(range(0, u16:5), |i: u16| { 2 * i });)"); +} + +TEST_F(ParserTest, LambdaInLetNoParams) { + RoundTrip(R"(fn uses_lambda(i: u32) -> u32 { + let X = || { u32:2 * i }; + X() +})"); +} + +TEST_F(ParserTest, LambdaInLetNoParamsWithReturn) { + RoundTrip(R"(fn uses_lambda(i: u32) -> u32 { + let X = || -> u32 { u32:2 * i }; + X() +})"); +} + +TEST_F(ParserTest, ParseMapWithLambdaCapture) { + RoundTrip(R"(const X = u16:3; +const ARR = map(range(0, u16:5), |i: u16| { X * i });)"); +} + +TEST_F(ParserTest, ParseMapWithLambdaWithReturnType) { + RoundTrip( + R"(const ARR = map(range(0, u16:5), |i: u16| -> u32 { 2 * i as u32 });)"); +} + +TEST_F(ParserTest, ParseMapWithLambdaMultilineBody) { + RoundTrip( + R"(const ARR = map(range(0, u16:5), |i: u16| { + let x = 2 * i as u32; + x +});)"); +} + +// TODO(https://github.com/google/xls/issues/1671): Support once +// `AnyTypeAnnotation` is available. +TEST_F(ParserTest, DISABLED_ParseMapWithLambdaNoParamAnnotation) { + RoundTrip(R"(const ARR = map(range(0, u16:5), |i| { 2 * i });)"); +} + TEST_F(ParserTest, ParseParametricInMapBuiltin) { constexpr std::string_view kProgram = R"( fn truncate(x: bits[IN]) -> bits[OUT] { diff --git a/xls/dslx/ir_convert/extract_conversion_order.cc b/xls/dslx/ir_convert/extract_conversion_order.cc index 017cae5742..e425b95ba8 100644 --- a/xls/dslx/ir_convert/extract_conversion_order.cc +++ b/xls/dslx/ir_convert/extract_conversion_order.cc @@ -359,6 +359,10 @@ class InvocationVisitor : public ExprVisitor { return absl::OkStatus(); } + absl::Status HandleLambda(const Lambda* expr) override { + return absl::UnimplementedError("lambdas not yet supported"); + } + absl::Status HandleLet(const Let* expr) override { XLS_RETURN_IF_ERROR(expr->rhs()->AcceptExpr(this)); return absl::OkStatus(); diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc index ba18535b08..59517cb6f4 100644 --- a/xls/dslx/ir_convert/function_converter.cc +++ b/xls/dslx/ir_convert/function_converter.cc @@ -417,8 +417,7 @@ class FunctionConverterVisitor : public AstNodeVisitor { INVALID(Function) INVALID(Impl) INVALID(Import) - INVALID(Use) - INVALID(UseTreeEntry) + INVALID(Lambda) INVALID(Module) INVALID(Proc) INVALID(ProcMember) @@ -428,6 +427,8 @@ class FunctionConverterVisitor : public AstNodeVisitor { INVALID(StructMemberNode) INVALID(ProcDef) INVALID(TypeAlias) + INVALID(Use) + INVALID(UseTreeEntry) // keep-sorted end private: @@ -849,6 +850,10 @@ absl::Status FunctionConverter::HandleConstantDef(const ConstantDef* node) { return DefAlias(node->value(), /*to=*/node->name_def()); } +absl::Status FunctionConverter::HandleLambda(const Lambda* node) { + return absl::UnimplementedError("lambdas not yet supported"); +} + absl::Status FunctionConverter::HandleLet(const Let* node) { VLOG(5) << "FunctionConverter::HandleLet: `" << node->ToString() << "`; rhs: `" << node->rhs()->ToString() << "`"; diff --git a/xls/dslx/ir_convert/function_converter.h b/xls/dslx/ir_convert/function_converter.h index cc9f285d58..4f60cdb0ab 100644 --- a/xls/dslx/ir_convert/function_converter.h +++ b/xls/dslx/ir_convert/function_converter.h @@ -345,6 +345,7 @@ class FunctionConverter { absl::Status HandleFormatMacro(const FormatMacro* node); absl::Status HandleIndex(const Index* node); absl::Status HandleInvocation(const Invocation* node); + absl::Status HandleLambda(const Lambda* node); absl::Status HandleLet(const Let* node); absl::Status HandleMatch(const Match* node); absl::Status HandleRange(const Range* node); diff --git a/xls/dslx/type_system/deduce.cc b/xls/dslx/type_system/deduce.cc index 247b99e8f6..28b4308d41 100644 --- a/xls/dslx/type_system/deduce.cc +++ b/xls/dslx/type_system/deduce.cc @@ -467,6 +467,11 @@ absl::StatusOr> DeduceLet(const Let* node, return Type::MakeUnit(); } +absl::StatusOr> DeduceLambda(const Lambda* node, + DeduceCtx* ctx) { + return absl::UnimplementedError("lambdas not yet supported in type system"); +} + // The types that need to be deduced for `for`-like loops (including // `unroll_for!`). struct ForLoopTypes { @@ -2057,6 +2062,7 @@ class DeduceVisitor : public AstNodeVisitor { DEDUCE_DISPATCH(Binop, DeduceBinop) DEDUCE_DISPATCH(EnumDef, DeduceEnumDef) DEDUCE_DISPATCH(Let, DeduceLet) + DEDUCE_DISPATCH(Lambda, DeduceLambda) DEDUCE_DISPATCH(For, DeduceFor) DEDUCE_DISPATCH(Cast, DeduceCast) DEDUCE_DISPATCH(ConstAssert, DeduceConstAssert) diff --git a/xls/dslx/type_system/type_info.proto b/xls/dslx/type_system/type_info.proto index 50d829e526..4877c92c13 100644 --- a/xls/dslx/type_system/type_info.proto +++ b/xls/dslx/type_system/type_info.proto @@ -101,6 +101,7 @@ enum AstNodeKindProto { AST_NODE_KIND_USE = 67; AST_NODE_KIND_USE_TREE_ENTRY = 68; AST_NODE_KIND_STRUCT_MEMBER = 69; + AST_NODE_KIND_LAMBDA = 70; } message BitsValueProto { diff --git a/xls/dslx/type_system/type_info_to_proto.cc b/xls/dslx/type_system/type_info_to_proto.cc index 1ca1d7498b..0384b3598c 100644 --- a/xls/dslx/type_system/type_info_to_proto.cc +++ b/xls/dslx/type_system/type_info_to_proto.cc @@ -189,6 +189,8 @@ AstNodeKindProto ToProto(AstNodeKind kind) { return AST_NODE_KIND_USE; case AstNodeKind::kUseTreeEntry: return AST_NODE_KIND_USE_TREE_ENTRY; + case AstNodeKind::kLambda: + return AST_NODE_KIND_LAMBDA; } // Fatal since enum class values should not be out of range. LOG(FATAL) << "Out of range AstNodeKind: " << static_cast(kind); @@ -832,6 +834,8 @@ absl::StatusOr FromProto(AstNodeKindProto p) { return AstNodeKind::kUse; case AST_NODE_KIND_USE_TREE_ENTRY: return AstNodeKind::kUseTreeEntry; + case AST_NODE_KIND_LAMBDA: + return AstNodeKind::kLambda; // Note: since this is a proto enum there are sentinel values defined in // addition to the "real" above. Return an invalid argument error. case AST_NODE_KIND_INVALID: