Skip to content

Commit

Permalink
revised strategy --- prepad, require stretch-multiply to be divisible…
Browse files Browse the repository at this point in the history
…, then crop
  • Loading branch information
talonchandler committed Feb 8, 2025
1 parent f62f909 commit 16c87a7
Showing 1 changed file with 56 additions and 40 deletions.
96 changes: 56 additions & 40 deletions waveorder/filter.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import itertools
import numpy as np
import torch


def apply_transfer_function_filter(
transfer_function: torch.Tensor, input_array: torch.Tensor,
transfer_function: torch.Tensor,
input_array: torch.Tensor,
) -> torch.Tensor:
"""
Applies a transfer function filter to an input array.
transfer_function.shape must be smaller or equal to input_array.shape in all
transfer_function.shape must be smaller or equal to input_array.shape in all
dimensions. When transfer_function is smaller, it is effectively "stretched"
to apply the filter.
transfer_function is in "wrapped" format, i.e., the zero frequency is the
zeroth element.
zeroth element.
input_array and transfer_function must have inverse sample spacing, i.e.,
is input_array contains samples spaced by dx, then transfer_function must
have extent 1/dx. Note that there is no need for transfer_function to have
sample spacing 1/(n*dx) because transfer_function will be stretched.
is input_array contains samples spaced by dx, then transfer_function must
have extent 1/dx. Note that there is no need for transfer_function to have
sample spacing 1/(n*dx) because transfer_function will be stretched.
Parameters
----------
Expand Down Expand Up @@ -45,32 +48,50 @@ def apply_transfer_function_filter(
"transfer_function and input_array must have the same number of dimensions"
)

input_spectrum = torch.fft.fftn(input_array)
output_spectrum = stretched_multiply(transfer_function, input_spectrum)
# Pad input_array until each dimension is divisible by transfer_function
pad_sizes = [
(0, (t - (i % t)) % t)
for t, i in zip(transfer_function.shape[::-1], input_array.shape[::-1])
]
flat_pad_sizes = list(itertools.chain(*pad_sizes))
padded_input_array = torch.nn.functional.pad(input_array, flat_pad_sizes)

# Apply the transfer function in the frequency domain
padded_input_spectrum = torch.fft.fftn(padded_input_array)
padded_output_spectrum = stretched_multiply(
transfer_function, padded_input_spectrum
)

# Casts to input_array dtype, which typically ignores imaginary part
result = torch.fft.ifftn(output_spectrum).type(input_array.dtype)
padded_result = torch.fft.ifftn(padded_output_spectrum).type(
input_array.dtype
)

# Remove padding and return
slices = tuple(slice(0, i) for i in input_array.shape)
return padded_result[slices]

return result

def stretched_multiply(
small_array: torch.Tensor, large_array: torch.Tensor
) -> torch.Tensor:
"""
Effectively "stretches" small_array onto large_array before multiplying.
Each dimension of large_array must be divisible by each dimension of small_array.
Instead of upsampling small_array, this function uses a "block element-wise"
multiplication by breaking the large_array into blocks before
element-wise multiplication with the small_array.
multiplication by breaking the large_array into blocks before element-wise
multiplication with the small_array.
For example, a `stretched_multiply` of a 3x3 array by a 100x100 array will
zero pad the 100x100 array so that it is divisible into 3x3 blocks with sizes
[[34x34, 34x34, 34x34],
[34x34, 34x34, 34x34],
[34x34, 34x34, 34x34]]
and multiply each block by the corresponding element in the 3x3 array.
For example, a `stretched_multiply` of a 3x3 array by a 99x99 array will
divide the 99x99 array into 33x33 blocks
[[33x33, 33x33, 33x33],
[33x33, 33x33, 33x33],
[33x33, 33x33, 33x33]]
and multiply each block by the corresponding element in the 3x3 array.
Returns an array with the same shape as large_array (padding is cropped).
Returns an array with the same shape as large_array.
Works for arbitrary dimensions.
Expand Down Expand Up @@ -103,10 +124,11 @@ def stretched_multiply(
[ 27, 30, 44, 48],
[ 39, 42, 60, 64]]
"""
# Ensure all dimensions of small_array are smaller than large_array
if any(s > l for s, l in zip(small_array.shape, large_array.shape)):

# Ensure each dimension of large_array is divisible by each dimension of small_array
if any(l % s != 0 for s, l in zip(small_array.shape, large_array.shape)):
raise ValueError(
"All dimensions of small_array must be <= large_array"
"Each dimension of large_array must be divisible by each dimension of small_array"
)

# Ensure the number of dimensions match
Expand All @@ -119,25 +141,19 @@ def stretched_multiply(
s_shape = small_array.shape
l_shape = large_array.shape

# Compute padding sizes to make large_array divisible by small_array
pad_sizes = [(0, (s - (l % s)) % s) for l, s in zip(l_shape, s_shape)]
pad_sizes_flat = [size for pair in pad_sizes for size in pair]

# Pad the large array
padded_large_array = torch.nn.functional.pad(large_array, pad_sizes_flat)

# Reshape the padded large array into blocks
new_shape = tuple(p // s for p, s in zip(padded_large_array.shape, s_shape)) + s_shape
reshaped_large_array = padded_large_array.reshape(new_shape)

# Multiply the reshaped large array with the small array
result_blocks = reshaped_large_array * small_array
# Reshape both array into blocks
block_shape = tuple(p // s for p, s in zip(l_shape, s_shape))
new_large_shape = tuple(itertools.chain(*zip(s_shape, block_shape)))
new_small_shape = tuple(
itertools.chain(*zip(s_shape, small_array.ndim * (1,)))
)
reshaped_large_array = large_array.reshape(new_large_shape)
reshaped_small_array = small_array.reshape(new_small_shape)

# Reshape the result back to the padded large array shape
result_padded = result_blocks.reshape(padded_large_array.shape)
# Multiply the reshaped arrays
reshaped_result = reshaped_large_array * reshaped_small_array

# Unpad the result to get back to the original large array shape
slices = tuple(slice(0, l) for l in l_shape)
result = result_padded[slices]
# Reshape the result back to the large array shape
result = reshaped_result.reshape(l_shape)

return result

0 comments on commit 16c87a7

Please sign in to comment.