refactored to load models once and run multiple times
This commit is contained in:
@@ -26,7 +26,8 @@ class DecoderCrossAttentionTorch(AttentionTorch):
|
||||
|
||||
|
||||
class DecoderSelfAttentionTorch(AttentionTorch):
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
decoder_state: FloatTensor,
|
||||
keys_values: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
@@ -49,7 +50,8 @@ class DecoderSelfAttentionTorch(AttentionTorch):
|
||||
|
||||
|
||||
class DecoderLayerTorch(nn.Module):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
image_token_count: int,
|
||||
head_count: int,
|
||||
embed_count: int,
|
||||
@@ -69,7 +71,8 @@ class DecoderLayerTorch(nn.Module):
|
||||
if torch.cuda.is_available():
|
||||
self.token_indices = self.token_indices.cuda()
|
||||
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
decoder_state: FloatTensor,
|
||||
encoder_state: FloatTensor,
|
||||
keys_values_state: FloatTensor,
|
||||
@@ -111,7 +114,8 @@ class DecoderLayerTorch(nn.Module):
|
||||
|
||||
|
||||
class DalleBartDecoderTorch(nn.Module):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
image_vocab_size: int,
|
||||
image_token_count: int,
|
||||
sample_token_count: int,
|
||||
@@ -158,7 +162,8 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
self.start_token = self.start_token.cuda()
|
||||
|
||||
|
||||
def decode_step(self,
|
||||
def decode_step(
|
||||
self,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor,
|
||||
keys_values_state: FloatTensor,
|
||||
@@ -198,7 +203,8 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
return probs, keys_values
|
||||
|
||||
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor
|
||||
) -> LongTensor:
|
||||
|
||||
Reference in New Issue
Block a user