fixed typing error for older python versions

This commit is contained in:
Brett Kuprel
2022-07-02 09:06:22 -04:00
parent 2dadfdfb31
commit 313635e914
5 changed files with 16 additions and 14 deletions
+2 -1
View File
@@ -1,3 +1,4 @@
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False)
@@ -120,7 +121,7 @@ class DalleBartEncoder(nn.Module):
super().__init__()
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
self.embed_positions = nn.Embedding(text_token_count, embed_count)
self.layers: list[EncoderLayer] = nn.ModuleList([
self.layers: List[EncoderLayer] = nn.ModuleList([
EncoderLayer(
embed_count = embed_count,
head_count = attention_head_count,