refactored to load models once and run multiple times

This commit is contained in:
Brett Kuprel
2022-06-29 09:42:12 -04:00
parent 1ef9b0b929
commit ed91ab4a30
11 changed files with 225 additions and 282 deletions
+12 -6
View File
@@ -26,7 +26,8 @@ class DecoderCrossAttentionTorch(AttentionTorch):
class DecoderSelfAttentionTorch(AttentionTorch):
def forward(self,
def forward(
self,
decoder_state: FloatTensor,
keys_values: FloatTensor,
attention_mask: BoolTensor,
@@ -49,7 +50,8 @@ class DecoderSelfAttentionTorch(AttentionTorch):
class DecoderLayerTorch(nn.Module):
def __init__(self,
def __init__(
self,
image_token_count: int,
head_count: int,
embed_count: int,
@@ -69,7 +71,8 @@ class DecoderLayerTorch(nn.Module):
if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda()
def forward(self,
def forward(
self,
decoder_state: FloatTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
@@ -111,7 +114,8 @@ class DecoderLayerTorch(nn.Module):
class DalleBartDecoderTorch(nn.Module):
def __init__(self,
def __init__(
self,
image_vocab_size: int,
image_token_count: int,
sample_token_count: int,
@@ -158,7 +162,8 @@ class DalleBartDecoderTorch(nn.Module):
self.start_token = self.start_token.cuda()
def decode_step(self,
def decode_step(
self,
text_tokens: LongTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
@@ -198,7 +203,8 @@ class DalleBartDecoderTorch(nn.Module):
return probs, keys_values
def forward(self,
def forward(
self,
text_tokens: LongTensor,
encoder_state: FloatTensor
) -> LongTensor: