diff --git a/descreen/networks/resnet.py b/descreen/networks/resnet.py index c46de95..ebc2e75 100644 --- a/descreen/networks/resnet.py +++ b/descreen/networks/resnet.py @@ -21,13 +21,13 @@ def forward(self, x: Tensor) -> Tensor: def input_size_unchecked(self, output_size: int) -> int: size = output_size - for _ in range(len(self.blocks)): - size = input_size(size, 3) + for b in self.blocks: + size = b.input_size(size) return size def output_size_unchecked(self, input_size: int) -> int: size = input_size - for _ in range(len(self.blocks)): - size = output_size(size, 3) + for b in self.blocks: + size = b.output_size(size) return size