Skip to content

Commit

Permalink
Permit opting in to unsafe perturbations (#362)
Browse files Browse the repository at this point in the history
* Add test case

* Update tests to do the right thing

* Fix up add_to_primal

* Bump patch version

* Fix error message

* Impove docstring

* Enable unsafe perturbation for Bijectors

* Fix up Bijectors testing

* Remove accidentally-commited debug info

* Remove redundant code

* Write docstring

* Remove redundant method
  • Loading branch information
willtebbutt authored Nov 11, 2024
1 parent d66f873 commit 9a1bb51
Show file tree
Hide file tree
Showing 16 changed files with 152 additions and 77 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.39"
version = "0.4.40"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion ext/MooncakeCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ TestUtils.has_equal_data(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x == y
increment!!(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x .+= y
__increment_should_allocate(::Type{<:CuArray{<:IEEEFloat}}) = true
set_to_zero!!(x::CuArray{<:IEEEFloat}) = x .= 0
_add_to_primal(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x + y
_add_to_primal(x::P, y::P, ::Bool) where {P<:CuArray{<:IEEEFloat}} = x + y
_diff(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x - y
_dot(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = Float64(dot(x, y))
_scale(x::Float64, y::P) where {T<:IEEEFloat, P<:CuArray{T}} = T(x) * y
Expand Down
4 changes: 2 additions & 2 deletions src/rrules/array_legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ function _dot(t::T, s::T) where {T<:Array}
)
end

function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}) where {P, N}
function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}, unsafe::Bool) where {P, N}
x′ = Array{P, N}(undef, size(x)...)
return _map_if_assigned!(_add_to_primal, x′, x, t)
return _map_if_assigned!((x, t) -> _add_to_primal(x, t, unsafe), x′, x, t)
end

function _diff(p::P, q::P) where {V, N, P<:Array{V, N}}
Expand Down
4 changes: 2 additions & 2 deletions src/rrules/iddict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ function _scale(a::Float64, t::IdDict{K, V}) where {K, V}
return IdDict{K, V}([k => _scale(a, v) for (k, v) in t])
end
_dot(p::T, q::T) where {T<:IdDict} = sum([_dot(p[k], q[k]) for k in keys(p)]; init=0.0)
function _add_to_primal(p::IdDict{K, V}, t::IdDict{K}) where {K, V}
function _add_to_primal(p::IdDict{K, V}, t::IdDict{K}, unsafe::Bool) where {K, V}
ks = intersect(keys(p), keys(t))
return IdDict{K, V}([k => _add_to_primal(p[k], t[k]) for k in ks])
return IdDict{K, V}([k => _add_to_primal(p[k], t[k], unsafe) for k in ks])
end
function _diff(p::P, q::P) where {K, V, P<:IdDict{K, V}}
@assert union(keys(p), keys(q)) == keys(p)
Expand Down
14 changes: 9 additions & 5 deletions src/rrules/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ end

set_to_zero!!(x::Memory) = _map_if_assigned!(set_to_zero!!, x, x)

function _add_to_primal(p::Memory{P}, t::Memory) where {P}
return _map_if_assigned!(_add_to_primal, Memory{P}(undef, length(p)), p, t)
function _add_to_primal(p::Memory{P}, t::Memory, unsafe::Bool) where {P}
return _map_if_assigned!(
(p, t) -> _add_to_primal(p, t, unsafe), Memory{P}(undef, length(p)), p, t
)
end

function _diff(p::Memory{P}, q::Memory{P}) where {P}
Expand Down Expand Up @@ -172,9 +174,9 @@ function _dot(t::T, s::T) where {T<:Array}
)
end

function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}) where {P, N}
function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}, unsafe::Bool) where {P, N}
x′ = Array{P, N}(undef, size(x)...)
return _map_if_assigned!(_add_to_primal, x′, x, t)
return _map_if_assigned!((x, t) -> _add_to_primal(x, t, unsafe), x′, x, t)
end

function _diff(p::P, q::P) where {P<:Array}
Expand Down Expand Up @@ -273,7 +275,9 @@ function set_to_zero!!(x::MemoryRef)
return x
end

