链接: https://arxiv.org/abs/2011.00209
从基于梯度的元学习入手, 这篇文章将重点放到了inner-loop optimization的过程上.
以往基于MAML的工作大多都是面向不同的任务, 寻找一个合适的初始化值. 然而这篇工作主要通过调整inner-loop optimization来改善元学习性能.
具体方案为采取一个额外的模型g, 以base learner的参数和inner loop optimization中的梯度作为模型输入, 输出梯度更新中的学习率α和权重衰减参数β(也就是l2正则项求梯度后再整理得到的θ前面的系数).
模型g的训练就是在outer-loop上用不同任务通过inner-loop finetune后得到的loss更新.
文章细节和总结
首先出发点来说这篇文章是比较新的, 但是这个出发点不算是凭空出来的, 作者Sungyong Baik之前有一篇工作learn2forget, 也是在inner-loop optimization过程中引入模型作优化, 取得了不错的效果. 可能在此基础上, 他才锚定了inner-loop optimization的基础优化过程, 进一步提出了本篇工作.
本篇论文核心想法和设计框架都比较简单, 不涉及理论, 是一篇以实验为主导的论文, 论文正文就4面半, 剩下实验又写了4面半.
简述论文思想:
传统的基于梯度的元学习在学习过程当中分为外循环和内循环. 内循环更新方式大部分都是卡死的例如采取SGD更新base learner, 不会针对不同任务做出调整. 另外有针对任务做出调整的, 但是采取的是直接输出更新后模型参数的方法(也就是Model-based meta-learning). 那么这篇论文就是取道中间采取一种按传统梯度更新公式来, 但是又考虑了不同任务的更新策略.
这种更新策略说白了, 就是通过学习率和权重衰减参数两个超参来影响更新过程, 而这两个超参由单独的模型g另外根据任务生成(实际上是基于base learner的参数和反馈的梯度, 直觉上这两个组合起来应该可以反应出特定任务信息, 但是严谨的对等或等价证明有待考证). 通过模型g, inner-loop optimization的每一个任务的每一步更新都会产生新的学习率和权重衰减参数.
模型g, 文章中采取的是三层感知机每层加ReLu. 模型具体输入为base learner每层内的参数和梯度求均值, 那么n层base learner最后得到2*n个输入值. 三层感知机隐藏层大小和输入一样. 最后输出也是2*n, 分别代表base learner每层参数更新会采用的学习率和权重衰减参数. 所以base learner同一层参数在每一步更新中采取的是同一个学习率和权重衰减, 而不同层之间则采取不同的学习率和权重衰减.
实验部分:
大体上分, 该篇文章做了6组实验(3组有效性证明+3组消融):
一, ALFA+MAML/Random Init+L2F 在miniImageNet和tiredImageNet
1. L2F也是他的工作, 并且严格意义上它不算是MAML的变体, 它没有严格定义内循环外循环, 而是借用MAML或其他元学习算法的结构, 更类似于一种插件.
2. 对比当中, 附加了LEO和MetaOpt, 虽然效果赶不上但是下面又附加了理由, 很可能是审稿人质疑缺少其他元学习对比后添加的.
二, cross domain 在miniImageNet上训练, 在CUB上测试
和第一组实验差距不大, 但是采取cross domain. 从摘要出发, 作者考虑过与任务挂钩的更新方式具有更强的灵活性, 在train和test差距大的时候具有更好的表现.
三, Meta-Dataset上的表现
Meta-Dataset应该已经成为meta-learning性能检验的主流了, 这里除了MAML还增加了Proto-MAML. 跨元学习算法难度是比较大的, 因为复现难度很大. 所以这里可能是直接用的别人的库.
四, 消融实验中验证per step和per layer
主要存在两个对比, 一个是对比每步都产生新的α和β与每个inner-loop只产生一个α和β, 一个是对比每一层用不同的α和β和所有层都用一样的α和β.
其余还有MAML和Random Init对比 以及α和β对结果影响对比.
五, 对比不同步数下inner-loop optimization结果的影响.
每种步数(1~5)下ALFA+MAML都比MAML 5步强
六, 考虑输入的影响
只输入base learner权重/ 只输入gradient/ 全部都输入
最后当然是全部都输入最好
最后还有两组额外型补充实验
Few-shot regression 上ALFA表现不错
还把α和β具体取值做统计, 按步数排列
文章不错 但是代码基于MAML++, 也可以看出来尽可能用别人效果已经比较好的代码会更容易实现一些.