license and cleanup

This commit is contained in:
Brett Kuprel
2022-06-27 14:34:10 -04:00
parent 32b7aa196b
commit 18e6a9852f
7 changed files with 25 additions and 42 deletions
+1 -29
View File
@@ -132,7 +132,7 @@ def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
return a * logits[0, -1] + (1 - a) * logits[1, -1]
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
top_logits, top_tokens = lax.top_k(logits, k)
top_logits, _ = lax.top_k(logits, k)
suppressed = -jnp.inf * jnp.ones_like(logits)
return lax.select(logits < top_logits[-1], suppressed, logits)
@@ -198,34 +198,6 @@ class DalleBartDecoderFlax(nn.Module):
decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state
def compute_logits(self,
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
params: dict
) -> jnp.ndarray:
batch_count = encoder_state.shape[0]
state_shape = (
self.layer_count,
batch_count,
self.image_token_count,
self.attention_head_count,
self.embed_count // self.attention_head_count
)
keys_state = jnp.zeros(state_shape)
values_state = jnp.zeros(state_shape)
logits, _, _ = self.apply(
{ 'params': params },
encoder_state = encoder_state,
keys_state = keys_state,
values_state = values_state,
attention_mask = jnp.not_equal(text_tokens, 1),
prev_token = self.start_token,
token_index = 0
)
return super_conditioned(logits, 10.0)
def sample_image_tokens(self,
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,