previous commit broke flax model, fixed now
This commit is contained in:
@@ -3,6 +3,7 @@ import json
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from typing import Tuple, List
|
||||
import torch
|
||||
|
||||
from min_dalle.load_params import load_dalle_bart_flax_params
|
||||
from min_dalle.text_tokenizer import TextTokenizer
|
||||
@@ -53,25 +54,23 @@ def generate_image_from_text(
|
||||
text_tokens = tokenize_text(text, config, vocab, merges)
|
||||
params_dalle_bart = load_dalle_bart_flax_params(model_path)
|
||||
|
||||
image_tokens = numpy.zeros(config['image_length'])
|
||||
if is_torch:
|
||||
image_tokens[:image_token_count] = generate_image_tokens_torch(
|
||||
image_tokens = generate_image_tokens_torch(
|
||||
text_tokens = text_tokens,
|
||||
seed = seed,
|
||||
config = config,
|
||||
params = params_dalle_bart,
|
||||
image_token_count = image_token_count
|
||||
)
|
||||
if image_token_count == config['image_length']:
|
||||
image = detokenize_torch(image_tokens)
|
||||
return Image.fromarray(image)
|
||||
else:
|
||||
image_tokens[...] = generate_image_tokens_flax(
|
||||
image_tokens = generate_image_tokens_flax(
|
||||
text_tokens = text_tokens,
|
||||
seed = seed,
|
||||
config = config,
|
||||
params = params_dalle_bart,
|
||||
)
|
||||
|
||||
if image_token_count == config['image_length']:
|
||||
image = detokenize_torch(image_tokens)
|
||||
return Image.fromarray(image)
|
||||
else:
|
||||
return None
|
||||
image = detokenize_torch(torch.tensor(image_tokens))
|
||||
return Image.fromarray(image)
|
||||
Reference in New Issue
Block a user