pre converting params to torch allows mega to run in standard colab runtime
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import numpy
|
||||
from copy import deepcopy
|
||||
from typing import Dict
|
||||
from flax.traverse_util import flatten_dict
|
||||
from flax.serialization import msgpack_restore
|
||||
@@ -105,4 +104,29 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||
|
||||
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
|
||||
P['embed_positions.weight'] = P.pop('embed_positions.embedding')
|
||||
return P
|
||||
return P
|
||||
|
||||
|
||||
def convert_and_save_mega_torch_params(is_mega: bool, model_path: str):
|
||||
print("converting params to torch")
|
||||
layer_count = 24 if is_mega else 12
|
||||
flax_params = load_dalle_bart_flax_params(model_path)
|
||||
encoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||
flax_params['encoder'],
|
||||
layer_count=layer_count,
|
||||
is_encoder=True
|
||||
)
|
||||
decoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||
flax_params['decoder'],
|
||||
layer_count=layer_count,
|
||||
is_encoder=False
|
||||
)
|
||||
|
||||
for i in decoder_params:
|
||||
decoder_params[i] = decoder_params[i].to(torch.float16)
|
||||
|
||||
for i in encoder_params:
|
||||
encoder_params[i] = encoder_params[i].to(torch.float16)
|
||||
|
||||
torch.save(encoder_params, os.path.join(model_path, 'encoder.pt'))
|
||||
torch.save(decoder_params, os.path.join(model_path, 'decoder.pt'))
|
||||
Reference in New Issue
Block a user