Skip to content

Commit 36837c2

Browse files
Merge pull request #2072 from CliMA/ck/fv_unit
Add fieldvector unit tests, update `rcompare`
2 parents 9a05ed2 + 46edf40 commit 36837c2

File tree

3 files changed

+76
-13
lines changed

3 files changed

+76
-13
lines changed

NEWS.md

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ ClimaCore.jl Release Notes
44
main
55
-------
66

7+
- A `strict = true` keyword was added to `rcompare`, which checks that the types match. If `strict = false`, then `rcompare` will return `true` for `FieldVector`s and `NamedTuple`s with the same properties but permuted order. For example:
8+
- `rcompare((;a=1,b=2), (;b=2,a=1); strict = true)` will return `false` and
9+
- `rcompare((;a=1,b=2), (;b=2,a=1); strict = false)` will return `true`
710
- We've added new datalayouts: `VIJHF`,`IJHF`,`IHF`,`VIHF`, to explore their performance compared to our existing datalayouts: `VIJFH`,`IJFH`,`IFH`,`VIFH`. PR [#2055](https://github.com/CliMA/ClimaCore.jl/pull/2053), PR [#2052](https://github.com/CliMA/ClimaCore.jl/pull/2055).
811
- We've refactored some modules to use less internals. PR [#2053](https://github.com/CliMA/ClimaCore.jl/pull/2053), PR [#2052](https://github.com/CliMA/ClimaCore.jl/pull/2052), [#2051](https://github.com/CliMA/ClimaCore.jl/pull/2051), [#2049](https://github.com/CliMA/ClimaCore.jl/pull/2049).
912
- Some work was done in attempt to reduce specializations and compile time. PR [#2042](https://github.com/CliMA/ClimaCore.jl/pull/2042), [#2041](https://github.com/CliMA/ClimaCore.jl/pull/2041)

src/Fields/fieldvector.jl

+31-13
Original file line numberDiff line numberDiff line change
@@ -469,36 +469,54 @@ end
469469

470470

471471
# Recursively compare contents of similar fieldvectors
472-
_rcompare(pass, x::T, y::T) where {T <: Field} =
473-
pass && _rcompare(pass, field_values(x), field_values(y))
474-
_rcompare(pass, x::T, y::T) where {T <: DataLayouts.AbstractData} =
472+
_rcompare(pass, x::T, y::T; strict) where {T <: Field} =
473+
pass && _rcompare(pass, field_values(x), field_values(y); strict)
474+
_rcompare(pass, x::T, y::T; strict) where {T <: DataLayouts.AbstractData} =
475475
pass && (parent(x) == parent(y))
476-
_rcompare(pass, x::T, y::T) where {T} = pass && (x == y)
476+
_rcompare(pass, x::T, y::T; strict) where {T} = pass && (x == y)
477477

478-
function _rcompare(pass, x::T, y::T) where {T <: Union{FieldVector, NamedTuple}}
478+
_rcompare(pass, x::NamedTuple, y::NamedTuple; strict) =
479+
_rcompare_nt(pass, x, y; strict)
480+
_rcompare(pass, x::FieldVector, y::FieldVector; strict) =
481+
_rcompare_nt(pass, x, y; strict)
482+
483+
function _rcompare_nt(pass, x, y; strict)
484+
length(propertynames(x)) length(propertynames(y)) && return false
485+
if strict
486+
typeof(x) == typeof(y) || return false
487+
end
479488
for pn in propertynames(x)
480-
pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn))
489+
pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn); strict)
481490
end
482491
return pass
483492
end
484493

485494
"""
486-
rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}}
495+
rcompare(x::T, y::T; strict = true) where {T <: Union{FieldVector, NamedTuple}}
487496
488497
Recursively compare given fieldvectors via `==`.
489498
Returns `true` if `x == y` recursively.
490499
491500
FieldVectors with different types are considered different.
492501
"""
493-
rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}} =
494-
_rcompare(true, x, y)
502+
rcompare(
503+
x::T,
504+
y::T;
505+
strict = true,
506+
) where {T <: Union{FieldVector, NamedTuple}} = _rcompare(true, x, y; strict)
495507

