diff --git a/resnet.py b/resnet.py index 1cf8d3d..579b8b9 100644 --- a/resnet.py +++ b/resnet.py @@ -1,3 +1,4 @@ +from torch import Module import torch.nn as nn @@ -23,7 +24,11 @@ def projection_shortcut(in_channels, out_channels): class ResidualBlock(nn.Module): - def __init__(self, in_channels, out_channels, stride=1, shortcut=None): + def __init__( + self, in_channels, out_channels, stride=1, shortcut=None, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.conv0 = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1