You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
when I use your kat_group as a module, it runs normally on cuda:0; however, when I run it on other GPUs, I encounter the error mentioned in the title. I tested the code you provided and experienced the same issue. Could you please let me know how to fix it?
import torch
import torch.nn as nn
from kat_rational import KAT_Group
class KAN(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks."""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=dict(type="KAT", act_init=["identity", "gelu"]),
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act1 = KAT_Group(mode = act_cfg['act_init'][0])
self.drop1 = nn.Dropout(drop)
self.act2 = KAT_Group(mode = act_cfg['act_init'][1])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.act1(x)
x = self.drop1(x)
x = self.fc1(x)
x = self.act2(x)
x = self.drop2(x)
x = self.fc2(x)
return x
N, C = 8, 64
input_tensor = torch.randn(N, C).to('cuda:1')
model = KAN(in_features=C, hidden_features=128, out_features=C).to('cuda:1')
output = model(input_tensor)
print(output.shape)
The text was updated successfully, but these errors were encountered:
when I use your kat_group as a module, it runs normally on cuda:0; however, when I run it on other GPUs, I encounter the error mentioned in the title. I tested the code you provided and experienced the same issue. Could you please let me know how to fix it?
import torch
import torch.nn as nn
from kat_rational import KAT_Group
class KAN(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks."""
N, C = 8, 64
input_tensor = torch.randn(N, C).to('cuda:1')
model = KAN(in_features=C, hidden_features=128, out_features=C).to('cuda:1')
output = model(input_tensor)
print(output.shape)
The text was updated successfully, but these errors were encountered: