added grid_size parameter to generate a grid of images
This commit is contained in:
@@ -3,8 +3,6 @@ from torch import Tensor
|
||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
BATCH_COUNT: int = 1
|
||||
|
||||
|
||||
class ResnetBlock(Module):
|
||||
def __init__(self, log2_count_in: int, log2_count_out: int):
|
||||
@@ -42,22 +40,22 @@ class AttentionBlock(Module):
|
||||
self.proj_out = Conv2d(n, n, 1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
n = 2 ** 9
|
||||
n, m = 2 ** 9, x.shape[0]
|
||||
h = x
|
||||
h = self.norm(h)
|
||||
q = self.q.forward(h)
|
||||
k = self.k.forward(h)
|
||||
v = self.v.forward(h)
|
||||
q = q.reshape(BATCH_COUNT, n, 2 ** 8)
|
||||
q = q.reshape(m, n, 2 ** 8)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(BATCH_COUNT, n, 2 ** 8)
|
||||
k = k.reshape(m, n, 2 ** 8)
|
||||
w = torch.bmm(q, k)
|
||||
w /= n ** 0.5
|
||||
w = torch.softmax(w, dim=2)
|
||||
v = v.reshape(BATCH_COUNT, n, 2 ** 8)
|
||||
v = v.reshape(m, n, 2 ** 8)
|
||||
w = w.permute(0, 2, 1)
|
||||
h = torch.bmm(v, w)
|
||||
h = h.reshape(BATCH_COUNT, n, 2 ** 4, 2 ** 4)
|
||||
h = h.reshape(m, n, 2 ** 4, 2 ** 4)
|
||||
h = self.proj_out.forward(h)
|
||||
return x + h
|
||||
|
||||
@@ -169,10 +167,10 @@ class VQGanDetokenizer(Module):
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
z = self.embedding.forward(z)
|
||||
z = z.view((BATCH_COUNT, 2 ** 4, 2 ** 4, 2 ** 8))
|
||||
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
|
||||
z = z.permute(0, 3, 1, 2).contiguous()
|
||||
z = self.post_quant_conv.forward(z)
|
||||
z = self.decoder.forward(z)
|
||||
z = z.permute(0, 2, 3, 1)
|
||||
z = z.clip(0.0, 1.0) * 255
|
||||
return z[0]
|
||||
return z
|
||||
|
||||
Reference in New Issue
Block a user