import math,torch
from torch import nn
from fastAIcourse.activations import *
Attention
import matplotlib.pyplot as plt
from diffusers.models.attention import Attention as AttentionBlock
42)
set_seed(= torch.randn(64,32,16,16) x
= x.view(*x.shape[:2], -1).transpose(1, 2)
t t.shape
torch.Size([64, 256, 32])
= 32 ni
= nn.Linear(ni, ni)
sk = nn.Linear(ni, ni)
sq = nn.Linear(ni, ni) sv
= sk(t)
k = sq(t)
q = sv(t) v
@k.transpose(1,2)).shape (q
torch.Size([64, 256, 256])
class SelfAttention(nn.Module):
def __init__(self, ni):
super().__init__()
self.scale = math.sqrt(ni)
self.norm = nn.GroupNorm(1, ni)
self.q = nn.Linear(ni, ni)
self.k = nn.Linear(ni, ni)
self.v = nn.Linear(ni, ni)
self.proj = nn.Linear(ni, ni)
def forward(self, x):
= x
inp = x.shape
n,c,h,w = self.norm(x)
x = x.view(n, c, -1).transpose(1, 2)
x = self.q(x)
q = self.k(x)
k = self.v(x)
v = (q@k.transpose(1,2))/self.scale
s = s.softmax(dim=-1)@v
x = self.proj(x)
x = x.transpose(1,2).reshape(n,c,h,w)
x return x+inp
= SelfAttention(32) sa
= sa(x)
ra ra.shape
torch.Size([64, 32, 16, 16])
0,0,0] ra[
tensor([ 1.91, 1.42, 0.84, -2.16, 0.63, -1.24, -0.08, -1.68, -0.79, 1.61, -0.39, -1.43, -0.75, -0.60, -0.83, 0.75],
grad_fn=<SelectBackward0>)
def cp_parms(a,b):
= a.weight
b.weight = a.bias b.bias
= AttentionBlock(32, norm_num_groups=1)
at = sa.q,sa.k,sa.v,sa.proj,sa.norm
src = at.query,at.key,at.value,at.proj_attn,at.group_norm
dst for s,d in zip(src,dst): cp_parms(s,d)
= at(x)
rb 0,0,0] rb[
= nn.Linear(ni, ni*3)
sqkv = sqkv(t)
st st.shape
torch.Size([64, 256, 96])
= torch.chunk(st, 3, dim=-1)
q,k,v q.shape
@q.transpose(1,2)).shape (k
class SelfAttention(nn.Module):
def __init__(self, ni):
super().__init__()
self.scale = math.sqrt(ni)
self.norm = nn.BatchNorm2d(ni)
self.qkv = nn.Linear(ni, ni*3)
self.proj = nn.Linear(ni, ni)
def forward(self, inp):
= inp.shape
n,c,h,w = self.norm(inp).view(n, c, -1).transpose(1, 2)
x = torch.chunk(self.qkv(x), 3, dim=-1)
q,k,v = (q@k.transpose(1,2))/self.scale
s = s.softmax(dim=-1)@v
x = self.proj(x).transpose(1,2).reshape(n,c,h,w)
x return x+inp
class SelfAttention(nn.Module):
def __init__(self, ni):
super().__init__()
self.scale = math.sqrt(ni)
self.norm = nn.BatchNorm2d(ni)
self.qkv = nn.Linear(ni, ni*3)
self.proj = nn.Linear(ni, ni)
def forward(self, x):
= self.norm(x).transpose(1, 2)
x = torch.chunk(self.qkv(x), 3, dim=-1)
q,k,v = (q@k.transpose(1,2))/self.scale
s = s.softmax(dim=-1)@v
x return self.proj(x).transpose(1,2)
= SelfAttention(32)
sa sa(x).shape
torch.Size([64, 32, 16, 16])
sa(x).std()
tensor(1.0047, grad_fn=<StdBackward0>)
def heads_to_batch(x, heads):
= x.shape
n,sl,d = x.reshape(n, sl, heads, -1)
x return x.transpose(2, 1).reshape(n*heads,sl,-1)
def batch_to_heads(x, heads):
= x.shape
n,sl,d = x.reshape(-1, heads, sl, d)
x return x.transpose(2, 1).reshape(-1,sl,d*heads)
from einops import rearrange
= rearrange(t , 'n s (h d) -> (n h) s d', h=8)
t2 t.shape, t2.shape
(torch.Size([64, 256, 32]), torch.Size([512, 256, 4]))
= rearrange(t2, '(n h) s d -> n s (h d)', h=8) t3
t2.shape,t3.shape
(torch.Size([512, 256, 4]), torch.Size([64, 256, 32]))
==t3).all() (t
tensor(True)
class SelfAttentionMultiHead(nn.Module):
def __init__(self, ni, nheads):
super().__init__()
self.nheads = nheads
self.scale = math.sqrt(ni/nheads)
self.norm = nn.BatchNorm2d(ni)
self.qkv = nn.Linear(ni, ni*3)
self.proj = nn.Linear(ni, ni)
def forward(self, inp):
= inp.shape
n,c,h,w = self.norm(inp).view(n, c, -1).transpose(1, 2)
x = self.qkv(x)
x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)
x = torch.chunk(x, 3, dim=-1)
q,k,v = (q@k.transpose(1,2))/self.scale
s = s.softmax(dim=-1)@v
x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)
x = self.proj(x).transpose(1,2).reshape(n,c,h,w)
x return x+inp
= SelfAttentionMultiHead(32, 4)
sa = sa(x)
sx sx.shape
torch.Size([64, 32, 16, 16])
sx.mean(),sx.std()
(tensor(0.0248, grad_fn=<MeanBackward0>),
tensor(1.0069, grad_fn=<StdBackward0>))
= nn.MultiheadAttention(32, num_heads=8, batch_first=True)
nm = nm(t,t,t)
nmx,nmw = nmx+t nmx
nmx.mean(),nmx.std()
(tensor(-0.0021, grad_fn=<MeanBackward0>),
tensor(1.0015, grad_fn=<StdBackward0>))