diff --git a/resnet.py b/resnet.py index 925f058..e3b4402 100644 --- a/resnet.py +++ b/resnet.py @@ -187,7 +187,7 @@ class MaiRes(nn.Module): x = self.layer_blue(x) x = self.avgpool(x) - x = torch.flatten(x) + x = x.view(x.size(0), -1) x = self.fc(x) return x