Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
csaybar committed Oct 28, 2024
1 parent 5c79b22 commit ef6733e
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 32 deletions.
92 changes: 70 additions & 22 deletions supers2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def predict(

# Check if the models are loaded
if models is None:
models = setmodel(resolution=resolution)
models = setmodel(resolution=resolution, device=X.device)

# if resolution is 10m
if resolution == "10m":
Expand Down Expand Up @@ -232,7 +232,8 @@ def predict_large(
output_fullname (Union[str, pathlib.Path]): The output image with the S2 bands
resolution (Literal["2.5m", "5m", "10m"], optional): The final resolution of the
tensor. Defaults to "2.5m".
models (Optional[dict], optional): The dictionary with the loaded models. Defaults to None.
models (Optional[dict], optional): The dictionary with the loaded models. Defaults
to None.
Returns:
pathlib.Path: The path to the output image
Expand Down Expand Up @@ -263,7 +264,7 @@ def predict_large(
overlap=overlap,
)

# Define the output metadata and create the output image
# Define the output metadata
output_metadata = metadata.copy()
output_metadata["width"] = metadata["width"] * res_n
output_metadata["height"] = metadata["height"] * res_n
Expand All @@ -276,34 +277,81 @@ def predict_large(
metadata["transform"].f,
)
output_metadata["blockxsize"] = 128 * res_n
output_metadata["blockysize"] = 128 * res_n
output_metadata["blockysize"] = 128 * res_n

# Create the output image
with rio.open(output_fullname, "w", **output_metadata) as dst:
data_np = np.zeros(
(metadata["count"], metadata["height"] * res_n, metadata["width"] * res_n),
dtype=np.uint16,
)
dst.write(data_np)

# Check if the models are loaded
if models is None:
models = setmodel(resolution=resolution, device=device)

# Iterate over the image
for index in tqdm.tqdm(nruns):

# Read a block of the image
with rio.open(output_fullname, "r+") as dst:
with rio.open(image_fullname) as src:
window = rio.windows.Window(index[1], index[0], 128, 128)
X = torch.from_numpy(src.read(window=window)).float().to(device)

# Predict the super-resolution
result = predict(X=X / 10_000, models=models, resolution=resolution) * 10_000
result[result < 0] = 0
result = result.cpu().numpy().astype(np.uint16)

# Write the block to the output
with rio.open(output_fullname, "r+") as dst:
# Define your patch (x_off, y_off, width, height)
window = rio.windows.Window(
index[1] * res_n, index[0] * res_n, 128 * res_n, 128 * res_n
)
dst.write(result, window=window)
for index, point in enumerate(tqdm.tqdm(nruns)):

# Read a block of the image
window = rio.windows.Window(point[1], point[0], 128, 128)
X = torch.from_numpy(src.read(window=window)).float().to(device)

# Predict the super-resolution
result = predict(X=X / 10_000, models=models, resolution=resolution) * 10_000
result[result < 0] = 0
result = result.cpu().numpy().astype(np.uint16)

# Define the offset in the output space
# If the point is at the border, the offset is 0
# otherwise consider the overlap
if point[1] == 0:
offset_x = 0
else:
offset_x = point[1] * res_n + overlap * res_n // 2

if point[0] == 0:
offset_y = 0
else:
offset_y = point[0] * res_n + overlap * res_n // 2

# Define the length of the patch
# The patch is always 224x224
# There is three conditions:
# - The patch is at the corner begining (0, *) or (*, 0)
# - The patch is at the corner ending (width, *) or (*, height)
# - The patch is in the middle of the image
if offset_x == 0:
skip = overlap * res_n // 2
length_x = 128 * res_n - skip
result = result[:, :, :length_x]
elif (offset_x + 128) == metadata["width"]:
length_x = 128 * res_n
result = result[:, :, :length_x]
else:
skip = overlap * res_n // 2
length_x = 128 * res_n - skip
result = result[:, :, skip:(128 * res_n)]

# Do the same for the Y axis
if offset_y == 0:
skip = overlap * res_n // 2
length_y = 128 * res_n - skip
result = result[:, :length_y, :]
elif (offset_y + 128) == metadata["height"]:
length_y = 128 * res_n
result = result[:, :length_y, :]
else:
skip = overlap * res_n // 2
length_y = 128 * res_n - overlap * res_n // 2
result = result[:, skip:(128 * res_n), :]

# Write the result in the output image
window = rio.windows.Window(offset_x, offset_y, length_x, length_y)
dst.write(result, window=window)

return pathlib.Path(output_fullname)

Expand Down
43 changes: 33 additions & 10 deletions supers2/xai/lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ def Path_gradient(
- results_numpy (np.ndarray): Model outputs for each interpolated image.
- image_interpolation (np.ndarray): Interpolated images created by `path_interpolation_func`.
"""
# Set the model to training mode
model = model.train()

# Prepare image for interpolation and initialize gradient accumulation array
image_interpolation, lambda_derivative_interpolation, _ = path_interpolation_func(
Expand Down Expand Up @@ -270,11 +268,13 @@ def lam(
fold: Optional[int] = 25,
kernel_size: Optional[int] = 13,
sigma: Optional[float] = 3.5,
robustness_metric: Optional[str] = True
):
"""
Computes the Local Attribution Map (LAM) for an input tensor using a specified model
and attribution function. The function calculates the path gradient for each band in the
input tensor and combines the results to generate the LAM.
Computes the Local Attribution Map (LAM) for an input tensor using
a specified model and attribution function. The function calculates
the path gradient for each band in the input tensor and combines the
results to generate the LAM.
Args:
X (torch.Tensor): Input tensor of shape (channels, height, width).
Expand All @@ -283,11 +283,28 @@ def lam(
h (int): The top coordinate of the window within the image.
w (int): The left coordinate of the window within the image.
window (int, optional): The size of the square window. Defaults to 16.
fold (int, optional): Number of interpolation steps for the blurring path. Defaults to 10.
fold (int, optional): Number of interpolation steps for the blurring path.
Defaults to 10.
kernel_size (int, optional): Size of the Gaussian kernel. Defaults to 5.
sigma (float, optional): Initial standard deviation for the Gaussian blur.
Defaults to 3.5.
robustness_metric (bool, optional): Whether to return the robustness metric.
Defaults to True.
Returns:
torch.Tensor: The Local Attribution Map (LAM) for the input tensor.
tuple: A tuple containing the following elements:
- kde_map (np.ndarray): KDE estimation of the LAM.
- complexity_metric (float): Gini index of the LAM that
measures the consistency of the attribution. The
larger the value, the more use more complex attribution
patterns to solve the task.
- robustness_metric (np.ndarray): Blurriness sensitivity of the LAM.
The sensitivity measures the average gradient magnitude of the
interpolated images.
- robustness_vector (np.ndarray): Vector of gradient magnitudes for
each interpolated image.
"""

# Get the scale of the results
with torch.no_grad():
output = model(X[None])
Expand All @@ -303,10 +320,10 @@ def lam(
attr_objective = attribution_objective(attr_grad, h, w, window=window)

# Compute the path gradient for the input tensor
grad_accumulate_list, _, _ = Path_gradient(
grad_accumulate_list,results_numpy, image_interpolation = Path_gradient(
X, model, attr_objective, path_interpolation_func
)

# Sum the accumulated gradients across all bands
lam_results = torch.sum(torch.from_numpy(np.abs(grad_accumulate_list)), dim=0)
grad_2d = np.abs(lam_results.sum(axis=0))
Expand All @@ -318,5 +335,11 @@ def lam(

# KDE estimation
kde_map = vis_saliency_kde(grad_norm, scale=scale, bandwidth=1.0)
complexity_metric = (1 - gini_index) * 100

# Estimate blurriness sensitivity
robustness_vector = np.abs(grad_accumulate_list).mean(axis=(1, 2, 3))
robustness_metric = np.trapz(robustness_vector)

return kde_map, (1 - gini_index) * 100
# Return the LAM results
return kde_map, complexity_metric, robustness_metric, robustness_vector

0 comments on commit ef6733e

Please sign in to comment.