clamp tokens to appropriate bounds
This commit is contained in:
@@ -117,6 +117,7 @@ class DalleBartDecoder(nn.Module):
|
||||
super().__init__()
|
||||
self.layer_count = layer_count
|
||||
self.embed_count = embed_count
|
||||
self.image_vocab_count = image_vocab_count
|
||||
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
|
||||
self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count)
|
||||
self.layers: List[DecoderLayer] = nn.ModuleList([
|
||||
@@ -152,6 +153,7 @@ class DalleBartDecoder(nn.Module):
|
||||
image_count = encoder_state.shape[0] // 2
|
||||
token_index_batched = token_index[[0] * image_count * 2]
|
||||
prev_tokens = prev_tokens[list(range(image_count)) * 2]
|
||||
prev_tokens = prev_tokens.clamp(0, self.image_vocab_count)
|
||||
decoder_state = self.embed_tokens.forward(prev_tokens)
|
||||
decoder_state += self.embed_positions.forward(token_index_batched)
|
||||
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
||||
|
||||
Reference in New Issue
Block a user