Skip to content

Commit

Permalink
Add Lambda nodes to the ast and start parsing.
Browse files Browse the repository at this point in the history
github.com//issues/1671

PiperOrigin-RevId: 721032937
  • Loading branch information
erinzmoore authored and copybara-github committed Jan 29, 2025
1 parent e109d24 commit 2917733
Show file tree
Hide file tree
Showing 17 changed files with 195 additions and 2 deletions.
4 changes: 4 additions & 0 deletions xls/dslx/bytecode/bytecode_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,10 @@ absl::Status BytecodeEmitter::DestructureLet(
return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleLambda(const Lambda* node) {
return absl::UnimplementedError("lambdas not yet supported");
}

absl::Status BytecodeEmitter::HandleLet(const Let* node) {
XLS_RETURN_IF_ERROR(node->rhs()->AcceptExpr(this));
std::optional<Type*> type = type_info_->GetItem(node->rhs());
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/bytecode/bytecode_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class BytecodeEmitter : public ExprVisitor {
absl::Status HandleWidthSlice(const Index* node, WidthSlice* width_slice);

absl::Status HandleInvocation(const Invocation* node) override;
absl::Status HandleLambda(const Lambda* node) override;
absl::Status HandleLet(const Let* node) override;
absl::Status HandleMatch(const Match* node) override;
absl::Status HandleNameRef(const NameRef* node) override;
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/constexpr_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,10 @@ absl::Status ConstexprEvaluator::HandleInvocation(const Invocation* expr) {
return absl::OkStatus();
}

absl::Status ConstexprEvaluator::HandleLambda(const Lambda* expr) {
return absl::OkStatus();
}

absl::Status ConstexprEvaluator::HandleMatch(const Match* expr) {
EVAL_AS_CONSTEXPR_OR_RETURN(expr->matched());

Expand Down
1 change: 1 addition & 0 deletions xls/dslx/constexpr_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ConstexprEvaluator : public xls::dslx::ExprVisitor {
absl::Status HandleAllOnesMacro(const AllOnesMacro* expr) override;
absl::Status HandleIndex(const Index* expr) override;
absl::Status HandleInvocation(const Invocation* expr) override;
absl::Status HandleLambda(const Lambda* expr) override;
absl::Status HandleLet(const Let* expr) override { return absl::OkStatus(); }
absl::Status HandleMatch(const Match* expr) override;
absl::Status HandleNameRef(const NameRef* expr) override;
Expand Down
40 changes: 40 additions & 0 deletions xls/dslx/frontend/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ std::string_view AstNodeKindToString(AstNodeKind kind) {
return "cast";
case AstNodeKind::kConstantDef:
return "constant definition";
case AstNodeKind::kLambda:
return "lambda";
case AstNodeKind::kLet:
return "let";
case AstNodeKind::kChannelDecl:
Expand Down Expand Up @@ -2088,6 +2090,44 @@ std::vector<std::string> Function::GetFreeParametricKeys() const {

TestFunction::~TestFunction() = default;

// -- class Lambda

Lambda::Lambda(Module* owner, Span span, std::vector<Param*> params,
TypeAnnotation* return_type, StatementBlock* body)
: Expr(owner, std::move(span)),
params_(std::move(params)),
return_type_(return_type),
body_(body) {}

Lambda::~Lambda() = default;

std::vector<AstNode*> Lambda::GetChildren(bool want_types) const {
std::vector<AstNode*> results;
for (Param* p : params()) {
results.push_back(p);
}
if (return_type_ != nullptr && want_types) {
results.push_back(return_type_);
}
results.push_back(body_);
return results;
}

std::string Lambda::ToStringInternal() const {
std::string params_str =
absl::StrJoin(params(), ", ", [](std::string* out, Param* param) {
absl::StrAppend(out, param->ToString());
});

std::string return_str = return_type_ != nullptr
? absl::StrCat(" -> ", return_type_->ToString())
: "";
std::string body_str =
body_->size() > 1 ? body_->ToString() : body_->ToInlineString();

return absl::StrFormat("|%s|%s %s", params_str, return_str, body_str);
}

// -- class MatchArm

MatchArm::MatchArm(Module* owner, Span span, std::vector<NameDefTree*> patterns,
Expand Down
43 changes: 43 additions & 0 deletions xls/dslx/frontend/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
X(FunctionRef) \
X(Index) \
X(Invocation) \
X(Lambda) \
X(Let) \
X(Match) \
X(NameRef) \
Expand Down Expand Up @@ -2084,6 +2085,48 @@ class Function : public AstNode {
bool disable_format_ = false;
};

// A lambda expression.
// Syntax: `|<PARAM>[: <TYPE], ... | [-> <RETURN_TYPE>] { <BODY> }`
//
// Parameter types and return type are optional.
//
// Example: `let squares = map(range(u32:0, u32:5), |x| { x * x });`
//
// Attributes:
// * params: The explicit parameters of the lambda.
// * return_type: The return type of the lambda.
// * body: The body of the lambda.
class Lambda : public Expr {
public:
Lambda(Module* owner, Span span, std::vector<Param*> params,
TypeAnnotation* return_type, StatementBlock* body);

~Lambda() override;

AstNodeKind kind() const override { return AstNodeKind::kLambda; }
absl::Status Accept(AstNodeVisitor* v) const override {
return v->HandleLambda(this);
}
absl::Status AcceptExpr(ExprVisitor* v) const override {
return v->HandleLambda(this);
}
std::string_view GetNodeTypeName() const override { return "Lambda"; }
std::vector<AstNode*> GetChildren(bool want_types) const override;

const std::vector<Param*>& params() const { return params_; }

private:
std::vector<Param*> params_;
TypeAnnotation* return_type_; // May be null.
StatementBlock* body_;

std::string ToStringInternal() const final;

Precedence GetPrecedenceWithoutParens() const final {
return Precedence::kStrongest;
}
};

// Represents a single arm in a match expression.
//
// Attributes:
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/frontend/ast_cloner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ class AstCloner : public AstNodeVisitor {
return absl::OkStatus();
}

absl::Status HandleLambda(const Lambda* n) override {
return absl::UnimplementedError("lambdas not yet supported");
}

absl::Status HandleLet(const Let* n) override {
XLS_RETURN_IF_ERROR(VisitChildren(n));

Expand Down
1 change: 1 addition & 0 deletions xls/dslx/frontend/ast_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ enum class AstNodeKind : uint8_t {
kInstantiation,
kInvocation,
kJoin,
kLambda,
kLet,
kMatch,
kMatchArm,
Expand Down
30 changes: 30 additions & 0 deletions xls/dslx/frontend/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,32 @@ absl::StatusOr<Function*> Parser::ParseFunction(
return f;
}

// Lambda syntax: | <PARAM>[: <TYPE>], ... | [-> <RETURN_TYPE>] { <BODY> }
absl::StatusOr<Lambda*> Parser::ParseLambda(Bindings& bindings) {
Pos start_pos = GetPos();
VLOG(5) << "ParseLambda @ " << start_pos;
XLS_ASSIGN_OR_RETURN(const Token* peek, PeekToken());
std::vector<Param*> params;
if (peek->kind() == TokenKind::kBar) {
XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kBar));
auto parse_param = [&] { return ParseParam(bindings); };
XLS_ASSIGN_OR_RETURN(params,
ParseCommaSeq<Param*>(parse_param, TokenKind::kBar));
} else {
XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kDoubleBar));
}

XLS_ASSIGN_OR_RETURN(bool dropped_arrow, TryDropToken(TokenKind::kArrow));
TypeAnnotation* return_type = nullptr;
if (dropped_arrow) {
XLS_ASSIGN_OR_RETURN(return_type, ParseTypeAnnotation(bindings));
}

XLS_ASSIGN_OR_RETURN(StatementBlock * body, ParseBlockExpression(bindings));
return module_->Make<Lambda>(Span(start_pos, GetPos()), params, return_type,
body);
}

absl::Status Parser::ParseModuleAttribute() {
XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kOBrack));
Span attribute_span;
Expand Down Expand Up @@ -666,6 +692,10 @@ absl::StatusOr<Expr*> Parser::ParseExpression(Bindings& bindings,
if (peek->kind() == TokenKind::kOBrace) {
return ParseBlockExpression(bindings);
}
if (peek->kind() == TokenKind::kBar ||
peek->kind() == TokenKind::kDoubleBar) {
return ParseLambda(bindings);
}
return ParseConditionalExpression(bindings, restrictions);
}

Expand Down
2 changes: 2 additions & 0 deletions xls/dslx/frontend/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class Parser : public TokenParser {
const Pos& start_pos, bool is_public, Bindings& bindings,
absl::flat_hash_map<std::string, Function*>* name_to_fn = nullptr);

absl::StatusOr<Lambda*> ParseLambda(Bindings& bindings);

absl::StatusOr<ModuleMember> ParseProc(const Pos& start_pos, bool is_public,
Bindings& bindings);

Expand Down
42 changes: 42 additions & 0 deletions xls/dslx/frontend/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3354,6 +3354,48 @@ TEST_F(ParserTest, ParseParametricProcWithConstAssert) {
EXPECT_EQ(p->ToString(), text);
}

TEST_F(ParserTest, ParseMapWithLambdaParamAnnotation) {
RoundTrip(R"(const ARR = map(range(0, u16:5), |i: u16| { 2 * i });)");
}

TEST_F(ParserTest, LambdaInLetNoParams) {
RoundTrip(R"(fn uses_lambda(i: u32) -> u32 {
let X = || { u32:2 * i };
X()
})");
}

TEST_F(ParserTest, LambdaInLetNoParamsWithReturn) {
RoundTrip(R"(fn uses_lambda(i: u32) -> u32 {
let X = || -> u32 { u32:2 * i };
X()
})");
}

TEST_F(ParserTest, ParseMapWithLambdaCapture) {
RoundTrip(R"(const X = u16:3;
const ARR = map(range(0, u16:5), |i: u16| { X * i });)");
}

TEST_F(ParserTest, ParseMapWithLambdaWithReturnType) {
RoundTrip(
R"(const ARR = map(range(0, u16:5), |i: u16| -> u32 { 2 * i as u32 });)");
}

TEST_F(ParserTest, ParseMapWithLambdaMultilineBody) {
RoundTrip(
R"(const ARR = map(range(0, u16:5), |i: u16| {
let x = 2 * i as u32;
x
});)");
}

// TODO(https://github.com/google/xls/issues/1671): Support once
// `AnyTypeAnnotation` is available.
TEST_F(ParserTest, DISABLED_ParseMapWithLambdaNoParamAnnotation) {
RoundTrip(R"(const ARR = map(range(0, u16:5), |i| { 2 * i });)");
}

TEST_F(ParserTest, ParseParametricInMapBuiltin) {
constexpr std::string_view kProgram = R"(
fn truncate<OUT: u32, IN: u32>(x: bits[IN]) -> bits[OUT] {
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/ir_convert/extract_conversion_order.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ class InvocationVisitor : public ExprVisitor {
return absl::OkStatus();
}

absl::Status HandleLambda(const Lambda* expr) override {
return absl::UnimplementedError("lambdas not yet supported");
}

absl::Status HandleLet(const Let* expr) override {
XLS_RETURN_IF_ERROR(expr->rhs()->AcceptExpr(this));
return absl::OkStatus();
Expand Down
9 changes: 7 additions & 2 deletions xls/dslx/ir_convert/function_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,7 @@ class FunctionConverterVisitor : public AstNodeVisitor {
INVALID(Function)
INVALID(Impl)
INVALID(Import)
INVALID(Use)
INVALID(UseTreeEntry)
INVALID(Lambda)
INVALID(Module)
INVALID(Proc)
INVALID(ProcMember)
Expand All @@ -428,6 +427,8 @@ class FunctionConverterVisitor : public AstNodeVisitor {
INVALID(StructMemberNode)
INVALID(ProcDef)
INVALID(TypeAlias)
INVALID(Use)
INVALID(UseTreeEntry)
// keep-sorted end

private:
Expand Down Expand Up @@ -849,6 +850,10 @@ absl::Status FunctionConverter::HandleConstantDef(const ConstantDef* node) {
return DefAlias(node->value(), /*to=*/node->name_def());
}

absl::Status FunctionConverter::HandleLambda(const Lambda* node) {
return absl::UnimplementedError("lambdas not yet supported");
}

absl::Status FunctionConverter::HandleLet(const Let* node) {
VLOG(5) << "FunctionConverter::HandleLet: `" << node->ToString()
<< "`; rhs: `" << node->rhs()->ToString() << "`";
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/ir_convert/function_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ class FunctionConverter {
absl::Status HandleFormatMacro(const FormatMacro* node);
absl::Status HandleIndex(const Index* node);
absl::Status HandleInvocation(const Invocation* node);
absl::Status HandleLambda(const Lambda* node);
absl::Status HandleLet(const Let* node);
absl::Status HandleMatch(const Match* node);
absl::Status HandleRange(const Range* node);
Expand Down
6 changes: 6 additions & 0 deletions xls/dslx/type_system/deduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,11 @@ absl::StatusOr<std::unique_ptr<Type>> DeduceLet(const Let* node,
return Type::MakeUnit();
}

absl::StatusOr<std::unique_ptr<Type>> DeduceLambda(const Lambda* node,
DeduceCtx* ctx) {
return absl::UnimplementedError("lambdas not yet supported in type system");
}

// The types that need to be deduced for `for`-like loops (including
// `unroll_for!`).
struct ForLoopTypes {
Expand Down Expand Up @@ -2057,6 +2062,7 @@ class DeduceVisitor : public AstNodeVisitor {
DEDUCE_DISPATCH(Binop, DeduceBinop)
DEDUCE_DISPATCH(EnumDef, DeduceEnumDef)
DEDUCE_DISPATCH(Let, DeduceLet)
DEDUCE_DISPATCH(Lambda, DeduceLambda)
DEDUCE_DISPATCH(For, DeduceFor)
DEDUCE_DISPATCH(Cast, DeduceCast)
DEDUCE_DISPATCH(ConstAssert, DeduceConstAssert)
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/type_system/type_info.proto
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ enum AstNodeKindProto {
AST_NODE_KIND_USE = 67;
AST_NODE_KIND_USE_TREE_ENTRY = 68;
AST_NODE_KIND_STRUCT_MEMBER = 69;
AST_NODE_KIND_LAMBDA = 70;
}

message BitsValueProto {
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/type_system/type_info_to_proto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ AstNodeKindProto ToProto(AstNodeKind kind) {
return AST_NODE_KIND_USE;
case AstNodeKind::kUseTreeEntry:
return AST_NODE_KIND_USE_TREE_ENTRY;
case AstNodeKind::kLambda:
return AST_NODE_KIND_LAMBDA;
}
// Fatal since enum class values should not be out of range.
LOG(FATAL) << "Out of range AstNodeKind: " << static_cast<int64_t>(kind);
Expand Down Expand Up @@ -832,6 +834,8 @@ absl::StatusOr<AstNodeKind> FromProto(AstNodeKindProto p) {
return AstNodeKind::kUse;
case AST_NODE_KIND_USE_TREE_ENTRY:
return AstNodeKind::kUseTreeEntry;
case AST_NODE_KIND_LAMBDA:
return AstNodeKind::kLambda;
// Note: since this is a proto enum there are sentinel values defined in
// addition to the "real" above. Return an invalid argument error.
case AST_NODE_KIND_INVALID:
Expand Down

0 comments on commit 2917733

Please sign in to comment.