diff --git a/src/base/algodiff/owl_algodiff_generic.ml b/src/base/algodiff/owl_algodiff_generic.ml index f812e9592..94abe180f 100644 --- a/src/base/algodiff/owl_algodiff_generic.ml +++ b/src/base/algodiff/owl_algodiff_generic.ml @@ -48,6 +48,7 @@ module Make (A : Owl_types_ndarray_algodiff.Sig) = struct (* derivative of f (scalar -> scalar) at x, forward ad *) let diff' f x = + if not (is_float x) then failwith "input of `diff` must be a scalar"; let x = make_forward x (pack_flt 1.) (tag ()) in let y = f x in primal y, tangent y @@ -60,6 +61,7 @@ module Make (A : Owl_types_ndarray_algodiff.Sig) = struct let grad' f x = let x = make_reverse x (tag ()) in let y = f x in + if not (is_float y) then failwith "output of `grad` must be a scalar"; Reverse.reverse_reset y; Reverse.reverse_push (pack_flt 1.) y; primal y, x |> adjval diff --git a/src/base/algodiff/owl_algodiff_ops.ml b/src/base/algodiff/owl_algodiff_ops.ml index 0247d0ec6..4d5ade44d 100644 --- a/src/base/algodiff/owl_algodiff_ops.ml +++ b/src/base/algodiff/owl_algodiff_ops.ml @@ -503,6 +503,25 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct and get_slice i = Lazy.force _get_slice i + and _get_fancy = + lazy + (fun i -> + build_siso + (module struct + let label = "get_fancy" + + let ff_f a = error_uniop label (pack_elt a) + + let ff_arr a = Arr A.(get_fancy i a) + + let df _cp _ap at = get_fancy i at + + let dr a _cp ca = set_fancy i (zero a) !ca + end : Siso)) + + + and get_fancy i = Lazy.force _get_fancy i + and _sum' = lazy (build_siso @@ -1303,6 +1322,41 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct and set_slice i = Lazy.force _set_slice i + and _set_fancy = + lazy + (fun i -> + build_piso + (module struct + let label = "set_fancy" + + let ff_aa a _b = error_uniop label (pack_elt a) + + let ff_ab a _b = error_uniop label (pack_elt a) + + let ff_ba _a b = error_uniop label (pack_elt b) + + let ff_bb a b = + let a = A.copy a in + A.(set_fancy i a b); + Arr a + + + let df_da _cp _ap at bp = set_fancy i at (zero bp) + + let df_db _cp ap _bp bt = set_fancy i (zero ap) bt + + let df_dab _cp _ap at _bp bt = set_fancy i at bt + + let dr_ab _a b _cp ca = set_fancy i !ca (zero b), get_fancy i !ca + + let dr_a _a b _cp ca = set_fancy i !ca (zero b) + + let dr_b _a _b _cp ca = get_fancy i !ca + end : Piso)) + + + and set_fancy i = Lazy.force _set_fancy i + and ( *@ ) a b = dot a b and _dot = @@ -1569,9 +1623,7 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct let ff_arr a = A.(split ~axis parts a) |> Array.map (fun x -> Arr x) - let df _cp _ap _at = - raise (Owl_exception.NOT_IMPLEMENTED "owl_algodiff_ops.split") - + let df _cp _ap at = split ~axis parts at let dr _a _cp _cp_ref_arr ca_ref_arr = concatenate ~axis (Array.map (fun ca -> !ca) ca_ref_arr) diff --git a/src/base/algodiff/owl_algodiff_ops_builder.ml b/src/base/algodiff/owl_algodiff_ops_builder.ml index 0ceb2d09a..9f9d3a717 100644 --- a/src/base/algodiff/owl_algodiff_ops_builder.ml +++ b/src/base/algodiff/owl_algodiff_ops_builder.ml @@ -195,7 +195,7 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct val ff_arr : A.arr -> t array - val df : t -> t -> t -> t + val df : t array -> t -> t -> t array val dr : t -> t -> t ref array -> t ref array -> t end @@ -206,7 +206,8 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct match a with | DF (ap, at, ai) -> let cp_arr = fd ap in - Array.map (fun cp -> DF (cp, df cp ap at, ai)) cp_arr + let ct_arr = df cp_arr ap at in + Array.map2 (fun cp ct -> DF (cp, ct, ai)) cp_arr ct_arr | DR (ap, _, _, _, ai, _) -> let cp_arr = fd ap in let cp_arr_ref = Array.map (fun cp -> ref cp) cp_arr in @@ -399,34 +400,34 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct | `normal, DR (_, _, _, _, t', _) -> succ i, t', `reverse, [ i ] | `forward, DR (_, _, _, _, t', _) -> if t' > t - then succ i, t', `reverse, [] + then succ i, t', `reverse, [ i ] else if t' = t then failwith "error: forward and reverse clash on the same level" else succ i, t, `forward, idxs | `reverse, DR (_, _, _, _, t', _) -> if t' > t - then succ i, t', `reverse, [] + then succ i, t', `reverse, [ i ] else if t' = t then succ i, t', `reverse, i :: idxs else succ i, t, m, idxs | `normal, DF (_, _, t') -> succ i, t', `forward, [ i ] | `forward, DF (_, _, t') -> if t' > t - then succ i, t', `forward, [] + then succ i, t', `forward, [ i ] else if t' = t then succ i, t', `forward, i :: idxs else succ i, t, `forward, idxs | `reverse, DF (_, _, t') -> if t' > t - then succ i, t', `forward, [] + then succ i, t', `forward, [ i ] else if t' = t then failwith "error: forward and reverse clash on the same level" else succ i, t, `reverse, idxs) - (0, -10000, `normal, []) + (0, -50000, `normal, []) in fun (module S : Aiso) -> let rec f a = - let _, t, mode, idxs = build_info a in + let _, max_t, mode, idxs = build_info a in let idxs = idxs |> List.rev in match mode with | `normal -> S.ff a @@ -435,7 +436,12 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct Array.map (fun x -> match x with - | DF (p, _, t') -> if t = t' then p else x + | DF (p, _, t') -> + if max_t = t' + then p + else if t' > max_t + then failwith "no tags should be higher than max_t" + else x | x -> x) a in @@ -445,13 +451,18 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct List.iter (fun k -> at.(k) <- tangent a.(k)) idxs; S.df idxs cp ap at in - DF (cp, at, t) + DF (cp, at, max_t) | `reverse -> let ap = Array.map (fun x -> match x with - | DR (p, _, _, _, t', _) -> if t = t' then p else x + | DR (p, _, _, _, t', _) -> + if max_t = t' + then p + else if t' > max_t + then failwith "no tags should be higher than max_t" + else x | x -> x) a in @@ -463,7 +474,7 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct in let register t = List.fold_left (fun t i -> a.(i) :: t) t idxs in let label = S.label, List.(map (fun i -> a.(i)) idxs) in - DR (cp, ref (zero cp), (adjoint, register, label), ref 0, t, ref 0) + DR (cp, ref (zero cp), (adjoint, register, label), ref 0, max_t, ref 0) in f end diff --git a/src/base/algodiff/owl_algodiff_ops_builder_sig.ml b/src/base/algodiff/owl_algodiff_ops_builder_sig.ml index 07935f47b..368afd43f 100644 --- a/src/base/algodiff/owl_algodiff_ops_builder_sig.ml +++ b/src/base/algodiff/owl_algodiff_ops_builder_sig.ml @@ -64,7 +64,7 @@ module type Sig = sig val ff_arr : arr -> t array - val df : t -> t -> t -> t + val df : t array -> t -> t -> t array val dr : t -> t -> t ref array -> t ref array -> t end diff --git a/src/base/algodiff/owl_algodiff_ops_sig.ml b/src/base/algodiff/owl_algodiff_ops_sig.ml index b7b77af33..9507cfe0d 100644 --- a/src/base/algodiff/owl_algodiff_ops_sig.ml +++ b/src/base/algodiff/owl_algodiff_ops_sig.ml @@ -238,6 +238,12 @@ module type Sig = sig val set_slice : int list list -> t -> t -> t (** Refer to :doc:`owl_dense_ndarray_generic` *) + val get_fancy : index list -> t -> t + (** Refer to :doc:`owl_dense_ndarray_generic` *) + + val set_fancy : index list -> t -> t -> t + (** Refer to :doc:`owl_dense_ndarray_generic` *) + val diag : ?k:int -> t -> t (** Refer to :doc:`owl_dense_ndarray_generic` *) diff --git a/src/base/compute/owl_computation_cpu_init.ml b/src/base/compute/owl_computation_cpu_init.ml index 0c99c8d27..7d4a2fd8c 100644 --- a/src/base/compute/owl_computation_cpu_init.ml +++ b/src/base/compute/owl_computation_cpu_init.ml @@ -61,6 +61,8 @@ module Make (Graph : Owl_computation_graph_sig.Sig) = struct | Set _i -> split_01 p | GetSlice _slice -> split_00 p (* ? *) | SetSlice _slice -> split_00 p (* ? *) + | GetFancy _indices -> split_00 p (* ? *) + | SetFancy _indices -> split_00 p (* ? *) | Copy -> split_01 p | Reset -> split_01 p | Reshape _shape -> split_01 p diff --git a/src/base/compute/owl_computation_operator.ml b/src/base/compute/owl_computation_operator.ml index b8d3ea9cf..960a946c4 100644 --- a/src/base/compute/owl_computation_operator.ml +++ b/src/base/compute/owl_computation_operator.ml @@ -109,10 +109,14 @@ module Make (Symbol : Owl_computation_symbol_sig.Sig) = struct let get_slice slice x = make_then_connect (GetSlice slice) [| arr_to_node x |] |> node_to_arr - let set_slice slice x y = make_then_connect (SetSlice slice) [| arr_to_node x; arr_to_node y |] |> ignore + let get_fancy indices x = + make_then_connect (GetFancy indices) [| arr_to_node x |] |> node_to_arr + + let set_fancy indices x y = + make_then_connect (SetFancy indices) [| arr_to_node x; arr_to_node y |] |> ignore let copy x = make_then_connect Copy [| arr_to_node x |] |> node_to_arr diff --git a/src/base/compute/owl_computation_operator_sig.ml b/src/base/compute/owl_computation_operator_sig.ml index ccb91d346..2c5ea3dcf 100644 --- a/src/base/compute/owl_computation_operator_sig.ml +++ b/src/base/compute/owl_computation_operator_sig.ml @@ -65,6 +65,12 @@ module type Sig = sig val set_slice : int list list -> arr -> arr -> unit (** TODO *) + val get_fancy : index list -> arr -> arr + (** TODO *) + + val set_fancy : index list -> arr -> arr -> unit + (** TODO *) + val copy : arr -> arr (** TODO *) diff --git a/src/base/compute/owl_computation_symbol.ml b/src/base/compute/owl_computation_symbol.ml index a038b5f39..5187d782a 100644 --- a/src/base/compute/owl_computation_symbol.ml +++ b/src/base/compute/owl_computation_symbol.ml @@ -32,6 +32,8 @@ module Make (Shape : Owl_computation_shape_sig.Sig) = struct | Set _i -> "Set" | GetSlice _slice -> "GetSlice" | SetSlice _slice -> "SetSlice" + | GetFancy _ -> "GetFancy" + | SetFancy _ -> "SetFancy" | Copy -> "Copy" | Reset -> "Reset" | Reshape _shape -> "Reshape" diff --git a/src/base/compute/owl_computation_type.ml b/src/base/compute/owl_computation_type.ml index 13c026489..548a04e84 100644 --- a/src/base/compute/owl_computation_type.ml +++ b/src/base/compute/owl_computation_type.ml @@ -73,6 +73,8 @@ module Make (Device : Owl_types_computation_device.Sig) = struct | Set of int array | GetSlice of int list list | SetSlice of int list list + | GetFancy of index list + | SetFancy of index list | Copy | Reset | Reshape of int array diff --git a/src/base/compute/owl_computation_type_sig.ml b/src/base/compute/owl_computation_type_sig.ml index 709607237..3c6f11ed3 100644 --- a/src/base/compute/owl_computation_type_sig.ml +++ b/src/base/compute/owl_computation_type_sig.ml @@ -3,7 +3,7 @@ * Copyright (c) 2016-2020 Liang Wang *) -open Owl_types +open Owl_types_common (* Functor of making the symbols of a computation graph. *) @@ -75,6 +75,8 @@ module type Sig = sig | Set of int array | GetSlice of int list list | SetSlice of int list list + | GetFancy of index list + | SetFancy of index list | Copy | Reset | Reshape of int array diff --git a/src/base/dense/owl_base_dense_ndarray_generic.ml b/src/base/dense/owl_base_dense_ndarray_generic.ml index dad142366..312aface4 100644 --- a/src/base/dense/owl_base_dense_ndarray_generic.ml +++ b/src/base/dense/owl_base_dense_ndarray_generic.ml @@ -214,6 +214,14 @@ let set_slice index_list varr slice_varr = done +(*TODO: optimise, test *) +let get_fancy _indices _varr = raise (Owl_exception.NOT_IMPLEMENTED "base: get_fancy") + +(*TODO: optimise, test *) +let set_fancy _indices _target _input = + raise (Owl_exception.NOT_IMPLEMENTED "base: set_fancy") + + (* The result shares the underlying buffer with original, not a copy *) let reshape x d = let minus_one = Owl_utils.Array.count d (-1) in @@ -2619,11 +2627,10 @@ let conv2d ?(padding = SAME) input kernel stride = let in_col = (i * col_stride) + di - pad_left in let in_row = (j * row_stride) + dj - pad_top in let in_val = - if - 0 <= in_col - && in_col < input_cols - && 0 <= in_row - && in_row < input_rows + if 0 <= in_col + && in_col < input_cols + && 0 <= in_row + && in_row < input_rows then get input [| b; in_col; in_row; q |] else 0. in @@ -2792,13 +2799,12 @@ let conv3d ?(padding = SAME) input kernel stride = let in_row = (j * row_stride) + dj - pad_top in let in_dpt = (dpt * dpt_stride) + d_dpt - pad_shallow in let in_val = - if - 0 <= in_col - && in_col < input_cols - && 0 <= in_row - && in_row < input_rows - && 0 <= in_dpt - && in_dpt < input_dpts + if 0 <= in_col + && in_col < input_cols + && 0 <= in_row + && in_row < input_rows + && 0 <= in_dpt + && in_dpt < input_dpts then get input [| b; in_col; in_row; in_dpt; q |] else 0. in @@ -2981,13 +2987,12 @@ let _pool3d let in_col = (i * col_stride) + di - pad_left in let in_row = (j * row_stride) + dj - pad_top in let in_dpt = (dpt * dpt_stride) + d_dpt - pad_shallow in - if - 0 <= in_col - && in_col < input_cols - && 0 <= in_row - && in_row < input_rows - && 0 <= in_dpt - && in_dpt < input_dpts + if 0 <= in_col + && in_col < input_cols + && 0 <= in_row + && in_row < input_rows + && 0 <= in_dpt + && in_dpt < input_dpts then add_val_pool_fun (get input [| b; in_col; in_row; in_dpt; k |]) done (*d_dpt*) @@ -3219,17 +3224,15 @@ let conv2d_backward_input input kernel stride output' = let sum = ref 0. in for di = 0 to kernel_cols - 1 do for dj = 0 to kernel_rows - 1 do - if - Stdlib.( mod ) (in_i + pad_left - di) col_stride = 0 - && Stdlib.( mod ) (in_j + pad_top - dj) row_stride = 0 + if Stdlib.( mod ) (in_i + pad_left - di) col_stride = 0 + && Stdlib.( mod ) (in_j + pad_top - dj) row_stride = 0 then ( let out_col = (in_i + pad_left - di) / col_stride in let out_row = (in_j + pad_top - dj) / row_stride in - if - 0 <= out_col - && out_col < output_cols - && 0 <= out_row - && out_row < output_rows + if 0 <= out_col + && out_col < output_cols + && 0 <= out_row + && out_row < output_rows then for k = 0 to out_channel - 1 do let out_grad = get output' [| b; out_col; out_row; k |] in @@ -3344,8 +3347,10 @@ let conv2d_backward_kernel input kernel stride output' = for j = 0 to output_rows - 1 do let in_col = (i * col_stride) + di - pad_left in let in_row = (j * row_stride) + dj - pad_top in - if - 0 <= in_col && in_col < input_cols && 0 <= in_row && in_row < input_rows + if 0 <= in_col + && in_col < input_cols + && 0 <= in_row + && in_row < input_rows then ( let out_grad = get output' [| b; i; j; k |] in let input_val = get input [| b; in_col; in_row; q |] in @@ -4006,21 +4011,19 @@ let conv3d_backward_input input kernel stride output' = for di = 0 to kernel_cols - 1 do for dj = 0 to kernel_rows - 1 do for d_dpt = 0 to kernel_dpts - 1 do - if - Stdlib.( mod ) (in_i + pad_left - di) col_stride = 0 - && Stdlib.( mod ) (in_j + pad_top - dj) row_stride = 0 - && Stdlib.( mod ) (in_dpt + pad_shallow - d_dpt) dpt_stride = 0 + if Stdlib.( mod ) (in_i + pad_left - di) col_stride = 0 + && Stdlib.( mod ) (in_j + pad_top - dj) row_stride = 0 + && Stdlib.( mod ) (in_dpt + pad_shallow - d_dpt) dpt_stride = 0 then ( let out_col = (in_i + pad_left - di) / col_stride in let out_row = (in_j + pad_top - dj) / row_stride in let out_dpt = (in_dpt + pad_shallow - d_dpt) / dpt_stride in - if - 0 <= out_col - && out_col < output_cols - && 0 <= out_row - && out_row < output_rows - && 0 <= out_dpt - && out_dpt < output_dpts + if 0 <= out_col + && out_col < output_cols + && 0 <= out_row + && out_row < output_rows + && 0 <= out_dpt + && out_dpt < output_dpts then for k = 0 to out_channel - 1 do let out_grad = @@ -4152,13 +4155,12 @@ let conv3d_backward_kernel input kernel stride output' = let in_col = (i * col_stride) + di - pad_left in let in_row = (j * row_stride) + dj - pad_top in let in_dpt = (dpt * dpt_stride) + d_dpt - pad_shallow in - if - 0 <= in_col - && in_col < input_cols - && 0 <= in_row - && in_row < input_rows - && 0 <= in_dpt - && in_dpt < input_dpts + if 0 <= in_col + && in_col < input_cols + && 0 <= in_row + && in_row < input_rows + && 0 <= in_dpt + && in_dpt < input_dpts then ( let out_grad = get output' [| b; i; j; dpt; k |] in let input_val = get input [| b; in_col; in_row; in_dpt; q |] in @@ -4342,10 +4344,9 @@ let transpose_conv3d_backward_input input kernel stride output' = dpt_stride in let p = - if - output_cols_same = output_cols - && output_rows_same = output_rows - && output_dpts_same = output_dpts + if output_cols_same = output_cols + && output_rows_same = output_rows + && output_dpts_same = output_dpts then SAME else VALID in @@ -4611,13 +4612,12 @@ let _pool3d_backward let in_col = (i * col_stride) + di - pad_left in let in_row = (j * row_stride) + dj - pad_top in let in_dpt = (dpt * dpt_stride) + dk - pad_shallow in - if - 0 <= in_col - && in_col < input_cols - && 0 <= in_row - && in_row < input_rows - && 0 <= in_dpt - && in_dpt < input_dpts + if 0 <= in_col + && in_col < input_cols + && 0 <= in_row + && in_row < input_rows + && 0 <= in_dpt + && in_dpt < input_dpts then add_val_pool_fun (get input [| b; in_col; in_row; in_dpt; k |]) done (*dk*) @@ -4633,13 +4633,12 @@ let _pool3d_backward let in_col = (i * col_stride) + di - pad_left in let in_row = (j * row_stride) + dj - pad_top in let in_dpt = (dpt * dpt_stride) + dk - pad_shallow in - if - 0 <= in_col - && in_col < input_cols - && 0 <= in_row - && in_row < input_rows - && 0 <= in_dpt - && in_dpt < input_dpts + if 0 <= in_col + && in_col < input_cols + && 0 <= in_row + && in_row < input_rows + && 0 <= in_dpt + && in_dpt < input_dpts then ( let input_val = get input [| b; in_col; in_row; in_dpt; k |] in let input_grad = get input' [| b; in_col; in_row; in_dpt; k |] in diff --git a/src/base/dense/owl_base_dense_ndarray_generic.mli b/src/base/dense/owl_base_dense_ndarray_generic.mli index b3306e411..e5931b814 100644 --- a/src/base/dense/owl_base_dense_ndarray_generic.mli +++ b/src/base/dense/owl_base_dense_ndarray_generic.mli @@ -101,6 +101,12 @@ val get_slice : int list list -> ('a, 'b) t -> ('a, 'b) t val set_slice : int list list -> ('a, 'b) t -> ('a, 'b) t -> unit (** Refer to :doc:`owl_dense_ndarray_generic` *) +val get_fancy : index list -> ('a, 'b) t -> ('a, 'b) t +(** Refer to :doc:`owl_dense_ndarray_generic` *) + +val set_fancy : index list -> ('a, 'b) t -> ('a, 'b) t -> unit +(** Refer to :doc:`owl_dense_ndarray_generic` *) + val reset : ('a, 'b) t -> unit (** Refer to :doc:`owl_dense_ndarray_generic` *) diff --git a/src/base/dense/owl_base_dense_ndarray_intf.ml b/src/base/dense/owl_base_dense_ndarray_intf.ml index 5c8c2c81e..a3657a996 100644 --- a/src/base/dense/owl_base_dense_ndarray_intf.ml +++ b/src/base/dense/owl_base_dense_ndarray_intf.ml @@ -54,6 +54,10 @@ module type Common = sig val set_slice : int list list -> arr -> arr -> unit + val get_fancy : index list -> arr -> arr + + val set_fancy : index list -> arr -> arr -> unit + val copy : arr -> arr val copy_ : out:arr -> arr -> unit diff --git a/src/base/types/owl_types_ndarray_basic.ml b/src/base/types/owl_types_ndarray_basic.ml index 3680eb23b..055f35e92 100644 --- a/src/base/types/owl_types_ndarray_basic.ml +++ b/src/base/types/owl_types_ndarray_basic.ml @@ -48,6 +48,10 @@ module type Sig = sig val set_slice : int list list -> arr -> arr -> unit + val get_fancy : index list -> arr -> arr + + val set_fancy : index list -> arr -> arr -> unit + val copy : arr -> arr val copy_ : out:arr -> arr -> unit (* FIXME: move to mutable? *) diff --git a/test/unit_algodiff_matrix_generic.ml b/test/unit_algodiff_matrix_generic.ml index f8a2d30d5..49f30f5e2 100644 --- a/test/unit_algodiff_matrix_generic.ml +++ b/test/unit_algodiff_matrix_generic.ml @@ -91,6 +91,24 @@ module Make (M : Ndarray_Algodiff with type elt = float) = struct let l2norm_sqr' () = test_func Maths.l2norm_sqr' + let get_slice () = + let f x = + let y1 = Maths.get_slice [ [ 0 ] ] x in + let y2 = Maths.get_slice [ [ 2 ] ] x in + Maths.(sum' (sin (y1 * y2))) + in + test_func f + + + let get_fancy () = + let f x = + let y1 = Maths.get_fancy [ R [ 0; 1 ]; L [ 0; 2 ] ] x in + let y2 = Maths.get_fancy [ R [ 1; 2 ]; L [ 0; 1 ] ] x in + Maths.(sum' (tan (y1 * y2))) + in + test_func f + + let tril () = test_func Maths.tril let triu () = test_func Maths.triu @@ -172,7 +190,7 @@ module Make (M : Ndarray_Algodiff with type elt = float) = struct let f = let y1 = Mat.gaussian 10 n in let y2 = Mat.gaussian 15 n in - let h x = Maths.(y1 *@ x) in + let h x = Maths.(sum' (y1 *@ x)) in let h' = grad h in fun x -> let y = Maths.concatenate ~axis:0 [| y1; x; y2; h' x |] in @@ -185,7 +203,7 @@ module Make (M : Ndarray_Algodiff with type elt = float) = struct let f = let y1 = Mat.gaussian n n in let y2 = Mat.gaussian n n in - let h x = Maths.(y1 *@ x) in + let h x = Maths.(sum' (y1 *@ x)) in let h' = grad h in fun x -> Maths.stack ~axis:(-1) [| y1; x; y2; h' x |] in @@ -316,7 +334,7 @@ module Make (M : Ndarray_Algodiff with type elt = float) = struct test_func f - let nested_grad1 () = + let nested_grad1 = let x = Mat.gaussian 1 (n * n) in let r ~theta x = Maths.(sum' (sqr x *@ transpose (theta * theta))) in let quad ~theta x = @@ -328,7 +346,25 @@ module Make (M : Ndarray_Algodiff with type elt = float) = struct let theta = Arr.reshape theta [| 1; n * n |] in test_theta x theta in - test_func f + fun () -> test_func f + + + let nested_grad2 = + (* check aiao build_info: when inputs include DFs and DRs at different levels *) + let ff = + let z = Mat.gaussian n n in + let hfwd = Mat.gaussian n n in + let hrev = Mat.gaussian (3 * n) n in + fun x -> + let zfwd = make_forward z hfwd (tag ()) in + let zrev = make_reverse z (tag ()) in + let f1 = Maths.(concatenate ~axis:0 [| x; zfwd + sqr x; zfwd |]) |> Maths.sin in + let f2 = Maths.(concatenate ~axis:0 [| x; zrev + sqr x; zrev |]) |> Maths.sin in + let df = tangent f1 in + reverse_prop hrev f2; + Maths.(sum' df + sum' (adjval zrev)) + in + fun () -> test_func ff let test = @@ -337,15 +373,16 @@ module Make (M : Ndarray_Algodiff with type elt = float) = struct ; "cos", cos; "tan", tan; "sinh", sinh; "cosh", cosh; "tanh", tanh ; "sigmoid", sigmoid; "relu", relu; "dawsn", dawsn; "exp", exp ; "transpose", transpose; "diag", diag; "diagm", diagm; "trace", trace - ; "l1norm'", l1norm'; "l2norm'", l2norm'; "l2norm_sqr'", l2norm_sqr'; "tril", tril - ; "triu", triu; "inv", inv; "logdet", logdet; "chol", chol; "qr", qr; "lq", lq - ; "split", split; "concat", concat; "concatenate", concatenate; "stack", stack - ; "svd", svd; "of_arrays", of_arrays; "to_arrays", to_arrays; "init_2d", init_2d + ; "l1norm'", l1norm'; "l2norm'", l2norm'; "l2norm_sqr'", l2norm_sqr' + ; "get_slice", get_slice; "get_fancy", get_fancy; "tril", tril; "triu", triu + ; "inv", inv; "logdet", logdet; "chol", chol; "qr", qr; "lq", lq; "split", split + ; "concat", concat; "concatenate", concatenate; "stack", stack; "svd", svd + ; "of_arrays", of_arrays; "to_arrays", to_arrays; "init_2d", init_2d ; "sylvester", sylvester; "lyapunov", lyapunov ; "discrete_lyapunov", discrete_lyapunov; "linsolve", linsolve ; "linsolve_triangular", linsolve_triangular; "care", care ; "log_sum_exp'", log_sum_exp'; "log_sum_exp", log_sum_exp - ; "nested_grad1", nested_grad1 ] + ; "nested_grad1", nested_grad1; "nested_grad2", nested_grad2 ] |> List.fold_left (fun (b, error_msg) (s, f) -> let b', c =