Skip to content

Commit 601bb57

Browse files
Add nodal mask
Add mask support for pointwise operations Remove support for SpectralElementSpaceSlab, support stencil ops Apply formatter CUDA Fixes More fixes Updates + reviewer comments Fixes Improve threading for masked cuda operations Add more docs More docs, + fixes, need to fix doc example + cpu unit test dispatch, test, and gpu fixes Fixed UB in tests, update docs Update docs Fix git conflic artifact Update docs
1 parent c139e78 commit 601bb57

25 files changed

+767
-82
lines changed

.buildkite/pipeline.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ steps:
265265
key: "gpu_cuda_spaces"
266266
command:
267267
- "julia --project=.buildkite -e 'using CUDA; CUDA.versioninfo()'"
268-
- "srun julia --color=yes --check-bounds=yes --project=.buildkite test/Spaces/unit_spaces.jl"
269-
- "srun julia --color=yes --check-bounds=yes --project=.buildkite test/Spaces/opt_spaces.jl"
268+
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Spaces/unit_spaces.jl"
269+
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Spaces/opt_spaces.jl"
270270
- "julia --color=yes --check-bounds=yes --project=.buildkite test/Spaces/unit_high_resolution_space.jl"
271271
env:
272272
CLIMACOMMS_DEVICE: "CUDA"

NEWS.md

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ ClimaCore.jl Release Notes
44
main
55
-------
66

