diff --git a/src/lustre/lustreAbstractInterpretation.ml b/src/lustre/lustreAbstractInterpretation.ml index d07895654..d51b51b08 100644 --- a/src/lustre/lustreAbstractInterpretation.ml +++ b/src/lustre/lustreAbstractInterpretation.ml @@ -458,6 +458,7 @@ and interpret_structured_expr f node_id ctx ty_ctx ty proj expr = let g = interpret_structured_expr f node_id ctx ty_ctx ty in Ctx.traverse_group_expr_list g ty_ctx proj es ) + | StructUpdate (_, e, _, _) -> interpret_expr_by_type node_id ctx ty_ctx ty proj e | _ -> assert false) and interpret_int_expr node_id ctx ty_ctx proj expr = diff --git a/src/lustre/lustreAstHelpers.ml b/src/lustre/lustreAstHelpers.ml index a285a2c9f..8ba9d2241 100644 --- a/src/lustre/lustreAstHelpers.ml +++ b/src/lustre/lustreAstHelpers.ml @@ -1649,3 +1649,8 @@ let name_of_prop pos name k = in Format.asprintf "%sProp%a" kind_str Lib.pp_print_line_and_column pos |> HString.mk_hstring + +let get_const_num_value = function + | Const (_, Num x) -> + int_of_string_opt (HString.string_of_hstring x) + | _ -> None \ No newline at end of file diff --git a/src/lustre/lustreAstHelpers.mli b/src/lustre/lustreAstHelpers.mli index 9b9eaaa49..ce9c0eac4 100644 --- a/src/lustre/lustreAstHelpers.mli +++ b/src/lustre/lustreAstHelpers.mli @@ -186,3 +186,5 @@ val rename_contract_vars : expr -> expr val name_of_prop : Lib.position -> HString.t option -> LustreAst.prop_kind -> HString.t (** Get the name associated with a property *) + +val get_const_num_value : expr -> int option \ No newline at end of file diff --git a/src/lustre/lustreTypeChecker.ml b/src/lustre/lustreTypeChecker.ml index 89d9a8444..c4cb44760 100644 --- a/src/lustre/lustreTypeChecker.ml +++ b/src/lustre/lustreTypeChecker.ml @@ -57,8 +57,10 @@ type error_kind = Unknown of string | Unsupported of string | UnequalArrayExpressionType | TypeMismatchOfRecordLabel of HString.t * tc_type * tc_type - | IlltypedRecordUpdate of tc_type + | IlltypedUpdateWithLabel of tc_type + | IlltypedUpdateWithIndex of tc_type | ExpectedLabel of LA.expr + | ExpectedIntegerLiteral of LA.expr | IlltypedArraySlice of tc_type | ExpectedIntegerTypeForSlice | IlltypedArrayIndex of tc_type @@ -127,8 +129,10 @@ let error_message kind = match kind with | UnequalArrayExpressionType -> "All expressions must be of the same type in an Array" | TypeMismatchOfRecordLabel (label, ty1, ty2) -> "Type mismatch. Type of record label '" ^ (HString.string_of_hstring label) ^ "' is of type " ^ string_of_tc_type ty1 ^ " but the type of the expression is " ^ string_of_tc_type ty2 - | IlltypedRecordUpdate ty -> "Cannot do an update on non-record type " ^ string_of_tc_type ty + | IlltypedUpdateWithLabel ty -> "Expected a record type but found " ^ string_of_tc_type ty + | IlltypedUpdateWithIndex ty -> "Expected a tuple type or an array type but found " ^ string_of_tc_type ty | ExpectedLabel e -> "Only labels can be used for record expressions but found " ^ LA.string_of_expr e + | ExpectedIntegerLiteral e -> "Expected an integer literal but found " ^ LA.string_of_expr e | IlltypedArraySlice ty -> "Slicing can only be done on an array type but found " ^ string_of_tc_type ty | ExpectedIntegerTypeForSlice -> "Slicing should have integer types" | IlltypedArrayIndex ty -> "Indexing can only be done on an array type but found " ^ string_of_tc_type ty @@ -759,13 +763,13 @@ let rec infer_type_expr: tc_context -> HString.t option -> LA.expr -> (tc_type * check_array_size_expr ctx nname sup_expr >> R.ok (LA.ArrayType (pos, (b_ty, sup_expr)), warnings) ) - | LA.StructUpdate (pos, r, i_or_ls, e) -> + | LA.StructUpdate (pos, ue, i_or_ls, e) -> if List.length i_or_ls != 1 then type_error pos (Unsupported ("List of labels or indices for structure update is not supported")) else (match List.hd i_or_ls with | LA.Label (pos, l) -> - infer_type_expr ctx nname r + infer_type_expr ctx nname ue >>= (function | RecordType (_, _, flds) as r_ty, warnings1 -> (let typed_fields = List.map (fun (_, i, ty) -> (i, ty)) flds in @@ -776,8 +780,34 @@ let rec infer_type_expr: tc_context -> HString.t option -> LA.expr -> (tc_type * (R.ok (r_ty, warnings1 @ warnings2)) (type_error pos (TypeMismatchOfRecordLabel (l, f_ty, e_ty))) | None -> type_error pos (NotAFieldOfRecord l))) - | r_ty, _ -> type_error pos (IlltypedRecordUpdate r_ty)) - | LA.Index (_, e) -> type_error pos (ExpectedLabel e)) + | r_ty, _ -> type_error pos (IlltypedUpdateWithLabel r_ty)) + | LA.Index (pos, i) -> + let* ue_ty, warnings1 = infer_type_expr ctx nname ue in + (match ue_ty with + | TupleType _ -> ( + let* idx = + match LH.get_const_num_value i with + | Some n -> Ok n + | None -> type_error pos (ExpectedIntegerLiteral i) + in + let* e_ty, warnings2 = infer_type_expr ctx nname e in + let* warnings3 = check_type_tuple_proj pos ctx nname ue idx e_ty in + R.ok (ue_ty, warnings1 @ warnings2 @ warnings3) + ) + | ArrayType (_, (b_ty, _)) -> ( + let* index_type, warnings1 = infer_type_expr ctx nname i in + let* index_type = expand_type_syn_reftype_history ctx index_type in + if is_expr_int_type ctx nname i then + let* e_ty, warnings2 = infer_type_expr ctx nname e in + R.ifM (eq_lustre_type ctx b_ty e_ty) + (R.ok (ue_ty, warnings1 @ warnings2)) + (type_error pos (ExpectedType (e_ty, b_ty))) + else + type_error pos (ExpectedIntegerTypeForArrayIndex index_type) + ) + | _ -> type_error pos (IlltypedUpdateWithIndex ue_ty) + ) + ) | LA.ArrayIndex (pos, e, i) -> let* index_type, warnings1 = infer_type_expr ctx nname i in let* index_type = expand_type_syn_reftype_history ctx index_type in @@ -976,22 +1006,46 @@ and check_type_expr: tc_context -> HString.t option -> LA.expr -> tc_type -> ([> (type_error pos UnequalArrayExpressionType)) (* Update of structured expressions *) - | StructUpdate (pos, r, i_or_ls, e) -> + | StructUpdate (pos, ue, i_or_ls, e) -> if List.length i_or_ls != 1 then type_error pos (Unsupported ("List of labels or indices for structure update is not supported")) - else (match List.hd i_or_ls with - | LA.Label (pos, l) -> - let* r_ty, warnings1 = infer_type_expr ctx nname r in ( - match r_ty with - | RecordType (_, _, flds) -> - (let typed_fields = List.map (fun (_, i, ty) -> (i, ty)) flds in - (match (List.assoc_opt l typed_fields) with - | Some ty -> - let* warnings2 = check_type_expr ctx nname e ty in - R.ok (warnings1 @ warnings2) - | None -> type_error pos (NotAFieldOfRecord l))) - | _ -> type_error pos (IlltypedRecordUpdate r_ty)) - | LA.Index (_, e) -> type_error pos (ExpectedLabel e)) + else + (match List.hd i_or_ls with + | LA.Label (pos, l) -> + let* r_ty, warnings1 = infer_type_expr ctx nname ue in ( + match r_ty with + | RecordType (_, _, flds) -> + (let typed_fields = List.map (fun (_, i, ty) -> (i, ty)) flds in + (match (List.assoc_opt l typed_fields) with + | Some ty -> + let* warnings2 = check_type_expr ctx nname e ty in + R.ok (warnings1 @ warnings2) + | None -> type_error pos (NotAFieldOfRecord l))) + | _ -> type_error pos (IlltypedUpdateWithLabel r_ty)) + | LA.Index (_, i) -> + let* ue_ty, warnings1 = infer_type_expr ctx nname ue in + (match ue_ty with + | TupleType _ -> ( + let* idx = + match LH.get_const_num_value i with + | Some n -> Ok n + | None -> type_error pos (ExpectedIntegerLiteral i) + in + let* e_ty, warnings2 = infer_type_expr ctx nname e in + let* warnings3 = check_type_tuple_proj pos ctx nname ue idx e_ty in + R.ok (warnings1 @ warnings2 @ warnings3) + ) + | ArrayType (_, (b_ty, _)) -> ( + let* index_type, warnings1 = infer_type_expr ctx nname i in + let* index_type = expand_type_syn_reftype_history ctx index_type in + if is_expr_int_type ctx nname i then + let* warnings2 = check_type_expr ctx nname e b_ty in + R.ok (warnings1 @ warnings2) + else + type_error pos (ExpectedIntegerTypeForArrayIndex index_type) + ) + | _ -> type_error pos (IlltypedUpdateWithIndex ue_ty) + )) (* Array constructor*) | ArrayConstr (pos, b_exp, sup_exp) -> diff --git a/src/lustre/lustreTypeChecker.mli b/src/lustre/lustreTypeChecker.mli index 436b1fa0c..770ec52eb 100644 --- a/src/lustre/lustreTypeChecker.mli +++ b/src/lustre/lustreTypeChecker.mli @@ -42,8 +42,10 @@ type error_kind = Unknown of string | Unsupported of string | UnequalArrayExpressionType | TypeMismatchOfRecordLabel of HString.t * tc_type * tc_type - | IlltypedRecordUpdate of tc_type + | IlltypedUpdateWithLabel of tc_type + | IlltypedUpdateWithIndex of tc_type | ExpectedLabel of LA.expr + | ExpectedIntegerLiteral of LA.expr | IlltypedArraySlice of tc_type | ExpectedIntegerTypeForSlice | IlltypedArrayIndex of tc_type diff --git a/tests/regression/success/struct_update.lus b/tests/regression/success/struct_update.lus new file mode 100644 index 000000000..9f8e7fc55 --- /dev/null +++ b/tests/regression/success/struct_update.lus @@ -0,0 +1,10 @@ + +type MyPair = [int, bool]; + +node N(A:int^3; p1: MyPair) returns (B: int^3; p2: MyPair); +let + B = (A with [0] = 1); + p2 = (p1 with .%1 = true); + check B[0] = 1; + check p2.%1; +tel \ No newline at end of file