Skip to content

Commit

Permalink
stretched_multiply functions & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
talonchandler committed Jan 31, 2025
1 parent 1451bf1 commit 07d8f00
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from waveorder import filter
import torch
import pytest


def test_stretched_multiply():
small_array = torch.tensor([[1, 2], [3, 4]])
large_array = torch.tensor(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
)
result = filter.stretched_multiply(small_array, large_array)
expected = torch.tensor(
[[1, 2, 6, 8], [5, 6, 14, 16], [27, 30, 44, 48], [39, 42, 60, 64]]
)
assert torch.all(result == expected)
assert torch.all(
filter.stretched_multiply(large_array, large_array) == large_array**2
)

# Test that output dims are correct
rand_array_3x3x3 = torch.rand((3, 3, 3))
rand_array_100x100x100 = torch.rand((100, 100, 100))
result = filter.stretched_multiply(
rand_array_3x3x3, rand_array_100x100x100
)
assert result.shape == (100, 100, 100)


def test_stretched_multiply_incompatible_dims():
# small_array > large_array
small_array = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
large_array = torch.tensor([[1, 2], [3, 4]])
with pytest.raises(ValueError):
filter.stretched_multiply(small_array, large_array)

# Mismatched dims
small_array = torch.tensor([[1, 2], [3, 4]])
large_array = torch.tensor(
[[[1, 2], [4, 5], [7, 8]], [[10, 11], [13, 14], [16, 17]]]
)
with pytest.raises(ValueError):
filter.stretched_multiply(small_array, large_array)
103 changes: 103 additions & 0 deletions waveorder/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import numpy as np
import torch


def stretched_multiply(
small_array: torch.Tensor, large_array: torch.Tensor
) -> torch.Tensor:
"""
Effectively "stretches" small_array onto large_array before multiplying.
Instead of upsampling small_array, this function uses a "block element-wise"
multiplication by breaking the large_array into blocks before
element-wise multiplication with the small_array.
For example, a `stretched_multiply` of a 3x3 array by a 100x100 array will
break the 100x100 array into 3x3 blocks, with sizes
[[34x34, 33x34, 33x33],
[34x33, 33x33, 33x33],
[34x33, 33x33, 33x33]]
and multiply each block by the corresponding element in the 3x3 array.
Returns an array with the same shape as large_array.
Works for arbitrary dimensions.
Parameters
----------
small_array : torch.Tensor
A smaller array whose elements will be "stretched" onto blocks in the large array.
large_array : torch.Tensor
A larger array that will be divided into blocks and multiplied by the small array.
Returns
-------
torch.Tensor
Resulting tensor with shape matching large_array.
Example
-------
small_array = torch.tensor([[1, 2],
[3, 4]])
large_array = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]])
stretched_multiply(small_array, large_array) returns
[[ 1, 2, 6, 8],
[ 5, 6, 14, 16],
[ 27, 30, 44, 48],
[ 39, 42, 60, 64]]
"""
# Ensure all dimensions of small_array are smaller than large_array
if any(s > l for s, l in zip(small_array.shape, large_array.shape)):
raise ValueError(
"All dimensions of small_array must be <= large_array"
)

# Ensure the number of dimensions match
if small_array.ndim != large_array.ndim:
raise ValueError(
"small_array and large_array must have the same number of dimensions"
)

# Get shapes
s_shape = small_array.shape
l_shape = large_array.shape

# Compute base block sizes using integer division
# This gives the approximate size of each block before handling remainders
base_block_size = tuple(l // s for l, s in zip(l_shape, s_shape))

# Compute remainder (extra elements that don't fit evenly)
remainder = tuple(l % s for l, s in zip(l_shape, s_shape))

# Compute block boundaries by spreading remainder elements evenly
indices = []
for b, s, r in zip(base_block_size, s_shape, remainder):
# Create an array where the first 'r' blocks get an extra +1
# Example: if b=9, s=11, r=1 -> first remainder block gets size 10 instead of 9
block_sizes = [(b + 1) if i < r else b for i in range(s)]

# Compute cumulative sum to determine start and end indices
# Example: [0, 10, 19, 28, ..., 100] gives the split points
indices.append(np.cumsum([0] + block_sizes))

# Create output array initialized as a copy of the large array
result = large_array.clone()

# Iterate over small_array indices using ndindex
for idx in np.ndindex(s_shape):
# Extract multi-dimensional indices
slices = tuple(
slice(indices[dim][idx[dim]], indices[dim][idx[dim] + 1])
for dim in range(small_array.ndim)
)

# Multiply the corresponding block in the large array by the value from the small array
result[slices] *= small_array[idx]

return result

0 comments on commit 07d8f00

Please sign in to comment.