Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Triton Error [CUDA]: context is destroyed #28

Open
huxiaopang666 opened this issue Mar 2, 2025 · 1 comment
Open

RuntimeError: Triton Error [CUDA]: context is destroyed #28

huxiaopang666 opened this issue Mar 2, 2025 · 1 comment

Comments

@huxiaopang666
Copy link

when I use your kat_group as a module, it runs normally on cuda:0; however, when I run it on other GPUs, I encounter the error mentioned in the title. I tested the code you provided and experienced the same issue. Could you please let me know how to fix it?
import torch
import torch.nn as nn
from kat_rational import KAT_Group
class KAN(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks."""

def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_cfg=dict(type="KAT", act_init=["identity", "gelu"]),
        bias=True,
        drop=0.,
):
    super().__init__()
    out_features = out_features or in_features
    hidden_features = hidden_features or in_features

    self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
    self.act1 = KAT_Group(mode = act_cfg['act_init'][0])
    self.drop1 = nn.Dropout(drop)
    self.act2 = KAT_Group(mode = act_cfg['act_init'][1])
    self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
    self.drop2 = nn.Dropout(drop)

def forward(self, x):
    x = self.act1(x)
    x = self.drop1(x)
    x = self.fc1(x)
    x = self.act2(x)
    x = self.drop2(x)
    x = self.fc2(x)
    return x

N, C = 8, 64
input_tensor = torch.randn(N, C).to('cuda:1')
model = KAN(in_features=C, hidden_features=128, out_features=C).to('cuda:1')
output = model(input_tensor)
print(output.shape)

Image

Image

@Adamdad
Copy link
Owner

Adamdad commented Mar 2, 2025

This is a commonly reported Triton error. facebookresearch/xformers#681

I found a solution here, can you please try it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants