support bfloat16

This commit is contained in:
Brett Kuprel
2022-07-07 08:21:20 -04:00
parent 5f526e2109
commit da62298f06
9 changed files with 108 additions and 96 deletions
+2 -1
View File
@@ -40,7 +40,8 @@ class DecoderSelfAttention(AttentionBase):
queries = self.q_proj.forward(decoder_state)
attn_mask = self.token_indices < token_index + 1
attn_mask = attn_mask[None][[0] * decoder_state.shape[0]]
attention_state[:, token_index] = torch.cat([keys, values])
attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
attention_state[:, token_index] = attn_state_new
batch_count = decoder_state.shape[0]
keys = attention_state[:batch_count]
values = attention_state[batch_count:]