added grid_size parameter to generate a grid of images

This commit is contained in:
Brett Kuprel
2022-07-02 08:45:49 -04:00
parent e0386f991c
commit 1eb56737d8
6 changed files with 87 additions and 69 deletions
+2 -4
View File
@@ -1,4 +1,3 @@
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False)
@@ -121,7 +120,7 @@ class DalleBartEncoder(nn.Module):
super().__init__()
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
self.embed_positions = nn.Embedding(text_token_count, embed_count)
self.layers: List[EncoderLayer] = nn.ModuleList([
self.layers: list[EncoderLayer] = nn.ModuleList([
EncoderLayer(
embed_count = embed_count,
head_count = attention_head_count,
@@ -137,8 +136,7 @@ class DalleBartEncoder(nn.Module):
def forward(self, text_tokens: LongTensor) -> FloatTensor:
attention_mask = text_tokens.not_equal(1)
batch_count = text_tokens.shape[0]
pose_tokens = torch.stack([self.token_indices] * batch_count)
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
encoder_state = (
self.embed_tokens.forward(text_tokens) +
self.embed_positions.forward(pose_tokens)