save converted detokenizer params
This commit is contained in:
@@ -6,8 +6,9 @@ import torch
|
||||
from .min_dalle_base import MinDalleBase
|
||||
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
||||
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||
|
||||
from .load_params import load_dalle_bart_flax_params
|
||||
from .load_params import load_dalle_bart_flax_params, load_vqgan_torch_params
|
||||
|
||||
|
||||
class MinDalleFlax(MinDalleBase):
|
||||
@@ -32,7 +33,7 @@ class MinDalleFlax(MinDalleBase):
|
||||
text_vocab_count = self.config['encoder_vocab_size'],
|
||||
layer_count = self.config['encoder_layers']
|
||||
).bind({'params': self.model_params.pop('encoder')})
|
||||
|
||||
|
||||
|
||||
def init_decoder(self):
|
||||
print("initializing DalleBartDecoderFlax")
|
||||
@@ -46,7 +47,14 @@ class MinDalleFlax(MinDalleBase):
|
||||
layer_count = self.config['decoder_layers'],
|
||||
start_token = self.config['decoder_start_token_id']
|
||||
)
|
||||
|
||||
|
||||
|
||||
def init_detokenizer(self):
|
||||
print("initializing VQGanDetokenizer")
|
||||
params = load_vqgan_torch_params('./pretrained/vqgan')
|
||||
self.detokenizer = VQGanDetokenizer()
|
||||
self.detokenizer.load_state_dict(params)
|
||||
del params
|
||||
|
||||
def generate_image(self, text: str, seed: int) -> Image.Image:
|
||||
text_tokens = self.tokenize_text(text)
|
||||
|
||||
Reference in New Issue
Block a user