clamp tokens to appropriate bounds
This commit is contained in:
@@ -119,6 +119,7 @@ class DalleBartEncoder(nn.Module):
|
||||
glu_embed_count: int
|
||||
):
|
||||
super().__init__()
|
||||
self.text_vocab_count = text_vocab_count
|
||||
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
|
||||
self.embed_positions = nn.Embedding(text_token_count, embed_count)
|
||||
self.layers: List[EncoderLayer] = nn.ModuleList([
|
||||
@@ -138,6 +139,7 @@ class DalleBartEncoder(nn.Module):
|
||||
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
|
||||
text_tokens = text_tokens.clamp(0, self.text_vocab_count - 1)
|
||||
encoder_state = (
|
||||
self.embed_tokens.forward(text_tokens) +
|
||||
self.embed_positions.forward(pose_tokens)
|
||||
|
||||
Reference in New Issue
Block a user