sort -> topk, prev_token_and_index -> prev_token, token_index

This commit is contained in:
Brett Kuprel
2022-06-30 09:04:11 -04:00
parent fb97ba5e20
commit df9aa6f915
3 changed files with 21 additions and 14 deletions
+6 -5
View File
@@ -2,16 +2,17 @@ import os
import numpy
from copy import deepcopy
from typing import Dict
from flax import traverse_util, serialization
from flax.traverse_util import flatten_dict
from flax.serialization import msgpack_restore
import torch
torch.set_grad_enabled(False)
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read())
params: Dict[str, numpy.ndarray] = msgpack_restore(f.read())
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.')
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
for i in list(P.keys()):
j = i
@@ -42,7 +43,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
params = serialization.msgpack_restore(f.read())
params = msgpack_restore(f.read())
for codec in ['encoder', 'decoder']:
k = 'FlaxBart{}Layers'.format(codec.title())
@@ -82,7 +83,7 @@ def convert_dalle_bart_torch_from_flax_params(
is_encoder: bool
) -> dict:
P = deepcopy(params)
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(P, sep='.')
P: Dict[str, numpy.ndarray] = flatten_dict(P, sep='.')
for i in P:
P[i] = torch.tensor(P[i])