Skip to content

Commit

Permalink
Clean up union decoding and encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
johnridesabike committed Nov 18, 2024
1 parent 1ae0ba7 commit 04c87df
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 230 deletions.
218 changes: 109 additions & 109 deletions lib/instruct.ml
Original file line number Diff line number Diff line change
Expand Up @@ -573,34 +573,17 @@ end = struct

let stack_add debug x = (debug.stack_add @@ x) @@ debug.stack

type 'a union_helper = {
to_data : 'a exp -> Data.t exp;
of_data : Data.t exp -> 'a exp;
to_extern : 'a exp -> External.t exp;
classify : 'a External.classify;
type ('ty, 'extern) decode_union_helper = {
classify : 'extern External.classify;
if_equal :
'extern exp ->
'ty ->
then_:(Data.t exp -> unit stmt) ->
else_:(unit -> unit stmt) ->
unit stmt;
if_open : 'extern exp -> (Data.t exp -> unit stmt) -> unit stmt;
}

let union_helper_string =
{
to_data = Data.string;
of_data = Data.to_string;
to_extern = External.of_string;
classify = String;
}

let union_helper_int =
{
to_data = Data.int;
of_data = Data.to_int;
to_extern = External.of_int;
classify = Int;
}

let external_of_int_bool i = External.of_bool (int_to_bool i)

let union_helper_bool =
{ union_helper_int with to_extern = external_of_int_bool }

