Skip to content

Commit

Permalink
Add Plotting Recipes
Browse files Browse the repository at this point in the history
  • Loading branch information
Micki-D committed Apr 3, 2024
1 parent 823f61f commit 1175a80
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 20 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MonotonicSplines = "568f7cb4-8305-41bc-b90d-d32b39cc99d1"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f"
Expand Down
2 changes: 2 additions & 0 deletions src/AdaptiveFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ using Lux
using MonotonicSplines
using Optimisers
using Random
using RecipesBase
using Statistics
using StatsFuns
using ValueShapes
using Zygote

include("adaptive_flows.jl")
include("optimize_flow.jl")
include("plotting.jl")
include("rqspline_coupling.jl")
include("utils.jl")
end # module
2 changes: 1 addition & 1 deletion src/adaptive_flows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ function build_flow(target_samples::AbstractArray, modules::Vector = [InvMulAdd,
stds = vec(std(flat_samples, dims = 2))
means = vec(mean(flat_samples, dims = 2))

flow_modules[1] = modules[1] isa Function ? typeof(modules[1])(Diagonal(stds), means) : modules[1](Diagonal(stds), means)
flow_modules[1] = modules[1] isa Function ? typeof(modules[1])(Matrix(Diagonal(stds)), means) : modules[1](Matrix(Diagonal(stds)), means)
end

for (i, flow_module) in enumerate(modules[2:end])
Expand Down
78 changes: 59 additions & 19 deletions src/optimize_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,57 @@
std_normal_logpdf(x::Real) = -(abs2(x) + log2π)/2
std_normal_logpdf(x::AbstractArray) = vec(sum(std_normal_logpdf.(flatview(x)), dims = 1))

function negll_flow_loss(flow::F, x::AbstractMatrix{<:Real}, logpdf::Function) where F<:AbstractFlow
function negll_flow_loss(flow::F, x::AbstractMatrix{<:Real}, cum_ladj::AbstractVector, logpdf::Function) where F<:AbstractFlow
nsamples = size(x, 2)
flow_corr = fchain(flow,logpdf.f)
y, ladj = with_logabsdet_jacobian(flow_corr, x)
y, ladj_tmp = with_logabsdet_jacobian(flow_corr, x)
ladj = cum_ladj + vec(ladj_tmp)
ll = (sum(logpdf.logdensity(y)) + sum(ladj)) / nsamples
return -ll
end

function negll_flow(flow::F, x::AbstractMatrix{<:Real}, logd_orig::AbstractVector, logpdf::Tuple{Function, Function}) where F<:AbstractFlow
negll, back = Zygote.pullback(negll_flow_loss, flow, x, logpdf[2])
function negll_flow(flow::F, x::AbstractMatrix{<:Real}, logd_orig::AbstractVector, cum_ladj::AbstractVector, logpdf::Tuple{Function, Function}) where F<:AbstractFlow
negll, back = Zygote.pullback(negll_flow_loss, flow, x, cum_ladj, logpdf[2])
d_flow = back(one(eltype(x)))[1]
return negll, d_flow
end
export negll_flow

function KLDiv_flow_loss(flow::F, x::AbstractMatrix{<:Real}, logd_orig::AbstractVector, logpdfs::Tuple{Function, Function}) where F<:AbstractFlow
function KLDiv_flow_loss(flow::F, x::AbstractMatrix{<:Real}, logd_orig::AbstractVector, cum_ladj::AbstractVector, logpdfs::Tuple{Function, Function}) where F<:AbstractFlow
nsamples = size(x, 2)
flow_corr = fchain(flow, logpdfs[2].f)
logpdf_y = logpdfs[2].logdensity
y, ladj = with_logabsdet_jacobian(flow_corr, x)
KLDiv = sum(exp.(logd_orig - vec(ladj)) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples
y, ladj_tmp = with_logabsdet_jacobian(flow_corr, x)
ladj = cum_ladj + vec(ladj_tmp)


q = logd_orig - vec(ladj)
p = logpdf_y(y)


# KLDiv = (sum(exp.(q) .* (q - p)) + sum(exp.(p) .* (p - q))) / nsamples # composite

KLDiv = sum(exp.(logd_orig - vec(ladj)) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples #(1)

#KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig)) / nsamples #(MALA PAPER)



# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig)) / nsamples

# KLDiv = sum(exp.(logd_orig) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples #(to tight)


#KLDiv = sum(exp.(logd_orig - vec(ladj)) .* (logd_orig - vec(ladj) - logpdf_y(y))) / nsamples #(1)
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig)) / nsamples #(2)/ (3) with logpdfs[2] = target
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (vec(ladj) - logd_orig)) / nsamples
# KLDiv = sum(exp.(logpdf_y(y) + vec(ladj)) .* (logpdf_y(y) + vec(ladj) - logd_orig - logpdf_y(y))) / nsamples
# KLDiv = sum(exp.(logd_orig - vec(ladj)) .* (logd_orig - vec(ladj) - logpdf_y(y) - logpdfs[1].logdensity(x))) / nsamples
return KLDiv
end

function KLDiv_flow(flow::F, x::AbstractMatrix{<:Real}, logd_orig::AbstractVector, logpdfs::Tuple{Function, Function}) where F<:AbstractFlow
KLDiv, back = Zygote.pullback(KLDiv_flow_loss, flow, x, logd_orig, logpdfs)
function KLDiv_flow(flow::F, x::AbstractMatrix{<:Real}, logd_orig::AbstractVector, cum_ladj::AbstractVector, logpdfs::Tuple{Function, Function}) where F<:AbstractFlow
KLDiv, back = Zygote.pullback(KLDiv_flow_loss, flow, x, logd_orig, cum_ladj, logpdfs)
d_flow = back(one(eltype(x)))[1]
return KLDiv, d_flow
end
Expand Down Expand Up @@ -88,7 +113,7 @@ function optimize_flow(samples::Union{AbstractArray, Tuple{AbstractArray, Abstra
flow, state, loss_hist = _train_flow(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples)
end

return (result = flow, optimizer_state = state, loss_hist = vcat(loss_history, loss_hist))
return (result = flow, optimizer_state = state, loss_hist = vcat(loss_history, loss_hist), training_metadata = Dict(:nepochs => nepochs, :nbatches => nbatches, :shuffle_samples => shuffle_samples, :sequential => sequential, :optimizer => optimizer, :loss => loss))
end
export optimize_flow

Expand All @@ -101,7 +126,9 @@ function _train_flow_sequentially(samples::Union{AbstractArray, Tuple{AbstractAr
pushfwd_logpdf::Union{Function,
Tuple{Function, Function}},
logd_orig::AbstractVector,
shuffle_samples::Bool)
shuffle_samples::Bool;
cum_ladj::AbstractVector = zeros(length(logd_orig))
)

if !_is_trainable(initial_flow)
return initial_flow, nothing, nothing
Expand All @@ -112,7 +139,6 @@ function _train_flow_sequentially(samples::Union{AbstractArray, Tuple{AbstractAr
component_optstates = Vector{Any}()
component_loss_hists = Vector{Any}()
intermediate_samples = samples
logd_orig_intermediate = logd_orig

for flow_component in initial_flow.flow.fs
trained_flow_component, component_opt_state, component_loss_hist = _train_flow_sequentially(intermediate_samples,
Expand All @@ -122,8 +148,10 @@ function _train_flow_sequentially(samples::Union{AbstractArray, Tuple{AbstractAr
nbatches,
loss,
pushfwd_logpdf,
logd_orig_intermediate,
shuffle_samples)
logd_orig,
shuffle_samples;
cum_ladj
)
push!(trained_components, trained_flow_component)
push!(component_optstates, component_opt_state)
push!(component_loss_hists, component_loss_hist)
Expand All @@ -133,16 +161,16 @@ function _train_flow_sequentially(samples::Union{AbstractArray, Tuple{AbstractAr
intermediate_samples = (x_int, trained_flow_component(intermediate_samples[2]))
# fix AffineMaps to return row matrix ladj
ladj = ladj isa Real ? fill(ladj, length(logd_orig_intermediate)) : vec(ladj)
logd_orig_intermediate -= ladj
cum_ladj += ladj
else
intermediate_samples, ladj = with_logabsdet_jacobian(trained_flow_component, intermediate_samples)
ladj = ladj isa Real ? fill(ladj, length(logd_orig_intermediate)) : vec(ladj)
logd_orig_intermediate -= ladj
cum_ladj += ladj
end
end
return typeof(initial_flow)(trained_components), component_optstates, component_loss_hists
end
_train_flow(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples)
_train_flow(samples, initial_flow, optimizer, nepochs, nbatches, loss, pushfwd_logpdf, logd_orig, shuffle_samples; cum_ladj)
end


Expand All @@ -154,7 +182,9 @@ function _train_flow(samples::Union{AbstractArray, Tuple{AbstractArray, Abstract
loss::Function,
pushfwd_logpdf::Union{Function, Tuple{Function, Function}},
logd_orig::AbstractVector,
shuffle_samples::Bool)
shuffle_samples::Bool;
cum_ladj::AbstractVector = zeros(length(logd_orig))
)

if !_is_trainable(initial_flow)
return initial_flow, nothing, nothing
Expand All @@ -163,13 +193,23 @@ function _train_flow(samples::Union{AbstractArray, Tuple{AbstractArray, Abstract
batchsize = round(Int, n_samples / nbatches)
batches = samples isa Tuple ? collect.(Iterators.partition.(samples, batchsize)) : collect(Iterators.partition(samples, batchsize))
logd_orig_batches = collect(Iterators.partition(logd_orig, batchsize))
cum_ladj_batches = collect(Iterators.partition(cum_ladj, batchsize))
flow = deepcopy(initial_flow)
state = Optimisers.setup(optimizer, deepcopy(initial_flow))
loss_hist = Vector{Float64}()
for i in 1:nepochs
for j in 1:nbatches
training_samples = batches isa Tuple ? (Matrix(flatview(batches[1][j])), Matrix(flatview(batches[2][j]))) : Matrix(flatview(batches[j]))
loss_val, d_flow = loss(flow, training_samples, logd_orig_batches[j], pushfwd_logpdf)
loss_val, d_flow = loss(flow, training_samples, logd_orig_batches[j], cum_ladj_batches[j], pushfwd_logpdf)
if i == 1 && j == 1 && flow.mask[1]
global g_state_gradient_1 = (loss_val, d_flow)
end

if i == 1 && j == 2 && flow.mask[1]
global g_state_gradient_2 = (loss_val, d_flow)
end


state, flow = Optimisers.update(state, flow, d_flow)
push!(loss_hist, loss_val)
end
Expand Down
184 changes: 184 additions & 0 deletions src/plotting.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# This file is a part of AdaptiveFlows.jl, licensed under the MIT License (MIT).

@recipe function plot_flow(flow::AbstractFlow, samples::AbstractMatrix;
d_sel = nothing,
n_bins_1D = 100,
n_bins_2D = 100,
colorbar = false,
h_x = true,
h_y = true,
fs = nothing,
size_plot = (1000, 1000),
p = x -> 1/sqrt(2pi) * exp(-x^2/2),
training_metadata = nothing
)
n_dims = size(samples,1)

if isnothing(d_sel)
d_sel = 1:Integer(minimum([6, n_dims]))
end
n_plots = length(d_sel)

if isnothing(fs)
fs = n_plots > 5 ? 6 : n_plots > 3 ? 7 : 11
end

samples_transformed = flow(samples)

stds_in = vec(std(samples, dims = 2))
means_in = vec(mean(samples, dims = 2))
samples_in = InvMulAdd(Diagonal(stds_in), means_in)(samples)

stds_out = vec(std(samples_transformed, dims = 2))
means_out = vec(mean(samples_transformed, dims = 2))
samples_out = InvMulAdd(Diagonal(stds_out), means_out)(samples_transformed)

if !isnothing(training_metadata)
training_metadata_labels = ["[Training metadata] ", "Loss: $(training_metadata[:loss]) ", "Optimizer: ", "$(training_metadata[:optimizer]) ", "# Batches: $(training_metadata[:nbatches]) ", "# Epochs: $(training_metadata[:nepochs]) ", "Sequential?: $(training_metadata[:sequential]) ", "Shuffle samples?: $(training_metadata[:shuffle_samples]) "]
end
samples_metadata = ["[Samples metadata] ", "# Samples: $(size(samples,2)) ", "# Dimensions: $n_dims ", "Displayed Dimensions: ", "$(d_sel) ", " ", " ", " "]

layout --> (n_plots + 1, n_plots)
size --> size_plot

for i in 1:n_plots
bin_range_1D = range(minimum([minimum(samples_in[i,:]), minimum(samples_out[i,:])]), stop = maximum([maximum(samples_in[i,:]), maximum(samples_out[i,:])]), length = n_bins_1D)

# Diagonal
subplot := i + (i-1) * n_plots
tickfontsize --> fs
labelfontsize --> fs
if h_x
@series begin
seriestype --> :stephist
bins --> bin_range_1D
normalize --> :pdf
label --> false
color --> :blue
fill --> true
alpha --> 0.3
samples_in[d_sel[i],:]
end
end
if h_y
@series begin
seriestype --> :stephist
bins --> bin_range_1D
normalize --> :pdf
label --> false
color --> :red
fill --> true
alpha --> 0.3
samples_out[d_sel[i],:]
end
end
@series begin
seriestype --> :line
lw --> 1.5
color --> :black
label --> false
p
end

for j in i+1:n_plots
# Input, lower off-diagonal
subplot := i + (j - 1) * n_plots
@series begin
seriestype --> :histogram2d
bins --> n_bins_2D
color --> :blues
colorbar --> colorbar
background --> :white
aspect_ratio --> :equal
tickfontsize --> fs
labelfontsize --> fs
xlabel --> "x$((d_sel[i]))"
ylabel --> "x$((d_sel[j]))"
(samples_in[d_sel[i],:], samples_in[d_sel[j],:])
end

# Output upper off-diagonal
subplot := j + (i - 1) * n_plots

@series begin
seriestype --> :histogram2d
bins --> n_bins_2D
color --> :reds
colorbar --> colorbar
aspect_ratio --> :equal
tickfontsize --> fs
labelfontsize --> fs
xlabel --> "y$(d_sel[i])"
ylabel --> "y$(d_sel[j])"
(samples_out[d_sel[i],:], samples_out[d_sel[j],:])
end
end

subplot := n_plots^2 + n_plots + 1 - i

if i == n_plots && !isnothing(training_metadata)
for k in 1:8
@series begin
seriestype --> :scatter
ticks --> false
framestyle --> :none
legend --> :bottomleft
legendfontsize --> fs
label --> training_metadata_labels[k]
markeralpha --> 0
[0]
end
end
elseif i == 1
for k in 1:8
@series begin
seriestype --> :scatter
ticks --> false
framestyle --> :none
legend --> :bottomright
legendfontsize --> fs
label --> samples_metadata[k]
markeralpha --> 0
[0]
end
end
else
@series begin
seriestype --> :scatter
ticks --> false
framestyle --> :none
legend --> false
markeralpha --> 0
[0]
end
end
end
end

@recipe function plot_flow_res(res::Union{NamedTuple{(:result, :optimizer_state, :loss_hist, :training_metadata), Tuple{CompositeFlow, Vector{Any}, Vector{Any}, Dict{Symbol, Any}}},
NamedTuple{(:result, :optimizer_state, :loss_hist, :training_metadata), Tuple{CompositeFlow, NamedTuple{(:flow,), Tuple{NamedTuple{(:fs,), Tuple{Vector{NamedTuple}}}}}, Vector{Float64}, Dict{Symbol, Any}}}},
samples::AbstractMatrix;
d_sel = nothing,
n_bins_1D = 100,
n_bins_2D = 100,
colorbar = false,
h_x = true,
h_y = true,
fs = nothing,
size_plot = (1000, 1000),
p = x -> 1/sqrt(2pi) * exp(-x^2/2)
)
@series begin
n_bins_1d := n_bins_1D
n_bins_2d := n_bins_2D
d_sel := d_sel
colorbar := colorbar
h_x := h_x
h_y := h_y
fs := fs
size := size_plot
p := p
training_metadata := res.training_metadata
(res.result, samples)
end
end

0 comments on commit 1175a80

Please sign in to comment.