refactored to load models once and run multiple times
This commit is contained in:
@@ -34,7 +34,8 @@ class AttentionFlax(nn.Module):
|
||||
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
keys: jnp.ndarray,
|
||||
values: jnp.ndarray,
|
||||
queries: jnp.ndarray,
|
||||
@@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
|
||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
encoder_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray
|
||||
) -> jnp.ndarray:
|
||||
|
||||
Reference in New Issue
Block a user