refactored to load models once and run multiple times

This commit is contained in:
Brett Kuprel
2022-06-29 09:42:12 -04:00
parent 1ef9b0b929
commit ed91ab4a30
11 changed files with 225 additions and 282 deletions
+4 -2
View File
@@ -37,7 +37,8 @@ class AttentionTorch(nn.Module):
self.one = torch.ones((1, 1))
if torch.cuda.is_available(): self.one = self.one.cuda()
def forward(self,
def forward(
self,
keys: FloatTensor,
values: FloatTensor,
queries: FloatTensor,
@@ -105,7 +106,8 @@ class EncoderLayerTorch(nn.Module):
class DalleBartEncoderTorch(nn.Module):
def __init__(self,
def __init__(
self,
layer_count: int,
embed_count: int,
attention_head_count: int,