let rec decode ~set ~debug input ty =
let$ ty_str = ("type", show_type ty) in
match ty.contents with
Expand Down Expand Up @@ -746,74 +729,80 @@ end = struct
set (Data.hashtbl decoded))
a)
~error:(fun () -> push_error debug ty_str input)
| T.Union_int (key, { cases; row = _ }, Bool) ->
let key = string key in
External.classify Assoc input
~ok:(fun input' ->
let aux i v () =
match MapInt.find_opt i cases with
| Some tys ->
let$ decoded = ("decoded", hashtbl_create ()) in
let| () = decoded.%{key} <- v in
let| () =
decode_record_aux ~debug decoded input' tys.contents ty_str
in
set (Data.hashtbl decoded)
| None -> push_error debug ty_str input
in
if_else
(External.assoc_mem key input')
~then_:(fun () ->
External.classify Bool
(External.assoc_find key input')
~ok:(fun b ->
if_else b ~then_:(aux 1 true_value)
~else_:(aux 0 false_value))
~error:(fun () -> push_error debug ty_str input))
~else_:(fun () -> push_error debug ty_str input))
~error:(fun () -> push_error debug ty_str input)
| T.Union_int (key, { cases; row }, Bool) ->
decode_union
{
classify = Bool;
if_equal =
(fun extern ty ~then_ ~else_ ->
match ty with
| 0 ->
if_else (not extern)
~then_:(fun () -> then_ false_value)
~else_
| _ -> if_else extern ~then_:(fun () -> then_ true_value) ~else_);
if_open =
(fun x f ->
if_else x
~then_:(fun () -> f true_value)
~else_:(fun () -> f false_value));
}
(MapInt.to_seq cases) ~set ~debug (string key) input row ty_str
| T.Union_int (key, { cases; row }, Not_bool) ->
decode_union union_helper_int
(MapInt.to_seq cases |> Seq.map (fun (k, v) -> (int k, v)))
~set ~debug key input row ty_str
decode_union
{
classify = Int;
if_equal =
(fun extern ty ~then_ ~else_ ->
let ty = int ty in
if_else (extern = ty)
~then_:(fun () -> then_ (Data.int ty))
~else_);
if_open = (fun x f -> f (Data.int x));
}
(MapInt.to_seq cases) ~set ~debug (string key) input row ty_str
| T.Union_string (key, { cases; row }) ->
decode_union union_helper_string
(MapString.to_seq cases |> Seq.map (fun (k, v) -> (string k, v)))
~set ~debug key input row ty_str

and decode_union : 'a. 'a union_helper -> ('a exp * T.record) Seq.t -> _ =
fun { to_data; classify; _ } seq ~set ~debug key input row ty_str ->
let key' = string key in
decode_union
{
classify = String;
if_equal =
(fun extern ty ~then_ ~else_ ->
let ty = string ty in
if_else (extern = ty)
~then_:(fun () -> then_ (Data.string ty))
~else_);
if_open = (fun x f -> f (Data.string x));
}
(MapString.to_seq cases) ~set ~debug (string key) input row ty_str

and decode_union :
'a 'b. ('a, 'b) decode_union_helper -> ('a * T.record) Seq.t -> _ =
fun helper seq ~set ~debug key input row ty_str ->
External.classify Assoc input
~ok:(fun input' ->
if_else
(External.assoc_mem key' input')
(External.assoc_mem key input')
~then_:(fun () ->
External.classify classify
(External.assoc_find key' input')
~ok:(fun i ->
External.classify helper.classify
(External.assoc_find key input')
~ok:(fun x ->
let$ decoded = ("decoded", hashtbl_create ()) in
let rec aux seq =
match seq () with
| Seq.Nil -> (
match row with
| `Open ->
let$ decoded = ("decoded", hashtbl_create ()) in
let| () = decoded.%{key'} <- to_data i in
set (Data.hashtbl decoded)
| `Open -> helper.if_open x (fun x -> decoded.%{key} <- x)
| `Closed -> push_error debug ty_str input)
| Seq.Cons ((ty_i, tys), seq) ->
if_else (i = ty_i)
~then_:(fun () ->
let$ decoded = ("decoded", hashtbl_create ()) in
let| () = decoded.%{key'} <- to_data i in
let| () =
decode_record_aux ~debug decoded input' tys.contents
ty_str
in
set (Data.hashtbl decoded))
| Seq.Cons ((ty_x, tys), seq) ->
helper.if_equal x ty_x
~then_:(fun x ->
let| () = decoded.%{key} <- x in
decode_record_aux ~debug decoded input' tys.contents
ty_str)
~else_:(fun () -> aux seq)
in
aux seq)
let| () = aux seq in
set (Data.hashtbl decoded))
~error:(fun () -> push_error debug ty_str input))
~else_:(fun () -> push_error debug ty_str input))
~error:(fun () -> push_error debug ty_str input)
Expand Down Expand Up @@ -845,6 +834,13 @@ end = struct
(not (buffer_length missing_keys = int 0))
~then_:(fun () -> push_key_error debug ty_str missing_keys)

type 'a encode_union_helper = {
of_data : Data.t exp -> 'a exp;
to_extern : 'a exp -> External.t exp;
}

let external_of_int_bool i = External.of_bool (int_to_bool i)

let rec encode ~set props ty =
match ty.contents with
| T.Unknown _ -> set (Data.to_external_untyped props)
Expand Down Expand Up @@ -908,43 +904,47 @@ end = struct
in
set (External.of_hashtbl encoded)
| T.Union_int (key, { cases; row }, Bool) ->
encode_union union_helper_bool
encode_union
{ of_data = Data.to_int; to_extern = external_of_int_bool }
(MapInt.to_seq cases |> Seq.map (fun (k, v) -> (int k, v)))
row ~set key props
row ~set (string key) props
| T.Union_int (key, { cases; row }, Not_bool) ->
encode_union union_helper_int
encode_union
{ of_data = Data.to_int; to_extern = External.of_int }
(MapInt.to_seq cases |> Seq.map (fun (k, v) -> (int k, v)))
row ~set key props
row ~set (string key) props
| T.Union_string (key, { cases; row }) ->
encode_union union_helper_string
encode_union
{ of_data = Data.to_string; to_extern = External.of_string }
(MapString.to_seq cases |> Seq.map (fun (k, v) -> (string k, v)))
row ~set key props
row ~set (string key) props

and encode_union : 'a. 'a union_helper -> ('a exp * T.record) Seq.t -> _ =
fun { of_data; to_extern; _ } cases row ~set key props ->
let key = string key in
and encode_union :
'a. 'a encode_union_helper -> ('a exp * T.record) Seq.t -> _ =
fun { of_data; to_extern } cases row ~set key props ->
let$ props = ("props", Data.to_hashtbl props) in
let$ tag = ("tag", props.%{key} |> of_data) in
let rec aux (tag', tys) seq =
if_else (tag = tag')
~then_:(fun () ->
let$ encoded = ("encoded", hashtbl_create ()) in
let| () = encoded.%{key} <- to_extern tag in
let| () = encode_record_aux encoded props tys.contents in
set (External.of_hashtbl encoded))
~else_:
(match seq () with
| Seq.Cons (hd, seq) -> fun () -> aux hd seq
| Seq.Nil -> (
match row with
| `Closed -> fun () -> unit
| `Open ->
fun () ->
let$ encoded = ("encoded", hashtbl_create ()) in
let| () = encoded.%{key} <- to_extern tag in
set (External.of_hashtbl encoded)))
let$ encoded = ("encoded", hashtbl_create ()) in
let| () =
match cases () with
| Seq.Nil -> unit
| Seq.Cons (hd, seq) ->
let rec aux (tag', tys) seq =
if_else (tag = tag')
~then_:(fun () ->
let| () = encoded.%{key} <- to_extern tag in
encode_record_aux encoded props tys.contents)
~else_:(fun () ->
match seq () with
| Seq.Nil -> (
match row with
| `Closed -> unit
| `Open -> encoded.%{key} <- to_extern tag)
| Seq.Cons (hd, seq) -> aux hd seq)
in
aux hd seq
in
match cases () with Seq.Nil -> unit | Seq.Cons (hd, seq) -> aux hd seq
set (External.of_hashtbl encoded)

and encode_record_aux encoded props tys =
MapString.to_seq tys
Expand Down
67 changes: 35 additions & 32 deletions test/parse-test.t/run.t
Original file line number Diff line number Diff line change
Expand Up @@ -1530,40 +1530,43 @@ Print the runtime instructions
(External.classify (bool)
(External.assoc_find "tag" classified/52)
(ok classified/53
(if_else classified/53
(let$ decoded/12 = (hashtbl_create))
(if_else (not classified/53)
(then
(let$ decoded/13 = (hashtbl_create))
(decoded/13.%{"tag"} <- (Data.int 1))
(decoded/12.%{"tag"} <- (Data.int 0))
(let$ missing_keys/6 = (buffer_create))
(if_else (External.assoc_mem "a" classified/52)
(then
(let$ input/35 = (External.assoc_find "a" classified/52))
(let$ stack/43 = ((stack_add/0 @@ "a") @@ stack/42))
(let$ type/47 = "string")
(External.classify (string) input/35
(ok classified/54
(decoded/13.%{"a"} <- (Data.string classified/54)))
(error
(stmt
(((decode_error/0 @@ input/35) @@ stack/43) @@ type/47)))))
(else
(stmt
(((buffer_add_sep/0 @@ missing_keys/6) @@ ", ") @@ "a"))))
(unit)
(if (not ((buffer_length missing_keys/6) = 0))
(then
(stmt
(((key_error/0 @@ missing_keys/6) @@ stack/42) @@ type/46))))
(props/0.%{"tagged"} <- (Data.hashtbl decoded/13)))
(((key_error/0 @@ missing_keys/6) @@ stack/42) @@ type/46)))))
(else
(let$ decoded/12 = (hashtbl_create))
(decoded/12.%{"tag"} <- (Data.int 0))
(let$ missing_keys/5 = (buffer_create))
(unit)
(if (not ((buffer_length missing_keys/5) = 0))
(if_else classified/53
(then
(decoded/12.%{"tag"} <- (Data.int 1))
(let$ missing_keys/5 = (buffer_create))
(if_else (External.assoc_mem "a" classified/52)
(then
(let$ input/35 = (External.assoc_find "a" classified/52))
(let$ stack/43 = ((stack_add/0 @@ "a") @@ stack/42))
(let$ type/47 = "string")
(External.classify (string) input/35
(ok classified/54
(decoded/12.%{"a"} <- (Data.string classified/54)))
(error
(stmt
(((decode_error/0 @@ input/35) @@ stack/43) @@ type/47)))))
(else
(stmt
(((buffer_add_sep/0 @@ missing_keys/5) @@ ", ") @@ "a"))))
(if (not ((buffer_length missing_keys/5) = 0))
(then
(stmt
(((key_error/0 @@ missing_keys/5) @@ stack/42) @@ type/46)))))
(else
(stmt
(((key_error/0 @@ missing_keys/5) @@ stack/42) @@ type/46))))
(props/0.%{"tagged"} <- (Data.hashtbl decoded/12)))))
(((decode_error/0 @@ input/34) @@ stack/42) @@ type/46))))))
(props/0.%{"tagged"} <- (Data.hashtbl decoded/12)))
(error
(stmt (((decode_error/0 @@ input/34) @@ stack/42) @@ type/46)))))
(else
Expand Down Expand Up @@ -1657,7 +1660,7 @@ Print the runtime instructions
(ok classified/62
(if_else ((External.length classified/62) = 3)
(then
(let$ decoded/14 = (array_make 3 (Data.int 0)))
(let$ decoded/13 = (array_make 3 (Data.int 0)))
(External.iteri classified/62 key/5 value/5
(let$ stack/52 =
((stack_add/0 @@ (string_of_int key/5)) @@ stack/51))
Expand All @@ -1666,7 +1669,7 @@ Print the runtime instructions
(let$ type/58 = "int")
(External.classify (int) value/5
(ok classified/66
(decoded/14.%(key/5) <- (Data.int classified/66)))
(decoded/13.%(key/5) <- (Data.int classified/66)))
(error
(stmt (((decode_error/0 @@ value/5) @@ stack/52) @@ type/58)))))
(else
Expand All @@ -1675,11 +1678,11 @@ Print the runtime instructions
(let$ type/57 = "float")
(External.classify (float) value/5
(ok classified/64
(decoded/14.%(key/5) <- (Data.float classified/64)))
(decoded/13.%(key/5) <- (Data.float classified/64)))
(error
(External.classify (int) value/5
(ok classified/65
(decoded/14.%(key/5) <-
(decoded/13.%(key/5) <-
(Data.float (float_of_int classified/65))))
(error
(stmt
Expand All @@ -1690,14 +1693,14 @@ Print the runtime instructions
(let$ type/56 = "string")
(External.classify (string) value/5
(ok classified/63
(decoded/14.%(key/5) <- (Data.string classified/63)))
(decoded/13.%(key/5) <- (Data.string classified/63)))
(error
(stmt
(((decode_error/0 @@ value/5) @@ stack/52) @@ type/56)))))
(else
(stmt
(((decode_error/0 @@ value/5) @@ stack/52) @@ type/55))))))))
(props/0.%{"tuple"} <- (Data.array decoded/14))))
(props/0.%{"tuple"} <- (Data.array decoded/13))))
(else
(stmt (((decode_error/0 @@ input/43) @@ stack/51) @@ type/55)))))
(error (stmt (((decode_error/0 @@ input/43) @@ stack/51) @@ type/55)))))
Expand Down
Loading

0 comments on commit 04c87df

Please sign in to comment.