Skip to content

Commit 17f5a70

Browse files
Define CartesianFieldIndex, use less internals
1 parent 0635ff3 commit 17f5a70

File tree

8 files changed

+82
-75
lines changed

8 files changed

+82
-75
lines changed

.buildkite/pipeline.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ steps:
8888

8989
- label: "Unit: getindex_field"
9090
key: unit_data_getindex_field
91-
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_getindex_field.jl"
91+
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_cartesian_field_index.jl"
9292

9393
- label: "Unit: mapreduce"
9494
key: unit_data_mapreduce

ext/cuda/topologies_dss.jl

+21-41
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import ClimaCore: DataLayouts, Topologies, Spaces, Fields
2-
import ClimaCore.DataLayouts: getindex_field, setindex_field!
2+
import ClimaCore.DataLayouts: CartesianFieldIndex
33
using CUDA
44
import ClimaCore.Topologies
55
import ClimaCore.Topologies: perimeter_vertex_node_index
@@ -44,13 +44,13 @@ function dss_load_perimeter_data_kernel!(
4444
(nperimeter, _, _, nlevels, nelems) = size(perimeter_data)
4545
nfidx = DataLayouts.ncomponents(perimeter_data)
4646
sizep = (nlevels, nperimeter, nfidx, nelems) # assume VIFH order
47-
CI = CartesianIndex
47+
CI = CartesianFieldIndex
4848

4949
if gidx prod(sizep)
5050
(level, p, fidx, elem) = cart_ind(sizep, gidx).I
5151
(ip, jp) = perimeter[p]
52-
val = getindex_field(data, CI(ip, jp, fidx, level, elem))
53-
setindex_field!(perimeter_data, val, CI(p, 1, fidx, level, elem))
52+
perimeter_data[CI(p, 1, fidx, level, elem)] =
53+
data[CI(ip, jp, fidx, level, elem)]
5454
end
5555
return nothing
5656
end
@@ -84,13 +84,13 @@ function dss_unload_perimeter_data_kernel!(
8484
(nperimeter, _, _, nlevels, nelems) = size(perimeter_data)
8585
nfidx = DataLayouts.ncomponents(perimeter_data)
8686
sizep = (nlevels, nperimeter, nfidx, nelems) # assume VIFH order
87-
CI = CartesianIndex
87+
CI = CartesianFieldIndex
8888

8989
if gidx prod(sizep)
9090
(level, p, fidx, elem) = cart_ind(sizep, gidx).I
9191
(ip, jp) = perimeter[p]
92-
val = getindex_field(perimeter_data, CI(p, 1, fidx, level, elem))
93-
setindex_field!(data, val, CI(ip, jp, fidx, level, elem))
92+
data[CI(ip, jp, fidx, level, elem)] =
93+
perimeter_data[CI(p, 1, fidx, level, elem)]
9494
end
9595
return nothing
9696
end
@@ -139,7 +139,7 @@ function dss_local_kernel!(
139139
nlocalfaces = length(interior_faces)
140140
(nperimeter, _, _, nlevels, _) = size(perimeter_data)
141141
nfidx = DataLayouts.ncomponents(perimeter_data)
142-
CI = CartesianIndex
142+
CI = CartesianFieldIndex
143143
if gidx nlevels * nfidx * nlocalvertices # local vertices
144144
sizev = (nlevels, nfidx, nlocalvertices)
145145
(level, fidx, vertexid) = cart_ind(sizev, gidx).I
@@ -149,17 +149,12 @@ function dss_local_kernel!(
149149
for idx in st:(en - 1)
150150
(lidx, vert) = local_vertices[idx]
151151
ip = perimeter_vertex_node_index(vert)
152-
sum_data +=
153-
getindex_field(perimeter_data, CI(ip, 1, fidx, level, lidx))
152+
sum_data += perimeter_data[CI(ip, 1, fidx, level, lidx)]
154153
end
155154
for idx in st:(en - 1)
156155
(lidx, vert) = local_vertices[idx]
157156
ip = perimeter_vertex_node_index(vert)
158-
setindex_field!(
159-
perimeter_data,
160-
sum_data,
161-
CI(ip, 1, fidx, level, lidx),
162-
)
157+
perimeter_data[CI(ip, 1, fidx, level, lidx)] = sum_data
163158
end
164159
elseif gidx nlevels * nfidx * (nlocalvertices + nlocalfaces) # interior faces
165160
nfacedof = div(nperimeter - 4, 4)
@@ -176,11 +171,9 @@ function dss_local_kernel!(
176171
ip2 = inc2 == 1 ? first2 + i - 1 : first2 - i + 1
177172
idx1 = CI(ip1, 1, fidx, level, lidx1)
178173
idx2 = CI(ip2, 1, fidx, level, lidx2)
179-
val =
180-
getindex_field(perimeter_data, idx1) +
181-
getindex_field(perimeter_data, idx2)
182-
setindex_field!(perimeter_data, val, idx1)
183-
setindex_field!(perimeter_data, val, idx2)
174+
val = perimeter_data[idx1] + perimeter_data[idx2]
175+
perimeter_data[idx1] = val
176+
perimeter_data[idx2] = val
184177
end
185178
end
186179

@@ -353,7 +346,7 @@ function dss_local_ghost_kernel!(
353346
FT = eltype(parent(perimeter_data))
354347
(nperimeter, _, _, nlevels, _) = size(perimeter_data)
355348
nfidx = DataLayouts.ncomponents(perimeter_data)
356-
CI = CartesianIndex
349+
CI = CartesianFieldIndex
357350
nghostvertices = length(ghost_vertex_offset) - 1
358351
if gidx nlevels * nfidx * nghostvertices
359352
sizev = (nlevels, nfidx, nghostvertices)
@@ -365,19 +358,14 @@ function dss_local_ghost_kernel!(
365358
isghost, lidx, vert = ghost_vertices[idx]
366359
if !isghost
367360
ip = perimeter_vertex_node_index(vert)
368-
sum_data +=
369-
getindex_field(perimeter_data, CI(ip, 1, fidx, level, lidx))
361+
sum_data += perimeter_data[CI(ip, 1, fidx, level, lidx)]
370362
end
371363
end
372364
for idx in st:(en - 1)
373365
isghost, lidx, vert = ghost_vertices[idx]
374366
if !isghost
375367
ip = perimeter_vertex_node_index(vert)
376-
setindex_field!(
377-
perimeter_data,
378-
sum_data,
379-
CI(ip, 1, fidx, level, lidx),
380-
)
368+
perimeter_data[CI(ip, 1, fidx, level, lidx)] = sum_data
381369
end
382370
end
383371
end
@@ -421,14 +409,13 @@ function fill_send_buffer_kernel!(
421409
(_, _, _, nlevels, nelems) = size(perimeter_data)
422410
nfid = DataLayouts.ncomponents(perimeter_data)
423411
sizet = (nlevels, nfid, nsend)
424-
CI = CartesianIndex
412+
CI = CartesianFieldIndex
425413
if gidx nlevels * nfid * nsend
426414
(level, fidx, isend) = cart_ind(sizet, gidx).I
427415
lidx = send_buf_idx[isend, 1]
428416
ip = send_buf_idx[isend, 2]
429417
idx = level + ((fidx - 1) + (isend - 1) * nfid) * nlevels
430-
send_data[idx] =
431-
getindex_field(perimeter_data, CI(ip, 1, fidx, level, lidx))
418+
send_data[idx] = perimeter_data[CI(ip, 1, fidx, level, lidx)]
432419
end
433420
return nothing
434421
end
@@ -527,27 +514,20 @@ function dss_ghost_kernel!(
527514
(_, _, _, nlevels, _) = size(perimeter_data)
528515
nfidx = DataLayouts.ncomponents(perimeter_data)
529516
nghostvertices = length(ghost_vertex_offset) - 1
530-
CI = CartesianIndex
517+
CI = CartesianFieldIndex
531518
if gidx nlevels * nfidx * nghostvertices
532519
(level, fidx, ghostvertexidx) =
533520
cart_ind((nlevels, nfidx, nghostvertices), gidx).I
534521
idxresult, lvertresult = repr_ghost_vertex[ghostvertexidx]
535522
ipresult = perimeter_vertex_node_index(lvertresult)
536-
result = getindex_field(
537-
perimeter_data,
538-
CI(ipresult, 1, fidx, level, idxresult),
539-
)
523+
result = perimeter_data[CI(ipresult, 1, fidx, level, idxresult)]
540524
st, en = ghost_vertex_offset[ghostvertexidx],
541525
ghost_vertex_offset[ghostvertexidx + 1]
542526
for vertexidx in st:(en - 1)
543527
isghost, eidx, lvert = ghost_vertices[vertexidx]
544528
if !isghost
545529
ip = perimeter_vertex_node_index(lvert)
546-
setindex_field!(
547-
perimeter_data,
548-
result,
549-
CI(ip, 1, fidx, level, eidx),
550-
)
530+
perimeter_data[CI(ip, 1, fidx, level, eidx)] = result
551531
end
552532
end
553533
end

lib/ClimaCoreTempestRemap/src/onlineremap.jl

+13-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ClimaCore.DataLayouts
2+
using ClimaCore.DataLayouts: CartesianFieldIndex
23
using ClimaComms
34

45

@@ -42,11 +43,9 @@ function remap!(
4243
R::LinearMap,
4344
source::IJFH{S, Nqs},
4445
) where {S, Nqt, Nqs}
45-
source_array = parent(source)
46-
target_array = parent(target)
47-
48-
fill!(target_array, zero(eltype(target_array)))
49-
Nf = size(target_array, 3)
46+
fill!(target, zero(eltype(target)))
47+
Nf = DataLayouts.ncomponents(target)
48+
CI = CartesianFieldIndex
5049

5150
# ideally we would use the tempestremap dgll (redundant node) representation
5251
# unfortunately, this doesn't appear to work quite as well (for out_type = dgll) as the cgll
@@ -62,7 +61,7 @@ function remap!(
6261
view(R.target_local_idxs[3], n)[1],
6362
)
6463
for f in 1:Nf
65-
target_array[it, jt, f, et] += wt * source_array[is, js, f, es]
64+
target[CI(it, jt, f, 1, et)] += wt * source[CI(is, js, f, 1, es)]
6665
end
6766
end
6867

@@ -91,11 +90,12 @@ function remap!(target::Fields.Field, R::LinearMap, source::Fields.Field)
9190
@assert Spaces.topology(axes(source)).context isa
9291
ClimaComms.SingletonCommsContext
9392

94-
target_array = parent(target)
95-
source_array = parent(source)
93+
CI = CartesianFieldIndex
94+
target_values = Fields.field_values(target)
95+
source_values = Fields.field_values(source)
9696

97-
fill!(target_array, zero(eltype(target_array)))
98-
Nf = size(target_array, 3)
97+
fill!(target, zero(eltype(target)))
98+
Nf = DataLayouts.ncomponents(target)
9999

100100
# ideally we would use the tempestremap dgll (redundant node) representation
101101
# unfortunately, this doesn't appear to work quite as well (for out_type = dgll) as the cgll
@@ -118,8 +118,9 @@ function remap!(target::Fields.Field, R::LinearMap, source::Fields.Field)
118118
# only use local weights - i.e. et, es != 0
119119
if (et != 0)
120120
for f in 1:Nf
121-
target_array[it, jt, f, et] +=
122-
wt * source_array[is, js, f, es]
121+
ci_src = CI(is, js, f, 1, es)
122+
ci_tar = CI(it, jt, f, 1, et)
123+
target_values[ci_tar] += wt * source_values[ci_src]
123124
end
124125
end
125126
end

src/DataLayouts/DataLayouts.jl

+23
Original file line numberDiff line numberDiff line change
@@ -1491,6 +1491,29 @@ end
14911491
)
14921492
end
14931493

1494+
"""
1495+
CartesianFieldIndex{N} <: Base.AbstractCartesianIndex{N}
1496+
1497+
A CartesianIndex wrapper to dispatch `getindex` / `setindex!`
1498+
to call [`getindex_field`](@ref) and [`setindex_field!`](@ref)
1499+
for specific field variables in a datalayout.
1500+
"""
1501+
struct CartesianFieldIndex{N} <: Base.AbstractCartesianIndex{N}
1502+
CI::CartesianIndex{N}
1503+
end
1504+
CartesianFieldIndex(I...) = CartesianFieldIndex(CartesianIndex(I...))
1505+
1506+
Base.ndims(::CartesianFieldIndex{N}) where {N} = N
1507+
Base.@propagate_inbounds Base.getindex(
1508+
data::AbstractData,
1509+
CI::CartesianFieldIndex,
1510+
) = getindex_field(data, CI.CI)
1511+
Base.@propagate_inbounds Base.setindex!(
1512+
data::AbstractData,
1513+
val::Real,
1514+
CI::CartesianFieldIndex,
1515+
) = setindex_field!(data, val, CI.CI)
1516+
14941517
"""
14951518
getindex_field(data, ci::CartesianIndex{5})
14961519

src/Topologies/dss.jl

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using DocStringExtensions
2-
using .DataLayouts: getindex_field, setindex_field!
2+
using .DataLayouts: CartesianFieldIndex
33

44
"""
55
DSSBuffer{G, D, A, B}
@@ -582,13 +582,12 @@ function fill_send_buffer!(
582582
Nf = DataLayouts.ncomponents(perimeter_data)
583583
nsend = size(send_buf_idx, 1)
584584
ctr = 1
585-
CI = CartesianIndex
585+
CI = CartesianFieldIndex
586586
@inbounds for i in 1:nsend
587587
lidx = send_buf_idx[i, 1]
588588
ip = send_buf_idx[i, 2]
589589
for f in 1:Nf, v in 1:Nv
590-
send_data[ctr] =
591-
getindex_field(perimeter_data, CI(ip, 1, f, v, lidx))
590+
send_data[ctr] = perimeter_data[CI(ip, 1, f, v, lidx)]
592591
ctr += 1
593592
end
594593
end
@@ -612,14 +611,13 @@ function load_from_recv_buffer!(
612611
Nf = DataLayouts.ncomponents(perimeter_data)
613612
nrecv = size(recv_buf_idx, 1)
614613
ctr = 1
615-
CI = CartesianIndex
614+
CI = CartesianFieldIndex
616615
@inbounds for i in 1:nrecv
617616
lidx = recv_buf_idx[i, 1]
618617
ip = recv_buf_idx[i, 2]
619618
for f in 1:Nf, v in 1:Nv
620619
ci = CI(ip, 1, f, v, lidx)
621-
val = getindex_field(perimeter_data, ci) + recv_data[ctr]
622-
setindex_field!(perimeter_data, val, ci)
620+
perimeter_data[ci] += recv_data[ctr]
623621
ctr += 1
624622
end
625623
end

test/DataLayouts/unit_getindex_field.jl test/DataLayouts/unit_cartesian_field_index.jl

+14-11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#=
22
julia --project
3-
using Revise; include(joinpath("test", "DataLayouts", "unit_getindex_field.jl"))
3+
using Revise; include(joinpath("test", "DataLayouts", "unit_cartesian_field_index.jl"))
44
=#
55
using Test
66
using ClimaCore.DataLayouts
7-
using ClimaCore.DataLayouts: getindex_field, setindex_field!
7+
using ClimaCore.DataLayouts: CartesianFieldIndex
88
using ClimaCore.DataLayouts: to_data_specific_field, singleton
99
import ClimaCore.Geometry
1010
import ClimaComms
@@ -31,15 +31,18 @@ function test_copyto_float!(data)
3131
ArrayType = ClimaComms.array_type(ClimaComms.device())
3232
FT = eltype(parent(data))
3333
parent(rand_data) .= ArrayType(rand(FT, DataLayouts.farray_size(data)))
34-
# For a float, getindex and getindex_field return the same thing
34+
# For a float, CartesianIndex and CartesianFieldIndex return the same thing
3535
for I in CartesianIndices(universal_axes(data))
36-
@test getindex_field(data, I) == getindex(data, I)
36+
CI = CartesianFieldIndex(I.I)
37+
@test data[CI] == data[I]
3738
end
3839
for I in CartesianIndices(universal_axes(data))
39-
setindex_field!(data, FT(prod(I.I)), I)
40+
CI = CartesianFieldIndex(I.I)
41+
data[CI] = FT(prod(I.I))
4042
end
4143
for I in CartesianIndices(universal_axes(data))
42-
@test getindex_field(data, I) == prod(I.I)
44+
CI = CartesianFieldIndex(I.I)
45+
@test data[CI] == prod(I.I)
4346
end
4447
end
4548

@@ -55,7 +58,7 @@ function test_copyto!(data)
5558
for f in 1:DataLayouts.ncomponents(data)
5659
UFI = universal_field_index(I, f)
5760
DSI = CartesianIndex(to_data_specific_field(singleton(data), UFI.I))
58-
@test getindex_field(data, UFI) == parent(data)[DSI]
61+
@test data[CartesianFieldIndex(UFI)] == parent(data)[DSI]
5962
end
6063
end
6164

@@ -64,13 +67,13 @@ function test_copyto!(data)
6467
UFI = universal_field_index(I, f)
6568
DSI = CartesianIndex(to_data_specific_field(singleton(data), UFI.I))
6669
val = parent(data)[DSI]
67-
setindex_field!(data, val + 1, UFI)
70+
data[CartesianFieldIndex(UFI)] = val + 1
6871
@test parent(data)[DSI] == val + 1
6972
end
7073
end
7174
end
7275

73-
@testset "copyto! with Nf = 1" begin
76+
@testset "CartesianFieldIndex with Nf = 1" begin
7477
device = ClimaComms.device()
7578
ArrayType = ClimaComms.array_type(device)
7679
FT = Float64
@@ -99,7 +102,7 @@ end
99102
# data = DataLayouts.IH1JH2{S}(ArrayType{FT}, zeros; Nij); test_copyto_float!(data) # TODO: test
100103
end
101104

102-
@testset "copyto! with Nf > 1" begin
105+
@testset "CartesianFieldIndex with Nf > 1" begin
103106
device = ClimaComms.device()
104107
ArrayType = ClimaComms.array_type(device)
105108
FT = Float64
@@ -129,7 +132,7 @@ end
129132
# data = DataLayouts.IH1JH2{S}(ArrayType{FT}, zeros; Nij); test_copyto!(data) # TODO: test
130133
end
131134

132-
@testset "copyto! views with Nf > 1" begin
135+
@testset "CartesianFieldIndex views with Nf > 1" begin
133136
device = ClimaComms.device()
134137
ArrayType = ClimaComms.array_type(device)
135138
data_view(data) = DataLayouts.rebuild(

test/Fields/unit_field.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ end
5757
n1 = n2 = 1
5858
Nh = n1 * n2
5959
space = spectral_space_2D(n1 = n1, n2 = n2, Nij = Nij)
60+
device = ClimaComms.device(space)
61+
ArrayType = ClimaComms.array_type(device)
6062

61-
field =
62-
Fields.Field(IJFH{ComplexF64, Nij}(ones(Nij, Nij, 2, n1 * n2)), space)
63+
data = IJFH{ComplexF64}(ArrayType{Float64}, ones; Nij, Nh = n1 * n2)
64+
field = Fields.Field(data, space)
6365

6466
@test sum(field) Complex(1.0, 1.0) * 8.0 * 10.0 rtol = 10eps()
6567
@test sum(x -> 3.0, field) 3 * 8.0 * 10.0 rtol = 10eps()

0 commit comments

Comments
 (0)