Skip to content

Commit

Permalink
fix (torch frontends)(non_linear_activation_functions.py): fixing the…
Browse files Browse the repository at this point in the history
… implementation of `torch.nn.functional.prelu` to handle broadcasting when the input is a higher dimensional matrix.
  • Loading branch information
YushaArif99 committed Oct 10, 2024
1 parent 11a1c1a commit ece779c
Showing 1 changed file with 16 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ def normalize(input, p=2.0, dim=1, eps=1e-12, out=None):

@to_ivy_arrays_and_back
def prelu(input, weight):
input_dim = input.ndim
weight_dim = weight.ndim

if weight_dim == 0:
pass
elif weight_dim == 1:
if input_dim >= 2:
assert weight.shape[0] == input.shape[1], "Weight size must match input channels"

# Unsqueeze weight to match input shape
weight = weight.expand_dims(axis=0).expand_dims(axis=2).expand_dims(axis=3)
# Add more unsqueeze operations if input has more dimensions
for i in range(4, input_dim):
weight = weight.expand_dims(axis=-1)
else:
raise ValueError("Weight must be a scalar or 1-D tensor")
return ivy.add(ivy.maximum(0, input), ivy.multiply(weight, ivy.minimum(0, input)))


Expand Down

0 comments on commit ece779c

Please sign in to comment.