🔧 自定义层
PyTorch允许你创建自定义的网络层和模块,以实现特殊的需求。
📝 基本自定义层
继承nn.Module
import torch
import torch.nn as nn
class MyLinear(nn.Module):
"""自定义线性层"""
def __init__(self, in_features, out_features):
super().__init__()
# 定义可学习参数
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, x):
return x @ self.weight.t() + self.bias
# 使用
layer = MyLinear(10, 5)
x = torch.randn(3, 10)
y = layer(x)
print(f"输出形状: {y.shape}") # [3, 5]
# 参数会自动注册
for name, param in layer.named_parameters():
print(f"{name}: {param.shape}")
nn.Parameter的作用
import torch
import torch.nn as nn
class Demo(nn.Module):
def __init__(self):
super().__init__()
# 使用nn.Parameter:会被注册为可学习参数
self.learnable = nn.Parameter(torch.randn(3, 3))
# 不使用nn.Parameter:只是普通张量,不会被优化
self.not_learnable = torch.randn(3, 3)
def forward(self, x):
return x @ self.learnable
model = Demo()
print("可学习参数:")
for name, param in model.named_parameters():
print(f" {name}: {param.shape}")
# 只有learnable会被打印
🎯 实用自定义层示例
1. Flatten层
class Flatten(nn.Module):
"""展平层"""
def forward(self, x):
return x.view(x.size(0), -1)
# 使用
flatten = Flatten()
x = torch.randn(2, 3, 4, 5)
y = flatten(x)
print(f"输入: {x.shape} → 输出: {y.shape}") # [2, 3, 4, 5] → [2, 60]
2. Squeeze Excitation Block
class SEBlock(nn.Module):
"""Squeeze-and-Excitation块,用于通道注意力"""
def __init__(self, channels, reduction=16):
super().__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(channels, channels // reduction),
nn.ReLU(),
nn.Linear(channels // reduction, channels),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.shape
# Squeeze: 全局平均池化
y = self.squeeze(x).view(b, c)
# Excitation: 学习通道权重
y = self.excitation(y).view(b, c, 1, 1)
# Scale: 加权
return x * y
# 使用
se = SEBlock(64)
x = torch.randn(2, 64, 32, 32)
y = se(x)
print(f"输出形状: {y.shape}") # [2, 64, 32, 32]
3. 残差块
class ResidualBlock(nn.Module):
"""残差块"""
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
out = torch.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += residual # 残差连接
return torch.relu(out)
# 使用
res_block = ResidualBlock(64)
x = torch.randn(2, 64, 32, 32)
y = res_block(x)
print(f"输出形状: {y.shape}")
4. 自注意力层
class SelfAttention(nn.Module):
"""简单的自注意力层"""
def __init__(self, embed_dim):
super().__init__()
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.scale = embed_dim ** 0.5
def forward(self, x):
# x: (batch, seq_len, embed_dim)
Q = self.query(x)
K = self.key(x)
V = self.value(x)
# 注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
attention = torch.softmax(scores, dim=-1)
# 加权求和
output = torch.matmul(attention, V)
return output
# 使用
attention = SelfAttention(256)
x = torch.randn(2, 10, 256) # 2个样本,10个token,256维
y = attention(x)
print(f"输出形状: {y.shape}") # [2, 10, 256]
🔄 自定义激活函数
使用torch.autograd.Function
import torch
from torch.autograd import Function
class CustomReLU(Function):
"""自定义ReLU的前向和反向传播"""
@staticmethod
def forward(ctx, input):
# ctx用于保存反向传播需要的信息
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
# 包装成nn.Module
class CustomReLUModule(nn.Module):
def forward(self, x):
return CustomReLU.apply(x)
# 使用
relu = CustomReLUModule()
x = torch.randn(5, requires_grad=True)
y = relu(x)
y.sum().backward()
print(f"梯度: {x.grad}")
Swish激活函数
class Swish(nn.Module):
"""Swish激活函数: x * sigmoid(x)"""
def forward(self, x):
return x * torch.sigmoid(x)
# 或者用lambda
swish = lambda x: x * torch.sigmoid(x)
📦 组合复杂模块
class ConvBlock(nn.Module):
"""卷积块:Conv + BN + ReLU"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class DownBlock(nn.Module):
"""下采样块"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.block = nn.Sequential(
ConvBlock(in_channels, out_channels),
ConvBlock(out_channels, out_channels),
nn.MaxPool2d(2)
)
def forward(self, x):
return self.block(x)
class UpBlock(nn.Module):
"""上采样块"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2)
self.conv = nn.Sequential(
ConvBlock(out_channels * 2, out_channels),
ConvBlock(out_channels, out_channels)
)
def forward(self, x, skip):
x = self.up(x)
x = torch.cat([x, skip], dim=1)
return self.conv(x)
# 简单的U-Net
class SimpleUNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down1 = DownBlock(in_channels, 64)
self.down2 = DownBlock(64, 128)
self.down3 = DownBlock(128, 256)
self.bottom = nn.Sequential(
ConvBlock(256, 512),
ConvBlock(512, 512)
)
self.up3 = UpBlock(512, 256)
self.up2 = UpBlock(256, 128)
self.up1 = UpBlock(128, 64)
self.final = nn.Conv2d(64, out_channels, 1)
def forward(self, x):
# 编码器
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
# 底部
bottom = self.bottom(d3)
# 解码器(带跳跃连接)
u3 = self.up3(bottom, d3)
u2 = self.up2(u3, d2)
u1 = self.up1(u2, d1)
return self.final(u1)
# 测试
model = SimpleUNet()
x = torch.randn(1, 1, 256, 256)
y = model(x)
print(f"输入: {x.shape} → 输出: {y.shape}")
📝 最佳实践
1. 使用reset_parameters初始化
class MyLayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.empty(out_features, in_features))
self.bias = nn.Parameter(torch.empty(out_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.zeros_(self.bias)
def forward(self, x):
return x @ self.weight.t() + self.bias
2. 添加extra_repr便于打印
class MyLayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.randn(out_features, in_features))
def forward(self, x):
return x @ self.weight.t()
def extra_repr(self):
return f'in_features={self.in_features}, out_features={self.out_features}'
layer = MyLayer(10, 5)
print(layer)
# MyLayer(in_features=10, out_features=5)
🏋️ 练习
# 练习:实现一个Layer Normalization层
# 公式:y = (x - mean) / sqrt(var + eps) * gamma + beta
# 其中gamma和beta是可学习参数
# 你的代码:
点击查看答案
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
# 测试
ln = LayerNorm(256)
x = torch.randn(2, 10, 256)
y = ln(x)
print(f"输出形状: {y.shape}")