save converted detokenizer params
This commit is contained in:
@@ -6,13 +6,11 @@ import torch
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_num_threads(os.cpu_count())
|
||||
|
||||
from .load_params import (
|
||||
convert_and_save_torch_params,
|
||||
load_dalle_bart_flax_params
|
||||
)
|
||||
from .load_params import convert_and_save_torch_params
|
||||
from .min_dalle_base import MinDalleBase
|
||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||
|
||||
|
||||
class MinDalleTorch(MinDalleBase):
|
||||
@@ -26,15 +24,14 @@ class MinDalleTorch(MinDalleBase):
|
||||
super().__init__(is_mega)
|
||||
self.is_reusable = is_reusable
|
||||
self.token_count = token_count
|
||||
|
||||
if not is_mega:
|
||||
self.model_params = load_dalle_bart_flax_params(self.model_path)
|
||||
|
||||
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
|
||||
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
|
||||
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detokenizer.pt')
|
||||
|
||||
is_converted = os.path.exists(self.encoder_params_path)
|
||||
is_converted &= os.path.exists(self.decoder_params_path)
|
||||
is_converted &= os.path.exists(self.detoker_params_path)
|
||||
if not is_converted:
|
||||
convert_and_save_torch_params(is_mega, self.model_path)
|
||||
|
||||
@@ -79,11 +76,14 @@ class MinDalleTorch(MinDalleBase):
|
||||
del params
|
||||
if torch.cuda.is_available(): self.decoder = self.decoder.cuda()
|
||||
|
||||
|
||||
|
||||
def init_detokenizer(self):
|
||||
super().init_detokenizer()
|
||||
if torch.cuda.is_available():
|
||||
self.detokenizer = self.detokenizer.cuda()
|
||||
print("initializing VQGanDetokenizer")
|
||||
self.detokenizer = VQGanDetokenizer()
|
||||
params = torch.load(self.detoker_params_path)
|
||||
self.detokenizer.load_state_dict(params)
|
||||
del params
|
||||
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
|
||||
|
||||
|
||||
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
||||
|
||||
Reference in New Issue
Block a user