fix flatten code
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@@ -174,7 +175,6 @@ class MaiRes(nn.Module):
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=RESNET_KERNEL_SIZE)
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc = nn.Linear(in_features=1000, out_features=1)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -187,7 +187,7 @@ class MaiRes(nn.Module):
|
||||
x = self.layer_blue(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = self.flatten(x)
|
||||
x = torch.flatten(x)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
Reference in New Issue
Block a user