works with cuda

This commit is contained in:
Brett Kuprel
2022-06-28 21:28:36 -04:00
parent 9d6b6dcc92
commit 17c96fe110
6 changed files with 43 additions and 33 deletions
+7 -3
View File
@@ -2,7 +2,7 @@ import numpy
from typing import Dict
from torch import LongTensor, FloatTensor
import torch
torch.no_grad()
torch.set_grad_enabled(False)
from .models.vqgan_detokenizer import VQGanDetokenizer
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
@@ -35,6 +35,7 @@ def encode_torch(
)
encoder.load_state_dict(encoder_params, strict=False)
del encoder_params
if torch.cuda.is_available(): encoder = encoder.cuda()
print("encoding text tokens")
encoder_state = encoder(text_tokens)
@@ -70,6 +71,7 @@ def decode_torch(
)
decoder.load_state_dict(decoder_params, strict=False)
del decoder_params
if torch.cuda.is_available(): decoder = decoder.cuda()
print("sampling image tokens")
torch.manual_seed(seed)
@@ -85,7 +87,7 @@ def generate_image_tokens_torch(
image_token_count: int
) -> LongTensor:
text_tokens = torch.tensor(text_tokens).to(torch.long)
# if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
encoder_state = encode_torch(
text_tokens,
config,
@@ -108,6 +110,8 @@ def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray:
params = load_vqgan_torch_params(model_path)
detokenizer = VQGanDetokenizer()
detokenizer.load_state_dict(params)
if torch.cuda.is_available(): detokenizer = detokenizer.cuda()
image = detokenizer.forward(image_tokens).to(torch.uint8)
return image.detach().numpy()
del detokenizer, params
return image.to('cpu').detach().numpy()