深度炼丹如同炖排骨一般,需要先大火全局加热,紧接着中火炖出营养,最后转小火收汁。
本文给出炼丹中的 “火候控制器”— 学习率的几种调节方法,框架基于 pytorch.
1. 自定义根据 epoch 改变学习率
PyTorch官网动态调整学习率1
2
3
4
5def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
或者1
2
3
4def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
for param_group in optimizer.param_groups:
param_group['lr'] = args.lr * (0.1 ** (epoch // 30))
注释:在调用此函数时需要输入所用的 optimizer 以及对应的 epoch ,并且 args.lr 作为初始化的学习率也需要给出。
使用代码示例:1
2
3
4
5optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
for epoch in range(10):
adjust_learning_rate(optimizer,epoch)
train(...)
validate(...)
2. 针对模型的不同层设置不同的学习率
当我们在使用预训练的模型时,需要对分类层进行单独修改并进行初始化,其他层的参数采用预训练的模型参数进行初始化,这个时候我们希望在进行训练过程中,除分类层以外的层只进行微调,不需要过多改变参数,因此需要设置较小的学习率。而改正后的分类层则需要以较大的步子去收敛,学习率往往要设置大一点以 resnet101 为例,分层设置学习率。
1 | model = torchvision.models.resnet101(pretrained=True) |
注:large_lr_layers 学习率为 1e-2,small_lr_layers 学习率为 1e-4,两部分参数共用一个 momenum
3. 根据具体需要改变 lr
pytorch 中的ReduceLROnPlateau
1
class torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
以 acc 为例,当 mode 设置为 “max” 时,如果 acc 在给定 patience 内没有提升,则以 factor 的倍率降低 lr。
使用方法示例:1
2
3
4
5
6
7optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'max',verbose=1,patience=3)
for epoch in range(10):
train(...)
val_acc = validate(...)
# 降低学习率需要在给出 val_acc 之后
scheduler.step(val_acc)
4. 手动设置 lr 衰减区间
使用方法示例1
2
3
4
5
6
7
8
9
10
11
12
13
14
15def adjust_learning_rate(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
for epoch in range(60):
lr = 30e-5
if epoch > 25:
lr = 15e-5
if epoch > 30:
lr = 7.5e-5
if epoch > 35:
lr = 3e-5
if epoch > 40:
lr = 1e-5
adjust_learning_rate(optimizer, lr)
5. 余弦退火
1 | epochs = 60 |