Fitzhugh-Nagumo Bayesian Parameter Estimation Benchmarks

Vaibhav Dixit, Chris Rackauckas
using DiffEqBayes, BenchmarkTools
using OrdinaryDiffEq, RecursiveArrayTools, Distributions, ParameterizedFunctions, Mamba
using Plots
gr(fmt=:png)
Plots.GRBackend()

Defining the problem.

The FitzHugh-Nagumo model is a simplified version of Hodgkin-Huxley model and is used to describe an excitable system (e.g. neuron).

fitz = @ode_def FitzhughNagumo begin
  dv = v - v^3/3 -w + l
  dw = τinv*(v +  a - b*w)
end a b τinv l
(::Main.WeaveSandBox9.FitzhughNagumo{getfield(Main.WeaveSandBox9, Symbol("#
#1#5")),getfield(Main.WeaveSandBox9, Symbol("##2#6")),getfield(Main.WeaveSa
ndBox9, Symbol("##3#7")),Nothing,Nothing,getfield(Main.WeaveSandBox9, Symbo
l("##4#8")),Expr,Expr}) (generic function with 2 methods)
prob_ode_fitzhughnagumo = ODEProblem(fitz,[1.0,1.0],(0.0,10.0),[0.7,0.8,1/12.5,0.5])
sol = solve(prob_ode_fitzhughnagumo, Tsit5())
retcode: Success
Interpolation: specialized 4th order "free" interpolation
t: 14-element Array{Float64,1}:
  0.0                
  0.15079562872319327
  0.6663751602069676 
  1.45491551203011   
  2.634172592705079  
  3.7872847149850593 
  5.149289290296318  
  6.764810021639485  
  7.606013876691694  
  8.324324391350054  
  9.040746828687205  
  9.552469682130221  
  9.985006344186914  
 10.0                
u: 14-element Array{Array{Float64,1},1}:
 [1.0, 1.0]          
 [1.02428, 1.01095]  
 [1.09254, 1.04957]  
 [1.14789, 1.11021]  
 [1.13454, 1.19755]  
 [1.04328, 1.27187]  
 [0.844691, 1.3381]  
 [0.313544, 1.36894] 
 [-0.409826, 1.34276]
 [-1.40824, 1.27062] 
 [-1.90978, 1.15634] 
 [-1.96185, 1.06889] 
 [-1.95443, 0.996705]
 [-1.95386, 0.994246]

Data is genereated by adding noise to the solution obtained above.

t = collect(range(1,stop=10,length=10))
sig = 0.20
data = convert(Array, VectorOfArray([(sol(t[i]) + sig*randn(2)) for i in 1:length(t)]))
2×10 Array{Float64,2}:
 1.24661  1.02895  1.00509  1.0614   …  -0.860215  -1.5491   -1.60432
 1.24283  0.60134  1.13515  1.32399      1.62333    1.37337   1.08187

Plot of the data and the solution.

scatter(t, data[1,:])
scatter!(t, data[2,:])
plot!(sol)

Priors for the parameters which will be passed for the Bayesian Inference

priors = [Truncated(Normal(1.0,0.5),0,1.5),Truncated(Normal(1.0,0.5),0,1.5),Truncated(Normal(0.0,0.5),-0.5,0.5),Truncated(Normal(0.5,0.5),0,1)]
4-element Array{Distributions.Truncated{Distributions.Normal{Float64},Distr
ibutions.Continuous},1}:
 Truncated(Distributions.Normal{Float64}(μ=1.0, σ=0.5), range=(0.0, 1.5)) 
 Truncated(Distributions.Normal{Float64}(μ=1.0, σ=0.5), range=(0.0, 1.5)) 
 Truncated(Distributions.Normal{Float64}(μ=0.0, σ=0.5), range=(-0.5, 0.5))
 Truncated(Distributions.Normal{Float64}(μ=0.5, σ=0.5), range=(0.0, 1.0))

Parameter Estimation with Stan.jl backend

@time bayesian_result_stan = stan_inference(prob_ode_fitzhughnagumo,t,data,priors;reltol=1e-5,abstol=1e-5,vars =(StanODEData(),InverseGamma(3,2)))
make: `/Users/vaibhav/DiffEqBenchmarks.jl/tmp/parameter_estimation_model' i
s up to date.

Length of data array is not equal to nchains,
all chains will use the first data dictionary.

Calling /Users/vaibhav/Downloads/cmdstan-2.18.0/bin/stansummary to infer ac
ross chains.

Inference for Stan model: parameter_estimation_model_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1
,1,1); 4000 iterations saved.

Warmup took (25, 27, 26, 26) seconds, 1.7 minutes total
Sampling took (18, 23, 21, 27) seconds, 1.5 minutes total

                    Mean     MCSE   StdDev     5%    50%    95%    N_Eff  N
_Eff/s    R_hat
lp__             5.6e+00  7.6e-02  1.9e+00    2.1    6.0    8.1  6.2e+02  7
.1e+00  1.0e+00
accept_stat__    8.4e-01  1.5e-02  2.4e-01   0.17   0.94   1.00  2.6e+02  2
.9e+00  1.0e+00
stepsize__       3.9e-02  3.6e-03  5.1e-03  0.033  0.041  0.047  2.0e+00  2
.3e-02  6.6e+13
treedepth__      5.9e+00  1.4e-01  9.1e-01    4.0    6.0    7.0  4.5e+01  5
.1e-01  1.0e+00
n_leapfrog__     8.1e+01  8.3e+00  3.9e+01     23     63    127  2.2e+01  2
.5e-01  1.1e+00
divergent__      2.5e-04      nan  1.6e-02   0.00   0.00   0.00      nan   
   nan  1.0e+00
energy__        -2.6e+00  9.1e-02  2.5e+00   -6.1   -3.0    2.1  7.8e+02  8
.9e+00  1.0e+00
sigma1[1]        2.9e-01  2.3e-03  8.4e-02   0.18   0.28   0.45  1.3e+03  1
.5e+01  1.0e+00
sigma1[2]        3.6e-01  3.8e-03  1.0e-01   0.23   0.34   0.54  7.0e+02  7
.9e+00  1.0e+00
theta1           9.2e-01  1.2e-02  3.2e-01   0.34   0.94    1.4  7.7e+02  8
.7e+00  1.0e+00
theta2           9.6e-01  7.8e-03  2.8e-01   0.44   1.00    1.4  1.3e+03  1
.5e+01  1.0e+00
theta3           1.1e-01  2.4e-03  5.3e-02  0.040   0.10   0.21  5.0e+02  5
.7e+00  1.0e+00
theta4           5.6e-01  5.0e-03  1.0e-01   0.41   0.55   0.73  4.0e+02  4
.5e+00  1.0e+00
theta[1]         9.2e-01  1.2e-02  3.2e-01   0.34   0.94    1.4  7.7e+02  8
.7e+00  1.0e+00
theta[2]         9.6e-01  7.8e-03  2.8e-01   0.44   1.00    1.4  1.3e+03  1
.5e+01  1.0e+00
theta[3]         1.1e-01  2.4e-03  5.3e-02  0.040   0.10   0.21  5.0e+02  5
.7e+00  1.0e+00
theta[4]         5.6e-01  5.0e-03  1.0e-01   0.41   0.55   0.73  4.0e+02  4
.5e+00  1.0e+00

Samples were drawn using hmc with nuts.
For each parameter, N_Eff is a crude measure of effective sample size,
and R_hat is the potential scale reduction factor on split chains (at 
convergence, R_hat=1).

 54.174508 seconds (730.00 k allocations: 33.598 MiB, 0.16% gc time)
DiffEqBayes.StanModel{Int64,Mamba.Chains}(0, Object of type "Mamba.Chains"

Iterations = 1:1000
Thinning interval = 1
Chains = 1,2,3,4
Samples per chain = 1000

[4.83058 0.778715 … 0.0324986 0.341292; 6.11107 0.488424 … 0.0751691 0.3961
28; … ; 6.0256 0.0204567 … 0.146132 0.503971; 5.59766 0.776708 … 0.129993 0
.572213]

[3.79971 0.979581 … 0.180407 0.549928; 5.74962 0.814852 … 0.0812866 0.64573
1; … ; 6.83501 0.971436 … 0.142341 0.500675; 5.40353 0.824716 … 0.202008 0.
655223]

[3.05951 0.848388 … 0.166076 0.711124; 3.83639 0.950364 … 0.130988 0.543624
; … ; 5.55305 0.98938 … 0.11494 0.567927; 6.82127 0.967026 … 0.0559778 0.51
3644]

[6.90678 0.996385 … 0.0850368 0.454131; 5.42693 0.904635 … 0.0659168 0.5263
6; … ; 7.0709 0.971089 … 0.0796397 0.510921; 5.92015 0.966491 … 0.0824837 0
.558359])
plot_chain(bayesian_result_stan)
Press ENTER to draw next plot
Press ENTER to draw next plot
Press ENTER to draw next plot
Press ENTER to draw next plot
Press ENTER to draw next plot