remove config.json dependency, default to torch in image_from_text.py
This commit is contained in:
@@ -37,12 +37,12 @@ class MinDalleTorch(MinDalleBase):
|
||||
def init_encoder(self):
|
||||
print("initializing DalleBartEncoderTorch")
|
||||
self.encoder = DalleBartEncoderTorch(
|
||||
layer_count = self.config['encoder_layers'],
|
||||
embed_count = self.config['d_model'],
|
||||
attention_head_count = self.config['encoder_attention_heads'],
|
||||
text_vocab_count = self.config['encoder_vocab_size'],
|
||||
text_token_count = self.config['max_text_length'],
|
||||
glu_embed_count = self.config['encoder_ffn_dim']
|
||||
attention_head_count = 32 if self.is_mega else 16,
|
||||
embed_count = 2048 if self.is_mega else 1024,
|
||||
glu_embed_count = 4096 if self.is_mega else 2730,
|
||||
text_token_count = 64,
|
||||
text_vocab_count = 50272 if self.is_mega else 50264,
|
||||
layer_count = 24 if self.is_mega else 12
|
||||
)
|
||||
params = torch.load(self.encoder_params_path)
|
||||
self.encoder.load_state_dict(params, strict=False)
|
||||
@@ -53,16 +53,15 @@ class MinDalleTorch(MinDalleBase):
|
||||
def init_decoder(self):
|
||||
print("initializing DalleBartDecoderTorch")
|
||||
self.decoder = DalleBartDecoderTorch(
|
||||
image_vocab_size = self.config['image_vocab_size'],
|
||||
image_token_count = self.config['image_length'],
|
||||
sample_token_count = self.token_count,
|
||||
embed_count = self.config['d_model'],
|
||||
attention_head_count = self.config['decoder_attention_heads'],
|
||||
glu_embed_count = self.config['decoder_ffn_dim'],
|
||||
layer_count = self.config['decoder_layers'],
|
||||
batch_count = 2,
|
||||
start_token = self.config['decoder_start_token_id'],
|
||||
is_verbose = True
|
||||
image_token_count = 256,
|
||||
image_vocab_count = 16415 if self.is_mega else 16384,
|
||||
attention_head_count = 32 if self.is_mega else 16,
|
||||
embed_count = 2048 if self.is_mega else 1024,
|
||||
glu_embed_count = 4096 if self.is_mega else 2730,
|
||||
layer_count = 24 if self.is_mega else 12,
|
||||
start_token = 16415 if self.is_mega else 16384,
|
||||
batch_count = 2
|
||||
)
|
||||
params = torch.load(self.decoder_params_path)
|
||||
self.decoder.load_state_dict(params, strict=False)
|
||||
|
||||
Reference in New Issue
Block a user