simplified flax attention and matched torch attention
This commit is contained in:
@@ -41,6 +41,10 @@ class AttentionFlax(nn.Module):
|
||||
queries: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray
|
||||
) -> jnp.ndarray:
|
||||
keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
|
||||
values = values.reshape(values.shape[:2] + (self.head_count, -1))
|
||||
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
attention_bias: jnp.ndarray = lax.select(
|
||||
attention_mask,
|
||||
jnp.full(attention_mask.shape, 0.0),
|
||||
@@ -70,11 +74,9 @@ class EncoderSelfAttentionFlax(AttentionFlax):
|
||||
encoder_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray
|
||||
) -> jnp.ndarray:
|
||||
shape_split = encoder_state.shape[:2] + (self.head_count, -1)
|
||||
keys = self.k_proj(encoder_state).reshape(shape_split)
|
||||
values = self.v_proj(encoder_state).reshape(shape_split)
|
||||
queries = self.q_proj(encoder_state).reshape(shape_split)
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
keys = self.k_proj(encoder_state)
|
||||
values = self.v_proj(encoder_state)
|
||||
queries = self.q_proj(encoder_state)
|
||||
return self.forward(keys, values, queries, attention_mask)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user