fix flatten code

This commit is contained in:
2024-05-18 18:31:07 +01:00
parent 9f6852ee23
commit 00749b0374

View File

@@ -1,3 +1,4 @@
import torch
import torch.nn as nn
@@ -174,7 +175,6 @@ class MaiRes(nn.Module):
)
self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE)
self.flatten = nn.Flatten()
self.fc = nn.Linear(in_features=1000, out_features=1)
def forward(self, x):
@@ -187,7 +187,7 @@ class MaiRes(nn.Module):
x = self.layer_blue(x)
x = self.avgpool(x)
x = self.flatten(x)
x = torch.flatten(x)
x = self.fc(x)
return x