sort -> topk, prev_token_and_index -> prev_token, token_index
This commit is contained in:
@@ -61,6 +61,7 @@ class AttentionBlock(Module):
|
||||
h = self.proj_out.forward(h)
|
||||
return x + h
|
||||
|
||||
|
||||
class MiddleLayer(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -74,6 +75,7 @@ class MiddleLayer(Module):
|
||||
h = self.block_2.forward(h)
|
||||
return h
|
||||
|
||||
|
||||
class Upsample(Module):
|
||||
def __init__(self, log2_count):
|
||||
super().__init__()
|
||||
@@ -86,6 +88,7 @@ class Upsample(Module):
|
||||
x = self.conv.forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class UpsampleBlock(Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -124,6 +127,7 @@ class UpsampleBlock(Module):
|
||||
h = self.upsample.forward(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -154,6 +158,7 @@ class Decoder(Module):
|
||||
z = self.conv_out.forward(z)
|
||||
return z
|
||||
|
||||
|
||||
class VQGanDetokenizer(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user