license and cleanup

This commit is contained in:
Brett Kuprel
2022-06-27 14:34:10 -04:00
parent 32b7aa196b
commit 18e6a9852f
7 changed files with 25 additions and 42 deletions
+7 -10
View File
@@ -2,7 +2,8 @@ import torch
from torch import Tensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
batch_size: int = 1
BATCH_COUNT: int = 1
class ResnetBlock(Module):
def __init__(self, log2_count_in: int, log2_count_out: int):
@@ -46,16 +47,16 @@ class AttentionBlock(Module):
q = self.q.forward(h)
k = self.k.forward(h)
v = self.v.forward(h)
q = q.reshape(batch_size, n, 2 ** 8)
q = q.reshape(BATCH_COUNT, n, 2 ** 8)
q = q.permute(0, 2, 1)
k = k.reshape(batch_size, n, 2 ** 8)
k = k.reshape(BATCH_COUNT, n, 2 ** 8)
w = torch.bmm(q, k)
w /= n ** 0.5
w = torch.softmax(w, dim=2)
v = v.reshape(batch_size, n, 2 ** 8)
v = v.reshape(BATCH_COUNT, n, 2 ** 8)
w = w.permute(0, 2, 1)
h = torch.bmm(v, w)
h = h.reshape(batch_size, n, 2 ** 4, 2 ** 4)
h = h.reshape(BATCH_COUNT, n, 2 ** 4, 2 ** 4)
h = self.proj_out.forward(h)
return x + h
@@ -162,14 +163,10 @@ class VQGanDetokenizer(Module):
def forward(self, z: Tensor) -> Tensor:
z = self.embedding.forward(z)
z = z.view((batch_size, 2 ** 4, 2 ** 4, 2 ** 8))
z = z.view((BATCH_COUNT, 2 ** 4, 2 ** 4, 2 ** 8))
z = z.permute(0, 3, 1, 2).contiguous()
z = self.post_quant_conv.forward(z)
z = self.decoder.forward(z)
z = z.permute(0, 2, 3, 1)
# z = torch.concat((
# torch.concat((z[0], z[1]), axis=1),
# torch.concat((z[2], z[3]), axis=1)
# ), axis=0)
z = z.clip(0.0, 1.0) * 255
return z[0]