Skip to content

Commit

Permalink
Merge pull request #388 from hacspec/make-fn-always-be-arrows
Browse files Browse the repository at this point in the history
fix(engine): make `fn` types always arrows
  • Loading branch information
W95Psp authored Dec 11, 2023
2 parents 3765856 + 1ef597e commit 1f17240
Show file tree
Hide file tree
Showing 15 changed files with 162 additions and 51 deletions.
5 changes: 5 additions & 0 deletions engine/lib/ast_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,11 @@ module Make (F : Features.T) = struct

let make_wild_pat (typ : ty) (span : span) : pat = { p = PWild; span; typ }

let make_unit_param (span : span) : param =
let typ = unit_typ in
let pat = make_wild_pat typ span in
{ pat; typ; typ_span = None; attrs = [] }

let make_seq (e1 : expr) (e2 : expr) : expr =
make_let (make_wild_pat e1.typ e1.span) e1 e2

Expand Down
29 changes: 25 additions & 4 deletions engine/lib/import_thir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ end) : EXPR = struct
{ f with e = GlobalVar (def_id (AssociatedItem Value) id) }
| _ -> f
in
let args = if List.is_empty args then [ unit_expr span ] else args in
App { f; args; generic_args }
| Box { value } ->
(U.call Rust_primitives__hax__box_new [ c_expr value ] span typ).e
Expand Down Expand Up @@ -616,6 +617,10 @@ end) : EXPR = struct
let params =
List.filter_map ~f:(fun p -> Option.map ~f:c_pat p.pat) params
in
let params =
if List.is_empty params then [ U.make_wild_pat U.unit_typ span ]
else params
in
let body = c_expr body in
let upvars = List.map ~f:c_expr upvars in
Closure { body; params; captures = upvars }
Expand Down Expand Up @@ -843,7 +848,11 @@ end) : EXPR = struct
| Float k -> TFloat (match k with F32 -> F32 | F64 -> F64)
| Arrow value ->
let ({ inputs; output; _ } : Thir.ty_fn_sig) = value.value in
TArrow (List.map ~f:(c_ty span) inputs, c_ty span output)
let inputs =
if List.is_empty inputs then [ U.unit_typ ]
else List.map ~f:(c_ty span) inputs
in
TArrow (inputs, c_ty span output)
| Adt { def_id = id; generic_args } ->
let ident = def_id Type id in
let args = List.map ~f:(c_generic_value span) generic_args in
Expand Down Expand Up @@ -1022,7 +1031,11 @@ end) : EXPR = struct
| DefaultReturn _span -> unit_typ
| Return ty -> c_ty span ty
in
TIFn (TArrow (List.map ~f:(c_ty span) inputs, output))
let inputs =
if List.is_empty inputs then [ U.unit_typ ]
else List.map ~f:(c_ty span) inputs
in
TIFn (TArrow (inputs, output))
| Type (bounds, None) ->
let bounds = List.filter_map ~f:(c_predicate_kind span) bounds in
TIType bounds
Expand Down Expand Up @@ -1126,14 +1139,18 @@ and c_item_unwrapped ~ident (item : Thir.item) : item list =
ty = c_ty item.span ty;
}
| Fn (generics, { body; params; _ }) ->
let params =
if List.is_empty params then [ U.make_unit_param span ]
else List.map ~f:(c_param item.span) params
in
mk
@@ Fn
{
name =
Concrete_ident.of_def_id Value (Option.value_exn item.def_id);
generics = c_generics generics;
body = c_expr body;
params = List.map ~f:(c_param item.span) params;
params;
}
| Enum (variants, generics) ->
let def_id = Option.value_exn item.def_id in
Expand Down Expand Up @@ -1233,12 +1250,16 @@ and c_item_unwrapped ~ident (item : Thir.item) : item list =
let v =
match (item.kind : Thir.impl_item_kind) with
| Fn { body; params; _ } ->
let params =
if List.is_empty params then [ U.make_unit_param span ]
else List.map ~f:(c_param item.span) params
in
Fn
{
name = item_def_id;
generics = c_generics generics;
body = c_expr body;
params = List.map ~f:(c_param item.span) params;
params;
}
| Const (_ty, e) ->
Fn
Expand Down
2 changes: 1 addition & 1 deletion examples/chacha20/proofs/fstar/extraction/Chacha20.fst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ let chacha20_key_block0 (key: t_Array u8 (sz 32)) (iv: t_Array u8 (sz 12)) : t_A

