0.3.13, simplified code, specify device when initializing MinDalle
This commit is contained in:
@@ -4,7 +4,7 @@ from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, count_in_out, count_middle):
|
||||
def __init__(self, count_in_out: int, count_middle: int):
|
||||
super().__init__()
|
||||
self.gelu = nn.GELU()
|
||||
self.ln0 = nn.LayerNorm(count_in_out)
|
||||
@@ -33,8 +33,6 @@ class AttentionBase(nn.Module):
|
||||
self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.one = torch.ones((1, 1))
|
||||
if torch.cuda.is_available(): self.one = self.one.cuda()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -48,11 +46,7 @@ class AttentionBase(nn.Module):
|
||||
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
|
||||
attention_bias = torch.where(
|
||||
attention_mask,
|
||||
self.one * 0,
|
||||
self.one * (-torch.inf),
|
||||
)
|
||||
attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12
|
||||
attention_weights: FloatTensor = torch.einsum(
|
||||
'bqhc,bkhc->bhqk',
|
||||
queries,
|
||||
@@ -115,7 +109,8 @@ class DalleBartEncoder(nn.Module):
|
||||
attention_head_count: int,
|
||||
text_vocab_count: int,
|
||||
text_token_count: int,
|
||||
glu_embed_count: int
|
||||
glu_embed_count: int,
|
||||
device: str
|
||||
):
|
||||
super().__init__()
|
||||
self.text_vocab_count = text_vocab_count
|
||||
@@ -131,17 +126,14 @@ class DalleBartEncoder(nn.Module):
|
||||
])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||
self.final_ln = nn.LayerNorm(embed_count)
|
||||
self.token_indices = torch.arange(text_token_count).to(torch.long)
|
||||
if torch.cuda.is_available():
|
||||
self.token_indices = self.token_indices.cuda()
|
||||
token_indices = torch.arange(text_token_count, device=device)
|
||||
self.pose_tokens = torch.stack([token_indices] * 2)
|
||||
|
||||
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
|
||||
text_tokens.clamp_(0, self.text_vocab_count - 1)
|
||||
encoder_state = (
|
||||
self.embed_tokens.forward(text_tokens) +
|
||||
self.embed_positions.forward(pose_tokens)
|
||||
self.embed_positions.forward(self.pose_tokens)
|
||||
)
|
||||
encoder_state = self.layernorm_embedding.forward(encoder_state)
|
||||
for layer in self.layers:
|
||||
|
||||
Reference in New Issue
Block a user