works with cuda

This commit is contained in:
Brett Kuprel
2022-06-28 21:28:36 -04:00
parent 9d6b6dcc92
commit 17c96fe110
6 changed files with 43 additions and 33 deletions
+1 -3
View File
@@ -4,7 +4,7 @@ from copy import deepcopy
from typing import Dict
from flax import traverse_util, serialization
import torch
torch.no_grad()
torch.set_grad_enabled(False)
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
@@ -30,7 +30,6 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
for i in P:
P[i] = torch.tensor(P[i])
# if torch.cuda.is_available(): P[i] = P[i].cuda()
P['embedding.weight'] = P.pop('quantize.embedding.embedding')
@@ -87,7 +86,6 @@ def convert_dalle_bart_torch_from_flax_params(
for i in P:
P[i] = torch.tensor(P[i])
# if torch.cuda.is_available(): P[i] = P[i].cuda()
for i in list(P):
if 'kernel' in i: