Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support APISR #8

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions ccrestoration/config/edsr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,30 @@ class EDSRConfig(BaseConfig):
scale=4,
),
# Official Large size models
EDSRConfig(
name=ConfigType.EDSR_Lx2_f256b32_DIV2K_official_2x,
url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx2_f256b32_DIV2K_official_2x.pth",
hash="be38e77dcff9ec95225cea6326b5a616d57869824688674da317df37f3d87d1b",
scale=2,
num_feat=256,
num_block=32,
),
EDSRConfig(
name=ConfigType.EDSR_Lx3_f256b32_DIV2K_official_3x,
url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx3_f256b32_DIV2K_official_3x.pth",
hash="3660f70d306481ef4867731500c4e3d901a1b8547996cf4245a09ffbc151b70b",
scale=3,
num_feat=256,
num_block=32,
),
EDSRConfig(
name=ConfigType.EDSR_Lx4_f256b32_DIV2K_official_4x,
url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx4_f256b32_DIV2K_official_4x.pth",
hash="76ee1c8f48813f46024bee8d2700f417f6b2db070e899954ff1552fbae343e93",
scale=4,
num_feat=256,
num_block=32,
),
# EDSRConfig(
# name=ConfigType.EDSR_Lx2_f256b32_DIV2K_official_2x,
# url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx2_f256b32_DIV2K_official_2x.pth",
# hash="be38e77dcff9ec95225cea6326b5a616d57869824688674da317df37f3d87d1b",
# scale=2,
# num_feat=256,
# num_block=32,
# ),
# EDSRConfig(
# name=ConfigType.EDSR_Lx3_f256b32_DIV2K_official_3x,
# url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx3_f256b32_DIV2K_official_3x.pth",
# hash="3660f70d306481ef4867731500c4e3d901a1b8547996cf4245a09ffbc151b70b",
# scale=3,
# num_feat=256,
# num_block=32,
# ),
# EDSRConfig(
# name=ConfigType.EDSR_Lx4_f256b32_DIV2K_official_4x,
# url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx4_f256b32_DIV2K_official_4x.pth",
# hash="76ee1c8f48813f46024bee8d2700f417f6b2db070e899954ff1552fbae343e93",
# scale=4,
# num_feat=256,
# num_block=32,
# ),
]

for cfg in EDSRConfigs:
Expand Down
16 changes: 16 additions & 0 deletions ccrestoration/config/realesrgan_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ def act_type_match(cls, v: str) -> str:
arch=ArchType.SRVGG,
scale=2,
),
RealESRGANConfig(
name=ConfigType.RealESRGAN_APISR_RRDB_GAN_generator_2x,
url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/RealESRGAN_APISR_RRDB_GAN_generator_2x.pth",
hash="3b0d2b3a3c0461ac17d00f4f32240666fb832b738ea5a48449b1acf07fbb07e5",
arch=ArchType.RRDB,
scale=2,
num_block=6,
),
RealESRGANConfig(
name=ConfigType.RealESRGAN_APISR_RRDB_GAN_generator_4x,
url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/RealESRGAN_APISR_RRDB_GAN_generator_4x.pth",
hash="6bd14a66224c90d4754011f378ac828b18e221f2d031026ec99cb5facdf40c19",
arch=ArchType.RRDB,
scale=4,
num_block=6,
),
]

for cfg in RealESRGANConfigs:
Expand Down
3 changes: 3 additions & 0 deletions ccrestoration/model/realesrgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def load_model(self) -> Any:
state_dict = state_dict["params_ema"]
elif "params" in state_dict:
state_dict = state_dict["params"]
elif "model_state_dict" in state_dict:
# For APISR's model
state_dict = state_dict["model_state_dict"]

