fix ResidualBlock init call

This commit is contained in:
2024-05-18 01:40:11 +01:00
parent 0ca7d5b8ec
commit 8f8da41251

View File

@@ -1,3 +1,4 @@
from torch import Module
import torch.nn as nn
@@ -23,7 +24,11 @@ def projection_shortcut(in_channels, out_channels):
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, shortcut=None):
def __init__(
self, in_channels, out_channels, stride=1, shortcut=None, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.conv0 = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1