Skip to content

Commit

Permalink
Fix access (#543)
Browse files Browse the repository at this point in the history
* owl_operator: add also accessor operators

To support call with single index

Signed-off-by: Marcello Seri <[email protected]>

* Add minimal tests

Signed-off-by: Marcello Seri <[email protected]>

* Matrix access operators: make more robust and add tuple-based get/set

Signed-off-by: Marcello Seri <[email protected]>

* improve error message

Signed-off-by: Marcello Seri <[email protected]>
  • Loading branch information
mseri authored Sep 15, 2020
1 parent e3519b8 commit 9310a35
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/base/compute/owl_computation_optimiser.ml
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ module Make (Operator : Owl_computation_operator_sig.Sig) = struct
| Asinh -> pattern_000 x
| Acosh -> pattern_000 x
| Atanh -> pattern_000 x
| Min (_keep_dims, _axis) -> pattern_000 x
| Max (_keep_dims, _axis) -> pattern_000 x
| Sum (_keep_dims, _axis) -> pattern_000 x
| Min (_keep_dims, _axis) -> pattern_000 x
| Max (_keep_dims, _axis) -> pattern_000 x
| Sum (_keep_dims, _axis) -> pattern_000 x
| SumReduce _axis -> pattern_024 x
| Signum -> pattern_000 x
| Sigmoid -> pattern_000 x
Expand Down
32 changes: 30 additions & 2 deletions src/base/core/owl_operator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -138,26 +138,54 @@ module Make_Extend (M : ExtendSig) = struct

let ( .!{;..}<- ) x s = M.set_fancy_ext s x

let ( .${} ) x s = M.get_slice_ext [| s |] x

let ( .${;..} ) x s = M.get_slice_ext s x

let ( .${}<- ) x s = M.set_slice_ext [| s |] x

let ( .${;..}<- ) x s = M.set_slice_ext s x
end [@warning "-34"]

module Make_Matrix (M : MatrixSig) = struct
type ('a, 'b) op_t2 = ('a, 'b) M.t

let ( .%{;..} ) x i = M.get x i.(0) i.(1)
let ( .%{} ) x (i1, i2) = M.get x i1 i2

let ( .%{;..} ) x i =
if Array.length i = 2
then M.get x i.(0) i.(1)
else
failwith
(".%{} on matrices requires exactly two indices but I got "
^ string_of_int
@@ Array.length i)


let ( .%{}<- ) x (i1, i2) = M.set x i1 i2

let ( .%{;..}<- ) x i =
if Array.length i = 2
then M.set x i.(0) i.(1)
else
failwith
(".%{}<- on matrices requires exactly two indices but I got "
^ string_of_int
@@ Array.length i)

let ( .%{;..}<- ) x i = M.set x i.(0) i.(1)

let ( *@ ) a b = M.dot a b
end [@warning "-34"]

module Make_Ndarray (M : NdarraySig) = struct
type ('a, 'b) op_t3 = ('a, 'b) M.t

let ( .%{} ) x i = M.get x [| i |]

let ( .%{;..} ) x i = M.get x i

let ( .%{}<- ) x i = M.set x [| i |]

let ( .%{;..}<- ) x i = M.set x i
end [@warning "-34"]

Expand Down
12 changes: 12 additions & 0 deletions src/base/core/owl_operator.mli
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,13 @@ module Make_Extend (M : ExtendSig) : sig
val ( .!{;..}<- ) : ('a, 'b) M.t -> Owl_types.index array -> ('a, 'b) M.t -> unit
(** Operator of ``set_fancy`` *)

val ( .${} ) : ('a, 'b) M.t -> int list -> ('a, 'b) M.t

val ( .${;..} ) : ('a, 'b) M.t -> int list array -> ('a, 'b) M.t
(** Operator of ``get_slice`` *)

val ( .${}<- ) : ('a, 'b) M.t -> int list -> ('a, 'b) M.t -> unit

val ( .${;..}<- ) : ('a, 'b) M.t -> int list array -> ('a, 'b) M.t -> unit
(** Operator of ``set_slice`` *)
end
Expand All @@ -208,19 +212,27 @@ module Make_Matrix (M : MatrixSig) : sig
val ( *@ ) : ('a, 'b) M.t -> ('a, 'b) M.t -> ('a, 'b) M.t
(** Operator of ``dot a b``, i.e. matrix multiplication ``a * b``. *)

val ( .%{} ) : ('a, 'b) M.t -> int * int -> 'a

val ( .%{;..} ) : ('a, 'b) M.t -> int array -> 'a
(** Operator of ``get`` *)

val ( .%{}<- ) : ('a, 'b) M.t -> int * int -> 'a -> unit

val ( .%{;..}<- ) : ('a, 'b) M.t -> int array -> 'a -> unit
(** Operator of ``set`` *)
end

(** {6 Ndarray-specific operators} *)

module Make_Ndarray (M : NdarraySig) : sig
val ( .%{} ) : ('a, 'b) M.t -> int -> 'a

val ( .%{;..} ) : ('a, 'b) M.t -> int array -> 'a
(** Operator of ``get`` *)

val ( .%{}<- ) : ('a, 'b) M.t -> int -> 'a -> unit

val ( .%{;..}<- ) : ('a, 'b) M.t -> int array -> 'a -> unit
(** Operator of ``set`` *)
end
Expand Down
25 changes: 24 additions & 1 deletion test/unit_slicing_fancy.ml
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,25 @@ module To_test = struct
Arr.set z [| 9; 5; 0 |] 1.;
Arr.set z [| 9; 6; 0 |] 2.;
Arr.(x = z)


let test_35 () =
(* This just checks that the extended accessor syntax compiles *)
let open Arr in
let x = of_array [| 0.5; 0.7 |] [| 2 |] in
let y = of_array [| 0.8; 0.9; 1.1; 1.2 |] [| 2; 2 |] in
let v1 = l2norm' (x.${[ 1 ]} - of_array [| 0.7 |] [| 1 |]) in
let v2 = abs_float (x.%{1} -. 0.7) in
let v3 = abs_float (y.%{0; 1} -. 0.9) in
let v4 = l2norm' (y.${[ 0 ]; [ 0; -1 ]} - of_array [| 0.8; 0.9 |] [| 2 |]) in
Stdlib.(v1 < 1e-10 && v2 < 1e-10 && v3 < 1e-10 && v4 < 1e-10)


let test_36 () =
(* This just checks that the extended accessor syntax compiles *)
let open Mat in
let x = gaussian 2 2 in
Stdlib.(x.%{0; 1} = x.%{0, 1})
end

(* the tests *)
Expand Down Expand Up @@ -352,6 +371,10 @@ let test_33 () = Alcotest.(check bool) "test 33" true (To_test.test_33 ())

let test_34 () = Alcotest.(check bool) "test 34" true (To_test.test_34 ())

let test_35 () = Alcotest.(check bool) "test 35" true (To_test.test_35 ())

let test_36 () = Alcotest.(check bool) "test 36" true (To_test.test_36 ())

let test_set =
[ "test 01", `Slow, test_01; "test 02", `Slow, test_02; "test 03", `Slow, test_03
; "test 04", `Slow, test_04; "test 05", `Slow, test_05; "test 06", `Slow, test_06
Expand All @@ -364,4 +387,4 @@ let test_set =
; "test 25", `Slow, test_25; "test 26", `Slow, test_26; "test 27", `Slow, test_27
; "test 28", `Slow, test_28; "test 29", `Slow, test_29; "test 30", `Slow, test_30
; "test 31", `Slow, test_31; "test 32", `Slow, test_32; "test 33", `Slow, test_33
; "test 34", `Slow, test_34 ]
; "test 34", `Slow, test_34; "test 35", `Slow, test_35; "test 36", `Slow, test_36 ]

0 comments on commit 9310a35

Please sign in to comment.