0.3.13, simplified code, specify device when initializing MinDalle

This commit is contained in:
Brett Kuprel
2022-07-15 17:18:23 -04:00
parent a2dca41623
commit 3c28b1059b
9 changed files with 139 additions and 188 deletions
+7 -15
View File
@@ -4,7 +4,7 @@ from torch import nn, BoolTensor, FloatTensor, LongTensor
class GLU(nn.Module):
def __init__(self, count_in_out, count_middle):
def __init__(self, count_in_out: int, count_middle: int):
super().__init__()
self.gelu = nn.GELU()
self.ln0 = nn.LayerNorm(count_in_out)
@@ -33,8 +33,6 @@ class AttentionBase(nn.Module):
self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
self.one = torch.ones((1, 1))
if torch.cuda.is_available(): self.one = self.one.cuda()
def forward(
self,
@@ -48,11 +46,7 @@ class AttentionBase(nn.Module):
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
queries /= queries.shape[-1] ** 0.5
attention_bias = torch.where(
attention_mask,
self.one * 0,
self.one * (-torch.inf),
)
attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12
attention_weights: FloatTensor = torch.einsum(
'bqhc,bkhc->bhqk',
queries,
@@ -115,7 +109,8 @@ class DalleBartEncoder(nn.Module):
attention_head_count: int,
text_vocab_count: int,
text_token_count: int,
glu_embed_count: int
glu_embed_count: int,
device: str
):
super().__init__()
self.text_vocab_count = text_vocab_count
@@ -131,17 +126,14 @@ class DalleBartEncoder(nn.Module):
])
self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count)
self.token_indices = torch.arange(text_token_count).to(torch.long)
if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda()
token_indices = torch.arange(text_token_count, device=device)
self.pose_tokens = torch.stack([token_indices] * 2)
def forward(self, text_tokens: LongTensor) -> FloatTensor:
attention_mask = text_tokens.not_equal(1)
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
text_tokens.clamp_(0, self.text_vocab_count - 1)
encoder_state = (
self.embed_tokens.forward(text_tokens) +
self.embed_positions.forward(pose_tokens)
self.embed_positions.forward(self.pose_tokens)
)
encoder_state = self.layernorm_embedding.forward(encoder_state)
for layer in self.layers: