Skip to content

Commit 526b2d3

Browse files
committed
extend dss functions for FieldVectors
1 parent 2c12dd3 commit 526b2d3

File tree

4 files changed

+88
-1
lines changed

4 files changed

+88
-1
lines changed

NEWS.md

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ ClimaCore.jl Release Notes
44
main
55
-------
66

7+
- Extended `create_dss_buffer` and `weighted_dss!` for `FieldVector`s, rather than
8+
just `Field`s. PR [#2000](https://github.com/CliMA/ClimaCore.jl/pull/2000).
9+
710
v0.14.16
811
-------
912

src/Fields/Fields.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ function interpcoord(elemrange, x::Real)
374374
end
375375

376376
"""
377-
Spaces.weighted_dss!(f::Field[, ghost_buffer = Spaces.create_dss_buffer(field)])
377+
Spaces.weighted_dss!(f::Field, dss_buffer = Spaces.create_dss_buffer(field))
378378
379379
Apply weighted direct stiffness summation (DSS) to `f`. This operates in-place
380380
(i.e. it modifies the `f`). `ghost_buffer` contains the necessary information

src/Fields/fieldvector.jl

+39
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import BlockArrays
2+
import ClimaCore.Utilities.UnrolledFunctions: unrolled_map, unrolled_foreach
23

34

45
"""
@@ -184,6 +185,44 @@ end
184185
return dest
185186
end
186187

188+
"""
189+
Spaces.create_dss_buffer(fv::FieldVector)
190+
191+
Create a NamedTuple of buffers for communicating neighbour information of
192+
each Field in `fv`. In this NamedTuple, the name of each field is mapped
193+
to the buffer.
194+
"""
195+
function Spaces.create_dss_buffer(fv::FieldVector)
196+
NamedTuple{propertynames(fv)}(
197+
unrolled_map(
198+
key -> Spaces.create_dss_buffer(getproperty(fv, key)),
199+
propertynames(fv),
200+
),
201+
)
202+
end
203+
204+
"""
205+
Spaces.weighted_dss!(fv::FieldVector, dss_buffer = Spaces.create_dss_buffer(fv))
206+
207+
Apply weighted direct stiffness summation (DSS) to each field in `fv`.
208+
If a `dss_buffer` object is not provided, a buffer will be created for each
209+
field in `fv`.
210+
"""
211+
function Spaces.weighted_dss!(
212+
fv::FieldVector,
213+
dss_buffer = Spaces.create_dss_buffer(fv),
214+
)
215+
unrolled_foreach(
216+
key -> Spaces.weighted_dss!(
217+
getproperty(fv, key),
218+
getproperty(dss_buffer, key),
219+
),
220+
propertynames(fv),
221+
)
222+
end
223+
# TODO distribute `propertynames(fv)` over processes to parallelize `unrolled_foreach`
224+
225+
187226
# Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer
188227
# see Base.Broadcast.preprocess_args
189228
@inline transform_bc_args(args::Tuple, inds...) = (

test/Fields/field_opt.jl

+45
Original file line numberDiff line numberDiff line change
@@ -393,4 +393,49 @@ using JET
393393
@test_opt ifelsekernel!(S, ρ)
394394
end
395395

396+
@testset "dss of FieldVectors" begin
397+
function field_vec(center_space, face_space)
398+
Y = Fields.FieldVector(
399+
c = map(Fields.coordinate_field(center_space)) do coord
400+
FT = Spaces.undertype(center_space)
401+
(;
402+
ρ = FT(coord.lat + coord.long),
403+
uₕ = Geometry.Covariant12Vector(
404+
FT(coord.lat),
405+
FT(coord.long),
406+
),
407+
)
408+
end,
409+
f = map(Fields.coordinate_field(face_space)) do coord
410+
FT = Spaces.undertype(face_space)
411+
(; w = Geometry.Covariant3Vector(FT(coord.lat + coord.long)))
412+
end,
413+
)
414+
return Y
415+
end
416+
417+
fv = field_vec(toy_sphere(Float64)...)
418+
419+
c_copy = copy(getproperty(fv, :c))
420+
f_copy = copy(getproperty(fv, :f))
421+
422+
# Test that dss_buffer is created and has the correct keys
423+
dss_buffer = Spaces.create_dss_buffer(fv)
424+
@test haskey(dss_buffer, :c)
425+
@test haskey(dss_buffer, :f)
426+
427+
# Test weighted_dss! with and without preallocated buffer
428+
Spaces.weighted_dss!(fv, dss_buffer)
429+
@test getproperty(fv, :c) Spaces.weighted_dss!(c_copy)
430+
@test getproperty(fv, :f) Spaces.weighted_dss!(f_copy)
431+
432+
fv = field_vec(toy_sphere(Float64)...)
433+
c_copy = copy(getproperty(fv, :c))
434+
f_copy = copy(getproperty(fv, :f))
435+
436+
Spaces.weighted_dss!(fv)
437+
@test getproperty(fv, :c) Spaces.weighted_dss!(c_copy)
438+
@test getproperty(fv, :f) Spaces.weighted_dss!(f_copy)
439+
end
440+
396441
nothing

0 commit comments

Comments
 (0)