From bd6ff3dd2f1e4c26a4681c780c54f1ac9fe52e0d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 10 Mar 2024 09:05:23 +0100 Subject: [PATCH 01/19] Rename VecCholeskyBijector to VecCorrCholeskyBijector --- src/bijectors/corr.jl | 30 ++++++++++++++++-------------- test/bijectors/corr.jl | 4 ++-- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 1e68fce8..9c7ed115 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -95,7 +95,7 @@ A bijector to transform a correlation matrix to an unconstrained vector. # Reference https://mc-stan.org/docs/reference-manual/correlation-matrix-transform.html -See also: [`CorrBijector`](@ref) and ['VecCholeskyBijector'](@ref) +See also: [`CorrBijector`](@ref) and ['VecCorrCholeskyBijector'](@ref) # Example @@ -151,7 +151,7 @@ function output_size(::Inverse{VecCorrBijector}, sz::Tuple{Int}) end """ - VecCholeskyBijector <: Bijector + VecCorrCholeskyBijector <: Bijector A bijector to transform a Cholesky factor of a correlation matrix to an unconstrained vector. @@ -172,7 +172,7 @@ julia> using LinearAlgebra julia> using StableRNGs; rng = StableRNG(42); -julia> b = Bijectors.VecCholeskyBijector(:U); +julia> b = Bijectors.VecCorrCholeskyBijector(:U); julia> X = rand(rng, LKJCholesky(3, 1, :U)) # Sample a correlation matrix. Cholesky{Float64, Matrix{Float64}} @@ -194,9 +194,9 @@ true julia> X_inv.L ≈ X.L # (✓) Also works for the lower triangular factor. true """ -struct VecCholeskyBijector <: Bijector +struct VecCorrCholeskyBijector <: Bijector mode::Symbol - function VecCholeskyBijector(uplo) + function VecCorrCholeskyBijector(uplo) s = Symbol(uplo) if (s === :U) || (s === :L) new(s) @@ -210,39 +210,41 @@ struct VecCholeskyBijector <: Bijector end end +Base.@deprecate_binding VecCholeskyBijector VecCorrCholeskyBijector + # TODO: Implement directly to make use of shared computations. -with_logabsdet_jacobian(b::VecCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x) +with_logabsdet_jacobian(b::VecCorrCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x) -function transform(b::VecCholeskyBijector, X) +function transform(b::VecCorrCholeskyBijector, X) return if b.mode === :U _link_chol_lkj_from_upper(cholesky_upper(X)) - else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor. + else # No need to check for === :L, as it is checked in the VecCorrCholeskyBijector constructor. _link_chol_lkj_from_lower(cholesky_lower(X)) end end -function logabsdetjac(b::VecCholeskyBijector, x) +function logabsdetjac(b::VecCorrCholeskyBijector, x) return -logabsdetjac(inverse(b), b(x)) end -function transform(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) +function transform(b::Inverse{VecCorrCholeskyBijector}, y::AbstractVector{<:Real}) if b.orig.mode === :U # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works return Cholesky(_inv_link_chol_lkj(y), 'U', 0) - else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor. + else # No need to check for === :L, as it is checked in the VecCorrCholeskyBijector constructor. # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. return Cholesky(transpose_eager(_inv_link_chol_lkj(y)), 'L', 0) end end -function logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) +function logabsdetjac(::Inverse{VecCorrCholeskyBijector}, y::AbstractVector{<:Real}) return _logabsdetjac_inv_chol(y) end -output_size(::VecCholeskyBijector, sz::Tuple{Int,Int}) = output_size(VecCorrBijector(), sz) -function output_size(::Inverse{<:VecCholeskyBijector}, sz::Tuple{Int}) +output_size(::VecCorrCholeskyBijector, sz::Tuple{Int,Int}) = output_size(VecCorrBijector(), sz) +function output_size(::Inverse{<:VecCorrCholeskyBijector}, sz::Tuple{Int}) return output_size(inverse(VecCorrBijector()), sz) end diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 8a423bc3..03dff5d2 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -1,5 +1,5 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test -using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector +using Bijectors: VecCorrBijector, VecCorrCholeskyBijector, CorrBijector @testset "CorrBijector & VecCorrBijector" begin for d in [1, 2, 5] @@ -45,7 +45,7 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector end end -@testset "VecCholeskyBijector" begin +@testset "VecCorrCholeskyBijector" begin for d in [2, 5] for dist in [LKJCholesky(d, 1, 'U'), LKJCholesky(d, 1, 'L')] b = bijector(dist) From 2d5ad8281ebce03a2e6d27f9bda507b782267791 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 27 May 2024 23:33:48 +0200 Subject: [PATCH 02/19] Compute corr logdetjac during transform --- src/bijectors/corr.jl | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 9c7ed115..bcc65056 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -338,44 +338,54 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) K = LinearAlgebra.checksquare(Y) W = similar(Y) + T = typeof(log(one(eltype(W)))) + logJ = zero(T) + idx = 1 @inbounds for j in 1:K - W[1, j] = 1 - for i in 2:j - z = tanh(Y[i - 1, j]) - tmp = W[i - 1, j] - W[i - 1, j] = z * tmp - W[i, j] = tmp * sqrt(1 - z^2) + log_remainder = zero(T) # log of proportion of unit vector remaining + for i in 1:(j - 1) + z = tanh(Y[i, j]) + idx += 1 + W[i, j] = z * exp(log_remainder) + log_remainder += log1p(-z^2) / 2 + logJ += log_remainder end + logJ += log_remainder + W[j, j] = exp(log_remainder) for i in (j + 1):K W[i, j] = 0 end end - return W + return W, logJ end function _inv_link_chol_lkj(y::AbstractVector) K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) + T = typeof(log(one(eltype(W)))) + logJ = zero(T) idx = 1 @inbounds for j in 1:K - W[1, j] = 1 - for i in 2:j + log_remainder = zero(T) # log of proportion of unit vector remaining + for i in 1:(j - 1) z = tanh(y[idx]) idx += 1 - tmp = W[i - 1, j] - W[i - 1, j] = z * tmp - W[i, j] = tmp * sqrt(1 - z^2) + W[i, j] = z * exp(log_remainder) + log_remainder += log1p(-z^2) / 2 + logJ += log_remainder end + logJ += log_remainder + W[j, j] = exp(log_remainder) for i in (j + 1):K W[i, j] = 0 end end - return W + return W, logJ end function _logabsdetjac_inv_corr(Y::AbstractMatrix) From 9c3e22bdbcaf3d78a82366356501243320de5036 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 27 May 2024 23:34:25 +0200 Subject: [PATCH 03/19] Enforce one-based indexing --- src/bijectors/corr.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index bcc65056..40ca72c2 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -335,6 +335,7 @@ _link_chol_lkj_from_lower(W::AbstractMatrix) = _link_chol_lkj_from_upper(transpo Inverse link function for cholesky factor. """ function _inv_link_chol_lkj(Y::AbstractMatrix) + LinearAlgebra.require_one_based_indexing(Y) K = LinearAlgebra.checksquare(Y) W = similar(Y) @@ -362,6 +363,7 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) end function _inv_link_chol_lkj(y::AbstractVector) + LinearAlgebra.require_one_based_indexing(y) K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) From 5180c193844e3a856598b9b5368ec7a7f2a5f659 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 27 May 2024 23:39:47 +0200 Subject: [PATCH 04/19] Add with_logabsdet_jacobian for correlation transforms --- src/bijectors/corr.jl | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 40ca72c2..27f0608e 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -71,9 +71,13 @@ function transform(b::CorrBijector, X::AbstractMatrix{<:Real}) return r end -function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) - w = _inv_link_chol_lkj(y) - return pd_from_upper(w) +function with_logabsdet_jacobian(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) + U, logJ = _inv_link_chol_lkj(y) + K = size(U, 1) + for j in 2:(K - 1) + logJ += (K - j) * log(U[j, j]) + end + return pd_from_upper(U), logJ end logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_inv_corr(Y) @@ -131,8 +135,13 @@ function logabsdetjac(b::VecCorrBijector, x) return -logabsdetjac(inverse(b), b(x)) end -function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) - return pd_from_upper(_inv_link_chol_lkj(y)) +function with_logabsdet_jacobian(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + U, logJ = _inv_link_chol_lkj(y) + K = size(U, 1) + for j in 2:(K - 1) + logJ += (K - j) * log(U[j, j]) + end + return pd_from_upper(U), logJ end function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) @@ -227,15 +236,16 @@ function logabsdetjac(b::VecCorrCholeskyBijector, x) return -logabsdetjac(inverse(b), b(x)) end -function transform(b::Inverse{VecCorrCholeskyBijector}, y::AbstractVector{<:Real}) +function with_logabsdet_jacobian(b::Inverse{VecCorrCholeskyBijector}, y::AbstractVector{<:Real}) + factors, logJ = _inv_link_chol_lkj(y) if b.orig.mode === :U # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works - return Cholesky(_inv_link_chol_lkj(y), 'U', 0) + return Cholesky(factors, 'U', 0), logJ else # No need to check for === :L, as it is checked in the VecCorrCholeskyBijector constructor. # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. - return Cholesky(transpose_eager(_inv_link_chol_lkj(y)), 'L', 0) + return Cholesky(transpose_eager(factors), 'L', 0), logJ end end From 8186ee9828185061f3fe6905db9419c4c52f0456 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 27 May 2024 23:41:36 +0200 Subject: [PATCH 05/19] Add rrule for non-mutating ADs --- src/bijectors/corr.jl | 63 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 27f0608e..5f465acc 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -400,6 +400,69 @@ function _inv_link_chol_lkj(y::AbstractVector) return W, logJ end +# shared reverse-mode AD rule code +function _inv_link_chol_lkj_rrule(y::AbstractVector) + LinearAlgebra.require_one_based_indexing(y) + K = _triu1_dim_from_length(length(y)) + + W = similar(y, K, K) + T = typeof(log(one(eltype(W)))) + logJ = zero(T) + + z_vec = tanh.(y) + + idx = 1 + W[1, 1] = 1 + @inbounds for j in 2:K + log_remainder = zero(T) # log of proportion of unit vector remaining + for i in 1:(j - 1) + z = z_vec[idx] + idx += 1 + W[i, j] = z * exp(log_remainder) + log_remainder += log1p(-z^2) / 2 + logJ += log_remainder + end + logJ += log_remainder + W[j, j] = exp(log_remainder) + for i in (j + 1):K + W[i, j] = 0 + end + end + + function pullback_inv_link_chol_lkj((ΔW, ΔlogJ)) + LinearAlgebra.require_one_based_indexing(ΔW) + Δy = similar(y) + + idx = lastindex(y) + @inbounds for j in K:-1:2 + Δlog_remainder = W[j, j] * ΔW[j, j] + 2ΔlogJ + for i in (j - 1):-1:1 + W_ΔW = W[i, j] * ΔW[i, j] + z = z_vec[idx] + Δy[idx] = (inv(z) - z) * W_ΔW - z * Δlog_remainder + idx -= 1 + Δlog_remainder += ΔlogJ + W_ΔW + end + end + + return Δy + end + + return (W, logJ), pullback_inv_link_chol_lkj +end + +function _inv_link_chol_lkj_rrule(y::AbstractMatrix) + K = LinearAlgebra.checksquare(y) + y_vec = Bijectors._triu_to_vec(y, 1) + W_logJ, back = _inv_link_chol_lkj_reverse(y_vec) + + function pullback_inv_link_chol_lkj(ΔW_ΔlogJ) + return update_triu_from_vec(_triu_to_vec(back(ΔW_ΔlogJ), 1), 1, K) + end + + return W_logJ, pullback_inv_link_chol_lkj +end + function _logabsdetjac_inv_corr(Y::AbstractMatrix) K = LinearAlgebra.checksquare(Y) From c0d8aa1b66a25be8708e02c9fcc8c0efe8c5ec7c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 27 May 2024 23:41:55 +0200 Subject: [PATCH 06/19] Update ChainRules to use manual rrule --- src/chainrules.jl | 49 ++++------------------------------------------- 1 file changed, 4 insertions(+), 45 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 3f598634..f15e1c22 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -267,55 +267,14 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_lower), W::AbstractMa end function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) - K = _triu1_dim_from_length(length(y)) - - W = similar(y, K, K) - - z_vec = similar(y) - tmp_vec = similar(y) - - idx = 1 - @inbounds for j in 1:K - W[1, j] = 1 - for i in 2:j - z = tanh(y[idx]) - tmp = W[i - 1, j] - - z_vec[idx] = z - tmp_vec[idx] = tmp - idx += 1 - - W[i - 1, j] = z * tmp - W[i, j] = tmp * sqrt(1 - z^2) - end - for i in (j + 1):K - W[i, j] = 0 - end - end - - function pullback_inv_link_chol_lkj(ΔW_thunked) - ΔW = ChainRulesCore.unthunk(ΔW_thunked) - - Δy = zero(y) - - @inbounds for j in 1:K - idx_up_to_prev_column = ((j - 1) * (j - 2) ÷ 2) - Δtmp = ΔW[j, j] - for i in j:-1:2 - idx = idx_up_to_prev_column + i - 1 - tmp = tmp_vec[idx] - z = z_vec[idx] - - Δz = ΔW[i - 1, j] * tmp - Δtmp * tmp / sqrt(1 - z^2) * z - Δy[idx] = Δz / cosh(y[idx])^2 - Δtmp = ΔW[i - 1, j] * z + Δtmp * sqrt(1 - z^2) - end - end + W_logJ, back = _inv_link_chol_lkj_rrule(y) + function pullback_inv_link_chol_lkj(ΔW_ΔlogJ) + Δy = back(ChainRulesCore.unthunk(ΔW_ΔlogJ)) return ChainRulesCore.NoTangent(), Δy end - return W, pullback_inv_link_chol_lkj + return W_logJ, pullback_inv_link_chol_lkj end function ChainRulesCore.rrule(::typeof(pd_from_upper), X::AbstractMatrix) From e47a946e6e93a16a2c8223c099e354528304b9e0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 27 May 2024 23:42:02 +0200 Subject: [PATCH 07/19] Update Tracker to use manual rrule --- ext/BijectorsTrackerExt.jl | 100 ++----------------------------------- 1 file changed, 5 insertions(+), 95 deletions(-) diff --git a/ext/BijectorsTrackerExt.jl b/ext/BijectorsTrackerExt.jl index b44cf3a3..6943a518 100644 --- a/ext/BijectorsTrackerExt.jl +++ b/ext/BijectorsTrackerExt.jl @@ -338,106 +338,16 @@ Bijectors.lower_triangular(A::TrackedMatrix) = track(Bijectors.lower_triangular, end Bijectors._inv_link_chol_lkj(y::TrackedVector) = track(Bijectors._inv_link_chol_lkj, y) -@grad function Bijectors._inv_link_chol_lkj(y_tracked::TrackedVector) - y = data(y_tracked) - K = _triu1_dim_from_length(length(y)) - - W = similar(y, K, K) - - z_vec = similar(y) - tmp_vec = similar(y) - - idx = 1 - @inbounds for j in 1:K - W[1, j] = 1 - for i in 2:j - z = tanh(y[idx]) - tmp = W[i - 1, j] - - z_vec[idx] = z - tmp_vec[idx] = tmp - idx += 1 - - W[i - 1, j] = z * tmp - W[i, j] = tmp * sqrt(1 - z^2) - end - for i in (j + 1):K - W[i, j] = 0 - end - end - - function pullback_inv_link_chol_lkj(ΔW) - LinearAlgebra.checksquare(ΔW) - - Δy = zero(y) - - @inbounds for j in 1:K - idx_up_to_prev_column = ((j - 1) * (j - 2) ÷ 2) - Δtmp = ΔW[j, j] - for i in j:-1:2 - idx = idx_up_to_prev_column + i - 1 - Δz = - ΔW[i - 1, j] * tmp_vec[idx] - - Δtmp * tmp_vec[idx] / sqrt(1 - z_vec[idx]^2) * z_vec[idx] - Δy[idx] = Δz / cosh(y[idx])^2 - Δtmp = ΔW[i - 1, j] * z_vec[idx] + Δtmp * sqrt(1 - z_vec[idx]^2) - end - end - - return (Δy,) - end - - return W, pullback_inv_link_chol_lkj -end - Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_lkj, y) -@grad function Bijectors._inv_link_chol_lkj(y_tracked::TrackedMatrix) +@grad function Bijectors._inv_link_chol_lkj(y_tracked::Union{TrackedVector,TrackedMatrix}) y = data(y_tracked) + W_logJ, back = _inv_link_chol_lkj_rrule(y) - K = LinearAlgebra.checksquare(y) - - w = similar(y) - - z_mat = similar(y) # cache for adjoint - tmp_mat = similar(y) - - @inbounds for j in 1:K - w[1, j] = 1 - for i in 2:j - z = tanh(y[i - 1, j]) - tmp = w[i - 1, j] - - z_mat[i, j] = z - tmp_mat[i, j] = tmp - - w[i - 1, j] = z * tmp - w[i, j] = tmp * sqrt(1 - z^2) - end - for i in (j + 1):K - w[i, j] = 0 - end - end - - function pullback_inv_link_chol_lkj(Δw) - LinearAlgebra.checksquare(Δw) - - Δy = zero(y) - - @inbounds for j in 1:K - Δtmp = Δw[j, j] - for i in j:-1:2 - Δz = - Δw[i - 1, j] * tmp_mat[i, j] - - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j] - Δy[i - 1, j] = Δz / cosh(y[i - 1, j])^2 - Δtmp = Δw[i - 1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2) - end - end - - return (Δy,) + function pullback_inv_link_chol_lkj(ΔW_ΔlogJ) + return (back(ΔW_ΔlogJ),) end - return w, pullback_inv_link_chol_lkj + return W_logJ, pullback_inv_link_chol_lkj end Bijectors._link_chol_lkj(w::TrackedMatrix) = track(Bijectors._link_chol_lkj, w) From a3fe7bbdf4e8ba2e512693007cc5b9d2efa9f60e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 27 May 2024 23:43:04 +0200 Subject: [PATCH 08/19] Remove rrule for ReverseDiff `@grad_from_chainrules` can't handle multi-output functions, see https://github.com/JuliaDiff/ReverseDiff.jl/issues/221. In this case it can AD through the primal just fine. --- ext/BijectorsReverseDiffExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/BijectorsReverseDiffExt.jl b/ext/BijectorsReverseDiffExt.jl index a733bd71..4489cb26 100644 --- a/ext/BijectorsReverseDiffExt.jl +++ b/ext/BijectorsReverseDiffExt.jl @@ -268,7 +268,6 @@ end @grad_from_chainrules _link_chol_lkj(x::TrackedMatrix) @grad_from_chainrules _link_chol_lkj_from_upper(x::TrackedMatrix) @grad_from_chainrules _link_chol_lkj_from_lower(x::TrackedMatrix) -@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector) cholesky_lower(X::TrackedMatrix) = track(cholesky_lower, X) @grad function cholesky_lower(X_tracked::TrackedMatrix) From dd065825d4d372893fd142416e3d32b3299f29cd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 27 May 2024 23:57:51 +0200 Subject: [PATCH 09/19] Add module --- ext/BijectorsTrackerExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/BijectorsTrackerExt.jl b/ext/BijectorsTrackerExt.jl index 6943a518..9d04ea32 100644 --- a/ext/BijectorsTrackerExt.jl +++ b/ext/BijectorsTrackerExt.jl @@ -341,7 +341,7 @@ Bijectors._inv_link_chol_lkj(y::TrackedVector) = track(Bijectors._inv_link_chol_ Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_lkj, y) @grad function Bijectors._inv_link_chol_lkj(y_tracked::Union{TrackedVector,TrackedMatrix}) y = data(y_tracked) - W_logJ, back = _inv_link_chol_lkj_rrule(y) + W_logJ, back = Bijectors._inv_link_chol_lkj_rrule(y) function pullback_inv_link_chol_lkj(ΔW_ΔlogJ) return (back(ΔW_ΔlogJ),) From 9f34b9c861770adbeb17600b327349554406d9f3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 27 May 2024 23:59:16 +0200 Subject: [PATCH 10/19] Make CorrBijector more numerically stable Also use consistent notation with inverse transform --- src/bijectors/corr.jl | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5f465acc..7d115ade 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -293,48 +293,44 @@ which is the above implementation. function _link_chol_lkj(W::AbstractMatrix) K = LinearAlgebra.checksquare(W) - z = similar(W) # z is also UpperTriangular. + y = similar(W) # z is also UpperTriangular. # Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero. - # This block can't be integrated with loop below, because W[1,1] != 0. - @inbounds z[:, 1] .= 0 - - @inbounds for j in 2:K - z[1, j] = atanh(W[1, j]) - tmp = sqrt(1 - W[1, j]^2) - for i in 2:(j - 1) - p = W[i, j] / tmp - tmp *= sqrt(1 - p^2) - z[i, j] = atanh(p) + @inbounds for j in 1:K + remainder_sq = one(eltype(W)) + for i in 1:(j - 1) + z = W[i, j] / sqrt(remainder_sq) + y[i, j] = atanh(z) + remainder_sq -= W[i, j]^2 end for i in j:K - z[i, j] = 0 + y[i, j] = 0 end end - return z + return y end function _link_chol_lkj_from_upper(W::AbstractMatrix) K = LinearAlgebra.checksquare(W) N = ((K - 1) * K) ÷ 2 # {K \choose 2} free parameters - z = similar(W, N) + y = similar(W, N) idx = 1 @inbounds for j in 2:K - z[idx] = atanh(W[1, j]) + y[idx] = atanh(W[1, j]) idx += 1 - tmp = sqrt(1 - W[1, j]^2) + remainder_sq = 1 - W[1, j]^2 for i in 2:(j - 1) - p = W[i, j] / tmp - tmp *= sqrt(1 - p^2) - z[idx] = atanh(p) + z = W[i, j] / sqrt(remainder_sq) + y[idx] = atanh(z) + remainder_sq -= W[i, j]^2 idx += 1 end end - return z + return y end _link_chol_lkj_from_lower(W::AbstractMatrix) = _link_chol_lkj_from_upper(transpose_eager(W)) From f417c32ab37e037272bb0af927b438227a36b4a4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 28 May 2024 00:15:00 +0200 Subject: [PATCH 11/19] Increment patch number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d26d2f45..f3903688 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.12" +version = "0.13.13" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From 121e3cdd87eb1e7085e405562c754e56e4e213c4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 28 May 2024 00:23:12 +0200 Subject: [PATCH 12/19] Revert "Rename VecCholeskyBijector to VecCorrCholeskyBijector" This reverts commit bd6ff3dd2f1e4c26a4681c780c54f1ac9fe52e0d. --- src/bijectors/corr.jl | 30 ++++++++++++++---------------- test/bijectors/corr.jl | 4 ++-- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 7d115ade..3832bf2a 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -99,7 +99,7 @@ A bijector to transform a correlation matrix to an unconstrained vector. # Reference https://mc-stan.org/docs/reference-manual/correlation-matrix-transform.html -See also: [`CorrBijector`](@ref) and ['VecCorrCholeskyBijector'](@ref) +See also: [`CorrBijector`](@ref) and ['VecCholeskyBijector'](@ref) # Example @@ -160,7 +160,7 @@ function output_size(::Inverse{VecCorrBijector}, sz::Tuple{Int}) end """ - VecCorrCholeskyBijector <: Bijector + VecCholeskyBijector <: Bijector A bijector to transform a Cholesky factor of a correlation matrix to an unconstrained vector. @@ -181,7 +181,7 @@ julia> using LinearAlgebra julia> using StableRNGs; rng = StableRNG(42); -julia> b = Bijectors.VecCorrCholeskyBijector(:U); +julia> b = Bijectors.VecCholeskyBijector(:U); julia> X = rand(rng, LKJCholesky(3, 1, :U)) # Sample a correlation matrix. Cholesky{Float64, Matrix{Float64}} @@ -203,9 +203,9 @@ true julia> X_inv.L ≈ X.L # (✓) Also works for the lower triangular factor. true """ -struct VecCorrCholeskyBijector <: Bijector +struct VecCholeskyBijector <: Bijector mode::Symbol - function VecCorrCholeskyBijector(uplo) + function VecCholeskyBijector(uplo) s = Symbol(uplo) if (s === :U) || (s === :L) new(s) @@ -219,42 +219,40 @@ struct VecCorrCholeskyBijector <: Bijector end end -Base.@deprecate_binding VecCholeskyBijector VecCorrCholeskyBijector - # TODO: Implement directly to make use of shared computations. -with_logabsdet_jacobian(b::VecCorrCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x) +with_logabsdet_jacobian(b::VecCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x) -function transform(b::VecCorrCholeskyBijector, X) +function transform(b::VecCholeskyBijector, X) return if b.mode === :U _link_chol_lkj_from_upper(cholesky_upper(X)) - else # No need to check for === :L, as it is checked in the VecCorrCholeskyBijector constructor. + else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor. _link_chol_lkj_from_lower(cholesky_lower(X)) end end -function logabsdetjac(b::VecCorrCholeskyBijector, x) +function logabsdetjac(b::VecCholeskyBijector, x) return -logabsdetjac(inverse(b), b(x)) end -function with_logabsdet_jacobian(b::Inverse{VecCorrCholeskyBijector}, y::AbstractVector{<:Real}) +function with_logabsdet_jacobian(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) factors, logJ = _inv_link_chol_lkj(y) if b.orig.mode === :U # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works return Cholesky(factors, 'U', 0), logJ - else # No need to check for === :L, as it is checked in the VecCorrCholeskyBijector constructor. + else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor. # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. return Cholesky(transpose_eager(factors), 'L', 0), logJ end end -function logabsdetjac(::Inverse{VecCorrCholeskyBijector}, y::AbstractVector{<:Real}) +function logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) return _logabsdetjac_inv_chol(y) end -output_size(::VecCorrCholeskyBijector, sz::Tuple{Int,Int}) = output_size(VecCorrBijector(), sz) -function output_size(::Inverse{<:VecCorrCholeskyBijector}, sz::Tuple{Int}) +output_size(::VecCholeskyBijector, sz::Tuple{Int,Int}) = output_size(VecCorrBijector(), sz) +function output_size(::Inverse{<:VecCholeskyBijector}, sz::Tuple{Int}) return output_size(inverse(VecCorrBijector()), sz) end diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 03dff5d2..8a423bc3 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -1,5 +1,5 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test -using Bijectors: VecCorrBijector, VecCorrCholeskyBijector, CorrBijector +using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector @testset "CorrBijector & VecCorrBijector" begin for d in [1, 2, 5] @@ -45,7 +45,7 @@ using Bijectors: VecCorrBijector, VecCorrCholeskyBijector, CorrBijector end end -@testset "VecCorrCholeskyBijector" begin +@testset "VecCholeskyBijector" begin for d in [2, 5] for dist in [LKJCholesky(d, 1, 'U'), LKJCholesky(d, 1, 'L')] b = bijector(dist) From c1dbb30ccd01841f0b5765d7ff10e41486edbb4a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 3 Jun 2024 08:47:30 +0100 Subject: [PATCH 13/19] Update src/bijectors/corr.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/bijectors/corr.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 3832bf2a..9a35480a 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -449,7 +449,6 @@ function _inv_link_chol_lkj_rrule(y::AbstractMatrix) K = LinearAlgebra.checksquare(y) y_vec = Bijectors._triu_to_vec(y, 1) W_logJ, back = _inv_link_chol_lkj_reverse(y_vec) - function pullback_inv_link_chol_lkj(ΔW_ΔlogJ) return update_triu_from_vec(_triu_to_vec(back(ΔW_ΔlogJ), 1), 1, K) end From 289c689c722fc50ae7560d08691ea2e90dbed254 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 3 Jun 2024 11:05:09 +0200 Subject: [PATCH 14/19] Apply suggestions from code review Co-authored-by: Tor Erlend Fjelde --- src/bijectors/corr.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 9a35480a..95c0c03f 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -343,7 +343,7 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) K = LinearAlgebra.checksquare(Y) W = similar(Y) - T = typeof(log(one(eltype(W)))) + T = float(eltype(W)) logJ = zero(T) idx = 1 @@ -351,7 +351,6 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) log_remainder = zero(T) # log of proportion of unit vector remaining for i in 1:(j - 1) z = tanh(Y[i, j]) - idx += 1 W[i, j] = z * exp(log_remainder) log_remainder += log1p(-z^2) / 2 logJ += log_remainder @@ -427,14 +426,14 @@ function _inv_link_chol_lkj_rrule(y::AbstractVector) LinearAlgebra.require_one_based_indexing(ΔW) Δy = similar(y) - idx = lastindex(y) + idx_local = lastindex(y) @inbounds for j in K:-1:2 Δlog_remainder = W[j, j] * ΔW[j, j] + 2ΔlogJ for i in (j - 1):-1:1 W_ΔW = W[i, j] * ΔW[i, j] - z = z_vec[idx] - Δy[idx] = (inv(z) - z) * W_ΔW - z * Δlog_remainder - idx -= 1 + z = z_vec[idx_local] + Δy[idx_local] = (inv(z) - z) * W_ΔW - z * Δlog_remainder + idx_local -= 1 Δlog_remainder += ΔlogJ + W_ΔW end end From f0ad1218dc47bbe564efd57b3a9c962979815f99 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 3 Jun 2024 11:05:31 +0200 Subject: [PATCH 15/19] Apply suggestions from code review --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 95c0c03f..a347fa3d 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -370,7 +370,7 @@ function _inv_link_chol_lkj(y::AbstractVector) K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) - T = typeof(log(one(eltype(W)))) + T = float(eltype(W)) logJ = zero(T) idx = 1 From a2eac95e47b1ff235a8f5699e888f962cc241eaa Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 3 Jun 2024 13:37:10 +0200 Subject: [PATCH 16/19] Work around issues with Tracker --- src/bijectors/corr.jl | 4 +++- src/utils.jl | 10 ++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index a347fa3d..a7fe0b92 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -136,7 +136,9 @@ function logabsdetjac(b::VecCorrBijector, x) end function with_logabsdet_jacobian(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) - U, logJ = _inv_link_chol_lkj(y) + U_logJ = _inv_link_chol_lkj(y) + # workaround for `Tracker.TrackedTuple` not supporting iteration + U, logJ = U_logJ[1], U_logJ[2] K = size(U, 1) for j in 2:(K - 1) logJ += (K - j) * log(U[j, j]) diff --git a/src/utils.jl b/src/utils.jl index 82c15de6..9fd6c65c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,8 +11,14 @@ _vec(x::Real) = x lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) -pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' -pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) +function pd_from_lower(X) + L = lower_triangular(X) + return L * L' +end +function pd_from_upper(X) + U = upper_triangular(X) + return U' * U +end # HACK: Allows us to define custom chain rules while we wait for upstream fixes. transpose_eager(X::AbstractMatrix) = permutedims(X) From 21189fa0ff9f8e7c9b72ee0a82c8829633f77a9e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 3 Jun 2024 23:40:33 +0100 Subject: [PATCH 17/19] import `stack` from Compat.jl (#314) --- src/Bijectors.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 0ef63a2d..50a8f07e 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -82,6 +82,10 @@ if VERSION < v"1.1" using Compat: eachcol end +if VERSION < v"1.9" + using Compat: stack +end + const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0"))) _debug(str) = @debug str From f9a86874548f5ce910a1720454436e8f3bebfca1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 4 Jun 2024 21:26:37 +0100 Subject: [PATCH 18/19] import `stack` in tests too --- Project.toml | 2 +- test/Project.toml | 2 ++ test/runtests.jl | 4 ++++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 959115c7..c53a2ce6 100644 --- a/Project.toml +++ b/Project.toml @@ -43,7 +43,7 @@ ArgCheck = "1, 2" ChainRules = "1" ChainRulesCore = "0.10.11, 1" ChangesOfVariables = "0.1" -Compat = "3, 4" +Compat = "3.46, 4.2" Distributions = "0.25.33" ForwardDiff = "0.10" DistributionsAD = "0.6" diff --git a/test/Project.toml b/test/Project.toml index a7a089c3..6f156b8b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" @@ -21,6 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ChainRulesTestUtils = "0.7, 1" ChangesOfVariables = "0.1" Combinatorics = "1.0.2" +Compat = "3.46, 4.2" DistributionsAD = "0.6.3" FillArrays = "1" FiniteDifferences = "0.11, 0.12" diff --git a/test/runtests.jl b/test/runtests.jl index 01b81fb8..2ce01010 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,6 +29,10 @@ using ChangesOfVariables: ChangesOfVariables using InverseFunctions: InverseFunctions using LazyArrays: LazyArrays +if VERSION < v"1.9" + using Compat: stack +end + const GROUP = get(ENV, "GROUP", "All") # Always include this since it can be useful for other tests. From b6e7fa3654dbb5d04bc99311696ffd17a0d88a63 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 5 Jun 2024 08:08:01 +0100 Subject: [PATCH 19/19] disable certain tests for ProductBijector on Julia versions with older `eachslice` impls --- test/bijectors/product_bijector.jl | 59 ++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/test/bijectors/product_bijector.jl b/test/bijectors/product_bijector.jl index 818f89c0..78310572 100644 --- a/test/bijectors/product_bijector.jl +++ b/test/bijectors/product_bijector.jl @@ -33,14 +33,27 @@ has_square_jacobian(b, x) = Bijectors.output_size(b, x) == size(x) end y, logjac = stack(map(first, results)), sum(last, results) - test_bijector( - b_prod, - x; - y, - logjac, - changes_of_variables_test=has_square_jacobian(b, xs[1]), - test_not_identity=!isidentity, - ) + if VERSION < v"1.9" && length(size(d)) > 0 + # `eachslice`, which is used by `ProductBijector`, is type-unstable + # for multivariate cases on Julia < 1.9. Hence the type-inference fails. + @test_broken test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) + else + test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) + end end @testset "Two-dim stack: $(nameof(typeof(d)))" for (d, isidentity) in ds @@ -57,13 +70,27 @@ has_square_jacobian(b, x) = Bijectors.output_size(b, x) == size(x) results = map(Base.Fix1(with_logabsdet_jacobian, b), xs) y, logjac = stack(map(first, results)), sum(last, results) - test_bijector( - b_prod, - x; - y, - logjac, - changes_of_variables_test=has_square_jacobian(b, xs[1]), - test_not_identity=!isidentity, - ) + if VERSION < v"1.9" && length(size(d)) > 0 + # `eachslice`, which is used by `ProductBijector`, does not support + # `dims` with more than one value. As a result, stacking anything that + # isn't univariate won't work here. + @test_broken test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) + else + test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) + end end end