元学习(Meta-learning)——让机器学习如何学习

1 元学习概述

元学习的意思即“学会如何学习” 。 在机器学习中,工作量最大也是最无聊的事情就是调参。我们针对每一个任务从头开始进行这种无聊的调参,然后耗费大量的时间去训练并测试效果。因此,一个直观的想法是:我们是否能让机器自己学会调参,在遇到相似任务时能够触类旁通、举一反三,用不着我们从头开始调参,也用不着大量标签数据重新进行训练。
通常的机器学习是针对一个特定的任务找到一个能够实现这个任务的function,例如猫和狗的分类任务。而元学习的目标就是要找到一个Function能够让机器自动学习原来人为确定的一些超参(Hyper-parameter),如初始化参数\theta_0、学习速率\eta、网络架构等,元学习的分类就是看学习的是什么超参。这个Function用F_{\phi}表示,F_{\phi}不是针对某一个特定任务的,而是针对一群类似的任务,例如这些任务可能包括猫和狗的分类、橘子和苹果的分类、自行车和摩托车的分类等等。这个F_{\phi}是要帮这一群类似任务找到一个好的超参,在下次再遇到相似任务的时候,初始化参数可以直接用上,用不着我们再调参了。

image.png

元学习是跨任务学习(multi-task learning),因此它需要收集多个类似任务的数据集。比如针对图片二分类任务,我们需要收集橙子和苹果训练数据和测试数据、自行车和汽车的训练数据和测试数据等等许多二分类任务的数据集。元学习的目标是:利用F_{\phi}找到最优的超参\phi,使各任务在超参\phi的基础上训练出最优参数后测试得到的损失值l^n的和最小。这句话讲起来比较难以理解,举个例子比较好明白:对于苹果和橙子的分类任务,在超参\phi的基础上利用训练数据集进行训练,得到最优参数\theta^{1*},然后再利用测试数据集对训练后的模型进行测试,测试得到的损失值使l^1;同理,可以得到自行车和汽车分类任务的测试损失值l^2,以及其他二分类任务的测试损失值l^n;元学习的目标就是要找到最优超参\phi,使所有任务的测试损失值之和最小。所以元学习的损失函数定义为,L(\phi)=\sum_{n=1}^N{l^n}这里每一个用于训练超参\phi的任务都称为训练任务,上面的N指所有训练任务的总数。如果在拿一个新的任务(该任务未在训练任务中出现过)来测试通过训练找到的超参\phi的效果,那么这个任务就称为测试任务。
我们可以看到在每一个训练任务中包含了训练数据和测试数据,当然在测试任务中也包含了训练数据和测试数据,这和普通机器学习是大不同的。这样听起来很容易让人迷糊,所以有的文献不叫训练数据和测试数据,而是把训练数据叫支持集(support set),把测试数据叫查询集(query set)。

image.png

元学习的目标是要找到超参\phi最小化损失函数L(\phi)如果能够计算梯度,那么用梯度下降法求解即可。但是有很多情况使无法求梯度的,例如对网络架构的优化,此时有些文献会采用强化学习或进化算法等方法进行求解。
image.png

2 MAML

2.1 MAML概述

在普通机器学习中,初始化参数往往是随机生成的,MAML聚焦于学习一个最好的初始化参数\phi。初始参数\phi不同,对于同一个任务n训练得到的最优参数\hat{\theta}^n不同,在任务n的测试数据集上损失值l^n(\hat{\theta}^n)不同。MAML的目标是找到最优的初始参数\phi,是所有任务的测试损失值最小,在遇到新任务时,只需基于少量标签对初始化参数\phi进行微调就可以获得很好的效果。这和前面提到的预训练有些相似,但也有些不同。

image.png

2.2 MAML的训练

