ResidualBlock 实现残差连接
ResidualBlock 的具体作用:
缓解梯度消失:通过直接将输入信息传递到输出,残差连接为梯度提供了额外的传播路径,使网络能够更有效地学习。
促进特征复用:残差块允许网络保留输入中的重要特征,同时学习增量变化,有助于保留低级特征。
提高训练稳定性:实验表明,带有残差连接的网络通常比没有残差连接的网络更容易训练,并且在相同深度下表现更好
import torch
import torch.nn as nn
# 你的ResidualBlock实现
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# Handle cases of dimension mismatch
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = self.relu(out)
return out
# 测试用例
def test_residual_block():
# 测试用输入 (batch_size, channels, height, width)
x = torch.randn(1, 64, 32, 32)
# 测试1: 通道数不变,步长为1
block1 = ResidualBlock(64, 64, stride=1)
out1 = block1(x)
print(f"测试1: 输入尺寸={x.shape}, 输出尺寸={out1.shape}")
# 测试2: 通道数增加,步长为1
block2 = ResidualBlock(64, 128, stride=1)
out2 = block2(x)
print(f"测试2: 输入尺寸={x.shape}, 输出尺寸={out2.shape}")
# 测试3: 通道数不变,步长为2 (空间尺寸减半)
block3 = ResidualBlock(64, 64, stride=2)
out3 = block3(x)
print(f"测试3: 输入尺寸={x.shape}, 输出尺寸={out3.shape}")
# 测试4: 通道数增加,步长为2 (空间尺寸减半)
block4 = ResidualBlock(64, 128, stride=2)
out4 = block4(x)
print(f"测试4: 输入尺寸={x.shape}, 输出尺寸={out4.shape}")
test_residual_block()
测试1: 输入尺寸=torch.Size([1, 64, 32, 32]), 输出尺寸=torch.Size([1, 64, 32, 32]) 测试2: 输入尺寸=torch.Size([1, 64, 32, 32]), 输出尺寸=torch.Size([1, 128, 32, 32]) 测试3: 输入尺寸=torch.Size([1, 64, 32, 32]), 输出尺寸=torch.Size([1, 64, 16, 16]) 测试4: 输入尺寸=torch.Size([1, 64, 32, 32]), 输出尺寸=torch.Size([1, 128, 16, 16])
Last updated
Was this helpful?