Skip to content

Commit

Permalink
added more examples
Browse files Browse the repository at this point in the history
  • Loading branch information
162348 committed Oct 20, 2024
1 parent 6aaf381 commit aa26d58
Show file tree
Hide file tree
Showing 21 changed files with 11,725 additions and 32 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/codecov.yml

This file was deleted.

57 changes: 47 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

| Documentation | Workflows | Code Coverage | Quality Assurance |
|:-------------:|:---------:|:-------------:|:-----------------:|
| [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://162348.github.io/PDMPFlux.jl/stable/) | [![Build Status](https://github.com/162348/PDMPFlux.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/162348/PDMPFlux.jl/actions/workflows/CI.yml?query=branch%3Amain) | [![Coverage](https://codecov.io/gh/162348/PDMPFlux.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/162348/PDMPFlux.jl) | [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) |
| [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://162348.github.io/PDMPFlux.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://162348.github.io/PDMPFlux.jl/dev/) | [![Build Status](https://github.com/162348/PDMPFlux.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/162348/PDMPFlux.jl/actions/workflows/CI.yml?query=branch%3Amain) | [![Coverage](https://codecov.io/gh/162348/PDMPFlux.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/162348/PDMPFlux.jl) | [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) |

This repository contains a [`Zygote.jl`](https://github.com/FluxML/Zygote.jl) implementation of the PDMP samplers.

Expand Down Expand Up @@ -49,19 +49,56 @@ diagnostic(output)

## Gallery

![](assets/banana_density.svg)

![](assets/banana_jointplot.svg)

![](assets/Cauchy1D.gif)
<table>
<tbody>
<tr>
<td style="width: 25%;"><img src="examples/Funnel/Funnel_GroundTruthSamples.svg"></td>
<td style="width: 25%;"><img src="examples/Funnel/ZigZag_Funnel2D_trajectory.svg"></td>
<td style="width: 25%;"><img src="examples/Funnel/ZigZag_Funnel2D.gif"></td>
<td style="width: 25%;"><img src="examples/Funnel/ZigZag_Funnel3D_2.gif"></td>
</tr>
<tr>
<td align="center"><a href="examples/ZigZag_Funnel3D.jl"><sup>2D</sup> Funnel Distribution (Ground Truth)</a></td>
<td align="center"><a href="examples/ZigZag_Funnel3D.jl"><sup>2D</sup> Zig-Zag Trajectory (T<sub>max</sub>=10000)</a></td>
<td align="center"><a href="examples/ZigZag_Funnel3D.jl"><sup>2D</sup> Zig-Zag on Funnel</a></td>
<td align="center"><a href="examples/ZigZag_Funnel3D.jl"><sup>3D</sup> Zig-Zag on Funnel</a></td>
</tr>
<tr>
<td style="width: 25%;"><img src="assets/banana_density.svg"></td>
<td style="width: 25%;"><img src="assets/banana_jointplot.svg"></td>
<td style="width: 25%;"><img src="assets/ZigZag_Banana2D_2.gif"></td>
<td style="width: 25%;"><img src="assets/ZigZag_Banana3D.gif"></td>
</tr>
<tr>
<td align="center"><a href="test/runtests.jl"><sup>2D</sup> Banana Density Contour (Ground Truth)</a></td>
<td align="center"><a href="test/runtests.jl"><sup>2D</sup> Zig-Zag Sample Jointplot</a></td>
<td align="center"><a href="test/runtests.jl"><sup>2D</sup> Zig-Zag on Banana</a></td>
<td align="center"><a href="test/runtests.jl"><sup>3D</sup> Zig-Zag on Banana</a></td>
</tr>
</tbody>
</table>

<table>
<tbody>
<tr>
<td style="width: 50%;"><img src="assets/Cauchy1D.gif"></td>
<td style="width: 50%;"><img src="assets/Gauss1D.gif"></td>
</tr>
<tr>
<td align="center"><a href="test/1d_test.jl"><sup>1D</sup> Zig-Zag on Cauchy</a></td>
<td align="center"><a href="test/1d_test.jl"><sup>1D</sup> Zig-Zag on Gaussian</a></td>
</tr>
</tbody>
</table>

## Remarks

- The implementation of the PDMP samplers is based on the paper [Andral and Kamatani (2024) Automated Techniques for Efficient Sampling of Piecewise-Deterministic Markov Processes](https://arxiv.org/abs/2408.03682) and its implementation in [`pdmp_jax`](https://github.com/charlyandral/pdmp_jax).
- `pdmp_jax` has
- `pdmp_jax` has a `jax` based implementation, and typically about four times faster than current `PDMPFlux.jl`.

## References

* [`pdmp_jax.jl`](https://github.com/charlyandral/pdmp_jax): This repository is based on this repository.
* [Andral and Kamatani (2024) Automated Techniques for Efficient Sampling of Piecewise-Deterministic Markov Processes](https://arxiv.org/abs/2408.03682)
* [`Zygote.jl`](https://github.com/FluxML/Zygote.jl) is used for automatic differentiation.
* [`pdmp_jax.jl`](https://github.com/charlyandral/pdmp_jax) by [Charly Andral](https://github.com/charlyandral): This repository is based on this repository.
* [Andral and Kamatani (2024) Automated Techniques for Efficient Sampling of Piecewise-Deterministic Markov Processes](https://arxiv.org/abs/2408.03682)
* [`ForwardDiff.jl`](https://github.com/JuliaDiff/ForwardDiff.jl) is used for automatic differentiation.
* [Revels, Lubin, and Papamarkou (2016) Forward-Mode Automatic Differentiation in Julia](https://arxiv.org/abs/1607.07892)
File renamed without changes
Binary file added assets/ZigZag_Banana2D_1000.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/ZigZag_Banana2D_2.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/ZigZag_Banana3D.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/test1.gif
Binary file not shown.
Binary file removed assets/test2.gif
Binary file not shown.
9,233 changes: 9,233 additions & 0 deletions examples/Funnel/Funnel_GroundTruthSamples.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/Funnel/ZigZag_Funnel2D.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
43 changes: 43 additions & 0 deletions examples/Funnel/ZigZag_Funnel2D_trajectory.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/Funnel/ZigZag_Funnel3D.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
71 changes: 71 additions & 0 deletions examples/Funnel/ZigZag_Funnel3D.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using PDMPFlux

using Random, Distributions, Plots, LaTeXStrings, Zygote, LinearAlgebra

"""
Funnel distribution for testing. Returns energy and sample functions.
For reference, see Neal, R. M. (2003). Slice sampling. The Annals of Statistics, 31(3), 705–767.
"""
function funnel(d::Int=10, σ::Float64=3.0, clip_y::Int=11)

function neg_energy(x::Vector{Float64})
v = x[1]
log_density_v = logpdf(Normal(0.0, 3.0), v)
variance_other = exp(v)
other_dim = d - 1
cov_other = I * variance_other
mean_other = zeros(other_dim)
log_density_other = logpdf(MvNormal(mean_other, cov_other), x[2:end])
return - log_density_v - log_density_other
end

function sample_data(n_samples::Int)
# sample from Nd funnel distribution
y = clamp.(σ * randn(n_samples, 1), -clip_y, clip_y)
x = randn(n_samples, d - 1) .* exp.(-y / 2)
return hcat(.- y, x)
end

return neg_energy, sample_data
end

function plot_funnel(d::Int=10, n_samples::Int=10000)
_, sample_data = funnel(d)
data = sample_data(n_samples)

# 最初の2次元を抽出(yとx1)
y = data[:, 1]
x1 = data[:, 2]

# 散布図をプロット
scatter(y, x1, alpha=0.5, markersize=1, xlabel=L"y", ylabel=L"x_1",
title="Funnel Distribution (First Two Dimensions' Ground Truth)", grid=true, legend=false, color="#78C2AD")

# xlim と ylim を追加
xlims!(-8, 8) # x軸の範囲を -8 から 8 に設定
ylims!(-7, 7) # y軸の範囲を -7 から 7 に設定
end
plot_funnel()

function run_ZigZag_on_funnel(N_sk::Int=100_000, N::Int=100_000, d::Int=10)
U, _ = funnel(d)
grad_U(x::Vector{Float64}) = gradient(U, x)[1]
xinit = ones(d)
vinit = ones(d)
seed = 2024
grid_size = 0 # constant bounds
sampler = ZigZag(d, grad_U, grid_size=grid_size)
out = sample_skeleton(sampler, N_sk, xinit, vinit, seed=seed, verbose = true)
samples = sample_from_skeleton(sampler, N, out)
return out, samples
end
output, samples = run_ZigZag_on_funnel()

jointplot(samples)
plot_traj(output, 10000)
plot_traj(output, 1000, plot_type="3D")

anim_traj(output, 1000, plot_type="3D"; filename="ZigZag_Funnel3D_2.gif", dt=0.1)
anim_traj(output, 1000; filename="ZigZag_Funnel2D.gif")

diagnostic(output)
Binary file added examples/Funnel/ZigZag_Funnel3D_2.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/Funnel/ZigZag_Funnel3D_Dynamic.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2,273 changes: 2,273 additions & 0 deletions examples/Funnel/diagnostic.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 0 additions & 1 deletion examples/ZigZag_Funnel3D.jl

This file was deleted.

37 changes: 35 additions & 2 deletions src/UpperBound.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,41 @@ end

## TODO: Implement upper_bound_grid and upper_bound_grid_vect functions

function upper_bound_grid(func, start, horizon, grid_size, refresh_rate::Union{Float64,Int} = 0.0)
throw(NotImplementedError("upper_bound_grid is not implemented yet."))
"""
upper_bound_grid(func, start, horizon, n_grid, refresh_rate)
Compute the upper bound as a piecewise constant function using a grid mechanism.
Args:
func: the function for which the upper bound is computed
a (Float64): the lower bound of the interval
b (Float64): the upper bound of the interval
n_grid (Int64, optional): size of the grid for the upperbound of func. Defaults to 100.
refresh_rate (Float64, optional): refresh rate for the upper bound. Defaults to 0.
Returns:
BoundBox: An object containing the upper bound constant information.
"""
function upper_bound_grid(func::Function, start::Float64, horizon::Float64, n_grid::Int=100, refresh_rate::Float64 = 0.0)
t = range(start, stop=horizon, length=n_grid)
step_size = t[2] - t[1]

values = [func(x) for x in t]
grads = [ForwardDiff.derivative(func, x) for x in t]

intersection_pos = (values[1:end-1] .- values[2:end] .+ grads[2:end] .* step_size) ./ (grads[2:end] .- grads[1:end-1])
intersection_pos = replace(intersection_pos, NaN => 0.0)
intersection_pos = clamp.(intersection_pos, 0.0, step_size)

intersection = values[1:end-1] .+ grads[1:end-1] .* intersection_pos
box_max = max.(values[1:end-1], values[2:end])
box_max = max.(box_max, intersection)
box_max = max.(box_max, 0.0)
box_max .+= refresh_rate

cum_sum = zeros(Float64, n_grid)
cum_sum[2:end] .= cumsum(box_max) .* step_size

return BoundBox(collect(t), box_max, cum_sum, step_size)
end

function upper_bound_grid_vect(func, start, horizon, grid_size::Union{Float64,Int} = 10)
Expand Down
13 changes: 6 additions & 7 deletions src/diagnostic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function anim_traj(history::PDMPHistory, T_max::Int; filename::Union{String, Not
background=background,
linewidth=linewidth
)
args = dynamic_range ? args[3:end] : args
args = dynamic_range ? (; args..., xlims=nothing, ylims=nothing) : args
traj = traj[1,:] # Vector に変換しないと @animate に掛かる時間が 10 倍くらいになる
times = collect(Float64, 1:length(traj)) # なぜか Float64 にしないと @animate 内の push! エラー oundsError: attempt to access 2-element Vector{Plots.Series} at index [3] が出る
p = plot(times[1:1], traj[1:1]; args...)
Expand Down Expand Up @@ -81,7 +81,7 @@ function anim_traj(history::PDMPHistory, T_max::Int; filename::Union{String, Not
background=background,
linewidth=linewidth
)
args = dynamic_range ? args[3:end] : args
args = dynamic_range ? (; args..., xlims=nothing, ylims=nothing) : args
traj_x = traj[1,:]
traj_y = traj[2,:]
p = plot(traj_x[1:1], traj_y[1:1]; args...)
Expand Down Expand Up @@ -110,7 +110,7 @@ function anim_traj(history::PDMPHistory, T_max::Int; filename::Union{String, Not
background=background,
linewidth=linewidth
)
args = dynamic_range ? args[3:end] : args
args = dynamic_range ? (; args..., xlims=nothing, ylims=nothing, zlims=nothing) : args
traj_x = traj[1,:]
traj_y = traj[2,:]
traj_z = traj[3,:]
Expand All @@ -127,10 +127,9 @@ function anim_traj(history::PDMPHistory, T_max::Int; filename::Union{String, Not
end
end

if filename !== nothing
filename = endswith(filename, ".gif") ? filename : filename * ".gif"
gif(anim, filename, fps=fps)
end
filename = isnothing(filename) ? "PDMPFlux_Animation.gif" : filename
filename = endswith(filename, ".gif") ? filename : filename * ".gif"
gif(anim, filename, fps=fps)

return anim
end
Expand Down
13 changes: 12 additions & 1 deletion test/1d_test.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using PDMPFlux

using Zygote, Random
using Zygote, Random, Plots, LaTeXStrings

function U_Gauss(x::Union{Float64, Int})
x = Float64(x)
Expand All @@ -12,6 +12,17 @@ function U_Cauchy(x::Union{Float64, Int})
return log(1 + x^2)
end

function plot_densities()
# プロット範囲の設定
x_values = -10:0.1:10

# U_Gauss と U_Cauchy の値を計算
y_gauss = [exp(-U_Gauss(x)) for x in x_values]
y_cauchy = [exp(-U_Cauchy(x)) for x in x_values]
plot(x_values, y_gauss, label="Gaussian density", xlabel=L"x", ylabel=L"p(x)", title="Gaussian vs Cauchy", color="#78C2AD")
plot!(x_values, y_cauchy, label="Cauchy density", color="#E95420")
end

dim = 1
seed = 8
key = MersenneTwister(seed)
Expand Down
14 changes: 5 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,15 @@ function ground_truth()
contourf(x_range, y_range, z, xlabel="x2", ylabel="x1", title="Banana Density Contour", color=:summer)
end

ground_truth()
# ground_truth()
out, samples = runtest(N_sk, N)
jointplot(samples, coordinate_numbers=[2,1])

anim_traj(out, 10000; filename="ZigZag_Banana2D.gif", dt=0.1)
anim_traj(out, 10000; filename="ZigZag_Banana3D.gif", dt=0.1, plot_type="3D")
plot_traj(out, 10000)
diagnostic(out)
jointplot(samples, coordinate_numbers=[2,3])
# diagnostic(out)
# jointplot(samples, coordinate_numbers=[2,3])

# @testset "PDMPFlux.jl" begin
# @test sampler.dim == dim
# @test sampler.grad_U == grad_U
# @test sampler.grid_size == grid_size
# end



Expand Down

0 comments on commit aa26d58

Please sign in to comment.