save converted detokenizer params

This commit is contained in:
Brett Kuprel
2022-07-01 10:17:29 -04:00
parent 8b5960b687
commit e4c2be54cb
7 changed files with 35 additions and 32 deletions
+2 -11
View File
@@ -3,7 +3,6 @@ import json
import numpy
from .text_tokenizer import TextTokenizer
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
from .models.vqgan_detokenizer import VQGanDetokenizer
class MinDalleBase:
@@ -27,20 +26,12 @@ class MinDalleBase:
self.tokenizer = TextTokenizer(vocab, merges)
def init_detokenizer(self):
print("initializing VQGanDetokenizer")
params = load_vqgan_torch_params('./pretrained/vqgan')
self.detokenizer = VQGanDetokenizer()
self.detokenizer.load_state_dict(params)
del params
def tokenize_text(self, text: str) -> numpy.ndarray:
print("tokenizing text")
tokens = self.tokenizer.tokenize(text)
print("text tokens", tokens)
text_token_count = self.config['max_text_length']
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
text_tokens[0, :2] = [tokens[0], tokens[-1]]
text_tokens[1, :len(tokens)] = tokens
return text_tokens