496-
rcompare(x::T, y::T) where {T <: FieldVector} = _rcompare(true, x, y)
508+
rcompare(x::T, y::T; strict = true) where {T <: FieldVector} =
509+
_rcompare(true, x, y; strict)
497510

498-
rcompare(x::T, y::T) where {T <: NamedTuple} = _rcompare(true, x, y)
511+
rcompare(x::T, y::T; strict = true) where {T <: NamedTuple} =
512+
_rcompare(true, x, y; strict)
499513

500514
# FieldVectors with different types are always different
501-
rcompare(x::FieldVector, y::FieldVector) = false
515+
rcompare(x::FieldVector, y::FieldVector; strict::Bool = true) =
516+
strict ? false : _rcompare(true, x, y; strict)
517+
518+
rcompare(x::NamedTuple, y::NamedTuple; strict::Bool = true) =
519+
strict ? false : _rcompare(true, x, y; strict)
502520

503521
# Define == to call rcompare for two fieldvectors
504-
Base.:(==)(x::FieldVector, y::FieldVector) = rcompare(x, y)
522+
Base.:(==)(x::FieldVector, y::FieldVector) = rcompare(x, y; strict = true)

test/Fields/unit_field.jl

+42
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ ClimaComms.@import_required_backends
1111
using OrderedCollections
1212
using StaticArrays, IntervalSets
1313
import ClimaCore
14+
import ClimaCore.InputOutput
1415
import ClimaCore.Utilities: PlusHalf
1516
import ClimaCore.DataLayouts
1617
import ClimaCore.DataLayouts: IJFH
@@ -330,6 +331,47 @@ end
330331
@test occursin("==================== Difference found:", s)
331332
end
332333

334+
@testset "Nested FieldVector broadcasting with permuted order" begin
335+
FT = Float32
336+
context = ClimaComms.context()
337+
vertdomain = Domains.IntervalDomain(
338+
Geometry.ZPoint{FT}(-3.5),
339+
Geometry.ZPoint{FT}(0);
340+
boundary_names = (:bottom, :top),
341+
)
342+
vertmesh = Meshes.IntervalMesh(vertdomain; nelems = 10)
343+
device = ClimaComms.device()
344+
vert_center_space = Spaces.CenterFiniteDifferenceSpace(device, vertmesh)
345+
horzdomain = Domains.SphereDomain(FT(100))
346+
horzmesh = Meshes.EquiangularCubedSphere(horzdomain, 1)
347+
horztopology = Topologies.Topology2D(context, horzmesh)
348+
quad = Spaces.Quadratures.GLL{2}()
349+
space = Spaces.SpectralElementSpace2D(horztopology, quad)
350+
351+
vars1 = (; # order is different!
352+
bucket = (; # nesting is needed!
353+
T = Fields.Field(FT, space),
354+
W = Fields.Field(FT, space),
355+
)
356+
)
357+
vars2 = (; # order is different!
358+
bucket = (; # nesting is needed!
359+
W = Fields.Field(FT, space),
360+
T = Fields.Field(FT, space),
361+
)
362+
)
363+
Y1 = Fields.FieldVector(; vars1...)
364+
Y1.bucket.T .= 280.0
365+
Y1.bucket.W .= 0.05
366+
367+
Y2 = Fields.FieldVector(; vars2...)
368+
Y2.bucket.T .= 280.0
369+
Y2.bucket.W .= 0.05
370+
371+
Y1 .= Y2 # FieldVector broadcasting
372+
@test Fields.rcompare(Y1, Y2; strict = false)
373+
end
374+
333375
# https://github.com/CliMA/ClimaCore.jl/issues/1465
334376
@testset "Diagonal FieldVector broadcast expressions" begin
335377
FT = Float64

0 commit comments

Comments
 (0)