Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fieldvector unit tests, update rcompare #2072

Merged
merged 1 commit into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ ClimaCore.jl Release Notes
main
-------

- A `strict = true` keyword was added to `rcompare`, which checks that the types match. If `strict = false`, then `rcompare` will return `true` for `FieldVector`s and `NamedTuple`s with the same properties but permuted order. For example:
- `rcompare((;a=1,b=2), (;b=2,a=1); strict = true)` will return `false` and
- `rcompare((;a=1,b=2), (;b=2,a=1); strict = false)` will return `true`
- We've added new datalayouts: `VIJHF`,`IJHF`,`IHF`,`VIHF`, to explore their performance compared to our existing datalayouts: `VIJFH`,`IJFH`,`IFH`,`VIFH`. PR [#2055](https://github.com/CliMA/ClimaCore.jl/pull/2053), PR [#2052](https://github.com/CliMA/ClimaCore.jl/pull/2055).
- We've refactored some modules to use less internals. PR [#2053](https://github.com/CliMA/ClimaCore.jl/pull/2053), PR [#2052](https://github.com/CliMA/ClimaCore.jl/pull/2052), [#2051](https://github.com/CliMA/ClimaCore.jl/pull/2051), [#2049](https://github.com/CliMA/ClimaCore.jl/pull/2049).
- Some work was done in attempt to reduce specializations and compile time. PR [#2042](https://github.com/CliMA/ClimaCore.jl/pull/2042), [#2041](https://github.com/CliMA/ClimaCore.jl/pull/2041)
Expand Down
44 changes: 31 additions & 13 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,36 +469,54 @@ end


# Recursively compare contents of similar fieldvectors
_rcompare(pass, x::T, y::T) where {T <: Field} =
pass && _rcompare(pass, field_values(x), field_values(y))
_rcompare(pass, x::T, y::T) where {T <: DataLayouts.AbstractData} =
_rcompare(pass, x::T, y::T; strict) where {T <: Field} =
pass && _rcompare(pass, field_values(x), field_values(y); strict)
_rcompare(pass, x::T, y::T; strict) where {T <: DataLayouts.AbstractData} =
pass && (parent(x) == parent(y))
_rcompare(pass, x::T, y::T) where {T} = pass && (x == y)
_rcompare(pass, x::T, y::T; strict) where {T} = pass && (x == y)

function _rcompare(pass, x::T, y::T) where {T <: Union{FieldVector, NamedTuple}}
_rcompare(pass, x::NamedTuple, y::NamedTuple; strict) =
_rcompare_nt(pass, x, y; strict)
_rcompare(pass, x::FieldVector, y::FieldVector; strict) =
_rcompare_nt(pass, x, y; strict)

function _rcompare_nt(pass, x, y; strict)
length(propertynames(x)) ≠ length(propertynames(y)) && return false
if strict
typeof(x) == typeof(y) || return false
end
for pn in propertynames(x)
pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn))
pass &= _rcompare(pass, getproperty(x, pn), getproperty(y, pn); strict)
end
return pass
end

"""
rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}}
rcompare(x::T, y::T; strict = true) where {T <: Union{FieldVector, NamedTuple}}

Recursively compare given fieldvectors via `==`.
Returns `true` if `x == y` recursively.

FieldVectors with different types are considered different.
"""
rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}} =
_rcompare(true, x, y)
rcompare(
x::T,
y::T;
strict = true,
) where {T <: Union{FieldVector, NamedTuple}} = _rcompare(true, x, y; strict)

rcompare(x::T, y::T) where {T <: FieldVector} = _rcompare(true, x, y)
rcompare(x::T, y::T; strict = true) where {T <: FieldVector} =
_rcompare(true, x, y; strict)

rcompare(x::T, y::T) where {T <: NamedTuple} = _rcompare(true, x, y)
rcompare(x::T, y::T; strict = true) where {T <: NamedTuple} =
_rcompare(true, x, y; strict)

# FieldVectors with different types are always different
rcompare(x::FieldVector, y::FieldVector) = false
rcompare(x::FieldVector, y::FieldVector; strict::Bool = true) =
strict ? false : _rcompare(true, x, y; strict)

rcompare(x::NamedTuple, y::NamedTuple; strict::Bool = true) =
strict ? false : _rcompare(true, x, y; strict)

# Define == to call rcompare for two fieldvectors
Base.:(==)(x::FieldVector, y::FieldVector) = rcompare(x, y)
Base.:(==)(x::FieldVector, y::FieldVector) = rcompare(x, y; strict = true)
42 changes: 42 additions & 0 deletions test/Fields/unit_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ClimaComms.@import_required_backends
using OrderedCollections
using StaticArrays, IntervalSets
import ClimaCore
import ClimaCore.InputOutput
import ClimaCore.Utilities: PlusHalf
import ClimaCore.DataLayouts
import ClimaCore.DataLayouts: IJFH
Expand Down Expand Up @@ -330,6 +331,47 @@ end
@test occursin("==================== Difference found:", s)
end

@testset "Nested FieldVector broadcasting with permuted order" begin
FT = Float32
context = ClimaComms.context()
vertdomain = Domains.IntervalDomain(
Geometry.ZPoint{FT}(-3.5),
Geometry.ZPoint{FT}(0);
boundary_names = (:bottom, :top),
)
vertmesh = Meshes.IntervalMesh(vertdomain; nelems = 10)
device = ClimaComms.device()
vert_center_space = Spaces.CenterFiniteDifferenceSpace(device, vertmesh)
horzdomain = Domains.SphereDomain(FT(100))
horzmesh = Meshes.EquiangularCubedSphere(horzdomain, 1)
horztopology = Topologies.Topology2D(context, horzmesh)
quad = Spaces.Quadratures.GLL{2}()
space = Spaces.SpectralElementSpace2D(horztopology, quad)

vars1 = (; # order is different!
bucket = (; # nesting is needed!
T = Fields.Field(FT, space),
W = Fields.Field(FT, space),
)
)
vars2 = (; # order is different!
bucket = (; # nesting is needed!
W = Fields.Field(FT, space),
T = Fields.Field(FT, space),
)
)
Y1 = Fields.FieldVector(; vars1...)
Y1.bucket.T .= 280.0
Y1.bucket.W .= 0.05

Y2 = Fields.FieldVector(; vars2...)
Y2.bucket.T .= 280.0
Y2.bucket.W .= 0.05

Y1 .= Y2 # FieldVector broadcasting
@test Fields.rcompare(Y1, Y2; strict = false)
end

# https://github.com/CliMA/ClimaCore.jl/issues/1465
@testset "Diagonal FieldVector broadcast expressions" begin
FT = Float64
Expand Down
Loading