From 8273a6e800ccce2a1cd67ec26f9940a552312dfb Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 21 Apr 2024 09:39:19 +0200 Subject: [PATCH] [EnzymeTestUtils] Vectorize function for FiniteDifferencesCalls (#1327) * Add to_vec * Use to_vec for tangent generation * Fix incorrect call to test_reverse * Use to_vec in calls to FiniteDifferences * Increment patch number * Add more cases to test_approx * Handle cases where constructorof not implemented but needed * Correctly handle case where ret activity is batched and all else const * Replace NamedTuple method with Dict * Add function for structured array testing * Add structured array test * Add tests for to_vec * Add to_vec * Use to_vec for tangent generation * Fix incorrect call to test_reverse * Use to_vec in calls to FiniteDifferences * Increment patch number * Add more cases to test_approx * Handle cases where constructorof not implemented but needed * Correctly handle case where ret activity is batched and all else const * Replace NamedTuple method with Dict * Add function for structured array testing * Add structured array test * Add tests for to_vec * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Add LinearAlgebra to test env * Run formatter on finitedifferences calls * Introduce AliasDict for checking for aliased arrays * Refactor to_vec to handle aliased arrays correctly * Test new to_vec behavior * Note difference between zero_tangent and make_zero * Restore deleted code * Don't treat immutable structs as equivalent * Remove obsolete limitation * Test cases where arrays alias * Document remaining limitation * Also test aliasing in when batching * Also test aliasing in forward-mode * Skip test that hits Julia GC bug pre v1.8 * Change mutating test to support returned arg * Clarify documentation of limitations * Skip structured array test for v1.7 * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Run formatter * Fix random seed in tests * Increment patch number --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- lib/EnzymeTestUtils/Project.toml | 5 +- lib/EnzymeTestUtils/src/EnzymeTestUtils.jl | 1 + .../src/compatible_activities.jl | 2 +- .../src/finite_difference_calls.jl | 43 +++-- lib/EnzymeTestUtils/src/generate_tangent.jl | 48 ++++- lib/EnzymeTestUtils/src/test_approx.jl | 20 ++ lib/EnzymeTestUtils/src/test_forward.jl | 4 +- lib/EnzymeTestUtils/src/test_reverse.jl | 17 +- lib/EnzymeTestUtils/src/to_vec.jl | 155 ++++++++++++++++ lib/EnzymeTestUtils/test/helpers.jl | 25 +++ lib/EnzymeTestUtils/test/runtests.jl | 1 + lib/EnzymeTestUtils/test/test_approx.jl | 31 ++++ lib/EnzymeTestUtils/test/test_forward.jl | 48 ++++- lib/EnzymeTestUtils/test/test_reverse.jl | 64 ++++++- lib/EnzymeTestUtils/test/to_vec.jl | 175 ++++++++++++++++++ 15 files changed, 591 insertions(+), 48 deletions(-) create mode 100644 lib/EnzymeTestUtils/src/to_vec.jl create mode 100644 lib/EnzymeTestUtils/test/to_vec.jl diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 5069878de3..38b783facc 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeTestUtils" uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a" authors = ["Seth Axen ", "William Moses ", "Valentin Churavy "] -version = "0.1.6" +version = "0.1.7" [deps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -21,8 +21,9 @@ Quaternions = "0.7" julia = "1.6" [extras] +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MetaTesting = "9e32d19f-1e4f-477a-8631-b16c78aa0f56" Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" [targets] -test = ["MetaTesting", "Quaternions"] +test = ["LinearAlgebra", "MetaTesting", "Quaternions"] diff --git a/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl b/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl index cc4266cdbd..56a050455b 100644 --- a/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl +++ b/lib/EnzymeTestUtils/src/EnzymeTestUtils.jl @@ -10,6 +10,7 @@ using Test export test_forward, test_reverse, are_activities_compatible include("output_control.jl") +include("to_vec.jl") include("test_approx.jl") include("compatible_activities.jl") include("finite_difference_calls.jl") diff --git a/lib/EnzymeTestUtils/src/compatible_activities.jl b/lib/EnzymeTestUtils/src/compatible_activities.jl index dcb584e067..48ee1a24df 100644 --- a/lib/EnzymeTestUtils/src/compatible_activities.jl +++ b/lib/EnzymeTestUtils/src/compatible_activities.jl @@ -20,7 +20,7 @@ _batch_size(::Type{BatchDuplicated{T,N}}) where {T,N} = N _batch_size(::Type{<:Annotation}) = nothing function _batch_size(activities...) sizes = filter(!isnothing, map(_batch_size, activities)) - isempty(sizes) && return nothing + isempty(sizes) && return 1 @assert all(==(sizes[1]), sizes) return sizes[1] end diff --git a/lib/EnzymeTestUtils/src/finite_difference_calls.jl b/lib/EnzymeTestUtils/src/finite_difference_calls.jl index 56dec44569..7433b9ccd9 100644 --- a/lib/EnzymeTestUtils/src/finite_difference_calls.jl +++ b/lib/EnzymeTestUtils/src/finite_difference_calls.jl @@ -22,17 +22,25 @@ function _fd_forward(fdm, f, rettype, y, activities) xs = map(x -> x.val, activities) ẋs = map(a -> a isa Const ? nothing : a.dval, activities) ignores = map(a -> a isa Const, activities) - f2 = _wrap_forward_function(f, xs, ignores) + f_sig_args = _wrap_forward_function(f, xs, ignores) ignores = collect(ignores) + _, from_vec_out = to_vec(y) + sig_arg_val_vec, from_vec_in = to_vec(xs[.!ignores]) + # vectorize inputs and outputs of function + f_vec = first ∘ to_vec ∘ Base.splat(f_sig_args) ∘ from_vec_in if rettype <: Union{Duplicated,DuplicatedNoNeed} all(ignores) && return zero_tangent(y) - sigargs = zip(xs[.!ignores], ẋs[.!ignores]) - return FiniteDifferences.jvp(fdm, f2, sigargs...) + sig_arg_dval_vec, _ = to_vec(ẋs[.!ignores]) + ret_deval_vec = FiniteDifferences.jvp(fdm, f_vec, + (sig_arg_val_vec, sig_arg_dval_vec)) + return from_vec_out(ret_deval_vec) elseif rettype <: Union{BatchDuplicated,BatchDuplicatedNoNeed} all(ignores) && return (var"1"=zero_tangent(y),) - sig_arg_vals = xs[.!ignores] ret_dvals = map(ẋs[.!ignores]...) do sig_args_dvals... - FiniteDifferences.jvp(fdm, f2, zip(sig_arg_vals, sig_args_dvals)...) + sig_args_dvals_vec, _ = to_vec(sig_args_dvals) + ret_dval_vec = FiniteDifferences.jvp(fdm, f_vec, + (sig_arg_val_vec, sig_args_dvals_vec)) + return from_vec_out(ret_dval_vec) end return NamedTuple{ntuple(Symbol, length(ret_dvals))}(ret_dvals) else @@ -58,7 +66,7 @@ Call `FiniteDifferences.j′vp` on `f` with the arguments `xs` determined by `ac function _fd_reverse(fdm, f, ȳ, activities, active_return) xs = map(x -> x.val, activities) ignores = map(a -> a isa Const, activities) - f2 = _wrap_reverse_function(active_return, f, xs, ignores) + f_sig_args = _wrap_reverse_function(active_return, f, xs, ignores) all(ignores) && return map(zero_tangent, xs) ignores = collect(ignores) is_batch = _any_batch_duplicated(map(typeof, activities)...) @@ -74,18 +82,21 @@ function _fd_reverse(fdm, f, ȳ, activities, active_return) sigargs = xs[.!ignores] s̄igargs = x̄s[.!ignores] sigarginds = eachindex(x̄s)[.!ignores] + sigargs_vec, from_vec_in = to_vec(sigargs) + # vectorize inputs and outputs of function + f_vec = first ∘ to_vec ∘ Base.splat(f_sig_args) ∘ from_vec_in if !is_batch - fd = FiniteDifferences.j′vp(fdm, f2, (ȳ, s̄igargs...), sigargs...) + ȳ_extended = (ȳ, s̄igargs...) + ȳ_extended_vec, _ = to_vec(ȳ_extended) + fd_vec = only(FiniteDifferences.j′vp(fdm, f_vec, ȳ_extended_vec, sigargs_vec)) + fd = from_vec_in(fd_vec) else - fd = Tuple( - zip( - map(ȳ, s̄igargs...) do y_dval, sigargs_dvals... - FiniteDifferences.j′vp( - fdm, f2, (y_dval, sigargs_dvals...), sigargs... - ) - end..., - ), - ) + fd = Tuple(zip(map(ȳ, s̄igargs...) do ȳ_extended... + ȳ_extended_vec, _ = to_vec(ȳ_extended) + fd_vec = only(FiniteDifferences.j′vp(fdm, f_vec, ȳ_extended_vec, + sigargs_vec)) + return from_vec_in(fd_vec) + end...)) end @assert length(fd) == length(sigarginds) x̄s[sigarginds] = collect(fd) diff --git a/lib/EnzymeTestUtils/src/generate_tangent.jl b/lib/EnzymeTestUtils/src/generate_tangent.jl index e5ae0dd7e3..d774036e7e 100644 --- a/lib/EnzymeTestUtils/src/generate_tangent.jl +++ b/lib/EnzymeTestUtils/src/generate_tangent.jl @@ -4,9 +4,9 @@ function map_fields_recursive(f, x::T...) where {T} fields = map(ConstructionBase.getfields, x) all(isempty, fields) && return first(x) new_fields = map(fields...) do xi... - map_fields_recursive(f, xi...) + return map_fields_recursive(f, xi...) end - return ConstructionBase.constructorof(T)(new_fields...) + return _construct(T, new_fields...) end function map_fields_recursive(f, x::T...) where {T<:Union{Array,Tuple,NamedTuple}} map(x...) do xi... @@ -17,14 +17,20 @@ map_fields_recursive(f, x::T...) where {T<:AbstractFloat} = f(x...) map_fields_recursive(f, x::Array{<:Number}...) = f(x...) rand_tangent(x) = rand_tangent(Random.default_rng(), x) -rand_tangent(rng, x) = map_fields_recursive(Base.Fix1(rand_tangent, rng), x) -# make numbers prettier sometimes when errors are printed. -rand_tangent(rng, ::T) where {T<:AbstractFloat} = rand(rng, -9:T(0.01):9) -rand_tangent(rng, x::T) where {T<:Array{<:Number}} = rand_tangent.(rng, x) +function rand_tangent(rng, x) + v, from_vec = to_vec(x) + T = eltype(v) + # make numbers prettier sometimes when errors are printed. + v_new = rand(rng, -9:T(0.01):9, length(v)) + return from_vec(v_new) +end -zero_tangent(x) = map_fields_recursive(zero_tangent, x) -zero_tangent(::T) where {T<:AbstractFloat} = zero(T) -zero_tangent(x::T) where {T<:Array{<:Number}} = zero_tangent.(x) +# differs from Enzyme.make_zero primarily in that reshaped Arrays in the argument will share +# the same memory in the output. +function zero_tangent(x) + v, from_vec = to_vec(x) + return from_vec(zero(v)) +end auto_activity(arg) = auto_activity(Random.default_rng(), arg) function auto_activity(rng, arg::Tuple) @@ -47,3 +53,27 @@ end function _build_activity(rng, primal, T::Type{<:Annotation}) throw(ArgumentError("Unsupported activity type: $T")) end + +# below code is adapted from https://github.com/JuliaDiff/FiniteDifferences.jl/blob/99ad77f05bdf6c023b249025dbb8edc746d52b4f/src/to_vec.jl +# MIT Expat License +# Copyright (c) 2018 Invenia Technical Computing + +# get around the constructors and make the type directly +# Note this is moderately evil accessing julia's internals +if VERSION >= v"1.3" + @generated function _force_construct(T, args...) + return Expr(:splatnew, :T, :args) + end +else + @generated function _force_construct(T, args...) + return Expr(:new, :T, Any[:(args[$i]) for i in 1:length(args)]...) + end +end + +function _construct(T, args...) + try + return ConstructionBase.constructorof(T)(args...) + catch MethodError + return _force_construct(T, args...) + end +end diff --git a/lib/EnzymeTestUtils/src/test_approx.jl b/lib/EnzymeTestUtils/src/test_approx.jl index c36657827e..305bef4021 100644 --- a/lib/EnzymeTestUtils/src/test_approx.jl +++ b/lib/EnzymeTestUtils/src/test_approx.jl @@ -21,6 +21,26 @@ function test_approx(x::AbstractArray, y::AbstractArray, msg; kwargs...) end return nothing end +function test_approx(x::Tuple, y::Tuple, msg; kwargs...) + @test_msg "$msg: lengths must match" length(x) == length(y) + for i in eachindex(x) + msg_new = "$msg: ::$(typeof(x))[$i]" + test_approx(x[i], y[i], msg_new; kwargs...) + end + return nothing +end +function test_approx(x::Dict, y::Dict, msg; kwargs...) + @test_msg "$msg: keys must match" issetequal(keys(x), keys(y)) + for k in keys(x) + msg_new = "$msg: ::$(typeof(x))[$k]" + test_approx(x[k], y[k], msg_new; kwargs...) + end + return nothing +end +function test_approx(x::Type, y::Type, msg; kwargs...) + @test_msg "$msg: types must match" x === y + return nothing +end test_approx(x, y, msg; kwargs...) = _test_fields_approx(x, y, msg; kwargs...) function _test_fields_approx(x, y, msg; kwargs...) diff --git a/lib/EnzymeTestUtils/src/test_forward.jl b/lib/EnzymeTestUtils/src/test_forward.jl index eaef915a4d..e57a5c7e34 100644 --- a/lib/EnzymeTestUtils/src/test_forward.jl +++ b/lib/EnzymeTestUtils/src/test_forward.jl @@ -3,8 +3,8 @@ Test `Enzyme.autodiff` of `f` in `Forward`-mode against finite differences. -`f` has all constraints of the same argument passed to `Enzyme.autodiff`, with several -additional constraints: +`f` has all constraints of the same argument passed to `Enzyme.autodiff`, with additional +constraints: - If it mutates one of its arguments, it _must_ return that argument. # Arguments diff --git a/lib/EnzymeTestUtils/src/test_reverse.jl b/lib/EnzymeTestUtils/src/test_reverse.jl index 1f36a04a5a..f204b00a7b 100644 --- a/lib/EnzymeTestUtils/src/test_reverse.jl +++ b/lib/EnzymeTestUtils/src/test_reverse.jl @@ -8,7 +8,7 @@ for N in 1:30 function call_with_kwargs(fkwargs::NT, f::FT, $(argexprs...)) where {NT, FT} Base.@_inline_meta @static if VERSION ≤ v"1.8" - # callsite inline syntax unsupported in <= 1.8 + # callsite inline syntax unsupported in <= 1.8 f($(argexprs...); fkwargs...) else @inline f($(argexprs...); fkwargs...) @@ -23,11 +23,10 @@ end Test `Enzyme.autodiff_thunk` of `f` in `ReverseSplitWithPrimal`-mode against finite differences. -`f` has all constraints of the same argument passed to `Enzyme.autodiff_thunk`, with several +`f` has all constraints of the same argument passed to `Enzyme.autodiff_thunk`, with additional constraints: -- If it mutates one of its arguments, it must not also return that argument. -- If the return value is a struct, then all floating point numbers contained in the struct - or its fields must be in arrays. +- If an `Array{<:AbstractFloat}` appears in the input/output, then a reshaped version of it + may not also appear in the input/output. # Arguments @@ -96,13 +95,13 @@ function test_reverse( args_copy = deepcopy(Base.tail(primals)) y = fcopy(args_copy...; deepcopy(fkwargs)...) # generate tangent for output - if !_any_batch_duplicated(map(typeof, activities)...) + if !_any_batch_duplicated(ret_activity, map(typeof, activities)...) ȳ = ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y) else - batch_size = _batch_size(map(typeof, activities)...) + batch_size = _batch_size(ret_activity, map(typeof, activities)...) ks = ntuple(Symbol ∘ string, batch_size) ȳ = ntuple(batch_size) do _ - return ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y) + return ret_activity <: Const ? zero_tangent(y) : rand_tangent(y) end end # call finitedifferences, avoid mutating original arguments @@ -137,7 +136,7 @@ function test_reverse( else # if there's a shadow result, then we need to set it to our random adjoint if !(shadow_result === nothing) - if !_any_batch_duplicated(map(typeof, activities)...) + if !_any_batch_duplicated(ret_activity, map(typeof, activities)...) map_fields_recursive(copyto!, shadow_result, ȳ) else for (sr, dy) in zip(shadow_result, ȳ) diff --git a/lib/EnzymeTestUtils/src/to_vec.jl b/lib/EnzymeTestUtils/src/to_vec.jl new file mode 100644 index 0000000000..412c6efb1b --- /dev/null +++ b/lib/EnzymeTestUtils/src/to_vec.jl @@ -0,0 +1,155 @@ +# Like an IdDict, but also handles cases where 2 arrays share the same memory due to +# reshaping +struct AliasDict{K,V} <: AbstractDict{K,V} + id_dict::IdDict{K,V} + dataids_dict::IdDict{Tuple{UInt,Vararg{UInt}},V} +end +AliasDict() = AliasDict(IdDict(), IdDict{Tuple{UInt,Vararg{UInt}},Any}()) + +function Base.haskey(d::AliasDict, key) + haskey(d.id_dict, key) && return true + key isa Array && haskey(d.dataids_dict, Base.dataids(key)) && return true + return false +end + +Base.getindex(d::AliasDict, key) = d.id_dict[key] +function Base.getindex(d::AliasDict, key::Array) + haskey(d.id_dict, key) && return d.id_dict[key] + dataids = Base.dataids(key) + return d.dataids_dict[dataids] +end + +function Base.setindex!(d::AliasDict, val, key) + d.id_dict[key] = val + if key isa Array + dataids = Base.dataids(key) + d.dataids_dict[dataids] = val + end + return d +end + +# alternative to FiniteDifferences.to_vec to use Enzyme's semantics for arrays instead of +# ChainRules': Enzyme treats tangents of AbstractArrays the same as tangents of any other +# struct (i.e. with a container of the same type as the original), while ChainRules +# represents the tangent with an array of some type that is tangent to the subspace defined +# by the original array type. +# We take special care that floats that occupy the same memory in the argument only appear +# once in the vector, and that the reconstructed object shares the same memory pattern + +function to_vec(x) + x_vec, from_vec_inner = to_vec(x, AliasDict()) + from_vec(x_vec::Vector{<:AbstractFloat}) = from_vec_inner(x_vec, AliasDict()) + return x_vec, from_vec +end + +# base case: we've unwrapped to a number, so we break the recursion +function to_vec(x::AbstractFloat, seen_vecs::AliasDict) + AbstractFloat_from_vec(v::Vector{<:AbstractFloat}, _) = oftype(x, only(v)) + return [x], AbstractFloat_from_vec +end + +# basic containers: loop over defined elements, recursively converting them to vectors +function to_vec(x::RT, seen_vecs::AliasDict) where {RT<:Array} + has_seen = haskey(seen_vecs, x) + is_const = Enzyme.Compiler.guaranteed_const_nongen(RT, nothing) + if has_seen || is_const + x_vec = Float32[] + else + x_vecs = Vector{<:AbstractFloat}[] + from_vecs = [] + subvec_inds = UnitRange{Int}[] + l = 0 + for i in eachindex(x) + isassigned(x, i) || continue + xi_vec, xi_from_vec = to_vec(x[i], seen_vecs) + push!(x_vecs, xi_vec) + push!(from_vecs, xi_from_vec) + push!(subvec_inds, (l + 1):(l + length(xi_vec))) + l += length(xi_vec) + end + x_vec = reduce(vcat, x_vecs; init=Float32[]) + seen_vecs[x] = x_vec + end + function Array_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict) + if xor(has_seen, haskey(seen_xs, x)) + throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized.")) + end + has_seen && return reshape(seen_xs[x], size(x)) + is_const && return x + x_new = typeof(x)(undef, size(x)) + k = 1 + for i in eachindex(x) + isassigned(x, i) || continue + xi = from_vecs[k](x_vec_new[subvec_inds[k]], seen_xs) + x_new[i] = xi + k += 1 + end + seen_xs[x] = x_new + return x_new + end + return x_vec, Array_from_vec +end +function to_vec(x::Tuple, seen_vecs::AliasDict) + x_vec, from_vec = to_vec(collect(x), seen_vecs) + function Tuple_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict) + return typeof(x)(Tuple(from_vec(x_vec_new, seen_xs))) + end + return x_vec, Tuple_from_vec +end +function to_vec(x::NamedTuple, seen_vecs::AliasDict) + x_vec, from_vec = to_vec(values(x), seen_vecs) + function NamedTuple_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict) + return NamedTuple{keys(x)}(from_vec(x_vec_new, seen_xs)) + end + return x_vec, NamedTuple_from_vec +end + +# fallback: for any other struct, loop over fields, recursively converting them to vectors +function to_vec(x::RT, seen_vecs::AliasDict) where {RT} + has_seen = haskey(seen_vecs, x) + is_const = Enzyme.Compiler.guaranteed_const_nongen(RT, nothing) + if has_seen || is_const + x_vec = Float32[] + else + @assert !Base.isabstracttype(RT) + @assert Base.isconcretetype(RT) + nf = fieldcount(RT) + flds = Vector{Any}(undef, nf) + for i in 1:nf + if isdefined(x, i) + flds[i] = xi = getfield(x, i) + elseif !ismutable(x) + nf = i - 1 # rest of tail must be undefined values + break + end + end + x_vec, fields_from_vec = to_vec(flds, seen_vecs) + if ismutable(x) + seen_vecs[x] = x_vec + end + end + function Struct_from_vec(x_vec_new::Vector{<:AbstractFloat}, seen_xs::AliasDict) + if xor(has_seen, haskey(seen_xs, x)) + throw(ErrorException("Objects must be reconstructed in the same order as they are vectorized.")) + end + has_seen && return seen_xs[x] + (is_const || nf == 0) && return x + flds_new = fields_from_vec(x_vec_new, seen_xs) + if ismutable(x) + x_new = ccall(:jl_new_struct_uninit, Any, (Any,), RT) + for i in 1:nf + if isdefined(x, i) + xi = flds_new[i] + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x_new, i - 1, xi) + end + end + else + x_new = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds_new, nf) + end + if ismutable(x) + seen_xs[x] = x_new + end + return x_new + end + return x_vec, Struct_from_vec +end diff --git a/lib/EnzymeTestUtils/test/helpers.jl b/lib/EnzymeTestUtils/test/helpers.jl index c3e3ece134..6754b0a935 100644 --- a/lib/EnzymeTestUtils/test/helpers.jl +++ b/lib/EnzymeTestUtils/test/helpers.jl @@ -1,8 +1,22 @@ +using LinearAlgebra + struct TestStruct{X,A} x::X a::A end +struct TestStruct2 + x::Any + a::Any + TestStruct2(x) = new(x) +end + +mutable struct MutableTestStruct + x::Any + a::Any + MutableTestStruct() = new() +end + struct MutatedCallable{T} x::T end @@ -14,3 +28,14 @@ end f_array(x) = sum(abs2, x) f_multiarg(x::AbstractArray, a) = abs2.(a .* x) + +function f_structured_array(x::Hermitian) + y = x * 3 + # mutate the unused triangle, which ensures that our Jacobian differs from FiniteDifferences + if y.uplo == 'U' + LowerTriangular(y.data) .*= 2 + else + UpperTriangular(y.data) .*= 2 + end + return y +end diff --git a/lib/EnzymeTestUtils/test/runtests.jl b/lib/EnzymeTestUtils/test/runtests.jl index 7785fe151a..8883ee78ef 100644 --- a/lib/EnzymeTestUtils/test/runtests.jl +++ b/lib/EnzymeTestUtils/test/runtests.jl @@ -8,6 +8,7 @@ Random.seed!(0) include("helpers.jl") include("test_approx.jl") include("compatible_activities.jl") + include("to_vec.jl") include("generate_tangent.jl") include("test_forward.jl") include("test_reverse.jl") diff --git a/lib/EnzymeTestUtils/test/test_approx.jl b/lib/EnzymeTestUtils/test/test_approx.jl index 57f1145576..99b8070bc1 100644 --- a/lib/EnzymeTestUtils/test/test_approx.jl +++ b/lib/EnzymeTestUtils/test/test_approx.jl @@ -25,6 +25,37 @@ end @test fails(() -> test_approx([0, 1], [0, 1 + 1e-9]; rtol=1e-9)) @test errors(() -> test_approx([1, 2], [1, 2, 3])) end + @testset "tuples" begin + test_approx((1, 2), (1, 2)) + test_approx((1, 2), (1, 2 + 1e-9); atol=1.1e-9) + @test fails(() -> test_approx((1, 2), (1, 2 + 1e-9); atol=1e-9)) + test_approx((0, 1), (0, 1 + 1e-9); rtol=1.1e-9) + @test fails(() -> test_approx((0, 1), (0, 1 + 1e-9); rtol=1e-9)) + @test fails(() -> test_approx((1, 2), (1, 2, 3))) + end + @testset "type" begin + test_approx(Bool, Bool) + test_approx(String, String) + @test fails(() -> test_approx(Bool, String)) + end + @testset "dict" begin + x1 = Dict(:x => randn(3), :y => randn(2)) + x2 = Dict(:x => copy(x1[:x]), :y => copy(x1[:y])) + test_approx(x1, x2) + for i in eachindex(x2[:x]), err in (1e-2, 1e-9) + y = copy(x1[:x]) + y[i] += rand((-1, 1)) * err + x2[:x] = y + test_approx(x1, x2; atol=err * 1.1) + @test fails() do + return test_approx(x1, x2; atol=err * 0.9) + end + end + x2[:x] = vcat(x1[:x], 1.0) + @test errors() do + return test_approx(x1, x2; atol=err * 0.9) + end + end @testset "non-numeric types" begin test_approx(:x, :x) @test fails(() -> test_approx(:x, :y)) diff --git a/lib/EnzymeTestUtils/test/test_forward.jl b/lib/EnzymeTestUtils/test/test_forward.jl index a2ab010042..7f870af7bf 100644 --- a/lib/EnzymeTestUtils/test/test_forward.jl +++ b/lib/EnzymeTestUtils/test/test_forward.jl @@ -1,5 +1,6 @@ using Enzyme using EnzymeTestUtils +using LinearAlgebra using MetaTesting using Test @@ -133,6 +134,49 @@ end end end + VERSION >= v"1.8" && @testset "structured array inputs/outputs" begin + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), + T in (Float32, Float64, ComplexF32, ComplexF64) + + # if some are batch, none must be duplicated + are_activities_compatible(Tret, Tx) || continue + + x = Hermitian(randn(T, 5, 5)) + + atol = rtol = sqrt(eps(real(T))) + test_forward(f_structured_array, Tret, (x, Tx); atol, rtol) + end + end + + @testset "equivalent arrays in output" begin + function f(x) + z = x * 2 + return (z, z) + end + x = randn(2, 3) + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Tx) || continue + test_forward(f, Tret, (x, Tx)) + end + end + + @testset "arrays sharing memory in output" begin + function f(x) + z = x * 2 + return (z, z) + end + x = randn(2, 3) + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Tx) || continue + test_forward(f, Tret, (x, Tx)) + end + end + @testset "mutating function" begin Enzyme.API.runtimeActivity!(true) sz = (2, 3) @@ -163,10 +207,10 @@ end x = randn(3) a = randn() - test_reverse(f_kwargs_fwd!, Const, (x, Tx); fkwargs=(; a)) + test_forward(f_kwargs_fwd!, Const, (x, Tx); fkwargs=(; a)) fkwargs = (; a, incorrect_primal=true) @test fails() do - test_forward(f_kwargs_fwd!, Const, (x, Tx); fkwargs) + return test_forward(f_kwargs_fwd!, Const, (x, Tx); fkwargs) end end end diff --git a/lib/EnzymeTestUtils/test/test_reverse.jl b/lib/EnzymeTestUtils/test/test_reverse.jl index f73f3eaed3..b394fa171d 100644 --- a/lib/EnzymeTestUtils/test/test_reverse.jl +++ b/lib/EnzymeTestUtils/test/test_reverse.jl @@ -1,11 +1,12 @@ using Enzyme using EnzymeTestUtils +using LinearAlgebra using MetaTesting using Test function f_mut_rev!(y, x, a) map!(xi -> xi * a, y, x) - return nothing + return y end f_kwargs_rev(x; a=3.0, kwargs...) = a .* x .^ 2 @@ -90,23 +91,72 @@ end end end + VERSION >= v"1.8" && @testset "structured array inputs/outputs" begin + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), + T in (Float32, Float64, ComplexF32, ComplexF64) + + # if some are batch, none must be duplicated + are_activities_compatible(Tret, Tx) || continue + + x = Hermitian(randn(T, 5, 5)) + + atol = rtol = sqrt(eps(real(T))) + test_reverse(f_structured_array, Tret, (x, Tx); atol, rtol) + end + end + + @testset "equivalent arrays in output" begin + function f(x) + z = x * 2 + return (z, z) + end + x = randn(2, 3) + + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Tx) || continue + test_reverse(f, Tret, (x, Tx)) + end + end + + @testset "arrays sharing memory in output" begin + function f(x) + z = x * 2 + return (z, vec(z)) + end + x = randn(2, 3) + @testset for Tret in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Tx) || continue + if Tx <: Const + test_reverse(f, Tret, (x, Tx)) + else + @test_broken !fails() do + return test_reverse(f, Tret, (x, Tx)) + end + end + end + end + @testset "mutating function" begin sz = (2, 3) @testset for Ty in (Const, Duplicated, BatchDuplicated), - Tx in (Const, Duplicated, BatchDuplicated), - Ta in (Const, Active), - Tret in (Const,), # return value is nothing - T in (Float32, Float64, ComplexF32, ComplexF64) + Tx in (Const, Duplicated, BatchDuplicated), + Ta in (Const, Active), + T in (Float32, Float64, ComplexF32, ComplexF64) # if some are batch, none must be duplicated - are_activities_compatible(Tret, Ty, Tx, Ta) || continue + are_activities_compatible(Ty, Tx, Ta) || continue x = randn(T, sz) y = zeros(T, sz) a = randn(T) atol = rtol = sqrt(eps(real(T))) - test_reverse(f_mut_rev!, Tret, (y, Ty), (x, Tx), (a, Ta); atol, rtol) + test_reverse(f_mut_rev!, Ty, (y, Ty), (x, Tx), (a, Ta); atol, rtol) end end diff --git a/lib/EnzymeTestUtils/test/to_vec.jl b/lib/EnzymeTestUtils/test/to_vec.jl new file mode 100644 index 0000000000..3f7609d47a --- /dev/null +++ b/lib/EnzymeTestUtils/test/to_vec.jl @@ -0,0 +1,175 @@ +using EnzymeTestUtils +using EnzymeTestUtils: to_vec +using Test + +function test_to_vec(x) + x_vec, from_vec = to_vec(x) + @test x_vec isa Vector{<:AbstractFloat} + x2 = from_vec(x_vec) + @test typeof(x2) === typeof(x) + return EnzymeTestUtils.test_approx(x2, x) +end + +@testset "to_vec" begin + @testset "BLAS floats" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + x = randn(T) + test_to_vec(x) + if T <: Real + @test to_vec(x)[1] == [x] + else + @test to_vec(x)[1] == [reim(x)...] + end + end + end + + @testset "non-vectorizable cases" begin + @testset for x in [Bool, (), true, 1, [2], (3, "string")] + test_to_vec(x) + @test isempty(to_vec(x)[1]) + end + end + + @testset "array of floats" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64), + sz in (2, (2, 3), (2, 3, 4)) + + test_to_vec(randn(T, sz)) + end + end + + @testset "struct" begin + v = randn(2, 3) + x = TestStruct(1, TestStruct("foo", v)) + test_to_vec(x) + @test to_vec(x)[1] == vec(v) + + x = (TestStruct(1.0, 2.0), TestStruct(1.0, 2.0)) + v, from_vec = to_vec(x) + @test v == [1.0, 2.0, 1.0, 2.0] + @test from_vec(v) === x + end + + @testset "incompletely initialized struct" begin + x = randn(2, 3) + y = TestStruct2(x) + v, from_vec = to_vec(y) + @test v == vec(x) + v2 = randn(size(v)) + y2 = from_vec(v2) + @test y2.x == reshape(v2, size(x)) + @test !isdefined(y2, :a) + end + + @testset "mutable struct" begin + @testset for k in (:a, :x) + x = randn(2, 3) + y = MutableTestStruct() + setfield!(y, k, x) + @test isdefined(y, k) + @test getfield(y, k) == x + v, from_vec = to_vec(y) + @test v == vec(x) + v2 = randn(size(v)) + y2 = from_vec(v2) + @test getfield(y2, k) == reshape(v2, size(x)) + @test !isdefined(y2, k === :a ? :x : :a) + end + + y = MutableTestStruct() + y.x = randn() + t = (y, y) + v, from_vec = to_vec(t) + @test v == [y.x] + t2 = from_vec(v) + @test t2[1] === t2[2] + + t = (y, deepcopy(y)) + v, from_vec = to_vec(t) + @test v == [y.x, y.x] + t2 = from_vec(v) + @test t2[1].x == t2[2].x + @test t2[1] !== t2[2] + end + + @testset "nested array" begin + @testset for T in (Float32, Float64, ComplexF32, ComplexF64), + sz in (2, (2, 3), (2, 3, 4)) + + test_to_vec([randn(T, sz) for _ in 1:10]) + end + end + + @testset "partially defined array" begin + @testset for i in 1:2 + x = Vector{Vector{Float64}}(undef, 2) + x[i] = randn(5) + v, from_vec = to_vec(x) + @test v == x[i] + v2 = randn(size(v)) + x2 = from_vec(v2) + @test x2[i] == v2 + @test !isassigned(x2, 3 - i) + end + end + + @testset "tuple" begin + v = randn(3) + x = ("foo", 1, false, String, TestStruct(3.0, v)) + test_to_vec(x) + @test to_vec(x)[1] == vcat(3.0, v) + end + + @testset "namedtuple" begin + x = (x="bar", y=randn(3), z=randn(), w=TestStruct(4.0, randn(2))) + test_to_vec(x) + @test to_vec(x)[1] == vcat(x.y, x.z, x.w.x, x.w.a) + end + + @testset "dict" begin + x = Dict(:a => randn(2), :b => randn(3)) + test_to_vec(x) + end + + @testset "views of arrays" begin + x = randn(2, 3) + test_to_vec(reshape(x, 3, 2)) + test_to_vec(view(x, :, 1)) + test_to_vec(PermutedDimsArray(x, (2, 1))) + end + + @testset "subarrays" begin + x = randn(2, 3) + # note: bottom right 2x2 submatrix ommited from y but will be present in v + y = @views (x[:, 1], x[1, :]) + test_to_vec(y) + v, from_vec = to_vec(y) + @test v == vec(x) + v2 = randn(size(v)) + y2 = from_vec(v2) + @test y2[1] == reshape(v2, size(x))[:, 1] + @test y2[2] == reshape(v2, size(x))[1, :] + @test Base.dataids(y2[1]) == Base.dataids(y2[2]) + end + + @testset "reshaped arrays share memory" begin + struct MyContainer1 + a::Any + b::Any + end + mutable struct MyContainer2 + a::Any + b::Any + end + @testset for T in (MyContainer1, MyContainer2) + x = randn(2, 3) + x2 = vec(x) + y = T(x, x2) + test_to_vec(y) + v, from_vec = to_vec(y) + @test v == x2 + y2 = from_vec(v) + @test Base.dataids(y2.a) == Base.dataids(y2.b) + end + end +end