diff --git a/Project.toml b/Project.toml index fe116e2e..3e19b1cf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AbstractGPs" uuid = "99985d1d-32ba-4be9-9821-2ec096f28918" authors = ["JuliaGaussianProcesses Team"] -version = "0.4.0" +version = "0.5.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/Project.toml b/docs/Project.toml index 367379a9..5667f86a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,5 +4,5 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -AbstractGPs = "0.4" +AbstractGPs = "0.4, 0.5" Documenter = "0.27" diff --git a/examples/regression-1d/Project.toml b/examples/regression-1d/Project.toml index c9b9bcb4..4e921146 100644 --- a/examples/regression-1d/Project.toml +++ b/examples/regression-1d/Project.toml @@ -13,7 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] -AbstractGPs = "0.4" +AbstractGPs = "0.4, 0.5" AdvancedHMC = "0.2" Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25" DynamicHMC = "2.2, 3.1" diff --git a/examples/regression-1d/script.jl b/examples/regression-1d/script.jl index e8d86d1d..8fa5e1bf 100644 --- a/examples/regression-1d/script.jl +++ b/examples/regression-1d/script.jl @@ -212,19 +212,19 @@ mean(logpdf(gp_posterior(x_train, y_train, p)(x_test), y_test) for p in samples) # We sample 5 functions from each posterior GP given by the final 100 samples of kernel # parameters. -plt = scatter( - x_train, - y_train; - xlim=(0, 1), - xlabel="x", - ylabel="y", - title="posterior (AdvancedHMC)", - label="Train Data", -) -for p in samples[(end - 100):end] - sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p); samples=5) +plt = plot(; xlim=(0, 1), xlabel="x", ylabel="y", title="posterior (AdvancedHMC)") +for (i, p) in enumerate(samples[(end - 100):end]) + sampleplot!( + plt, + 0:0.02:1, + gp_posterior(x_train, y_train, p); + samples=5, + seriescolor="red", + label=(i == 1 ? "samples" : nothing), + ) end -scatter!(plt, x_test, y_test; label="Test Data") +scatter!(plt, x_train, y_train; label="Train Data", markercolor=1) +scatter!(plt, x_test, y_test; label="Test Data", markercolor=2) plt # #### DynamicHMC @@ -290,18 +290,11 @@ mean(logpdf(gp_posterior(x_train, y_train, p)(x_test), y_test) for p in samples) # We sample a function from the posterior GP for the final 100 samples of kernel # parameters. -plt = scatter( - x_train, - y_train; - xlim=(0, 1), - xlabel="x", - ylabel="y", - title="posterior (DynamicHMC)", - label="Train Data", -) +plt = plot(; xlim=(0, 1), xlabel="x", ylabel="y", title="posterior (DynamicHMC)") +scatter!(plt, x_train, y_train; label="Train Data") scatter!(plt, x_test, y_test; label="Test Data") for p in samples[(end - 100):end] - sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p)) + sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p); seriescolor="red") end plt @@ -349,18 +342,13 @@ mean(logpdf(gp_posterior(x_train, y_train, p)(x_test), y_test) for p in samples) # We sample a function from the posterior GP for the final 100 samples of kernel # parameters. -plt = scatter( - x_train, - y_train; - xlim=(0, 1), - xlabel="x", - ylabel="y", - title="posterior (EllipticalSliceSampling)", - label="Train Data", +plt = plot(; + xlim=(0, 1), xlabel="x", ylabel="y", title="posterior (EllipticalSliceSampling)" ) +scatter!(plt, x_train, y_train; label="Train Data") scatter!(plt, x_test, y_test; label="Test Data") for p in samples[(end - 100):end] - sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p)) + sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p); seriescolor="red") end plt diff --git a/src/util/plotting.jl b/src/util/plotting.jl index b04d2e20..dc6aa211 100644 --- a/src/util/plotting.jl +++ b/src/util/plotting.jl @@ -4,7 +4,7 @@ length(x) == length(gp.x) || throw(DimensionMismatch("length of `x` and `gp.x` has to be equal")) scale::Float64 = pop!(plotattributes, :ribbon_scale, 1.0) - scale > 0.0 || error("`bandwidth` keyword argument must be non-negative") + scale >= 0.0 || error("`ribbon_scale` keyword argument must be non-negative") # compute marginals μ, σ2 = mean_and_var(gp) @@ -82,16 +82,19 @@ Plot samples from the projection `f` of a Gaussian process versus `x`. Make sure to load [Plots.jl](https://github.com/JuliaPlots/Plots.jl) before you use this function. +When plotting multiple samples, these are treated as a _single_ series (i.e., +only a single entry will be added to the legend when providing a `label`). + # Example ```julia using Plots gp = GP(SqExponentialKernel()) -sampleplot(gp(rand(5)); samples=10, markersize=5) +sampleplot(gp(rand(5)); samples=10, linealpha=1.0) ``` -The given example plots 10 samples from the projection of the GP `gp`. The `markersize` is modified -from default of 0.5 to 5. +The given example plots 10 samples from the projection of the GP `gp`. +The `linealpha` is modified from default of 0.35 to 1. --- sampleplot(x::AbstractVector, gp::AbstractGP; samples=1, kwargs...) @@ -115,18 +118,15 @@ SamplePlot((f,)::Tuple{<:FiniteGP}) = SamplePlot((f.x, f)) SamplePlot((x, gp)::Tuple{<:AbstractVector,<:AbstractGP}) = SamplePlot((gp(x, 1e-9),)) @recipe function f(sp::SamplePlot) - nsamples::Int = get(plotattributes, :samples, 1) + nsamples::Int = pop!(plotattributes, :samples, 1) samples = rand(sp.f, nsamples) + flat_x = repeat(vcat(sp.x, NaN), nsamples) + flat_f = vec(vcat(samples, fill(NaN, 1, nsamples))) + # Set default attributes - seriestype --> :line - linealpha --> 0.2 - markershape --> :circle - markerstrokewidth --> 0.0 - markersize --> 0.5 - markeralpha --> 0.3 - seriescolor --> "red" + linealpha --> 0.35 label --> "" - return sp.x, samples + return flat_x, flat_f end diff --git a/test/deprecations.jl b/test/deprecations.jl index 95576086..58273158 100644 --- a/test/deprecations.jl +++ b/test/deprecations.jl @@ -4,11 +4,11 @@ gp = f(x, 0.1) plt = @test_deprecated sampleplot(gp, 10) - @test plt.n == 10 + @test plt.n == 1 @test_deprecated sampleplot!(gp, 4) - @test plt.n == 14 + @test plt.n == 2 @test_deprecated sampleplot!(Plots.current(), gp, 3) - @test plt.n == 17 + @test plt.n == 3 end diff --git a/test/util/plotting.jl b/test/util/plotting.jl index 0cc3b88e..22f510ac 100644 --- a/test/util/plotting.jl +++ b/test/util/plotting.jl @@ -6,18 +6,20 @@ z = rand(10) plt1 = sampleplot(z, gp) @test plt1.n == 1 - @test plt1.series_list[1].plotattributes[:x] == sort(z) + @test isequal(plt1.series_list[1].plotattributes[:x], vcat(z, NaN)) - plt2 = sampleplot(gp; samples=10) - @test plt2.n == 10 - sort_x = sort(x) - @test all(series.plotattributes[:x] == sort_x for series in plt2.series_list) + plt2 = sampleplot(gp; samples=3) + @test plt2.n == 1 + plt2_x = plt2.series_list[1].plotattributes[:x] + plt2_y = plt2.series_list[1].plotattributes[:y] + @test isequal(plt2_x, vcat(x, NaN, x, NaN, x, NaN)) + @test length(plt2_y) == length(plt2_x) + @test isnan(plt2_y[length(z) + 1]) && isnan(plt2_y[2length(z) + 2]) - z = rand(7) - plt3 = sampleplot(z, f; samples=8) - @test plt3.n == 8 - sort_z = sort(z) - @test all(series.plotattributes[:x] == sort_z for series in plt3.series_list) + z3 = rand(7) + plt3 = sampleplot(z3, f; samples=2) + @test plt3.n == 1 + @test isequal(plt3.series_list[1].plotattributes[:x], vcat(z3, NaN, z3, NaN)) # Check recipe dispatches for `FiniteGP`s rec = RecipesBase.apply_recipe(Dict{Symbol,Any}(), gp)