ResidualBlock 实现残差连接
import torch
import torch.nn as nn
# 你的ResidualBlock实现
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# Handle cases of dimension mismatch
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = self.relu(out)
return out
# 测试用例
def test_residual_block():
# 测试用输入 (batch_size, channels, height, width)
x = torch.randn(1, 64, 32, 32)
# 测试1: 通道数不变,步长为1
block1 = ResidualBlock(64, 64, stride=1)
out1 = block1(x)
print(f"测试1: 输入尺寸={x.shape}, 输出尺寸={out1.shape}")
# 测试2: 通道数增加,步长为1
block2 = ResidualBlock(64, 128, stride=1)
out2 = block2(x)
print(f"测试2: 输入尺寸={x.shape}, 输出尺寸={out2.shape}")
# 测试3: 通道数不变,步长为2 (空间尺寸减半)
block3 = ResidualBlock(64, 64, stride=2)
out3 = block3(x)
print(f"测试3: 输入尺寸={x.shape}, 输出尺寸={out3.shape}")
# 测试4: 通道数增加,步长为2 (空间尺寸减半)
block4 = ResidualBlock(64, 128, stride=2)
out4 = block4(x)
print(f"测试4: 输入尺寸={x.shape}, 输出尺寸={out4.shape}")
test_residual_block()Last updated