Skip to content

Commit

Permalink
wip tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
talonchandler committed Feb 18, 2025
1 parent 9028e16 commit b7499cb
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
11 changes: 11 additions & 0 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@
import torch
import pytest

def test_apply_transfer_function_filter_tiled():
input_array = torch.ones((100, 100))
transfer_function = torch.tensor([[1, 0], [0, 0]])
tile_size = (35, 35)
overlap_size = (5, 5)

result = filter.apply_transfer_function_filter_tiled(
transfer_function, input_array, tile_size, overlap_size
)
assert result == 0


def test_apply_transfer_function_filter():
input_array = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]])
Expand Down
75 changes: 72 additions & 3 deletions waveorder/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,73 @@
import torch


def apply_transfer_function_filter_tiled(
transfer_function: torch.Tensor,
input_array: torch.Tensor,
tile_size: tuple,
overlap_size: tuple,
) -> torch.Tensor:

# Get the shape of the input array
input_shape = input_array.shape

# Calculate the number of tiles in each dimension
num_tiles = tuple(
(input_shape[dim] + tile_size[dim] - 1) // tile_size[dim]
for dim in range(len(tile_size))
)

# Initialize the output array
output_array = torch.zeros_like(input_array)

# Iterate over each tile
for tile_idx in np.ndindex(num_tiles):
# Calculate the start and end indices for each dimension
start_indices = tuple(
max(0, tile_idx[dim] * tile_size[dim] - overlap_size[dim])
for dim in range(len(tile_size))
)
end_indices = tuple(
min(
input_shape[dim],
(tile_idx[dim] + 1) * tile_size[dim] + overlap_size[dim],
)
for dim in range(len(tile_size))
)

# Extract the tile from the input array
tile_slices = tuple(
slice(start, end) for start, end in zip(start_indices, end_indices)
)
input_tile = input_array[tile_slices]

# Apply the transfer function filter to the tile
filtered_tile = apply_filter_bank(transfer_function, input_tile)

# Calculate the region to add the filtered tile to the output array
output_start_indices = tuple(
tile_idx[dim] * tile_size[dim] for dim in range(len(tile_size))
)
output_end_indices = tuple(
min(
input_shape[dim],
output_start_indices[dim] + filtered_tile.shape[dim],
)
for dim in range(len(tile_size))
)
output_slices = tuple(
slice(start, end)
for start, end in zip(output_start_indices, output_end_indices)
)
import pdb

pdb.set_trace()
# Add the filtered tile to the output array
output_array[output_slices] += filtered_tile

return output_array


def apply_filter_bank(
io_filter_bank: torch.Tensor,
i_input_array: torch.Tensor,
Expand Down Expand Up @@ -49,7 +116,7 @@ def apply_filter_bank(
torch.Tensor
The filtered output array with the same shape and dtype as input_array.
"""

# Ensure all dimensions of transfer_function are smaller than or equal to input_array
if any(
t > i
Expand Down Expand Up @@ -83,7 +150,7 @@ def apply_filter_bank(
]
flat_pad_sizes = list(itertools.chain(*pad_sizes))
padded_input_array = torch.nn.functional.pad(i_input_array, flat_pad_sizes)

# Apply the transfer function in the frequency domain
fft_dims = [d for d in range(1, i_input_array.ndim)]
padded_input_spectrum = torch.fft.fftn(padded_input_array, dim=fft_dims)
Expand All @@ -92,7 +159,9 @@ def apply_filter_bank(
# If this is a bottleneck, consider extending `stretched_multiply` to
# a `stretched_matrix_multiply` that uses an call like
# torch.einsum('io..., i... -> o...', io_filter_bank, padded_input_spectrum)
padded_output_spectrum = torch.zeros((O,) + spatial_dims, dtype=padded_input_spectrum.dtype)
padded_output_spectrum = torch.zeros(
(O,) + spatial_dims, dtype=padded_input_spectrum.dtype
)
for i_idx in range(I):
for o_idx in range(O):
padded_output_spectrum[o_idx] += stretched_multiply(
Expand Down

0 comments on commit b7499cb

Please sign in to comment.