From 9f4bd7ec8fb9f89a1eb412741e8b898ab264de7c Mon Sep 17 00:00:00 2001 From: Aron Date: Tue, 19 Mar 2024 18:10:25 +0100 Subject: [PATCH] Fix global group pooling --- groco/layers/pooling.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/groco/layers/pooling.py b/groco/layers/pooling.py index 2df7ed4..94bdcc4 100644 --- a/groco/layers/pooling.py +++ b/groco/layers/pooling.py @@ -15,6 +15,7 @@ MaxPooling3D, ) +from groco import utils from groco.layers.group_transforms import GroupTransforms @@ -92,8 +93,10 @@ def __init__(self, dimensions: int, pool_type: str, **kwargs): if "data_format" in kwargs and kwargs["data_format"] == "channels_first": self.group_axis = self.dimensions + 2 + self.channels_axis = 1 else: self.group_axis = self.dimensions + 1 + self.channels_axis = -1 def call(self, inputs): inputs = self.pool_group(inputs) @@ -113,9 +116,13 @@ def restore_group_axis(self, outputs): return outputs def build(self, input_shape): - reshaped_input = self.group_transforms.build(input_shape) + if len(input_shape) == self.dimensions + 3: + reshaped_input = utils.merge_shapes( + input_shape, + merged_axis=self.group_axis, + target_axis=self.channels_axis, + ) self.pooling.build(reshaped_input) - self.group_transforms.build_pool() def get_config(self): config = self.pooling.get_config()