From ece779c5abdce4bfcdb6c4c87689bd2452421999 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Thu, 10 Oct 2024 06:18:39 +0000 Subject: [PATCH] fix (torch frontends)(non_linear_activation_functions.py): fixing the implementation of `torch.nn.functional.prelu` to handle broadcasting when the input is a higher dimensional matrix. --- .../non_linear_activation_functions.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py b/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py index 55a5f5d35371f..6a3ddf8d6db5f 100644 --- a/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py @@ -179,6 +179,22 @@ def normalize(input, p=2.0, dim=1, eps=1e-12, out=None): @to_ivy_arrays_and_back def prelu(input, weight): + input_dim = input.ndim + weight_dim = weight.ndim + + if weight_dim == 0: + pass + elif weight_dim == 1: + if input_dim >= 2: + assert weight.shape[0] == input.shape[1], "Weight size must match input channels" + + # Unsqueeze weight to match input shape + weight = weight.expand_dims(axis=0).expand_dims(axis=2).expand_dims(axis=3) + # Add more unsqueeze operations if input has more dimensions + for i in range(4, input_dim): + weight = weight.expand_dims(axis=-1) + else: + raise ValueError("Weight must be a scalar or 1-D tensor") return ivy.add(ivy.maximum(0, input), ivy.multiply(weight, ivy.minimum(0, input)))