_add_to_primal(p::MemoryRef, t::MemoryRef) = construct_ref(p, _add_to_primal(p.mem, t.mem))
function _add_to_primal(p::MemoryRef, t::MemoryRef, unsafe::Bool)
return construct_ref(p, _add_to_primal(p.mem, t.mem, unsafe))
end

function _diff(p::P, q::P) where {P<:MemoryRef}
@assert Core.memoryrefoffset(p) == Core.memoryrefoffset(q)
Expand Down
2 changes: 1 addition & 1 deletion src/rrules/tasks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ increment!!(t::TaskTangent, s::TaskTangent) = t

set_to_zero!!(t::TaskTangent) = t

_add_to_primal(p::Task, t::TaskTangent) = p
_add_to_primal(p::Task, t::TaskTangent, ::Bool) = p

_diff(::Task, ::Task) = TaskTangent()

Expand Down
87 changes: 71 additions & 16 deletions src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -637,37 +637,92 @@ function _dot(t::T, s::T) where {T<:Union{Tangent, MutableTangent}}
end

"""
_add_to_primal(p::P, t::T) where {P, T}
Required for testing.
_Not_ currently defined by default.
`_containerlike_add_to_primal` is potentially what you want to target when implementing for
a particular primal-tangent pair.
_add_to_primal(p::P, t::T, unsafe::Bool=false) where {P, T}
Adds `t` to `p`, returning a `P`. It must be the case that `tangent_type(P) == T`.
If `unsafe` is `true` and `P` is a composite type, then `_add_to_primal` will construct a
new instance of `P` by directly invoking the `:new` instruction for `P`, rather than
attempting to use the default constructor for `P`. This is fine if you are confident that
the new `P` constructed by adding `t` to `p` will always be a valid instance of `P`, but
could cause problems if you are not confident of this.
This is, for example, fine for the following type:
```julia
struct Foo{T}
x::Vector{T}
y::Vector{T}
function Foo(x::Vector{T}, y::Vector{T}) where {T}
@assert length(x) == length(y)
return new{T}(x, y)
end
end
```
Here, the value returned by `_add_to_primal` will satisfy the invariant asserted in the
inner constructor for `Foo`.
"""
_add_to_primal(x, ::NoTangent) = x
_add_to_primal(x::T, t::T) where {T<:IEEEFloat} = x + t
function _add_to_primal(x::SimpleVector, t::Vector{Any})
return svec(map(n -> _add_to_primal(x[n], t[n]), eachindex(x))...)
_add_to_primal(p, t) = _add_to_primal(p, t, false)
_add_to_primal(x, ::NoTangent, ::Bool) = x
_add_to_primal(x::T, t::T, ::Bool) where {T<:IEEEFloat} = x + t
function _add_to_primal(x::SimpleVector, t::Vector{Any}, unsafe::Bool)
return svec(map(n -> _add_to_primal(x[n], t[n], unsafe), eachindex(x))...)
end
function _add_to_primal(x::Tuple, t::Tuple, unsafe::Bool)
return _map((x, t) -> _add_to_primal(x, t, unsafe), x, t)
end
function _add_to_primal(x::NamedTuple, t::NamedTuple, unsafe::Bool)
return _map((x, t) -> _add_to_primal(x, t, unsafe), x, t)
end
_add_to_primal(x::Tuple, t::Tuple) = _map(_add_to_primal, x, t)
_add_to_primal(x::NamedTuple, t::NamedTuple) = _map(_add_to_primal, x, t)
_add_to_primal(x, ::Tangent{NamedTuple{(), Tuple{}}}) = x

function _add_to_primal(p::P, t::T) where {P, T<:Union{Tangent, MutableTangent}}
struct AddToPrimalException <: Exception
primal_type::Type
end

