Skip to content

Commit

Permalink
Fix global group pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Mar 19, 2024
1 parent c020cf8 commit 9f4bd7e
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions groco/layers/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MaxPooling3D,
)

from groco import utils
from groco.layers.group_transforms import GroupTransforms


Expand Down Expand Up @@ -92,8 +93,10 @@ def __init__(self, dimensions: int, pool_type: str, **kwargs):

if "data_format" in kwargs and kwargs["data_format"] == "channels_first":
self.group_axis = self.dimensions + 2
self.channels_axis = 1
else:
self.group_axis = self.dimensions + 1
self.channels_axis = -1

def call(self, inputs):
inputs = self.pool_group(inputs)
Expand All @@ -113,9 +116,13 @@ def restore_group_axis(self, outputs):
return outputs

def build(self, input_shape):
reshaped_input = self.group_transforms.build(input_shape)
if len(input_shape) == self.dimensions + 3:
reshaped_input = utils.merge_shapes(
input_shape,
merged_axis=self.group_axis,
target_axis=self.channels_axis,
)
self.pooling.build(reshaped_input)
self.group_transforms.build_pool()

def get_config(self):
config = self.pooling.get_config()
Expand Down

0 comments on commit 9f4bd7e

Please sign in to comment.