remove config.json dependency, default to torch in image_from_text.py

This commit is contained in:
Brett Kuprel
2022-07-01 12:03:37 -04:00
parent 4404e70764
commit 85f5866eff
10 changed files with 52 additions and 64 deletions
+13 -14
View File
@@ -26,26 +26,25 @@ class MinDalleFlax(MinDalleBase):
def init_encoder(self):
print("initializing DalleBartEncoderFlax")
self.encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
attention_head_count = self.config['encoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['encoder_ffn_dim'],
text_token_count = self.config['max_text_length'],
text_vocab_count = self.config['encoder_vocab_size'],
layer_count = self.config['encoder_layers']
attention_head_count = 32 if self.is_mega else 16,
embed_count = 2048 if self.is_mega else 1024,
glu_embed_count = 4096 if self.is_mega else 2730,
text_token_count = 64,
text_vocab_count = 50272 if self.is_mega else 50264,
layer_count = 24 if self.is_mega else 12
).bind({'params': self.model_params.pop('encoder')})
def init_decoder(self):
print("initializing DalleBartDecoderFlax")
self.decoder = DalleBartDecoderFlax(
image_token_count = self.config['image_length'],
text_token_count = self.config['max_text_length'],
image_vocab_count = self.config['image_vocab_size'],
attention_head_count = self.config['decoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
start_token = self.config['decoder_start_token_id']
image_token_count = 256,
image_vocab_count = 16415 if self.is_mega else 16384,
attention_head_count = 32 if self.is_mega else 16,
embed_count = 2048 if self.is_mega else 1024,
glu_embed_count = 4096 if self.is_mega else 2730,
layer_count = 24 if self.is_mega else 12,
start_token = 16415 if self.is_mega else 16384
)