updated readme
This commit is contained in:
@@ -107,7 +107,7 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||
return P
|
||||
|
||||
|
||||
def convert_and_save_mega_torch_params(is_mega: bool, model_path: str):
|
||||
def convert_and_save_torch_params(is_mega: bool, model_path: str):
|
||||
print("converting params to torch")
|
||||
layer_count = 24 if is_mega else 12
|
||||
flax_params = load_dalle_bart_flax_params(model_path)
|
||||
|
||||
Reference in New Issue
Block a user