-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upsampling/Unpooling for Transpose Convolutions #894
Comments
I think you're the first one to ask for them! :) I'd be happy to take a pull request that adds these. |
I can try my hand at it - @JadM133 currently I've used jax's image resize:
|
@TugdualKerjan thanks! I won't be able to work on it anytime soon, so feel free to take the lead :) |
It seems like there isn't a straight-forward implementation of upsampling in JAX's library. The closest thing I can find is this : jax-ml/jax#862. I'm still learning more about JAX and Equinox, so I guess this would be a way of being thrown in the deep end ! From what I gather, if we want to do PR on equinox for upsampling it's best if we first PR JAX with an upsampling solution and then come back to this ? @patrick-kidger would love to have your opinion on this Edit: Maybe working off this could work for at least the linear and nearest case ? https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.interpolate.RegularGridInterpolator.html#jax.scipy.interpolate.RegularGridInterpolator |
Do you know how this operation is defined (mathemtically) in other frameworks e.g. PyTorch? |
Let me share with you what I know. The most poopular Unpooling technique is Max Unpooling which is basically the transpose of a Max pooling layer. When a max pooling layer is applied, the positions of the maximums for each sliding window are memorized as indices. The corresponding unpooling layer will take the indices, and unpool each value in the corresponding index, the rest of the sliding window is set to zeros, for example here in torch. The operation is quite simple, but as @TugdualKerjan mentioned, we are missing an operation from JAX. Different pooling layers in equinox are applied using This part of the code is from
So basically what needs to be done is two things: 1- In JAX, define a window reducer using argmax, this requires also defining its primitives, lowerings, etc. @patrick-kidger what do you think? @TugdualKerjan I won't have the time to tackle this before January so if you could give it a go that would be great! Otherwise i could try my best starting next year. |
I think it's best if we break the problem down into Unpooling and Upsampling which are seperate matters - Should this be split in two issues ? Concerning the upsampling: I think we can base ourselves off https://en.wikipedia.org/wiki/Multivariate_interpolation and build brick-by-brick starting with nearest neighbor and linear ? I guess this doesn't require doing PRs on JAX as we can use https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.interp.html, like what torch does, calling F.interpolate behind the scenes: https://pytorch.org/docs/stable/_modules/torch/nn/modules/upsampling.html#Upsample Concerning MaxUnpooling: Should we open a JAX ticket to request it ? |
If this is given by a transpose, could this be implemented using |
I think it would be possible although this wouldn't be as efficient an implementation as modifying JAX to return the input IDs for the Unpooling. Concerning upsampling, what do you think of using jax.numpy.interp @patrick-kidger ? |
No opinions as I've not really used it I'm afraid :) Why would transposing be less efficient? XLA actually has a custom op (https://openxla.org/xla/operation_semantics#selectandscatter) for specifically the operation you're trying to do, so I suspect JAX is smart enough to lower to this. |
Hello all,
Great work, the library is amazing. I am having trouble in finding unpooling classes. Did I miss them or are they unavailable yet?
Thanks in advance!
The text was updated successfully, but these errors were encountered: