Skip to content

Commit

Permalink
feat(engine): add ast_destruct module
Browse files Browse the repository at this point in the history
Fixes #941
  • Loading branch information
W95Psp committed Oct 3, 2024
1 parent bf1f3d8 commit a7ddc69
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 96 deletions.
13 changes: 13 additions & 0 deletions engine/lib/ast_destruct.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
open! Prelude
open! Ast

module Make (F : Features.T) = struct
include Ast_destruct_generated.Make (F)

let list_0 = function [] -> Some () | _ -> None
let list_1 = function [ a ] -> Some a | _ -> None
let list_2 = function [ a; b ] -> Some (a, b) | _ -> None
let list_3 = function [ a; b; c ] -> Some (a, b, c) | _ -> None
let list_4 = function [ a; b; c; d ] -> Some (a, b, c, d) | _ -> None
let list_5 = function [ a; b; c; d; e ] -> Some (a, b, c, d, e) | _ -> None
end
1 change: 1 addition & 0 deletions engine/lib/ast_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ module Make (F : Features.T) = struct
module TypedLocalIdent = TypedLocalIdent (AST)
module Visitors = Ast_visitors.Make (F)
module M = Ast_builder.Make (F)
module D = Ast_destruct.Make (F)

module Expect = struct
let mut_borrow (e : expr) : expr option =
Expand Down
11 changes: 11 additions & 0 deletions engine/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@
%{ast}
(run generate_from_ast visitors)))))

(rule
(target ast_destruct_generated.ml)
(deps
(:ast ast.ml))
(action
(with-stdout-to
ast_destruct_generated.ml
(with-stdin-from
%{ast}
(run generate_from_ast ast_destruct)))))