function Base.showerror(io::IO, err::AddToPrimalException)
msg = "Attempted to construct an instance of $(err.primal_type) using the default " *
"constuctor. In most cases, this error is caused by the lack of existence of the " *
"default constructor for this type. There are two approaches to dealing with " *
"this problem. The first is to avoid having to call `_add_to_primal` on this " *
"type, which can be achieved by avoiding testing functions whose arguments are " *
"of this type. If this cannot be avoided, you should consider using calling " *
"`Mooncake._add_to_primal` with its third positional argument set to `true`. " *
"If you are using some of Mooncake's testing functionality, this can be achieved " *
"by setting the `unsafe_perturb` setting to `true` -- check the docstring " *
"for `Mooncake._add_to_primal` to ensure that your use case is unlikely to " *
"cause problems."
println(io, msg)
end

function _add_to_primal(p::P, t::T, unsafe::Bool) where {P, T<:Union{Tangent, MutableTangent}}
Tt = tangent_type(P)
if Tt != typeof(t)
throw(ArgumentError("p of type $P has tangent_type $Tt, but t is of type $T"))
end
tmp = map(fieldnames(P)) do f
tf = getfield(t.fields, f)
isdefined(p, f) && is_init(tf) && return _add_to_primal(getfield(p, f), val(tf))
isdefined(p, f) && is_init(tf) && return _add_to_primal(getfield(p, f), val(tf), unsafe)
!isdefined(p, f) && !is_init(tf) && return FieldUndefined()
throw(error("unable to handle undefined-ness"))
end
i = findfirst(==(FieldUndefined()), tmp)
return i === nothing ? P(tmp...) : P(tmp[1:i-1]...)

# If unsafe mode is enabled, then call `_new_` directly, and avoid the possibility that
# the default inner constructor for `P` does not exist.
if unsafe
return i === nothing ? _new_(P, tmp...) : _new_(P, tmp[1:i-1]...)
end

# If unsafe mode is disabled, try to use the default constructor for `P`. If this does
# not work, then throw an informative error message.
try
return i === nothing ? P(tmp...) : P(tmp[1:i-1]...)
catch e
if e isa MethodError
throw(AddToPrimalException(P))
else
rethrow(e)
end
end
end

