Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upsampling/Unpooling for Transpose Convolutions #894

Open
JadM133 opened this issue Nov 10, 2024 · 10 comments
Open

Upsampling/Unpooling for Transpose Convolutions #894

JadM133 opened this issue Nov 10, 2024 · 10 comments
Labels
feature New feature

Comments

@JadM133
Copy link

JadM133 commented Nov 10, 2024

Hello all,
Great work, the library is amazing. I am having trouble in finding unpooling classes. Did I miss them or are they unavailable yet?
Thanks in advance!

@patrick-kidger
Copy link
Owner

I think you're the first one to ask for them! :)

I'd be happy to take a pull request that adds these.

@patrick-kidger patrick-kidger added the feature New feature label Nov 11, 2024
@TugdualKerjan
Copy link
Contributor

I can try my hand at it - @JadM133 currently I've used jax's image resize:

return jax.image.resize(x, (x.shape[0], target_length), method="linear")

@JadM133
Copy link
Author

JadM133 commented Dec 11, 2024

@TugdualKerjan thanks! I won't be able to work on it anytime soon, so feel free to take the lead :)

@TugdualKerjan
Copy link
Contributor

TugdualKerjan commented Dec 12, 2024

It seems like there isn't a straight-forward implementation of upsampling in JAX's library. The closest thing I can find is this : jax-ml/jax#862. I'm still learning more about JAX and Equinox, so I guess this would be a way of being thrown in the deep end !

From what I gather, if we want to do PR on equinox for upsampling it's best if we first PR JAX with an upsampling solution and then come back to this ? @patrick-kidger would love to have your opinion on this

Edit: Maybe working off this could work for at least the linear and nearest case ? https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.interpolate.RegularGridInterpolator.html#jax.scipy.interpolate.RegularGridInterpolator

@patrick-kidger
Copy link
Owner

Do you know how this operation is defined (mathemtically) in other frameworks e.g. PyTorch?

@JadM133
Copy link
Author

JadM133 commented Dec 14, 2024

Let me share with you what I know.

The most poopular Unpooling technique is Max Unpooling which is basically the transpose of a Max pooling layer.

When a max pooling layer is applied, the positions of the maximums for each sliding window are memorized as indices. The corresponding unpooling layer will take the indices, and unpool each value in the corresponding index, the rest of the sliding window is set to zeros, for example here in torch.

The operation is quite simple, but as @TugdualKerjan mentioned, we are missing an operation from JAX.

Different pooling layers in equinox are applied using lax.reduce_window , where we give a certain monoid operation. However, the general window reducer is not applied in JAX, only specific cases of max, min and add as follows:

This part of the code is from jax._src.lax.

def _get_monoid_window_reducer(
    monoid_op, xs: Sequence[Array]
) -> Callable | None:
  if len(xs) != 1:
    return None
  x, = xs
  aval = core.get_aval(x)
  if (type(aval) is ConcreteArray) and aval.shape == ():
    if monoid_op is lax.add:
      return aval.val == 0 and _reduce_window_sum
    elif monoid_op is lax.max:
      return (aval.val == lax._get_max_identity(aval.dtype)
              and _reduce_window_max)
    elif monoid_op is lax.min:
      return (aval.val == lax._get_min_identity(aval.dtype)
              and _reduce_window_min)
  return None

So basically what needs to be done is two things:

1- In JAX, define a window reducer using argmax, this requires also defining its primitives, lowerings, etc.
2- In equinox, add an argument to the maxpooling layers to allow the users to get back the indicies, so we could pass them to the unpooling layer when necessary.

@patrick-kidger what do you think?

@TugdualKerjan I won't have the time to tackle this before January so if you could give it a go that would be great! Otherwise i could try my best starting next year.

@TugdualKerjan
Copy link
Contributor

I think it's best if we break the problem down into Unpooling and Upsampling which are seperate matters - Should this be split in two issues ?

Concerning the upsampling: I think we can base ourselves off https://en.wikipedia.org/wiki/Multivariate_interpolation and build brick-by-brick starting with nearest neighbor and linear ? I guess this doesn't require doing PRs on JAX as we can use https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.interp.html, like what torch does, calling F.interpolate behind the scenes: https://pytorch.org/docs/stable/_modules/torch/nn/modules/upsampling.html#Upsample

Concerning MaxUnpooling: Should we open a JAX ticket to request it ?

@patrick-kidger
Copy link
Owner

If this is given by a transpose, could this be implemented using jax.linear_transpose, without requiring any changes in JAX?

@TugdualKerjan
Copy link
Contributor

I think it would be possible although this wouldn't be as efficient an implementation as modifying JAX to return the input IDs for the Unpooling.

Concerning upsampling, what do you think of using jax.numpy.interp @patrick-kidger ?

@patrick-kidger
Copy link
Owner

No opinions as I've not really used it I'm afraid :)

Why would transposing be less efficient? XLA actually has a custom op (https://openxla.org/xla/operation_semantics#selectandscatter) for specifically the operation you're trying to do, so I suspect JAX is smart enough to lower to this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

3 participants