use cuda if available
This commit is contained in:
@@ -73,7 +73,6 @@ def decode_torch(
|
||||
|
||||
print("sampling image tokens")
|
||||
torch.manual_seed(seed)
|
||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||
image_tokens = decoder.forward(text_tokens, encoder_state)
|
||||
return image_tokens
|
||||
|
||||
@@ -84,10 +83,9 @@ def generate_image_tokens_torch(
|
||||
config: dict,
|
||||
params: dict,
|
||||
image_token_count: int
|
||||
) -> numpy.ndarray:
|
||||
) -> LongTensor:
|
||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||
if torch.cuda.is_available():
|
||||
text_tokens = text_tokens.cuda()
|
||||
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||
encoder_state = encode_torch(
|
||||
text_tokens,
|
||||
config,
|
||||
@@ -101,16 +99,15 @@ def generate_image_tokens_torch(
|
||||
params,
|
||||
image_token_count
|
||||
)
|
||||
return image_tokens.detach().numpy()
|
||||
return image_tokens
|
||||
|
||||
|
||||
def detokenize_torch(image_tokens: numpy.ndarray) -> numpy.ndarray:
|
||||
def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray:
|
||||
print("detokenizing image")
|
||||
model_path = './pretrained/vqgan'
|
||||
params = load_vqgan_torch_params(model_path)
|
||||
detokenizer = VQGanDetokenizer()
|
||||
detokenizer.load_state_dict(params)
|
||||
image_tokens = torch.tensor(image_tokens).to(torch.long)
|
||||
image = detokenizer.forward(image_tokens).to(torch.uint8)
|
||||
return image.detach().numpy()
|
||||
|
||||
Reference in New Issue
Block a user