"""
Expand Down
5 changes: 5 additions & 0 deletions src/test_resources.jl
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,11 @@ end

tuple_with_union(x::Bool) = (x ? 5.0 : 5, nothing)

struct NoDefaultCtor{T}
x::T
NoDefaultCtor(x::T) where {T} = new{T}(x)
end

function generate_test_functions()
return Any[
(false, :allocs, nothing, const_tester),
Expand Down
41 changes: 22 additions & 19 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,35 +283,34 @@ function address_maps_are_consistent(x::AddressMap, y::AddressMap)
end

# Assumes that the interface has been tested, and we can simply check for numerical issues.
function test_rrule_numerical_correctness(rng::AbstractRNG, f_f̄, x_x̄...; rule)
@nospecialize rng f_f̄ x_x̄
function test_rule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb::Bool)
@nospecialize rng x_x̄

x_x̄ = map(_deepcopy, x_x̄) # defensive copy

# Run original function on deep-copies of inputs.
f = primal(f_f̄)
x = map(primal, x_x̄)
= map(tangent, x_x̄)

# Run primal, and ensure that we still have access to mutated inputs afterwards.
x_primal = _deepcopy(x)
y_primal = f(x_primal...)
y_primal = x_primal[1](x_primal[2:end]...)

# Use finite differences to estimate vjps
= map(_x -> randn_tangent(rng, _x), x)
ε = 1e-7
x′ = _add_to_primal(x, _scale(ε, ẋ))
y′ = f(x′...)
x′ = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb)
y′ = x′[1](x′[2:end]...)
= _scale(1 / ε, _diff(y′, y_primal))
ẋ_post = map((_x′, _x_p) -> _scale(1 / ε, _diff(_x′, _x_p)), x′, x_primal)

# Run `rrule!!` on copies of `f` and `x`. We use randomly generated tangents so that we
# Run rule on copies of `f` and `x`. We use randomly generated tangents so that we
# can later verify that non-zero values do not get propagated by the rule.
x̄_zero = map(zero_tangent, x)
x̄_fwds = map(Mooncake.fdata, x̄_zero)
x_x̄_rule = map((x, x̄_f) -> fcodual_type(_typeof(x))(_deepcopy(x), x̄_f), x, x̄_fwds)
inputs_address_map = populate_address_map(map(primal, x_x̄_rule), map(tangent, x_x̄_rule))
y_ȳ_rule, pb!! = rule(to_fwds(f_f̄), x_x̄_rule...)
y_ȳ_rule, pb!! = rule(x_x̄_rule...)

# Verify that inputs / outputs are the same under `f` and its rrule.
@test has_equal_data(x_primal, map(primal, x_x̄_rule))
Expand All @@ -332,8 +331,8 @@ function test_rrule_numerical_correctness(rng::AbstractRNG, f_f̄, x_x̄...; rul
x̄_init = map(set_to_zero!!, x̄_zero)
= increment!!(ȳ_init, ȳ_delta)
map(increment!!, x̄_init, x̄_delta)
_, x̄_rvs_inc... = pb!!(Mooncake.rdata(ȳ))
x̄_rvs = map((x, x_inc) -> increment!!(rdata(x), x_inc), x̄_delta, x̄_rvs_inc)
x̄_rvs_inc = pb!!(Mooncake.rdata(ȳ))
x̄_rvs = increment!!(map(rdata, x̄_delta), x̄_rvs_inc)
= map(tangent, x̄_fwds, x̄_rvs)

# Check that inputs have been returned to their original value.
Expand Down Expand Up @@ -481,6 +480,7 @@ __get_primals(xs) = map(x -> x isa CoDual ? primal(x) : x, xs)
perf_flag::Symbol=:none,
interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(),
debug_mode::Bool=false,
unsafe_perturb::Bool=false,
)
Run standardised tests on the `rule` for `x`.
Expand Down Expand Up @@ -527,6 +527,9 @@ This function uses [`Mooncake.build_rrule`](@ref) to construct a rule. This will
Typically this should be left at its default `false` value, but if you are finding that
the tests are failing for a given rule, you may wish to temporarily set it to `true` in
order to get access to additional information and automated testing.
- `unsafe_perturb::Bool=false`: value passed as the third argument to `_add_to_primal`.
Should usually be left `false` -- consult the docstring for `_add_to_primal` for more
info on when you might wish to set it to `true`.
"""
function test_rule(
rng::AbstractRNG, x...;
Expand All @@ -535,16 +538,16 @@ function test_rule(
perf_flag::Symbol=:none,
interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(),
debug_mode::Bool=false,
unsafe_perturb::Bool=false,
)
@nospecialize rng x

# Construct the rule.
rule = Mooncake.build_rrule(interp, _typeof(__get_primals(x)); debug_mode)
sig = _typeof(__get_primals(x))
rule = Mooncake.build_rrule(interp, sig; debug_mode)

# If something is primitive, then the rule should be `rrule!!`.
if is_primitive
@test rule == (debug_mode ? Mooncake.DebugRRule(rrule!!) : rrule!!)
end
is_primitive && @test rule == (debug_mode ? Mooncake.DebugRRule(rrule!!) : rrule!!)

# Generate random tangents for anything that is not already a CoDual.
x_x̄ = map(x -> x isa CoDual ? x : interface_only ? uninit_codual(x) : zero_codual(x), x)
Expand All @@ -553,14 +556,13 @@ function test_rule(
test_rrule_interface(x_x̄...; rule)

# Test that answers are numerically correct / consistent.
interface_only || test_rrule_numerical_correctness(rng, x_x̄...; rule)
interface_only || test_rule_correctness(rng, x_x̄...; rule, unsafe_perturb)

# Test the performance of the rule.
test_rrule_performance(perf_flag, rule, x_x̄...)

# Test the interface again, in order to verify that caching is working correctly.
rule_2 = Mooncake.build_rrule(interp, _typeof(__get_primals(x)); debug_mode)
test_rrule_interface(x_x̄..., rule=rule_2)
test_rrule_interface(x_x̄..., rule=Mooncake.build_rrule(interp, sig; debug_mode))
end


Expand Down Expand Up @@ -790,6 +792,7 @@ function test_tangent_consistency(rng::AbstractRNG, p::P; interface_only=false)
# Verify that operations required for finite difference testing to run, and produce the
# correct output type.
@test _add_to_primal(p, t) isa P
@test _add_to_primal(p, t, true) isa P
@test _diff(p, p) isa T
@test _dot(t, t) isa Float64
@test _scale(11.0, t) isa T
Expand All @@ -798,9 +801,9 @@ function test_tangent_consistency(rng::AbstractRNG, p::P; interface_only=false)
# Run some basic numerical sanity checks on the output the functions required for finite
# difference testing. These are necessary but insufficient conditions.
if !interface_only
@test has_equal_data(_add_to_primal(p, z), p)
@test has_equal_data(_add_to_primal(p, z, true), p)
if !has_equal_data(z, r)
@test !has_equal_data(_add_to_primal(p, r), p)
@test !has_equal_data(_add_to_primal(p, r, true), p)
end
@test has_equal_data(_diff(p, p), zero_tangent(p))
end
Expand Down
2 changes: 1 addition & 1 deletion test/ext/dynamic_ppl/dynamic_ppl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ Pkg.develop(; path=joinpath(@__DIR__, "..", "..", ".."))
using DynamicPPL, Mooncake, Test

@testset "DynamicPPLMooncakeExt" begin
test_rule(sr(123456), DynamicPPL.istrans, DynamicPPL.VarInfo(); interface_only=true)
test_rule(sr(123456), DynamicPPL.istrans, DynamicPPL.VarInfo(); unsafe_perturb=true)
end
10 changes: 5 additions & 5 deletions test/integration_testing/bijectors/bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Pkg
Pkg.activate(@__DIR__)
Pkg.develop(; path = joinpath(@__DIR__, "..", "..", ".."))

using Bijectors: Bijectors
using Bijectors: Bijectors, inverse
using LinearAlgebra: LinearAlgebra
using Random: randn

Expand All @@ -25,8 +25,7 @@ function b_binv_test_case(bijector, dim; name = nothing, rng = Xoshiro(23))
if name === nothing
name = string(bijector)
end
b_inv = Bijectors.inverse(bijector)
return TestCase(x -> bijector(b_inv(x)), randn(rng, dim); name = name)
return TestCase(x -> bijector(inverse(bijector)(x)), randn(rng, dim); name = name)
end

@testset "Bijectors integration tests" begin
Expand All @@ -43,7 +42,7 @@ end
Bijectors.Coupling(Bijectors.Shift, Bijectors.PartitionMask(3, [1], [2])),
3,
),
b_binv_test_case(Bijectors.InvertibleBatchNorm(3), (3, 3)),
b_binv_test_case(Bijectors.InvertibleBatchNorm(3; eps=1e-5, mtm=1e-1), (3, 3)),
b_binv_test_case(Bijectors.LeakyReLU(0.2), 3),
b_binv_test_case(Bijectors.Logit(0.1, 0.3), 3),
b_binv_test_case(Bijectors.PDBijector(), (3, 3)),
Expand Down Expand Up @@ -128,7 +127,8 @@ end
true
end
else
test_rule(Xoshiro(123456), case.func, case.arg; is_primitive=false)
rng = Xoshiro(123456)
test_rule(rng, case.func, case.arg; is_primitive=false, unsafe_perturb=true)
end
end
end
2 changes: 1 addition & 1 deletion test/integration_testing/diff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
TestResources.DIFFTESTS_FUNCTIONS[91:end], # skipping sparse_ldiv
))
@info "$n: $(_typeof((f, x...)))"
test_rule(sr(123456), f, x...; interface_only=false, is_primitive=false)
test_rule(sr(123456), f, x...; is_primitive=false)
end
end
Loading

2 comments on commit 9a1bb51

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/119141

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.40 -m "<description of version>" 9a1bb51641a16dfbbe3dfe288e358e19d4996bb0
git push origin v0.4.40

Please sign in to comment.