works with cuda
This commit is contained in:
@@ -2,7 +2,7 @@ import numpy
|
||||
from typing import Dict
|
||||
from torch import LongTensor, FloatTensor
|
||||
import torch
|
||||
torch.no_grad()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||
@@ -35,6 +35,7 @@ def encode_torch(
|
||||
)
|
||||
encoder.load_state_dict(encoder_params, strict=False)
|
||||
del encoder_params
|
||||
if torch.cuda.is_available(): encoder = encoder.cuda()
|
||||
|
||||
print("encoding text tokens")
|
||||
encoder_state = encoder(text_tokens)
|
||||
@@ -70,6 +71,7 @@ def decode_torch(
|
||||
)
|
||||
decoder.load_state_dict(decoder_params, strict=False)
|
||||
del decoder_params
|
||||
if torch.cuda.is_available(): decoder = decoder.cuda()
|
||||
|
||||
print("sampling image tokens")
|
||||
torch.manual_seed(seed)
|
||||
@@ -85,7 +87,7 @@ def generate_image_tokens_torch(
|
||||
image_token_count: int
|
||||
) -> LongTensor:
|
||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||
# if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||
encoder_state = encode_torch(
|
||||
text_tokens,
|
||||
config,
|
||||
@@ -108,6 +110,8 @@ def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray:
|
||||
params = load_vqgan_torch_params(model_path)
|
||||
detokenizer = VQGanDetokenizer()
|
||||
detokenizer.load_state_dict(params)
|
||||
if torch.cuda.is_available(): detokenizer = detokenizer.cuda()
|
||||
image = detokenizer.forward(image_tokens).to(torch.uint8)
|
||||
return image.detach().numpy()
|
||||
del detokenizer, params
|
||||
return image.to('cpu').detach().numpy()
|
||||
|
||||
Reference in New Issue
Block a user