-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlogistic_trajectory_by_week_example.jl
148 lines (128 loc) · 4.22 KB
/
logistic_trajectory_by_week_example.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Turing logistic example
"""
We need a logistic function, which is provided by StatsFuns.
"""
using StatsFuns: logistic, logit
using Turing, MCMCChains, BayesTesting
# Turing.jl model definitions:
# Bayesian logistic regression
# only one covariate (log time in the example below)
@model logistic_regression(x, y) = begin
a ~ Normal(0, 5)
b ~ Normal(0, 5)
# balance ~ Normal(0, σ²)
# income ~ Normal(0, σ²)
for i = 1:length(y)
v = logistic(a + b*x[i])
y[i] ~ Bernoulli(v)
end
end
# multiple covariates in matrix x
@model logistic_multi_regression(x, y, ::Type{TV}=Vector{Float64}) where {TV} = begin
# a ~ Normal(0, 10)
n, k = size(x)
# b = Array{Float64}(undef, k)
b = TV(undef, k+1)
for i in 1:(k+1)
b[i] ~ Normal(0,6)
end
mu = [ones(n) x]*b
for i = 1:n
v = logistic(mu[i])
y[i] ~ Bernoulli(v)
end
end
# load data
using CSV, DataFrames, Plots, StatsPlots
df = CSV.read("cgi_example.csv")
# model CGI-I as a function of log(time)
y = df.cgi
x = df.ltime
n = length(y)
# Sample using HMC.
Random.seed!(1359)
@time chain = mapreduce(c -> sample(logistic_regression(x, y), HMC(3000, 0.05, 10)),
chainscat, 1:3)
# Sample using NUTS
#chain = mapreduce(c -> sample(logistic_regression(x, y), NUTS(3000,1000, 0.65)),
# chainscat, 1:3)
plot(chain)
cc = chain[1001:end] ### must drop the adaption sample (1st 1000 here)
@show(describe(cc))
plot(cc)
bdraws = Array(cc["b"])
plot(bdraws, st=:density, label="b",fill=true)
adraws = Array(cc["a"])
## Generating the plots:
# plot predictive probabilities
# (apologies for the lazy cut & paste code, I should have wrote a loop!)
v = zeros(length(bdraws),5)
z = log.(1:5)
wk = [0 2 4 6 8]
## Corrected figure (the one in the presentation is incorrect)
plt = plot()
for i in 1:5
v[:,i] = logistic.(adraws .+ bdraws.*z[i])
if i == 1
plot!(v[:,1], st=:density, label="Prob success for baseline",fill=true,title="Predictive Probabilities")
vline!([mean(v[:,1])],linewidth=2,label="Baseline mean")
else
w = wk[i]
plot!(v[:,i], st=:density, label="Prob success for x = $w",fill=true)
end
end
plt
#savefig("predictive_probs.png")
yt = zeros(5)
ylb = zeros(5)
yub = zeros(5)
for i in 1:5
yt[i] = mean(v[:,i])
ylb[i],yub[i] = quantile(v[:,i],[0.025,0.975])
end
t = [0,2,4,6,8]
plot(t,yt, st=:scatter, color=:blue,label="Mean probability",legend=:topleft,xlabel="Weeks",ylabel="Probability")
plot!(t,ylb, st=:scatter, color=:green,alpha=0.6,label="0.95 interval")
plot!(t,yub, st=:scatter,color=:green,alpha=0.6, label="")
#savefig("trajectory_plot.png")
# An example with more RHS variables (covariates) - not in the presentation
# interaction variables to allow trajectory to differ
df.htr_ltime = df.ltime.*df.htr_over
df.c19_ltime = df.c19.*df.ltime
df.slc_ltime = df.slc_ss.*df.ltime
dd = dropmissing(df) ## drop obs. with missing values
x = [dd.ltime dd.htr_over dd.htr_ltime]
y = dd.cgi
## The following takes approx 3 minutes on my machine for 5 chains
Random.seed!(1359)
nchains = 5 # number of MCMC chains
Turing.setadbackend(:reverse_diff)
# NUTS
# @time chain = mapreduce(c -> sample(logistic_multi_regression(x, y), NUTS(4000,2000, 0.65)),
# chainscat, 1:nchains)
# HMC
@time chain = mapreduce(c -> sample(logistic_multi_regression(x, y), HMC(5000,0.05, 10)),
chainscat, 1:nchains)
plot(chain)
cc = chain[1001:end]
plot(cc)
@show(describe(cc))
# plot each chain to select (if some unstable)
param = 2 # looking at the second parameter
plt = plot()
for i in 1:nchains
c_param = Array(cc[:,param,i])
plot!(c_param,st=:density,label="chain $i")
end
plt
b2_draws = Array(cc["b[2]"]) # get the b[2] draws
plot(b2_draws,st=:density,fill=true,label="b2")
vline!([0.0],label="",linecolor=:black,linewidth=2)
# This coefficient is not "statistically significant" - look where zero is
using BayesTesting
import BayesTesting.post_odds_pval
# Bayesian posterior density ratio and Bayesian probability in tail
@show(mcodds(b2_draws))
@show(bayespval(b2_draws))
pdr_pval(b2_draws)
####################################