sort -> topk, prev_token_and_index -> prev_token, token_index
This commit is contained in:
@@ -2,16 +2,17 @@ import os
|
||||
import numpy
|
||||
from copy import deepcopy
|
||||
from typing import Dict
|
||||
from flax import traverse_util, serialization
|
||||
from flax.traverse_util import flatten_dict
|
||||
from flax.serialization import msgpack_restore
|
||||
import torch
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
||||
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
|
||||
params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read())
|
||||
params: Dict[str, numpy.ndarray] = msgpack_restore(f.read())
|
||||
|
||||
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.')
|
||||
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
|
||||
|
||||
for i in list(P.keys()):
|
||||
j = i
|
||||
@@ -42,7 +43,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
||||
|
||||
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
|
||||
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
|
||||
params = serialization.msgpack_restore(f.read())
|
||||
params = msgpack_restore(f.read())
|
||||
|
||||
for codec in ['encoder', 'decoder']:
|
||||
k = 'FlaxBart{}Layers'.format(codec.title())
|
||||
@@ -82,7 +83,7 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||
is_encoder: bool
|
||||
) -> dict:
|
||||
P = deepcopy(params)
|
||||
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(P, sep='.')
|
||||
P: Dict[str, numpy.ndarray] = flatten_dict(P, sep='.')
|
||||
|
||||
for i in P:
|
||||
P[i] = torch.tensor(P[i])
|
||||
|
||||
Reference in New Issue
Block a user