Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add plot animation #91

Merged
merged 2 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SequentialSamplingModels"
uuid = "0e71a2a6-2b30-4447-8742-d083a85e82d1"
authors = ["itsdfish"]
version = "0.11.5"
version = "0.11.6"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
Binary file added docs/src/assets/rdm_animation.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 16 additions & 1 deletion docs/src/plot_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,19 @@ Random.seed!(77)
dist = RDM()
density_kwargs=(;t_range=range(.20, 1.0, length=100),)
plot_model(dist; n_sim=1, add_density=true, density_kwargs, xlims=(0,1.0))
```
```

## Animate

You can animate the evidence accumulation process with the plotting function `animate`, which works similarly to `plot_model`.

```julia
using SequentialSamplingModels
using Plots
using Random
Random.seed!(77)

dist = RDM()
animate(dist)
```
![](assets/rdm_animation.gif)
4 changes: 4 additions & 0 deletions ext/PlotsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import Plots: histogram
import Plots: histogram!
import Plots: plot
import Plots: plot!
import Plots: animate

import SequentialSamplingModels as SSMs
import SequentialSamplingModels: get_plot_defaults
import SequentialSamplingModels: get_model_plot_defaults
Expand All @@ -20,6 +22,7 @@ using KernelDensity
using KernelDensity: Epanechnikov
using LinearAlgebra
using Plots
using Plots: giffn
using SequentialSamplingModels
using SequentialSamplingModels: Approximate
using SequentialSamplingModels: Exact
Expand All @@ -32,4 +35,5 @@ include("plots/plot_model.jl")
include("plots/plot_quantiles.jl")
include("plots/plot_choices.jl")
include("plots/kde.jl")
include("plots/animation.jl")
end
137 changes: 137 additions & 0 deletions ext/plots/animation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
animate(
model;
add_density = false,
density_kwargs = (),
labels = get_default_labels(model),
density_scale = compute_threshold(model),
model_args = (),
model_kwargs = (),
kwargs...
)

Animates the evidence accumulation process of the specified model.

# Arguments

- `model`: a generic object representing an SSM

# Keywords

- `add_density=false`: add density plot above threshold line if true
- `density_kwargs=()`: pass optional keyword arguments to density plot
- `labels = get_default_labels(model)`: a vector of parameter label options
- `density_scale = compute_threshold(model)`: scale the maximum height of the density
- `file_path = giffn()`: the path in which the animation is saved. By default, it is saved to a temporary folder
called `tmp` using `giffn`.
- `fps = 30`: speed of animation in terms of frames per second
- `model_args = ()`: optional positional arguments passed to the `rand` and `simulate`
- `model_kwargs = ()`: optional keyword arguments passed to the `rand` and `simulate`
- `t_range`: the range of time points over which the probability density is plotted
- `kwargs...`: optional keyword arguments for configuring plot options
"""
function animate(
model;
add_density = false,
density_kwargs = (),
labels = get_default_labels(model),
density_scale = compute_threshold(model),
file_path = giffn(),
fps = 30,
model_args = (),
model_kwargs = (),
t_range = default_range(model),
kwargs...
)
times, evidence = simulate(model, model_args...; model_kwargs...)
y_min = minimum(evidence)

n_subplots = n_options(model)
defaults = get_model_plot_defaults(model)
α = compute_threshold(model)
ylims = (0, maximum(α))
xlims = (0, max(maximum(times) + model.τ * 1.1, maximum(t_range)))

animation = @animate for i ∈ 1:length(times)
model_plot = plot(; defaults..., kwargs...)
add_starting_point!(model, model_plot)

add_threshold!(model, model_plot)
for s ∈ 1:n_subplots
annotate!(labels, subplot = s)
end

model_plot = plot(
model_plot,
times[1:i] .+ model.τ,
evidence[1:i, :];
xlims,
ylims,
defaults...,
kwargs...
)

if add_density
add_density!(
model,
model_plot;
model_args,
model_kwargs,
density_scale,
t_range,
density_kwargs...
)
end
end
return gif(animation, file_path; fps)
end

"""
animate(
model::ContinuousMultivariateSSM;
add_density = false,
density_kwargs = (),
labels = get_default_labels(model),
kwargs...
)

Animates the evidence accumulation process of a continous multivariate sequential sampling model.

# Arguments

- `model::ContinuousMultivariateSSM`: a continous multivariate sequential sampling model

# Keywords

- `add_density=false`: add density plot above threshold line if true
- `density_kwargs=()`: pass optional keyword arguments to density plot
- `labels = get_default_labels(model)`: a vector of parameter label options
- `file_path = giffn()`: the path in which the animation is saved. By default, it is saved to a temporary folder
called `tmp` using `giffn`.
- `fps = 90`: speed of animation in terms of frames per second
- `kwargs...`: optional keyword arguments for configuring plot options
"""
function animate(
model::ContinuousMultivariateSSM;
add_density = false,
density_kwargs = (),
labels = get_default_labels(model),
file_path = giffn(),
fps = 90,
t_range = default_range(model),
kwargs...
)
defaults = get_model_plot_defaults(model)
times, evidence = simulate(model)
animation = @animate for i ∈ 1:4:length(times)
model_plot = plot(
evidence[1:i, 1],
evidence[1:i, 2],
line_z = times;
defaults...,
kwargs...
)
add_threshold!(model, model_plot)
end
return gif(animation, file_path; fps)
end
itsdfish marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion ext/plots/histogram.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
histogram(d::SSM2D; kwargs...)
histogram(d::SSM2D; norm = true, n_sim = 2000, kwargs...)

Plots the histogram of a multi-alternative sequential sampling model.

Expand Down
9 changes: 7 additions & 2 deletions ext/plots/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,13 @@ function plot!(
end

"""
plot!([cur_plot], d::ContinuousMultivariateSSM; t_range=default_range(d), kwargs...)

