previous commit broke flax model, fixed now

This commit is contained in:
Brett Kuprel
2022-06-28 12:54:58 -04:00
parent 5aa6fe49bf
commit 9d6b6dcc92
4 changed files with 16 additions and 17 deletions
+1 -1
View File
@@ -85,7 +85,7 @@ def generate_image_tokens_torch(
image_token_count: int
) -> LongTensor:
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
# if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
encoder_state = encode_torch(
text_tokens,
config,