-
Notifications
You must be signed in to change notification settings - Fork 162
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #521 from sharlaon/product
Add product distribution combinator
- Loading branch information
Showing
5 changed files
with
206 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
######################################################################## | ||
# ProductDistribution: product of fixed distributions of similar types # | ||
######################################################################## | ||
|
||
""" | ||
ProductDistribution(distributions::Vararg{<:Distribution}) | ||
Define new distribution that is the product of the given nonempty list of distributions having a common type. | ||
The arguments comprise the list of base distributions. | ||
Example: | ||
```julia | ||
normal_strip = ProductDistribution(uniform, normal) | ||
``` | ||
The resulting product distribution takes `n` arguments, where `n` is the sum of the numbers of arguments taken by each distribution in the list. | ||
These arguments are the arguments to each component distribution, in the order in which the distributions are passed to the constructor. | ||
Example: | ||
```julia | ||
@gen function unit_strip_and_near_seven() | ||
x ~ flip_and_number(0.0, 0.1, 7.0, 0.01) | ||
end | ||
``` | ||
""" | ||
struct ProductDistribution{T, Ds} <: Distribution{T} | ||
K::Int | ||
distributions::Ds | ||
has_output_grad::Bool | ||
has_argument_grads::Tuple | ||
is_discrete::Bool | ||
num_args::Vector{Int} | ||
starting_args::Vector{Int} | ||
end | ||
|
||
(dist::ProductDistribution)(args...) = random(dist, args...) | ||
|
||
Gen.has_output_grad(dist::ProductDistribution) = dist.has_output_grad | ||
Gen.has_argument_grads(dist::ProductDistribution) = dist.has_argument_grads | ||
Gen.is_discrete(dist::ProductDistribution) = dist.is_discrete | ||
|
||
function ProductDistribution(distributions::Vararg{<:Distribution}) | ||
_has_output_grads = true | ||
_is_discrete = true | ||
|
||
types = Type[] | ||
|
||
_has_argument_grads = Bool[] | ||
_num_args = Int[] | ||
_starting_args = Int[] | ||
start_pos = 1 | ||
|
||
for dist in distributions | ||
push!(types, Gen.get_return_type(dist)) | ||
|
||
_has_output_grads = _has_output_grads && has_output_grad(dist) | ||
_is_discrete = _is_discrete && is_discrete(dist) | ||
|
||
grads_data = has_argument_grads(dist) | ||
append!(_has_argument_grads, grads_data) | ||
push!(_num_args, length(grads_data)) | ||
push!(_starting_args, start_pos) | ||
start_pos += length(grads_data) | ||
end | ||
|
||
return ProductDistribution{Tuple{types...}, typeof(distributions)}( | ||
length(distributions), | ||
distributions, | ||
_has_output_grads, | ||
Tuple(_has_argument_grads), | ||
_is_discrete, | ||
_num_args, | ||
_starting_args) | ||
end | ||
|
||
function extract_args_for_component(dist::ProductDistribution, component_args_flat, k::Int) | ||
start_arg = dist.starting_args[k] | ||
n = dist.num_args[k] | ||
return component_args_flat[start_arg:start_arg+n-1] | ||
end | ||
|
||
Gen.random(dist::ProductDistribution, args...) = | ||
Tuple(random(d, extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions)) | ||
|
||
Gen.logpdf(dist::ProductDistribution, x, args...) = | ||
sum(Gen.logpdf(d, x[k], extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions)) | ||
|
||
function Gen.logpdf_grad(dist::ProductDistribution, x, args...) | ||
x_grad = () | ||
arg_grads = () | ||
for (k, d) in enumerate(dist.distributions) | ||
grads = Gen.logpdf_grad(d, x[k], extract_args_for_component(dist, args, k)...) | ||
x_grad = (x_grad..., grads[1]) | ||
arg_grads = (arg_grads..., grads[2:end]...) | ||
end | ||
x_grad = dist.has_output_grad ? x_grad : nothing | ||
return (x_grad, arg_grads...) | ||
end | ||
|
||
export ProductDistribution |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,4 @@ include("recurse.jl") | |
include("switch.jl") | ||
include("dist_dsl.jl") | ||
include("mixture.jl") | ||
include("product.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
discrete_product = ProductDistribution(bernoulli, binom) | ||
|
||
@testset "product of discrete distributions" begin | ||
@test is_discrete(discrete_product) | ||
grad_bools = (has_output_grad(discrete_product), has_argument_grads(discrete_product)...) | ||
@test grad_bools == (false, true, false, true) | ||
|
||
p1 = 0.5 | ||
(n, p2) = (3, 0.9) | ||
|
||
# random | ||
x = discrete_product(p1, n, p2) | ||
@assert typeof(x) == Gen.get_return_type(discrete_product) == Tuple{Bool, Int} | ||
|
||
# logpdf | ||
x = (true, 2) | ||
actual = logpdf(discrete_product, x, p1, n, p2) | ||
expected = logpdf(bernoulli, x[1], p1) + logpdf(binom, x[2], n, p2) | ||
@test isapprox(actual, expected) | ||
|
||
# test logpdf_grad against finite differencing | ||
f = (x, p1, n, p2) -> logpdf(discrete_product, x, p1, n, p2) | ||
args = (x, p1, n, p2) | ||
actual = logpdf_grad(discrete_product, args...) | ||
for i in [2, 4] | ||
@test isapprox(actual[i], finite_diff(f, args, i, dx)) | ||
end | ||
end | ||
|
||
continuous_product = ProductDistribution(uniform, normal) | ||
|
||
@testset "product of continuous distributions" begin | ||
@test !is_discrete(continuous_product) | ||
grad_bools = (has_output_grad(continuous_product), has_argument_grads(continuous_product)...) | ||
@test grad_bools == (true, true, true, true, true) | ||
|
||
(low, high) = (-0.5, 0.5) | ||
(mu, std) = (0.0, 1.0) | ||
|
||
# random | ||
x = continuous_product(low, high, mu, std) | ||
@assert typeof(x) == Gen.get_return_type(continuous_product) == Tuple{Float64, Float64} | ||
|
||
# logpdf | ||
x = (0.1, 0.7) | ||
actual = logpdf(continuous_product, x, low, high, mu, std) | ||
expected = logpdf(uniform, x[1], low, high) + logpdf(normal, x[2], mu, std) | ||
@test isapprox(actual, expected) | ||
|
||
# test logpdf_grad against finite differencing | ||
f = (x, low, high, mu, std) -> logpdf(continuous_product, x, low, high, mu, std) | ||
# A mutable indexable is required by `finite_diff_vec`, hence the `collect` here: | ||
args = (collect(x), low, high, mu, std) | ||
actual = logpdf_grad(continuous_product, args...) | ||
@test isapprox(actual[1][1], finite_diff_vec(f, args, 1, 1, dx)) | ||
@test isapprox(actual[1][2], finite_diff_vec(f, args, 1, 2, dx)) | ||
for i in 2:5 | ||
@test isapprox(actual[i], finite_diff(f, args, i, dx)) | ||
end | ||
end | ||
|
||
dissimilar_product = ProductDistribution(bernoulli, normal) | ||
|
||
@testset "product of dissimilarly-typed distributions" begin | ||
@test !is_discrete(dissimilar_product) | ||
grad_bools = (has_output_grad(dissimilar_product), has_argument_grads(dissimilar_product)...) | ||
@test grad_bools == (false, true, true, true) | ||
|
||
p = 0.5 | ||
(mu, std) = (0.0, 1.0) | ||
|
||
# random | ||
x = dissimilar_product(p, mu, std) | ||
@assert typeof(x) == Gen.get_return_type(dissimilar_product) == Tuple{Bool, Float64} | ||
|
||
# logpdf | ||
x = (false, 0.3) | ||
actual = logpdf(dissimilar_product, x, p, mu, std) | ||
expected = logpdf(bernoulli, x[1], p) + logpdf(normal, x[2], mu, std) | ||
@test isapprox(actual, expected) | ||
|
||
# test logpdf_grad against finite differencing | ||
f = (x, p, mu, std) -> logpdf(dissimilar_product, x, p, mu, std) | ||
args = (x, p, mu, std) | ||
actual = logpdf_grad(dissimilar_product, args...) | ||
for i in 2:4 | ||
@test isapprox(actual[i], finite_diff(f, args, i, dx)) | ||
end | ||
end |