back to linear attention

This commit is contained in:
Brett Kuprel
2022-06-27 13:19:03 -04:00
parent 018414a5c3
commit c936d26102
3 changed files with 45 additions and 69 deletions
-4
View File
@@ -101,10 +101,6 @@ def convert_dalle_bart_torch_from_flax_params(
k = i.replace(j, 'layers.' + str(l))
P[k] = P[i][l]
P.pop(i)
for i in list(P):
if '_proj' in i:
P[i] = P[i][:, :, None, None]
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
P['embed_positions.weight'] = P.pop('embed_positions.embedding')