# Fitzhugh-Nagumo Bayesian Parameter Estimation Benchmarks

##### Vaibhav Dixit, Chris Rackauckas
using DiffEqBayes, BenchmarkTools

using OrdinaryDiffEq, RecursiveArrayTools, Distributions, ParameterizedFunctions, Mamba
using Plots

gr(fmt=:png)

Plots.GRBackend()


### Defining the problem.

The FitzHugh-Nagumo model is a simplified version of Hodgkin-Huxley model and is used to describe an excitable system (e.g. neuron).

fitz = @ode_def FitzhughNagumo begin
dv = v - v^3/3 -w + l
dw = τinv*(v +  a - b*w)
end a b τinv l

(::Main.WeaveSandBox9.FitzhughNagumo{getfield(Main.WeaveSandBox9, Symbol("#
#1#5")),getfield(Main.WeaveSandBox9, Symbol("##2#6")),getfield(Main.WeaveSa
ndBox9, Symbol("##3#7")),Nothing,Nothing,getfield(Main.WeaveSandBox9, Symbo
l("##4#8")),Expr,Expr}) (generic function with 2 methods)

prob_ode_fitzhughnagumo = ODEProblem(fitz,[1.0,1.0],(0.0,10.0),[0.7,0.8,1/12.5,0.5])
sol = solve(prob_ode_fitzhughnagumo, Tsit5())

retcode: Success
Interpolation: specialized 4th order "free" interpolation
t: 14-element Array{Float64,1}:
0.0
0.15079562872319327
0.6663751602069676
1.45491551203011
2.634172592705079
3.7872847149850593
5.149289290296318
6.764810021639485
7.606013876691694
8.324324391350054
9.040746828687205
9.552469682130221
9.985006344186914
10.0
u: 14-element Array{Array{Float64,1},1}:
[1.0, 1.0]
[1.02428, 1.01095]
[1.09254, 1.04957]
[1.14789, 1.11021]
[1.13454, 1.19755]
[1.04328, 1.27187]
[0.844691, 1.3381]
[0.313544, 1.36894]
[-0.409826, 1.34276]
[-1.40824, 1.27062]
[-1.90978, 1.15634]
[-1.96185, 1.06889]
[-1.95443, 0.996705]
[-1.95386, 0.994246]


Data is genereated by adding noise to the solution obtained above.

t = collect(range(1,stop=10,length=10))
sig = 0.20
data = convert(Array, VectorOfArray([(sol(t[i]) + sig*randn(2)) for i in 1:length(t)]))

2×10 Array{Float64,2}:
1.24661  1.02895  1.00509  1.0614   …  -0.860215  -1.5491   -1.60432
1.24283  0.60134  1.13515  1.32399      1.62333    1.37337   1.08187


### Plot of the data and the solution.

scatter(t, data[1,:])
scatter!(t, data[2,:])
plot!(sol)


### Priors for the parameters which will be passed for the Bayesian Inference

priors = [Truncated(Normal(1.0,0.5),0,1.5),Truncated(Normal(1.0,0.5),0,1.5),Truncated(Normal(0.0,0.5),-0.5,0.5),Truncated(Normal(0.5,0.5),0,1)]

4-element Array{Distributions.Truncated{Distributions.Normal{Float64},Distr
ibutions.Continuous},1}:
Truncated(Distributions.Normal{Float64}(μ=1.0, σ=0.5), range=(0.0, 1.5))
Truncated(Distributions.Normal{Float64}(μ=1.0, σ=0.5), range=(0.0, 1.5))
Truncated(Distributions.Normal{Float64}(μ=0.0, σ=0.5), range=(-0.5, 0.5))
Truncated(Distributions.Normal{Float64}(μ=0.5, σ=0.5), range=(0.0, 1.0))


## Parameter Estimation with Stan.jl backend

@time bayesian_result_stan = stan_inference(prob_ode_fitzhughnagumo,t,data,priors;reltol=1e-5,abstol=1e-5,vars =(StanODEData(),InverseGamma(3,2)))

make: /Users/vaibhav/DiffEqBenchmarks.jl/tmp/parameter_estimation_model' i
s up to date.

Length of data array is not equal to nchains,
all chains will use the first data dictionary.

ross chains.

Inference for Stan model: parameter_estimation_model_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1
,1,1); 4000 iterations saved.

Warmup took (25, 27, 26, 26) seconds, 1.7 minutes total
Sampling took (18, 23, 21, 27) seconds, 1.5 minutes total

