From 07d8f004524c46286052d6f2e188f78cf8a7dbcf Mon Sep 17 00:00:00 2001 From: Talon Chandler Date: Fri, 31 Jan 2025 11:55:34 -0800 Subject: [PATCH] stretched_multiply functions & tests --- tests/test_filter.py | 42 ++++++++++++++++++ waveorder/filter.py | 103 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 tests/test_filter.py create mode 100644 waveorder/filter.py diff --git a/tests/test_filter.py b/tests/test_filter.py new file mode 100644 index 00000000..f3d42714 --- /dev/null +++ b/tests/test_filter.py @@ -0,0 +1,42 @@ +from waveorder import filter +import torch +import pytest + + +def test_stretched_multiply(): + small_array = torch.tensor([[1, 2], [3, 4]]) + large_array = torch.tensor( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] + ) + result = filter.stretched_multiply(small_array, large_array) + expected = torch.tensor( + [[1, 2, 6, 8], [5, 6, 14, 16], [27, 30, 44, 48], [39, 42, 60, 64]] + ) + assert torch.all(result == expected) + assert torch.all( + filter.stretched_multiply(large_array, large_array) == large_array**2 + ) + + # Test that output dims are correct + rand_array_3x3x3 = torch.rand((3, 3, 3)) + rand_array_100x100x100 = torch.rand((100, 100, 100)) + result = filter.stretched_multiply( + rand_array_3x3x3, rand_array_100x100x100 + ) + assert result.shape == (100, 100, 100) + + +def test_stretched_multiply_incompatible_dims(): + # small_array > large_array + small_array = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + large_array = torch.tensor([[1, 2], [3, 4]]) + with pytest.raises(ValueError): + filter.stretched_multiply(small_array, large_array) + + # Mismatched dims + small_array = torch.tensor([[1, 2], [3, 4]]) + large_array = torch.tensor( + [[[1, 2], [4, 5], [7, 8]], [[10, 11], [13, 14], [16, 17]]] + ) + with pytest.raises(ValueError): + filter.stretched_multiply(small_array, large_array) diff --git a/waveorder/filter.py b/waveorder/filter.py new file mode 100644 index 00000000..698badce --- /dev/null +++ b/waveorder/filter.py @@ -0,0 +1,103 @@ +import numpy as np +import torch + + +def stretched_multiply( + small_array: torch.Tensor, large_array: torch.Tensor +) -> torch.Tensor: + """ + Effectively "stretches" small_array onto large_array before multiplying. + + Instead of upsampling small_array, this function uses a "block element-wise" + multiplication by breaking the large_array into blocks before + element-wise multiplication with the small_array. + + For example, a `stretched_multiply` of a 3x3 array by a 100x100 array will + break the 100x100 array into 3x3 blocks, with sizes + [[34x34, 33x34, 33x33], + [34x33, 33x33, 33x33], + [34x33, 33x33, 33x33]] + and multiply each block by the corresponding element in the 3x3 array. + + Returns an array with the same shape as large_array. + + Works for arbitrary dimensions. + + Parameters + ---------- + small_array : torch.Tensor + A smaller array whose elements will be "stretched" onto blocks in the large array. + large_array : torch.Tensor + A larger array that will be divided into blocks and multiplied by the small array. + + Returns + ------- + torch.Tensor + Resulting tensor with shape matching large_array. + + Example + ------- + small_array = torch.tensor([[1, 2], + [3, 4]]) + + large_array = torch.tensor([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]]) + + stretched_multiply(small_array, large_array) returns + + [[ 1, 2, 6, 8], + [ 5, 6, 14, 16], + [ 27, 30, 44, 48], + [ 39, 42, 60, 64]] + """ + # Ensure all dimensions of small_array are smaller than large_array + if any(s > l for s, l in zip(small_array.shape, large_array.shape)): + raise ValueError( + "All dimensions of small_array must be <= large_array" + ) + + # Ensure the number of dimensions match + if small_array.ndim != large_array.ndim: + raise ValueError( + "small_array and large_array must have the same number of dimensions" + ) + + # Get shapes + s_shape = small_array.shape + l_shape = large_array.shape + + # Compute base block sizes using integer division + # This gives the approximate size of each block before handling remainders + base_block_size = tuple(l // s for l, s in zip(l_shape, s_shape)) + + # Compute remainder (extra elements that don't fit evenly) + remainder = tuple(l % s for l, s in zip(l_shape, s_shape)) + + # Compute block boundaries by spreading remainder elements evenly + indices = [] + for b, s, r in zip(base_block_size, s_shape, remainder): + # Create an array where the first 'r' blocks get an extra +1 + # Example: if b=9, s=11, r=1 -> first remainder block gets size 10 instead of 9 + block_sizes = [(b + 1) if i < r else b for i in range(s)] + + # Compute cumulative sum to determine start and end indices + # Example: [0, 10, 19, 28, ..., 100] gives the split points + indices.append(np.cumsum([0] + block_sizes)) + + # Create output array initialized as a copy of the large array + result = large_array.clone() + + # Iterate over small_array indices using ndindex + for idx in np.ndindex(s_shape): + # Extract multi-dimensional indices + slices = tuple( + slice(indices[dim][idx[dim]], indices[dim][idx[dim] + 1]) + for dim in range(small_array.ndim) + ) + + # Multiply the corresponding block in the large array by the value from the small array + result[slices] *= small_array[idx] + + return result