license and cleanup
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import List
|
||||
import torch
|
||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
|
||||
|
||||
class GLUTorch(nn.Module):
|
||||
def __init__(self, count_in_out, count_middle):
|
||||
super().__init__()
|
||||
@@ -21,6 +22,7 @@ class GLUTorch(nn.Module):
|
||||
z = self.fc2.forward(z)
|
||||
return z
|
||||
|
||||
|
||||
class AttentionTorch(nn.Module):
|
||||
def __init__(self, head_count: int, embed_count: int):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user