support bfloat16

This commit is contained in:
Brett Kuprel
2022-07-07 08:21:20 -04:00
parent 5f526e2109
commit da62298f06
9 changed files with 108 additions and 96 deletions
+1 -1
View File
@@ -82,7 +82,7 @@ class Upsample(Module):
self.conv = Conv2d(n, n, 3, padding=1)
def forward(self, x: Tensor) -> Tensor:
x = self.upsample.forward(x)
x = self.upsample.forward(x.to(torch.float32))
x = self.conv.forward(x)
return x