Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed bug in aiso + checks for grad and diff + forward-mode derivative for split #561

Merged
merged 5 commits into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/base/algodiff/owl_algodiff_generic.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ module Make (A : Owl_types_ndarray_algodiff.Sig) = struct

(* derivative of f (scalar -> scalar) at x, forward ad *)
let diff' f x =
if not (is_float x) then failwith "input of `diff` must be a scalar";
let x = make_forward x (pack_flt 1.) (tag ()) in
let y = f x in
primal y, tangent y
Expand All @@ -60,6 +61,7 @@ module Make (A : Owl_types_ndarray_algodiff.Sig) = struct
let grad' f x =
let x = make_reverse x (tag ()) in
let y = f x in
if not (is_float y) then failwith "output of `grad` must be a scalar";
Reverse.reverse_reset y;
Reverse.reverse_push (pack_flt 1.) y;
primal y, x |> adjval
Expand Down
58 changes: 55 additions & 3 deletions src/base/algodiff/owl_algodiff_ops.ml
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,25 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct

and get_slice i = Lazy.force _get_slice i

and _get_fancy =
lazy
(fun i ->
build_siso
(module struct
let label = "get_fancy"

let ff_f a = error_uniop label (pack_elt a)

let ff_arr a = Arr A.(get_fancy i a)

let df _cp _ap at = get_fancy i at

let dr a _cp ca = set_fancy i (zero a) !ca
end : Siso))


and get_fancy i = Lazy.force _get_fancy i

