forked from etched-ai/open-oasis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention.py
137 lines (113 loc) · 4.59 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py
"""
from typing import Optional
from collections import namedtuple
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
from embeddings import TimestepEmbedding, Timesteps, Positions2d
class TemporalAxialAttention(nn.Module):
def __init__(
self,
dim: int,
heads: int = 4,
dim_head: int = 32,
is_causal: bool = True,
rotary_emb: Optional[RotaryEmbedding] = None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
self.to_out = nn.Linear(self.inner_dim, dim)
self.rotary_emb = rotary_emb
self.time_pos_embedding = (
nn.Sequential(
Timesteps(dim),
TimestepEmbedding(in_channels=dim, time_embed_dim=dim * 4, out_dim=dim),
)
if rotary_emb is None
else None
)
self.is_causal = is_causal
def forward(self, x: torch.Tensor):
B, T, H, W, D = x.shape
if self.time_pos_embedding is not None:
time_emb = self.time_pos_embedding(
torch.arange(T, device=x.device)
)
x = x + rearrange(time_emb, "t d -> 1 t 1 1 d")
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
if self.rotary_emb is not None:
q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
x = F.scaled_dot_product_attention(
query=q, key=k, value=v, is_causal=self.is_causal
)
x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
x = x.to(q.dtype)
# linear proj
x = self.to_out(x)
return x
class SpatialAxialAttention(nn.Module):
def __init__(
self,
dim: int,
heads: int = 4,
dim_head: int = 32,
rotary_emb: Optional[RotaryEmbedding] = None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
self.to_out = nn.Linear(self.inner_dim, dim)
self.rotary_emb = rotary_emb
self.space_pos_embedding = (
nn.Sequential(
Positions2d(dim),
TimestepEmbedding(in_channels=dim, time_embed_dim=dim * 4, out_dim=dim),
)
if rotary_emb is None
else None
)
def forward(self, x: torch.Tensor):
B, T, H, W, D = x.shape
if self.space_pos_embedding is not None:
h_steps = torch.arange(H, device=x.device)
w_steps = torch.arange(W, device=x.device)
grid = torch.meshgrid(h_steps, w_steps, indexing="ij")
space_emb = self.space_pos_embedding(grid)
x = x + rearrange(space_emb, "h w d -> 1 1 h w d")
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads)
k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads)
v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads)
if self.rotary_emb is not None:
freqs = self.rotary_emb.get_axial_freqs(H, W)
q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)
# prepare for attn
q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
x = F.scaled_dot_product_attention(
query=q, key=k, value=v, is_causal=False
)
x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W)
x = x.to(q.dtype)
# linear proj
x = self.to_out(x)
return x