Skip to content

Commit

Permalink
Add function setladj
Browse files Browse the repository at this point in the history
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
oschulz and devmotion committed Jul 3, 2023
1 parent bdcf78c commit f659510
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 5 deletions.
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,24 @@ uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.7"

[deps]
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"

[extensions]
ChangesOfVariablesInverseFunctionsExt = "InverseFunctions"

[compat]
InverseFunctions = "0.1"
julia = "1"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"

[targets]
test = ["Documenter", "ForwardDiff"]
test = ["Documenter", "InverseFunctions", "ForwardDiff"]
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ changes for functions that perform a change of variables (like coordinate
transformations).

`ChangesOfVariables` is a very lightweight package and has no dependencies
beyond `Base`, `LinearAlgebra`, `Test`.
beyond `Base`, `LinearAlgebra` and `Test` (plus a weak depdendency on
`InverseFunctions`).

## Documentation

Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"

Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using ChangesOfVariables
DocMeta.setdocmeta!(
ChangesOfVariables,
:DocTestSetup,
:(using ChangesOfVariables);
:(using ChangesOfVariables, InverseFunctions);
recursive=true,
)

Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
```@docs
with_logabsdet_jacobian
NoLogAbsDetJacobian
setladj
```

## Test utility
Expand Down
32 changes: 32 additions & 0 deletions ext/ChangesOfVariablesInverseFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module ChangesOfVariablesInverseFunctionsExt

using ChangesOfVariables
using InverseFunctions


struct InverseFunctionWithLADJ{InvF,LADJF} <: Function
inv_f::InvF
ladjf::LADJF
end
InverseFunctionWithLADJ(::Type{InvF}, ladjf::LADJF) where {InvF,LADJF} = InverseFunctionWithLADJ{Type{InvF},LADJF}(InvF,ladjf)
InverseFunctionWithLADJ(inv_f::InvF, ::Type{LADJF}) where {InvF,LADJF} = InverseFunctionWithLADJ{InvF,Type{LADJF}}(inv_f,LADJF)
InverseFunctionWithLADJ(::Type{InvF}, ::Type{LADJF}) where {InvF,LADJF} = InverseFunctionWithLADJ{Type{InvF},Type{LADJF}}(InvF,LADJF)

(f::InverseFunctionWithLADJ)(y) = f.inv_f(y)

function ChangesOfVariables.with_logabsdet_jacobian(f::InverseFunctionWithLADJ, y)
x = f.inv_f(y)
return x, -f.ladjf(x)
end

InverseFunctions.inverse(f::ChangesOfVariables.FunctionWithLADJ) = InverseFunctionWithLADJ(inverse(f.f), f.ladjf)
InverseFunctions.inverse(f::InverseFunctionWithLADJ) = ChangesOfVariables.FunctionWithLADJ(inverse(f.inv_f), f.ladjf)


@static if isdefined(InverseFunctions, :FunctionWithInverse)
function ChangesOfVariables.with_logabsdet_jacobian(f::InverseFunctions.FunctionWithInverse, x)
ChangesOfVariables.with_logabsdet_jacobian(f.f, x)
end
end

end # module ChangesOfVariablesInverseFunctionsExt
5 changes: 5 additions & 0 deletions src/ChangesOfVariables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ using LinearAlgebra
using Test

include("with_ladj.jl")
include("setladj.jl")
include("test.jl")

@static if !isdefined(Base, :get_extension)
include("../ext/ChangesOfVariablesInverseFunctionsExt.jl")
end

end # module
71 changes: 71 additions & 0 deletions src/setladj.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT).


"""
struct FunctionWithLADJ{F,LADJF} <: Function
A function with an separate function to compute it's `logabddet(J)`.
Do not construct directly, use [`setladj(f, ladjf)`](@ref) instead.
"""
struct FunctionWithLADJ{F,LADJF} <: Function
f::F
ladjf::LADJF
end
FunctionWithLADJ(::Type{F}, ladjf::LADJF) where {F,LADJF} = FunctionWithLADJ{Type{F},LADJF}(F,ladjf)
FunctionWithLADJ(f::F, ::Type{LADJF}) where {F,LADJF} = FunctionWithLADJ{F,Type{LADJF}}(f,LADJF)
FunctionWithLADJ(::Type{F}, ::Type{LADJF}) where {F,LADJF} = FunctionWithLADJ{Type{F},Type{LADJF}}(F,LADJF)

(f::FunctionWithLADJ)(x) = f.f(x)

with_logabsdet_jacobian(f::FunctionWithLADJ, x) = f.f(x), f.ladjf(x)