Mean     MCSE   StdDev     5%    50%    95%    N_Eff  N
_Eff/s    R_hat
lp__             5.6e+00  7.6e-02  1.9e+00    2.1    6.0    8.1  6.2e+02  7
.1e+00  1.0e+00
accept_stat__    8.4e-01  1.5e-02  2.4e-01   0.17   0.94   1.00  2.6e+02  2
.9e+00  1.0e+00
stepsize__       3.9e-02  3.6e-03  5.1e-03  0.033  0.041  0.047  2.0e+00  2
.3e-02  6.6e+13
treedepth__      5.9e+00  1.4e-01  9.1e-01    4.0    6.0    7.0  4.5e+01  5
.1e-01  1.0e+00
n_leapfrog__     8.1e+01  8.3e+00  3.9e+01     23     63    127  2.2e+01  2
.5e-01  1.1e+00
divergent__      2.5e-04      nan  1.6e-02   0.00   0.00   0.00      nan
nan  1.0e+00
energy__        -2.6e+00  9.1e-02  2.5e+00   -6.1   -3.0    2.1  7.8e+02  8
.9e+00  1.0e+00
sigma1[1]        2.9e-01  2.3e-03  8.4e-02   0.18   0.28   0.45  1.3e+03  1
.5e+01  1.0e+00
sigma1[2]        3.6e-01  3.8e-03  1.0e-01   0.23   0.34   0.54  7.0e+02  7
.9e+00  1.0e+00
theta1           9.2e-01  1.2e-02  3.2e-01   0.34   0.94    1.4  7.7e+02  8
.7e+00  1.0e+00
theta2           9.6e-01  7.8e-03  2.8e-01   0.44   1.00    1.4  1.3e+03  1
.5e+01  1.0e+00
theta3           1.1e-01  2.4e-03  5.3e-02  0.040   0.10   0.21  5.0e+02  5
.7e+00  1.0e+00
theta4           5.6e-01  5.0e-03  1.0e-01   0.41   0.55   0.73  4.0e+02  4
.5e+00  1.0e+00
theta[1]         9.2e-01  1.2e-02  3.2e-01   0.34   0.94    1.4  7.7e+02  8
.7e+00  1.0e+00
theta[2]         9.6e-01  7.8e-03  2.8e-01   0.44   1.00    1.4  1.3e+03  1
.5e+01  1.0e+00
theta[3]         1.1e-01  2.4e-03  5.3e-02  0.040   0.10   0.21  5.0e+02  5
.7e+00  1.0e+00
theta[4]         5.6e-01  5.0e-03  1.0e-01   0.41   0.55   0.73  4.0e+02  4
.5e+00  1.0e+00

Samples were drawn using hmc with nuts.
For each parameter, N_Eff is a crude measure of effective sample size,
and R_hat is the potential scale reduction factor on split chains (at
convergence, R_hat=1).

54.174508 seconds (730.00 k allocations: 33.598 MiB, 0.16% gc time)
DiffEqBayes.StanModel{Int64,Mamba.Chains}(0, Object of type "Mamba.Chains"

Iterations = 1:1000
Thinning interval = 1
Chains = 1,2,3,4
Samples per chain = 1000

[4.83058 0.778715 … 0.0324986 0.341292; 6.11107 0.488424 … 0.0751691 0.3961
28; … ; 6.0256 0.0204567 … 0.146132 0.503971; 5.59766 0.776708 … 0.129993 0
.572213]

[3.79971 0.979581 … 0.180407 0.549928; 5.74962 0.814852 … 0.0812866 0.64573
1; … ; 6.83501 0.971436 … 0.142341 0.500675; 5.40353 0.824716 … 0.202008 0.
655223]

[3.05951 0.848388 … 0.166076 0.711124; 3.83639 0.950364 … 0.130988 0.543624
; … ; 5.55305 0.98938 … 0.11494 0.567927; 6.82127 0.967026 … 0.0559778 0.51
3644]

[6.90678 0.996385 … 0.0850368 0.454131; 5.42693 0.904635 … 0.0659168 0.5263
6; … ; 7.0709 0.971089 … 0.0796397 0.510921; 5.92015 0.966491 … 0.0824837 0
.558359])

plot_chain(bayesian_result_stan)

Press ENTER to draw next plot
Press ENTER to draw next plot
Press ENTER to draw next plot
Press ENTER to draw next plot
Press ENTER to draw next plot
`