Skip to content

Commit 67dd505

Browse files
authoredOct 21, 2024
Merge pull request #2052 from CliMA/ck/refactor
Refactor and use less internals
2 parents 2e0c495 + e220729 commit 67dd505

File tree

9 files changed

+37
-23
lines changed

9 files changed

+37
-23
lines changed
 

‎.buildkite/pipeline.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ steps:
8686
key: unit_data_copyto
8787
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_copyto.jl"
8888

89-
- label: "Unit: getindex_field"
90-
key: unit_data_getindex_field
89+
- label: "Unit: cartesian_field_index"
90+
key: unit_data_cartesian_field_index
9191
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_cartesian_field_index.jl"
9292

9393
- label: "Unit: mapreduce"

‎src/DataLayouts/DataLayouts.jl

+12
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,18 @@ device_dispatch(x::MArray) = ToCPU()
16531653
@inline singleton(@nospecialize(::IH1JH2)) = IH1JH2Singleton()
16541654
@inline singleton(@nospecialize(::IV1JH2)) = IV1JH2Singleton()
16551655

1656+
@inline singleton(::Type{IJKFVH}) = IJKFVHSingleton()
1657+
@inline singleton(::Type{IJFH}) = IJFHSingleton()
1658+
@inline singleton(::Type{IFH}) = IFHSingleton()
1659+
@inline singleton(::Type{DataF}) = DataFSingleton()
1660+
@inline singleton(::Type{IJF}) = IJFSingleton()
1661+
@inline singleton(::Type{IF}) = IFSingleton()
1662+
@inline singleton(::Type{VF}) = VFSingleton()
1663+
@inline singleton(::Type{VIJFH}) = VIJFHSingleton()
1664+
@inline singleton(::Type{VIFH}) = VIFHSingleton()
1665+
@inline singleton(::Type{IH1JH2}) = IH1JH2Singleton()
1666+
@inline singleton(::Type{IV1JH2}) = IV1JH2Singleton()
1667+
16561668

16571669
include("copyto.jl")
16581670
include("fused_copyto.jl")

‎src/InputOutput/readers.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -467,16 +467,18 @@ function read_field(reader::HDF5Reader, name::AbstractString)
467467
end
468468
topology = Spaces.topology(space)
469469
ArrayType = ClimaComms.array_type(topology)
470+
data_layout = attrs(obj)["data_layout"]
471+
DataLayout = _scan_data_layout(data_layout)
472+
h_dim = DataLayouts.h_dim(DataLayouts.singleton(DataLayout))
470473
if topology isa Topologies.Topology2D
471474
nd = ndims(obj)
472-
localidx = ntuple(d -> d < nd ? (:) : topology.local_elem_gidx, nd)
475+
localidx =
476+
ntuple(d -> d == h_dim ? topology.local_elem_gidx : (:), nd)
473477
data = ArrayType(obj[localidx...])
474478
else
475479
data = ArrayType(read(obj))
476480
end
477-
data_layout = attrs(obj)["data_layout"]
478481
Nij = size(data, findfirst("I", data_layout)[1])
479-
DataLayout = _scan_data_layout(data_layout)
480482
# For when `Nh` is added back to the type space
481483
# Nhd = Nh_dim(data_layout)
482484
# Nht = Nhd == -1 ? () : (size(data, Nhd),)

‎src/InputOutput/writers.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ end
423423

