support bfloat16
This commit is contained in:
@@ -82,7 +82,7 @@ class Upsample(Module):
|
||||
self.conv = Conv2d(n, n, 3, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.upsample.forward(x)
|
||||
x = self.upsample.forward(x.to(torch.float32))
|
||||
x = self.conv.forward(x)
|
||||
return x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user