works with cuda
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
torch.no_grad()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
class GLUTorch(nn.Module):
|
||||
@@ -34,6 +34,8 @@ class AttentionTorch(nn.Module):
|
||||
self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.one = torch.ones((1, 1))
|
||||
if torch.cuda.is_available(): self.one = self.one.cuda()
|
||||
|
||||
def forward(self,
|
||||
keys: FloatTensor,
|
||||
@@ -43,8 +45,8 @@ class AttentionTorch(nn.Module):
|
||||
) -> FloatTensor:
|
||||
attention_bias = torch.where(
|
||||
attention_mask,
|
||||
torch.full(attention_mask.shape, 0.0),
|
||||
torch.full(attention_mask.shape, -torch.inf),
|
||||
self.one * 0,
|
||||
self.one * (-torch.inf),
|
||||
)
|
||||
attention_weights: FloatTensor = torch.einsum(
|
||||
'bqhc,bkhc->bhqk',
|
||||
@@ -124,11 +126,14 @@ class DalleBartEncoderTorch(nn.Module):
|
||||
])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||
self.final_ln = nn.LayerNorm(embed_count)
|
||||
self.token_indices = torch.arange(text_token_count).to(torch.long)
|
||||
if torch.cuda.is_available():
|
||||
self.token_indices = self.token_indices.cuda()
|
||||
|
||||
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
batch_count, token_count = text_tokens.shape
|
||||
pose_tokens = torch.stack([torch.arange(token_count)] * batch_count)
|
||||
batch_count = text_tokens.shape[0]
|
||||
pose_tokens = torch.stack([self.token_indices] * batch_count)
|
||||
encoder_state = (
|
||||
self.embed_tokens.forward(text_tokens) +
|
||||
self.embed_positions.forward(pose_tokens)
|
||||
|
||||
Reference in New Issue
Block a user