use cuda if available

This commit is contained in:
Brett Kuprel
2022-06-28 12:47:11 -04:00
parent 8544f59576
commit 5aa6fe49bf
3 changed files with 13 additions and 8 deletions
+4 -1
View File
@@ -2,8 +2,9 @@ import os
import numpy
from copy import deepcopy
from typing import Dict
import torch
from flax import traverse_util, serialization
import torch
torch.no_grad()
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
@@ -29,6 +30,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
for i in P:
P[i] = torch.tensor(P[i])
if torch.cuda.is_available(): P[i] = P[i].cuda()
P['embedding.weight'] = P.pop('quantize.embedding.embedding')
@@ -85,6 +87,7 @@ def convert_dalle_bart_torch_from_flax_params(
for i in P:
P[i] = torch.tensor(P[i])
if torch.cuda.is_available(): P[i] = P[i].cuda()
for i in list(P):
if 'kernel' in i: