save converted detokenizer params

This commit is contained in:
Brett Kuprel
2022-07-01 10:17:29 -04:00
parent 8b5960b687
commit e4c2be54cb
7 changed files with 35 additions and 32 deletions
+1 -1
View File
@@ -184,7 +184,7 @@ class DalleBartDecoderTorch(nn.Module):
decoder_state = self.final_ln(decoder_state)
logits = self.lm_head(decoder_state)
a = self.condition_factor
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
logits: FloatTensor = (1 - a) * logits[0, -1] + a * logits[1, -1]
top_logits, _ = logits.topk(50, dim=-1)
probs = torch.where(