1 元学习概述
元学习的意思即“学会如何学习” 。 在机器学习中,工作量最大也是最无聊的事情就是调参。我们针对每一个任务从头开始进行这种无聊的调参,然后耗费大量的时间去训练并测试效果。因此,一个直观的想法是:我们是否能让机器自己学会调参,在遇到相似任务时能够触类旁通、举一反三,用不着我们从头开始调参,也用不着大量标签数据重新进行训练。
通常的机器学习是针对一个特定的任务找到一个能够实现这个任务的function,例如猫和狗的分类任务。而元学习的目标就是要找到一个Function能够让机器自动学习原来人为确定的一些超参(Hyper-parameter),如初始化参数、学习速率、网络架构等,元学习的分类就是看学习的是什么超参。这个Function用表示,不是针对某一个特定任务的,而是针对一群类似的任务,例如这些任务可能包括猫和狗的分类、橘子和苹果的分类、自行车和摩托车的分类等等。这个是要帮这一群类似任务找到一个好的超参,在下次再遇到相似任务的时候,初始化参数可以直接用上,用不着我们再调参了。
元学习是跨任务学习(multi-task learning),因此它需要收集多个类似任务的数据集。比如针对图片二分类任务,我们需要收集橙子和苹果训练数据和测试数据、自行车和汽车的训练数据和测试数据等等许多二分类任务的数据集。元学习的目标是:利用找到最优的超参,使各任务在超参的基础上训练出最优参数后测试得到的损失值的和最小。这句话讲起来比较难以理解,举个例子比较好明白:对于苹果和橙子的分类任务,在超参的基础上利用训练数据集进行训练,得到最优参数,然后再利用测试数据集对训练后的模型进行测试,测试得到的损失值使;同理,可以得到自行车和汽车分类任务的测试损失值,以及其他二分类任务的测试损失值;元学习的目标就是要找到最优超参,使所有任务的测试损失值之和最小。所以元学习的损失函数定义为,这里每一个用于训练超参的任务都称为训练任务,上面的N指所有训练任务的总数。如果在拿一个新的任务(该任务未在训练任务中出现过)来测试通过训练找到的超参的效果,那么这个任务就称为测试任务。
我们可以看到在每一个训练任务中包含了训练数据和测试数据,当然在测试任务中也包含了训练数据和测试数据,这和普通机器学习是大不同的。这样听起来很容易让人迷糊,所以有的文献不叫训练数据和测试数据,而是把训练数据叫支持集(support set),把测试数据叫查询集(query set)。
元学习的目标是要找到超参最小化损失函数如果能够计算梯度,那么用梯度下降法求解即可。但是有很多情况使无法求梯度的,例如对网络架构的优化,此时有些文献会采用强化学习或进化算法等方法进行求解。
2 MAML
2.1 MAML概述
在普通机器学习中,初始化参数往往是随机生成的,MAML聚焦于学习一个最好的初始化参数。初始参数不同,对于同一个任务训练得到的最优参数不同,在任务的测试数据集上损失值不同。MAML的目标是找到最优的初始参数,是所有任务的测试损失值最小,在遇到新任务时,只需基于少量标签对初始化参数进行微调就可以获得很好的效果。这和前面提到的预训练有些相似,但也有些不同。
2.2 MAML的训练
MAML的训练使用梯度下降法:具体的数学推导不管它了,我们直接看上面的梯度下降是如何实施的(这里假设batch size是1):
- ①假设刚开始的初始化参数的初始化参数是
- ②随机采样一个训练任务
- ③通过训练任务的支持集(训练数据)求loss,然后更新一次参数得到任务的最优参数,注意此时并没有更新。
- ④通过训练任务的查询集(测试数据)再求一次loss,计算梯度,然后用此梯度的方向更新
- ⑤回到②,重复②③④
上述过程有一个很值得关注的地方是,对于任务,更新一次参数就认为参数最优了,李宏毅认为作者之所以这么设置是因为MAML主要用于小样本学习,更新一次是怕发生过拟合的问题。
之前说MAML和与训练模型很像,但是也有所不同。不同点是与训练模型用任务的训练数据求loss就直接更新了,而不是在测试数据上二次求loss后才更新的。
结合上面的讲解,来看一看MAML原文的算法,如下图所示。首先随机在这个算法里采样一个batch的训练任务,注意这里的batch是任务而不是数据。对于这一个batch的所有任务:第5行对每一个训练任务,通过支持集求loss,计算梯度;第6行根据第5行算出来的梯度更新一次参数得到,并且保存起来。假设一个batch有10个任务,那么这里就保存了10个模型参数。完成了一个batch所有任务的参数更新后,进行第8行:基于更新后的参数和所有任务的查询集计算出各自的loss,将这些loss求和,计算出梯度,利用该梯度更新初始参数。假设一个batch有10个任务,基于更新后的参数和这10个batch任务的查询集计算出10个loss,将这10个loss进行求和,并基于求和结果计算梯度,利用该梯度更新初始参数。
我截了一部分MAML的代码,通过分析代码就可以更好理解上述参数的更新过程了。以下是第一次更新参数的代码
for k in range(1, self.update_step):
# 1.在支持集上计算loss
logits = self.net(x_spt[i], fast_weights, bn_training=True)
loss = F.cross_entropy(logits, y_spt[i])
# 2. 利用上面的loss,计算梯度
grad = torch.autograd.grad(loss, fast_weights)
# 3. 更新参数:theta_pi = theta_pi – train_lr * grad
fast_weights = list(map(lambda p: p[1] – self.update_lr * p[0], zip(grad, fast_weights)))
# 4.基于更新后的参数,在查询集上计算loss
logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
loss_q = F.cross_entropy(logits_q, y_qry[i])
# 5把所有loss加起来,并保存.
losses_q[k + 1] += loss_q
以下是第二次更新参数的代码
# 将所有任务的查询集上的loss的和除以任务数目,求了个平均值
loss_q = losses_q[-1] / task_num
# 利用上面的loss算梯度,并更新初始化参数
self.meta_optim.zero_grad()
loss_q.backward()
self.meta_optim.step()
11.3 元学习在N-ways K-shot上的应用
N-way K-shot是典型的小样本学习问题。所谓N-way K-shot是指在每一个任务里面,有N个类别,每个类别有K个样本。Omniglot是一个典型例子它包含1632个不同的字符,每个字符只有20个样本。从上面1632字符中可以构建N-way K-shot任务。例如通过下面的方式构建一个小样本分类任务:抽出20个字符出来,里面每个字符只有1个样本,我们把这个数据集作为训练集(支持集),那就是20-ways 1-shot的问题。然后再在这20个字符中取1个样本出来,作为测试集(查询集),利用训练出来的模型来判断这个样本属于哪个字符。通过这种方式,可以构建出许多任务来,如果是20 ways的就可以构建出81个任务。这些任务又可以分为训练任务和测试任务,例如将81个任务中的60个任务作为训练任务,21个任务作为测试任务。
拥有了这些数据集后,就可以来测试MAML了。以下是MAML原文的测试结果。从测试结果来看,MAML处理N-way K-shot任务是非常棒的。