previous commit broke flax model, fixed now
This commit is contained in:
@@ -30,7 +30,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
||||
|
||||
for i in P:
|
||||
P[i] = torch.tensor(P[i])
|
||||
if torch.cuda.is_available(): P[i] = P[i].cuda()
|
||||
# if torch.cuda.is_available(): P[i] = P[i].cuda()
|
||||
|
||||
P['embedding.weight'] = P.pop('quantize.embedding.embedding')
|
||||
|
||||
@@ -87,7 +87,7 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||
|
||||
for i in P:
|
||||
P[i] = torch.tensor(P[i])
|
||||
if torch.cuda.is_available(): P[i] = P[i].cuda()
|
||||
# if torch.cuda.is_available(): P[i] = P[i].cuda()
|
||||
|
||||
for i in list(P):
|
||||
if 'kernel' in i:
|
||||
|
||||
Reference in New Issue
Block a user