if cfg.arch == ArchType.RRDB:
model = RRDBNet(
Expand Down
4 changes: 2 additions & 2 deletions ccrestoration/model/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def calculate_pad_img_size(width: int, height: int, tile: Tuple[int, int], tile_
:param tile_pad: The padding size for each tile
:return: The size of the padded image as a tuple (padded_width, padded_height)
"""
pad_w = math.ceil(min(tile[0] + 2 * tile_pad, width) / 8) * 8
pad_h = math.ceil(min(tile[1] + 2 * tile_pad, height) / 8) * 8
pad_w = math.ceil(min(tile[0] + 2 * tile_pad, width) / tile_pad) * tile_pad
pad_h = math.ceil(min(tile[1] + 2 * tile_pad, height) / tile_pad) * tile_pad

return pad_w, pad_h

Expand Down
6 changes: 3 additions & 3 deletions ccrestoration/type/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class ConfigType(str, Enum):
RealESRGAN_AnimeJaNai_HD_V3_Compact_2x = "RealESRGAN_AnimeJaNai_HD_V3_Compact_2x.pth"
RealESRGAN_AniScale_2_Compact_2x = "RealESRGAN_AniScale_2_Compact_2x.pth"
RealESRGAN_Ani4Kv2_Compact_2x = "RealESRGAN_Ani4Kv2_Compact_2x.pth"
RealESRGAN_APISR_RRDB_GAN_generator_2x = "RealESRGAN_APISR_RRDB_GAN_generator_2x.pth"
RealESRGAN_APISR_RRDB_GAN_generator_4x = "RealESRGAN_APISR_RRDB_GAN_generator_4x.pth"

# RealCUGAN
RealCUGAN_Conservative_2x = "RealCUGAN_Conservative_2x.pth"
Expand All @@ -36,14 +38,12 @@ class ConfigType(str, Enum):
EDSR_Mx2_f64b16_DIV2K_official_2x = "EDSR_Mx2_f64b16_DIV2K_official_2x.pth"
EDSR_Mx3_f64b16_DIV2K_official_3x = "EDSR_Mx3_f64b16_DIV2K_official_3x.pth"
EDSR_Mx4_f64b16_DIV2K_official_4x = "EDSR_Mx4_f64b16_DIV2K_official_4x.pth"
EDSR_Lx2_f256b32_DIV2K_official_2x = "EDSR_Lx2_f256b32_DIV2K_official_2x.pth"
EDSR_Lx3_f256b32_DIV2K_official_3x = "EDSR_Lx3_f256b32_DIV2K_official_3x.pth"
EDSR_Lx4_f256b32_DIV2K_official_4x = "EDSR_Lx4_f256b32_DIV2K_official_4x.pth"

# SwinIR
SwinIR_classicalSR_DF2K_s64w8_SwinIR_M_2x = "SwinIR_classicalSR_DF2K_s64w8_SwinIR_M_2x.pth"
SwinIR_lightweightSR_DIV2K_s64w8_SwinIR_S_2x = "SwinIR_lightweightSR_DIV2K_s64w8_SwinIR_S_2x.pth"
SwinIR_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR_L_GAN_4x = "SwinIR_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR_L_GAN_4x.pth"
SwinIR_realSR_BSRGAN_DFO_s64w8_SwinIR_M_GAN_2x = "SwinIR_realSR_BSRGAN_DFO_s64w8_SwinIR_M_GAN_2x.pth"
SwinIR_realSR_BSRGAN_DFO_s64w8_SwinIR_M_GAN_4x = "SwinIR_realSR_BSRGAN_DFO_s64w8_SwinIR_M_GAN_4x.pth"

SwinIR_Bubble_AnimeScale_SwinIR_Small_v1_2x = "SwinIR_Bubble_AnimeScale_SwinIR_Small_v1_2x.pth"
21 changes: 0 additions & 21 deletions tests/test_edsr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import cv2
import pytest

from ccrestoration import AutoConfig, AutoModel, BaseConfig, ConfigType
from ccrestoration.model import SRBaseModel
Expand All @@ -26,23 +25,3 @@ def test_official_M(self) -> None:

assert calculate_image_similarity(img1, img2)
assert compare_image_size(img1, img2, cfg.scale)

@pytest.mark.skip("Skip because it's too large")
def test_official_L(self) -> None:
img1 = load_image()

for k in [
ConfigType.EDSR_Lx2_f256b32_DIV2K_official_2x,
ConfigType.EDSR_Lx3_f256b32_DIV2K_official_3x,
ConfigType.EDSR_Lx4_f256b32_DIV2K_official_4x,
]:
print(f"Testing {k}")
cfg: BaseConfig = AutoConfig.from_pretrained(k)
model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device())
print(model.device)

img2 = model.inference_image(img1)
cv2.imwrite(str(ASSETS_PATH / f"test_{k}_out.jpg"), img2)

assert calculate_image_similarity(img1, img2)
assert compare_image_size(img1, img2, cfg.scale)
4 changes: 3 additions & 1 deletion tests/test_realesrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def test_custom(self) -> None:
ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x,
ConfigType.RealESRGAN_AniScale_2_Compact_2x,
ConfigType.RealESRGAN_Ani4Kv2_Compact_2x,
ConfigType.RealESRGAN_APISR_RRDB_GAN_generator_2x,
ConfigType.RealESRGAN_APISR_RRDB_GAN_generator_4x,
]:
print(f"Testing {k}")
cfg: BaseConfig = AutoConfig.from_pretrained(k)
Expand All @@ -43,5 +45,5 @@ def test_custom(self) -> None:
img2 = model.inference_image(img1)
cv2.imwrite(str(ASSETS_PATH / f"test_{k}_out.jpg"), img2)

assert calculate_image_similarity(img1, img2)
assert calculate_image_similarity(img1, img2, 0.8)
assert compare_image_size(img1, img2, cfg.scale)
5 changes: 3 additions & 2 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ def load_image() -> np.ndarray:
return img


def calculate_image_similarity(image1: np.ndarray, image2: np.ndarray) -> bool:
def calculate_image_similarity(image1: np.ndarray, image2: np.ndarray, similarity: float = 0.85) -> bool:
"""
calculate image similarity, check SR is correct

:param image1: original image
:param image2: upscale image
:param similarity: similarity threshold
:return:
"""
# Resize the two images to the same size
Expand All @@ -41,7 +42,7 @@ def calculate_image_similarity(image1: np.ndarray, image2: np.ndarray) -> bool:
# Calculate the Structural Similarity Index (SSIM) between the two images
(score, diff) = structural_similarity(grayscale_image1, grayscale_image2, full=True)
print("SSIM: {}".format(score))
return score > 0.85
return score > similarity


def compare_image_size(image1: np.ndarray, image2: np.ndarray, scale: int) -> bool:
Expand Down
Loading