diff --git a/src/lustre/lustreTypeChecker.ml b/src/lustre/lustreTypeChecker.ml index 73acccfac..842e0db5a 100644 --- a/src/lustre/lustreTypeChecker.ml +++ b/src/lustre/lustreTypeChecker.ml @@ -839,9 +839,14 @@ let rec infer_type_expr: tc_context -> HString.t option -> LA.expr -> (tc_type * (* Quantified expressions *) | LA.Quantifier (_, _, qs, e) -> + let* warnings1 = + R.seq (List.map (fun (_, _, ty) -> + check_type_well_formed ctx Local nname true ty) qs) + in let extn_ctx = List.fold_left union ctx (List.map (fun (_, i, ty) -> singleton_ty i ty) qs) in - infer_type_expr extn_ctx nname e + let* ty, warnings2 = infer_type_expr extn_ctx nname e in + R.ok (ty, List.flatten warnings1 @ warnings2) | AnyOp _ -> assert false (* Already desugared in lustreDesugarAnyOps *) @@ -1087,7 +1092,11 @@ and check_type_expr: tc_context -> HString.t option -> LA.expr -> tc_type -> ([> else type_error pos (ExpectedIntegerTypeForArrayIndex index_type) (* Quantified expressions *) - | Quantifier (_, _, qs, e) -> + | Quantifier (_, _, qs, e) -> ( + let* warnings1 = + R.seq (List.map (fun (_, _, ty) -> + check_type_well_formed ctx Local nname true ty) qs) + in (* Disallow quantification over abstract types *) let* _ = R.seq_ (List.map (fun (pos, id, ty) -> if type_contains_abstract ctx ty @@ -1098,8 +1107,9 @@ and check_type_expr: tc_context -> HString.t option -> LA.expr -> tc_type -> ([> ) qs) in let extn_ctx = List.fold_left union ctx (List.map (fun (_, i, ty) -> singleton_ty i ty) qs) in - check_type_expr extn_ctx nname e exp_ty - + let* warnings2 = check_type_expr extn_ctx nname e exp_ty in + R.ok (List.flatten warnings1 @ warnings2) + ) | AnyOp _ -> assert false (* Already desugared in lustreDesugarAnyOps *) (*let extn_ctx = union ctx (singleton_ty i ty) in diff --git a/tests/ounit/lustre/lustreTypeChecker/bad_bound_var_type.lus b/tests/ounit/lustre/lustreTypeChecker/bad_bound_var_type.lus new file mode 100644 index 000000000..230c9592f --- /dev/null +++ b/tests/ounit/lustre/lustreTypeChecker/bad_bound_var_type.lus @@ -0,0 +1,5 @@ + +node N() returns (ok: bool); +let + check forall (x: subrange [0.0,5.0] of int) x>=0; +tel diff --git a/tests/ounit/lustre/testLustreFrontend.ml b/tests/ounit/lustre/testLustreFrontend.ml index c4fa89d0b..13b34b1bc 100644 --- a/tests/ounit/lustre/testLustreFrontend.ml +++ b/tests/ounit/lustre/testLustreFrontend.ml @@ -571,6 +571,10 @@ let _ = run_test_tt_main ("frontend LustreTypeChecker error tests" >::: [ match load_file "./lustreTypeChecker/type-grammer.lus" with | Error (`LustreTypeCheckerError (_, ExpectedIntegerExpression _)) -> true | _ -> false); + mk_test "test invalid type for bound variable" (fun () -> + match load_file "./lustreTypeChecker/bad_bound_var_type.lus" with + | Error (`LustreTypeCheckerError (_, ExpectedIntegerExpression _)) -> true + | _ -> false); mk_test "test invalid expression for array size 1" (fun () -> match load_file "./lustreTypeChecker/node_call_in_array_size_expr.lus" with | Error (`LustreTypeCheckerError (_, ExpectedConstant _)) -> true