0.3.13, simplified code, specify device when initializing MinDalle
This commit is contained in:
Vendored
+12
-5
@@ -135,6 +135,7 @@
|
||||
"\n",
|
||||
"model = MinDalle(\n",
|
||||
" dtype=getattr(torch, dtype),\n",
|
||||
" device='cuda',\n",
|
||||
" is_mega=True, \n",
|
||||
" is_reusable=True\n",
|
||||
")"
|
||||
@@ -196,14 +197,13 @@
|
||||
"grid_size = 5 #@param {type:\"integer\"}\n",
|
||||
"temperature = 2 #@param {type:\"slider\", min:0.01, max:3, step:0.01}\n",
|
||||
"supercondition_factor = 16 #@param {type:\"number\"}\n",
|
||||
"top_k = 256 #@param {type:\"integer\"}\n",
|
||||
"log2_mid_count = 3 if progressive_outputs else 0\n",
|
||||
"top_k = 128 #@param {type:\"integer\"}\n",
|
||||
"\n",
|
||||
"image_stream = model.generate_image_stream(\n",
|
||||
" text=text,\n",
|
||||
" seed=-1,\n",
|
||||
" grid_size=grid_size,\n",
|
||||
" log2_mid_count=log2_mid_count,\n",
|
||||
" progressive_outputs=progressive_outputs,\n",
|
||||
" temperature=temperature,\n",
|
||||
" top_k=int(top_k),\n",
|
||||
" supercondition_factor=float(supercondition_factor)\n",
|
||||
@@ -229,11 +229,18 @@
|
||||
},
|
||||
"gpuClass": "standard",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3.9.13 64-bit",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
"name": "python",
|
||||
"version": "3.9.13"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user