@@ -475,30 +475,44 @@ _rcompare(pass, x::T, y::T) where {T <: DataLayouts.AbstractData} =
475
475
pass && (parent (x) == parent (y))
476
476
_rcompare (pass, x:: T , y:: T ) where {T} = pass && (x == y)
477
477
478
- function _rcompare (pass, x:: T , y:: T ) where {T <: Union{FieldVector, NamedTuple} }
478
+ _rcompare (pass, x:: NamedTuple , y:: NamedTuple ; strict) =
479
+ _rcompare_nt (pass, x, y; strict)
480
+ _rcompare (pass, x:: FieldVector , y:: FieldVector ; strict) =
481
+ _rcompare_nt (pass, x, y; strict)
482
+
483
+ function _rcompare_nt (pass, x, y; strict)
484
+ if strict
485
+ typeof (x) == typeof (y) || return false
486
+ end
479
487
for pn in propertynames (x)
480
488
pass &= _rcompare (pass, getproperty (x, pn), getproperty (y, pn))
481
489
end
482
490
return pass
483
491
end
484
492
485
493
"""
486
- rcompare(x::T, y::T) where {T <: Union{FieldVector, NamedTuple}}
494
+ rcompare(x::T, y::T; strict = true ) where {T <: Union{FieldVector, NamedTuple}}
487
495
488
496
Recursively compare given fieldvectors via `==`.
489
497
Returns `true` if `x == y` recursively.
490
498
491
499
FieldVectors with different types are considered different.
492
500
"""
493
- rcompare (x:: T , y:: T ) where {T <: Union{FieldVector, NamedTuple} } =
494
- _rcompare (true , x, y)
501
+ rcompare (
502
+ x:: T ,
503
+ y:: T ;
504
+ strict = true ,
505
+ ) where {T <: Union{FieldVector, NamedTuple} } = _rcompare (true , x, y; strict)
495
506
496
- rcompare (x:: T , y:: T ) where {T <: FieldVector } = _rcompare (true , x, y)
507
+ rcompare (x:: T , y:: T ; strict = true ) where {T <: FieldVector } =
508
+ _rcompare (true , x, y; strict)
497
509
498
- rcompare (x:: T , y:: T ) where {T <: NamedTuple } = _rcompare (true , x, y)
510
+ rcompare (x:: T , y:: T ; strict = true ) where {T <: NamedTuple } =
511
+ _rcompare (true , x, y; strict)
499
512
500
513
# FieldVectors with different types are always different
501
- rcompare (x:: FieldVector , y:: FieldVector ) = false
514
+ rcompare (x:: FieldVector , y:: FieldVector ; strict:: Bool = true ) =
515
+ strict ? false : _rcompare (true , x, y; strict)
502
516
503
517
# Define == to call rcompare for two fieldvectors
504
- Base.:(== )(x:: FieldVector , y:: FieldVector ) = rcompare (x, y)
518
+ Base.:(== )(x:: FieldVector , y:: FieldVector ) = rcompare (x, y; true )
0 commit comments