refactored to load models once and run multiple times

This commit is contained in:
Brett Kuprel
2022-06-29 09:42:12 -04:00
parent 1ef9b0b929
commit ed91ab4a30
11 changed files with 225 additions and 282 deletions
+4 -2
View File
@@ -34,7 +34,8 @@ class AttentionFlax(nn.Module):
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
def forward(self,
def forward(
self,
keys: jnp.ndarray,
values: jnp.ndarray,
queries: jnp.ndarray,
@@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
def __call__(
self,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray: