FastSo3 provides a fast implement of the equivariant layers introduced in the paper escn. These layers properly handle 3D rigid transformations (rotation and translation), so they are useful for modelling 3D data such as point clouds or molecules. FastSo3 is torch.compile
compatible.
How is FastSo3 different from the official implement?
FastSo3 is faster, and it can be even faster when combined with torch.compile
. Computation time (us):
[b, c, l] | escn | so2_efficient | so2_efficient_compile |
---|---|---|---|
1000 128 1 | 1545.0 | 776.4 | 246.8 |
1000 128 2 | 2143.3 | 724.7 | 430.5 |
1000 128 3 | 1848.9 | 724.0 | 664.3 |
1000 128 4 | 2324.5 | 962.0 | 996.3 |
1000 128 5 | 2838.1 | 1480.7 | 1534.6 |
1000 128 6 | 3142.7 | 2314.4 | 2033.4 |
- The code is run on one NVidia A100 GPU.
escn
: the official escn implement.so2_efficient
: the FastSo3 layerSO2_conv_e
.so2_efficient_compile
: the compiledSO2_conv_e
layer.b
: number of edges. 'c': number of channels. 'l': number of degrees. Hidden channel is set toc
. For escn all possiblem
s are used.
FastSo3 pads features of different degrees to the same size, so that they can be processed efficiently. Specifically,
- There is no for-loops in FastSo3, so the overhead is avoided.
- There is no dynamic indexing, which means that
torch.compile
can be used to acclerate the modules.
For a graph containing B
nodes in 3D space,
the channel C
maximum degree L
feature x
are represented by a tensor of shape (B, C, L, 2L+1)
.
For all degree l
smaller than L
, corresponding features are zero padded to length 2L+1
.
For example, Let L=3.
A degree-1 feature f (length=3) will be padded as [0, 0, f, 0, 0] (length=7);
A degree-2 feature g (length=5) will be padded as [0, g, 0] (length=7).
the degree 0
feature, i.e., the invariant feature, x_0
takes the shape (B, C)
.
To compute the message from node i
to node j
,
a rotation matrix R
is first used to rotate the edge ij to y-axis.
Then the feature Rx
is used as the input to the convolution layer,
the output is finally rotated back as the message.
This library provides several ways to do convolution in the second step as described as follows.
In our computation,
both L
and C
are treated as channels, so we use the notation h=CL
for simplicity.
We also use M = 2L
as the index of the last dimension.
Specifically,
SO2_conv
: The standard convolution. The kernel size is(B, h_in, h_out, M)
. It can be expensive for largeC
orL
.SO2_conv_c
: The depth-wise convolution. Each channel is processed independently, and the output feature shape is the same as the input feature shape. The kernel size is(B, h_in, M)
.SO2_mix_c
: The linear channel mixing layer. The parameter size is(h_in, h_out)
.SO2_conv_e
: Implemented based on the escn paper. First mix the channel, then do depth-wise convolution, finally map the channel back.
Note that these layers use weights of different shapes. See the source code for details.
We provide the following use case as an example.
import FastSo3.so3 as so3
import FastSo3.so2 as so2
import torch
c_in = 4 # input channel
c_out = 2 # output channel
c_hidden = 2 # hidden channel
L = 2 # degree
c_L0_in =2 # additional channel of degree 0 input
c_L0_out =4 # additional channel for degree 0 output
b = 100 # number of edges
so2_conv = SO2_conv(c_in, c_out, L, c_L0_in=c_L0_in, c_L0_out=c_L0_out)
# prepare input
x_ = torch.randn((b, c_in, L, L * 2 + 1))
M = so2.get_mask(L)
x = x_* M
# extra degree 0 input
x_L0 = torch.randn((b, c_L0_in))
# weight
w = torch.randn((b, c_in, L, L * 2, c_out, L))
w_L0 = torch.randn((b, c_L0_in + c_in * L, c_L0_out + c_out * L))
# edge vector
edges_vec = torch.randn((b, 3))
# a forward pass
message, message_0 = so2.get_message(x, x_L0, w, w_L0, edges_vec, so2_conv, L)
The equivariance of the layer can be checked by running the check_equi.py
script.
Uncomment different blocks to check different functions.
>>> python check_equi.py
# output
# The equivariance error: 5.020955995860277e-06
# The invariance error: 3.5846128412231337e-06
FastSo3 is available under a MIT license.