Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate deep disjunctive patterns. #830

Merged
merged 3 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions engine/backends/fstar/fstar_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,7 @@ module TransformToInputLanguage =
|> Phases.Drop_references
|> Phases.Trivialize_assign_lhs
|> Side_effect_utils.Hoist
|> Phases.Hoist_disjunctive_patterns
|> Phases.Simplify_match_return
|> Phases.Drop_needless_returns
|> Phases.Local_mutation
Expand Down
1 change: 1 addition & 0 deletions engine/lib/diagnostics.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ module Phase = struct
| SimplifyQuestionMarks
| Specialize
| HoistSideEffects
| HoistDisjunctions
| LocalMutation
| TrivializeAssignLhs
| CfIntoMonads
Expand Down
115 changes: 115 additions & 0 deletions engine/lib/phases/phase_hoist_disjunctive_patterns.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
(* This phase transforms deep disjunctive patterns in equivalent
shallow ones. For example `Some(1 | 2)` becomes `Some(1) | Some(2)` *)

open! Prelude

module Make (F : Features.T) =
Phase_utils.MakeMonomorphicPhase
(F)
(struct
let phase_id = Diagnostics.Phase.HoistDisjunctions

open Ast.Make (F)
module U = Ast_utils.Make (F)
module Visitors = Ast_visitors.Make (F)

module Error = Phase_utils.MakeError (struct
let ctx = Diagnostics.Context.Phase phase_id
end)

let hoist_disjunctions =
object (self)
inherit [_] Visitors.map as super

method! visit_pat () p =
let return_pat p' = { p = p'; span = p.span; typ = p.typ } in

(* When there is a list of subpaterns, we use the distributivity of nested
disjunctions: (a | b, c | d) gives (a, c) | (a, d) | (b, c) | (b,d) *)
let rec treat_args cases = function
| { p = POr { subpats }; _ } :: tail ->
treat_args
(List.concat_map
~f:(fun subpat ->
List.map ~f:(fun args -> subpat :: args) cases)
subpats)
tail
| pat :: tail ->
let pat = self#visit_pat () pat in
treat_args (List.map ~f:(fun args -> pat :: args) cases) tail
| [] -> cases
in
let subpats_to_disj subpats =
match subpats with
| [ pat ] -> pat
| _ -> POr { subpats } |> return_pat
in

(* When there is one subpattern, we check if it is a disjunction,
and if it is, we hoist it. *)
let treat_subpat pat to_pattern =
let subpat = self#visit_pat () pat in
match subpat with
| { p = POr { subpats }; span; _ } ->
return_pat
(POr
{
subpats =
List.map
~f:(fun pat ->
{ p = to_pattern pat; span; typ = p.typ })
subpats;
})
| _ -> p
in

