Skip to content

Commit

Permalink
Merge pull request #521 from sharlaon/product
Browse files Browse the repository at this point in the history
Add product distribution combinator
  • Loading branch information
ztangent authored Mar 28, 2024
2 parents 73d3790 + 6dc63a8 commit 18c06fd
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 2 deletions.
14 changes: 12 additions & 2 deletions docs/src/ref/distributions.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Probability Distributions

Gen provides a library of built-in probability distributions, and three ways of
Gen provides a library of built-in probability distributions, and four ways of
defining custom distributions, each of which are explained below:

1. The [`@dist` constructor](@ref dist_dsl), for a distribution that can be expressed as a
Expand All @@ -11,7 +11,10 @@ defining custom distributions, each of which are explained below:
2. The [`HeterogeneousMixture`](@ref) and [`HomogeneousMixture`](@ref) constructors
for distributions that are mixtures of other distributions.

3. An API for defining arbitrary [custom distributions](@ref
3. The [`ProductDistribution`](@ref) constructor for distributions that are products of
other distributions.

4. An API for defining arbitrary [custom distributions](@ref
custom_distributions) in plain Julia code.

## Built-In Distributions
Expand Down Expand Up @@ -220,6 +223,13 @@ HomogeneousMixture
HeterogeneousMixture
```

## Product Distribution Constructors

There is a built-in constructor for defining product distributions:
```@docs
ProductDistribution
```

## Defining New Distributions From Scratch

For distributions that cannot be expressed in the `@dist` DSL, users can define
Expand Down
3 changes: 3 additions & 0 deletions src/modeling_library/modeling_library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ include("dist_dsl/dist_dsl.jl")
# mixtures of distributions
include("mixture.jl")

# products of distributions
include("product.jl")

###############
# combinators #
###############
Expand Down
101 changes: 101 additions & 0 deletions src/modeling_library/product.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
########################################################################
# ProductDistribution: product of fixed distributions of similar types #
########################################################################

"""
ProductDistribution(distributions::Vararg{<:Distribution})
Define new distribution that is the product of the given nonempty list of distributions having a common type.
The arguments comprise the list of base distributions.
Example:
```julia
normal_strip = ProductDistribution(uniform, normal)
```
The resulting product distribution takes `n` arguments, where `n` is the sum of the numbers of arguments taken by each distribution in the list.
These arguments are the arguments to each component distribution, in the order in which the distributions are passed to the constructor.
Example:
```julia
@gen function unit_strip_and_near_seven()
x ~ flip_and_number(0.0, 0.1, 7.0, 0.01)
end
```
"""
struct ProductDistribution{T, Ds} <: Distribution{T}
K::Int
distributions::Ds
has_output_grad::Bool
has_argument_grads::Tuple
is_discrete::Bool
num_args::Vector{Int}
starting_args::Vector{Int}
end

(dist::ProductDistribution)(args...) = random(dist, args...)

Gen.has_output_grad(dist::ProductDistribution) = dist.has_output_grad
Gen.has_argument_grads(dist::ProductDistribution) = dist.has_argument_grads
Gen.is_discrete(dist::ProductDistribution) = dist.is_discrete

function ProductDistribution(distributions::Vararg{<:Distribution})
_has_output_grads = true
_is_discrete = true

types = Type[]

_has_argument_grads = Bool[]
_num_args = Int[]
_starting_args = Int[]
start_pos = 1

for dist in distributions
push!(types, Gen.get_return_type(dist))

_has_output_grads = _has_output_grads && has_output_grad(dist)
_is_discrete = _is_discrete && is_discrete(dist)

grads_data = has_argument_grads(dist)
append!(_has_argument_grads, grads_data)
push!(_num_args, length(grads_data))
push!(_starting_args, start_pos)
start_pos += length(grads_data)
end

return ProductDistribution{Tuple{types...}, typeof(distributions)}(
length(distributions),
distributions,
_has_output_grads,
Tuple(_has_argument_grads),
_is_discrete,
_num_args,
_starting_args)
end

function extract_args_for_component(dist::ProductDistribution, component_args_flat, k::Int)
start_arg = dist.starting_args[k]
n = dist.num_args[k]
return component_args_flat[start_arg:start_arg+n-1]
end

Gen.random(dist::ProductDistribution, args...) =
Tuple(random(d, extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions))

Gen.logpdf(dist::ProductDistribution, x, args...) =
sum(Gen.logpdf(d, x[k], extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions))

function Gen.logpdf_grad(dist::ProductDistribution, x, args...)
x_grad = ()
arg_grads = ()
for (k, d) in enumerate(dist.distributions)
grads = Gen.logpdf_grad(d, x[k], extract_args_for_component(dist, args, k)...)
x_grad = (x_grad..., grads[1])
arg_grads = (arg_grads..., grads[2:end]...)
end
x_grad = dist.has_output_grad ? x_grad : nothing
return (x_grad, arg_grads...)
end

export ProductDistribution
1 change: 1 addition & 0 deletions test/modeling_library/modeling_library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ include("recurse.jl")
include("switch.jl")
include("dist_dsl.jl")
include("mixture.jl")
include("product.jl")
89 changes: 89 additions & 0 deletions test/modeling_library/product.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
discrete_product = ProductDistribution(bernoulli, binom)

@testset "product of discrete distributions" begin
@test is_discrete(discrete_product)
grad_bools = (has_output_grad(discrete_product), has_argument_grads(discrete_product)...)
@test grad_bools == (false, true, false, true)

p1 = 0.5
(n, p2) = (3, 0.9)

# random
x = discrete_product(p1, n, p2)
@assert typeof(x) == Gen.get_return_type(discrete_product) == Tuple{Bool, Int}

# logpdf
x = (true, 2)
actual = logpdf(discrete_product, x, p1, n, p2)
expected = logpdf(bernoulli, x[1], p1) + logpdf(binom, x[2], n, p2)
@test isapprox(actual, expected)

# test logpdf_grad against finite differencing
f = (x, p1, n, p2) -> logpdf(discrete_product, x, p1, n, p2)
args = (x, p1, n, p2)
actual = logpdf_grad(discrete_product, args...)
for i in [2, 4]
@test isapprox(actual[i], finite_diff(f, args, i, dx))
end
end

continuous_product = ProductDistribution(uniform, normal)

@testset "product of continuous distributions" begin
@test !is_discrete(continuous_product)
grad_bools = (has_output_grad(continuous_product), has_argument_grads(continuous_product)...)
@test grad_bools == (true, true, true, true, true)

(low, high) = (-0.5, 0.5)
(mu, std) = (0.0, 1.0)

# random
x = continuous_product(low, high, mu, std)
@assert typeof(x) == Gen.get_return_type(continuous_product) == Tuple{Float64, Float64}

# logpdf
x = (0.1, 0.7)
actual = logpdf(continuous_product, x, low, high, mu, std)
expected = logpdf(uniform, x[1], low, high) + logpdf(normal, x[2], mu, std)
@test isapprox(actual, expected)

# test logpdf_grad against finite differencing
f = (x, low, high, mu, std) -> logpdf(continuous_product, x, low, high, mu, std)
# A mutable indexable is required by `finite_diff_vec`, hence the `collect` here:
args = (collect(x), low, high, mu, std)
actual = logpdf_grad(continuous_product, args...)
@test isapprox(actual[1][1], finite_diff_vec(f, args, 1, 1, dx))
@test isapprox(actual[1][2], finite_diff_vec(f, args, 1, 2, dx))
for i in 2:5
@test isapprox(actual[i], finite_diff(f, args, i, dx))
end
end

dissimilar_product = ProductDistribution(bernoulli, normal)

@testset "product of dissimilarly-typed distributions" begin
@test !is_discrete(dissimilar_product)
grad_bools = (has_output_grad(dissimilar_product), has_argument_grads(dissimilar_product)...)
@test grad_bools == (false, true, true, true)

p = 0.5
(mu, std) = (0.0, 1.0)

# random
x = dissimilar_product(p, mu, std)
@assert typeof(x) == Gen.get_return_type(dissimilar_product) == Tuple{Bool, Float64}

# logpdf
x = (false, 0.3)
actual = logpdf(dissimilar_product, x, p, mu, std)
expected = logpdf(bernoulli, x[1], p) + logpdf(normal, x[2], mu, std)
@test isapprox(actual, expected)

# test logpdf_grad against finite differencing
f = (x, p, mu, std) -> logpdf(dissimilar_product, x, p, mu, std)
args = (x, p, mu, std)
actual = logpdf_grad(dissimilar_product, args...)
for i in 2:4
@test isapprox(actual[i], finite_diff(f, args, i, dx))
end
end

0 comments on commit 18c06fd

Please sign in to comment.