remove config.json dependency, default to torch in image_from_text.py
This commit is contained in:
+13
-14
@@ -26,26 +26,25 @@ class MinDalleFlax(MinDalleBase):
|
||||
def init_encoder(self):
|
||||
print("initializing DalleBartEncoderFlax")
|
||||
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
|
||||
attention_head_count = self.config['encoder_attention_heads'],
|
||||
embed_count = self.config['d_model'],
|
||||
glu_embed_count = self.config['encoder_ffn_dim'],
|
||||
text_token_count = self.config['max_text_length'],
|
||||
text_vocab_count = self.config['encoder_vocab_size'],
|
||||
layer_count = self.config['encoder_layers']
|
||||
attention_head_count = 32 if self.is_mega else 16,
|
||||
embed_count = 2048 if self.is_mega else 1024,
|
||||
glu_embed_count = 4096 if self.is_mega else 2730,
|
||||
text_token_count = 64,
|
||||
text_vocab_count = 50272 if self.is_mega else 50264,
|
||||
layer_count = 24 if self.is_mega else 12
|
||||
).bind({'params': self.model_params.pop('encoder')})
|
||||
|
||||
|
||||
def init_decoder(self):
|
||||
print("initializing DalleBartDecoderFlax")
|
||||
self.decoder = DalleBartDecoderFlax(
|
||||
image_token_count = self.config['image_length'],
|
||||
text_token_count = self.config['max_text_length'],
|
||||
image_vocab_count = self.config['image_vocab_size'],
|
||||
attention_head_count = self.config['decoder_attention_heads'],
|
||||
embed_count = self.config['d_model'],
|
||||
glu_embed_count = self.config['decoder_ffn_dim'],
|
||||
layer_count = self.config['decoder_layers'],
|
||||
start_token = self.config['decoder_start_token_id']
|
||||
image_token_count = 256,
|
||||
image_vocab_count = 16415 if self.is_mega else 16384,
|
||||
attention_head_count = 32 if self.is_mega else 16,
|
||||
embed_count = 2048 if self.is_mega else 1024,
|
||||
glu_embed_count = 4096 if self.is_mega else 2730,
|
||||
layer_count = 24 if self.is_mega else 12,
|
||||
start_token = 16415 if self.is_mega else 16384
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user