Skip to content

Commit

Permalink
Merge pull request #40 from Abhishek-1Bhatt/lvex
Browse files Browse the repository at this point in the history
Updated the implementation
  • Loading branch information
ChrisRackauckas authored Jun 15, 2022
2 parents 9fd95d7 + 2739afe commit faa5481
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions LotkaVolterra/scenario_1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra, Optim
using DiffEqFlux, Flux
using Optimization, OptimizationFlux, OptimizationOptimJL #OptimizationFlux for ADAM and OptimizationOptimJL for BFGS
using DiffEqSensitivity
using Lux
using Plots
gr()
using JLD2, FileIO
using Statistics
# Set a random seed for reproduceable behaviour
using Random
rng = Random.default_rng()
Random.seed!(1234)

#### NOTE
Expand Down Expand Up @@ -43,7 +46,7 @@ t = solution.t

# Add noise in terms of the mean
= mean(X, dims = 2)
noise_magnitude = Float32(5e-2)
noise_magnitude = Float32(5e-3)
Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X))

plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
Expand All @@ -53,15 +56,15 @@ scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])
rbf(x) = exp.(-(x.^2))

# Multilayer FeedForward
U = FastChain(
FastDense(2,5,rbf), FastDense(5,5, rbf), FastDense(5,5, rbf), FastDense(5,2)
U = Lux.Chain(
Lux.Dense(2,5,rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,2)
)
# Get the initial parameters
p = initial_params(U)
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)

# Define the hybrid model
function ude_dynamics!(du,u, p, t, p_true)
= U(u, p) # Network prediction
= U(u, p, st)[1] # Network prediction
du[1] = p_true[1]*u[1] + û[1]
du[2] = -p_true[4]*u[2] + û[2]
end
Expand All @@ -75,7 +78,7 @@ prob_nn = ODEProblem(nn_dynamics!,Xₙ[:, 1], tspan, p)
# Define a predictor
function predict(θ, X = Xₙ[:,1], T = t)
Array(solve(prob_nn, Vern7(), u0 = X, p=θ,
tspan = (T[1], T[end]), saveat = T,
saveat = T,
abstol=1e-6, reltol=1e-6,
sensealg = ForwardDiffSensitivity()
))
Expand All @@ -90,23 +93,26 @@ end
# Container to track the losses
losses = Float32[]

# Callback to show the loss during training
callback(θ,l) = begin
push!(losses, l)
if length(losses)%50==0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
false
callback = function (p, l)
push!(losses, l)
if length(losses)%50==0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end

## Training

# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
res1 = DiffEqFlux.sciml_train(loss, p, ADAM(0.1f0), cb=callback, maxiters = 200)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res1 = Optimization.solve(optprob, ADAM(0.1), callback=callback, maxiters = 200)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS
res2 = DiffEqFlux.sciml_train(loss, res1.minimizer, BFGS(initial_stepnorm=0.01f0), cb=callback, maxiters = 10000)
optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), callback=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Plot the losses
Expand All @@ -128,7 +134,7 @@ savefig(pl_trajectory, joinpath(pwd(), "plots", "$(svname)_trajectory_reconstruc
# Ideal unknown interactions of the predictor
= [-p_[2]*(X̂[1,:].*X̂[2,:])';p_[3]*(X̂[1,:].*X̂[2,:])']
# Neural network guess
= U(X̂,p_trained)
= U(X̂,p_trained,st)[1]

pl_reconstruction = plot(ts, transpose(Ŷ), xlabel = "t", ylabel ="U(x,y)", color = :red, label = ["UDE Approximation" nothing])
plot!(ts, transpose(Ȳ), color = :black, label = ["True Interaction" nothing])
Expand All @@ -147,14 +153,14 @@ savefig(pl_overall, joinpath(pwd(), "plots", "$(svname)_reconstruction.pdf"))
# Generate the basis functions, multivariate polynomials up to deg 5
# and sine
b = [polynomial_basis(u, 5); sin.(u)]
basis = Basis(b, u)
basis = Basis(b,u);

# Create the thresholds which should be used in the search process
λ = Float32.(exp10.(-7:0.1:0))
# Create an optimizer for the SINDy problem
opt = STLSQ(λ)
# Define different problems for the recovery
full_problem = ContinuousDataDrivenProblem(solution)
full_problem = ContinuousDataDrivenProblem(X, t)
ideal_problem = ContinuousDataDrivenProblem(X̂, ts, DX = Ȳ)
nn_problem = ContinuousDataDrivenProblem(X̂, ts, DX = Ŷ)
# Test on ideal derivative data for unknown function ( not available )
Expand Down Expand Up @@ -190,7 +196,7 @@ plot!(estimate)

# Look at long term prediction
t_long = (0.0f0, 50.0f0)
estimation_prob = ODEProblem(estimated_dynamics!, u0, t_long, )
estimation_prob = ODEProblem(recovered_dynamics!, u0, t_long, parameters(nn_res))
estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.1) # Using higher tolerances here results in exit of julia
plot(estimate_long)

Expand Down

0 comments on commit faa5481

Please sign in to comment.