works with cuda

This commit is contained in:
Brett Kuprel
2022-06-28 21:28:36 -04:00
parent 9d6b6dcc92
commit 17c96fe110
6 changed files with 43 additions and 33 deletions
+10 -5
View File
@@ -1,7 +1,7 @@
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.no_grad()
torch.set_grad_enabled(False)
class GLUTorch(nn.Module):
@@ -34,6 +34,8 @@ class AttentionTorch(nn.Module):
self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
self.one = torch.ones((1, 1))
if torch.cuda.is_available(): self.one = self.one.cuda()
def forward(self,
keys: FloatTensor,
@@ -43,8 +45,8 @@ class AttentionTorch(nn.Module):
) -> FloatTensor:
attention_bias = torch.where(
attention_mask,
torch.full(attention_mask.shape, 0.0),
torch.full(attention_mask.shape, -torch.inf),
self.one * 0,
self.one * (-torch.inf),
)
attention_weights: FloatTensor = torch.einsum(
'bqhc,bkhc->bhqk',
@@ -124,11 +126,14 @@ class DalleBartEncoderTorch(nn.Module):
])
self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count)
self.token_indices = torch.arange(text_token_count).to(torch.long)
if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda()
def forward(self, text_tokens: LongTensor) -> FloatTensor:
attention_mask = text_tokens.not_equal(1)
batch_count, token_count = text_tokens.shape
pose_tokens = torch.stack([torch.arange(token_count)] * batch_count)
batch_count = text_tokens.shape[0]
pose_tokens = torch.stack([self.token_indices] * batch_count)
encoder_state = (
self.embed_tokens.forward(text_tokens) +
self.embed_positions.forward(pose_tokens)