refactored to load models once and run multiple times
This commit is contained in:
@@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax):
|
||||
|
||||
|
||||
class DecoderSelfAttentionFlax(AttentionFlax):
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
decoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
@@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
decoder_state: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
@@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module):
|
||||
self.final_ln = nn.LayerNorm(use_scale=False)
|
||||
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
|
||||
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
encoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
@@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module):
|
||||
decoder_state = self.lm_head(decoder_state)
|
||||
return decoder_state, keys_state, values_state
|
||||
|
||||
def sample_image_tokens(self,
|
||||
def sample_image_tokens(
|
||||
self,
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
prng_key: jax.random.PRNGKey,
|
||||
|
||||
Reference in New Issue
Block a user