faster inference with cuda/cudnn backends flags
This commit is contained in:
@@ -4,15 +4,21 @@ import numpy
|
||||
from torch import LongTensor, FloatTensor
|
||||
from math import sqrt
|
||||
import torch
|
||||
import torch.backends.cudnn, torch.backends.cuda
|
||||
import json
|
||||
import requests
|
||||
from typing import Iterator
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_num_threads(os.cpu_count())
|
||||
|
||||
from .text_tokenizer import TextTokenizer
|
||||
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_num_threads(os.cpu_count())
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
||||
|
||||
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user