From 00749b0374210db17335bd59b5c5398d7bd4dca6 Mon Sep 17 00:00:00 2001 From: Kenneth Date: Sat, 18 May 2024 18:31:07 +0100 Subject: [PATCH] fix flatten code --- resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resnet.py b/resnet.py index 0405114..ce3ddc7 100644 --- a/resnet.py +++ b/resnet.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn @@ -174,7 +175,6 @@ class MaiRes(nn.Module): ) self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE) - self.flatten = nn.Flatten() self.fc = nn.Linear(in_features=1000, out_features=1) def forward(self, x): @@ -187,7 +187,7 @@ class MaiRes(nn.Module): x = self.layer_blue(x) x = self.avgpool(x) - x = self.flatten(x) + x = torch.flatten(x) x = self.fc(x) return x