diff --git a/resnet.py b/resnet.py index 0405114..ce3ddc7 100644 --- a/resnet.py +++ b/resnet.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn @@ -174,7 +175,6 @@ class MaiRes(nn.Module): ) self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE) - self.flatten = nn.Flatten() self.fc = nn.Linear(in_features=1000, out_features=1) def forward(self, x): @@ -187,7 +187,7 @@ class MaiRes(nn.Module): x = self.layer_blue(x) x = self.avgpool(x) - x = self.flatten(x) + x = torch.flatten(x) x = self.fc(x) return x