vqgan needs to be float32

This commit is contained in:
Brett Kuprel
2022-07-07 08:53:27 -04:00
parent da62298f06
commit c199507a7a
4 changed files with 7 additions and 6 deletions
Vendored
+1 -1
View File
@@ -36,7 +36,7 @@ model = MinDalle(
)
```
The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `dtype=torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens.
The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens.
```python
image = model.generate_image(