refactored to load models once and run multiple times

This commit is contained in:
Brett Kuprel
2022-06-29 09:42:12 -04:00
parent 1ef9b0b929
commit ed91ab4a30
11 changed files with 225 additions and 282 deletions
+8 -4
View File
@@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax):
class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(self,
def __call__(
self,
decoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
@@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
def __call__(
self,
decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
@@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module):
self.final_ln = nn.LayerNorm(use_scale=False)
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
def __call__(self,
def __call__(
self,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
@@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module):
decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state
def sample_image_tokens(self,
def sample_image_tokens(
self,
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
prng_key: jax.random.PRNGKey,