let chacha20_update (st0: t_Array u32 (sz 16)) (m: t_Slice u8)
: Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global =
let blocks_out:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global = Alloc.Vec.impl__new in
let blocks_out:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global = Alloc.Vec.impl__new () in
let num_blocks:usize = (Core.Slice.impl__len m <: usize) /! sz 64 in
let remainder_len:usize = (Core.Slice.impl__len m <: usize) %! sz 64 in
let blocks_out:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ let process_order
(other_side: Alloc.Collections.Binary_heap.t_BinaryHeap v_T)
: (Alloc.Collections.Binary_heap.t_BinaryHeap v_T &
(Alloc.Vec.t_Vec t_Match Alloc.Alloc.t_Global & Core.Option.t_Option t_Order)) =
let matches:Alloc.Vec.t_Vec t_Match Alloc.Alloc.t_Global = Alloc.Vec.impl__new in
let matches:Alloc.Vec.t_Vec t_Match Alloc.Alloc.t_Global = Alloc.Vec.impl__new () in
let done:bool = false in
let done, matches, order, other_side:(bool & Alloc.Vec.t_Vec t_Match Alloc.Alloc.t_Global &
t_Order &
Expand Down Expand Up @@ -139,15 +139,16 @@ let process_order
(bool & Alloc.Vec.t_Vec t_Match Alloc.Alloc.t_Global & t_Order &
Alloc.Collections.Binary_heap.t_BinaryHeap v_T))
in
let output:(Alloc.Vec.t_Vec t_Match Alloc.Alloc.t_Global & Core.Option.t_Option t_Order) =
let hax_temp_output:(Alloc.Vec.t_Vec t_Match Alloc.Alloc.t_Global & Core.Option.t_Option t_Order)
=
matches,
(if order.f_quantity >. 0uL
then Core.Option.Option_Some order <: Core.Option.t_Option t_Order
else Core.Option.Option_None <: Core.Option.t_Option t_Order)
<:
(Alloc.Vec.t_Vec t_Match Alloc.Alloc.t_Global & Core.Option.t_Option t_Order)
in
other_side, output
other_side, hax_temp_output
<:
(Alloc.Collections.Binary_heap.t_BinaryHeap v_T &
(Alloc.Vec.t_Vec t_Match Alloc.Alloc.t_Global & Core.Option.t_Option t_Order))
52 changes: 36 additions & 16 deletions examples/sha256/proofs/fstar/extraction/Sha256.fst
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,33 @@ let shuffle (ws: t_Array u32 (sz 64)) (hashi: t_Array u32 (sz 8)) : t_Array u32
Core.Num.impl__u32__wrapping_add (sigma a0 (sz 0) (sz 1) <: u32) (maj a0 b0 c0 <: u32)
in
let h:t_Array u32 (sz 8) =
Rust_primitives.Hax.update_at h (sz 0) (Core.Num.impl__u32__wrapping_add t1 t2 <: u32)
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize h
(sz 0)
(Core.Num.impl__u32__wrapping_add t1 t2 <: u32)
in
let h:t_Array u32 (sz 8) = Rust_primitives.Hax.update_at h (sz 1) a0 in
let h:t_Array u32 (sz 8) = Rust_primitives.Hax.update_at h (sz 2) b0 in
let h:t_Array u32 (sz 8) = Rust_primitives.Hax.update_at h (sz 3) c0 in
let h:t_Array u32 (sz 8) =
Rust_primitives.Hax.update_at h (sz 4) (Core.Num.impl__u32__wrapping_add d0 t1 <: u32)
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize h (sz 1) a0
in
let h:t_Array u32 (sz 8) =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize h (sz 2) b0
in
let h:t_Array u32 (sz 8) =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize h (sz 3) c0
in
let h:t_Array u32 (sz 8) =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize h
(sz 4)
(Core.Num.impl__u32__wrapping_add d0 t1 <: u32)
in
let h:t_Array u32 (sz 8) =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize h (sz 5) e0
in
let h:t_Array u32 (sz 8) =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize h (sz 6) f0
in
let h:t_Array u32 (sz 8) =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize h (sz 7) g0
in
let h:t_Array u32 (sz 8) = Rust_primitives.Hax.update_at h (sz 5) e0 in
let h:t_Array u32 (sz 8) = Rust_primitives.Hax.update_at h (sz 6) f0 in
let h:t_Array u32 (sz 8) = Rust_primitives.Hax.update_at h (sz 7) g0 in
h)
in
h
Expand Down Expand Up @@ -189,7 +205,9 @@ let schedule (block: t_Array u8 (sz 64)) : t_Array u32 (sz 64) =
let i:usize = i in
if i <. sz 16 <: bool
then
let s:t_Array u32 (sz 64) = Rust_primitives.Hax.update_at s i (b.[ i ] <: u32) in
let s:t_Array u32 (sz 64) =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize s i (b.[ i ] <: u32)
in
s
else
let t16:u32 = s.[ i -! sz 16 <: usize ] in
Expand All @@ -199,7 +217,7 @@ let schedule (block: t_Array u8 (sz 64)) : t_Array u32 (sz 64) =
let s1:u32 = sigma t2 (sz 3) (sz 0) in
let s0:u32 = sigma t15 (sz 2) (sz 0) in
let s:t_Array u32 (sz 64) =
Rust_primitives.Hax.update_at s
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize s
i
(Core.Num.impl__u32__wrapping_add (Core.Num.impl__u32__wrapping_add (Core.Num.impl__u32__wrapping_add
s1
Expand Down Expand Up @@ -233,7 +251,7 @@ let compress (block: t_Array u8 (sz 64)) (h_in: t_Array u32 (sz 8)) : t_Array u3
(fun h i ->
let h:t_Array u32 (sz 8) = h in
let i:usize = i in
Rust_primitives.Hax.update_at h
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize h
i
(Core.Num.impl__u32__wrapping_add (h.[ i ] <: u32) (h_in.[ i ] <: u32) <: u32)
<:
Expand Down Expand Up @@ -270,7 +288,7 @@ let u32s_to_be_bytes (state: t_Array u32 (sz 8)) : t_Array u8 (sz 32) =
(fun out j ->
let out:t_Array u8 (sz 32) = out in
let j:usize = j in
Rust_primitives.Hax.update_at out
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out
((i *! sz 4 <: usize) +! j <: usize)
(tmp.[ j ] <: u8)
<:
Expand Down Expand Up @@ -312,7 +330,9 @@ let hash (msg: t_Slice u8) : t_Array u8 (sz 32) =
(fun last_block i ->
let last_block:t_Array u8 (sz 64) = last_block in
let i:usize = i in
Rust_primitives.Hax.update_at last_block i (block.[ i ] <: u8)
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize last_block
i
(block.[ i ] <: u8)
<:
t_Array u8 (sz 64))
in
Expand All @@ -330,7 +350,7 @@ let hash (msg: t_Slice u8) : t_Array u8 (sz 32) =
h, last_block, last_block_len <: (t_Array u32 (sz 8) & t_Array u8 (sz 64) & usize))
in
let last_block:t_Array u8 (sz 64) =
Rust_primitives.Hax.update_at last_block last_block_len 128uy
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize last_block last_block_len 128uy
in
let len_bist:u64 = cast ((Core.Slice.impl__len msg <: usize) *! sz 8 <: usize) <: u64 in
let len_bist_bytes:t_Array u8 (sz 8) = Core.Num.impl__u64__to_be_bytes len_bist in
Expand All @@ -350,7 +370,7 @@ let hash (msg: t_Slice u8) : t_Array u8 (sz 32) =
(fun last_block i ->
let last_block:t_Array u8 (sz 64) = last_block in
let i:usize = i in
Rust_primitives.Hax.update_at last_block
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize last_block
((v_BLOCK_SIZE -! v_LEN_SIZE <: usize) +! i <: usize)
(len_bist_bytes.[ i ] <: u8)
<:
Expand All @@ -375,7 +395,7 @@ let hash (msg: t_Slice u8) : t_Array u8 (sz 32) =
(fun pad_block i ->
let pad_block:t_Array u8 (sz 64) = pad_block in
let i:usize = i in
Rust_primitives.Hax.update_at pad_block
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize pad_block
((v_BLOCK_SIZE -! v_LEN_SIZE <: usize) +! i <: usize)
(len_bist_bytes.[ i ] <: u8)
<:
Expand Down
4 changes: 2 additions & 2 deletions proof-libs/fstar/core/Alloc.Vec.fst
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ open Rust_primitives

unfold type t_Vec t (_: unit) = s:t_Slice t

let impl__new #t: t_Vec t () = FStar.Seq.empty
let impl__new #t (): t_Vec t () = FStar.Seq.empty

let impl_2__extend_from_slice #t (self: t_Vec t ()) (other: t_Slice t{Seq.length self + Seq.length other <= max_usize}): t_Vec t ()
= FStar.Seq.append self other

let impl__with_capacity (_capacity: usize) = impl__new
let impl__with_capacity (_capacity: usize) = impl__new ()

// TODO: missing precondition For now, `impl_1__push` has a wrong
// semantics: pushing on a "full" vector does nothing. It should panic
Expand Down
57 changes: 57 additions & 0 deletions test-harness/src/snapshots/toolchain__attributes into-fstar.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
---
source: test-harness/src/harness.rs
expression: snapshot
info:
kind:
Translate:
backend: fstar
info:
name: attributes
manifest: attributes/Cargo.toml
description: ~
spec:
optional: false
broken: false
issue_id: ~
positive: true
snapshot:
stderr: false
stdout: true
---
exit = 0

[stdout]
diagnostics = []

[stdout.files]
"Attributes.fst" = '''
module Attributes
#set-options "--fuel 0 --ifuel 1 --z3rlimit 15"
open Core
open FStar.Mul

let add3_lemma (x: u32)
: Lemma Prims.l_True
(ensures
x <=. 10ul || x >=. (u32_max /! 3ul <: u32) || (add3 x x x <: u32) =. (x *! 3ul <: u32)) =
()

let u32_max: u32 = 90000ul

let add3 (x y z: u32)
: Prims.Pure u32
(requires x >. 10ul && y >. 10ul && z >. 10ul && ((x +! y <: u32) +! z <: u32) <. u32_max)
(ensures
fun result ->
let result:u32 = result in
Hax_lib.implies true
(fun temp_0_ ->
let _:Prims.unit = temp_0_ in
result >. 32ul <: bool)) = (x +! y <: u32) +! z

type t_Foo = {
f_x:u32;
f_y:f_y: u32{f_y >. 3ul};
f_z:f_z: u32{((f_y +! f_x <: u32) +! f_z <: u32) >. 3ul}
}
'''
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ let array (x: t_Array u8 (sz 10)) : t_Array u8 (sz 10) =
in
x

let f: Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global =
let vec:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global = Alloc.Vec.impl__new in
let f (_: Prims.unit) : Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global =
let vec:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global = Alloc.Vec.impl__new () in
let vec:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global = Alloc.Vec.impl_1__push vec 1uy in
let vec:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global = Alloc.Vec.impl_1__push vec 2uy in
let vec:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global = Core.Slice.impl__swap vec (sz 0) (sz 1) in
Expand All @@ -58,7 +58,7 @@ let h (x: u8) : u8 =
let x:u8 = x +! 10uy in
x

let build_vec: Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global =
let build_vec (_: Prims.unit) : Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global =
Alloc.Slice.impl__into_vec (Rust_primitives.unsize (Rust_primitives.Hax.box_new (let list =
[1uy; 2uy; 3uy]
in
Expand Down Expand Up @@ -115,8 +115,8 @@ let index_mutation_unsize (x: t_Array u8 (sz 12)) : u8 =
in
42uy

let test_append: Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global =
let vec1:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global = Alloc.Vec.impl__new in
let test_append (_: Prims.unit) : Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global =
let vec1:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global = Alloc.Vec.impl__new () in
let vec2:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global =
Alloc.Slice.impl__into_vec (Rust_primitives.unsize (Rust_primitives.Hax.box_new (let list =
[1uy; 2uy; 3uy]
Expand All @@ -136,7 +136,7 @@ let test_append: Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global =
let vec2:Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global = tmp1 in
let _:Prims.unit = () in
let vec1:(Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global & Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global) =
Alloc.Vec.impl_1__append vec1 (build_vec <: Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global)
Alloc.Vec.impl_1__append vec1 (build_vec () <: Alloc.Vec.t_Vec u8 Alloc.Alloc.t_Global)
in
vec1

Expand Down
Loading

0 comments on commit 1f17240

Please sign in to comment.