MAML的训练使用梯度下降法:\phi \leftarrow \phi-\eta \nabla_{\phi} L(\phi)具体的数学推导不管它了,我们直接看上面的梯度下降是如何实施的(这里假设batch size是1):

  • ①假设刚开始的初始化参数的初始化参数是\phi_0
  • ②随机采样一个训练任务m
  • ③通过训练任务m的支持集(训练数据)求loss,然后更新一次参数得到任务m的最优参数\hat{\theta}^m,注意此时并没有更新\phi
  • ④通过训练任务m的查询集(测试数据)再求一次loss,计算梯度,然后用此梯度的方向更新\phi
  • ⑤回到②,重复②③④
    上述过程有一个很值得关注的地方是,对于任务m,更新一次参数就认为参数最优了,李宏毅认为作者之所以这么设置是因为MAML主要用于小样本学习,更新一次是怕发生过拟合的问题。
    之前说MAML和与训练模型很像,但是也有所不同。不同点是与训练模型用任务的训练数据求loss就直接更新\phi了,而不是在测试数据上二次求loss后才更新的。
image.png

结合上面的讲解,来看一看MAML原文的算法,如下图所示。首先随机在这个算法里采样一个batch的训练任务,注意这里的batch是任务而不是数据。对于这一个batch的所有任务:第5行对每一个训练任务T_i,通过支持集求loss,计算梯度;第6行根据第5行算出来的梯度更新一次参数得到\theta'_i,并且保存起来。假设一个batch有10个任务,那么这里就保存了10个模型参数\theta'_i。完成了一个batch所有任务的参数更新后,进行第8行:基于更新后的参数和所有任务的查询集计算出各自的loss,将这些loss求和,计算出梯度,利用该梯度更新初始参数。假设一个batch有10个任务,基于更新后的参数和这10个batch任务的查询集计算出10个loss,将这10个loss进行求和,并基于求和结果计算梯度,利用该梯度更新初始参数。

image.png

我截了一部分MAML的代码,通过分析代码就可以更好理解上述参数的更新过程了。以下是第一次更新参数的代码

for k in range(1, self.update_step):
       # 1.在支持集上计算loss
       logits = self.net(x_spt[i], fast_weights, bn_training=True)
       loss = F.cross_entropy(logits, y_spt[i])
       # 2. 利用上面的loss,计算梯度
       grad = torch.autograd.grad(loss, fast_weights)
       # 3. 更新参数:theta_pi = theta_pi – train_lr * grad
       fast_weights = list(map(lambda p: p[1] – self.update_lr * p[0], zip(grad, fast_weights)))
       # 4.基于更新后的参数,在查询集上计算loss
       logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
       loss_q = F.cross_entropy(logits_q, y_qry[i])
       # 5把所有loss加起来,并保存.
       losses_q[k + 1] += loss_q

以下是第二次更新参数的代码

# 将所有任务的查询集上的loss的和除以任务数目,求了个平均值
loss_q = losses_q[-1] / task_num
# 利用上面的loss算梯度,并更新初始化参数
self.meta_optim.zero_grad()
loss_q.backward()
self.meta_optim.step()

11.3 元学习在N-ways K-shot上的应用

N-way K-shot是典型的小样本学习问题。所谓N-way K-shot是指在每一个任务里面,有N个类别,每个类别有K个样本。Omniglot是一个典型例子它包含1632个不同的字符,每个字符只有20个样本。从上面1632字符中可以构建N-way K-shot任务。例如通过下面的方式构建一个小样本分类任务:抽出20个字符出来,里面每个字符只有1个样本,我们把这个数据集作为训练集(支持集),那就是20-ways 1-shot的问题。然后再在这20个字符中取1个样本出来,作为测试集(查询集),利用训练出来的模型来判断这个样本属于哪个字符。通过这种方式,可以构建出许多任务来,如果是20 ways的就可以构建出81个任务。这些任务又可以分为训练任务和测试任务,例如将81个任务中的60个任务作为训练任务,21个任务作为测试任务。

image.png

拥有了这些数据集后,就可以来测试MAML了。以下是MAML原文的测试结果。从测试结果来看,MAML处理N-way K-shot任务是非常棒的。
image.png

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 202,980评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,178评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 149,868评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,498评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,492评论 5 364
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,521评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,910评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,569评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,793评论 1 296
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,559评论 2 319
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,639评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,342评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,931评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,904评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,144评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,833评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,350评论 2 342

推荐阅读更多精彩内容