license and cleanup

This commit is contained in:
Brett Kuprel
2022-06-27 14:34:10 -04:00
parent 32b7aa196b
commit 18e6a9852f
7 changed files with 25 additions and 42 deletions
@@ -23,6 +23,7 @@ class GLUFlax(nn.Module):
z = self.fc2(z)
return z
class AttentionFlax(nn.Module):
head_count: int
embed_count: int
@@ -61,6 +62,7 @@ class AttentionFlax(nn.Module):
attention_output = self.out_proj(attention_output)
return attention_output
class EncoderSelfAttentionFlax(AttentionFlax):
def __call__(
self,
@@ -74,6 +76,7 @@ class EncoderSelfAttentionFlax(AttentionFlax):
queries /= queries.shape[-1] ** 0.5
return self.forward(keys, values, queries, attention_mask)
class DalleBartEncoderLayerFlax(nn.Module):
attention_head_count: int
embed_count: int
@@ -103,6 +106,7 @@ class DalleBartEncoderLayerFlax(nn.Module):
encoder_state = residual + encoder_state
return encoder_state, None
class DalleBartEncoderFlax(nn.Module):
attention_head_count: int
embed_count: int