Skip to content

Commit

Permalink
deploy: 2362b47
Browse files Browse the repository at this point in the history
  • Loading branch information
csaybar committed Oct 24, 2024
1 parent a75a168 commit 299768a
Show file tree
Hide file tree
Showing 15 changed files with 5,432 additions and 129 deletions.
Binary file modified sitemap.xml.gz
Binary file not shown.
226 changes: 111 additions & 115 deletions supers2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,57 +9,55 @@

def setmodel(
resolution: Literal["2.5m", "5m", "10m"] = "2.5m",
SR_model_name: Literal["cnn", "swin", "mamba"] = "cnn",
SR_model_size: Literal[
"lightweight", "small", "medium", "expanded", "large"
] = "small",
SR_model_name: Literal["cnn", "swin", "mamba", "diffusion"] = "cnn",
SR_model_size: Literal["lightweight", "small", "medium", "expanded", "large"] = "small",
SR_model_loss: Literal["l1", "superloss", "adversarial"] = "l1",
Fusionx2_model_name: Literal["cnn", "swin", "mamba"] = "cnn",
Fusionx2_model_size: Literal[
"lightweight", "small", "medium", "expanded", "large"
] = "lightweight",
Fusionx2_model_size: Literal["lightweight", "small", "medium", "expanded", "large"] = "lightweight",
Fusionx4_model_name: Literal["cnn", "swin", "mamba"] = "cnn",
Fusionx4_model_size: Literal[
"lightweight", "small", "medium", "expanded", "large"
] = "lightweight",
Fusionx4_model_size: Literal["lightweight", "small", "medium", "expanded", "large"] = "lightweight",
weights_path: Union[str, pathlib.Path, None] = None,
device: str = "cpu",
**kwargs
) -> dict:
"""
Sets up models for super-resolution and fusion tasks based on the specified parameters.
Args:
resolution (Literal["2.5m", "5m", "10m"], optional):
Target spatial resolution. Determines which models to load.
resolution (Literal["2.5m", "5m", "10m"], optional):
Target spatial resolution. Determines which models to load.
Defaults to "2.5m".
SR_model_name (Literal["cnn", "swin", "mamba"], optional):
The super-resolution model to use.
SR_model_name (Literal["cnn", "swin", "mamba"], optional):
The super-resolution model to use.
Options: "cnn", "swin", "mamba". Defaults to "cnn".
SR_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional):
Size of the super-resolution model.
Options: "lightweight", "small", "medium", "expanded", "large".
SR_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional):
Size of the super-resolution model.
Options: "lightweight", "small", "medium", "expanded", "large".
Defaults to "small".
SR_model_loss (Literal["l1", "superloss", "adversarial"], optional):
Loss function used in training the super-resolution model.
SR_model_loss (Literal["l1", "superloss", "adversarial"], optional):
Loss function used in training the super-resolution model.
Options: "l1", "superloss", "adversarial". Defaults to "l1".
Fusionx2_model_name (Literal["cnn", "swin", "mamba"], optional):
Fusionx2_model_name (Literal["cnn", "swin", "mamba"], optional):
Model for Fusion X2 (e.g., 20m -> 10m resolution).
Options: "cnn", "swin", "mamba". Defaults to "cnn".
Fusionx2_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional):
Size of the Fusion X2 model.
Options: "lightweight", "small", "medium", "expanded", "large".
Fusionx2_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional):
Size of the Fusion X2 model.
Options: "lightweight", "small", "medium", "expanded", "large".
Defaults to "lightweight".
Fusionx4_model_name (Literal["cnn", "swin", "mamba"], optional):
Fusionx4_model_name (Literal["cnn", "swin", "mamba"], optional):
Model for Fusion X4 (e.g., 10m -> 2.5m resolution).
Options: "cnn", "swin", "mamba". Defaults to "cnn".
Fusionx4_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional):
Size of the Fusion X4 model.
Options: "lightweight", "small", "medium", "expanded", "large".
Fusionx4_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional):
Size of the Fusion X4 model.
Options: "lightweight", "small", "medium", "expanded", "large".
Defaults to "lightweight".
weights_path (Union[str, pathlib.Path, None], optional):
Path to the pre-trained model weights.
weights_path (Union[str, pathlib.Path, None], optional):
Path to the pre-trained model weights.
Can be a string or pathlib.Path object. Defaults to None.
If None, the code will try to retrieve the weights from the
official repository.
device (str, optional): Device to use for the models. Defaults to "cpu".
**kwargs: Additional keyword arguments to pass to the models.
Returns:
dict: A dictionary containing the loaded models for super-resolution and fusion tasks.
Expand All @@ -72,15 +70,15 @@ def setmodel(
weights_path = pathlib.Path.home() / ".config" / "supers2"
weights_path.mkdir(parents=True, exist_ok=True)

# If the resolution is 10m we only run the FusionX2 model that
# If the resolution is 10m we only run the FusionX2 model that
# converts 20m bands to 10m
if resolution == 10:
return {
"FusionX2": load_fusionx2_model(
model_name=Fusionx2_model_name,
model_size=Fusionx2_model_size,
model_loss="l1",
weights_path=weights_path,
weights_path=weights_path
),
"FusionX4": None,
"SR": None,
Expand All @@ -92,33 +90,35 @@ def setmodel(
model_name=Fusionx2_model_name,
model_size=Fusionx2_model_size,
model_loss="l1",
weights_path=weights_path,
weights_path=weights_path
),
"FusionX4": load_fusionx4_model(
model_name=Fusionx4_model_name,
model_size=Fusionx4_model_size,
model_loss="l1",
weights_path=weights_path,
weights_path=weights_path
),
"SR": load_srx4_model(
model_name=SR_model_name,
model_size=SR_model_size,
model_loss=SR_model_loss,
weights_path=weights_path,
),
device=device,
**kwargs
)
}


def predict(
X: torch.Tensor,
resolution: Literal["2.5m", "5m", "10m"] = "5m",
models: Optional[dict] = None,
resolution: Literal["2.5m", "5m", "10m"] = "2.5m",
models: Optional[dict] = None
) -> torch.Tensor:
"""Generate a new S2 tensor with all the bands on the same resolution
""" Generate a new S2 tensor with all the bands on the same resolution
Args:
X (torch.Tensor): The input tensor with the S2 bands
resolution (Literal["2.5m", "5m", "10m"], optional): The final resolution of the
resolution (Literal["2.5m", "5m", "10m"], optional): The final resolution of the
tensor. Defaults to "2.5m".
device (str, optional): The device to use. Defaults to "cpu".
Expand All @@ -141,72 +141,68 @@ def predict(
raise ValueError("Invalid resolution. Please select 2.5m, 5m, or 10m.")


def fusionx2(
X: torch.Tensor,
models: dict
) -> torch.Tensor:
def fusionx2(X: torch.Tensor, models: dict) -> torch.Tensor:
"""Converts 20m bands to 10m resolution
Args:
X (torch.Tensor): The input tensor with the S2 bands
models (dict): The dictionary with the loaded models
Returns:
torch.Tensor: The tensor with the same resolution for all the bands
torch.Tensor: The tensor with the same resolution for all the bands
"""

# Obtain the device of X
device = X.device

# Band Selection
index20 = [3, 4, 5, 7, 8, 9]
index10 = [0, 1, 2, 6]

bands_20m = [3, 4, 5, 7, 8, 9]
bands_10m = [0, 1, 2, 6]
# Set the model
fusionmodelx2 = models["FusionX2"].to(device)

# Select the 20m bands
bands20_as_10 = X[index20]

bands20 = torch.nn.functional.interpolate(
bands20_as_10[None], scale_factor=0.5, mode="nearest"
bands_20m_data = X[bands_20m]

bands_20m_data_real = torch.nn.functional.interpolate(
bands_20m_data[None],
scale_factor=0.5,
mode="nearest"
).squeeze(0)

bands20_in_10 = torch.nn.functional.interpolate(
bands20[None], scale_factor=2, mode="bilinear", antialias=True
bands_20m_data = torch.nn.functional.interpolate(
bands_20m_data_real[None],
scale_factor=2,
mode="bilinear",
antialias=True
).squeeze(0)

# Select the 10m bands
bands10 = X[index10]

bands_10m_data = X[bands_10m]
# Concatenate the 20m and 10m bands
input_data = torch.cat([bands20_in_10, bands10], dim=0)
bands20_to_10 = fusionmodelx2(input_data[None]).squeeze(0)

# Order the channels back
results = torch.stack(
[
bands10[0],
bands10[1],
bands10[2],
bands20_to_10[0],
bands20_to_10[1],
bands20_to_10[2],
bands10[3],
bands20_to_10[3],
bands20_to_10[4],
bands20_to_10[5],
],
dim=0,
)
input_data = torch.cat([bands_20m_data, bands_10m_data], dim=0)
bands_20m_data_to_10 = fusionmodelx2(input_data[None]).squeeze(0)

# Order the channels back
results = torch.stack([
bands_10m_data[0],
bands_10m_data[1],
bands_10m_data[2],
bands_20m_data_to_10[0],
bands_20m_data_to_10[1],
bands_20m_data_to_10[2],
bands_10m_data[3],
bands_20m_data_to_10[3],
bands_20m_data_to_10[4],
bands_20m_data_to_10[5],
], dim=0)

return results


def fusionx8(
X: torch.Tensor,
models: dict
) -> torch.Tensor:
def fusionx8(X: torch.Tensor, models: dict) -> torch.Tensor:
"""Converts 20m bands to 10m resolution
Args:
Expand All @@ -224,54 +220,51 @@ def fusionx8(
superX: torch.Tensor = fusionx2(X, models)

# Band Selection
index20 = [3, 4, 5, 7, 8, 9]
index10 = [2, 1, 0, 6] # WARNING: The SR model needs RGBNIR bands

bands_20m = [3, 4, 5, 7, 8, 9]
bands_10m = [2, 1, 0, 6] # WARNING: The SR model needs RGBNIR bands
# Set the SR resolution and x4 fusion model
fusionmodelx4 = models["FusionX4"].to(device)
srmodelx4 = models["SR"].to(device)

# Convert the SWIR bands to 2.5m
bands20_to_10 = superX[index20]
bands10_in_2dot5 = torch.nn.functional.interpolate(
bands20_to_10[None], scale_factor=4, mode="bilinear", antialias=True
).squeeze(0)

bands_20m_data = superX[bands_20m]
bands_20m_data_up = torch.nn.functional.interpolate(
bands_20m_data[None],
scale_factor=4,
mode="bilinear",
antialias=True
).squeeze(0)

# Run super-resolution on the 10m bands
bands10 = superX[index10]
bands10_to_2dot5 = srmodelx4(bands10[None]).squeeze(0)

rgbn_bands_10m_data = superX[bands_10m]
tensor_x4_rgbnir = srmodelx4(rgbn_bands_10m_data[None]).squeeze(0)
# Reorder the bands from RGBNIR to BGRNIR
bands10_to_2dot5 = bands10_to_2dot5[[2, 1, 0, 3]]
tensor_x4_rgbnir = tensor_x4_rgbnir[[2, 1, 0, 3]]

# Run the fusion x4 model in the SWIR bands (10m to 2.5m)
input_data = torch.cat([bands10_in_2dot5, bands10_to_2dot5], dim=0)
allbands_to_2dot5 = fusionmodelx4(input_data[None]).squeeze(0)

input_data = torch.cat([bands_20m_data_up, tensor_x4_rgbnir], dim=0)
bands_20m_data_to_25m = fusionmodelx4(input_data[None]).squeeze(0)
# Order the channels back
results = torch.stack(
[
bands10_to_2dot5[0],
bands10_to_2dot5[1],
bands10_to_2dot5[2],
allbands_to_2dot5[0],
allbands_to_2dot5[1],
allbands_to_2dot5[2],
bands10_to_2dot5[3],
allbands_to_2dot5[3],
allbands_to_2dot5[4],
allbands_to_2dot5[5],
],
dim=0,
)
results = torch.stack([
tensor_x4_rgbnir[0],
tensor_x4_rgbnir[1],
tensor_x4_rgbnir[2],
bands_20m_data_to_25m[0],
bands_20m_data_to_25m[1],
bands_20m_data_to_25m[2],
tensor_x4_rgbnir[3],
bands_20m_data_to_25m[3],
bands_20m_data_to_25m[4],
bands_20m_data_to_25m[5],
], dim=0)

return results


def fusionx4(
X: torch.Tensor,
models: dict
) -> torch.Tensor:
def fusionx4(X: torch.Tensor, models: dict) -> torch.Tensor:
"""Converts 20m bands to 10m resolution
Args:
Expand All @@ -287,5 +280,8 @@ def fusionx4(

# From 2.5m to 5m resolution
return torch.nn.functional.interpolate(
superX[None], scale_factor=0.5, mode="bilinear", antialias=True
).squeeze(0)
superX[None],
scale_factor=0.5,
mode="bilinear",
antialias=True
).squeeze(0)
Loading

0 comments on commit 299768a

Please sign in to comment.