From 4cc2eb95023178d4a1639d831d845385b80bf5ca Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 4 Mar 2024 11:46:30 -0500 Subject: [PATCH 01/18] Add product distribution combinator --- src/modeling_library/modeling_library.jl | 3 + src/modeling_library/product.jl | 99 ++++++++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 src/modeling_library/product.jl diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index 13d6e488..7e2a782f 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -62,6 +62,9 @@ include("dist_dsl/dist_dsl.jl") # mixtures of distributions include("mixture.jl") +# products of distributions +include("product.jl") + ############### # combinators # ############### diff --git a/src/modeling_library/product.jl b/src/modeling_library/product.jl new file mode 100644 index 00000000..05249e9d --- /dev/null +++ b/src/modeling_library/product.jl @@ -0,0 +1,99 @@ +######################################################################## +# ProductDistribution: product of fixed distributions of similar types # +######################################################################## + +""" +ProductDistribution(distributions::Vararg{<:Distribution}) + +Define new distribution that is the product of the given nonempty list of distributions having a common type. + +The arguments comprise the list of base distributions. + +Example: +```julia +normal_strip = ProductDistribution(uniform, normal) +``` + +The resulting product distribution takes `n` arguments, where `n` is the sum of the numbers of arguments taken by each distribution in the list. +These arguments are the arguments to each component distribution, in the order in which the distributions are passed to the constructor. + +Example: +```julia +@gen function unit_strip_and_near_seven() + x ~ flip_and_number(0.0, 0.1, 7.0, 0.01) +end +``` +""" +struct ProductDistribution{T} <: Distribution{T} + K::Int + distributions::Vector{<:Distribution} + has_output_grad::Bool + has_argument_grads::Tuple + is_discrete::Bool + num_args::Vector{Int} + starting_args::Vector{Int} +end + +(dist::ProductDistribution)(args...) = random(dist, args...) + +Gen.has_output_grad(dist::ProductDistribution) = dist.has_output_grad +Gen.has_argument_grads(dist::ProductDistribution) = dist.has_argument_grads +Gen.is_discrete(dist::ProductDistribution) = dist.is_discrete + +function ProductDistribution(distributions::Vararg{<:Distribution}) + types = Type[] + _has_argument_grads = Bool[] + _num_args = Int[] + _starting_args = Int[] + start_pos = 1 + + for dist in distributions + type = typeof(dist) + while supertype(type) != Any + type = supertype(type) + end + push!(types, type.parameters[1]) + + grads_data = has_argument_grads(dist) + append!(_has_argument_grads, grads_data) + push!(_num_args, length(grads_data)) + push!(_starting_args, start_pos) + start_pos += length(grads_data) + end + + return ProductDistribution{Tuple{types...}}( + length(distributions), + collect(distributions), + all(has_output_grad(dist) for dist in distributions), + Tuple(_has_argument_grads), + all(is_discrete(dist) for dist in distributions), + _num_args, + _starting_args) +end + +function Gen.random(dist::ProductDistribution, factor_args_flat...) + factor_args = [factor_args_flat[dist.starting_args[i]:dist.starting_args[i]+dist.num_args[i]-1] for i in 1:dist.K] + return [random(dist.distributions[i], factor_args[i]...) for i in 1:dist.K] +end + +function Gen.logpdf(dist::ProductDistribution, x, factor_args_flat...) + factor_args = [factor_args_flat[dist.starting_args[i]:dist.starting_args[i]+dist.num_args[i]-1] for i in 1:dist.K] + return sum(Gen.logpdf(dist.distributions[i], x[i], factor_args[i]...) for i in 1:dist.K) +end + +function Gen.logpdf_grad(dist::ProductDistribution, x, factor_args_flat...) + factor_args = [factor_args_flat[dist.starting_args[i]:(dist.starting_args[i]+dist.num_args[i]-1)] for i in 1:dist.K] + logpdf_grads = [Gen.logpdf_grad(dist.distributions[i], x[i], factor_args[i]...) for i in 1:dist.K] + + x_grad = if dist.has_output_grad + [grads[1] for grads in logpdf_grads] + else + nothing + end + + arg_grads = vcat((collect(grads[2:end]) for grads in logpdf_grads)...) + + return (x_grad, arg_grads...) +end + +export ProductDistribution From 10016813e6df6c80eb60e1c4a4278a9a55965c93 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 4 Mar 2024 11:49:16 -0500 Subject: [PATCH 02/18] Add docs --- docs/src/ref/distributions.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/src/ref/distributions.md b/docs/src/ref/distributions.md index 134c84e8..f36a56dd 100644 --- a/docs/src/ref/distributions.md +++ b/docs/src/ref/distributions.md @@ -1,6 +1,6 @@ # Probability Distributions -Gen provides a library of built-in probability distributions, and three ways of +Gen provides a library of built-in probability distributions, and four ways of defining custom distributions, each of which are explained below: 1. The [`@dist` constructor](@ref dist_dsl), for a distribution that can be expressed as a @@ -11,7 +11,10 @@ defining custom distributions, each of which are explained below: 2. The [`HeterogeneousMixture`](@ref) and [`HomogeneousMixture`](@ref) constructors for distributions that are mixtures of other distributions. -3. An API for defining arbitrary [custom distributions](@ref +3. The [`ProductDistribution`](@ref) constructor for distributions that are products of + other distributions. + +4. An API for defining arbitrary [custom distributions](@ref custom_distributions) in plain Julia code. ## Built-In Distributions @@ -219,6 +222,13 @@ HomogeneousMixture HeterogeneousMixture ``` +## Product Distribution Constructors + +There is a built-in constructor for defining product distributions: +```@docs +ProductDistribution +``` + ## Defining New Distributions From Scratch For distributions that cannot be expressed in the `@dist` DSL, users can define From 1cb51d0e42a283186a811bf429246ef0a0d3a3ed Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 4 Mar 2024 11:49:42 -0500 Subject: [PATCH 03/18] Add tests --- test/modeling_library/modeling_library.jl | 1 + test/modeling_library/product.jl | 89 +++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 test/modeling_library/product.jl diff --git a/test/modeling_library/modeling_library.jl b/test/modeling_library/modeling_library.jl index e5e8d95a..7facd8c1 100644 --- a/test/modeling_library/modeling_library.jl +++ b/test/modeling_library/modeling_library.jl @@ -8,3 +8,4 @@ include("recurse.jl") include("switch.jl") include("dist_dsl.jl") include("mixture.jl") +include("product.jl") diff --git a/test/modeling_library/product.jl b/test/modeling_library/product.jl new file mode 100644 index 00000000..be82aebd --- /dev/null +++ b/test/modeling_library/product.jl @@ -0,0 +1,89 @@ +discrete_product = ProductDistribution(bernoulli, binom) + +@testset "product of discrete distributions" begin + @test is_discrete(discrete_product) + grad_bools = (has_output_grad(discrete_product), has_argument_grads(discrete_product)...) + @test grad_bools == (false, true, false, true) + + p1 = 0.5 + (n, p2) = (3, 0.9) + + # random + x = unrelated_pair_product(p1, n, p2) + + # logpdf + x = (true, 2) + actual = logpdf(unrelated_pair_product, x, p1, n, p2) + expected = logpdf(bernoulli, x[1], p1) + logpdf(binom, x[2], n, p2) + @test isapprox(actual, expected) + + # test logpdf_grad against finite differencing + f = (x, p, mu, std) -> logpdf(unrelated_pair_product, x, p, mu, std) + args = (x, p, mu, std) + actual = logpdf_grad(unrelated_pair_product, args...) + for (i, b) in enumerate(grad_bools) + if b + @test isapprox(actual[i], finite_diff(f, args, i, dx)) + end + end +end + +continuous_product = ProductDistribution(uniform, normal) + +@testset "product of continuous distributions" begin + @test !is_discrete(continuous_product) + grad_bools = (has_output_grad(continuous_product), has_argument_grads(continuous_product)...) + @test grad_bools == (true, true, true, true, true) + + (low, high) = (-0.5, 0.5) + (mu, std) = (0.0, 1.0) + + # random + x = continuous_product(low, high, mu, std) + + # logpdf + x = (0.1, 0.7) + actual = logpdf(continuous_product, x, low, high, mu, std) + expected = logpdf(uniform, x[1], low, high) + logpdf(normal, x[2], mu, std) + @test isapprox(actual, expected) + + # test logpdf_grad against finite differencing + f = (x, low, high, mu, std) -> logpdf(continuous_product, x, low, high, mu, std) + args = (x, low, high, mu, std) + actual = logpdf_grad(continuous_product, args...) + for (i, b) in enumerate(grad_bools) + if b + @test isapprox(actual[i], finite_diff(f, args, i, dx)) + end + end +end + +dissimilar_product = ProductDistribution(bernoulli, normal) + +@testset "product of dissimilarly-typed distributions" begin + @test !is_discrete(dissimilar_product) + grad_bools = (has_output_grad(dissimilar_product), has_argument_grads(dissimilar_product)...) + @test grad_bools == (false, true, true, true) + + p = 0.5 + (mu, std) = (0.0, 1.0) + + # random + x = dissimilar_product(p, mu, std) + + # logpdf + x = (false, 0.3) + actual = logpdf(dissimilar_product, x, p, mu, std) + expected = logpdf(bernoulli, x[1], p) + logpdf(normal, x[2], mu, std) + @test isapprox(actual, expected) + + # test logpdf_grad against finite differencing + f = (x, p, mu, std) -> logpdf(dissimilar_product, x, p, mu, std) + args = (x, p, mu, std) + actual = logpdf_grad(dissimilar_product, args...) + for (i, b) in enumerate(grad_bools) + if b + @test isapprox(actual[i], finite_diff(f, args, i, dx)) + end + end +end From d0eac110a6ff5ece34cbf4f2cd06be1a27a9f089 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 4 Mar 2024 12:13:27 -0500 Subject: [PATCH 04/18] Fix `logpdf_grad` output component type --- src/modeling_library/product.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modeling_library/product.jl b/src/modeling_library/product.jl index 05249e9d..58440e31 100644 --- a/src/modeling_library/product.jl +++ b/src/modeling_library/product.jl @@ -86,7 +86,7 @@ function Gen.logpdf_grad(dist::ProductDistribution, x, factor_args_flat...) logpdf_grads = [Gen.logpdf_grad(dist.distributions[i], x[i], factor_args[i]...) for i in 1:dist.K] x_grad = if dist.has_output_grad - [grads[1] for grads in logpdf_grads] + tuple((grads[1] for grads in logpdf_grads)...) else nothing end From 7169f744f77bf1bbd4b9b77d94078da27b5baea1 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 4 Mar 2024 15:26:21 -0500 Subject: [PATCH 05/18] Refactor `has_output_grad` and `is_discrete --- src/modeling_library/product.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/modeling_library/product.jl b/src/modeling_library/product.jl index 58440e31..c7f8a898 100644 --- a/src/modeling_library/product.jl +++ b/src/modeling_library/product.jl @@ -41,7 +41,11 @@ Gen.has_argument_grads(dist::ProductDistribution) = dist.has_argument_grads Gen.is_discrete(dist::ProductDistribution) = dist.is_discrete function ProductDistribution(distributions::Vararg{<:Distribution}) + _has_output_grads = true + _is_discrete = true + types = Type[] + _has_argument_grads = Bool[] _num_args = Int[] _starting_args = Int[] @@ -54,6 +58,9 @@ function ProductDistribution(distributions::Vararg{<:Distribution}) end push!(types, type.parameters[1]) + _has_output_grads = _has_output_grads && has_output_grad(dist) + _is_discrete = _is_discrete && is_discrete(dist) + grads_data = has_argument_grads(dist) append!(_has_argument_grads, grads_data) push!(_num_args, length(grads_data)) @@ -64,9 +71,9 @@ function ProductDistribution(distributions::Vararg{<:Distribution}) return ProductDistribution{Tuple{types...}}( length(distributions), collect(distributions), - all(has_output_grad(dist) for dist in distributions), + _has_output_grads, Tuple(_has_argument_grads), - all(is_discrete(dist) for dist in distributions), + _is_discrete, _num_args, _starting_args) end From 22055c26fd3154291a869d152b741fc624cfd510 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 4 Mar 2024 15:40:48 -0500 Subject: [PATCH 06/18] DRY / crib from mixture combinators --- src/modeling_library/product.jl | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/modeling_library/product.jl b/src/modeling_library/product.jl index c7f8a898..26aff393 100644 --- a/src/modeling_library/product.jl +++ b/src/modeling_library/product.jl @@ -78,28 +78,26 @@ function ProductDistribution(distributions::Vararg{<:Distribution}) _starting_args) end -function Gen.random(dist::ProductDistribution, factor_args_flat...) - factor_args = [factor_args_flat[dist.starting_args[i]:dist.starting_args[i]+dist.num_args[i]-1] for i in 1:dist.K] - return [random(dist.distributions[i], factor_args[i]...) for i in 1:dist.K] +function extract_args_for_component(dist::ProductDistribution, component_args_flat, k::Int) + start_arg = dist.starting_args[k] + n = dist.num_args[k] + return component_args_flat[start_arg:start_arg+n-1] end -function Gen.logpdf(dist::ProductDistribution, x, factor_args_flat...) - factor_args = [factor_args_flat[dist.starting_args[i]:dist.starting_args[i]+dist.num_args[i]-1] for i in 1:dist.K] - return sum(Gen.logpdf(dist.distributions[i], x[i], factor_args[i]...) for i in 1:dist.K) -end +Gen.random(dist::ProductDistribution, component_args_flat...) = + [random(dist.distributions[k], extract_args_for_component(dist, component_args_flat, k)...) for k in 1:dist.K] -function Gen.logpdf_grad(dist::ProductDistribution, x, factor_args_flat...) - factor_args = [factor_args_flat[dist.starting_args[i]:(dist.starting_args[i]+dist.num_args[i]-1)] for i in 1:dist.K] - logpdf_grads = [Gen.logpdf_grad(dist.distributions[i], x[i], factor_args[i]...) for i in 1:dist.K] +Gen.logpdf(dist::ProductDistribution, x, component_args_flat...) = + sum(Gen.logpdf(dist.distributions[k], x[k], extract_args_for_component(dist, component_args_flat, k)...) for k in 1:dist.K) +function Gen.logpdf_grad(dist::ProductDistribution, x, component_args_flat...) + logpdf_grads = [Gen.logpdf_grad(dist.distributions[k], x[k], extract_args_for_component(dist, component_args_flat, k)...) for k in 1:dist.K] x_grad = if dist.has_output_grad tuple((grads[1] for grads in logpdf_grads)...) else nothing end - arg_grads = vcat((collect(grads[2:end]) for grads in logpdf_grads)...) - return (x_grad, arg_grads...) end From 7ab98480fc58264b5981b40246f99f49e6cb2e72 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 11 Mar 2024 17:50:42 -0400 Subject: [PATCH 07/18] Fix test --- test/modeling_library/product.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/modeling_library/product.jl b/test/modeling_library/product.jl index be82aebd..81ddf349 100644 --- a/test/modeling_library/product.jl +++ b/test/modeling_library/product.jl @@ -9,18 +9,18 @@ discrete_product = ProductDistribution(bernoulli, binom) (n, p2) = (3, 0.9) # random - x = unrelated_pair_product(p1, n, p2) + x = discrete_product(p1, n, p2) # logpdf x = (true, 2) - actual = logpdf(unrelated_pair_product, x, p1, n, p2) + actual = logpdf(discrete_product, x, p1, n, p2) expected = logpdf(bernoulli, x[1], p1) + logpdf(binom, x[2], n, p2) @test isapprox(actual, expected) # test logpdf_grad against finite differencing - f = (x, p, mu, std) -> logpdf(unrelated_pair_product, x, p, mu, std) + f = (x, p, mu, std) -> logpdf(discrete_product, x, p, mu, std) args = (x, p, mu, std) - actual = logpdf_grad(unrelated_pair_product, args...) + actual = logpdf_grad(discrete_product, args...) for (i, b) in enumerate(grad_bools) if b @test isapprox(actual[i], finite_diff(f, args, i, dx)) From 7d585f11a03f7c0a810a82b4b1bc86f015a871a0 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Thu, 14 Mar 2024 16:14:45 -0500 Subject: [PATCH 08/18] Add type assertions --- test/modeling_library/product.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/modeling_library/product.jl b/test/modeling_library/product.jl index 81ddf349..bfb9e0c9 100644 --- a/test/modeling_library/product.jl +++ b/test/modeling_library/product.jl @@ -10,6 +10,7 @@ discrete_product = ProductDistribution(bernoulli, binom) # random x = discrete_product(p1, n, p2) + @assert typeof(x) == get_return_type(discrete_product) == Tuple{Bool, Int} # logpdf x = (true, 2) @@ -40,6 +41,7 @@ continuous_product = ProductDistribution(uniform, normal) # random x = continuous_product(low, high, mu, std) + @asssert typeof(x) == get_return_type(continuous_product) == Typle{Float64, Float64} # logpdf x = (0.1, 0.7) @@ -70,6 +72,7 @@ dissimilar_product = ProductDistribution(bernoulli, normal) # random x = dissimilar_product(p, mu, std) + @assert typeof(x) == get_return_type(dissimilar_product) == Tuple{Bool, Float64} # logpdf x = (false, 0.3) From a2058f2b16eeb6f9be3983281333788821d61935 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Thu, 14 Mar 2024 16:15:29 -0500 Subject: [PATCH 09/18] Bring type inference in line with canon --- src/modeling_library/product.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/modeling_library/product.jl b/src/modeling_library/product.jl index 26aff393..0335c354 100644 --- a/src/modeling_library/product.jl +++ b/src/modeling_library/product.jl @@ -52,11 +52,7 @@ function ProductDistribution(distributions::Vararg{<:Distribution}) start_pos = 1 for dist in distributions - type = typeof(dist) - while supertype(type) != Any - type = supertype(type) - end - push!(types, type.parameters[1]) + push!(types, get_return_type(dist)) _has_output_grads = _has_output_grads && has_output_grad(dist) _is_discrete = _is_discrete && is_discrete(dist) From b61eee71b2c035c490cff5a6d93f215b18bf083d Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Thu, 14 Mar 2024 16:15:39 -0500 Subject: [PATCH 10/18] Reduce type dispatch --- src/modeling_library/product.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/modeling_library/product.jl b/src/modeling_library/product.jl index 0335c354..a8e1a44d 100644 --- a/src/modeling_library/product.jl +++ b/src/modeling_library/product.jl @@ -24,9 +24,9 @@ Example: end ``` """ -struct ProductDistribution{T} <: Distribution{T} +struct ProductDistribution{T, Ds} <: Distribution{T} K::Int - distributions::Vector{<:Distribution} + distributions::Ds has_output_grad::Bool has_argument_grads::Tuple is_discrete::Bool @@ -64,9 +64,9 @@ function ProductDistribution(distributions::Vararg{<:Distribution}) start_pos += length(grads_data) end - return ProductDistribution{Tuple{types...}}( + return ProductDistribution{Tuple{types...}, typeof(distributions)}( length(distributions), - collect(distributions), + distributions, _has_output_grads, Tuple(_has_argument_grads), _is_discrete, From 6973cfd2cbc3454b99f20bf08a1c581712aec868 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Thu, 14 Mar 2024 16:19:02 -0500 Subject: [PATCH 11/18] Fix/improve `random` and `logpdf` --- src/modeling_library/product.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/modeling_library/product.jl b/src/modeling_library/product.jl index a8e1a44d..515562a3 100644 --- a/src/modeling_library/product.jl +++ b/src/modeling_library/product.jl @@ -80,11 +80,11 @@ function extract_args_for_component(dist::ProductDistribution, component_args_fl return component_args_flat[start_arg:start_arg+n-1] end -Gen.random(dist::ProductDistribution, component_args_flat...) = - [random(dist.distributions[k], extract_args_for_component(dist, component_args_flat, k)...) for k in 1:dist.K] +Gen.random(dist::ProductDistribution, args...) = + Tuple(random(d, extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions)) -Gen.logpdf(dist::ProductDistribution, x, component_args_flat...) = - sum(Gen.logpdf(dist.distributions[k], x[k], extract_args_for_component(dist, component_args_flat, k)...) for k in 1:dist.K) +Gen.logpdf(dist::ProductDistribution, x, args...) = + sum(Gen.logpdf(d, x[k], extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions)) function Gen.logpdf_grad(dist::ProductDistribution, x, component_args_flat...) logpdf_grads = [Gen.logpdf_grad(dist.distributions[k], x[k], extract_args_for_component(dist, component_args_flat, k)...) for k in 1:dist.K] From 2b62bd049f4e9bd379e916353aa14421a0f40284 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Thu, 14 Mar 2024 16:20:00 -0500 Subject: [PATCH 12/18] Rewrite `logpdf_grad` --- src/modeling_library/product.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/modeling_library/product.jl b/src/modeling_library/product.jl index 515562a3..29e15e57 100644 --- a/src/modeling_library/product.jl +++ b/src/modeling_library/product.jl @@ -86,14 +86,15 @@ Gen.random(dist::ProductDistribution, args...) = Gen.logpdf(dist::ProductDistribution, x, args...) = sum(Gen.logpdf(d, x[k], extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions)) -function Gen.logpdf_grad(dist::ProductDistribution, x, component_args_flat...) - logpdf_grads = [Gen.logpdf_grad(dist.distributions[k], x[k], extract_args_for_component(dist, component_args_flat, k)...) for k in 1:dist.K] - x_grad = if dist.has_output_grad - tuple((grads[1] for grads in logpdf_grads)...) - else - nothing +function Gen.logpdf_grad(dist::ProductDistribution, x, args...) + x_grad = () + arg_grads = () + for (k, d) in enumerate(dist.distributions) + grads = Gen.logpdf_grad(d, x[k], extract_args_for_component(dist, args, k)...) + x_grad = (x_grad..., grads[1]) + arg_grads = (arg_grads..., grads[2:end]...) end - arg_grads = vcat((collect(grads[2:end]) for grads in logpdf_grads)...) + x_grad = dist.has_output_grad ? x_grad : nothing return (x_grad, arg_grads...) end From e1d5b86f1cb6c7d746f852563c5ac37bdb4ea5f3 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Wed, 20 Mar 2024 13:07:25 -0400 Subject: [PATCH 13/18] Change `get_return_value` to `Gen.get_return_value` --- src/modeling_library/product.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modeling_library/product.jl b/src/modeling_library/product.jl index 29e15e57..e061b497 100644 --- a/src/modeling_library/product.jl +++ b/src/modeling_library/product.jl @@ -52,7 +52,7 @@ function ProductDistribution(distributions::Vararg{<:Distribution}) start_pos = 1 for dist in distributions - push!(types, get_return_type(dist)) + push!(types, Gen.get_return_type(dist)) _has_output_grads = _has_output_grads && has_output_grad(dist) _is_discrete = _is_discrete && is_discrete(dist) From 893adcb95afe60a6fe63e2ed671a6ee43793ea08 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Wed, 20 Mar 2024 21:55:28 -0400 Subject: [PATCH 14/18] Fix more `get_return_type`s --- test/modeling_library/product.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/modeling_library/product.jl b/test/modeling_library/product.jl index bfb9e0c9..2ce7b679 100644 --- a/test/modeling_library/product.jl +++ b/test/modeling_library/product.jl @@ -10,7 +10,7 @@ discrete_product = ProductDistribution(bernoulli, binom) # random x = discrete_product(p1, n, p2) - @assert typeof(x) == get_return_type(discrete_product) == Tuple{Bool, Int} + @assert typeof(x) == Gen.get_return_type(discrete_product) == Tuple{Bool, Int} # logpdf x = (true, 2) @@ -41,7 +41,7 @@ continuous_product = ProductDistribution(uniform, normal) # random x = continuous_product(low, high, mu, std) - @asssert typeof(x) == get_return_type(continuous_product) == Typle{Float64, Float64} + @asssert typeof(x) == Gen.get_return_type(continuous_product) == Typle{Float64, Float64} # logpdf x = (0.1, 0.7) @@ -72,7 +72,7 @@ dissimilar_product = ProductDistribution(bernoulli, normal) # random x = dissimilar_product(p, mu, std) - @assert typeof(x) == get_return_type(dissimilar_product) == Tuple{Bool, Float64} + @assert typeof(x) == Gen.get_return_type(dissimilar_product) == Tuple{Bool, Float64} # logpdf x = (false, 0.3) From 07246b56ad1fbe09d094b82315b7469b75176fd9 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Thu, 21 Mar 2024 13:21:28 -0400 Subject: [PATCH 15/18] Fix test --- test/modeling_library/product.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/modeling_library/product.jl b/test/modeling_library/product.jl index 2ce7b679..ab26836c 100644 --- a/test/modeling_library/product.jl +++ b/test/modeling_library/product.jl @@ -19,8 +19,8 @@ discrete_product = ProductDistribution(bernoulli, binom) @test isapprox(actual, expected) # test logpdf_grad against finite differencing - f = (x, p, mu, std) -> logpdf(discrete_product, x, p, mu, std) - args = (x, p, mu, std) + f = (x, p1, n, p2) -> logpdf(discrete_product, x, p1, n, p2) + args = (x, p1, n, p2) actual = logpdf_grad(discrete_product, args...) for (i, b) in enumerate(grad_bools) if b From 4a89a6dc7ddc3e8316530ea9da98c316f69f4d4c Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 22 Mar 2024 06:53:41 -0400 Subject: [PATCH 16/18] Fix typo `asssert` --- test/modeling_library/product.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/modeling_library/product.jl b/test/modeling_library/product.jl index ab26836c..b1047ebf 100644 --- a/test/modeling_library/product.jl +++ b/test/modeling_library/product.jl @@ -41,7 +41,7 @@ continuous_product = ProductDistribution(uniform, normal) # random x = continuous_product(low, high, mu, std) - @asssert typeof(x) == Gen.get_return_type(continuous_product) == Typle{Float64, Float64} + @assert typeof(x) == Gen.get_return_type(continuous_product) == Typle{Float64, Float64} # logpdf x = (0.1, 0.7) From e74607fd1b64e13c8f136d6ea9f52bee5186b6ea Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 22 Mar 2024 11:36:12 -0400 Subject: [PATCH 17/18] Fix typo `Typle` --- test/modeling_library/product.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/modeling_library/product.jl b/test/modeling_library/product.jl index b1047ebf..ef27aa81 100644 --- a/test/modeling_library/product.jl +++ b/test/modeling_library/product.jl @@ -41,7 +41,7 @@ continuous_product = ProductDistribution(uniform, normal) # random x = continuous_product(low, high, mu, std) - @assert typeof(x) == Gen.get_return_type(continuous_product) == Typle{Float64, Float64} + @assert typeof(x) == Gen.get_return_type(continuous_product) == Tuple{Float64, Float64} # logpdf x = (0.1, 0.7) From 6dc63a8f9b2dea16284b7072cb390b8d10ef2763 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Thu, 28 Mar 2024 14:10:47 -0400 Subject: [PATCH 18/18] Fix tests --- test/modeling_library/product.jl | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/test/modeling_library/product.jl b/test/modeling_library/product.jl index ef27aa81..e4de7a04 100644 --- a/test/modeling_library/product.jl +++ b/test/modeling_library/product.jl @@ -22,10 +22,8 @@ discrete_product = ProductDistribution(bernoulli, binom) f = (x, p1, n, p2) -> logpdf(discrete_product, x, p1, n, p2) args = (x, p1, n, p2) actual = logpdf_grad(discrete_product, args...) - for (i, b) in enumerate(grad_bools) - if b - @test isapprox(actual[i], finite_diff(f, args, i, dx)) - end + for i in [2, 4] + @test isapprox(actual[i], finite_diff(f, args, i, dx)) end end @@ -51,12 +49,13 @@ continuous_product = ProductDistribution(uniform, normal) # test logpdf_grad against finite differencing f = (x, low, high, mu, std) -> logpdf(continuous_product, x, low, high, mu, std) - args = (x, low, high, mu, std) + # A mutable indexable is required by `finite_diff_vec`, hence the `collect` here: + args = (collect(x), low, high, mu, std) actual = logpdf_grad(continuous_product, args...) - for (i, b) in enumerate(grad_bools) - if b - @test isapprox(actual[i], finite_diff(f, args, i, dx)) - end + @test isapprox(actual[1][1], finite_diff_vec(f, args, 1, 1, dx)) + @test isapprox(actual[1][2], finite_diff_vec(f, args, 1, 2, dx)) + for i in 2:5 + @test isapprox(actual[i], finite_diff(f, args, i, dx)) end end @@ -84,9 +83,7 @@ dissimilar_product = ProductDistribution(bernoulli, normal) f = (x, p, mu, std) -> logpdf(dissimilar_product, x, p, mu, std) args = (x, p, mu, std) actual = logpdf_grad(dissimilar_product, args...) - for (i, b) in enumerate(grad_bools) - if b - @test isapprox(actual[i], finite_diff(f, args, i, dx)) - end + for i in 2:4 + @test isapprox(actual[i], finite_diff(f, args, i, dx)) end end