works with cuda

This commit is contained in:
Brett Kuprel
2022-06-28 21:28:36 -04:00
parent 9d6b6dcc92
commit 17c96fe110
6 changed files with 43 additions and 33 deletions
+1 -1
View File
@@ -1,7 +1,7 @@
import torch
from torch import Tensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.no_grad()
torch.set_grad_enabled(False)
BATCH_COUNT: int = 1