diff --git a/cog.yaml b/cog.yaml index 42e8f9b..d835b67 100644 --- a/cog.yaml +++ b/cog.yaml @@ -9,4 +9,4 @@ build: - "torch==1.10.1" - "flax==0.5.2" -predict: "predict.py:Predictor" +predict: "replicate/predict.py:Predictor" diff --git a/replicate/predict.py b/replicate/predict.py index 794d563..395d6f6 100644 --- a/replicate/predict.py +++ b/replicate/predict.py @@ -13,10 +13,12 @@ class Predictor(BasePredictor): description="Text for generating images.", ), seed: int = Input( - description="Specify a random seed.", + description="Specify a random seed." ), grid_size: int = Input( description="Specify the grid size.", + ge=1, + le=4 ) ) -> Path: image = self.model.generate_image(text, seed, grid_size=grid_size)