diff --git a/src/lustre/lustreArrayDependencies.ml b/src/lustre/lustreArrayDependencies.ml index abbec76d0..430468c14 100644 --- a/src/lustre/lustreArrayDependencies.ml +++ b/src/lustre/lustreArrayDependencies.ml @@ -29,46 +29,62 @@ module Chk = LustreTypeChecker module StringMap = HString.HStringMap +type index = + | Val of int + | Unk of Lib.position * A.expr (* Unknown *) + +let is_unknown = function + | Val _ -> false + | Unk _ -> true + let lexiographic_order idx1 idx2 = let idx1_len = List.length idx1 in let idx2_len = List.length idx2 in if Int.equal idx1_len idx2_len then - let compared_elements = List.map (fun (i, j) -> i - j) (List.combine idx1 idx2) in + let compared_elements = + List.map + (fun (i, j) -> + match i, j with + | Val i, Val j -> i - j + | Val _, Unk _ -> -1 + | Unk _, Val _ -> 1 + | Unk _, Unk _ -> 0 + ) + (List.combine idx1 idx2) in match (List.filter (fun i -> i != 0) compared_elements) with | x :: _ -> x | [] -> 0 else idx1_len - idx2_len module G = Graph.Make(struct - type t = A.ident * int list + type t = A.ident * index list let compare (id1, idx1) (id2, idx2) = match HString.compare id1 id2 with | 0 -> lexiographic_order idx1 idx2 | x -> x + let print_r_index fmt = function + | Val i -> Format.pp_print_int fmt i + | Unk _ -> Format.pp_print_string fmt "?" + let pp_print_t = Lib.pp_print_pair (A.pp_print_ident) - (Lib.pp_print_list Format.pp_print_int " ") + (Lib.pp_print_list print_r_index " ") " " end) type error_kind = Unknown of string | ComplicatedExpr of LustreAst.expr - | ExprMissingIndex of HString.t * LustreAst.expr | Cycle of HString.t list let error_message = function | Unknown e -> e - | ComplicatedExpr e -> "The expression '" - ^ (Lib.string_of_t A.pp_print_expr e) - ^ "' is too complicated in definition of inductive array" + | ComplicatedExpr e -> + "Couldn't determine well-foundedness of array definition because of index expression '" ^ + (Lib.string_of_t A.pp_print_expr e) ^ "'" | Cycle ids -> "Cyclic dependency detected in definition of identifiers: " ^ (Lib.string_of_t (Lib.pp_print_list A.pp_print_ident " -> ") ids) - | ExprMissingIndex (i, e) -> "The index expression '" - ^ (Lib.string_of_t A.pp_print_expr e) - ^ "' must contain the index variable '" - ^ (HString.string_of_hstring i) ^ "'" type error = [ | `LustreArrayDependencies of Lib.position * error_kind @@ -113,7 +129,7 @@ and process_equation ctx ns = function and process_lhs ctx ns proj expr = function | (A.ArrayDef (pos, id, indices) :: tail) -> - let zero_list = List.map (fun _ -> 0) indices in + let zero_list = List.map (fun _ -> Val 0) indices in let* expr_graph = process_expr (Some (List.rev indices)) ctx ns proj [] expr in let expr_graph = G.connect expr_graph (id, zero_list) in let* (tail_graph, tail_pos_map, count, len) = process_lhs ctx ns (proj + 1) expr tail in @@ -123,7 +139,7 @@ and process_lhs ctx ns proj expr = function R.ok (graph, map, count + 1, max len (List.length indices)) | (A.SingleIdent (pos, id) :: tail) -> let* expr_graph = process_expr None ctx ns proj [] expr in - let expr_graph = G.connect expr_graph (id, 0 :: []) in + let expr_graph = G.connect expr_graph (id, [Val 0]) in let* (tail_graph, tail_pos_map, count, len) = process_lhs ctx ns (proj + 1) expr tail in let graph = union expr_graph tail_graph in let map = StringMap.singleton id pos in @@ -173,6 +189,7 @@ and process_expr ind_vars ctx ns proj indices expr = (match ind_vars with | Some ind_vars' -> let ind_var = List.hd ind_vars' in + let ind_vars = Some (List.tl ind_vars') in let idx' = AH.substitute_naive ind_var zero idx in (match AIC.eval_int_expr ctx idx' with | Ok idx' -> @@ -181,18 +198,20 @@ and process_expr ind_vars ctx ns proj indices expr = A.SI.mem ind_var idx_vars && not (AH.expr_contains_call idx) then - let ind_vars = Some (List.tl ind_vars') in - process_expr ind_vars ctx ns proj (idx' :: indices) e + process_expr ind_vars ctx ns proj (Val idx' :: indices) e else - mk_error p (ExprMissingIndex (ind_var, idx)) - | Error _ -> mk_error p (ComplicatedExpr idx)) + process_expr ind_vars ctx ns proj (Unk (p, idx) :: indices) e + | Error _ -> ( + process_expr ind_vars ctx ns proj (Unk (p, idx) :: indices) e + ) + ) | None -> r e) else r e (* Quantified expressions *) | Quantifier (_, _, vars, e) -> let* graph = r e in let graph = List.fold_left - (fun acc (_, id, _) -> G.remove_vertex acc (id, 0 :: [])) + (fun acc (_, id, _) -> G.remove_vertex acc (id, [Val 0])) graph vars in @@ -229,6 +248,21 @@ and process_expr ind_vars ctx ns proj indices expr = empty_ dep_args +let extract_unknown ids = + let unknowns = + List.filter_map + (fun (_, indices) -> + match List.find_opt is_unknown indices with + | Some (Unk (p, e)) -> Some (p, e) + | None -> None + | _ -> assert false + ) + ids + in + match unknowns with + | [] -> None + | unk :: _ -> Some unk + let rec check_inductive_array_dependencies ctx ns = function | (A.NodeDecl (_, decl)) :: tail | (A.FuncDecl (_, decl)) :: tail -> check_node_decl ctx ns decl @@ -269,17 +303,20 @@ and check_node_decl ctx ns decl = (* Format.eprintf "After offsets: %a@." G.pp_print_graph graph; *) let graph = add_wellfounded_edges idx_len graph in (* Format.eprintf "After wellfounded: %a@." G.pp_print_graph graph; *) - let* _ = (try (Res.ok (G.topological_sort graph)) with - | G.CyclicGraphException ids -> - let (id, _) = List.hd ids in - let pos = StringMap.find id pos_map in - let ids = List.map (fun (id, idx) -> - let idxs = List.map (fun x -> HString.mk_hstring (string_of_int x)) idx in - let idxs = HString.concat (HString.mk_hstring " ") idxs in - HString.concat (HString.mk_hstring " ") [id;idxs]) - ids - in - mk_error pos (Cycle ids)) + let* _ = + try + Res.ok (G.topological_sort graph) + with + | G.CyclicGraphException ids -> ( + match extract_unknown ids with + | Some (p, e) -> mk_error p (ComplicatedExpr e) + | None -> + let (id, _) = List.hd ids in + let pos = StringMap.find id pos_map in + let ids = List.map (fun idx -> + Format.asprintf "%a" G.pp_print_vertex idx |> HString.mk_hstring) ids in + mk_error pos (Cycle ids) + ) in R.ok () @@ -351,17 +388,27 @@ and add_wellfounded_edges idx_len graph = and add_offset_edges count graph = let vertices = G.get_vertices graph |> G.to_vertex_list in let non_zero_vertices = List.filter - (fun (_, idxs) -> List.fold_left (&&) true (List.map (fun i -> i != 0) idxs)) + (fun (_, idxs) -> + List.fold_left (&&) true + (List.map (function | Val i -> i != 0 | Unk _ -> true) idxs) + ) vertices in - let mk_edges ((id, offsets) as v) = + let mk_edges ((id, offsets) as v : G.vertex) = let offsets_len = List.length offsets in let nhbd = G.children graph - (id, List.mapi (fun _ _ -> 0) offsets) + (id, List.mapi (fun _ _ -> Val 0) offsets) in let nhbd_offset = List.map (fun (id', offsets') -> (id', - List.mapi (fun i e -> if i < offsets_len then e + (List.nth offsets i) else e) offsets')) + List.mapi (fun i e -> + if i < offsets_len then + match e, List.nth offsets i with + | Val e', Val o_i -> Val (e' + o_i) + | Unk (p, e), _ -> Unk (p, e) + | Val _, Unk (p, e) -> Unk (p, e) + else e) + offsets')) nhbd in let graph = List.fold_left G.add_vertex G.empty nhbd_offset in @@ -372,7 +419,7 @@ and add_offset_edges count graph = let vertices = G.get_vertices graph |> G.to_vertex_list in List.filter (fun v -> not (List.mem v old)) vertices in - let rec loop n vertices graph = + let rec loop n (vertices : G.vertex list) graph = if n <= 0 then graph else let new_edges = List.fold_left @@ -402,7 +449,7 @@ and add_init_edges idx_len graph = let filled_offset = List.init idx_len (fun i -> match (List.nth_opt offset i) with | Some x -> x - | None -> 0) + | None -> Val 0) in let filled_graph = G.singleton (id, filled_offset) in G.connect filled_graph v diff --git a/src/lustre/lustreArrayDependencies.mli b/src/lustre/lustreArrayDependencies.mli index b52233196..aec8a73e5 100644 --- a/src/lustre/lustreArrayDependencies.mli +++ b/src/lustre/lustreArrayDependencies.mli @@ -21,7 +21,6 @@ type error_kind = Unknown of string | ComplicatedExpr of LustreAst.expr - | ExprMissingIndex of HString.t * LustreAst.expr | Cycle of HString.t list val error_message: error_kind -> string diff --git a/tests/ounit/lustre/testLustreFrontend.ml b/tests/ounit/lustre/testLustreFrontend.ml index 23196319b..df389bd0b 100644 --- a/tests/ounit/lustre/testLustreFrontend.ml +++ b/tests/ounit/lustre/testLustreFrontend.ml @@ -190,7 +190,7 @@ let _ = run_test_tt_main ("frontend lustreArrayDependencies error tests" >::: [ | _ -> false); mk_test "test invalid inductive array def 3" (fun () -> match load_file "./lustreArrayDependencies/inductive_array3.lus" with - | Error (`LustreArrayDependencies (_, ExprMissingIndex _)) -> true + | Error (`LustreArrayDependencies (_, ComplicatedExpr _)) -> true | _ -> false); mk_test "test invalid inductive array def 4" (fun () -> match load_file "./lustreArrayDependencies/inductive_array4.lus" with @@ -198,7 +198,7 @@ let _ = run_test_tt_main ("frontend lustreArrayDependencies error tests" >::: [ | _ -> false); mk_test "test invalid inductive array def 5" (fun () -> match load_file "./lustreArrayDependencies/inductive_array5.lus" with - | Error (`LustreArrayDependencies (_, ExprMissingIndex _)) -> true + | Error (`LustreArrayDependencies (_, ComplicatedExpr _)) -> true | _ -> false); mk_test "test invalid inductive array def 6" (fun () -> match load_file "./lustreArrayDependencies/inductive_array6.lus" with @@ -699,4 +699,4 @@ let _ = run_test_tt_main ("frontend LustreAbstractInterpretation error tests" >: match load_file "./lustreAbstractInterpretation/subrange_bug6.lus" with | Error (`LustreAbstractInterpretationError (_, ConstantOutOfSubrange _)) -> true | _ -> false); -]) \ No newline at end of file +]) diff --git a/tests/regression/success/test_recursive_array_def.lus b/tests/regression/success/test_recursive_array_def.lus index a5b5134a0..49d5fd562 100644 --- a/tests/regression/success/test_recursive_array_def.lus +++ b/tests/regression/success/test_recursive_array_def.lus @@ -1,3 +1,17 @@ + +node sum (const n: int; A: int ^ n) returns (s: int); +var cumul: int ^ n; +let + cumul[i] = if i = 0 then A[0] else A[i] + cumul[i-1]; + s = cumul[n-1]; +tel + +node slice (const n: int; A: int ^ n; const low: int; const up: int) +returns (B : int ^ (up-low)); +let + B[i] = A[low + i]; +tel + node a(A: int^5) returns (B: int^5) let B[i] = A[i] + 1;