Skip to content

Commit

Permalink
Changed the behavior of matrix multiplication to ignore zero-valued e…
Browse files Browse the repository at this point in the history
…lements to save memory. (#89)
  • Loading branch information
byrdie authored Nov 2, 2024
1 parent 09ccd0a commit 67a862c
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
6 changes: 6 additions & 0 deletions named_arrays/_scalars/scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ def __array_matmul__(
if result is not NotImplemented:
return result

if out is None:
if np.all(x1 == 0) or np.all(x2 == 0):
unit_1 = na.unit(x1, unit_dimensionless=1)
unit_2 = na.unit(x2, unit_dimensionless=1)
return 0 * unit_1 * unit_2

return np.multiply(
x1,
x2,
Expand Down
2 changes: 1 addition & 1 deletion named_arrays/_scalars/tests/test_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_matmul(

result_expected = np.multiply(array, array_2)

out = 0 * result
out = 0 * array * array_2
result_out = np.matmul(array, array_2, out=out)

assert np.all(result == result_expected)
Expand Down
2 changes: 1 addition & 1 deletion named_arrays/_vectors/cartesian/vectors_cartesian_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _sold_angle(
d3 = (b @ c) * a_
denomerator = d0 + d1 + d2 + d3

unit = numerator.unit
unit = na.unit(numerator)

if unit is not None:
numerator = numerator.to(unit).value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class TestTemporalSpectralVectorLinearSpace(
wavelength=1 * u.nm,
),
pc=na.SpectralMatrixArray(
wavelength=na.CartesianNdVectorArray(dict(wavelength=1, x=0, y=0)),
wavelength=na.CartesianNdVectorArray(dict(wavelength=1, x=0, y=0.1)),
),
shape_wcs=dict(wavelength=5, x=_num_x, y=_num_y),
),
Expand Down

0 comments on commit 67a862c

Please sign in to comment.