424424
# write fields
425425
function write!(writer::HDF5Writer, field::Fields.Field, name::AbstractString)
426+
values = Fields.field_values(field)
426427
space = axes(field)
427428
staggering = Spaces.staggering(space)
428429
grid = Spaces.grid(space)
@@ -434,8 +435,9 @@ function write!(writer::HDF5Writer, field::Fields.Field, name::AbstractString)
434435
if topology isa Topologies.Topology2D &&
435436
!(writer.context isa ClimaComms.SingletonCommsContext)
436437
nelems = Topologies.nelems(topology)
437-
dims = ntuple(d -> d == nd ? nelems : size(array, d), nd)
438-
localidx = ntuple(d -> d < nd ? (:) : topology.local_elem_gidx, nd)
438+
h_dim = DataLayouts.h_dim(DataLayouts.singleton(values))
439+
dims = ntuple(d -> d == h_dim ? nelems : size(array, d), nd)
440+
localidx = ntuple(d -> d == h_dim ? topology.local_elem_gidx : (:), nd)
439441
dataset = create_dataset(
440442
writer.file,
441443
"fields/$name",

‎src/Topologies/dss_transform.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -269,17 +269,17 @@ function create_ghost_buffer(
269269
if data isa DataLayouts.IJFH
270270
send_data = DataLayouts.IJFH{S, Nij}(typeof(parent(data)), Nhsend)
271271
recv_data = DataLayouts.IJFH{S, Nij}(typeof(parent(data)), Nhrec)
272-
k = stride(parent(send_data), 4)
273272
else
274-
Nv, _, _, Nf, _ = DataLayouts.farray_size(data)
273+
Nv = DataLayouts.nlevels(data)
274+
Nf = DataLayouts.ncomponents(data)
275275
send_data = DataLayouts.VIJFH{S, Nv, Nij}(
276276
similar(parent(data), (Nv, Nij, Nij, Nf, Nhsend)),
277277
)
278278
recv_data = DataLayouts.VIJFH{S, Nv, Nij}(
279279
similar(parent(data), (Nv, Nij, Nij, Nf, Nhrec)),
280280
)
281-
k = stride(parent(send_data), 5)
282281
end
282+
k = stride(parent(send_data), DataLayouts.h_dim(data))
283283

284284
graph_context = ClimaComms.graph_context(
285285
topology.context,

‎test/Fields/unit_field.jl

+9-6
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,15 @@ end
251251
Nh = n1 * n2
252252
space = spectral_space_2D(n1 = n1, n2 = n2, Nij = Nij)
253253

254-
nt_field = Fields.Field(
255-
IJFH{NamedTuple{(:a, :b), Tuple{Float64, Float64}}, Nij}(
256-
ones(Nij, Nij, 2, Nh),
257-
),
258-
space,
259-
)
254+
S = NamedTuple{(:a, :b), Tuple{Float64, Float64}}
255+
context = ClimaComms.context(space)
256+
device = ClimaComms.device(context)
257+
ArrayType = ClimaComms.array_type(device)
258+
FT = Spaces.undertype(space)
259+
data = IJFH{S}(ArrayType{FT}, ones; Nij, Nh)
260+
261+
nt_field = Fields.Field(data, space)
262+
260263
nt_sum = sum(nt_field)
261264
@test nt_sum isa NamedTuple{(:a, :b), Tuple{Float64, Float64}}
262265
@test nt_sum.a 8.0 * 10.0 rtol = 10eps()

‎test/InputOutput/hybrid3dcubedsphere.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ end
5555
ᶜlocal_geometry = Fields.local_geometry_field(center_space)
5656
ᶠlocal_geometry = Fields.local_geometry_field(face_space)
5757

58-
Y = Fields.FieldVector(c = ᶜlocal_geometry, f = ᶠlocal_geometry)
58+
Y = Fields.FieldVector(; c = ᶜlocal_geometry, f = ᶠlocal_geometry)
5959

6060
# write field vector to hdf5 file
6161
writer = InputOutput.HDF5Writer(filename, comms_ctx)

‎test/InputOutput/hybrid3dcubedsphere_topography.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ end
7777
ᶜlocal_geometry = Fields.local_geometry_field(center_space)
7878
ᶠlocal_geometry = Fields.local_geometry_field(face_space)
7979

80-
Y = Fields.FieldVector(c = ᶜlocal_geometry, f = ᶠlocal_geometry)
80+
Y = Fields.FieldVector(; c = ᶜlocal_geometry, f = ᶠlocal_geometry)
8181

8282
# write field vector to hdf5 file
8383
writer = InputOutput.HDF5Writer(filename, comms_ctx)

‎test/runtests.jl

-5
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,10 @@ UnitTest("Spectral elem - sphere diffusion" ,"Operators/spectralelement/s
5252
UnitTest("Spectral elem - sphere diffusion vec" ,"Operators/spectralelement/sphere_diffusion_vec.jl"),
5353
UnitTest("Spectral elem - sphere hyperdiff" ,"Operators/spectralelement/unit_sphere_hyperdiffusion.jl"),
5454
UnitTest("Spectral elem - sphere hyperdiff vec" ,"Operators/spectralelement/unit_sphere_hyperdiffusion_vec.jl"),
55-
# UnitTest("Spectral elem - sphere hyperdiff vec" ,"Operators/spectralelement/sphere_geometry_distributed.jl"), # MPI-only
5655
UnitTest("FD ops - column" ,"Operators/finitedifference/unit_column.jl"),
5756
UnitTest("FD ops - opt" ,"Operators/finitedifference/opt.jl"),
5857
UnitTest("FD ops - wfact" ,"Operators/finitedifference/wfact.jl"),
5958
UnitTest("FD ops - linsolve" ,"Operators/finitedifference/linsolve.jl"),
60-
# UnitTest("FD ops - examples" ,"Operators/finitedifference/opt_examples.jl"), # only opt tests? (check coverage)
6159
UnitTest("Hybrid - 2D" ,"Operators/hybrid/unit_2d.jl"),
6260
UnitTest("Hybrid - 3D" ,"Operators/hybrid/unit_3d.jl"),
6361
UnitTest("Hybrid - dss opt" ,"Operators/hybrid/dss_opt.jl"),
@@ -89,15 +87,12 @@ UnitTest("MatrixFields - non-scalar broadcasting (3)" ,"MatrixFields/matrix_fiel
8987
UnitTest("MatrixFields - non-scalar broadcasting (4)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_4.jl"),
9088
UnitTest("MatrixFields - non-scalar broadcasting (5)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_5.jl"),
9189
UnitTest("MatrixFields - flat spaces" ,"MatrixFields/flat_spaces.jl"),
92-
93-
# UnitTest("MatrixFields - matrix field broadcast" ,"MatrixFields/matrix_field_broadcasting.jl"), # too long
9490
# UnitTest("MatrixFields - operator matrices" ,"MatrixFields/operator_matrices.jl"), # too long
9591
# UnitTest("MatrixFields - field matrix solvers" ,"MatrixFields/field_matrix_solvers.jl"), # too long
9692
UnitTest("Hypsography - 2d" ,"Hypsography/2d.jl"),
9793
UnitTest("Hypsography - 3d sphere" ,"Hypsography/3dsphere.jl"),
9894
UnitTest("Remapping" ,"Operators/remapping.jl"),
9995
UnitTest("Limiter" ,"Limiters/limiter.jl"),
100-
# UnitTest("Limiter" ,"Limiters/distributed/dlimiter.jl"), # requires MPI
10196
UnitTest("InputOutput - hdf5" ,"InputOutput/hdf5.jl"),
10297
UnitTest("InputOutput - spectralelement2d" ,"InputOutput/spectralelement2d.jl"),
10398
UnitTest("InputOutput - hybrid2dbox" ,"InputOutput/hybrid2dbox.jl"),

0 commit comments

Comments
 (0)
Failed to load comments.