faster inference with cuda/cudnn backends flags

This commit is contained in:
Brett Kuprel
2022-07-09 06:48:51 -04:00
parent 703bfb231d
commit dba3f11b3f
8 changed files with 19 additions and 14 deletions
+1 -2
View File
@@ -1,8 +1,6 @@
from typing import Tuple, List
import torch
from torch import nn, LongTensor, FloatTensor, BoolTensor
torch.set_grad_enabled(False)
from .dalle_bart_encoder import GLU, AttentionBase
IMAGE_TOKEN_COUNT = 256
@@ -180,6 +178,7 @@ class DalleBartDecoder(nn.Module):
self.zero_prob,
torch.exp(logits - top_logits[:, [0]])
)
probs[:, 2 ** 14:] = 0 # vqgan vocab_count is only 2 ** 14
return probs, attention_state