fixed typing error for older python versions
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
torch.set_grad_enabled(False)
|
||||
@@ -120,7 +121,7 @@ class DalleBartEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
|
||||
self.embed_positions = nn.Embedding(text_token_count, embed_count)
|
||||
self.layers: list[EncoderLayer] = nn.ModuleList([
|
||||
self.layers: List[EncoderLayer] = nn.ModuleList([
|
||||
EncoderLayer(
|
||||
embed_count = embed_count,
|
||||
head_count = attention_head_count,
|
||||
|
||||
Reference in New Issue
Block a user