Skip to content

Commit

Permalink
[EnzymeTestUtils] Vectorize function for FiniteDifferencesCalls (Enzy…
Browse files Browse the repository at this point in the history
…meAD#1327)

* Add to_vec

* Use to_vec for tangent generation

* Fix incorrect call to test_reverse

* Use to_vec in calls to FiniteDifferences

* Increment patch number

* Add more cases to test_approx

* Handle cases where constructorof not implemented but needed

* Correctly handle case where ret activity is batched and all else const

* Replace NamedTuple method with Dict

* Add function for structured array testing

* Add structured array test

* Add tests for to_vec

* Add to_vec

* Use to_vec for tangent generation

* Fix incorrect call to test_reverse

* Use to_vec in calls to FiniteDifferences

* Increment patch number

* Add more cases to test_approx

* Handle cases where constructorof not implemented but needed

* Correctly handle case where ret activity is batched and all else const

* Replace NamedTuple method with Dict

* Add function for structured array testing

* Add structured array test

* Add tests for to_vec

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Add LinearAlgebra to test env

* Run formatter on finitedifferences calls

* Introduce AliasDict for checking for aliased arrays

* Refactor to_vec to handle aliased arrays correctly

* Test new to_vec behavior

* Note difference between zero_tangent and make_zero

* Restore deleted code

* Don't treat immutable structs as equivalent

* Remove obsolete limitation

* Test cases where arrays alias

* Document remaining limitation

* Also test aliasing in when batching

* Also test aliasing in forward-mode

* Skip test that hits Julia GC bug pre v1.8

* Change mutating test to support returned arg

* Clarify documentation of limitations

* Skip structured array test for v1.7

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Run formatter

* Fix random seed in tests

* Increment patch number

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
sethaxen and github-actions[bot] authored Apr 21, 2024
1 parent 5ae36e5 commit 8273a6e
Show file tree
Hide file tree
Showing 15 changed files with 591 additions and 48 deletions.
5 changes: 3 additions & 2 deletions lib/EnzymeTestUtils/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeTestUtils"
uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
authors = ["Seth Axen <[email protected]>", "William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.1.6"
version = "0.1.7"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand All @@ -21,8 +21,9 @@ Quaternions = "0.7"
julia = "1.6"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MetaTesting = "9e32d19f-1e4f-477a-8631-b16c78aa0f56"
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"

[targets]
test = ["MetaTesting", "Quaternions"]
test = ["LinearAlgebra", "MetaTesting", "Quaternions"]
1 change: 1 addition & 0 deletions lib/EnzymeTestUtils/src/EnzymeTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using Test
export test_forward, test_reverse, are_activities_compatible

include("output_control.jl")
include("to_vec.jl")
include("test_approx.jl")
include("compatible_activities.jl")
include("finite_difference_calls.jl")
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeTestUtils/src/compatible_activities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ _batch_size(::Type{BatchDuplicated{T,N}}) where {T,N} = N
_batch_size(::Type{<:Annotation}) = nothing
function _batch_size(activities...)
sizes = filter(!isnothing, map(_batch_size, activities))
isempty(sizes) && return nothing
isempty(sizes) && return 1
@assert all(==(sizes[1]), sizes)
return sizes[1]
end
43 changes: 27 additions & 16 deletions lib/EnzymeTestUtils/src/finite_difference_calls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,25 @@ function _fd_forward(fdm, f, rettype, y, activities)
xs = map(x -> x.val, activities)
ẋs = map(a -> a isa Const ? nothing : a.dval, activities)
ignores = map(a -> a isa Const, activities)
f2 = _wrap_forward_function(f, xs, ignores)
f_sig_args = _wrap_forward_function(f, xs, ignores)
ignores = collect(ignores)
_, from_vec_out = to_vec(y)
sig_arg_val_vec, from_vec_in = to_vec(xs[.!ignores])
# vectorize inputs and outputs of function
f_vec = first to_vec Base.splat(f_sig_args) from_vec_in
if rettype <: Union{Duplicated,DuplicatedNoNeed}
all(ignores) && return zero_tangent(y)
sigargs = zip(xs[.!ignores], ẋs[.!ignores])
return FiniteDifferences.jvp(fdm, f2, sigargs...)
sig_arg_dval_vec, _ = to_vec(ẋs[.!ignores])
ret_deval_vec = FiniteDifferences.jvp(fdm, f_vec,
(sig_arg_val_vec, sig_arg_dval_vec))
return from_vec_out(ret_deval_vec)
elseif rettype <: Union{BatchDuplicated,BatchDuplicatedNoNeed}
all(ignores) && return (var"1"=zero_tangent(y),)
sig_arg_vals = xs[.!ignores]
ret_dvals = map(ẋs[.!ignores]...) do sig_args_dvals...
FiniteDifferences.jvp(fdm, f2, zip(sig_arg_vals, sig_args_dvals)...)
sig_args_dvals_vec, _ = to_vec(sig_args_dvals)
ret_dval_vec = FiniteDifferences.jvp(fdm, f_vec,
(sig_arg_val_vec, sig_args_dvals_vec))
return from_vec_out(ret_dval_vec)
end
return NamedTuple{ntuple(Symbol, length(ret_dvals))}(ret_dvals)
else
Expand All @@ -58,7 +66,7 @@ Call `FiniteDifferences.j′vp` on `f` with the arguments `xs` determined by `ac
function _fd_reverse(fdm, f, ȳ, activities, active_return)
xs = map(x -> x.val, activities)
ignores = map(a -> a isa Const, activities)
f2 = _wrap_reverse_function(active_return, f, xs, ignores)
f_sig_args = _wrap_reverse_function(active_return, f, xs, ignores)
all(ignores) && return map(zero_tangent, xs)
ignores = collect(ignores)
is_batch = _any_batch_duplicated(map(typeof, activities)...)
Expand All @@ -74,18 +82,21 @@ function _fd_reverse(fdm, f, ȳ, activities, active_return)
sigargs = xs[.!ignores]
s̄igargs = x̄s[.!ignores]
sigarginds = eachindex(x̄s)[.!ignores]
sigargs_vec, from_vec_in = to_vec(sigargs)
# vectorize inputs and outputs of function
f_vec = first to_vec Base.splat(f_sig_args) from_vec_in
if !is_batch
fd = FiniteDifferences.j′vp(fdm, f2, (ȳ, s̄igargs...), sigargs...)
ȳ_extended = (ȳ, s̄igargs...)
ȳ_extended_vec, _ = to_vec(ȳ_extended)
fd_vec = only(FiniteDifferences.j′vp(fdm, f_vec, ȳ_extended_vec, sigargs_vec))
fd = from_vec_in(fd_vec)
else
fd = Tuple(
zip(
map(ȳ, s̄igargs...) do y_dval, sigargs_dvals...
FiniteDifferences.j′vp(
fdm, f2, (y_dval, sigargs_dvals...), sigargs...
)
end...,
),
)
fd = Tuple(zip(map(ȳ, s̄igargs...) do ȳ_extended...
ȳ_extended_vec, _ = to_vec(ȳ_extended)
fd_vec = only(FiniteDifferences.j′vp(fdm, f_vec, ȳ_extended_vec,
sigargs_vec))
return from_vec_in(fd_vec)
end...))
end
@assert length(fd) == length(sigarginds)
x̄s[sigarginds] = collect(fd)
Expand Down
48 changes: 39 additions & 9 deletions lib/EnzymeTestUtils/src/generate_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ function map_fields_recursive(f, x::T...) where {T}
fields = map(ConstructionBase.getfields, x)
all(isempty, fields) && return first(x)
new_fields = map(fields...) do xi...
map_fields_recursive(f, xi...)
return map_fields_recursive(f, xi...)
end
return ConstructionBase.constructorof(T)(new_fields...)
return _construct(T, new_fields...)
end
function map_fields_recursive(f, x::T...) where {T<:Union{Array,Tuple,NamedTuple}}
map(x...) do xi...
Expand All @@ -17,14 +17,20 @@ map_fields_recursive(f, x::T...) where {T<:AbstractFloat} = f(x...)
map_fields_recursive(f, x::Array{<:Number}...) = f(x...)

rand_tangent(x) = rand_tangent(Random.default_rng(), x)
rand_tangent(rng, x) = map_fields_recursive(Base.Fix1(rand_tangent, rng), x)
# make numbers prettier sometimes when errors are printed.
rand_tangent(rng, ::T) where {T<:AbstractFloat} = rand(rng, -9:T(0.01):9)
rand_tangent(rng, x::T) where {T<:Array{<:Number}} = rand_tangent.(rng, x)
function rand_tangent(rng, x)
v, from_vec = to_vec(x)
T = eltype(v)
# make numbers prettier sometimes when errors are printed.
v_new = rand(rng, -9:T(0.01):9, length(v))
return from_vec(v_new)
end

zero_tangent(x) = map_fields_recursive(zero_tangent, x)
zero_tangent(::T) where {T<:AbstractFloat} = zero(T)
zero_tangent(x::T) where {T<:Array{<:Number}} = zero_tangent.(x)
# differs from Enzyme.make_zero primarily in that reshaped Arrays in the argument will share
# the same memory in the output.
function zero_tangent(x)
v, from_vec = to_vec(x)
return from_vec(zero(v))
end

auto_activity(arg) = auto_activity(Random.default_rng(), arg)
function auto_activity(rng, arg::Tuple)
Expand All @@ -47,3 +53,27 @@ end
function _build_activity(rng, primal, T::Type{<:Annotation})
throw(ArgumentError("Unsupported activity type: $T"))
end

# below code is adapted from https://github.com/JuliaDiff/FiniteDifferences.jl/blob/99ad77f05bdf6c023b249025dbb8edc746d52b4f/src/to_vec.jl
# MIT Expat License
# Copyright (c) 2018 Invenia Technical Computing

# get around the constructors and make the type directly
# Note this is moderately evil accessing julia's internals
if VERSION >= v"1.3"
@generated function _force_construct(T, args...)
return Expr(:splatnew, :T, :args)
end
else
@generated function _force_construct(T, args...)
return Expr(:new, :T, Any[:(args[$i]) for i in 1:length(args)]...)
end
end

function _construct(T, args...)
try
return ConstructionBase.constructorof(T)(args...)
catch MethodError
return _force_construct(T, args...)
end
end
20 changes: 20 additions & 0 deletions lib/EnzymeTestUtils/src/test_approx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ function test_approx(x::AbstractArray, y::AbstractArray, msg; kwargs...)
end
return nothing
end
function test_approx(x::Tuple, y::Tuple, msg; kwargs...)
@test_msg "$msg: lengths must match" length(x) == length(y)
for i in eachindex(x)
msg_new = "$msg: ::$(typeof(x))[$i]"
test_approx(x[i], y[i], msg_new; kwargs...)
end
return nothing
end
function test_approx(x::Dict, y::Dict, msg; kwargs...)
@test_msg "$msg: keys must match" issetequal(keys(x), keys(y))
for k in keys(x)
msg_new = "$msg: ::$(typeof(x))[$k]"
test_approx(x[k], y[k], msg_new; kwargs...)
end
return nothing
end
function test_approx(x::Type, y::Type, msg; kwargs...)
@test_msg "$msg: types must match" x === y
return nothing
end
test_approx(x, y, msg; kwargs...) = _test_fields_approx(x, y, msg; kwargs...)

function _test_fields_approx(x, y, msg; kwargs...)
Expand Down
4 changes: 2 additions & 2 deletions lib/EnzymeTestUtils/src/test_forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
Test `Enzyme.autodiff` of `f` in `Forward`-mode against finite differences.
`f` has all constraints of the same argument passed to `Enzyme.autodiff`, with several
additional constraints:
`f` has all constraints of the same argument passed to `Enzyme.autodiff`, with additional
constraints:
- If it mutates one of its arguments, it _must_ return that argument.
# Arguments
Expand Down
17 changes: 8 additions & 9 deletions lib/EnzymeTestUtils/src/test_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ for N in 1:30
function call_with_kwargs(fkwargs::NT, f::FT, $(argexprs...)) where {NT, FT}
Base.@_inline_meta
@static if VERSION v"1.8"
# callsite inline syntax unsupported in <= 1.8
# callsite inline syntax unsupported in <= 1.8
f($(argexprs...); fkwargs...)
else
@inline f($(argexprs...); fkwargs...)
Expand All @@ -23,11 +23,10 @@ end
Test `Enzyme.autodiff_thunk` of `f` in `ReverseSplitWithPrimal`-mode against finite
differences.
`f` has all constraints of the same argument passed to `Enzyme.autodiff_thunk`, with several
`f` has all constraints of the same argument passed to `Enzyme.autodiff_thunk`, with
additional constraints:
- If it mutates one of its arguments, it must not also return that argument.
- If the return value is a struct, then all floating point numbers contained in the struct
or its fields must be in arrays.
- If an `Array{<:AbstractFloat}` appears in the input/output, then a reshaped version of it
may not also appear in the input/output.
# Arguments
Expand Down Expand Up @@ -96,13 +95,13 @@ function test_reverse(
args_copy = deepcopy(Base.tail(primals))
y = fcopy(args_copy...; deepcopy(fkwargs)...)
# generate tangent for output
if !_any_batch_duplicated(map(typeof, activities)...)
if !_any_batch_duplicated(ret_activity, map(typeof, activities)...)
= ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y)
else
batch_size = _batch_size(map(typeof, activities)...)
batch_size = _batch_size(ret_activity, map(typeof, activities)...)
ks = ntuple(Symbol string, batch_size)
= ntuple(batch_size) do _
return ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y)
return ret_activity <: Const ? zero_tangent(y) : rand_tangent(y)
end
end
# call finitedifferences, avoid mutating original arguments
Expand Down Expand Up @@ -137,7 +136,7 @@ function test_reverse(
else
# if there's a shadow result, then we need to set it to our random adjoint
if !(shadow_result === nothing)
if !_any_batch_duplicated(map(typeof, activities)...)
if !_any_batch_duplicated(ret_activity, map(typeof, activities)...)
map_fields_recursive(copyto!, shadow_result, ȳ)
else
for (sr, dy) in zip(shadow_result, ȳ)
Expand Down
Loading

0 comments on commit 8273a6e

Please sign in to comment.