support bfloat16

This commit is contained in:
Brett Kuprel
2022-07-07 08:21:20 -04:00
parent 5f526e2109
commit da62298f06
9 changed files with 108 additions and 96 deletions
Vendored
+48 -41
View File
@@ -1,16 +1,16 @@
min(DALL·E)
===========
|Open In Colab|   |Replicate|   |Join us on Discord|
|Colab|   |Replicate|   |Discord|
This is a fast, minimal implementation of Boris Daymas `DALL·E
This is a fast, minimal port of Boris Daymas `DALL·E
Mega <https://github.com/borisdayma/dalle-mini>`__. It has been stripped
down for inference and converted to PyTorch. The only third party
dependencies are numpy, requests, pillow and torch.
To generate a 4x4 grid of DALL·E Mega images it takes: - 89 sec with a
T4 in Colab - 48 sec with a P100 in Colab - 14 sec with an A100 on
Replicate - TBD with an H100 (@NVIDIA?)
Replicate
The flax model and code for converting it to torch can be found
`here <https://github.com/kuprel/min-dalle-flax>`__.
@@ -32,47 +32,58 @@ images.
from min_dalle import MinDalle
model = MinDalle(is_mega=True, models_root='./pretrained')
model = MinDalle(
is_mega=True,
is_reusable=True,
models_root='./pretrained'
)
The required models will be downloaded to ``models_root`` if they are
not already there. Once everything has finished initializing, call
``generate_image`` with some text and a seed as many times as you want.
``generate_image`` with some text as many times as you want. Use a
positive ``seed`` for reproducible results. Higher values for
``log2_supercondition_factor`` result in better agreement with the text
but a narrower variety of generated images. Every image token is sampled
from the top-:math:`k` most probable tokens.
.. code:: python
text = 'Dali painting of WALL·E'
image = model.generate_image(text, seed=0, grid_size=4)
image = model.generate_image(
text='Nuclear explosion broccoli',
seed=-1,
grid_size=4,
log2_k=6,
log2_supercondition_factor=5,
is_verbose=False
)
display(image)
Interactive
~~~~~~~~~~~
If the model is being used interactively (e.g. in a notebook)
``generate_image_stream`` can be used to generate a stream of images as
the model is decoding. The detokenizer adds a slight delay for each
image. Setting ``log2_mid_count`` to 3 results in a total of
``2 ** 3 = 8`` generated images. The only valid values for
``log2_mid_count`` are 0, 1, 2, 3, and 4. This is implemented in the
colab.
.. code:: python
text = 'Rusty Iron Man suit found abandoned in the woods being reclaimed by nature'
image = model.generate_image(text, seed=0, grid_size=3)
display(image)
image_stream = model.generate_image_stream(
text='Dali painting of WALL·E',
seed=-1,
grid_size=3,
log2_mid_count=3,
log2_k=6,
log2_supercondition_factor=3,
is_verbose=False
)
.. code:: python
text = 'court sketch of godzilla on trial'
image = model.generate_image(text, seed=6, grid_size=3)
display(image)
.. code:: python
text = 'a funeral at Whole Foods'
image = model.generate_image(text, seed=10, grid_size=3)
display(image)
.. code:: python
text = 'Jesus turning water into wine on Americas Got Talent'
image = model.generate_image(text, seed=2, grid_size=3)
display(image)
.. code:: python
text = 'cctv footage of Yoda robbing a liquor store'
image = model.generate_image(text, seed=0, grid_size=3)
display(image)
for image in image_stream:
display(image)
Command Line
~~~~~~~~~~~~
@@ -81,15 +92,11 @@ Use ``image_from_text.py`` to generate images from the command line.
.. code:: bash
$ python image_from_text.py --text='artificial intelligence' --no-mega --seed=7
$ python image_from_text.py --text='artificial intelligence' --no-mega
.. code:: bash
$ python image_from_text.py --text='trail cam footage of gollum eating watermelon' --mega --seed=1 --grid-size=3
.. |Open In Colab| image:: https://colab.research.google.com/assets/colab-badge.svg
.. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb
.. |Replicate| image:: https://replicate.com/kuprel/min-dalle/badge
:target: https://replicate.com/kuprel/min-dalle
.. |Join us on Discord| image:: https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white
:target: https://discord.gg/xBPBXfcFHd
.. |Discord| image:: https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white
:target: https://discord.com/channels/823813159592001537/912729332311556136