Skip to content

Commit

Permalink
debugging progress
Browse files Browse the repository at this point in the history
  • Loading branch information
talonchandler committed Oct 7, 2024
1 parent b3ec8c6 commit 8e1f292
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions waveorder/models/inplane_oriented_thick_pol3d_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np

from torch import Tensor
from typing import Literal
from torch.nn.functional import avg_pool3d
from waveorder import optics, sampling, stokes, util
from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
Expand All @@ -16,8 +17,8 @@ def generate_test_phantom(zyx_shape):
margin=50,
)
c00 = yx_star
c2_2 = -torch.sin(2 * yx_theta) * yx_star
c22 = -torch.cos(2 * yx_theta) * yx_star
c2_2 = -torch.sin(2 * yx_theta) * yx_star #torch.zeros_like(c00)
c22 = -torch.cos(2 * yx_theta) * yx_star #torch.zeros_like(c00) #

# Put in a center slices of a 3D object
center_slice_object = torch.stack((c00, c2_2, c22), dim=0)
Expand Down Expand Up @@ -117,7 +118,8 @@ def _calculate_wrap_unsafe_transfer_function(
swing, scheme=scheme
)

input_jones = torch.tensor([0.0 + 1.0j, 1.0 + 0j]) # circular
input_jones = torch.tensor([0.0 - 1.0j, 1.0 + 0j]) # circular
# input_jones = torch.tensor([0 + 0j, 1 + 0j]) # linear

# Calculate frequencies
y_frequencies, x_frequencies = util.generate_frequencies(
Expand All @@ -131,6 +133,7 @@ def _calculate_wrap_unsafe_transfer_function(
)
if invert_phase_contrast:
z_position_list = torch.flip(z_position_list, dims=(0,))
z_frequencies = torch.fft.fftfreq(z_total, d=z_pixel_size)

# 2D pupils
ill_pupil = optics.generate_pupil(
Expand Down Expand Up @@ -195,13 +198,28 @@ def _calculate_wrap_unsafe_transfer_function(
G_3D = torch.abs(torch.fft.ifft(G, dim=-3)) * (-1j)
S_3D = torch.fft.ifft(S, dim=-3)

# cleanup
freq_shape = z_position_list.shape + x_frequencies.shape

z_broadcast = torch.broadcast_to(z_frequencies[:, None, None], freq_shape)
y_broadcast = torch.broadcast_to(y_frequencies[None, :, :], freq_shape)
x_broadcast = torch.broadcast_to(x_frequencies[None, :, :], freq_shape)

nu_rr = torch.sqrt(z_broadcast**2 + y_broadcast**2 + x_broadcast**2)
wavelength = wavelength_illumination / index_of_refraction_media
nu_max = (17 / 16) / (wavelength)
nu_min = (15 / 16) / (wavelength)

mask = torch.logical_and(nu_rr < nu_max, nu_rr > nu_min)

P_3D *= mask
G_3D *= mask
S_3D *= mask

# Main part
PG_3D = torch.einsum("zyx,ipzyx->ipzyx", P_3D, G_3D)
PS_3D = torch.einsum("zyx,jzyx,kzyx->jkzyx", P_3D, S_3D, torch.conj(S_3D))

PG_3D /= torch.amax(torch.abs(PG_3D))
PS_3D /= torch.amax(torch.abs(PS_3D))

pg = torch.fft.fftn(PG_3D, dim=(-3, -2, -1))
ps = torch.fft.fftn(PS_3D, dim=(-3, -2, -1))
Expand Down Expand Up @@ -267,14 +285,17 @@ def apply_transfer_function(
)
szyx_data = torch.fft.ifftn(sZYX_data, dim=(1, 2, 3))

return (50 * szyx_data) + 0.1 * torch.randn(szyx_data.shape)
return 50 * szyx_data # + 0.1 * torch.randn(szyx_data.shape)


def apply_inverse_transfer_function(
szyx_data: Tensor,
singular_system: tuple[Tensor],
intensity_to_stokes_matrix: Tensor,
reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
regularization_strength: float = 1e-3,
TV_rho_strength: float = 1e-3,
TV_iterations: int = 10,
):
sZYX_data = torch.fft.fftn(szyx_data, dim=(1, 2, 3))

Expand Down

0 comments on commit 8e1f292

Please sign in to comment.