works with cuda
This commit is contained in:
@@ -65,6 +65,8 @@ def generate_image_from_text(
|
||||
if image_token_count == config['image_length']:
|
||||
image = detokenize_torch(image_tokens)
|
||||
return Image.fromarray(image)
|
||||
else:
|
||||
print(list(image_tokens.to('cpu').detach().numpy()))
|
||||
else:
|
||||
image_tokens = generate_image_tokens_flax(
|
||||
text_tokens = text_tokens,
|
||||
|
||||
Reference in New Issue
Block a user