From f071b31bddc61f57648cd7874ec2ac8a08e30861 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Tue, 5 Jul 2022 22:14:19 -0400 Subject: [PATCH] properly limit input to 64 tokens --- cog.yaml | 4 ++-- min_dalle/min_dalle.py | 2 ++ replicate_predictor.py | 2 +- setup.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/cog.yaml b/cog.yaml index 371a187..5891bff 100644 --- a/cog.yaml +++ b/cog.yaml @@ -1,12 +1,12 @@ build: - cuda: "11.0" + cuda: "11.3" gpu: true python_version: "3.8" system_packages: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.2.28" + - "min-dalle==0.2.29" run: - pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index 3320b74..b0661d0 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -172,6 +172,8 @@ class MinDalle: assert(log2_mid_count in range(5)) if is_verbose: print("tokenizing text") tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) + if len(tokens) > self.text_token_count: + tokens = tokens[:self.text_token_count] if is_verbose: print("text tokens", tokens) text_tokens = numpy.ones((2, 64), dtype=numpy.int32) text_tokens[0, :2] = [tokens[0], tokens[-1]] diff --git a/replicate_predictor.py b/replicate_predictor.py index a19283d..4e8e263 100644 --- a/replicate_predictor.py +++ b/replicate_predictor.py @@ -11,7 +11,7 @@ class ReplicatePredictor(BasePredictor): def predict( self, text: str = Input( - description='Text', + description='For long prompts, only the first 64 tokens will be used to generate the image.', default='Dali painting of WALL·E' ), intermediate_outputs: bool = Input( diff --git a/setup.py b/setup.py index 4fe58cc..e45b90e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALL·E)', long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.2.28', + version='0.2.29', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',