pre converting params to torch allows mega to run in standard colab runtime

This commit is contained in:
Brett Kuprel
2022-06-30 14:54:08 -04:00
parent de97fcf06b
commit b913b58353
5 changed files with 54 additions and 21 deletions
+26 -2
View File
@@ -1,6 +1,5 @@
import os
import numpy
from copy import deepcopy
from typing import Dict
from flax.traverse_util import flatten_dict
from flax.serialization import msgpack_restore
@@ -105,4 +104,29 @@ def convert_dalle_bart_torch_from_flax_params(
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
P['embed_positions.weight'] = P.pop('embed_positions.embedding')
return P
return P
def convert_and_save_mega_torch_params(is_mega: bool, model_path: str):
print("converting params to torch")
layer_count = 24 if is_mega else 12
flax_params = load_dalle_bart_flax_params(model_path)
encoder_params = convert_dalle_bart_torch_from_flax_params(
flax_params['encoder'],
layer_count=layer_count,
is_encoder=True
)
decoder_params = convert_dalle_bart_torch_from_flax_params(
flax_params['decoder'],
layer_count=layer_count,
is_encoder=False
)
for i in decoder_params:
decoder_params[i] = decoder_params[i].to(torch.float16)
for i in encoder_params:
encoder_params[i] = encoder_params[i].to(torch.float16)
torch.save(encoder_params, os.path.join(model_path, 'encoder.pt'))
torch.save(decoder_params, os.path.join(model_path, 'decoder.pt'))