remove config.json dependency, default to torch in image_from_text.py

This commit is contained in:
Brett Kuprel
2022-07-01 12:03:37 -04:00
parent 4404e70764
commit 85f5866eff
10 changed files with 52 additions and 64 deletions
+14 -15
View File
@@ -37,12 +37,12 @@ class MinDalleTorch(MinDalleBase):
def init_encoder(self):
print("initializing DalleBartEncoderTorch")
self.encoder = DalleBartEncoderTorch(
layer_count = self.config['encoder_layers'],
embed_count = self.config['d_model'],
attention_head_count = self.config['encoder_attention_heads'],
text_vocab_count = self.config['encoder_vocab_size'],
text_token_count = self.config['max_text_length'],
glu_embed_count = self.config['encoder_ffn_dim']
attention_head_count = 32 if self.is_mega else 16,
embed_count = 2048 if self.is_mega else 1024,
glu_embed_count = 4096 if self.is_mega else 2730,
text_token_count = 64,
text_vocab_count = 50272 if self.is_mega else 50264,
layer_count = 24 if self.is_mega else 12
)
params = torch.load(self.encoder_params_path)
self.encoder.load_state_dict(params, strict=False)
@@ -53,16 +53,15 @@ class MinDalleTorch(MinDalleBase):
def init_decoder(self):
print("initializing DalleBartDecoderTorch")
self.decoder = DalleBartDecoderTorch(
image_vocab_size = self.config['image_vocab_size'],
image_token_count = self.config['image_length'],
sample_token_count = self.token_count,
embed_count = self.config['d_model'],
attention_head_count = self.config['decoder_attention_heads'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
batch_count = 2,
start_token = self.config['decoder_start_token_id'],
is_verbose = True
image_token_count = 256,
image_vocab_count = 16415 if self.is_mega else 16384,
attention_head_count = 32 if self.is_mega else 16,
embed_count = 2048 if self.is_mega else 1024,
glu_embed_count = 4096 if self.is_mega else 2730,
layer_count = 24 if self.is_mega else 12,
start_token = 16415 if self.is_mega else 16384,
batch_count = 2
)
params = torch.load(self.decoder_params_path)
self.decoder.load_state_dict(params, strict=False)