Skip to content

Commit 688e285

Browse files
committed
Use * to denote matrix multiplication
1 parent dd0f95e commit 688e285

30 files changed

+341
-235
lines changed

examples/hybrid/staggered_nonhydrostatic_model.jl

+39-31
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,11 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
331331
# norm_sqr(C123(ᶜuₕ) + C123(ᶜinterp(ᶠw))) / 2 =
332332
# ACT12(ᶜuₕ) * ᶜuₕ / 2 + ACT3(ᶜinterp(ᶠw)) * ᶜinterp(ᶠw) / 2
333333
# ∂(ᶜK)/∂(ᶠw) = ACT3(ᶜinterp(ᶠw)) * ᶜinterp_matrix()
334-
@. ∂ᶜK∂ᶠw = DiagonalMatrixRow(adjoint(CT3(ᶜinterp(ᶠw)))) ᶜinterp_matrix()
334+
@. ∂ᶜK∂ᶠw = DiagonalMatrixRow(adjoint(CT3(ᶜinterp(ᶠw)))) * ᶜinterp_matrix()
335335

336336
# ᶜρₜ = -ᶜdivᵥ(ᶠinterp(ᶜρ) * ᶠw)
337337
# ∂(ᶜρₜ)/∂(ᶠw) = -ᶜdivᵥ_matrix() * ᶠinterp(ᶜρ) * ᶠg³³
338-
@. ∂ᶜρₜ∂ᶠ𝕄 = -(ᶜdivᵥ_matrix()) DiagonalMatrixRow(ᶠinterp(ᶜρ) * g³³(ᶠgⁱʲ))
338+
@. ∂ᶜρₜ∂ᶠ𝕄 = -(ᶜdivᵥ_matrix()) * DiagonalMatrixRow(ᶠinterp(ᶜρ) * g³³(ᶠgⁱʲ))
339339

