Skip to content

Commit 7edba4c

Browse files
committed
extend dss functions for FieldVectors
1 parent bd20629 commit 7edba4c

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

src/Fields/Fields.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ function interpcoord(elemrange, x::Real)
374374
end
375375

376376
"""
377-
Spaces.weighted_dss!(f::Field[, ghost_buffer = Spaces.create_dss_buffer(field)])
377+
Spaces.weighted_dss!(f::Field, dss_buffer = Spaces.create_dss_buffer(field))
378378
379379
Apply weighted direct stiffness summation (DSS) to `f`. This operates in-place
380380
(i.e. it modifies the `f`). `ghost_buffer` contains the necessary information

src/Fields/fieldvector.jl

+38
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import BlockArrays
2+
import ClimaCore.Utilities.UnrolledFunctions: unrolled_map, unrolled_foreach
23

34

45
"""
@@ -184,6 +185,43 @@ end
184185
return dest
185186
end
186187

188+
"""
189+
Spaces.create_dss_buffer(fv::FieldVector)
190+
191+
Create a NamedTuple of buffers for communicating neighbour information of
192+
each Field in `fv`. In this NamedTuple, the name of each field is mapped
193+
to the buffer.
194+
"""
195+
function Spaces.create_dss_buffer(fv::FieldVector)
196+
NamedTuple{propertynames(fv)}(
197+
unrolled_map(
198+
key -> Spaces.create_dss_buffer(getproperty(fv, key)),
199+
propertynames(fv),
200+
),
201+
)
202+
end
203+
204+
"""
205+
Spaces.weighted_dss!(fv::FieldVector, dss_buffer = Spaces.create_dss_buffer(fv))
206+
207+
Apply weighted direct stiffness summation (DSS) to each field in `fv`.
208+
Reuse the same `dss_buffer` for all fields in `fv`, which is constructed using
209+
the first field in `fv` if it isn't passed explicitly.
210+
"""
211+
# TODO distribute `propertynames(fv)` over processes to parallelize `unrolled_foreach`
212+
function Spaces.weighted_dss!(
213+
fv::FieldVector,
214+
dss_buffer = Spaces.create_dss_buffer(fv),
215+
)
216+
unrolled_foreach(
217+
key -> Spaces.weighted_dss!(
218+
getproperty(fv, key),
219+
getproperty(dss_buffer, key),
220+
),
221+
propertynames(fv),
222+
)
223+
end
224+
187225
# Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer
188226
# see Base.Broadcast.preprocess_args
189227
@inline transform_bc_args(args::Tuple, inds...) = (

test/Fields/field_opt.jl

+37
Original file line numberDiff line numberDiff line change
@@ -393,4 +393,41 @@ using JET
393393
@test_opt ifelsekernel!(S, ρ)
394394
end
395395

396+
@testset "dss of FieldVectors" begin
397+
function field_vec(center_space, face_space)
398+
Y = Fields.FieldVector(
399+
c = map(Fields.coordinate_field(center_space)) do coord
400+
FT = Spaces.undertype(center_space)
401+
(; ρ = FT(0), uₕ = Geometry.Covariant12Vector(FT(0), FT(0)))
402+
end,
403+
f = map(Fields.coordinate_field(face_space)) do coord
404+
FT = Spaces.undertype(face_space)
405+
(; w = Geometry.Covariant3Vector(FT(0)))
406+
end,
407+
)
408+
return Y
409+
end
410+
411+
fv = field_vec(toy_sphere(Float64)...)
412+
413+
# Test that dss_buffer is created and has the correct keys and buffer types
414+
dss_buffer = Spaces.create_dss_buffer(fv)
415+
@test haskey(dss_buffer, :c)
416+
@test haskey(dss_buffer, :f)
417+
@test getproperty(dss_buffer, :c) isa Topologies.DSSBuffer
418+
@test getproperty(dss_buffer, :f) isa Topologies.DSSBuffer
419+
420+
c_copy = copy(getproperty(fv, :c))
421+
f_copy = copy(getproperty(fv, :f))
422+
423+
# Test weighted_dss! with and without preallocated buffer
424+
p = @allocated Spaces.weighted_dss!(fv, dss_buffer) # DSS2
425+
@test getproperty(fv, :c) c_copy
426+
@test getproperty(fv, :f) f_copy
427+
428+
Spaces.weighted_dss!(fv)
429+
@test getproperty(fv, :c) c_copy
430+
@test getproperty(fv, :f) f_copy
431+
end
432+
396433
nothing

0 commit comments

Comments
 (0)