From 9310a35e1eded121726cfb6c2fbdf5db3b97622c Mon Sep 17 00:00:00 2001 From: Marcello Seri Date: Tue, 15 Sep 2020 17:34:49 +0200 Subject: [PATCH] Fix access (#543) * owl_operator: add also accessor operators To support call with single index Signed-off-by: Marcello Seri * Add minimal tests Signed-off-by: Marcello Seri * Matrix access operators: make more robust and add tuple-based get/set Signed-off-by: Marcello Seri * improve error message Signed-off-by: Marcello Seri --- src/base/compute/owl_computation_optimiser.ml | 6 ++-- src/base/core/owl_operator.ml | 32 +++++++++++++++++-- src/base/core/owl_operator.mli | 12 +++++++ test/unit_slicing_fancy.ml | 25 ++++++++++++++- 4 files changed, 69 insertions(+), 6 deletions(-) diff --git a/src/base/compute/owl_computation_optimiser.ml b/src/base/compute/owl_computation_optimiser.ml index f4d263eed..50cd60edd 100644 --- a/src/base/compute/owl_computation_optimiser.ml +++ b/src/base/compute/owl_computation_optimiser.ml @@ -74,9 +74,9 @@ module Make (Operator : Owl_computation_operator_sig.Sig) = struct | Asinh -> pattern_000 x | Acosh -> pattern_000 x | Atanh -> pattern_000 x - | Min (_keep_dims, _axis) -> pattern_000 x - | Max (_keep_dims, _axis) -> pattern_000 x - | Sum (_keep_dims, _axis) -> pattern_000 x + | Min (_keep_dims, _axis) -> pattern_000 x + | Max (_keep_dims, _axis) -> pattern_000 x + | Sum (_keep_dims, _axis) -> pattern_000 x | SumReduce _axis -> pattern_024 x | Signum -> pattern_000 x | Sigmoid -> pattern_000 x diff --git a/src/base/core/owl_operator.ml b/src/base/core/owl_operator.ml index 959526f04..e676324a3 100644 --- a/src/base/core/owl_operator.ml +++ b/src/base/core/owl_operator.ml @@ -138,17 +138,41 @@ module Make_Extend (M : ExtendSig) = struct let ( .!{;..}<- ) x s = M.set_fancy_ext s x + let ( .${} ) x s = M.get_slice_ext [| s |] x + let ( .${;..} ) x s = M.get_slice_ext s x + let ( .${}<- ) x s = M.set_slice_ext [| s |] x + let ( .${;..}<- ) x s = M.set_slice_ext s x end [@warning "-34"] module Make_Matrix (M : MatrixSig) = struct type ('a, 'b) op_t2 = ('a, 'b) M.t - let ( .%{;..} ) x i = M.get x i.(0) i.(1) + let ( .%{} ) x (i1, i2) = M.get x i1 i2 + + let ( .%{;..} ) x i = + if Array.length i = 2 + then M.get x i.(0) i.(1) + else + failwith + (".%{} on matrices requires exactly two indices but I got " + ^ string_of_int + @@ Array.length i) + + + let ( .%{}<- ) x (i1, i2) = M.set x i1 i2 + + let ( .%{;..}<- ) x i = + if Array.length i = 2 + then M.set x i.(0) i.(1) + else + failwith + (".%{}<- on matrices requires exactly two indices but I got " + ^ string_of_int + @@ Array.length i) - let ( .%{;..}<- ) x i = M.set x i.(0) i.(1) let ( *@ ) a b = M.dot a b end [@warning "-34"] @@ -156,8 +180,12 @@ end [@warning "-34"] module Make_Ndarray (M : NdarraySig) = struct type ('a, 'b) op_t3 = ('a, 'b) M.t + let ( .%{} ) x i = M.get x [| i |] + let ( .%{;..} ) x i = M.get x i + let ( .%{}<- ) x i = M.set x [| i |] + let ( .%{;..}<- ) x i = M.set x i end [@warning "-34"] diff --git a/src/base/core/owl_operator.mli b/src/base/core/owl_operator.mli index cc00c10b3..a19bcab28 100644 --- a/src/base/core/owl_operator.mli +++ b/src/base/core/owl_operator.mli @@ -195,9 +195,13 @@ module Make_Extend (M : ExtendSig) : sig val ( .!{;..}<- ) : ('a, 'b) M.t -> Owl_types.index array -> ('a, 'b) M.t -> unit (** Operator of ``set_fancy`` *) + val ( .${} ) : ('a, 'b) M.t -> int list -> ('a, 'b) M.t + val ( .${;..} ) : ('a, 'b) M.t -> int list array -> ('a, 'b) M.t (** Operator of ``get_slice`` *) + val ( .${}<- ) : ('a, 'b) M.t -> int list -> ('a, 'b) M.t -> unit + val ( .${;..}<- ) : ('a, 'b) M.t -> int list array -> ('a, 'b) M.t -> unit (** Operator of ``set_slice`` *) end @@ -208,9 +212,13 @@ module Make_Matrix (M : MatrixSig) : sig val ( *@ ) : ('a, 'b) M.t -> ('a, 'b) M.t -> ('a, 'b) M.t (** Operator of ``dot a b``, i.e. matrix multiplication ``a * b``. *) + val ( .%{} ) : ('a, 'b) M.t -> int * int -> 'a + val ( .%{;..} ) : ('a, 'b) M.t -> int array -> 'a (** Operator of ``get`` *) + val ( .%{}<- ) : ('a, 'b) M.t -> int * int -> 'a -> unit + val ( .%{;..}<- ) : ('a, 'b) M.t -> int array -> 'a -> unit (** Operator of ``set`` *) end @@ -218,9 +226,13 @@ end (** {6 Ndarray-specific operators} *) module Make_Ndarray (M : NdarraySig) : sig + val ( .%{} ) : ('a, 'b) M.t -> int -> 'a + val ( .%{;..} ) : ('a, 'b) M.t -> int array -> 'a (** Operator of ``get`` *) + val ( .%{}<- ) : ('a, 'b) M.t -> int -> 'a -> unit + val ( .%{;..}<- ) : ('a, 'b) M.t -> int array -> 'a -> unit (** Operator of ``set`` *) end diff --git a/test/unit_slicing_fancy.ml b/test/unit_slicing_fancy.ml index 15ff5ec59..de4bd532a 100644 --- a/test/unit_slicing_fancy.ml +++ b/test/unit_slicing_fancy.ml @@ -280,6 +280,25 @@ module To_test = struct Arr.set z [| 9; 5; 0 |] 1.; Arr.set z [| 9; 6; 0 |] 2.; Arr.(x = z) + + + let test_35 () = + (* This just checks that the extended accessor syntax compiles *) + let open Arr in + let x = of_array [| 0.5; 0.7 |] [| 2 |] in + let y = of_array [| 0.8; 0.9; 1.1; 1.2 |] [| 2; 2 |] in + let v1 = l2norm' (x.${[ 1 ]} - of_array [| 0.7 |] [| 1 |]) in + let v2 = abs_float (x.%{1} -. 0.7) in + let v3 = abs_float (y.%{0; 1} -. 0.9) in + let v4 = l2norm' (y.${[ 0 ]; [ 0; -1 ]} - of_array [| 0.8; 0.9 |] [| 2 |]) in + Stdlib.(v1 < 1e-10 && v2 < 1e-10 && v3 < 1e-10 && v4 < 1e-10) + + + let test_36 () = + (* This just checks that the extended accessor syntax compiles *) + let open Mat in + let x = gaussian 2 2 in + Stdlib.(x.%{0; 1} = x.%{0, 1}) end (* the tests *) @@ -352,6 +371,10 @@ let test_33 () = Alcotest.(check bool) "test 33" true (To_test.test_33 ()) let test_34 () = Alcotest.(check bool) "test 34" true (To_test.test_34 ()) +let test_35 () = Alcotest.(check bool) "test 35" true (To_test.test_35 ()) + +let test_36 () = Alcotest.(check bool) "test 36" true (To_test.test_36 ()) + let test_set = [ "test 01", `Slow, test_01; "test 02", `Slow, test_02; "test 03", `Slow, test_03 ; "test 04", `Slow, test_04; "test 05", `Slow, test_05; "test 06", `Slow, test_06 @@ -364,4 +387,4 @@ let test_set = ; "test 25", `Slow, test_25; "test 26", `Slow, test_26; "test 27", `Slow, test_27 ; "test 28", `Slow, test_28; "test 29", `Slow, test_29; "test 30", `Slow, test_30 ; "test 31", `Slow, test_31; "test 32", `Slow, test_32; "test 33", `Slow, test_33 - ; "test 34", `Slow, test_34 ] + ; "test 34", `Slow, test_34; "test 35", `Slow, test_35; "test 36", `Slow, test_36 ]