diff --git a/ccrestoration/config/edsr_config.py b/ccrestoration/config/edsr_config.py index 0c8b4ab..33cc0a4 100644 --- a/ccrestoration/config/edsr_config.py +++ b/ccrestoration/config/edsr_config.py @@ -38,30 +38,30 @@ class EDSRConfig(BaseConfig): scale=4, ), # Official Large size models - EDSRConfig( - name=ConfigType.EDSR_Lx2_f256b32_DIV2K_official_2x, - url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx2_f256b32_DIV2K_official_2x.pth", - hash="be38e77dcff9ec95225cea6326b5a616d57869824688674da317df37f3d87d1b", - scale=2, - num_feat=256, - num_block=32, - ), - EDSRConfig( - name=ConfigType.EDSR_Lx3_f256b32_DIV2K_official_3x, - url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx3_f256b32_DIV2K_official_3x.pth", - hash="3660f70d306481ef4867731500c4e3d901a1b8547996cf4245a09ffbc151b70b", - scale=3, - num_feat=256, - num_block=32, - ), - EDSRConfig( - name=ConfigType.EDSR_Lx4_f256b32_DIV2K_official_4x, - url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx4_f256b32_DIV2K_official_4x.pth", - hash="76ee1c8f48813f46024bee8d2700f417f6b2db070e899954ff1552fbae343e93", - scale=4, - num_feat=256, - num_block=32, - ), + # EDSRConfig( + # name=ConfigType.EDSR_Lx2_f256b32_DIV2K_official_2x, + # url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx2_f256b32_DIV2K_official_2x.pth", + # hash="be38e77dcff9ec95225cea6326b5a616d57869824688674da317df37f3d87d1b", + # scale=2, + # num_feat=256, + # num_block=32, + # ), + # EDSRConfig( + # name=ConfigType.EDSR_Lx3_f256b32_DIV2K_official_3x, + # url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx3_f256b32_DIV2K_official_3x.pth", + # hash="3660f70d306481ef4867731500c4e3d901a1b8547996cf4245a09ffbc151b70b", + # scale=3, + # num_feat=256, + # num_block=32, + # ), + # EDSRConfig( + # name=ConfigType.EDSR_Lx4_f256b32_DIV2K_official_4x, + # url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/EDSR_Lx4_f256b32_DIV2K_official_4x.pth", + # hash="76ee1c8f48813f46024bee8d2700f417f6b2db070e899954ff1552fbae343e93", + # scale=4, + # num_feat=256, + # num_block=32, + # ), ] for cfg in EDSRConfigs: diff --git a/ccrestoration/config/realesrgan_config.py b/ccrestoration/config/realesrgan_config.py index b93177a..ef5d3a6 100644 --- a/ccrestoration/config/realesrgan_config.py +++ b/ccrestoration/config/realesrgan_config.py @@ -84,6 +84,22 @@ def act_type_match(cls, v: str) -> str: arch=ArchType.SRVGG, scale=2, ), + RealESRGANConfig( + name=ConfigType.RealESRGAN_APISR_RRDB_GAN_generator_2x, + url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/RealESRGAN_APISR_RRDB_GAN_generator_2x.pth", + hash="3b0d2b3a3c0461ac17d00f4f32240666fb832b738ea5a48449b1acf07fbb07e5", + arch=ArchType.RRDB, + scale=2, + num_block=6, + ), + RealESRGANConfig( + name=ConfigType.RealESRGAN_APISR_RRDB_GAN_generator_4x, + url="https://github.com/TensoRaws/ccrestoration/releases/download/model_zoo/RealESRGAN_APISR_RRDB_GAN_generator_4x.pth", + hash="6bd14a66224c90d4754011f378ac828b18e221f2d031026ec99cb5facdf40c19", + arch=ArchType.RRDB, + scale=4, + num_block=6, + ), ] for cfg in RealESRGANConfigs: diff --git a/ccrestoration/model/realesrgan_model.py b/ccrestoration/model/realesrgan_model.py index f164a2d..6dbff62 100644 --- a/ccrestoration/model/realesrgan_model.py +++ b/ccrestoration/model/realesrgan_model.py @@ -17,6 +17,9 @@ def load_model(self) -> Any: state_dict = state_dict["params_ema"] elif "params" in state_dict: state_dict = state_dict["params"] + elif "model_state_dict" in state_dict: + # For APISR's model + state_dict = state_dict["model_state_dict"] if cfg.arch == ArchType.RRDB: model = RRDBNet( diff --git a/ccrestoration/model/tile.py b/ccrestoration/model/tile.py index abc18b4..d686172 100644 --- a/ccrestoration/model/tile.py +++ b/ccrestoration/model/tile.py @@ -15,8 +15,8 @@ def calculate_pad_img_size(width: int, height: int, tile: Tuple[int, int], tile_ :param tile_pad: The padding size for each tile :return: The size of the padded image as a tuple (padded_width, padded_height) """ - pad_w = math.ceil(min(tile[0] + 2 * tile_pad, width) / 8) * 8 - pad_h = math.ceil(min(tile[1] + 2 * tile_pad, height) / 8) * 8 + pad_w = math.ceil(min(tile[0] + 2 * tile_pad, width) / tile_pad) * tile_pad + pad_h = math.ceil(min(tile[1] + 2 * tile_pad, height) / tile_pad) * tile_pad return pad_w, pad_h diff --git a/ccrestoration/type/config.py b/ccrestoration/type/config.py index 401817b..c438f1a 100644 --- a/ccrestoration/type/config.py +++ b/ccrestoration/type/config.py @@ -12,6 +12,8 @@ class ConfigType(str, Enum): RealESRGAN_AnimeJaNai_HD_V3_Compact_2x = "RealESRGAN_AnimeJaNai_HD_V3_Compact_2x.pth" RealESRGAN_AniScale_2_Compact_2x = "RealESRGAN_AniScale_2_Compact_2x.pth" RealESRGAN_Ani4Kv2_Compact_2x = "RealESRGAN_Ani4Kv2_Compact_2x.pth" + RealESRGAN_APISR_RRDB_GAN_generator_2x = "RealESRGAN_APISR_RRDB_GAN_generator_2x.pth" + RealESRGAN_APISR_RRDB_GAN_generator_4x = "RealESRGAN_APISR_RRDB_GAN_generator_4x.pth" # RealCUGAN RealCUGAN_Conservative_2x = "RealCUGAN_Conservative_2x.pth" @@ -36,9 +38,6 @@ class ConfigType(str, Enum): EDSR_Mx2_f64b16_DIV2K_official_2x = "EDSR_Mx2_f64b16_DIV2K_official_2x.pth" EDSR_Mx3_f64b16_DIV2K_official_3x = "EDSR_Mx3_f64b16_DIV2K_official_3x.pth" EDSR_Mx4_f64b16_DIV2K_official_4x = "EDSR_Mx4_f64b16_DIV2K_official_4x.pth" - EDSR_Lx2_f256b32_DIV2K_official_2x = "EDSR_Lx2_f256b32_DIV2K_official_2x.pth" - EDSR_Lx3_f256b32_DIV2K_official_3x = "EDSR_Lx3_f256b32_DIV2K_official_3x.pth" - EDSR_Lx4_f256b32_DIV2K_official_4x = "EDSR_Lx4_f256b32_DIV2K_official_4x.pth" # SwinIR SwinIR_classicalSR_DF2K_s64w8_SwinIR_M_2x = "SwinIR_classicalSR_DF2K_s64w8_SwinIR_M_2x.pth" @@ -46,4 +45,5 @@ class ConfigType(str, Enum): SwinIR_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR_L_GAN_4x = "SwinIR_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR_L_GAN_4x.pth" SwinIR_realSR_BSRGAN_DFO_s64w8_SwinIR_M_GAN_2x = "SwinIR_realSR_BSRGAN_DFO_s64w8_SwinIR_M_GAN_2x.pth" SwinIR_realSR_BSRGAN_DFO_s64w8_SwinIR_M_GAN_4x = "SwinIR_realSR_BSRGAN_DFO_s64w8_SwinIR_M_GAN_4x.pth" + SwinIR_Bubble_AnimeScale_SwinIR_Small_v1_2x = "SwinIR_Bubble_AnimeScale_SwinIR_Small_v1_2x.pth" diff --git a/tests/test_edsr.py b/tests/test_edsr.py index 8b76f19..d8c8fc9 100644 --- a/tests/test_edsr.py +++ b/tests/test_edsr.py @@ -1,5 +1,4 @@ import cv2 -import pytest from ccrestoration import AutoConfig, AutoModel, BaseConfig, ConfigType from ccrestoration.model import SRBaseModel @@ -26,23 +25,3 @@ def test_official_M(self) -> None: assert calculate_image_similarity(img1, img2) assert compare_image_size(img1, img2, cfg.scale) - - @pytest.mark.skip("Skip because it's too large") - def test_official_L(self) -> None: - img1 = load_image() - - for k in [ - ConfigType.EDSR_Lx2_f256b32_DIV2K_official_2x, - ConfigType.EDSR_Lx3_f256b32_DIV2K_official_3x, - ConfigType.EDSR_Lx4_f256b32_DIV2K_official_4x, - ]: - print(f"Testing {k}") - cfg: BaseConfig = AutoConfig.from_pretrained(k) - model: SRBaseModel = AutoModel.from_config(config=cfg, fp16=False, device=get_device()) - print(model.device) - - img2 = model.inference_image(img1) - cv2.imwrite(str(ASSETS_PATH / f"test_{k}_out.jpg"), img2) - - assert calculate_image_similarity(img1, img2) - assert compare_image_size(img1, img2, cfg.scale) diff --git a/tests/test_realesrgan.py b/tests/test_realesrgan.py index 6789d4a..4d58d6c 100644 --- a/tests/test_realesrgan.py +++ b/tests/test_realesrgan.py @@ -34,6 +34,8 @@ def test_custom(self) -> None: ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, ConfigType.RealESRGAN_AniScale_2_Compact_2x, ConfigType.RealESRGAN_Ani4Kv2_Compact_2x, + ConfigType.RealESRGAN_APISR_RRDB_GAN_generator_2x, + ConfigType.RealESRGAN_APISR_RRDB_GAN_generator_4x, ]: print(f"Testing {k}") cfg: BaseConfig = AutoConfig.from_pretrained(k) @@ -43,5 +45,5 @@ def test_custom(self) -> None: img2 = model.inference_image(img1) cv2.imwrite(str(ASSETS_PATH / f"test_{k}_out.jpg"), img2) - assert calculate_image_similarity(img1, img2) + assert calculate_image_similarity(img1, img2, 0.8) assert compare_image_size(img1, img2, cfg.scale) diff --git a/tests/util.py b/tests/util.py index 8a7faff..ee8cd02 100644 --- a/tests/util.py +++ b/tests/util.py @@ -24,12 +24,13 @@ def load_image() -> np.ndarray: return img -def calculate_image_similarity(image1: np.ndarray, image2: np.ndarray) -> bool: +def calculate_image_similarity(image1: np.ndarray, image2: np.ndarray, similarity: float = 0.85) -> bool: """ calculate image similarity, check SR is correct :param image1: original image :param image2: upscale image + :param similarity: similarity threshold :return: """ # Resize the two images to the same size @@ -41,7 +42,7 @@ def calculate_image_similarity(image1: np.ndarray, image2: np.ndarray) -> bool: # Calculate the Structural Similarity Index (SSIM) between the two images (score, diff) = structural_similarity(grayscale_image1, grayscale_image2, full=True) print("SSIM: {}".format(score)) - return score > 0.85 + return score > similarity def compare_image_size(image1: np.ndarray, image2: np.ndarray, scale: int) -> bool: