fix typing

This commit is contained in:
Brett Kuprel
2022-07-07 17:18:30 -04:00
parent 2cac9220b5
commit 736904ef2f
3 changed files with 14 additions and 14 deletions
+5 -5
View File
@@ -2,7 +2,7 @@ import os
from PIL import Image
from matplotlib.pyplot import grid
import numpy
from torch import LongTensor
from torch import LongTensor, FloatTensor
from math import sqrt
import torch
import json
@@ -148,7 +148,7 @@ class MinDalle:
self,
image_tokens: LongTensor,
is_verbose: bool = False
) -> LongTensor:
) -> FloatTensor:
if not self.is_reusable: del self.decoder
if torch.cuda.is_available(): torch.cuda.empty_cache()
if not self.is_reusable: self.init_detokenizer()
@@ -158,7 +158,7 @@ class MinDalle:
return images
def grid_from_images(self, images: LongTensor) -> Image.Image:
def grid_from_images(self, images: FloatTensor) -> Image.Image:
grid_size = int(sqrt(images.shape[0]))
images = images.reshape([grid_size] * 2 + list(images.shape[1:]))
image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2)
@@ -175,7 +175,7 @@ class MinDalle:
log2_k: int = 6,
log2_supercondition_factor: int = 3,
is_verbose: bool = False
) -> Iterator[LongTensor]:
) -> Iterator[FloatTensor]:
assert(log2_mid_count in range(5))
if is_verbose: print("tokenizing text")
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
@@ -260,7 +260,7 @@ class MinDalle:
log2_k: int = 6,
log2_supercondition_factor: int = 3,
is_verbose: bool = False
) -> LongTensor:
) -> FloatTensor:
log2_mid_count = 0
images_stream = self.generate_images_stream(
text,