fixed bug with cuda in detokenizer

This commit is contained in:
Brett Kuprel
2022-06-28 22:02:35 -04:00
parent 764b0bc685
commit 1fbb209623
3 changed files with 5 additions and 5 deletions
+2 -2
View File
@@ -104,13 +104,13 @@ def generate_image_tokens_torch(
return image_tokens
def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray:
def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray:
print("detokenizing image")
model_path = './pretrained/vqgan'
params = load_vqgan_torch_params(model_path)
detokenizer = VQGanDetokenizer()
detokenizer.load_state_dict(params)
# if torch.cuda.is_available(): detokenizer = detokenizer.cuda()
if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda()
image = detokenizer.forward(image_tokens).to(torch.uint8)
del detokenizer, params
return image.to('cpu').detach().numpy()