Skip to content

Commit

Permalink
Lift restriction on array accesses
Browse files Browse the repository at this point in the history
Array accesses of the form A[0] and A[n-i] are allowed if they do not occur
(directly or indirectly) in the definition of the array A itself.
  • Loading branch information
daniel-larraz committed Nov 6, 2023
1 parent 9566bde commit dbd2c27
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 39 deletions.
117 changes: 82 additions & 35 deletions src/lustre/lustreArrayDependencies.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,46 +29,62 @@ module Chk = LustreTypeChecker

module StringMap = HString.HStringMap

type index =
| Val of int
| Unk of Lib.position * A.expr (* Unknown *)

let is_unknown = function
| Val _ -> false
| Unk _ -> true

let lexiographic_order idx1 idx2 =
let idx1_len = List.length idx1 in
let idx2_len = List.length idx2 in
if Int.equal idx1_len idx2_len then
let compared_elements = List.map (fun (i, j) -> i - j) (List.combine idx1 idx2) in
let compared_elements =
List.map
(fun (i, j) ->
match i, j with
| Val i, Val j -> i - j
| Val _, Unk _ -> -1
| Unk _, Val _ -> 1
| Unk _, Unk _ -> 0
)
(List.combine idx1 idx2) in
match (List.filter (fun i -> i != 0) compared_elements) with
| x :: _ -> x
| [] -> 0
else idx1_len - idx2_len

module G = Graph.Make(struct
type t = A.ident * int list
type t = A.ident * index list

let compare (id1, idx1) (id2, idx2) =
match HString.compare id1 id2 with
| 0 -> lexiographic_order idx1 idx2
| x -> x

let print_r_index fmt = function
| Val i -> Format.pp_print_int fmt i
| Unk _ -> Format.pp_print_string fmt "?"

let pp_print_t = Lib.pp_print_pair
(A.pp_print_ident)
(Lib.pp_print_list Format.pp_print_int " ")
(Lib.pp_print_list print_r_index " ")
" "
end)

type error_kind = Unknown of string
| ComplicatedExpr of LustreAst.expr
| ExprMissingIndex of HString.t * LustreAst.expr
| Cycle of HString.t list

let error_message = function
| Unknown e -> e
| ComplicatedExpr e -> "The expression '"
^ (Lib.string_of_t A.pp_print_expr e)
^ "' is too complicated in definition of inductive array"
| ComplicatedExpr e ->
"Couldn't determine well-foundedness of array definition because of index expression '" ^
(Lib.string_of_t A.pp_print_expr e) ^ "'"
| Cycle ids -> "Cyclic dependency detected in definition of identifiers: "
^ (Lib.string_of_t (Lib.pp_print_list A.pp_print_ident " -> ") ids)
| ExprMissingIndex (i, e) -> "The index expression '"
^ (Lib.string_of_t A.pp_print_expr e)
^ "' must contain the index variable '"
^ (HString.string_of_hstring i) ^ "'"