7+
- `SpectralElementSpace2D` constructors now support nodal masks. PR [2201](https://github.com/CliMA/ClimaCore.jl/pull/2201). See its documentation [here](https://clima.github.io/ClimaCore.jl/dev/masks). Note that it does not yet support restarts.
8+
79
- Added support for InputOutput with PointSpaces
810
PR [2162](https://github.com/CliMA/ClimaCore.jl/pull/2162).
911

docs/make.jl

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ withenv("GKSwstype" => "nul") do
8686
tutorial in TUTORIALS
8787
],
8888
"Examples" => "examples.md",
89+
"Masks" => "masks.md",
8990
"Debugging" => "debugging.md",
9091
"Libraries" => [
9192
joinpath("lib", "ClimaCorePlots.md"),

docs/src/masks.md

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Masks
2+
3+
## Motivation
4+
5+
ClimaCore spaces, `SpectralElementSpace2D`s in particular, support masks, where
6+
users can set horizontal nodal locations where operations are skipped.
7+
8+
This is especially helpful for the land model, where they may have degrees of
9+
freedom over the ocean, but do not want to evaluate expressions in regions where
10+
data is missing.
11+
12+
Masks in ClimaCore offer a solution to this by, ahead of time prescribing
13+
regions to skip. This helps both with the ergonomics, as well as performance.
14+
15+
## User interface
16+
17+
There are two user-facing parts for ClimaCore masks:
18+
19+
- set the `enable_mask = true` keyword in the space constructor (when available),
20+
which is currently any constructor that returns/contains a `SpectralElementSpace2D`.
21+
- use `set_mask!` to set where the mask is `true` (where compute should occur)
22+
and `false` (where compute should be skipped)
23+
24+
Here is an example
25+
26+
```julia
27+
using ClimaComms
28+
ClimaComms.@import_required_backends
29+
import ClimaCore: Spaces, Fields
30+
using ClimaCore.CommonSpaces
31+
using Test
32+
33+
FT = Float64
34+
ᶜspace = ExtrudedCubedSphereSpace(FT;
35+
z_elem = 10,
36+
z_min = 0,
37+
z_max = 1,
38+
radius = 10,
39+
h_elem = 10,
40+
n_quad_points = 4,
41+
staggering = CellCenter(),
42+
enable_mask = true,
43+
)
44+
45+
# How to set the mask
46+
Spaces.set_mask!(ᶜspace) do coords
47+
coords.lat > 0.5
48+
end
49+
# Or
50+
mask = Fields.Field(FT, ᶜspace)
51+
mask .= map(cf -> cf.lat > 0.5 ? 0.0 : 1.0, Fields.coordinate_field(mask))
52+
Spaces.set_mask!(ᶜspace, mask)
53+
```
54+
55+
Finally, operations over fields will be skipped where `mask == 0`, and applied
56+
where `mask == 1`:
57+
58+
```
59+
@. f = 1 # only applied where the mask is equal to 1
60+
```
61+
62+
## Example script
63+
64+
Here is a more complex script where the mask is used:
65+
66+
```julia
67+
using ClimaComms
68+
ClimaComms.@import_required_backends
69+
import ClimaCore: Spaces, Fields, DataLayouts, Geometry, Operators
70+
using ClimaCore.CommonSpaces
71+
using Test
72+
73+
FT = Float64
74+
ᶜspace = ExtrudedCubedSphereSpace(FT;
75+
z_elem = 10,
76+
z_min = 0,
77+
z_max = 1,
78+
radius = 10,
79+
h_elem = 10,
80+
n_quad_points = 4,
81+
staggering = CellCenter(),
82+
enable_mask = true,
83+
)
84+
ᶠspace = Spaces.face_space(ᶜspace)
85+
ᶠcoords = Fields.coordinate_field(ᶠspace)
86+
87+
# How to set the mask
88+
Spaces.set_mask!(ᶜspace) do coords
89+
coords.lat > 0.5
90+
end
91+
92+
# We also support the syntax `Spaces.set_mask!(::AbstractSpace, ::Field)`
93+
94+
# We can check the mask directly: (internals, only for demonstrative purposes)
95+
mask = Spaces.get_mask(ᶜspace)
96+
@test count(parent(mask.is_active)) == 4640
97+
@test length(parent(mask.is_active)) == 9600
98+
99+
# Let's skip operations that use fill!
100+
ᶜf = zeros(ᶜspace) # ignores mask
101+
@. ᶜf = 1 # tests fill! # abides by mask
102+
103+
# Let's show that 4640 columns were impacted:
104+
@test count(x->x==1, parent(ᶜf)) == 4640 * Spaces.nlevels(axes(ᶜf))
105+
@test length(parent(ᶜf)) == 9600 * Spaces.nlevels(axes(ᶜf))
106+
107+
# Let's skip operations that use copyto!
108+
ᶜz = Fields.coordinate_field(ᶜspace).z
109+
ᶜf = zeros(ᶜspace)
110+
@. ᶜf = 1 + 0 * ᶜz # tests copyto!
111+
112+
# Let's again show that 4640 columns were impacted:
113+
@test count(x->x==1, parent(ᶜf)) == 4640 * Spaces.nlevels(axes(ᶜf))
114+
@test length(parent(ᶜf)) == 9600 * Spaces.nlevels(axes(ᶜf))
115+
116+
# Let's skip operations in FiniteDifference operators
117+
ᶠf = zeros(ᶠspace)
118+
c = Fields.Field(FT, ᶜspace)
119+
div = Operators.DivergenceF2C()
120+
foo(f, cf) = cf.lat > 0.5 ? zero(f) : sqrt(-1) # results in NaN in masked out regions
121+
@. c = div(Geometry.WVector(foo(ᶠf, ᶠcoords)))
122+
123+
# Check that this field should never yield NaNs
124+
@test count(isnan, parent(c)) == 0
125+
126+
# Doing the same thing with a space without a mask will yield NaNs:
127+
ᶜspace_no_mask = ExtrudedCubedSphereSpace(FT;
128+
z_elem = 10,
129+
z_min = 0,
130+
z_max = 1,
131+
radius = 10,
132+
h_elem = 10,
133+
n_quad_points = 4,
134+
staggering = CellCenter(),
135+
)
136+
ᶠspace_no_mask = Spaces.face_space(ᶜspace_no_mask)
137+
ᶠcoords_no_mask = Fields.coordinate_field(ᶠspace_no_mask)
138+
c_no_mask = Fields.Field(FT, ᶜspace_no_mask)
139+
ᶠf_no_mask = Fields.Field(FT, ᶠspace_no_mask)
140+
@. c_no_mask = div(Geometry.WVector(foo(ᶠf_no_mask, ᶠcoords_no_mask)))
141+
@test count(isnan, parent(c_no_mask)) == 49600
142+
```
143+
144+
## Supported operations and caveats
145+
146+
Currently, masked _operations_ are only supported for `Fields` (and not
147+
`DataLayouts`) with `SpectralElementSpace2D`s. We do not yet have support for
148+
masked `SpectralElement1DSpace`s, and we will likely never offer masked
149+
operation support for `DataLayouts`, as they do not have the space, and can
150+
therefore not use the mask.
151+
152+
In addition, some operations with masked fields skip masked regions
153+
(i.e., mask-aware), and other operations execute everywhere
154+
(i.e., mask-unaware), effectively ignoring the mask. Here is a list of
155+
operations of mask-aware and mask-unaware:
156+
157+
- `DataLayout` operations (`Fields.field_values(f) = 1`) mask-unaware (will likely never be mask-aware).
158+
- `fill!` (`@. f = 1`) mask-aware
159+
- point-wise `copyto!` (`@. f = 1 + z`) mask-aware
160+
- stencil `copyto!` (`@. ᶜf = 1 + DivergenceF2C()(Geometry.WVector(ᶠf))`) mask-aware (vertical derivatives and interpolations interpolations)
161+
- spectral element operations `copyto!` (`@. f = 1 + Operators.Divergence()(f)`), where `Operators.Divergence` carries out a divergence operation in horizontal directions. mask-unaware
162+
- fieldvector operations `copyto!` (`@. Y += 1`) mask-unaware
163+
- reductions:
164+
- `sum` (mask-unaware, warning is thrown)
165+
- `extrema` (mask-unaware, warning is thrown)
166+
- `min` (mask-unaware, warning is thrown)
167+
- `max` (mask-unaware, warning is thrown)
168+
- field constructors (`copy`, `Fields.Field`, `ones`, `zeros`) are mask-unaware.
169+
This was a design implementation detail, users should not generally depend on the results where `mask == 0`, in case this is changed in the future.
170+
- internal array operations (`fill!(parent(field), 0)`) mask-unaware.
171+
172+
## Developer docs
173+
174+
In order to support masks, we define their types in `DataLayouts`, since
175+
we need access to them from within kernels in `DataLayouts`. We could have made
176+
an API and kept them completely orthogonal, but that would have been a bit more
177+
complicated, also, it was convenient to make the masks themselves data layouts,
178+
so it seemed most natural for them to live there.
179+
180+
We have a couple types:
181+
182+
- abstract `AbstractMask` for subtyping masks and use for generic interface
183+
methods
184+
- `NoMask` (the default), which is a lazy object that should effectively result
185+
in a no-op, without any loss of runtime performance
186+
- `IJHMask` currently the only supported horizontal mask, which contains
187+
`is_active` (defined in `set_mask!`), `N` (the number of active columns),
188+
and maps containing indices to the `i, j, h` locations where `is_active` is
189+
true. The maps are defined in `set_mask_maps!`, allows us to launch cuda
190+
kernels to only target the active columns, and threads are not wasted on
191+
non-existent columns. The logic to handle this is relatively thin, and
192+
extends our current `ext/cuda/datalayouts_threadblock.jl` api
193+
(via `masked_partition` and `masked_universal_index`).
194+
195+
An important note is that when we set the mask maps for active columns, the
196+
order that they are assigned can be permuted without impacting correctness, but
197+
this could have a big impact on performance on the gpu. We should investigate
198+
this.
199+

ext/ClimaCoreCUDAExt.jl

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import StaticArrays: SVector, SMatrix, SArray
1313
import ClimaCore.DebugOnly: call_post_op_callback, post_op_callback
1414
import ClimaCore.DataLayouts: mapreduce_cuda
1515
import ClimaCore.DataLayouts: ToCUDA
16+
import ClimaCore.DataLayouts: NoMask, IJHMask
1617
import ClimaCore.DataLayouts: slab, column
1718
import ClimaCore.Utilities: half
1819
import ClimaCore.Utilities: cart_ind, linear_ind

ext/cuda/adapt.jl

+10
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Adapt.adapt_structure(
3030
Adapt.adapt(to, grid.quadrature_style),
3131
Adapt.adapt(to, grid.global_geometry),
3232
Adapt.adapt(to, grid.local_geometry),
33+
Adapt.adapt(to, grid.mask),
3334
)
3435

3536
Adapt.adapt_structure(to::CUDA.KernelAdaptor, space::Spaces.PointSpace) =
@@ -53,3 +54,12 @@ Adapt.adapt_structure(
5354
lim.rtol,
5455
Limiters.NoConvergenceStats(),
5556
)
57+
58+
Adapt.adapt_structure(to::CUDA.KernelAdaptor, mask::DataLayouts.IJHMask) =
59+
DataLayouts.IJHMask(
60+
Adapt.adapt(to, mask.is_active),
61+
nothing,
62+
Adapt.adapt(to, mask.i_map),
63+
Adapt.adapt(to, mask.j_map),
64+
Adapt.adapt(to, mask.h_map),
65+
)

ext/cuda/data_layouts_copyto.jl

+27-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
DataLayouts.device_dispatch(x::CUDA.CuArray) = ToCUDA()
22

3-
function knl_copyto!(dest, src, us)
4-
I = universal_index(dest)
3+
function knl_copyto!(dest, src, us, mask)
4+
I = if mask isa NoMask
5+
universal_index(dest)
6+
else
7+
masked_universal_index(mask)
8+
end
59
if is_valid_index(dest, I, us)
610
@inbounds dest[I] = src[I]
711
end
@@ -24,31 +28,36 @@ if VERSION ≥ v"1.11.0-beta"
2428
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
2529
# (including the GPU-variant related issue resolution efforts:
2630
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
27-
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
31+
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA, mask = NoMask())
2832
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
2933
us = DataLayouts.UniversalSize(dest)
3034
if Nv > 0 && Nh > 0
31-
args = (dest, bc, us)
35+
args = (dest, bc, us, mask)
3236
threads = threads_via_occupancy(knl_copyto!, args)
3337
n_max_threads = min(threads, get_N(us))
34-
p = partition(dest, n_max_threads)
38+
p = if mask isa NoMask
39+
partition(dest, n_max_threads)
40+
else
41+
masked_partition(us, n_max_threads, mask)
42+
end
3543
auto_launch!(
3644
knl_copyto!,
3745
args;
3846
threads_s = p.threads,
3947
blocks_s = p.blocks,
4048
)
4149
end
42-
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
50+
call_post_op_callback() && post_op_callback(dest, dest, bc, to, mask)
4351
return dest
4452
end
4553
else
46-
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
54+
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA, mask = NoMask())
4755
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
4856
us = DataLayouts.UniversalSize(dest)
4957
if Nv > 0 && Nh > 0
5058
if DataLayouts.has_uniform_datalayouts(bc) &&
51-
dest isa DataLayouts.EndsWithField
59+
dest isa DataLayouts.EndsWithField &&
60+
mask isa NoMask
5261
bc′ = Base.Broadcast.instantiate(
5362
DataLayouts.to_non_extruded_broadcasted(bc),
5463
)
@@ -63,10 +72,14 @@ else
6372
blocks_s = p.blocks,
6473
)
6574
else
66-
args = (dest, bc, us)
75+
args = (dest, bc, us, mask)
6776
threads = threads_via_occupancy(knl_copyto!, args)
6877
n_max_threads = min(threads, get_N(us))
69-
p = partition(dest, n_max_threads)
78+
p = if mask isa NoMask
79+
partition(dest, n_max_threads)
80+
else
81+
masked_partition(us, n_max_threads, mask)
82+
end
7083
auto_launch!(
7184
knl_copyto!,
7285
args;
@@ -75,7 +88,7 @@ else
7588
)
7689
end
7790
end
78-
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
91+
call_post_op_callback() && post_op_callback(dest, dest, bc, to, mask)
7992
return dest
8093
end
8194
end
@@ -88,6 +101,7 @@ function Base.copyto!(
88101
dest::AbstractData,
89102
bc::Base.Broadcast.Broadcasted{Style},
90103
to::ToCUDA,
104+
mask = NoMask(),
91105
) where {
92106
Style <:
93107
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
@@ -96,8 +110,8 @@ function Base.copyto!(
96110
Base.Broadcast.Broadcasted{Style}(bc.f, bc.args, ()),
97111
)
98112
@inbounds bc0 = bc[]
99-
fill!(dest, bc0)
100-
call_post_op_callback() && post_op_callback(dest, dest, bc, to)
113+
fill!(dest, bc0, mask)
114+
call_post_op_callback() && post_op_callback(dest, dest, bc, to, mask)
101115
end
102116

103117
# For field-vector operations

0 commit comments

Comments
 (0)