previous commit broke flax model, fixed now

This commit is contained in:
Brett Kuprel
2022-06-28 12:54:58 -04:00
parent 5aa6fe49bf
commit 9d6b6dcc92
4 changed files with 16 additions and 17 deletions
+2 -2
View File
@@ -30,7 +30,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
for i in P:
P[i] = torch.tensor(P[i])
if torch.cuda.is_available(): P[i] = P[i].cuda()
# if torch.cuda.is_available(): P[i] = P[i].cuda()
P['embedding.weight'] = P.pop('quantize.embedding.embedding')
@@ -87,7 +87,7 @@ def convert_dalle_bart_torch_from_flax_params(
for i in P:
P[i] = torch.tensor(P[i])
if torch.cuda.is_available(): P[i] = P[i].cuda()
# if torch.cuda.is_available(): P[i] = P[i].cuda()
for i in list(P):
if 'kernel' in i: