save converted detokenizer params
This commit is contained in:
@@ -37,7 +37,8 @@ class DecoderSelfAttentionFlax(AttentionFlax):
|
||||
state_index
|
||||
)
|
||||
batch_count = decoder_state.shape[0]
|
||||
keys, values = attention_state[:batch_count], attention_state[batch_count:]
|
||||
keys = attention_state[:batch_count]
|
||||
values = attention_state[batch_count:]
|
||||
|
||||
decoder_state = self.forward(
|
||||
keys,
|
||||
@@ -120,7 +121,7 @@ class SampleState:
|
||||
attention_state: jnp.ndarray
|
||||
|
||||
def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
||||
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||
return (1 - a) * logits[0, -1] + a * logits[1, -1]
|
||||
|
||||
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
|
||||
top_logits, _ = lax.top_k(logits, k)
|
||||
|
||||
Reference in New Issue
Block a user