# Mendes Multistate Model

##### Samuel Isaacson, Chris Rackauckas

Taken from Gupta and Mendes, An Overview of Network-Based and -Free Approaches for Stochastic Simulation of Biochemical Systems, Computation, 6 (9), 2018.

using DiffEqBase, DiffEqBiological, DiffEqJump, DiffEqProblemLibrary.JumpProblemLibrary, Plots, Statistics
gr()
fmt = :png
JumpProblemLibrary.importjumpproblems()


# Plot solutions by each method

methods = (Direct(),DirectFW(),FRM(),FRMFW(),SortingDirect(),NRM(),DirectCR(),RSSA())
shortlabels = [string(leg)[12:end-2] for leg in methods]
jprob   = prob_jump_multistate
tf      = 10.0*jprob.tstop
prob    = DiscreteProblem(jprob.u0, (0.0,tf), jprob.rates)
rn      = jprob.network
varlegs = ["A_P", "A_bound_P", "A_unbound_P", "RLA_P"]
varsyms = [
[:S7,:S8,:S9],
[:S9],
[:S7,:S8],
[:S7]
]
varidxs = []
for vars in varsyms
push!(varidxs, [findfirst(isequal(sym),rn.syms) for sym in vars])
end

p = []
for (i,method) in enumerate(methods)
jump_prob = JumpProblem(prob, method, rn, save_positions=(false,false))
sol = solve(jump_prob, SSAStepper(), saveat=tf/1000.)
solv = zeros(1001,4)
for (i,varidx) in enumerate(varidxs)
solv[:,i] = sum(sol[varidx,:], dims=1)
end
if i < length(methods)
push!(p, plot(sol.t,solv,title=shortlabels[i],legend=false,format=fmt))
else
push!(p, plot(sol.t,solv,title=shortlabels[i],legend=true,labels=varlegs,format=fmt))
end
end
plot(p...,format=fmt)


# Benchmarking performance of the methods

function run_benchmark!(t, jump_prob, stepper)
sol = solve(jump_prob, stepper)
@inbounds for i in 1:length(t)
t[i] = @elapsed (sol = solve(jump_prob, stepper))
end
end

run_benchmark! (generic function with 1 method)

nsims = 100
benchmarks = Vector{Vector{Float64}}()
for method in methods
jump_prob = JumpProblem(prob, method, rn, save_positions=(false,false))
stepper = SSAStepper()
t = Vector{Float64}(undef, nsims)
run_benchmark!(t, jump_prob, stepper)
push!(benchmarks, t)
end

medtimes = Vector{Float64}(undef,length(methods))
stdtimes = Vector{Float64}(undef,length(methods))
avgtimes = Vector{Float64}(undef,length(methods))
for i in 1:length(methods)
medtimes[i] = median(benchmarks[i])
avgtimes[i] = mean(benchmarks[i])
stdtimes[i] = std(benchmarks[i])
end
using DataFrames

df = DataFrame(names=shortlabels,medtimes=medtimes,relmedtimes=(medtimes/medtimes[1]),avgtimes=avgtimes, std=stdtimes, cv=stdtimes./avgtimes)

sa = [text(string(round(mt,digits=3),"s"),:center,12) for mt in df.medtimes]
bar(df.names,df.relmedtimes,legend=:false, fmt=fmt)
scatter!(df.names, .05 .+ df.relmedtimes, markeralpha=0, series_annotations=sa, fmt=fmt)
ylabel!("median relative to Direct")
title!("Multistate Model")

using DiffEqBenchmarks
## Appendix

These benchmarks are a part of the DiffEqBenchmarks.jl repository, found at: https://github.com/JuliaDiffEq/DiffEqBenchmarks.jl

To locally run this tutorial, do the following commands:

using DiffEqBenchmarks
DiffEqBenchmarks.weave_file("Jumps","Mendes_multistate_example.jmd")

