support bfloat16
This commit is contained in:
@@ -12,7 +12,6 @@ To generate a 4x4 grid of DALL·E Mega images it takes:
|
||||
- 89 sec with a T4 in Colab
|
||||
- 48 sec with a P100 in Colab
|
||||
- 14 sec with an A100 on Replicate
|
||||
- TBD with an H100 (@NVIDIA?)
|
||||
|
||||
The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax).
|
||||
|
||||
@@ -30,13 +29,14 @@ Load the model parameters once and reuse the model to generate multiple images.
|
||||
from min_dalle import MinDalle
|
||||
|
||||
model = MinDalle(
|
||||
models_root='./pretrained',
|
||||
dtype=torch.float32,
|
||||
is_mega=True,
|
||||
is_reusable=True,
|
||||
models_root='./pretrained'
|
||||
is_reusable=True
|
||||
)
|
||||
```
|
||||
|
||||
The required models will be downloaded to `models_root` if they are not already there. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top $k$ most probable tokens.
|
||||
The required models will be downloaded to `models_root` if they are not already there. If you have an Ampere architecture GPU you can set the `dtype=torch.bfloat16` and save GPU memory. There is still an issue with `dtype=torch.float16` that needs to be sorted out. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `log2_supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the top-$k$ most probable tokens.
|
||||
|
||||
```python
|
||||
image = model.generate_image(
|
||||
|
||||
Reference in New Issue
Block a user