pre converting params to torch allows mega to run in standard colab runtime

This commit is contained in:
Brett Kuprel
2022-06-30 14:54:08 -04:00
parent de97fcf06b
commit b913b58353
5 changed files with 54 additions and 21 deletions
+5 -6
View File
@@ -10,12 +10,12 @@ class MinDalleBase:
def __init__(self, is_mega: bool):
self.is_mega = is_mega
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
model_path = os.path.join('pretrained', model_name)
self.model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(model_path))
config_path = os.path.join(model_path, 'config.json')
vocab_path = os.path.join(model_path, 'vocab.json')
merges_path = os.path.join(model_path, 'merges.txt')
print("reading files from {}".format(self.model_path))
config_path = os.path.join(self.model_path, 'config.json')
vocab_path = os.path.join(self.model_path, 'vocab.json')
merges_path = os.path.join(self.model_path, 'merges.txt')
with open(config_path, 'r', encoding='utf8') as f:
self.config = json.load(f)
@@ -24,7 +24,6 @@ class MinDalleBase:
with open(merges_path, 'r', encoding='utf8') as f:
merges = f.read().split("\n")[1:-1]
self.model_params = load_dalle_bart_flax_params(model_path)
self.tokenizer = TextTokenizer(vocab, merges)