works with cuda
This commit is contained in:
@@ -4,7 +4,7 @@ from copy import deepcopy
|
||||
from typing import Dict
|
||||
from flax import traverse_util, serialization
|
||||
import torch
|
||||
torch.no_grad()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
||||
@@ -30,7 +30,6 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
||||
|
||||
for i in P:
|
||||
P[i] = torch.tensor(P[i])
|
||||
# if torch.cuda.is_available(): P[i] = P[i].cuda()
|
||||
|
||||
P['embedding.weight'] = P.pop('quantize.embedding.embedding')
|
||||
|
||||
@@ -87,7 +86,6 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||
|
||||
for i in P:
|
||||
P[i] = torch.tensor(P[i])
|
||||
# if torch.cuda.is_available(): P[i] = P[i].cuda()
|
||||
|
||||
for i in list(P):
|
||||
if 'kernel' in i:
|
||||
|
||||
Reference in New Issue
Block a user