Skip to content

Commit

Permalink
improved anim_traj()
Browse files Browse the repository at this point in the history
  • Loading branch information
162348 committed Oct 21, 2024
1 parent 96b9fcf commit 8724cc4
Show file tree
Hide file tree
Showing 17 changed files with 258 additions and 99 deletions.
Binary file added Cauchy1D.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PDMPFlux"
uuid = "fa921718-cbf9-4165-aaee-7047d51b02b3"
authors = ["Hirofumi Shiba"]
version = "0.2.0"
version = "0.2.1"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
66 changes: 43 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,42 @@
|:-------------:|:---------:|:-------------:|:-----------------:|
| [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://162348.github.io/PDMPFlux.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://162348.github.io/PDMPFlux.jl/dev/) | [![Build Status](https://github.com/162348/PDMPFlux.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/162348/PDMPFlux.jl/actions/workflows/CI.yml?query=branch%3Amain) | [![Coverage](https://codecov.io/gh/162348/PDMPFlux.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/162348/PDMPFlux.jl) | [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) |

This repository contains a [`Zygote.jl`](https://github.com/FluxML/Zygote.jl) implementation of the PDMP samplers.
## Overview

Currently, only Zig-Zag samplers are implemented.
`PDMPFlux.jl` provides a fast and efficient implementation of **Piecewise Deterministic Markov Process (PDMP)** samplers using a grid-based Poisson thinning approach.

In this version `v0.2.0`, only Zig-Zag samplers are implemented. We will extend the functionality to include other PDMP samplers in the future.

### Key Features

* To sample from a distribution $p(x)$, the *only* required inputs are its dimension $d$ and the negative log density $U(x)=-\log p(x)$ (up to constant).


## Motivation

Markov Chain Monte Carlo (MCMC) methods are standard in sampling from distributions with unknown normalizing constants.

However, PDMPs offer a promising alternative due to their continuous and non-reversible dynamics, particularly in high-dimensional and big data contexts, as discussed in [Bouchard-Côté et. al. (2018)](https://arxiv.org/abs/1510.02451) and [Bierkens et. al. (2019)](https://arxiv.org/abs/1607.03188).

Despite their potential, practical applications of PDMPs remain limited by a lack of efficient and flexible implementations.

`PDMPFlux.jl` is my attempt to fill this gap, with the aid of the existing automatic differentiation engines.

## Installation

Currently, `julia >= 1.11` is required, due to some compatibility issues.
Currently, `julia >= 1.11` is required for compatibility.

To install `PDMPFlux`, open up a Julia-REPL, type `]` to get into Pkg-mode, and type:
To install the package, use Julia's package manager:

```julia-repl
(@v1.11) pkg> add PDMPFlux
```

which will install the package and all dependencies to your local environment.
## Usage

## Examples
### Basic

The simplest example may be the following:
The following example demonstrates how to sample from a standard Gaussian distribution using a Zig-Zag sampler.

```julia
using PDMPFlux
Expand All @@ -40,25 +57,30 @@ samples = sample(sampler, N_sk, N, xinit, vinit, seed=2024)
jointplot(samples)
```

To diagnose the sampler, you can manually break down the `sample()` function into two steps: `sample_skeleton()` and `sample_from_skeleton()`:
### Advanced

For more control, you can manually provide the gradient.

Also, by breaking down the `sample()` function into two steps: `sample_skeleton()` and `sample_from_skeleton()`, you can use `plot_traj()` and `diagnostic()` functions to diagnose the sampler:

```julia
using PDMPFlux
using Zygote

N_sk = 1_000_000 # number of skeleton points
N = 1_000_000 # number of samples

function U_banana(x::Vector{Float64})
function U_banana(x::Vector)
mean_x2 = (x[1]^2 - 1)
return -(- x[1]^2 + -(x[2] - mean_x2)^2 - sum(x[3:end].^2)) / 2
return - (- x[1] + -(x[2] - mean_x2) - sum(x[3:end])) # don't forget the minus sign!
end

dim = 50
xinit = ones(dim)
vinit = ones(dim)
grid_size = 0 # use constant bounds

sampler = ZigZag(dim, ∇U, grid_size=grid_size) # initialize your Zig-Zag sampler
sampler = ZigZag(dim, ∇U_banana, grid_size=grid_size) # manually providing the gradient
output = sample_skeleton(sampler, N_sk, xinit, vinit, verbose = true) # simulate skeleton points
samples = sample_from_skeleton(sampler, N, output) # get samples from the skeleton points

Expand All @@ -76,25 +98,25 @@ jointplot(samples)
<td style="width: 25%;"><img src="examples/Funnel/Funnel_GroundTruthSamples.svg"></td>
<td style="width: 25%;"><img src="examples/Funnel/ZigZag_Funnel2D_trajectory.svg"></td>
<td style="width: 25%;"><img src="examples/Funnel/ZigZag_Funnel2D.gif"></td>
<td style="width: 25%;"><img src="examples/Funnel/ZigZag_Funnel3D_2.gif"></td>
<td style="width: 25%;"><img src="examples/Funnel/ZigZag_Funnel3D.gif"></td>
</tr>
<tr>
<td align="center"><a href="examples/ZigZag_Funnel3D.jl"><sup>2D</sup> Funnel Distribution (Ground Truth)</a></td>
<td align="center"><a href="examples/ZigZag_Funnel3D.jl"><sup>2D</sup> Zig-Zag Trajectory (T<sub>max</sub>=10000)</a></td>
<td align="center"><a href="examples/ZigZag_Funnel3D.jl"><sup>2D</sup> Zig-Zag on Funnel</a></td>
<td align="center"><a href="examples/ZigZag_Funnel2D.jl"><sup>2D</sup> Zig-Zag on Funnel</a></td>
<td align="center"><a href="examples/ZigZag_Funnel3D.jl"><sup>3D</sup> Zig-Zag on Funnel</a></td>
</tr>
<tr>
<td style="width: 25%;"><img src="assets/banana_density.svg"></td>
<td style="width: 25%;"><img src="assets/banana_jointplot.svg"></td>
<td style="width: 25%;"><img src="assets/ZigZag_Banana2D_2.gif"></td>
<td style="width: 25%;"><img src="assets/ZigZag_Banana3D.gif"></td>
<td style="width: 25%;"><img src="examples/Banana/ZigZag_Banana2D.gif"></td>
<td style="width: 25%;"><img src="examples/Banana/ZigZag_Banana3D.gif"></td>
</tr>
<tr>
<td align="center"><a href="test/runtests.jl"><sup>2D</sup> Banana Density Contour (Ground Truth)</a></td>
<td align="center"><a href="test/runtests.jl"><sup>2D</sup> Zig-Zag Sample Jointplot</a></td>
<td align="center"><a href="test/runtests.jl"><sup>2D</sup> Zig-Zag on Banana</a></td>
<td align="center"><a href="test/runtests.jl"><sup>3D</sup> Zig-Zag on Banana</a></td>
<td align="center"><a href="examples/Banana/ZigZag_Banana2D.jl"><sup>2D</sup> Zig-Zag on Banana</a></td>
<td align="center"><a href="examples/Banana/ZigZag_Banana3D.jl"><sup>3D</sup> Zig-Zag on Banana</a></td>
</tr>
</tbody>
</table>
Expand All @@ -116,9 +138,9 @@ jointplot(samples)

## Remarks

- The implementation of `PDMPFlux.jl` is based on the paper [Andral and Kamatani (2024) Automated Techniques for Efficient Sampling of Piecewise-Deterministic Markov Processes](https://arxiv.org/abs/2408.03682) and its accompanying Python package [`pdmp_jax`](https://github.com/charlyandral/pdmp_jax).
- The automatic Poisson thinning implementation in `PDMPFlux.jl` is based on the paper [Andral and Kamatani (2024) Automated Techniques for Efficient Sampling of Piecewise-Deterministic Markov Processes](https://arxiv.org/abs/2408.03682) and its accompanying Python package [`pdmp_jax`](https://github.com/charlyandral/pdmp_jax).
- [`pdmp_jax`](https://github.com/charlyandral/pdmp_jax) has a [`jax`](https://github.com/jax-ml/jax) based implementation, and typically about four times faster than current `PDMPFlux.jl`.
- Automatic differentiation engines I've tried are [`ForwardDiff.jl`](https://github.com/JuliaDiff/ForwardDiff.jl) and [`Zygote.jl`](https://github.com/FluxML/Zygote.jl). Both have pros and cons. I am still learning the trade-offs.
- Both [`ForwardDiff.jl`](https://github.com/JuliaDiff/ForwardDiff.jl) and [`Zygote.jl`](https://github.com/FluxML/Zygote.jl) are used for automatic differentiation, each with their own trade-offs.

## References

Expand All @@ -128,7 +150,5 @@ jointplot(samples)
* [Revels, Lubin, and Papamarkou (2016) Forward-Mode Automatic Differentiation rin Julia](https://arxiv.org/abs/1607.07892)
* [Innes et. al. (2018) Fashionable Modelling with Flux](https://arxiv.org/abs/1811.01457)
* Other PDMP packages:
* Julia
* [`ZigZagBoomerang.jl`](https://github.com/mschauer/ZigZagBoomerang.jl) by [Marcel Schauer](https://github.com/mschauer)
* R
* [`rjpdmp`](https://github.com/matt-sutton/rjpdmp) by [Matthew Sutton](https://github.com/matt-sutton)
* Julia package [`ZigZagBoomerang.jl`](https://github.com/mschauer/ZigZagBoomerang.jl) by [Marcel Schauer](https://github.com/mschauer)
* R package [`rjpdmp`](https://github.com/matt-sutton/rjpdmp) by [Matthew Sutton](https://github.com/matt-sutton)
Binary file added assets/Cauchy1D_10.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/Banana/ZigZag_Banana2D.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 30 additions & 0 deletions examples/Banana/ZigZag_Banana2D.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using PDMPFlux

# using Test
using Zygote, Random, Plots

N_sk = 10_000 # number of skeleton points
N = 10_000 # number of samples

function runtest(N_sk::Int, N::Int, dim::Int=2)
function U_banana(x::Vector)
mean_x2 = (x[1]^2 - 1)
return -(- x[1]^2 + -(x[2] - mean_x2)^2) / 2
end

∇U(x::Vector) = gradient(U_banana, x)[1]
seed = 8
key = MersenneTwister(seed)
xinit = ones(dim)
vinit = ones(dim)
grid_size = 0 # constant bounds

sampler = ZigZag(dim, ∇U, grid_size=grid_size)
out = sample_skeleton(sampler, N_sk, xinit, vinit, seed=seed, verbose = true)
samples = sample_from_skeleton(sampler, N, out)

return out, samples
end

out, samples = runtest(N_sk, N)
anim_traj(out, 1000; T_start=100, plot_start=100, filename="ZigZag_Banana2D.gif")
Binary file added examples/Banana/ZigZag_Banana3D.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 30 additions & 0 deletions examples/Banana/ZigZag_Banana3D.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using PDMPFlux

# using Test
using Zygote, Random, Plots

N_sk = 10_000 # number of skeleton points
N = 10_000 # number of samples

function runtest(N_sk::Int, N::Int, dim::Int=3)
function U_banana(x::Vector)
mean_x2 = (x[1]^2 - 1)
return -(- x[1]^2 + -(x[2] - mean_x2)^2 - sum(x[3:end].^2)) / 2
end

∇U(x::Vector) = gradient(U_banana, x)[1]
seed = 8
key = MersenneTwister(seed)
xinit = ones(dim)
vinit = ones(dim)
grid_size = 0 # constant bounds

sampler = ZigZag(dim, ∇U, grid_size=grid_size)
out = sample_skeleton(sampler, N_sk, xinit, vinit, seed=seed, verbose = true)
samples = sample_from_skeleton(sampler, N, out)

return out, samples
end

out, samples = runtest(N_sk, N)
anim_traj(out, 1000; T_start=100, filename="ZigZag_Banana3D.gif", plot_type="3D")
Binary file modified examples/Funnel/ZigZag_Funnel2D.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
71 changes: 71 additions & 0 deletions examples/Funnel/ZigZag_Funnel2D.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using PDMPFlux

using Random, Distributions, Plots, LaTeXStrings, ForwardDiff, LinearAlgebra

"""
Funnel distribution for testing. Returns energy and sample functions.
For reference, see Neal, R. M. (2003). Slice sampling. The Annals of Statistics, 31(3), 705–767.
"""
function funnel(d::Int=10, σ::Float64=3.0, clip_y::Int=11)

function neg_energy(x::Vector)
v = x[1]
log_density_v = logpdf(Normal(0.0, 3.0), v)
variance_other = exp(v)
other_dim = d - 1
cov_other = I * variance_other
mean_other = zeros(other_dim)
log_density_other = logpdf(MvNormal(mean_other, cov_other), x[2:end])
return - log_density_v - log_density_other
end

function sample_data(n_samples::Int)
# sample from Nd funnel distribution
y = clamp.(σ * randn(n_samples, 1), -clip_y, clip_y)
x = randn(n_samples, d - 1) .* exp.(-y / 2)
return hcat(.- y, x)
end

return neg_energy, sample_data
end

function plot_funnel(d::Int=10, n_samples::Int=10000)
_, sample_data = funnel(d)
data = sample_data(n_samples)

# 最初の2次元を抽出(yとx1)
y = data[:, 1]
x1 = data[:, 2]

# 散布図をプロット
scatter(y, x1, alpha=0.5, markersize=1, xlabel=L"y", ylabel=L"x_1",
title="Funnel Distribution (First Two Dimensions' Ground Truth)", grid=true, legend=false, color="#78C2AD")

# xlim と ylim を追加
xlims!(-8, 8) # x軸の範囲を -8 から 8 に設定
ylims!(-7, 7) # y軸の範囲を -7 から 7 に設定
end
plot_funnel()

function run_ZigZag_on_funnel(N_sk::Int=100_000, N::Int=100_000; d::Int=10)
U, _ = funnel(d)
∇U(x::Vector{Float64}) = ForwardDiff.gradient(U, x)
xinit = ones(d)
vinit = ones(d)
seed = 2024
grid_size = 0 # constant bounds
sampler = ZigZag(d, ∇U, grid_size=grid_size)
out = sample_skeleton(sampler, N_sk, xinit, vinit, seed=seed, verbose = true)
samples = sample_from_skeleton(sampler, N, out)
return out, samples
end
output, samples = run_ZigZag_on_funnel(d=2)

# jointplot(samples)
# plot_traj(output, 10000)
# plot_traj(output, 1000, plot_type="3D")

anim_traj(output, 1000; plot_start=100, filename="ZigZag_Funnel2D.gif")
# anim_traj(output, 1000; filename="ZigZag_Funnel2D.gif")

diagnostic(output)
Binary file modified examples/Funnel/ZigZag_Funnel3D.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 10 additions & 10 deletions examples/Funnel/ZigZag_Funnel3D.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using PDMPFlux

using Random, Distributions, Plots, LaTeXStrings, Zygote, LinearAlgebra
using Random, Distributions, Plots, LaTeXStrings, LinearAlgebra, ForwardDiff

"""
Funnel distribution for testing. Returns energy and sample functions.
For reference, see Neal, R. M. (2003). Slice sampling. The Annals of Statistics, 31(3), 705–767.
"""
function funnel(d::Int=10, σ::Float64=3.0, clip_y::Int=11)

function neg_energy(x::Vector{Float64})
function neg_energy(x::Vector)
v = x[1]
log_density_v = logpdf(Normal(0.0, 3.0), v)
variance_other = exp(v)
Expand Down Expand Up @@ -47,9 +47,9 @@ function plot_funnel(d::Int=10, n_samples::Int=10000)
end
plot_funnel()

function run_ZigZag_on_funnel(N_sk::Int=100_000, N::Int=100_000, d::Int=10)
function run_ZigZag_on_funnel(N_sk::Int=100_000, N::Int=100_000; d::Int=10)
U, _ = funnel(d)
∇U(x::Vector{Float64}) = ForwardDiff.gradient(U, x)[1]
∇U(x::Vector{Float64}) = ForwardDiff.gradient(U, x)
xinit = ones(d)
vinit = ones(d)
seed = 2024
Expand All @@ -59,13 +59,13 @@ function run_ZigZag_on_funnel(N_sk::Int=100_000, N::Int=100_000, d::Int=10)
samples = sample_from_skeleton(sampler, N, out)
return out, samples
end
output, samples = run_ZigZag_on_funnel()
output, samples = run_ZigZag_on_funnel(d=3)

jointplot(samples)
plot_traj(output, 10000)
plot_traj(output, 1000, plot_type="3D")
# jointplot(samples)
# plot_traj(output, 10000)
# plot_traj(output, 1000, plot_type="3D")

anim_traj(output, 1000, plot_type="3D"; filename="ZigZag_Funnel3D_2.gif", dt=0.1)
anim_traj(output, 1000; filename="ZigZag_Funnel2D.gif")
anim_traj(output, 1000, plot_type="3D"; plot_start=100, filename="ZigZag_Funnel3D.gif")
# anim_traj(output, 1000; filename="ZigZag_Funnel2D.gif")

diagnostic(output)
8 changes: 4 additions & 4 deletions src/Composites.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ end
horizon (Float64): horizon
key (Any): random key
integrator (Function): integrator function
grad_U (Function): gradient of the potential function
∇U (Function): gradient of the potential function
rate (Function): rate function
velocity_jump (Function): velocity jump function
upper_bound_func (Function): upper bound function
Expand All @@ -56,7 +56,7 @@ mutable struct PDMPState <: Any
horizon::Float64
key::AbstractRNG
integrator::Function
grad_U::Function
∇U::Function
rate::Function
velocity_jump::Function
upper_bound_func::Function
Expand All @@ -76,7 +76,7 @@ mutable struct PDMPState <: Any
adaptive::Bool
end

function PDMPState(x::Vector{Float64}, v::Vector{Float64}, t::Float64, horizon::Float64, key::AbstractRNG, integrator::Function, grad_U::Function, rate::Function, velocity_jump::Function, upper_bound_func::Function, upper_bound::Union{Nothing, BoundBox}, adaptive::Bool)
function PDMPState(x::Vector{Float64}, v::Vector{Float64}, t::Float64, horizon::Float64, key::AbstractRNG, integrator::Function, ∇U::Function, rate::Function, velocity_jump::Function, upper_bound_func::Function, upper_bound::Union{Nothing, BoundBox}, adaptive::Bool)
accept = false
indicator = false
tp = 0.0
Expand All @@ -89,7 +89,7 @@ function PDMPState(x::Vector{Float64}, v::Vector{Float64}, t::Float64, horizon::
error_value_ar = zeros(5)
rejected = 0
hitting_horizon = 0
return PDMPState(x, v, t, horizon, key, integrator, grad_U, rate, velocity_jump, upper_bound_func, accept, upper_bound, indicator, tp, ts, exp_rv, lambda_bar, lambda_t, ar, error_bound, error_value_ar, rejected, hitting_horizon, adaptive)
return PDMPState(x, v, t, horizon, key, integrator, ∇U, rate, velocity_jump, upper_bound_func, accept, upper_bound, indicator, tp, ts, exp_rv, lambda_bar, lambda_t, ar, error_bound, error_value_ar, rejected, hitting_horizon, adaptive)
end


Expand Down
Loading

2 comments on commit 8724cc4

@162348
Copy link
Owner Author

@162348 162348 commented on 8724cc4 Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/117869

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.1 -m "<description of version>" 8724cc4e05e7c1c9ff65bceb25261146d6e51714
git push origin v0.2.1

Also, note the warning: Version 0.2.1 skips over 0.2.0
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

Please sign in to comment.