340340
if :ρθ in propertynames(Y.c)
341341
ᶜρθ = Y.c.ρθ
@@ -349,14 +349,14 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
349349
# ᶜρθₜ = -ᶜdivᵥ(ᶠinterp(ᶜρθ) * ᶠw)
350350
# ∂(ᶜρθₜ)/∂(ᶠw) = -ᶜdivᵥ_matrix() * ᶠinterp(ᶜρθ) * ᶠg³³
351351
@. ∂ᶜ𝔼ₜ∂ᶠ𝕄 =
352-
-(ᶜdivᵥ_matrix()) DiagonalMatrixRow(ᶠinterp(ᶜρθ) * g³³(ᶠgⁱʲ))
352+
-(ᶜdivᵥ_matrix()) * DiagonalMatrixRow(ᶠinterp(ᶜρθ) * g³³(ᶠgⁱʲ))
353353
else
354354
# ᶜρθₜ = -ᶜdivᵥ(ᶠinterp(ᶜρ) * ᶠupwind_product(ᶠw, ᶜρθ / ᶜρ))
355355
# ∂(ᶜρθₜ)/∂(ᶠw) =
356356
# -ᶜdivᵥ_matrix() * ᶠinterp(ᶜρ) *
357357
# ∂(ᶠupwind_product(ᶠw, ᶜρθ / ᶜρ))/∂(ᶠw)
358358
@. ∂ᶜ𝔼ₜ∂ᶠ𝕄 =
359-
-(ᶜdivᵥ_matrix()) DiagonalMatrixRow(
359+
-(ᶜdivᵥ_matrix()) * DiagonalMatrixRow(
360360
ᶠinterp(ᶜρ) *
361361
vec_data(ᶠno_flux(ᶠupwind_product(ᶠw + εw, ᶜρθ / ᶜρ))) /
362362
vec_data(CT3(ᶠw + εw)) * g³³(ᶠgⁱʲ),
@@ -381,10 +381,12 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
381381
# ∂(ᶜp)/∂(ᶠw) = ∂(ᶜp)/∂(ᶜK) * ∂(ᶜK)/∂(ᶠw)
382382
# ∂(ᶜp)/∂(ᶜK) = -ᶜρ * R_d / cv_d
383383
@. ∂ᶜ𝔼ₜ∂ᶠ𝕄 =
384-
-(ᶜdivᵥ_matrix()) (
384+
-(ᶜdivᵥ_matrix()) * (
385385
DiagonalMatrixRow(ᶠinterp(ᶜρe + ᶜp) * g³³(ᶠgⁱʲ)) +
386-
DiagonalMatrixRow(CT3(ᶠw)) ᶠinterp_matrix()
387-
DiagonalMatrixRow(-(ᶜρ * R_d / cv_d)) ∂ᶜK∂ᶠw
386+
DiagonalMatrixRow(CT3(ᶠw)) *
387+
ᶠinterp_matrix() *
388+
DiagonalMatrixRow(-(ᶜρ * R_d / cv_d)) *
389+
∂ᶜK∂ᶠw
388390
)
389391
else
390392
# ᶜρeₜ =
@@ -397,15 +399,17 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
397399
# ∂(ᶜp)/∂(ᶠw) = ∂(ᶜp)/∂(ᶜK) * ∂(ᶜK)/∂(ᶠw)
398400
# ∂(ᶜp)/∂(ᶜK) = -ᶜρ * R_d / cv_d
399401
@. ∂ᶜ𝔼ₜ∂ᶠ𝕄 =
400-
-(ᶜdivᵥ_matrix()) DiagonalMatrixRow(ᶠinterp(ᶜρ)) (
402+
-(ᶜdivᵥ_matrix()) *
403+
DiagonalMatrixRow(ᶠinterp(ᶜρ)) *
404+
(
401405
DiagonalMatrixRow(
402406
vec_data(
403407
ᶠno_flux(
404408
ᶠupwind_product(ᶠw + εw, (ᶜρe + ᶜp) / ᶜρ),
405409
),
406410
) / vec_data(CT3(ᶠw + εw)) * g³³(ᶠgⁱʲ),
407411
) +
408-
ᶠno_flux_row(ᶠupwind_product_matrix(ᶠw))
412+
ᶠno_flux_row(ᶠupwind_product_matrix(ᶠw)) *
409413
(-R_d / cv_d * ∂ᶜK∂ᶠw)
410414
)
411415
end
@@ -414,11 +418,11 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
414418
# ∂ᶜ𝔼ₜ∂ᶠ𝕄 has 3 diagonals instead of 5
415419
if isnothing(ᶠupwind_product)
416420
@. ∂ᶜ𝔼ₜ∂ᶠ𝕄 =
417-
-(ᶜdivᵥ_matrix())
421+
-(ᶜdivᵥ_matrix()) *
418422
DiagonalMatrixRow(ᶠinterp(ᶜρe + ᶜp) * g³³(ᶠgⁱʲ))
419423
else
420424
@. ∂ᶜ𝔼ₜ∂ᶠ𝕄 =
421-
-(ᶜdivᵥ_matrix()) DiagonalMatrixRow(
425+
-(ᶜdivᵥ_matrix()) * DiagonalMatrixRow(
422426
ᶠinterp(ᶜρ) * vec_data(
423427
ᶠno_flux(ᶠupwind_product(ᶠw + εw, (ᶜρe + ᶜp) / ᶜρ)),
424428
) / vec_data(CT3(ᶠw + εw)) * g³³(ᶠgⁱʲ),
@@ -443,9 +447,9 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
443447
# -ᶜdivᵥ_matrix() * ᶠinterp(ᶜρe_int + ᶜp) * ᶠg³³ +
444448
# ᶜinterp_matrix() * adjoint(ᶠgradᵥ(ᶜp)) * ᶠg³³
445449
@. ∂ᶜ𝔼ₜ∂ᶠ𝕄 =
446-
-(ᶜdivᵥ_matrix())
450+
-(ᶜdivᵥ_matrix()) *
447451
DiagonalMatrixRow(ᶠinterp(ᶜρe_int + ᶜp) * g³³(ᶠgⁱʲ)) +
448-
ᶜinterp_matrix()
452+
ᶜinterp_matrix() *
449453
DiagonalMatrixRow(adjoint(ᶠgradᵥ(ᶜp)) * g³³(ᶠgⁱʲ))
450454
else
451455
# ᶜρe_intₜ =
@@ -456,12 +460,12 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
456460
# ∂(ᶠupwind_product(ᶠw, (ᶜρe_int + ᶜp) / ᶜρ))/∂(ᶠw) +
457461
# ᶜinterp_matrix() * adjoint(ᶠgradᵥ(ᶜp)) * ᶠg³³
458462
@. ∂ᶜ𝔼ₜ∂ᶠ𝕄 =
459-
-(ᶜdivᵥ_matrix()) DiagonalMatrixRow(
463+
-(ᶜdivᵥ_matrix()) * DiagonalMatrixRow(
460464
ᶠinterp(ᶜρ) * vec_data(
461465
ᶠno_flux(ᶠupwind_product(ᶠw + εw, (ᶜρe_int + ᶜp) / ᶜρ)),
462466
) / vec_data(CT3(ᶠw + εw)) * g³³(ᶠgⁱʲ),
463467
) +
464-
ᶜinterp_matrix()
468+
ᶜinterp_matrix() *
465469
DiagonalMatrixRow(adjoint(ᶠgradᵥ(ᶜp)) * g³³(ᶠgⁱʲ))
466470
end
467471
end
@@ -480,7 +484,8 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
480484
# ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρθ) =
481485
# ᶠgradᵥ_matrix() * γ * R_d * (ᶜρθ * R_d / p_0)^(γ - 1)
482486
@. ∂ᶠ𝕄ₜ∂ᶜ𝔼 =
483-
-DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) ᶠgradᵥ_matrix()
487+
-DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) *
488+
ᶠgradᵥ_matrix() *
484489
DiagonalMatrixRow* R_d * (ᶜρθ * R_d / p_0)^- 1))
485490

486491
if flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :exact
@@ -489,20 +494,20 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
489494
# ∂(ᶠwₜ)/∂(ᶠinterp(ᶜρ)) = ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2
490495
# ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) = ᶠinterp_matrix()
491496
@. ∂ᶠ𝕄ₜ∂ᶜρ =
492-
DiagonalMatrixRow(ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2) ᶠinterp_matrix()
497+
DiagonalMatrixRow(ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2) * ᶠinterp_matrix()
493498
elseif flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :hydrostatic_balance
494499
# same as above, but we assume that ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) =
495500
# -ᶠgradᵥ(ᶜΦ)
496501
@. ∂ᶠ𝕄ₜ∂ᶜρ =
497-
-DiagonalMatrixRow(ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ)) ᶠinterp_matrix()
502+
-DiagonalMatrixRow(ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ)) * ᶠinterp_matrix()
498503
end
499504
elseif :ρe in propertynames(Y.c)
500505
# ᶠwₜ = -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) - ᶠgradᵥ(ᶜK + ᶜΦ)
501506
# ∂(ᶠwₜ)/∂(ᶜρe) = ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) * ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρe)
502507
# ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) = -1 / ᶠinterp(ᶜρ)
503508
# ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρe) = ᶠgradᵥ_matrix() * R_d / cv_d
504509
@. ∂ᶠ𝕄ₜ∂ᶜ𝔼 =
505-
-DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) (ᶠgradᵥ_matrix() * R_d / cv_d)
510+
-DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) * (ᶠgradᵥ_matrix() * R_d / cv_d)
506511

