delete cache
This commit is contained in:
@@ -148,7 +148,7 @@ class DalleBartDecoderFlax(nn.Module):
|
||||
)
|
||||
self.layers = nn.scan(
|
||||
DalleBartDecoderLayerFlax,
|
||||
variable_axes = { "params": 0, "cache": 0 },
|
||||
variable_axes = { "params": 0 },
|
||||
split_rngs = { "params": True },
|
||||
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
|
||||
out_axes = 0,
|
||||
|
||||
Reference in New Issue
Block a user