clamp tokens to appropriate bounds

This commit is contained in:
Brett Kuprel
2022-07-08 09:19:28 -04:00
parent e409c120d0
commit 9eb5633931
5 changed files with 13 additions and 7 deletions
+2
View File
@@ -119,6 +119,7 @@ class DalleBartEncoder(nn.Module):
glu_embed_count: int
):
super().__init__()
self.text_vocab_count = text_vocab_count
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
self.embed_positions = nn.Embedding(text_token_count, embed_count)
self.layers: List[EncoderLayer] = nn.ModuleList([
@@ -138,6 +139,7 @@ class DalleBartEncoder(nn.Module):
def forward(self, text_tokens: LongTensor) -> FloatTensor:
attention_mask = text_tokens.not_equal(1)
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
text_tokens = text_tokens.clamp(0, self.text_vocab_count - 1)
encoder_state = (
self.embed_tokens.forward(text_tokens) +
self.embed_positions.forward(pose_tokens)