507512
if flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :exact
508513
# ᶠwₜ = -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) - ᶠgradᵥ(ᶜK + ᶜΦ)
@@ -515,24 +520,26 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
515520
# ∂(ᶠwₜ)/∂(ᶠinterp(ᶜρ)) = ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2
516521
# ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) = ᶠinterp_matrix()
517522
@. ∂ᶠ𝕄ₜ∂ᶜρ =
518-
-DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) ᶠgradᵥ_matrix()
523+
-DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) *
524+
ᶠgradᵥ_matrix() *
519525
DiagonalMatrixRow(R_d * (-(ᶜK + ᶜΦ) / cv_d + T_tri)) +
520-
DiagonalMatrixRow(ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2) ᶠinterp_matrix()
526+
DiagonalMatrixRow(ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2) * ᶠinterp_matrix()
521527
elseif flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :hydrostatic_balance
522528
# same as above, but we assume that ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) =
523529
# -ᶠgradᵥ(ᶜΦ) and that ᶜK is negligible compared ot ᶜΦ
524530
@. ∂ᶠ𝕄ₜ∂ᶜρ =
525-
-DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) ᶠgradᵥ_matrix()
531+
-DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) *
532+
ᶠgradᵥ_matrix() *
526533
DiagonalMatrixRow(R_d * (-(ᶜΦ) / cv_d + T_tri)) -
527-
DiagonalMatrixRow(ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ)) ᶠinterp_matrix()
534+
DiagonalMatrixRow(ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ)) * ᶠinterp_matrix()
528535
end
529536
elseif :ρe_int in propertynames(Y.c)
530537
# ᶠwₜ = -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) - ᶠgradᵥ(ᶜK + ᶜΦ)
531538
# ∂(ᶠwₜ)/∂(ᶜρe_int) = ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) * ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρe_int)
532539
# ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜp)) = -1 / ᶠinterp(ᶜρ)
533540
# ∂(ᶠgradᵥ(ᶜp))/∂(ᶜρe_int) = ᶠgradᵥ_matrix() * R_d / cv_d
534541
@. ∂ᶠ𝕄ₜ∂ᶜ𝔼 =
535-
DiagonalMatrixRow(-1 / ᶠinterp(ᶜρ)) (ᶠgradᵥ_matrix() * R_d / cv_d)
542+
DiagonalMatrixRow(-1 / ᶠinterp(ᶜρ)) * (ᶠgradᵥ_matrix() * R_d / cv_d)
536543

