Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define CartesianFieldIndex, use less internals #2051

Merged
merged 1 commit into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ steps:

- label: "Unit: getindex_field"
key: unit_data_getindex_field
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_getindex_field.jl"
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_cartesian_field_index.jl"

- label: "Unit: mapreduce"
key: unit_data_mapreduce
Expand Down
62 changes: 21 additions & 41 deletions ext/cuda/topologies_dss.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ClimaCore: DataLayouts, Topologies, Spaces, Fields
import ClimaCore.DataLayouts: getindex_field, setindex_field!
import ClimaCore.DataLayouts: CartesianFieldIndex
using CUDA
import ClimaCore.Topologies
import ClimaCore.Topologies: perimeter_vertex_node_index
Expand Down Expand Up @@ -44,13 +44,13 @@ function dss_load_perimeter_data_kernel!(
(nperimeter, _, _, nlevels, nelems) = size(perimeter_data)
nfidx = DataLayouts.ncomponents(perimeter_data)
sizep = (nlevels, nperimeter, nfidx, nelems) # assume VIFH order
CI = CartesianIndex
CI = CartesianFieldIndex

if gidx prod(sizep)
(level, p, fidx, elem) = cart_ind(sizep, gidx).I
(ip, jp) = perimeter[p]
val = getindex_field(data, CI(ip, jp, fidx, level, elem))
setindex_field!(perimeter_data, val, CI(p, 1, fidx, level, elem))
perimeter_data[CI(p, 1, fidx, level, elem)] =
data[CI(ip, jp, fidx, level, elem)]
end
return nothing
end
Expand Down Expand Up @@ -84,13 +84,13 @@ function dss_unload_perimeter_data_kernel!(
(nperimeter, _, _, nlevels, nelems) = size(perimeter_data)
nfidx = DataLayouts.ncomponents(perimeter_data)
sizep = (nlevels, nperimeter, nfidx, nelems) # assume VIFH order
CI = CartesianIndex
CI = CartesianFieldIndex

if gidx prod(sizep)
(level, p, fidx, elem) = cart_ind(sizep, gidx).I
(ip, jp) = perimeter[p]
val = getindex_field(perimeter_data, CI(p, 1, fidx, level, elem))
setindex_field!(data, val, CI(ip, jp, fidx, level, elem))
data[CI(ip, jp, fidx, level, elem)] =
perimeter_data[CI(p, 1, fidx, level, elem)]
end
return nothing
end
Expand Down Expand Up @@ -139,7 +139,7 @@ function dss_local_kernel!(
nlocalfaces = length(interior_faces)
(nperimeter, _, _, nlevels, _) = size(perimeter_data)
nfidx = DataLayouts.ncomponents(perimeter_data)
CI = CartesianIndex
CI = CartesianFieldIndex
if gidx nlevels * nfidx * nlocalvertices # local vertices
sizev = (nlevels, nfidx, nlocalvertices)
(level, fidx, vertexid) = cart_ind(sizev, gidx).I
Expand All @@ -149,17 +149,12 @@ function dss_local_kernel!(
for idx in st:(en - 1)
(lidx, vert) = local_vertices[idx]
ip = perimeter_vertex_node_index(vert)
sum_data +=
getindex_field(perimeter_data, CI(ip, 1, fidx, level, lidx))
sum_data += perimeter_data[CI(ip, 1, fidx, level, lidx)]
end
for idx in st:(en - 1)
(lidx, vert) = local_vertices[idx]
ip = perimeter_vertex_node_index(vert)
setindex_field!(
perimeter_data,
sum_data,
CI(ip, 1, fidx, level, lidx),
)
perimeter_data[CI(ip, 1, fidx, level, lidx)] = sum_data
end
elseif gidx nlevels * nfidx * (nlocalvertices + nlocalfaces) # interior faces
nfacedof = div(nperimeter - 4, 4)
Expand All @@ -176,11 +171,9 @@ function dss_local_kernel!(
ip2 = inc2 == 1 ? first2 + i - 1 : first2 - i + 1
idx1 = CI(ip1, 1, fidx, level, lidx1)
idx2 = CI(ip2, 1, fidx, level, lidx2)
val =
getindex_field(perimeter_data, idx1) +
getindex_field(perimeter_data, idx2)
setindex_field!(perimeter_data, val, idx1)
setindex_field!(perimeter_data, val, idx2)
val = perimeter_data[idx1] + perimeter_data[idx2]
perimeter_data[idx1] = val
perimeter_data[idx2] = val
end
end

Expand Down Expand Up @@ -353,7 +346,7 @@ function dss_local_ghost_kernel!(
FT = eltype(parent(perimeter_data))
(nperimeter, _, _, nlevels, _) = size(perimeter_data)
nfidx = DataLayouts.ncomponents(perimeter_data)
CI = CartesianIndex
CI = CartesianFieldIndex
nghostvertices = length(ghost_vertex_offset) - 1
if gidx nlevels * nfidx * nghostvertices
sizev = (nlevels, nfidx, nghostvertices)
Expand All @@ -365,19 +358,14 @@ function dss_local_ghost_kernel!(
isghost, lidx, vert = ghost_vertices[idx]
if !isghost
ip = perimeter_vertex_node_index(vert)
sum_data +=
getindex_field(perimeter_data, CI(ip, 1, fidx, level, lidx))
sum_data += perimeter_data[CI(ip, 1, fidx, level, lidx)]
end
end
for idx in st:(en - 1)
isghost, lidx, vert = ghost_vertices[idx]
if !isghost
ip = perimeter_vertex_node_index(vert)
setindex_field!(
perimeter_data,
sum_data,
CI(ip, 1, fidx, level, lidx),
)
perimeter_data[CI(ip, 1, fidx, level, lidx)] = sum_data
end
end
end
Expand Down Expand Up @@ -421,14 +409,13 @@ function fill_send_buffer_kernel!(
(_, _, _, nlevels, nelems) = size(perimeter_data)
nfid = DataLayouts.ncomponents(perimeter_data)
sizet = (nlevels, nfid, nsend)
CI = CartesianIndex
CI = CartesianFieldIndex
if gidx nlevels * nfid * nsend
(level, fidx, isend) = cart_ind(sizet, gidx).I
lidx = send_buf_idx[isend, 1]
ip = send_buf_idx[isend, 2]
idx = level + ((fidx - 1) + (isend - 1) * nfid) * nlevels
send_data[idx] =
getindex_field(perimeter_data, CI(ip, 1, fidx, level, lidx))
send_data[idx] = perimeter_data[CI(ip, 1, fidx, level, lidx)]
end
return nothing
end
Expand Down Expand Up @@ -527,27 +514,20 @@ function dss_ghost_kernel!(
(_, _, _, nlevels, _) = size(perimeter_data)
nfidx = DataLayouts.ncomponents(perimeter_data)
nghostvertices = length(ghost_vertex_offset) - 1
CI = CartesianIndex
CI = CartesianFieldIndex
if gidx nlevels * nfidx * nghostvertices
(level, fidx, ghostvertexidx) =
cart_ind((nlevels, nfidx, nghostvertices), gidx).I
idxresult, lvertresult = repr_ghost_vertex[ghostvertexidx]
ipresult = perimeter_vertex_node_index(lvertresult)
result = getindex_field(
perimeter_data,
CI(ipresult, 1, fidx, level, idxresult),
)
result = perimeter_data[CI(ipresult, 1, fidx, level, idxresult)]
st, en = ghost_vertex_offset[ghostvertexidx],
ghost_vertex_offset[ghostvertexidx + 1]
for vertexidx in st:(en - 1)
isghost, eidx, lvert = ghost_vertices[vertexidx]
if !isghost
ip = perimeter_vertex_node_index(lvert)
setindex_field!(
perimeter_data,
result,
CI(ip, 1, fidx, level, eidx),
)
perimeter_data[CI(ip, 1, fidx, level, eidx)] = result
end
end
end
Expand Down
25 changes: 13 additions & 12 deletions lib/ClimaCoreTempestRemap/src/onlineremap.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ClimaCore.DataLayouts
using ClimaCore.DataLayouts: CartesianFieldIndex
using ClimaComms


Expand Down Expand Up @@ -42,11 +43,9 @@ function remap!(
R::LinearMap,
source::IJFH{S, Nqs},
) where {S, Nqt, Nqs}
source_array = parent(source)
target_array = parent(target)

fill!(target_array, zero(eltype(target_array)))
Nf = size(target_array, 3)
fill!(target, zero(eltype(target)))
Nf = DataLayouts.ncomponents(target)
CI = CartesianFieldIndex

# ideally we would use the tempestremap dgll (redundant node) representation
# unfortunately, this doesn't appear to work quite as well (for out_type = dgll) as the cgll
Expand All @@ -62,7 +61,7 @@ function remap!(
view(R.target_local_idxs[3], n)[1],
)
for f in 1:Nf
target_array[it, jt, f, et] += wt * source_array[is, js, f, es]
target[CI(it, jt, f, 1, et)] += wt * source[CI(is, js, f, 1, es)]
end
end

Expand Down Expand Up @@ -91,11 +90,12 @@ function remap!(target::Fields.Field, R::LinearMap, source::Fields.Field)
@assert Spaces.topology(axes(source)).context isa
ClimaComms.SingletonCommsContext

target_array = parent(target)
source_array = parent(source)
CI = CartesianFieldIndex
target_values = Fields.field_values(target)
source_values = Fields.field_values(source)

fill!(target_array, zero(eltype(target_array)))
Nf = size(target_array, 3)
fill!(target, zero(eltype(target)))
Nf = DataLayouts.ncomponents(target)

# ideally we would use the tempestremap dgll (redundant node) representation
# unfortunately, this doesn't appear to work quite as well (for out_type = dgll) as the cgll
Expand All @@ -118,8 +118,9 @@ function remap!(target::Fields.Field, R::LinearMap, source::Fields.Field)
# only use local weights - i.e. et, es != 0
if (et != 0)
for f in 1:Nf
target_array[it, jt, f, et] +=
wt * source_array[is, js, f, es]
ci_src = CI(is, js, f, 1, es)
ci_tar = CI(it, jt, f, 1, et)
target_values[ci_tar] += wt * source_values[ci_src]
end
end
end
Expand Down
23 changes: 23 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1491,6 +1491,29 @@ end
)
end

"""
CartesianFieldIndex{N} <: Base.AbstractCartesianIndex{N}
A CartesianIndex wrapper to dispatch `getindex` / `setindex!`
to call [`getindex_field`](@ref) and [`setindex_field!`](@ref)
for specific field variables in a datalayout.
"""
struct CartesianFieldIndex{N} <: Base.AbstractCartesianIndex{N}
CI::CartesianIndex{N}
end
CartesianFieldIndex(I...) = CartesianFieldIndex(CartesianIndex(I...))

Base.ndims(::CartesianFieldIndex{N}) where {N} = N
Base.@propagate_inbounds Base.getindex(
data::AbstractData,
CI::CartesianFieldIndex,
) = getindex_field(data, CI.CI)
Base.@propagate_inbounds Base.setindex!(
data::AbstractData,
val::Real,
CI::CartesianFieldIndex,
) = setindex_field!(data, val, CI.CI)

"""
getindex_field(data, ci::CartesianIndex{5})
Expand Down
12 changes: 5 additions & 7 deletions src/Topologies/dss.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DocStringExtensions
using .DataLayouts: getindex_field, setindex_field!
using .DataLayouts: CartesianFieldIndex

"""
DSSBuffer{G, D, A, B}
Expand Down Expand Up @@ -582,13 +582,12 @@ function fill_send_buffer!(
Nf = DataLayouts.ncomponents(perimeter_data)
nsend = size(send_buf_idx, 1)
ctr = 1
CI = CartesianIndex
CI = CartesianFieldIndex
@inbounds for i in 1:nsend
lidx = send_buf_idx[i, 1]
ip = send_buf_idx[i, 2]
for f in 1:Nf, v in 1:Nv
send_data[ctr] =
getindex_field(perimeter_data, CI(ip, 1, f, v, lidx))
send_data[ctr] = perimeter_data[CI(ip, 1, f, v, lidx)]
ctr += 1
end
end
Expand All @@ -612,14 +611,13 @@ function load_from_recv_buffer!(
Nf = DataLayouts.ncomponents(perimeter_data)
nrecv = size(recv_buf_idx, 1)
ctr = 1
CI = CartesianIndex
CI = CartesianFieldIndex
@inbounds for i in 1:nrecv
lidx = recv_buf_idx[i, 1]
ip = recv_buf_idx[i, 2]
for f in 1:Nf, v in 1:Nv
ci = CI(ip, 1, f, v, lidx)
val = getindex_field(perimeter_data, ci) + recv_data[ctr]
setindex_field!(perimeter_data, val, ci)
perimeter_data[ci] += recv_data[ctr]
ctr += 1
end
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#=
julia --project
using Revise; include(joinpath("test", "DataLayouts", "unit_getindex_field.jl"))
using Revise; include(joinpath("test", "DataLayouts", "unit_cartesian_field_index.jl"))
=#
using Test
using ClimaCore.DataLayouts
using ClimaCore.DataLayouts: getindex_field, setindex_field!
using ClimaCore.DataLayouts: CartesianFieldIndex
using ClimaCore.DataLayouts: to_data_specific_field, singleton
import ClimaCore.Geometry
import ClimaComms
Expand All @@ -31,15 +31,18 @@ function test_copyto_float!(data)
ArrayType = ClimaComms.array_type(ClimaComms.device())
FT = eltype(parent(data))
parent(rand_data) .= ArrayType(rand(FT, DataLayouts.farray_size(data)))
# For a float, getindex and getindex_field return the same thing
# For a float, CartesianIndex and CartesianFieldIndex return the same thing
for I in CartesianIndices(universal_axes(data))
@test getindex_field(data, I) == getindex(data, I)
CI = CartesianFieldIndex(I.I)
@test data[CI] == data[I]
end
for I in CartesianIndices(universal_axes(data))
setindex_field!(data, FT(prod(I.I)), I)
CI = CartesianFieldIndex(I.I)
data[CI] = FT(prod(I.I))
end
for I in CartesianIndices(universal_axes(data))
@test getindex_field(data, I) == prod(I.I)
CI = CartesianFieldIndex(I.I)
@test data[CI] == prod(I.I)
end
end

Expand All @@ -55,7 +58,7 @@ function test_copyto!(data)
for f in 1:DataLayouts.ncomponents(data)
UFI = universal_field_index(I, f)
DSI = CartesianIndex(to_data_specific_field(singleton(data), UFI.I))
@test getindex_field(data, UFI) == parent(data)[DSI]
@test data[CartesianFieldIndex(UFI)] == parent(data)[DSI]
end
end

Expand All @@ -64,13 +67,13 @@ function test_copyto!(data)
UFI = universal_field_index(I, f)
DSI = CartesianIndex(to_data_specific_field(singleton(data), UFI.I))
val = parent(data)[DSI]
setindex_field!(data, val + 1, UFI)
data[CartesianFieldIndex(UFI)] = val + 1
@test parent(data)[DSI] == val + 1
end
end
end

@testset "copyto! with Nf = 1" begin
@testset "CartesianFieldIndex with Nf = 1" begin
device = ClimaComms.device()
ArrayType = ClimaComms.array_type(device)
FT = Float64
Expand Down Expand Up @@ -99,7 +102,7 @@ end
# data = DataLayouts.IH1JH2{S}(ArrayType{FT}, zeros; Nij); test_copyto_float!(data) # TODO: test
end

@testset "copyto! with Nf > 1" begin
@testset "CartesianFieldIndex with Nf > 1" begin
device = ClimaComms.device()
ArrayType = ClimaComms.array_type(device)
FT = Float64
Expand Down Expand Up @@ -129,7 +132,7 @@ end
# data = DataLayouts.IH1JH2{S}(ArrayType{FT}, zeros; Nij); test_copyto!(data) # TODO: test
end

@testset "copyto! views with Nf > 1" begin
@testset "CartesianFieldIndex views with Nf > 1" begin
device = ClimaComms.device()
ArrayType = ClimaComms.array_type(device)
data_view(data) = DataLayouts.rebuild(
Expand Down
6 changes: 4 additions & 2 deletions test/Fields/unit_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ end
n1 = n2 = 1
Nh = n1 * n2
space = spectral_space_2D(n1 = n1, n2 = n2, Nij = Nij)
device = ClimaComms.device(space)
ArrayType = ClimaComms.array_type(device)

field =
Fields.Field(IJFH{ComplexF64, Nij}(ones(Nij, Nij, 2, n1 * n2)), space)
data = IJFH{ComplexF64}(ArrayType{Float64}, ones; Nij, Nh = n1 * n2)
field = Fields.Field(data, space)

@test sum(field) Complex(1.0, 1.0) * 8.0 * 10.0 rtol = 10eps()
@test sum(x -> 3.0, field) 3 * 8.0 * 10.0 rtol = 10eps()
Expand Down
Loading
Loading