works with cuda

This commit is contained in:
Brett Kuprel
2022-06-28 21:28:36 -04:00
parent 9d6b6dcc92
commit 17c96fe110
6 changed files with 43 additions and 33 deletions
+2
View File
@@ -65,6 +65,8 @@ def generate_image_from_text(
if image_token_count == config['image_length']:
image = detokenize_torch(image_tokens)
return Image.fromarray(image)
else:
print(list(image_tokens.to('cpu').detach().numpy()))
else:
image_tokens = generate_image_tokens_flax(
text_tokens = text_tokens,