From ed91ab4a30ffdcf7e8773e6b434816f79f5fead8 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Wed, 29 Jun 2022 09:42:12 -0400 Subject: [PATCH] refactored to load models once and run multiple times --- README.md | 10 +- image_from_text.py | 53 ++++-- min_dalle/generate_image.py | 78 --------- min_dalle/min_dalle.py | 38 +++++ min_dalle/min_dalle_flax.py | 113 +++++------- min_dalle/min_dalle_torch.py | 171 ++++++++----------- min_dalle/models/dalle_bart_decoder_flax.py | 12 +- min_dalle/models/dalle_bart_decoder_torch.py | 18 +- min_dalle/models/dalle_bart_encoder_flax.py | 6 +- min_dalle/models/dalle_bart_encoder_torch.py | 6 +- min_dalle/text_tokenizer.py | 2 +- 11 files changed, 225 insertions(+), 282 deletions(-) delete mode 100644 min_dalle/generate_image.py create mode 100644 min_dalle/min_dalle.py diff --git a/README.md b/README.md index a268946..1f6509c 100644 --- a/README.md +++ b/README.md @@ -2,18 +2,18 @@ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) -This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model. +This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy`, `torch`, and `flax`. ### Setup -Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually: +Run `sh setup.sh` to install dependencies and download pretrained models. The `wandb` python package is installed to download DALL·E mini and DALL·E mega. Alternatively, the models can be downloaded manually here: [VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384), [DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files), [DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files) ### Usage -Use the command line python script `image_from_text.py` to generate images. Here are some examples: +The simplest way to get started is the command line python script `image_from_text.py` provided. Here are some examples runs: ``` python image_from_text.py --text='alien life' --seed=7 @@ -32,3 +32,7 @@ python image_from_text.py --text='court sketch of godzilla on trial' --mega --se ``` ![Godzilla Trial](examples/godzilla_trial.png) + +### Load once run multiple times + +The command line script loads the models and parameters each time. The colab notebook demonstrates how to load the models once and run multiple times. \ No newline at end of file diff --git a/image_from_text.py b/image_from_text.py index a56522d..31dbc4c 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -2,8 +2,8 @@ import argparse import os from PIL import Image -from min_dalle.generate_image import generate_image_from_text - +from min_dalle.min_dalle_torch import MinDalleTorch +from min_dalle.min_dalle_flax import MinDalleFlax parser = argparse.ArgumentParser() parser.add_argument('--mega', action='store_true') @@ -15,7 +15,7 @@ parser.set_defaults(torch=False) parser.add_argument('--text', type=str) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--image_path', type=str, default='generated') -parser.add_argument('--image_token_count', type=int, default=256) # for debugging +parser.add_argument('--sample_token_count', type=int, default=256) # for debugging def ascii_from_image(image: Image.Image, size: int) -> str: @@ -36,19 +36,40 @@ def save_image(image: Image.Image, path: str): return image +def generate_image( + is_torch: bool, + is_mega: bool, + text: str, + seed: int, + image_path: str, + sample_token_count: int +): + if is_torch: + image_generator = MinDalleTorch(is_mega, sample_token_count) + image_tokens = image_generator.generate_image_tokens(text, seed) + + if sample_token_count < image_generator.config['image_length']: + print('image tokens', list(image_tokens.to('cpu').detach().numpy())) + return + else: + image = image_generator.generate_image(text, seed) + + else: + image_generator = MinDalleFlax(is_mega) + image = image_generator.generate_image(text, seed) + + save_image(image, image_path) + print(ascii_from_image(image, size=128)) + + if __name__ == '__main__': args = parser.parse_args() - print(args) - - image = generate_image_from_text( - text = args.text, - is_mega = args.mega, - is_torch = args.torch, - seed = args.seed, - image_token_count = args.image_token_count - ) - - if image != None: - save_image(image, args.image_path) - print(ascii_from_image(image, size=128)) \ No newline at end of file + generate_image( + is_torch=args.torch, + is_mega=args.mega, + text=args.text, + seed=args.seed, + image_path=args.image_path, + sample_token_count=args.sample_token_count + ) \ No newline at end of file diff --git a/min_dalle/generate_image.py b/min_dalle/generate_image.py deleted file mode 100644 index f7f63cb..0000000 --- a/min_dalle/generate_image.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -import json -import numpy -from PIL import Image -from typing import Tuple, List -import torch - -from min_dalle.load_params import load_dalle_bart_flax_params -from min_dalle.text_tokenizer import TextTokenizer -from min_dalle.min_dalle_flax import generate_image_tokens_flax -from min_dalle.min_dalle_torch import ( - generate_image_tokens_torch, - detokenize_torch -) - -def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]: - print("parsing metadata from {}".format(path)) - for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']: - assert(os.path.exists(os.path.join(path, f))) - with open(path + '/config.json', 'r') as f: - config = json.load(f) - with open(path + '/vocab.json') as f: - vocab = json.load(f) - with open(path + '/merges.txt') as f: - merges = f.read().split("\n")[1:-1] - return config, vocab, merges - - -def tokenize_text( - text: str, - config: dict, - vocab: dict, - merges: List[str] -) -> numpy.ndarray: - print("tokenizing text") - tokens = TextTokenizer(vocab, merges)(text) - print("text tokens", tokens) - text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32) - text_tokens[0, :len(tokens)] = tokens - text_tokens[1, :2] = [tokens[0], tokens[-1]] - return text_tokens - - -def generate_image_from_text( - text: str, - is_mega: bool = False, - is_torch: bool = False, - seed: int = 0, - image_token_count: int = 256 -) -> Image.Image: - model_name = 'mega' if is_mega else 'mini' - model_path = './pretrained/dalle_bart_{}'.format(model_name) - config, vocab, merges = load_dalle_bart_metadata(model_path) - text_tokens = tokenize_text(text, config, vocab, merges) - params_dalle_bart = load_dalle_bart_flax_params(model_path) - - if is_torch: - image_tokens = generate_image_tokens_torch( - text_tokens = text_tokens, - seed = seed, - config = config, - params = params_dalle_bart, - image_token_count = image_token_count - ) - if image_token_count == config['image_length']: - image = detokenize_torch(image_tokens, is_torch=True) - return Image.fromarray(image) - else: - print(list(image_tokens.to('cpu').detach().numpy())) - else: - image_tokens = generate_image_tokens_flax( - text_tokens = text_tokens, - seed = seed, - config = config, - params = params_dalle_bart, - ) - image = detokenize_torch(torch.tensor(image_tokens), is_torch=False) - return Image.fromarray(image) \ No newline at end of file diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py new file mode 100644 index 0000000..2607026 --- /dev/null +++ b/min_dalle/min_dalle.py @@ -0,0 +1,38 @@ +import os +import json +import numpy + +from .text_tokenizer import TextTokenizer +from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params +from .models.vqgan_detokenizer import VQGanDetokenizer + +class MinDalle: + def __init__(self, is_mega: bool): + self.is_mega = is_mega + model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini') + model_path = os.path.join('pretrained', model_name) + + print("reading files from {}".format(model_path)) + with open(os.path.join(model_path, 'config.json'), 'r') as f: + self.config = json.load(f) + with open(os.path.join(model_path, 'vocab.json'), 'r') as f: + vocab = json.load(f) + with open(os.path.join(model_path, 'merges.txt'), 'r') as f: + merges = f.read().split("\n")[1:-1] + self.model_params = load_dalle_bart_flax_params(model_path) + + self.tokenizer = TextTokenizer(vocab, merges) + self.detokenizer = VQGanDetokenizer() + vqgan_params = load_vqgan_torch_params('./pretrained/vqgan') + self.detokenizer.load_state_dict(vqgan_params) + + + def tokenize_text(self, text: str) -> numpy.ndarray: + print("tokenizing text") + tokens = self.tokenizer.tokenize(text) + print("text tokens", tokens) + text_token_count = self.config['max_text_length'] + text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32) + text_tokens[0, :len(tokens)] = tokens + text_tokens[1, :2] = [tokens[0], tokens[-1]] + return text_tokens \ No newline at end of file diff --git a/min_dalle/min_dalle_flax.py b/min_dalle/min_dalle_flax.py index 884f271..100d4ab 100644 --- a/min_dalle/min_dalle_flax.py +++ b/min_dalle/min_dalle_flax.py @@ -1,79 +1,58 @@ import jax -from jax import numpy as jnp import numpy +from PIL import Image +import torch +from .min_dalle import MinDalle from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax -def encode_flax( - text_tokens: numpy.ndarray, - config: dict, - params: dict -) -> jnp.ndarray: - print("loading flax encoder") - encoder: DalleBartEncoderFlax = DalleBartEncoderFlax( - attention_head_count = config['encoder_attention_heads'], - embed_count = config['d_model'], - glu_embed_count = config['encoder_ffn_dim'], - text_token_count = config['max_text_length'], - text_vocab_count = config['encoder_vocab_size'], - layer_count = config['encoder_layers'] - ).bind({'params': params.pop('encoder')}) +class MinDalleFlax(MinDalle): + def __init__(self, is_mega: bool): + super().__init__(is_mega) + print("initializing MinDalleFlax") - print("encoding text tokens") - encoder_state = encoder(text_tokens) - del encoder - return encoder_state + print("loading encoder") + self.encoder = DalleBartEncoderFlax( + attention_head_count = self.config['encoder_attention_heads'], + embed_count = self.config['d_model'], + glu_embed_count = self.config['encoder_ffn_dim'], + text_token_count = self.config['max_text_length'], + text_vocab_count = self.config['encoder_vocab_size'], + layer_count = self.config['encoder_layers'] + ).bind({'params': self.model_params.pop('encoder')}) + print("loading decoder") + self.decoder = DalleBartDecoderFlax( + image_token_count = self.config['image_length'], + text_token_count = self.config['max_text_length'], + image_vocab_count = self.config['image_vocab_size'], + attention_head_count = self.config['decoder_attention_heads'], + embed_count = self.config['d_model'], + glu_embed_count = self.config['decoder_ffn_dim'], + layer_count = self.config['decoder_layers'], + start_token = self.config['decoder_start_token_id'] + ) + -def decode_flax( - text_tokens: jnp.ndarray, - encoder_state: jnp.ndarray, - config: dict, - seed: int, - params: dict -) -> jnp.ndarray: - print("loading flax decoder") - decoder = DalleBartDecoderFlax( - image_token_count = config['image_length'], - text_token_count = config['max_text_length'], - image_vocab_count = config['image_vocab_size'], - attention_head_count = config['decoder_attention_heads'], - embed_count = config['d_model'], - glu_embed_count = config['decoder_ffn_dim'], - layer_count = config['decoder_layers'], - start_token = config['decoder_start_token_id'] - ) - print("sampling image tokens") - image_tokens = decoder.sample_image_tokens( - text_tokens, - encoder_state, - jax.random.PRNGKey(seed), - params.pop('decoder') - ) - del decoder - return image_tokens + def generate_image(self, text: str, seed: int) -> Image.Image: + text_tokens = self.tokenize_text(text) + print("encoding text tokens") + encoder_state = self.encoder(text_tokens) -def generate_image_tokens_flax( - text_tokens: numpy.ndarray, - seed: int, - config: dict, - params: dict -) -> numpy.ndarray: - encoder_state = encode_flax( - text_tokens, - config, - params - ) - image_tokens = decode_flax( - text_tokens, - encoder_state, - config, - seed, - params - ) - image_tokens = numpy.array(image_tokens) - print("image tokens", list(image_tokens)) - return image_tokens \ No newline at end of file + print("sampling image tokens") + image_tokens = self.decoder.sample_image_tokens( + text_tokens, + encoder_state, + jax.random.PRNGKey(seed), + self.model_params['decoder'] + ) + + image_tokens = torch.tensor(numpy.array(image_tokens)) + + print("detokenizing image") + image = self.detokenizer.forward(image_tokens).to(torch.uint8) + image = Image.fromarray(image.to('cpu').detach().numpy()) + return image \ No newline at end of file diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 228c601..6bf71af 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -1,118 +1,83 @@ +from random import sample import numpy import os +from PIL import Image from typing import Dict -from torch import LongTensor, FloatTensor +from torch import LongTensor import torch torch.set_grad_enabled(False) torch.set_num_threads(os.cpu_count()) -from .models.vqgan_detokenizer import VQGanDetokenizer +from .load_params import convert_dalle_bart_torch_from_flax_params +from .min_dalle import MinDalle from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch -from .load_params import ( - load_vqgan_torch_params, - convert_dalle_bart_torch_from_flax_params -) + +class MinDalleTorch(MinDalle): + def __init__(self, is_mega: bool, sample_token_count: int = 256): + super().__init__(is_mega) + print("initializing MinDalleTorch") + + print("loading encoder") + self.encoder = DalleBartEncoderTorch( + layer_count = self.config['encoder_layers'], + embed_count = self.config['d_model'], + attention_head_count = self.config['encoder_attention_heads'], + text_vocab_count = self.config['encoder_vocab_size'], + text_token_count = self.config['max_text_length'], + glu_embed_count = self.config['encoder_ffn_dim'] + ) + encoder_params = convert_dalle_bart_torch_from_flax_params( + self.model_params.pop('encoder'), + layer_count=self.config['encoder_layers'], + is_encoder=True + ) + self.encoder.load_state_dict(encoder_params, strict=False) + + print("loading decoder") + self.decoder = DalleBartDecoderTorch( + image_vocab_size = self.config['image_vocab_size'], + image_token_count = self.config['image_length'], + sample_token_count = sample_token_count, + embed_count = self.config['d_model'], + attention_head_count = self.config['decoder_attention_heads'], + glu_embed_count = self.config['decoder_ffn_dim'], + layer_count = self.config['decoder_layers'], + batch_count = 2, + start_token = self.config['decoder_start_token_id'], + is_verbose = True + ) + decoder_params = convert_dalle_bart_torch_from_flax_params( + self.model_params.pop('decoder'), + layer_count=self.config['decoder_layers'], + is_encoder=False + ) + self.decoder.load_state_dict(decoder_params, strict=False) + + if torch.cuda.is_available(): + self.encoder = self.encoder.cuda() + self.decoder = self.decoder.cuda() + self.detokenizer = self.detokenizer.cuda() -def encode_torch( - text_tokens: LongTensor, - config: dict, - params: dict -) -> FloatTensor: - print("loading torch encoder") - encoder = DalleBartEncoderTorch( - layer_count = config['encoder_layers'], - embed_count = config['d_model'], - attention_head_count = config['encoder_attention_heads'], - text_vocab_count = config['encoder_vocab_size'], - text_token_count = config['max_text_length'], - glu_embed_count = config['encoder_ffn_dim'] - ) - encoder_params = convert_dalle_bart_torch_from_flax_params( - params.pop('encoder'), - layer_count=config['encoder_layers'], - is_encoder=True - ) - encoder.load_state_dict(encoder_params, strict=False) - del encoder_params - if torch.cuda.is_available(): encoder = encoder.cuda() + def generate_image_tokens(self, text: str, seed: int) -> LongTensor: + text_tokens = self.tokenize_text(text) + text_tokens = torch.tensor(text_tokens).to(torch.long) + if torch.cuda.is_available(): text_tokens = text_tokens.cuda() - print("encoding text tokens") - encoder_state = encoder(text_tokens) - del encoder - return encoder_state + print("encoding text tokens") + encoder_state = self.encoder.forward(text_tokens) + print("sampling image tokens") + torch.manual_seed(seed) + image_tokens = self.decoder.forward(text_tokens, encoder_state) + return image_tokens + -def decode_torch( - text_tokens: LongTensor, - encoder_state: FloatTensor, - config: dict, - seed: int, - params: dict, - image_token_count: int -) -> LongTensor: - print("loading torch decoder") - decoder = DalleBartDecoderTorch( - image_vocab_size = config['image_vocab_size'], - image_token_count = config['image_length'], - sample_token_count = image_token_count, - embed_count = config['d_model'], - attention_head_count = config['decoder_attention_heads'], - glu_embed_count = config['decoder_ffn_dim'], - layer_count = config['decoder_layers'], - batch_count = 2, - start_token = config['decoder_start_token_id'], - is_verbose = True - ) - decoder_params = convert_dalle_bart_torch_from_flax_params( - params.pop('decoder'), - layer_count=config['decoder_layers'], - is_encoder=False - ) - decoder.load_state_dict(decoder_params, strict=False) - del decoder_params - if torch.cuda.is_available(): decoder = decoder.cuda() - - print("sampling image tokens") - torch.manual_seed(seed) - image_tokens = decoder.forward(text_tokens, encoder_state) - return image_tokens - - -def generate_image_tokens_torch( - text_tokens: numpy.ndarray, - seed: int, - config: dict, - params: dict, - image_token_count: int -) -> LongTensor: - text_tokens = torch.tensor(text_tokens).to(torch.long) - if torch.cuda.is_available(): text_tokens = text_tokens.cuda() - encoder_state = encode_torch( - text_tokens, - config, - params - ) - image_tokens = decode_torch( - text_tokens, - encoder_state, - config, - seed, - params, - image_token_count - ) - return image_tokens - - -def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray: - print("detokenizing image") - model_path = './pretrained/vqgan' - params = load_vqgan_torch_params(model_path) - detokenizer = VQGanDetokenizer() - detokenizer.load_state_dict(params) - if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda() - image = detokenizer.forward(image_tokens).to(torch.uint8) - del detokenizer, params - return image.to('cpu').detach().numpy() + def generate_image(self, text: str, seed: int) -> Image.Image: + image_tokens = self.generate_image_tokens(text, seed) + print("detokenizing image") + image = self.detokenizer.forward(image_tokens).to(torch.uint8) + image = Image.fromarray(image.to('cpu').detach().numpy()) + return image \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_decoder_flax.py b/min_dalle/models/dalle_bart_decoder_flax.py index caf28ec..fa2d457 100644 --- a/min_dalle/models/dalle_bart_decoder_flax.py +++ b/min_dalle/models/dalle_bart_decoder_flax.py @@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax): class DecoderSelfAttentionFlax(AttentionFlax): - def __call__(self, + def __call__( + self, decoder_state: jnp.ndarray, keys_state: jnp.ndarray, values_state: jnp.ndarray, @@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module): self.glu = GLUFlax(self.embed_count, self.glu_embed_count) @nn.compact - def __call__(self, + def __call__( + self, decoder_state: jnp.ndarray, encoder_state: jnp.ndarray, keys_state: jnp.ndarray, @@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module): self.final_ln = nn.LayerNorm(use_scale=False) self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False) - def __call__(self, + def __call__( + self, encoder_state: jnp.ndarray, keys_state: jnp.ndarray, values_state: jnp.ndarray, @@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module): decoder_state = self.lm_head(decoder_state) return decoder_state, keys_state, values_state - def sample_image_tokens(self, + def sample_image_tokens( + self, text_tokens: jnp.ndarray, encoder_state: jnp.ndarray, prng_key: jax.random.PRNGKey, diff --git a/min_dalle/models/dalle_bart_decoder_torch.py b/min_dalle/models/dalle_bart_decoder_torch.py index f4555ab..bce3bff 100644 --- a/min_dalle/models/dalle_bart_decoder_torch.py +++ b/min_dalle/models/dalle_bart_decoder_torch.py @@ -26,7 +26,8 @@ class DecoderCrossAttentionTorch(AttentionTorch): class DecoderSelfAttentionTorch(AttentionTorch): - def forward(self, + def forward( + self, decoder_state: FloatTensor, keys_values: FloatTensor, attention_mask: BoolTensor, @@ -49,7 +50,8 @@ class DecoderSelfAttentionTorch(AttentionTorch): class DecoderLayerTorch(nn.Module): - def __init__(self, + def __init__( + self, image_token_count: int, head_count: int, embed_count: int, @@ -69,7 +71,8 @@ class DecoderLayerTorch(nn.Module): if torch.cuda.is_available(): self.token_indices = self.token_indices.cuda() - def forward(self, + def forward( + self, decoder_state: FloatTensor, encoder_state: FloatTensor, keys_values_state: FloatTensor, @@ -111,7 +114,8 @@ class DecoderLayerTorch(nn.Module): class DalleBartDecoderTorch(nn.Module): - def __init__(self, + def __init__( + self, image_vocab_size: int, image_token_count: int, sample_token_count: int, @@ -158,7 +162,8 @@ class DalleBartDecoderTorch(nn.Module): self.start_token = self.start_token.cuda() - def decode_step(self, + def decode_step( + self, text_tokens: LongTensor, encoder_state: FloatTensor, keys_values_state: FloatTensor, @@ -198,7 +203,8 @@ class DalleBartDecoderTorch(nn.Module): return probs, keys_values - def forward(self, + def forward( + self, text_tokens: LongTensor, encoder_state: FloatTensor ) -> LongTensor: diff --git a/min_dalle/models/dalle_bart_encoder_flax.py b/min_dalle/models/dalle_bart_encoder_flax.py index 71bbef3..3d159f0 100644 --- a/min_dalle/models/dalle_bart_encoder_flax.py +++ b/min_dalle/models/dalle_bart_encoder_flax.py @@ -34,7 +34,8 @@ class AttentionFlax(nn.Module): self.v_proj = nn.Dense(self.embed_count, use_bias=False) self.out_proj = nn.Dense(self.embed_count, use_bias=False) - def forward(self, + def forward( + self, keys: jnp.ndarray, values: jnp.ndarray, queries: jnp.ndarray, @@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module): self.glu = GLUFlax(self.embed_count, self.glu_embed_count) @nn.compact - def __call__(self, + def __call__( + self, encoder_state: jnp.ndarray, attention_mask: jnp.ndarray ) -> jnp.ndarray: diff --git a/min_dalle/models/dalle_bart_encoder_torch.py b/min_dalle/models/dalle_bart_encoder_torch.py index 92bf775..afd6295 100644 --- a/min_dalle/models/dalle_bart_encoder_torch.py +++ b/min_dalle/models/dalle_bart_encoder_torch.py @@ -37,7 +37,8 @@ class AttentionTorch(nn.Module): self.one = torch.ones((1, 1)) if torch.cuda.is_available(): self.one = self.one.cuda() - def forward(self, + def forward( + self, keys: FloatTensor, values: FloatTensor, queries: FloatTensor, @@ -105,7 +106,8 @@ class EncoderLayerTorch(nn.Module): class DalleBartEncoderTorch(nn.Module): - def __init__(self, + def __init__( + self, layer_count: int, embed_count: int, attention_head_count: int, diff --git a/min_dalle/text_tokenizer.py b/min_dalle/text_tokenizer.py index 1d601e6..1d06349 100644 --- a/min_dalle/text_tokenizer.py +++ b/min_dalle/text_tokenizer.py @@ -8,7 +8,7 @@ class TextTokenizer: pairs = [tuple(pair.split()) for pair in merges] self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) - def __call__(self, text: str) -> List[int]: + def tokenize(self, text: str) -> List[int]: sep_token = self.token_from_subword[''] cls_token = self.token_from_subword[''] unk_token = self.token_from_subword['']