0.3.13, simplified code, specify device when initializing MinDalle
This commit is contained in:
@@ -36,12 +36,13 @@ from min_dalle import MinDalle
|
||||
model = MinDalle(
|
||||
models_root='./pretrained',
|
||||
dtype=torch.float32,
|
||||
device='cuda',
|
||||
is_mega=True,
|
||||
is_reusable=True
|
||||
)
|
||||
```
|
||||
|
||||
The required models will be downloaded to `models_root` if they are not already there. Set the `dtype` to `torch.float16` to save GPU memory. If you have an Ampere architecture GPU you can use `torch.bfloat16`. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the `top_k` most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the `temperature`.
|
||||
The required models will be downloaded to `models_root` if they are not already there. Set the `dtype` to `torch.float16` to save GPU memory. If you have an Ampere architecture GPU you can use `torch.bfloat16`. Set the `device` to either "cuda" or "cpu". Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the `top_k` most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the `temperature`.
|
||||
|
||||
```python
|
||||
image = model.generate_image(
|
||||
@@ -88,7 +89,7 @@ image.save('image_{}.png'.format(i))
|
||||
|
||||
### Progressive Outputs
|
||||
|
||||
If the model is being used interactively (e.g. in a notebook) `generate_image_stream` can be used to generate a stream of images as the model is decoding. The detokenizer adds a slight delay for each image. Setting `log2_mid_count` to 3 results in a total of `2 ** 3 = 8` generated images. The only valid values for `log2_mid_count` are 0, 1, 2, 3, and 4. This is implemented in the colab.
|
||||
If the model is being used interactively (e.g. in a notebook) `generate_image_stream` can be used to generate a stream of images as the model is decoding. The detokenizer adds a slight delay for each image. Set `progressive_outputs` to `True` to enable this. An example is implemented in the colab.
|
||||
|
||||
```python
|
||||
image_stream = model.generate_image_stream(
|
||||
|
||||
Reference in New Issue
Block a user