previous commit broke flax model, fixed now

This commit is contained in:
Brett Kuprel
2022-06-28 12:54:58 -04:00
parent 5aa6fe49bf
commit 9d6b6dcc92
4 changed files with 16 additions and 17 deletions
+8 -9
View File
@@ -3,6 +3,7 @@ import json
import numpy
from PIL import Image
from typing import Tuple, List
import torch
from min_dalle.load_params import load_dalle_bart_flax_params
from min_dalle.text_tokenizer import TextTokenizer
@@ -53,25 +54,23 @@ def generate_image_from_text(
text_tokens = tokenize_text(text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(model_path)
image_tokens = numpy.zeros(config['image_length'])
if is_torch:
image_tokens[:image_token_count] = generate_image_tokens_torch(
image_tokens = generate_image_tokens_torch(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
image_token_count = image_token_count
)
if image_token_count == config['image_length']:
image = detokenize_torch(image_tokens)
return Image.fromarray(image)
else:
image_tokens[...] = generate_image_tokens_flax(
image_tokens = generate_image_tokens_flax(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
)
if image_token_count == config['image_length']:
image = detokenize_torch(image_tokens)
return Image.fromarray(image)
else:
return None
image = detokenize_torch(torch.tensor(image_tokens))
return Image.fromarray(image)