plot!(
cur_plot::Plots.Plot,
d::ContinuousMultivariateSSM;
t_range = default_range(d),
kwargs...
)

Adds the marginal probability density of a multivariate continuous sequential sampling model to an existing plot

# Arguments
Expand Down
17 changes: 13 additions & 4 deletions ext/plots/plot_choices.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""
plot_choices(data::AbstractVector{<:Real}, preds::AbstractArray{<:AbstractVector}; kwargs...)
plot_choices(
data::AbstractVector{<:Real},
preds::AbstractArray{<:AbstractVector};
kwargs...
)

Plots choice probability distributions for multi-choice SSMs.

# Arguments

- `data::AbstractVector`: a vector of observed choice proportions
- `data::AbstractVector{<:Real}`: a vector of observed choice proportions
- `preds::AbstractArray{<:AbstractVector}`: an array containing vectors of choice probabilities

# Keywords
Expand All @@ -22,13 +26,18 @@ function plot_choices(
end

"""
plot_choices!(cur_plot::Plots.Plot, data::AbstractVector{<:Real}, preds::AbstractArray{<:AbstractVector}; kwargs...)
plot_choices!(
cur_plot::Plots.Plot,
data::AbstractVector{<:Real},
preds::AbstractArray{<:AbstractVector};
kwargs...
)

Adds to a current plot choice probability distributions for multi-choice SSMs.

# Arguments

- `data::AbstractVector`: a vector of observed choice proportions
- `data::AbstractVector{<:Real}`: a vector of observed choice proportions
- `preds::AbstractArray{<:AbstractVector}`: an array containing vectors of choice probabilities

# Keywords
Expand Down
24 changes: 15 additions & 9 deletions ext/plots/plot_model.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""
plot_model(model;
add_density = false,
density_kwargs = (),
labels = get_default_labels(model),
plot_model(
model;
add_density = false,
density_kwargs = (),
labels = get_default_labels(model),
density_scale = compute_threshold(model),
n_sim = 1,
model_args = (),
model_kwargs = (),
model_kwargs = (),
kwargs...
)

Plot the evidence accumulation process of a generic SSM.

# Arguments
Expand Down Expand Up @@ -70,16 +71,21 @@ function plot_model(
model_args,
model_kwargs,
density_scale,
ylims = (y_min, Inf),
density_kwargs...
)
end
return model_plot
end

"""
plot_model(model;
add_density=false, density_kwargs=(), n_sim=1, kwargs...)
plot_model(
model::ContinuousMultivariateSSM;
add_density = false,
density_kwargs = (),
labels = get_default_labels(model),
n_sim = 1,
kwargs...
)

Plot the evidence accumulation process of a continous multivariate sequential sampling model.

Expand Down
30 changes: 24 additions & 6 deletions ext/plots/plot_quantiles.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
plot_quantiles(q_data::AbstractVector{<:AbstractVector}, q_preds::Matrix{<:AbstractVector}; kwargs...)

plot_quantiles(
q_data::AbstractVector{<:AbstractVector},
q_preds::Matrix{<:AbstractVector};
kwargs...
)

Plots the predictive quantile distribution against the quantiles of the data for multi-choice SSMs.

# Arguments
Expand All @@ -27,7 +31,12 @@ function plot_quantiles(
end

"""
plot_quantiles!(cur_plot::Plots.Plot, q_data::AbstractVector{<:AbstractVector}, q_preds::Matrix{<:AbstractVector}; kwargs...)
plot_quantiles!(
cur_plot::Plots.Plot,
q_data::AbstractVector{<:AbstractVector},
q_preds::Matrix{<:AbstractVector};
kwargs...
)

Adds to an existing plot the predictive quantile distribution against the quantiles of the data for multi-choice SSMs.

Expand Down Expand Up @@ -64,7 +73,11 @@ function plot_quantiles!(
end

"""
plot_quantiles(q_data::AbstractVector{<:AbstractVector}, q_preds::Matrix{<:AbstractVector}; kwargs...)
plot_quantiles(
q_data::AbstractVector,
q_preds::AbstractArray{<:AbstractVector};
kwargs...
)

Plots the predictive quantile distribution against the quantiles of the data for single choice SSMs.

Expand All @@ -86,8 +99,13 @@ function plot_quantiles(
end

"""
plot_quantiles!(q_data::AbstractVector{<:AbstractVector}, q_preds::Matrix{<:AbstractVector}; kwargs...)

plot_quantiles!(
cur_plot::Plots.Plot,
q_data::AbstractVector,
q_preds::AbstractArray{<:AbstractVector};
kwargs...
)

Adds to an existing plot the predictive quantile distribution against the quantiles of the data for single choice SSMs.

# Arguments
Expand Down
Loading