type error = [
| `LustreArrayDependencies of Lib.position * error_kind
Expand Down Expand Up @@ -113,7 +129,7 @@ and process_equation ctx ns = function

and process_lhs ctx ns proj expr = function
| (A.ArrayDef (pos, id, indices) :: tail) ->
let zero_list = List.map (fun _ -> 0) indices in
let zero_list = List.map (fun _ -> Val 0) indices in
let* expr_graph = process_expr (Some (List.rev indices)) ctx ns proj [] expr in
let expr_graph = G.connect expr_graph (id, zero_list) in
let* (tail_graph, tail_pos_map, count, len) = process_lhs ctx ns (proj + 1) expr tail in
Expand All @@ -123,7 +139,7 @@ and process_lhs ctx ns proj expr = function
R.ok (graph, map, count + 1, max len (List.length indices))
| (A.SingleIdent (pos, id) :: tail) ->
let* expr_graph = process_expr None ctx ns proj [] expr in
let expr_graph = G.connect expr_graph (id, 0 :: []) in
let expr_graph = G.connect expr_graph (id, [Val 0]) in
let* (tail_graph, tail_pos_map, count, len) = process_lhs ctx ns (proj + 1) expr tail in
let graph = union expr_graph tail_graph in
let map = StringMap.singleton id pos in
Expand Down Expand Up @@ -173,6 +189,7 @@ and process_expr ind_vars ctx ns proj indices expr =
(match ind_vars with
| Some ind_vars' ->
let ind_var = List.hd ind_vars' in
let ind_vars = Some (List.tl ind_vars') in
let idx' = AH.substitute_naive ind_var zero idx in
(match AIC.eval_int_expr ctx idx' with
| Ok idx' ->
Expand All @@ -181,18 +198,20 @@ and process_expr ind_vars ctx ns proj indices expr =
A.SI.mem ind_var idx_vars &&
not (AH.expr_contains_call idx)
then
let ind_vars = Some (List.tl ind_vars') in
process_expr ind_vars ctx ns proj (idx' :: indices) e
process_expr ind_vars ctx ns proj (Val idx' :: indices) e
else
mk_error p (ExprMissingIndex (ind_var, idx))
| Error _ -> mk_error p (ComplicatedExpr idx))
process_expr ind_vars ctx ns proj (Unk (p, idx) :: indices) e
| Error _ -> (
process_expr ind_vars ctx ns proj (Unk (p, idx) :: indices) e
)
)
| None -> r e)
else r e
(* Quantified expressions *)
| Quantifier (_, _, vars, e) ->
let* graph = r e in
let graph = List.fold_left
(fun acc (_, id, _) -> G.remove_vertex acc (id, 0 :: []))
(fun acc (_, id, _) -> G.remove_vertex acc (id, [Val 0]))
graph
vars
in
Expand Down Expand Up @@ -229,6 +248,21 @@ and process_expr ind_vars ctx ns proj indices expr =
empty_
dep_args

let extract_unknown ids =
let unknowns =
List.filter_map
(fun (_, indices) ->
match List.find_opt is_unknown indices with
| Some (Unk (p, e)) -> Some (p, e)
| None -> None
| _ -> assert false
)
ids
in
match unknowns with
| [] -> None
| unk :: _ -> Some unk

let rec check_inductive_array_dependencies ctx ns = function
| (A.NodeDecl (_, decl)) :: tail | (A.FuncDecl (_, decl)) :: tail ->
check_node_decl ctx ns decl
Expand Down Expand Up @@ -269,17 +303,20 @@ and check_node_decl ctx ns decl =
(* Format.eprintf "After offsets: %a@." G.pp_print_graph graph; *)
let graph = add_wellfounded_edges idx_len graph in
(* Format.eprintf "After wellfounded: %a@." G.pp_print_graph graph; *)
let* _ = (try (Res.ok (G.topological_sort graph)) with
| G.CyclicGraphException ids ->
let (id, _) = List.hd ids in
let pos = StringMap.find id pos_map in
let ids = List.map (fun (id, idx) ->
let idxs = List.map (fun x -> HString.mk_hstring (string_of_int x)) idx in
let idxs = HString.concat (HString.mk_hstring " ") idxs in
HString.concat (HString.mk_hstring " ") [id;idxs])
ids
in
mk_error pos (Cycle ids))
let* _ =
try
Res.ok (G.topological_sort graph)
with
| G.CyclicGraphException ids -> (
match extract_unknown ids with
| Some (p, e) -> mk_error p (ComplicatedExpr e)
| None ->
let (id, _) = List.hd ids in
let pos = StringMap.find id pos_map in
let ids = List.map (fun idx ->
Format.asprintf "%a" G.pp_print_vertex idx |> HString.mk_hstring) ids in
mk_error pos (Cycle ids)
)
in
R.ok ()

Expand Down Expand Up @@ -351,17 +388,27 @@ and add_wellfounded_edges idx_len graph =
and add_offset_edges count graph =
let vertices = G.get_vertices graph |> G.to_vertex_list in
let non_zero_vertices = List.filter
(fun (_, idxs) -> List.fold_left (&&) true (List.map (fun i -> i != 0) idxs))
(fun (_, idxs) ->
List.fold_left (&&) true
(List.map (function | Val i -> i != 0 | Unk _ -> true) idxs)
)
vertices
in
let mk_edges ((id, offsets) as v) =
let mk_edges ((id, offsets) as v : G.vertex) =
let offsets_len = List.length offsets in
let nhbd = G.children graph
(id, List.mapi (fun _ _ -> 0) offsets)
(id, List.mapi (fun _ _ -> Val 0) offsets)
in
let nhbd_offset = List.map
(fun (id', offsets') -> (id',
List.mapi (fun i e -> if i < offsets_len then e + (List.nth offsets i) else e) offsets'))
List.mapi (fun i e ->
if i < offsets_len then
match e, List.nth offsets i with
| Val e', Val o_i -> Val (e' + o_i)
| Unk (p, e), _ -> Unk (p, e)
| Val _, Unk (p, e) -> Unk (p, e)
else e)
offsets'))
nhbd
in
let graph = List.fold_left G.add_vertex G.empty nhbd_offset in
Expand All @@ -372,7 +419,7 @@ and add_offset_edges count graph =
let vertices = G.get_vertices graph |> G.to_vertex_list in
List.filter (fun v -> not (List.mem v old)) vertices
in
let rec loop n vertices graph =
let rec loop n (vertices : G.vertex list) graph =
if n <= 0 then graph
else
let new_edges = List.fold_left
Expand Down Expand Up @@ -402,7 +449,7 @@ and add_init_edges idx_len graph =
let filled_offset = List.init idx_len (fun i ->
match (List.nth_opt offset i) with
| Some x -> x
| None -> 0)
| None -> Val 0)
in
let filled_graph = G.singleton (id, filled_offset) in
G.connect filled_graph v
Expand Down
1 change: 0 additions & 1 deletion src/lustre/lustreArrayDependencies.mli
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

type error_kind = Unknown of string
| ComplicatedExpr of LustreAst.expr
| ExprMissingIndex of HString.t * LustreAst.expr
| Cycle of HString.t list

val error_message: error_kind -> string
Expand Down
6 changes: 3 additions & 3 deletions tests/ounit/lustre/testLustreFrontend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ let _ = run_test_tt_main ("frontend lustreArrayDependencies error tests" >::: [
| _ -> false);
mk_test "test invalid inductive array def 3" (fun () ->
match load_file "./lustreArrayDependencies/inductive_array3.lus" with
| Error (`LustreArrayDependencies (_, ExprMissingIndex _)) -> true
| Error (`LustreArrayDependencies (_, ComplicatedExpr _)) -> true
| _ -> false);
mk_test "test invalid inductive array def 4" (fun () ->
match load_file "./lustreArrayDependencies/inductive_array4.lus" with
| Error (`LustreArrayDependencies (_, Cycle _)) -> true
| _ -> false);
mk_test "test invalid inductive array def 5" (fun () ->
match load_file "./lustreArrayDependencies/inductive_array5.lus" with
| Error (`LustreArrayDependencies (_, ExprMissingIndex _)) -> true
| Error (`LustreArrayDependencies (_, ComplicatedExpr _)) -> true
| _ -> false);
mk_test "test invalid inductive array def 6" (fun () ->
match load_file "./lustreArrayDependencies/inductive_array6.lus" with
Expand Down Expand Up @@ -699,4 +699,4 @@ let _ = run_test_tt_main ("frontend LustreAbstractInterpretation error tests" >:
match load_file "./lustreAbstractInterpretation/subrange_bug6.lus" with
| Error (`LustreAbstractInterpretationError (_, ConstantOutOfSubrange _)) -> true
| _ -> false);
])
])
14 changes: 14 additions & 0 deletions tests/regression/success/test_recursive_array_def.lus
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@

node sum (const n: int; A: int ^ n) returns (s: int);
var cumul: int ^ n;
let
cumul[i] = if i = 0 then A[0] else A[i] + cumul[i-1];
s = cumul[n-1];
tel

node slice (const n: int; A: int ^ n; const low: int; const up: int)
returns (B : int ^ (up-low));
let
B[i] = A[low + i];
tel

node a(A: int^5) returns (B: int^5)
let
B[i] = A[i] + 1;
Expand Down

0 comments on commit dbd2c27

Please sign in to comment.