match p.p with
| PConstruct { name; args; is_record; is_struct } ->
let args_as_pat =
List.rev_map args ~f:(fun arg -> self#visit_pat () arg.pat)
in
let subpats =
List.map (treat_args [ [] ] args_as_pat)
~f:(fun args_as_pat ->
let args =
List.map2_exn args_as_pat args
~f:(fun pat { field; _ } -> { field; pat })
in
PConstruct { name; args; is_record; is_struct }
|> return_pat)
in

subpats_to_disj subpats
| PArray { args } ->
let subpats =
List.map
~f:(fun args -> PArray { args } |> return_pat)
(treat_args [ [] ]
(List.rev_map args ~f:(fun arg -> self#visit_pat () arg)))
in
subpats_to_disj subpats
| POr { subpats } ->
let subpats = List.map ~f:(self#visit_pat ()) subpats in
POr
{
subpats =
List.concat_map
~f:(function
| { p = POr { subpats }; _ } -> subpats | p -> [ p ])
subpats;
}
|> return_pat
| PAscription { typ; typ_span; pat } ->
treat_subpat pat (fun pat -> PAscription { typ; typ_span; pat })
| PBinding { subpat = Some (pat, as_pat); mut; mode; typ; var } ->
treat_subpat pat (fun pat ->
PBinding
{ subpat = Some (pat, as_pat); mut; mode; typ; var })
| PDeref { subpat; witness } ->
treat_subpat subpat (fun subpat -> PDeref { subpat; witness })
| PWild | PConstant _ | PBinding { subpat = None; _ } ->
super#visit_pat () p
W95Psp marked this conversation as resolved.
Show resolved Hide resolved
end

let ditems = List.map ~f:(hoist_disjunctions#visit_item ())
end)
5 changes: 5 additions & 0 deletions engine/lib/phases/phase_hoist_disjunctive_patterns.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
(** This phase eliminates nested disjunctive patterns (leaving
only shallow disjunctions). It moves the disjunctions up
to the top-level pattern. *)

module Make : Phase_utils.UNCONSTRAINTED_MONOMORPHIC_PHASE
26 changes: 26 additions & 0 deletions test-harness/src/snapshots/toolchain__pattern-or into-coq.snap
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,30 @@ Definition bar (x : t_E_t) : unit :=
| E_A | E_B =>
tt
end.

Definition deep (x : int32 × t_Option_t int32) : int32 :=
match x with
| '((@repr WORDSIZE32 1) | (@repr WORDSIZE32 2),Option_Some (@repr WORDSIZE32 3) | (@repr WORDSIZE32 4)) =>
(@repr WORDSIZE32 0)
| '(x,_) =>
x
end.

Definition equivalent (x : int32 × t_Option_t int32) : int32 :=
match x with
| '((@repr WORDSIZE32 1),Option_Some (@repr WORDSIZE32 3)) | '((@repr WORDSIZE32 1),Option_Some (@repr WORDSIZE32 4)) | '((@repr WORDSIZE32 2),Option_Some (@repr WORDSIZE32 3)) | '((@repr WORDSIZE32 2),Option_Some (@repr WORDSIZE32 4)) =>
(@repr WORDSIZE32 0)
| '(x,_) =>
x
end.

Definition nested (x : t_Option_t int32) : int32 :=
match x with
| Option_Some (@repr WORDSIZE32 1) | (@repr WORDSIZE32 2) =>
(@repr WORDSIZE32 1)
| Option_Some x =>
x
| Option_None =>
(@repr WORDSIZE32 0)
end.
'''
22 changes: 22 additions & 0 deletions test-harness/src/snapshots/toolchain__pattern-or into-fstar.snap
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,26 @@ let t_E_cast_to_repr (x: t_E) : isize =
| E_B -> isz 1

let bar (x: t_E) : Prims.unit = match x with | E_A | E_B -> () <: Prims.unit

let deep (x: (i32 & Core.Option.t_Option i32)) : i32 =
match x with
| 1l, Core.Option.Option_Some 3l
| 1l, Core.Option.Option_Some 4l
| 2l, Core.Option.Option_Some 3l
| 2l, Core.Option.Option_Some 4l -> 0l
| x, _ -> x

let equivalent (x: (i32 & Core.Option.t_Option i32)) : i32 =
match x with
| 1l, Core.Option.Option_Some 3l
| 1l, Core.Option.Option_Some 4l
| 2l, Core.Option.Option_Some 3l
| 2l, Core.Option.Option_Some 4l -> 0l
| x, _ -> x

let nested (x: Core.Option.t_Option i32) : i32 =
match x with
| Core.Option.Option_Some 1l | Core.Option.Option_Some 2l -> 1l
| Core.Option.Option_Some x -> x
| Core.Option.Option_None -> 0l
'''
21 changes: 21 additions & 0 deletions tests/pattern-or/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,24 @@ pub fn bar(x: E) {
E::A | E::B => (),
}
}
pub fn nested(x: Option<i32>) -> i32 {
match x {
Some(1 | 2) => 1,
Some(x) => x,
None => 0,
}
}

pub fn deep(x: (i32, Option<i32>)) -> i32 {
match x {
(1 | 2, Some(3 | 4)) => 0,
(x, _) => x,
}
}

pub fn equivalent(x: (i32, Option<i32>)) -> i32 {
match x {
(1, Some(3)) | (1, Some(4)) | (2, Some(3)) | (2, Some(4)) => 0,
(x, _) => x,
}
}
Loading