Skip to content

Commit

Permalink
Access color schemes through symbols (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Oct 18, 2023
1 parent c7ec885 commit 3df1717
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/VisionHeatmaps.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module VisionHeatmaps

using ColorSchemes: ColorScheme, get, seismic
using ColorSchemes: ColorScheme, colorschemes, get, seismic
using ImageCore

include("heatmap.jl")
Expand Down
8 changes: 6 additions & 2 deletions src/heatmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Visualize 4D arrays as heatmaps, assuming the WHCN convention for input array di
(width, height, color channels, batch dimension).
## Keyword arguments
- `colorscheme::ColorScheme`: Color scheme from ColorSchemes.jl.
- `colorscheme::Union{ColorScheme,Symbol}`: Color scheme from ColorSchemes.jl.
Defaults to `seismic`.
- `reduce::Symbol`: Selects how color channels are reduced to a single number to apply a color scheme.
The following methods can be selected, which are then applied over the color channels
Expand All @@ -32,14 +32,15 @@ Visualize 4D arrays as heatmaps, assuming the WHCN convention for input array di
"""
function heatmap(
val::AbstractArray{T,N};
colorscheme::ColorScheme=DEFAULT_COLORSCHEME,
colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME,
reduce::Symbol=DEFAULT_REDUCE,
rangescale::Symbol=DEFAULT_RANGESCALE,
permute::Bool=true,
unpack_singleton::Bool=true,
process_batch::Bool=false,
) where {T,N}
N != 4 && throw(InputDimensionError)
colorscheme = get_colorscheme(colorscheme)
if unpack_singleton && size(val, 4) == 1
return single_heatmap(val[:, :, :, 1], colorscheme, reduce, rangescale, permute)
end
Expand All @@ -58,6 +59,9 @@ const InputDimensionError = ArgumentError(
Please reshape your input to match this format if your model doesn't adhere to this convention.",
)

get_colorscheme(c::ColorScheme) = c
get_colorscheme(s::Symbol)::ColorScheme = colorschemes[s]

# Lower level function, mapped along batch dimension
function single_heatmap(
val, colorscheme::ColorScheme, reduce::Symbol, rangescale::Symbol, permute::Bool
Expand Down
4 changes: 4 additions & 0 deletions test/test_heatmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ end
@testset "ColorSchemes" begin
h = heatmap(A; colorscheme=ColorSchemes.inferno)
@test_reference "references/inferno.txt" h

# Test colorscheme symbols
h = heatmap(A; colorscheme=:inferno)
@test_reference "references/inferno.txt" h
end

@testset "Error handling" begin
Expand Down

0 comments on commit 3df1717

Please sign in to comment.