back to linear attention
This commit is contained in:
@@ -101,10 +101,6 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||
k = i.replace(j, 'layers.' + str(l))
|
||||
P[k] = P[i][l]
|
||||
P.pop(i)
|
||||
|
||||
for i in list(P):
|
||||
if '_proj' in i:
|
||||
P[i] = P[i][:, :, None, None]
|
||||
|
||||
P['embed_tokens.weight'] = P.pop('embed_tokens.embedding')
|
||||
P['embed_positions.weight'] = P.pop('embed_positions.embedding')
|
||||
|
||||
Reference in New Issue
Block a user