license and cleanup
This commit is contained in:
@@ -132,7 +132,7 @@ def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
||||
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||
|
||||
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
|
||||
top_logits, top_tokens = lax.top_k(logits, k)
|
||||
top_logits, _ = lax.top_k(logits, k)
|
||||
suppressed = -jnp.inf * jnp.ones_like(logits)
|
||||
return lax.select(logits < top_logits[-1], suppressed, logits)
|
||||
|
||||
@@ -198,34 +198,6 @@ class DalleBartDecoderFlax(nn.Module):
|
||||
decoder_state = self.lm_head(decoder_state)
|
||||
return decoder_state, keys_state, values_state
|
||||
|
||||
def compute_logits(self,
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
params: dict
|
||||
) -> jnp.ndarray:
|
||||
batch_count = encoder_state.shape[0]
|
||||
state_shape = (
|
||||
self.layer_count,
|
||||
batch_count,
|
||||
self.image_token_count,
|
||||
self.attention_head_count,
|
||||
self.embed_count // self.attention_head_count
|
||||
)
|
||||
keys_state = jnp.zeros(state_shape)
|
||||
values_state = jnp.zeros(state_shape)
|
||||
|
||||
logits, _, _ = self.apply(
|
||||
{ 'params': params },
|
||||
encoder_state = encoder_state,
|
||||
keys_state = keys_state,
|
||||
values_state = values_state,
|
||||
attention_mask = jnp.not_equal(text_tokens, 1),
|
||||
prev_token = self.start_token,
|
||||
token_index = 0
|
||||
)
|
||||
|
||||
return super_conditioned(logits, 10.0)
|
||||
|
||||
def sample_image_tokens(self,
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
|
||||
Reference in New Issue
Block a user