"""
setladj(f, ladjf)::Function
Return a function that behaves like `f` in general and which has
`with_logabsdet_jacobian(f, x) = f(x), ladjf(x)`.
Useful in cases where [`with_logabsdet_jacobian`](@ref) is not defined
for `f`, or if `f` needs to be assigned a LADJ-calculation that is
only valid within a given context, e.g. only for a
limited argument type/range that is guaranteed by the use case but
not in general, or that is optimized to a custom use case.
For example, `CUDA.CuArray` has no `with_logabsdet_jacobian` defined,
but may be used to switch computing device for a part of a
heterogenous computing function chain. Likewise, one may want to
switch numerical precision for a part of a calculation.
The function (wrapper) returned by `setladj` supports
[`InverseFunctions.inverse`](https://github.com/JuliaMath/InverseFunctions.jl)
if `f` does so.
Example:
```jldoctest setladj
VERSION < v"1.6" || begin # Support for ∘ requires Julia >= v1.6
# Increases precition before calculation exp:
foo = exp ∘ setladj(setinverse(Float64, Float32), _ -> 0)
# A log-value from some low-precision (e.g. GPU) computation:
log_x = Float32(100)
# f(log_x) would return Inf32 without going to Float64:
y, ladj = with_logabsdet_jacobian(foo, log_x)
r_log_x, ladj_inv = with_logabsdet_jacobian(inverse(foo), y)
ladj ≈ 100 ≈ -ladj_inv && r_log_x ≈ log_x
end
# output
true
```
"""
setladj(f, ladjf) = FunctionWithLADJ(_unwrap_f(f), ladjf)
export setladj

_unwrap_f(f) = f
_unwrap_f(f::FunctionWithLADJ) = f.f
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import Documenter
Test.@testset "Package ChangesOfVariables" begin
include("test_test.jl")
include("test_with_ladj.jl")
include("test_setladj.jl")

# doctests
Documenter.DocMeta.setdocmeta!(
ChangesOfVariables,
:DocTestSetup,
:(using ChangesOfVariables);
:(using ChangesOfVariables, InverseFunctions);
recursive=true,
)
Documenter.doctest(ChangesOfVariables)
end # testset

55 changes: 55 additions & 0 deletions test/test_setladj.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT).

using Test
using ChangesOfVariables
using InverseFunctions

const ChangesOfVariablesInverseFunctionsExt = if isdefined(Base, :get_extension)
Base.get_extension(ChangesOfVariables, :ChangesOfVariablesInverseFunctionsExt)
else
ChangesOfVariables.ChangesOfVariablesInverseFunctionsExt
end
const InverseFunctionWithLADJ = ChangesOfVariablesInverseFunctionsExt.InverseFunctionWithLADJ

include("getjacobian.jl")


# Dummy testing type that looks like something that represents abstract zeros:
struct _Zero{T} end
_Zero(::T) where {T} = _Zero{T}()


@testset "setladj" begin
@test @inferred(setladj(Real, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},Type{_Zero}}
@test @inferred(ChangesOfVariables.FunctionWithLADJ(Real, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},Type{_Zero}}
@test @inferred(ChangesOfVariables.FunctionWithLADJ(widen, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{typeof(widen),Type{_Zero}}
@test @inferred(ChangesOfVariables.FunctionWithLADJ(Real, zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},typeof(zero)}
@test @inferred(ChangesOfVariables.FunctionWithLADJ(widen, zero)) isa ChangesOfVariables.FunctionWithLADJ{typeof(widen),typeof(zero)}

@test @inferred(InverseFunctionWithLADJ(Real, _Zero)) isa InverseFunctionWithLADJ{Type{Real},Type{_Zero}}
@test @inferred(InverseFunctionWithLADJ(widen, _Zero)) isa InverseFunctionWithLADJ{typeof(widen),Type{_Zero}}
@test @inferred(InverseFunctionWithLADJ(Real, zero)) isa InverseFunctionWithLADJ{Type{Real},typeof(zero)}
@test @inferred(InverseFunctionWithLADJ(widen, zero)) isa InverseFunctionWithLADJ{typeof(widen),typeof(zero)}

x = 4.2
y = x^2

f_fwd = setladj(x -> x^2, x -> log(2*x))
f_inv = setladj(y -> sqrt(y), y -> log(inv(2*sqrt(y))))
ChangesOfVariables.test_with_logabsdet_jacobian(f_fwd, x, getjacobian)
ChangesOfVariables.test_with_logabsdet_jacobian(f_inv, y, getjacobian)

f = @inferred setladj(setinverse(x -> x^2, x -> sqrt(x)), x -> log(2*x))
@test @inferred(f(x)) == y
ChangesOfVariables.test_with_logabsdet_jacobian(f, x, getjacobian)
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(f), y, getjacobian)
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(inverse(f)), x, getjacobian)
@inferred(inverse(inverse(f))) isa ChangesOfVariables.FunctionWithLADJ

@static if isdefined(InverseFunctions, :setinverse)
g = setinverse(f_fwd, f_inv)
ChangesOfVariables.test_with_logabsdet_jacobian(g, x, getjacobian)
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(g), y, getjacobian)
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(inverse(g)), x, getjacobian)
end
end

0 comments on commit f659510

Please sign in to comment.