and _sum' =
lazy
(build_siso
Expand Down Expand Up @@ -1303,6 +1322,41 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct

and set_slice i = Lazy.force _set_slice i

and _set_fancy =
lazy
(fun i ->
build_piso
(module struct
let label = "set_fancy"

let ff_aa a _b = error_uniop label (pack_elt a)

let ff_ab a _b = error_uniop label (pack_elt a)

let ff_ba _a b = error_uniop label (pack_elt b)

let ff_bb a b =
let a = A.copy a in
A.(set_fancy i a b);
Arr a


let df_da _cp _ap at bp = set_fancy i at (zero bp)

let df_db _cp ap _bp bt = set_fancy i (zero ap) bt

let df_dab _cp _ap at _bp bt = set_fancy i at bt

let dr_ab _a b _cp ca = set_fancy i !ca (zero b), get_fancy i !ca

let dr_a _a b _cp ca = set_fancy i !ca (zero b)

let dr_b _a _b _cp ca = get_fancy i !ca
end : Piso))


and set_fancy i = Lazy.force _set_fancy i

and ( *@ ) a b = dot a b

and _dot =
Expand Down Expand Up @@ -1569,9 +1623,7 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct

let ff_arr a = A.(split ~axis parts a) |> Array.map (fun x -> Arr x)

let df _cp _ap _at =
raise (Owl_exception.NOT_IMPLEMENTED "owl_algodiff_ops.split")

let df _cp _ap at = split ~axis parts at

let dr _a _cp _cp_ref_arr ca_ref_arr =
concatenate ~axis (Array.map (fun ca -> !ca) ca_ref_arr)
Expand Down
35 changes: 23 additions & 12 deletions src/base/algodiff/owl_algodiff_ops_builder.ml
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct

val ff_arr : A.arr -> t array

val df : t -> t -> t -> t
val df : t array -> t -> t -> t array

val dr : t -> t -> t ref array -> t ref array -> t
end
Expand All @@ -206,7 +206,8 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct
match a with
| DF (ap, at, ai) ->
let cp_arr = fd ap in
Array.map (fun cp -> DF (cp, df cp ap at, ai)) cp_arr
let ct_arr = df cp_arr ap at in
Array.map2 (fun cp ct -> DF (cp, ct, ai)) cp_arr ct_arr
| DR (ap, _, _, _, ai, _) ->
let cp_arr = fd ap in
let cp_arr_ref = Array.map (fun cp -> ref cp) cp_arr in
Expand Down Expand Up @@ -399,34 +400,34 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct
| `normal, DR (_, _, _, _, t', _) -> succ i, t', `reverse, [ i ]
| `forward, DR (_, _, _, _, t', _) ->
if t' > t
then succ i, t', `reverse, []
then succ i, t', `reverse, [ i ]
else if t' = t
then failwith "error: forward and reverse clash on the same level"
else succ i, t, `forward, idxs
| `reverse, DR (_, _, _, _, t', _) ->
if t' > t
then succ i, t', `reverse, []
then succ i, t', `reverse, [ i ]
else if t' = t
then succ i, t', `reverse, i :: idxs
else succ i, t, m, idxs
| `normal, DF (_, _, t') -> succ i, t', `forward, [ i ]
| `forward, DF (_, _, t') ->
if t' > t
then succ i, t', `forward, []
then succ i, t', `forward, [ i ]
else if t' = t
then succ i, t', `forward, i :: idxs
else succ i, t, `forward, idxs
| `reverse, DF (_, _, t') ->
if t' > t
then succ i, t', `forward, []
then succ i, t', `forward, [ i ]
else if t' = t
then failwith "error: forward and reverse clash on the same level"
else succ i, t, `reverse, idxs)
(0, -10000, `normal, [])
(0, -50000, `normal, [])
in
fun (module S : Aiso) ->
let rec f a =
let _, t, mode, idxs = build_info a in
let _, max_t, mode, idxs = build_info a in
let idxs = idxs |> List.rev in
match mode with
| `normal -> S.ff a
Expand All @@ -435,7 +436,12 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct
Array.map
(fun x ->
match x with
| DF (p, _, t') -> if t = t' then p else x
| DF (p, _, t') ->
if max_t = t'
then p
else if t' > max_t
then failwith "no tags should be higher than max_t"
else x
| x -> x)
a
in
Expand All @@ -445,13 +451,18 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct
List.iter (fun k -> at.(k) <- tangent a.(k)) idxs;
S.df idxs cp ap at
in
DF (cp, at, t)
DF (cp, at, max_t)
| `reverse ->
let ap =
Array.map
(fun x ->
match x with
| DR (p, _, _, _, t', _) -> if t = t' then p else x
| DR (p, _, _, _, t', _) ->
if max_t = t'
then p
else if t' > max_t
then failwith "no tags should be higher than max_t"
else x
| x -> x)
a
in
Expand All @@ -463,7 +474,7 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct
in
let register t = List.fold_left (fun t i -> a.(i) :: t) t idxs in
let label = S.label, List.(map (fun i -> a.(i)) idxs) in
DR (cp, ref (zero cp), (adjoint, register, label), ref 0, t, ref 0)
DR (cp, ref (zero cp), (adjoint, register, label), ref 0, max_t, ref 0)
in
f
end
2 changes: 1 addition & 1 deletion src/base/algodiff/owl_algodiff_ops_builder_sig.ml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ module type Sig = sig

val ff_arr : arr -> t array

val df : t -> t -> t -> t
val df : t array -> t -> t -> t array

val dr : t -> t -> t ref array -> t ref array -> t
end
Expand Down
6 changes: 6 additions & 0 deletions src/base/algodiff/owl_algodiff_ops_sig.ml
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ module type Sig = sig
val set_slice : int list list -> t -> t -> t
(** Refer to :doc:`owl_dense_ndarray_generic` *)

val get_fancy : index list -> t -> t
(** Refer to :doc:`owl_dense_ndarray_generic` *)

val set_fancy : index list -> t -> t -> t
(** Refer to :doc:`owl_dense_ndarray_generic` *)

val diag : ?k:int -> t -> t
(** Refer to :doc:`owl_dense_ndarray_generic` *)

Expand Down
2 changes: 2 additions & 0 deletions src/base/compute/owl_computation_cpu_init.ml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ module Make (Graph : Owl_computation_graph_sig.Sig) = struct
| Set _i -> split_01 p
| GetSlice _slice -> split_00 p (* ? *)
| SetSlice _slice -> split_00 p (* ? *)
| GetFancy _indices -> split_00 p (* ? *)
| SetFancy _indices -> split_00 p (* ? *)
| Copy -> split_01 p
| Reset -> split_01 p
| Reshape _shape -> split_01 p
Expand Down
6 changes: 5 additions & 1 deletion src/base/compute/owl_computation_operator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,14 @@ module Make (Symbol : Owl_computation_symbol_sig.Sig) = struct
let get_slice slice x =
make_then_connect (GetSlice slice) [| arr_to_node x |] |> node_to_arr


let set_slice slice x y =
make_then_connect (SetSlice slice) [| arr_to_node x; arr_to_node y |] |> ignore

let get_fancy indices x =
make_then_connect (GetFancy indices) [| arr_to_node x |] |> node_to_arr

let set_fancy indices x y =
make_then_connect (SetFancy indices) [| arr_to_node x; arr_to_node y |] |> ignore

let copy x = make_then_connect Copy [| arr_to_node x |] |> node_to_arr

Expand Down
6 changes: 6 additions & 0 deletions src/base/compute/owl_computation_operator_sig.ml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ module type Sig = sig
val set_slice : int list list -> arr -> arr -> unit
(** TODO *)

val get_fancy : index list -> arr -> arr
(** TODO *)

val set_fancy : index list -> arr -> arr -> unit
(** TODO *)

val copy : arr -> arr
(** TODO *)

Expand Down
2 changes: 2 additions & 0 deletions src/base/compute/owl_computation_symbol.ml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ module Make (Shape : Owl_computation_shape_sig.Sig) = struct
| Set _i -> "Set"
| GetSlice _slice -> "GetSlice"
| SetSlice _slice -> "SetSlice"
| GetFancy _ -> "GetFancy"
| SetFancy _ -> "SetFancy"
| Copy -> "Copy"
| Reset -> "Reset"
| Reshape _shape -> "Reshape"
Expand Down
2 changes: 2 additions & 0 deletions src/base/compute/owl_computation_type.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ module Make (Device : Owl_types_computation_device.Sig) = struct
| Set of int array
| GetSlice of int list list
| SetSlice of int list list
| GetFancy of index list
| SetFancy of index list
| Copy
| Reset
| Reshape of int array
Expand Down
4 changes: 3 additions & 1 deletion src/base/compute/owl_computation_type_sig.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Copyright (c) 2016-2020 Liang Wang <[email protected]>
*)

open Owl_types
open Owl_types_common

(* Functor of making the symbols of a computation graph. *)

Expand Down Expand Up @@ -75,6 +75,8 @@ module type Sig = sig
| Set of int array
| GetSlice of int list list
| SetSlice of int list list
| GetFancy of index list
| SetFancy of index list
| Copy
| Reset
| Reshape of int array
Expand Down
Loading