diff --git a/resnet.py b/resnet.py index ce3ddc7..925f058 100644 --- a/resnet.py +++ b/resnet.py @@ -175,7 +175,7 @@ class MaiRes(nn.Module): ) self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE) - self.fc = nn.Linear(in_features=1000, out_features=1) + self.fc = nn.Linear(in_features=8192, out_features=1) def forward(self, x): x = self.conv(x)