previous commit broke flax model, fixed now
This commit is contained in:
@@ -85,7 +85,7 @@ def generate_image_tokens_torch(
|
||||
image_token_count: int
|
||||
) -> LongTensor:
|
||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||
# if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||
encoder_state = encode_torch(
|
||||
text_tokens,
|
||||
config,
|
||||
|
||||
Reference in New Issue
Block a user