537544
if flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :exact
538545
# ᶠwₜ = -ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) - ᶠgradᵥ(ᶜK + ᶜΦ)
@@ -544,16 +551,16 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
544551
# ∂(ᶠwₜ)/∂(ᶠinterp(ᶜρ)) = ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2
545552
# ∂(ᶠinterp(ᶜρ))/∂(ᶜρ) = ᶠinterp_matrix()
546553
@. ∂ᶠ𝕄ₜ∂ᶜρ =
547-
-DiagonalMatrixRow(1 / ᶠinterp(ᶜρ))
554+
-DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) *
548555
(ᶠgradᵥ_matrix() * R_d * T_tri) +
549-
DiagonalMatrixRow(ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2) ᶠinterp_matrix()
556+
DiagonalMatrixRow(ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ)^2) * ᶠinterp_matrix()
550557
elseif flags.∂ᶠ𝕄ₜ∂ᶜρ_mode == :hydrostatic_balance
551558
# same as above, but we assume that ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) =
552559
# -ᶠgradᵥ(ᶜΦ)
553560
@. ∂ᶠ𝕄ₜ∂ᶜρ =
554-
DiagonalMatrixRow(-1 / ᶠinterp(ᶜρ))
561+
DiagonalMatrixRow(-1 / ᶠinterp(ᶜρ)) *
555562
(ᶠgradᵥ_matrix() * R_d * T_tri) -
556-
DiagonalMatrixRow(ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ)) ᶠinterp_matrix()
563+
DiagonalMatrixRow(ᶠgradᵥ(ᶜΦ) / ᶠinterp(ᶜρ)) * ᶠinterp_matrix()
557564
end
558565
end
559566

