license and cleanup
This commit is contained in:
@@ -23,6 +23,7 @@ class GLUFlax(nn.Module):
|
||||
z = self.fc2(z)
|
||||
return z
|
||||
|
||||
|
||||
class AttentionFlax(nn.Module):
|
||||
head_count: int
|
||||
embed_count: int
|
||||
@@ -61,6 +62,7 @@ class AttentionFlax(nn.Module):
|
||||
attention_output = self.out_proj(attention_output)
|
||||
return attention_output
|
||||
|
||||
|
||||
class EncoderSelfAttentionFlax(AttentionFlax):
|
||||
def __call__(
|
||||
self,
|
||||
@@ -74,6 +76,7 @@ class EncoderSelfAttentionFlax(AttentionFlax):
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
return self.forward(keys, values, queries, attention_mask)
|
||||
|
||||
|
||||
class DalleBartEncoderLayerFlax(nn.Module):
|
||||
attention_head_count: int
|
||||
embed_count: int
|
||||
@@ -103,6 +106,7 @@ class DalleBartEncoderLayerFlax(nn.Module):
|
||||
encoder_state = residual + encoder_state
|
||||
return encoder_state, None
|
||||
|
||||
|
||||
class DalleBartEncoderFlax(nn.Module):
|
||||
attention_head_count: int
|
||||
embed_count: int
|
||||
|
||||
Reference in New Issue
Block a user