diff --git a/min_dalle.ipynb b/min_dalle.ipynb index 99d4b87..f3f25a3 100644 --- a/min_dalle.ipynb +++ b/min_dalle.ipynb @@ -130,9 +130,14 @@ "source": [ "from IPython.display import display, update_display\n", "from math import log2\n", + "import torch\n", "from min_dalle import MinDalle\n", "\n", - "model = MinDalle(is_mega=True, is_reusable=True)" + "model = MinDalle(\n", + " dtype=torch.float32,\n", + " is_mega=True, \n", + " is_reusable=True\n", + ")" ] }, { diff --git a/replicate_predictor.py b/replicate_predictor.py index 839e953..d9ab027 100644 --- a/replicate_predictor.py +++ b/replicate_predictor.py @@ -19,7 +19,7 @@ class ReplicatePredictor(BasePredictor): default=True ), grid_size: int = Input( - description='Size of the image grid. 4x4 takes about 15 seconds, 8x8 takes about 45 seconds', + description='Size of the image grid. 4x4 takes about 15 seconds, 8x8 takes about 35 seconds', ge=1, le=8, default=4