-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: David Widmann <[email protected]>
- Loading branch information
Showing
10 changed files
with
180 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
```@docs | ||
with_logabsdet_jacobian | ||
NoLogAbsDetJacobian | ||
setladj | ||
``` | ||
|
||
## Test utility | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
module ChangesOfVariablesInverseFunctionsExt | ||
|
||
using ChangesOfVariables | ||
using InverseFunctions | ||
|
||
|
||
struct InverseFunctionWithLADJ{InvF,LADJF} <: Function | ||
inv_f::InvF | ||
ladjf::LADJF | ||
end | ||
InverseFunctionWithLADJ(::Type{InvF}, ladjf::LADJF) where {InvF,LADJF} = InverseFunctionWithLADJ{Type{InvF},LADJF}(InvF,ladjf) | ||
InverseFunctionWithLADJ(inv_f::InvF, ::Type{LADJF}) where {InvF,LADJF} = InverseFunctionWithLADJ{InvF,Type{LADJF}}(inv_f,LADJF) | ||
InverseFunctionWithLADJ(::Type{InvF}, ::Type{LADJF}) where {InvF,LADJF} = InverseFunctionWithLADJ{Type{InvF},Type{LADJF}}(InvF,LADJF) | ||
|
||
(f::InverseFunctionWithLADJ)(y) = f.inv_f(y) | ||
|
||
function ChangesOfVariables.with_logabsdet_jacobian(f::InverseFunctionWithLADJ, y) | ||
x = f.inv_f(y) | ||
return x, -f.ladjf(x) | ||
end | ||
|
||
InverseFunctions.inverse(f::ChangesOfVariables.FunctionWithLADJ) = InverseFunctionWithLADJ(inverse(f.f), f.ladjf) | ||
InverseFunctions.inverse(f::InverseFunctionWithLADJ) = ChangesOfVariables.FunctionWithLADJ(inverse(f.inv_f), f.ladjf) | ||
|
||
|
||
@static if isdefined(InverseFunctions, :FunctionWithInverse) | ||
function ChangesOfVariables.with_logabsdet_jacobian(f::InverseFunctions.FunctionWithInverse, x) | ||
ChangesOfVariables.with_logabsdet_jacobian(f.f, x) | ||
end | ||
end | ||
|
||
end # module ChangesOfVariablesInverseFunctionsExt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT). | ||
|
||
|
||
""" | ||
struct FunctionWithLADJ{F,LADJF} <: Function | ||
A function with an separate function to compute it's `logabddet(J)`. | ||
Do not construct directly, use [`setladj(f, ladjf)`](@ref) instead. | ||
""" | ||
struct FunctionWithLADJ{F,LADJF} <: Function | ||
f::F | ||
ladjf::LADJF | ||
end | ||
FunctionWithLADJ(::Type{F}, ladjf::LADJF) where {F,LADJF} = FunctionWithLADJ{Type{F},LADJF}(F,ladjf) | ||
FunctionWithLADJ(f::F, ::Type{LADJF}) where {F,LADJF} = FunctionWithLADJ{F,Type{LADJF}}(f,LADJF) | ||
FunctionWithLADJ(::Type{F}, ::Type{LADJF}) where {F,LADJF} = FunctionWithLADJ{Type{F},Type{LADJF}}(F,LADJF) | ||
|
||
(f::FunctionWithLADJ)(x) = f.f(x) | ||
|
||
with_logabsdet_jacobian(f::FunctionWithLADJ, x) = f.f(x), f.ladjf(x) | ||
|
||
|
||
""" | ||
setladj(f, ladjf)::Function | ||
Return a function that behaves like `f` in general and which has | ||
`with_logabsdet_jacobian(f, x) = f(x), ladjf(x)`. | ||
Useful in cases where [`with_logabsdet_jacobian`](@ref) is not defined | ||
for `f`, or if `f` needs to be assigned a LADJ-calculation that is | ||
only valid within a given context, e.g. only for a | ||
limited argument type/range that is guaranteed by the use case but | ||
not in general, or that is optimized to a custom use case. | ||
For example, `CUDA.CuArray` has no `with_logabsdet_jacobian` defined, | ||
but may be used to switch computing device for a part of a | ||
heterogenous computing function chain. Likewise, one may want to | ||
switch numerical precision for a part of a calculation. | ||
The function (wrapper) returned by `setladj` supports | ||
[`InverseFunctions.inverse`](https://github.com/JuliaMath/InverseFunctions.jl) | ||
if `f` does so. | ||
Example: | ||
```jldoctest setladj | ||
VERSION < v"1.6" || begin # Support for ∘ requires Julia >= v1.6 | ||
# Increases precition before calculation exp: | ||
foo = exp ∘ setladj(setinverse(Float64, Float32), _ -> 0) | ||
# A log-value from some low-precision (e.g. GPU) computation: | ||
log_x = Float32(100) | ||
# f(log_x) would return Inf32 without going to Float64: | ||
y, ladj = with_logabsdet_jacobian(foo, log_x) | ||
r_log_x, ladj_inv = with_logabsdet_jacobian(inverse(foo), y) | ||
ladj ≈ 100 ≈ -ladj_inv && r_log_x ≈ log_x | ||
end | ||
# output | ||
true | ||
``` | ||
""" | ||
setladj(f, ladjf) = FunctionWithLADJ(_unwrap_f(f), ladjf) | ||
export setladj | ||
|
||
_unwrap_f(f) = f | ||
_unwrap_f(f::FunctionWithLADJ) = f.f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT). | ||
|
||
using Test | ||
using ChangesOfVariables | ||
using InverseFunctions | ||
|
||
const ChangesOfVariablesInverseFunctionsExt = if isdefined(Base, :get_extension) | ||
Base.get_extension(ChangesOfVariables, :ChangesOfVariablesInverseFunctionsExt) | ||
else | ||
ChangesOfVariables.ChangesOfVariablesInverseFunctionsExt | ||
end | ||
const InverseFunctionWithLADJ = ChangesOfVariablesInverseFunctionsExt.InverseFunctionWithLADJ | ||
|
||
include("getjacobian.jl") | ||
|
||
|
||
# Dummy testing type that looks like something that represents abstract zeros: | ||
struct _Zero{T} end | ||
_Zero(::T) where {T} = _Zero{T}() | ||
|
||
|
||
@testset "setladj" begin | ||
@test @inferred(setladj(Real, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},Type{_Zero}} | ||
@test @inferred(ChangesOfVariables.FunctionWithLADJ(Real, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},Type{_Zero}} | ||
@test @inferred(ChangesOfVariables.FunctionWithLADJ(widen, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{typeof(widen),Type{_Zero}} | ||
@test @inferred(ChangesOfVariables.FunctionWithLADJ(Real, zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},typeof(zero)} | ||
@test @inferred(ChangesOfVariables.FunctionWithLADJ(widen, zero)) isa ChangesOfVariables.FunctionWithLADJ{typeof(widen),typeof(zero)} | ||
|
||
@test @inferred(InverseFunctionWithLADJ(Real, _Zero)) isa InverseFunctionWithLADJ{Type{Real},Type{_Zero}} | ||
@test @inferred(InverseFunctionWithLADJ(widen, _Zero)) isa InverseFunctionWithLADJ{typeof(widen),Type{_Zero}} | ||
@test @inferred(InverseFunctionWithLADJ(Real, zero)) isa InverseFunctionWithLADJ{Type{Real},typeof(zero)} | ||
@test @inferred(InverseFunctionWithLADJ(widen, zero)) isa InverseFunctionWithLADJ{typeof(widen),typeof(zero)} | ||
|
||
x = 4.2 | ||
y = x^2 | ||
|
||
f_fwd = setladj(x -> x^2, x -> log(2*x)) | ||
f_inv = setladj(y -> sqrt(y), y -> log(inv(2*sqrt(y)))) | ||
ChangesOfVariables.test_with_logabsdet_jacobian(f_fwd, x, getjacobian) | ||
ChangesOfVariables.test_with_logabsdet_jacobian(f_inv, y, getjacobian) | ||
|
||
f = @inferred setladj(setinverse(x -> x^2, x -> sqrt(x)), x -> log(2*x)) | ||
@test @inferred(f(x)) == y | ||
ChangesOfVariables.test_with_logabsdet_jacobian(f, x, getjacobian) | ||
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(f), y, getjacobian) | ||
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(inverse(f)), x, getjacobian) | ||
@inferred(inverse(inverse(f))) isa ChangesOfVariables.FunctionWithLADJ | ||
|
||
@static if isdefined(InverseFunctions, :setinverse) | ||
g = setinverse(f_fwd, f_inv) | ||
ChangesOfVariables.test_with_logabsdet_jacobian(g, x, getjacobian) | ||
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(g), y, getjacobian) | ||
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(inverse(g)), x, getjacobian) | ||
end | ||
end |