@@ -571,13 +578,14 @@ function implicit_equation_jacobian!(j, Y, p, δtγ, t)
571578
# ∂(ᶠwₜ)/∂(ᶠgradᵥ(ᶜK + ᶜΦ)) = -1
572579
# ∂(ᶠgradᵥ(ᶜK + ᶜΦ))/∂(ᶜK) = ᶠgradᵥ_matrix()
573580
if :ρθ in propertynames(Y.c) || :ρe_int in propertynames(Y.c)
574-
@. ∂ᶠ𝕄ₜ∂ᶠ𝕄 = -(ᶠgradᵥ_matrix()) ∂ᶜK∂ᶠw
581+
@. ∂ᶠ𝕄ₜ∂ᶠ𝕄 = -(ᶠgradᵥ_matrix()) * ∂ᶜK∂ᶠw
575582
elseif :ρe in propertynames(Y.c)
576583
@. ∂ᶠ𝕄ₜ∂ᶠ𝕄 =
577584
-(
578-
DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) ᶠgradᵥ_matrix()
585+
DiagonalMatrixRow(1 / ᶠinterp(ᶜρ)) *
586+
ᶠgradᵥ_matrix() *
579587
DiagonalMatrixRow(-(ᶜρ * R_d / cv_d)) + ᶠgradᵥ_matrix()
580-
) ∂ᶜK∂ᶠw
588+
) * ∂ᶜK∂ᶠw
581589
end
582590

583591
I = one(∂R∂Y)

src/MatrixFields/MatrixFields.jl

+31-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ matrices, one for each column of the `Field`. Such `Field`s are called
1212
for them:
1313
- Constructors, e.g., `matrix_field = @. BidiagonalMatrixRow(field1, field2)`
1414
- Linear combinations, e.g., `@. 3 * matrix_field1 + matrix_field2 / 3`
15-
- Matrix-vector multiplication, e.g., `@. matrix_field field`
16-
- Matrix-matrix multiplication, e.g., `@. matrix_field1 matrix_field2`
15+
- Matrix-vector multiplication, e.g., `@. matrix_field * field`
16+
- Matrix-matrix multiplication, e.g., `@. matrix_field1 * matrix_field2`
1717
- Compatibility with `LinearAlgebra.I`, e.g., `@. matrix_field = (4I,)` or
1818
`@. matrix_field - (4I,)`
1919
- Integration with `RecursiveApply`, e.g., the entries of `matrix_field` can be
@@ -107,6 +107,35 @@ include("field_matrix_solver.jl")
107107
include("field_matrix_iterative_solver.jl")
108108
include("field_matrix_with_solver.jl")
109109

110+
const FieldOrStencilStyleType = Union{
111+
Fields.Field,
112+
Base.Broadcast.Broadcasted{<:Fields.AbstractFieldStyle},
113+
Operators.StencilBroadcasted{<:Operators.AbstractStencilStyle},
114+
}
115+
116+
Base.Broadcast.broadcasted(
117+
::typeof(*),
118+
field_or_broadcasted::FieldOrStencilStyleType,
119+
args...,
120+
) =
121+
unrolled_reduce(args; init = field_or_broadcasted) do arg1, arg2
122+
is_matrix_multiplication =
123+
eltype(arg1) <: BandMatrixRow && arg2 isa FieldOrStencilStyleType
124+
op = is_matrix_multiplication ? MultiplyColumnwiseBandMatrixField() :
125+
Base.Broadcast.broadcasted(op, arg1, arg2)
126+
end
127+
Base.Broadcast.broadcasted(
128+
::typeof(*),
129+
single_value_or_broadcasted::SingleValueStyleType,
130+
field_or_broadcasted::FieldOrStencilStyleType,
131+
args...,
132+
) = Base.Broadcast.broadcasted(
133+
,
134+
single_value_or_broadcasted,
135+
Base.Broadcast.broadcasted(*, field_or_broadcasted, args...),
136+
)
137+
# TODO: Generalize this to handle, e.g., @. scalar * scalar * matrix * matrix.
138+
110139
function Base.show(io::IO, field::ColumnwiseBandMatrixField)
111140
print(io, eltype(field), "-valued Field")
112141
if eltype(eltype(field)) <: Number

src/MatrixFields/field_name_dict.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ function Base.Broadcast.broadcasted(
553553
elseif entry2 isa ScalingFieldMatrixEntry
554554
Base.Broadcast.broadcasted(*, entry1, (scaling_value(entry2),))
555555
else
556-
Base.Broadcast.broadcasted(, entry1, entry2)
556+
Base.Broadcast.broadcasted(*, entry1, entry2)
557557
end
558558
end
559559
length(summand_bcs) == 1 ? summand_bcs[1] :

0 commit comments

Comments
 (0)