(rule
(target ast_builder_generated.ml)
(deps
Expand Down
114 changes: 28 additions & 86 deletions engine/lib/phases/phase_reconstruct_asserts.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,92 +19,34 @@ module Make (F : Features.T) =
inherit [_] Visitors.map as super

method! visit_expr () e =
match e with
| {
e =
If
{
cond;
then_ =
{
e =
( App
{
f = { e = GlobalVar nta; _ };
args =
[
{
e =
Let
{
body =
{
e =
Block
{
e =
{
e =
App
{
f =
{
e =
GlobalVar
panic;
_;
};
_;
};
_;
};
_;
};
_;
};
_;
};
_;
};
];
_;
}
| Block
{
e =
{
e =
App
{
f = { e = GlobalVar nta; _ };
args =
[
{
e =
App
{
f =
{
e = GlobalVar panic;
_;
};
_;
};
_;
};
];
_;
};
_;
};
_;
} );
_;
};
_;
};
_;
}
let extract_block e =
let* { e; _ } = U.D.expr_Block e in
let* { f; args; _ } = U.D.expr_App e in
let* nta = U.D.expr_GlobalVar f in
match args with
| [ { e = App { f = { e = GlobalVar panic; _ }; _ }; _ } ] ->
Some (nta, panic)
| _ -> None
in
let extract_app e =
let* { f; args; _ } = U.D.expr_App e in
let* nta = U.D.expr_GlobalVar f in
let* arg = U.D.list_1 args in
let* { body; _ } = U.D.expr_Let arg in
let* { e; _ } = U.D.expr_Block body in
let* { f; _ } = U.D.expr_App e in
let* panic = U.D.expr_GlobalVar f in
Some (nta, panic)
in
let extract e =
let* { cond; then_; _ } = U.D.expr_If e in
let* nta, panic =
extract_app then_ <|> fun _ -> extract_block then_
in
Some (panic, nta, cond)
in
match extract e with
| Some (panic, nta, cond)
when Ast.Global_ident.eq_name Rust_primitives__hax__never_to_any
nta
&& (Ast.Global_ident.eq_name Core__panicking__panic panic
Expand Down
1 change: 1 addition & 0 deletions engine/lib/utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ let sequence (l : 'a option list) : 'a list option =
match (acc, x) with Some acc, Some x -> Some (x :: acc) | _ -> None)
~init:(Some []) l

let ( <|> ) x f = match x with Some x -> Some x | None -> f ()
let tabsize = 2
let newline_indent depth : string = "\n" ^ String.make (tabsize * depth) ' '

Expand Down
2 changes: 2 additions & 0 deletions engine/utils/generate_from_ast/codegen_ast_builder.ml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ let mk datatypes =
(find "pat", find "pat'");
(find "item", find "item'");
(find "guard", find "guard'");
(find "trait_item", find "trait_item'");
(find "impl_expr", find "impl_expr_kind");
]
in
let body = data |> List.map ~f:(mk_builder []) |> String.concat ~sep:"\n\n" in
Expand Down
105 changes: 105 additions & 0 deletions engine/utils/generate_from_ast/codegen_ast_destruct.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
open Base
open Utils
open Types

let rec print_ty (t : Type.t) =
if String.is_prefix t.typ ~prefix:"prim___tuple_" then
"(" ^ String.concat ~sep:" * " (List.map t.args ~f:print_ty) ^ ")"
else
"("
^ (if List.is_empty t.args then ""
else "(" ^ String.concat ~sep:", " (List.map t.args ~f:print_ty) ^ ") ")
^ t.typ ^ ")"

let print_record_or_tuple is_record x =
let l, sep, r = if is_record then ("{", ";", "}") else ("(", ",", ")") in
l ^ String.concat ~sep (List.map ~f:fst x) ^ r

let print_record = print_record_or_tuple true
let print_tuple = print_record_or_tuple false

let print_record_type_or_tuple is_record x =
let l, sep, r = if is_record then ("{", ";", "}") else ("(", "*", ")") in
l
^ String.concat ~sep
(List.map
~f:(fun (name, ty) ->
(if is_record then name ^ ":" else "") ^ print_ty ty)
x)
^ r

let print_record_type = print_record_type_or_tuple true

let print_tuple_type =
List.map ~f:(fun ty -> ("", ty)) >> print_record_type_or_tuple false

let mk_builder ((record, enum) : Datatype.t * Datatype.t) =
let ty = record.name in
let record, variants =
match (record.kind, enum.kind) with
| Record record, Variant variants -> (record, variants)
| _ -> failwith "mk_builder: bad kinds of datatypes"
in
let field_name_raw, _ =
List.find ~f:(fun (_, ty) -> [%eq: string] ty.Type.typ enum.name) record
|> Option.value_exn
in
List.map
~f:(fun Variant.{ name; payload } ->
let id = ty ^ "_" ^ name in
let inline_record = id in
let type_decl =
"\ntype " ^ inline_record ^ " = "
^
match payload with
| Record record -> print_record_type record
| Tuple types -> types |> print_tuple_type
| None -> "unit"
in
let head =
"\nlet " ^ id ^ " (value: " ^ ty ^ ")" ^ ": " ^ inline_record
^ " option ="
in
let spayload =
match payload with
| Record record -> print_record record
| Tuple types ->
List.mapi ~f:(fun i ty -> ("x" ^ Int.to_string i, ty)) types
|> print_tuple
| None -> ""
in
type_decl ^ head ^ "\n match value." ^ field_name_raw ^ " with\n | "
^ name ^ " " ^ spayload ^ " -> Some "
^ (if String.is_empty spayload then "()" else spayload)
^ if List.length variants |> [%eq: int] 1 then "" else "\n | _ -> None")
variants
|> String.concat ~sep:"\n\n"

let mk datatypes =
let find name =
List.find ~f:(fun dt -> [%eq: string] dt.Datatype.name name) datatypes
|> Option.value_exn
in
let data =
[
(find "expr", find "expr'");
(find "pat", find "pat'");
(find "item", find "item'");
(find "guard", find "guard'");
(find "trait_item", find "trait_item'");
(find "impl_expr", find "impl_expr_kind");
]
in
let body = data |> List.map ~f:mk_builder |> String.concat ~sep:"\n\n" in
{|
open! Prelude
open! Ast

module Make (F : Features.T) = struct
open Ast.Make(F)

|}
^ body ^ {|

end
|}
25 changes: 15 additions & 10 deletions engine/utils/generate_from_ast/generate_from_ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@ let _main =
| _ -> None)
in

datatypes
|> (match Sys.get_argv () with
| [| _; "visitors" |] -> Codegen_visitor.mk
| [| _; "ast_builder" |] -> Codegen_ast_builder.mk
| [| _; "json" |] ->
[%yojson_of: Datatype.t list] >> Yojson.Safe.pretty_to_string
| [| _; verb |] ->
failwith ("`generate_from_ast`: unknown action `" ^ verb ^ "`")
| _ -> failwith "`generate_from_ast`: expected one argument")
|> Stdio.print_endline
let data =
datatypes
|>
match Sys.get_argv () with
| [| _; "visitors" |] -> Codegen_visitor.mk
| [| _; "ast_builder" |] -> Codegen_ast_builder.mk
| [| _; "ast_destruct" |] -> Codegen_ast_destruct.mk
| [| _; "json" |] ->
[%yojson_of: Datatype.t list] >> Yojson.Safe.pretty_to_string
| [| _; verb |] ->
failwith ("`generate_from_ast`: unknown action `" ^ verb ^ "`")
| _ -> failwith "`generate_from_ast`: expected one argument"
in
(* Stdio.Out_channel.write_all "/tmp/debug-generated-code.ml" ~data; *)
Stdio.print_endline data

0 comments on commit a7ddc69

Please sign in to comment.