Skip to content

Commit

Permalink
feat(engine): recognize iterator combinators and loop invariants
Browse files Browse the repository at this point in the history
  • Loading branch information
W95Psp committed Aug 19, 2024
1 parent 34b9977 commit 58ab7eb
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 11 deletions.
120 changes: 109 additions & 11 deletions engine/lib/phases/phase_functionalize_loops.ml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,95 @@ struct
include Features.SUBTYPE.Id
end

type body_and_invariant = { body : B.expr; invariant : B.expr option }

let extract_loop_invariant (body : B.expr) : body_and_invariant =
match body.e with
| Let
{
monadic = None;
lhs = { p = PWild; _ };
rhs =
{
e = App { f = { e = GlobalVar f; _ }; args = [ invariant ]; _ };
_;
};
body;
}
when Global_ident.eq_name Hax_lib__loop_invariant f ->
{ body; invariant = Some invariant }
| _ -> { body; invariant = None }

type iterator =
| Range of { start : B.expr; end_ : B.expr }
| Slice of B.expr
| ChunksExact of { size : B.expr; slice : B.expr }
| Enumerate of iterator
| StepBy of { n : B.expr; it : iterator }
[@@deriving show]

let rec as_iterator (e : B.expr) : iterator option =
match e.e with
| Construct
{
constructor = `Concrete range_ctor;
is_record = true;
is_struct = true;
fields =
[ (`Concrete start_field, start); (`Concrete end_field, end_) ];
base = None;
}
when Concrete_ident.eq_name Core__ops__range__Range__start start_field
&& Concrete_ident.eq_name Core__ops__range__Range range_ctor
&& Concrete_ident.eq_name Core__ops__range__Range__end end_field ->
Some (Range { start; end_ })
| _ -> meth_as_iterator e

and meth_as_iterator (e : B.expr) : iterator option =
let* f, args =
match e.e with
| App { f = { e = GlobalVar f; _ }; args; _ } -> Some (f, args)
| _ -> None
in
let f_eq n = Global_ident.eq_name n f in
let one_arg () = match args with [ x ] -> Some x | _ -> None in
let two_args () = match args with [ x; y ] -> Some (x, y) | _ -> None in
if f_eq Core__iter__traits__iterator__Iterator__step_by then
let* it, n = two_args () in
let* it = as_iterator it in
Some (StepBy { n; it })
else if
f_eq Core__iter__traits__collect__IntoIterator__into_iter
|| f_eq Core__slice__Impl__iter
then
let* iterable = one_arg () in
match iterable.typ with
| TSlice _ -> Some (Slice iterable)
| _ -> as_iterator iterable
else if f_eq Core__iter__traits__iterator__Iterator__enumerate then
let* iterable = one_arg () in
let* iterator = as_iterator iterable in
Some (Enumerate iterator)
else if f_eq Core__slice__Impl__chunks_exact then
let* slice, size = two_args () in
Some (ChunksExact { size; slice })
else None

let fn_args_of_iterator (it : iterator) :
(Concrete_ident.name * B.expr list) option =
let open Concrete_ident_generated in
match it with
| Enumerate (ChunksExact { size; slice }) ->
Some
( Rust_primitives__hax__folds__fold_enumerated_chunked_slice,
[ size; slice ] )
| Enumerate (Slice slice) ->
Some (Rust_primitives__hax__folds__fold_enumerated_slice, [ slice ])
| StepBy { n; it = Range { start; end_ } } ->
Some
(Rust_primitives__hax__folds__fold_range_step_by, [ start; end_; n ])
| _ -> None

[%%inline_defs dmutability]

let rec dexpr_unwrapped (expr : A.expr) : B.expr =
Expand All @@ -46,23 +135,32 @@ struct
_;
} ->
let body = dexpr body in
let { body; invariant } = extract_loop_invariant body in
let it = dexpr it in
let pat = dpat pat in
let bpat = dpat bpat in
let fn : B.expr' =
Closure { params = [ bpat; pat ]; body; captures = [] }
in
let fn : B.expr =
let as_lhs_closure e : B.expr =
{
e = fn;
typ = TArrow ([ bpat.typ; pat.typ ], body.typ);
span = body.span;
e = Closure { params = [ bpat; pat ]; body = e; captures = [] };
typ = TArrow ([ bpat.typ; pat.typ ], e.typ);
span = e.span;
}
in
UB.call ~kind:(AssociatedItem Value)
Core__iter__traits__iterator__Iterator__fold
[ it; dexpr init; fn ]
span (dty span expr.typ)
let fn : B.expr = as_lhs_closure body in
let invariant : B.expr =
let default : B.expr =
{ e = Literal (Bool true); typ = TBool; span = expr.span }
in
Option.value ~default invariant |> as_lhs_closure
in
let init = dexpr init in
let f, args =
match as_iterator it |> Option.bind ~f:fn_args_of_iterator with
| Some (f, args) -> (f, args @ [ init; invariant; fn ])
| None ->
(Core__iter__traits__iterator__Iterator__fold, [ it; init; fn ])
in
UB.call ~kind:(AssociatedItem Value) f args span (dty span expr.typ)
| Loop
{
body;
Expand Down
16 changes: 16 additions & 0 deletions engine/names/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn dummy_hax_concrete_ident_wrapper<I: core::iter::Iterator<Item = u8>>(x: I, mu
assert!(true);
assert_eq!(1, 1);
hax_lib::assert!(true);
hax_lib::loop_invariant(true);

let _ = [()].into_iter();
let _: u16 = 6u8.into();
Expand All @@ -37,6 +38,14 @@ fn dummy_hax_concrete_ident_wrapper<I: core::iter::Iterator<Item = u8>>(x: I, mu
let _ = ..;
let _ = ..1;

fn iterator_functions<It: Iterator + Clone>(it: It) {
let _ = it.clone().step_by(2);
let _ = it.clone().enumerate();
let _ = [()].chunks_exact(2);
let _ = [()].iter();
let _ = (&[()] as &[()]).iter();
}

{
use hax_lib::int::*;
let a: Int = 3u8.lift();
Expand Down Expand Up @@ -163,6 +172,13 @@ mod hax {
fn array_of_list() {}
fn never_to_any() {}

mod folds {
fn fold_range() {}
fn fold_range_step_by() {}
fn fold_enumerated_slice() {}
fn fold_enumerated_chunked_slice() {}
}

/// The engine uses this `dropped_body` symbol as a marker value
/// to signal that a item was extracted without body.
fn dropped_body() {}
Expand Down
3 changes: 3 additions & 0 deletions hax-lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ pub fn inline_unsafe<T>(_: &str) -> T {
unreachable!()
}

#[doc(hidden)]
pub fn loop_invariant(_: bool) {}

/// A type that implements `Refinement` should be a newtype for a
/// type `T`. The field holding the value of type `T` should be
/// private, and `Refinement` should be the only interface to the
Expand Down

0 comments on commit 58ab7eb

Please sign in to comment.