pre converting params to torch allows mega to run in standard colab runtime

This commit is contained in:
Brett Kuprel
2022-06-30 14:54:08 -04:00
parent de97fcf06b
commit b913b58353
5 changed files with 54 additions and 21 deletions
+3
View File
@@ -7,12 +7,15 @@ from .min_dalle_base import MinDalleBase
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
from .load_params import load_dalle_bart_flax_params
class MinDalleFlax(MinDalleBase):
def __init__(self, is_mega: bool, is_reusable: bool = True):
super().__init__(is_mega)
self.is_reusable = is_reusable
print("initializing MinDalleFlax")
self.model_params = load_dalle_bart_flax_params(self.model_path)
if is_reusable:
self.init_encoder()
self.init_decoder()