MAML 及其优化改进

1. 什么是元学习?

进入智能化时代,人工智能的终极目标也正是让机器拥有人的智能。在人类拥有的众多能力中,最重要的能力当属学习能力,正是学习能力让人类能够不断掌握新的知识和技能,支持生命活动。因此,让机器学会学习是实现人工智能的主要目标,在机器学习中出现的元学习概念正试图实现这一目标。元学习由不同层次的学习抽象而成,使得目标机器学习系统能够通过学习自己的主要组件(例如优化器、损失函数、初始化方式、模型架构等)对自己的学习能力进行改进,故而元学习也被称为"学会如何学习 (learning to learn)"。

元学习通常将学习过程抽象为两个或更多层次:在最内层,模型学习任务相关的知识(例如在新的数据集上对模型进行精调);在最外层,模型则需要学习"跨任务"知识(例如通过学习在不同任务之间进行迁移)。若最内层的组件存在可学习参数,则最外层优化时可以通过对这些组件进行"元学习",进而能够对这些最内层的组件进行自动化学习。

2. 元学习工作流程

以 MAML, Model-Agnostic Meta-Learning 为例,MAML 是一个经典的元学习框架,模型学习神经网络的参数 $\theta$(初始化为 $\theta=\theta_0$)。针对特定任务的支持集(Support Set)$\mathcal{S}=\{x_S,y_S\}$,元学习先利用较少次数($N=1\ldots 5$)的标准SGD算法对模型进行优化,然后通过二次学习得到对任务的目标集(Target Set)$\mathcal{T}=\{x_T,y_T\}$ 泛化性能较好的模型参数。

具体来说,在元学习任务中,给定一个任务,任务由支持集和目标集两个集合组成,其中的支持集由若干批量的输入/输出对(${x_S,y_S}$)组成,目标集则是一个小数量的验证集合,同样由输入/输出对(${x_T,y_T}$)组成,元学习的执行过程如下:

  • 初始化模型参数 $\theta_0=\theta$。
  • 开始进行内层循环的优化(inner loop optimization)
    • 从第一个步骤开始,在每个步骤 $i$, 神经网络 $f$ 的参数设置为 $\theta_{i-1}$
    • 输入样本,得到支持集预测结果 $f(x_s;\theta_{i-1})$
    • 根据支持集标签 $y_S$,由损失函数 $L$ 计算任务的支持集损失 $L_{i-1}^S$
    • 由 SGD 更新当前步骤的模型参数 $\theta_{i}=\theta_{i-1}-\alpha\nabla_{\theta_{i-1}}L_{i-1}^S$
    • 重复上述内层循环 $N$ 次,直到得到模型参数 $\theta_N$
  • 由内层优化更新后的模型参数 $\theta_N$ 得到目标集预测结果 $f(x_T,\theta_N)$
  • 根据目标集标签 $y_T$,由损失函数 $L$ 计算任务的目标集损失 $L_N^T$
  • 计算目标集损失对 $\theta=\theta_0$ 的梯度 $\nabla_{\theta_{0}}L_{N}^T$

在上述更新过程中,最后对目标梯度的计算包括了内层循环的梯度计算和更新,如此学习得到的初始化参数对特定任务有更好的泛化性,即:通过优化过程本身得到的梯度进行反向传播,可以得到更精确、信息更丰富的梯度,使模型学习更高效。

MAML 实现中,通常会评估一批任务,并使用这些任务的损失之和或均值更新模型。元学习可以看作是对模型初始化过程的改进,其目的在于使模型获得"学习调参"的能力,让模型能根据已有知识快速学习新的任务。例如,对于预训练方法,通常是人为调参之后获得初始化模型,之后再根据特定任务对模型进行训练/调优;元学习则通过内层循环,由支持集对应的任务先学习得到较好的参数,然后利用该参数对特定任务进行训练。普通方法中,数据通常分为训练集、验证集和测试集,通过数据对模型进行优化。元学习在此基础上将学习任务分为训练任务和测试任务,其中的训练任务由许多子任务组成,目的在于学习一个较好的参数,测试任务在该参数的基础上针对特定任务进行再优化。

3. 改进的 MAML 实现

元学习的理念简单、优雅、有效,但是,若针对复杂的系统, 使用 MAML 这样的原型系统作为基础时,研究人员发现 MAML 对超参数、模型架构改变很敏感,使训练过程不稳定。故 Antreas Antoniou 等分析了 MAML,并实现其训练过程的稳定化。

在其分析中,作者发现使用卷积步长、添加更多层网络等修改都会使模型训练过程不稳定(训练损失振荡),造成模型需要更长时间才能收敛,最终的泛化性能也比稳定收敛情况下的性能差。针对这个问题,作者研究认为 DNN 训练不稳定最常见的原因是梯度退化题,例如梯度爆照、梯度消失等,实验还发现内层循环的次数也会影响训练稳定性。针对该问题,作者提出解决方案:

  • 在每步内层循环之后引入显式的梯度计算目标集损失
  • 将每步损失的加权平均作为最终的优化损失
  • 添加 MAML 损失使用的隐式梯度,使得改进的模型仍然重点关注主要目标

3.1 提升训练稳定性 (MSL,多步学习)

改进的 MAML 和原始 MAML 区别在于计算每一步梯度用于更新 $\theta_{i}=\theta_{i-1}-\alpha\nabla_{\theta_{i-1}}L_{i-1}^S$ 后,在目标集使用当前参数 $\theta_i$ 计算目标集损失, $N$ 步之后得到的 $N$ 个目标集损失加权组合得到 $L_{0 \ldots N}^{T}=\sum_{i=0}^N{w_iL_{i}^T}$,外层循环使用该组合损失优化参数 $\theta$。训练初期,所有步骤的权重 $w_i$ 均等,后期则使得偏后的步骤权重更大。

3.2 提升收敛速度和泛化性能 (BN)

原始的 MAML 实现中,作者未在 BN 层中保存运行过程中的统计数据,相反,只用运行中每个 Batch 对应的统计数据做归一化处理。由于元学习过程中要对所有可能的均值和方差共享参数,这样的处理使得批归一化参数 $\beta$ 和 $\gamma$ 参数的优化非常复杂,且此时用于归一化的均值和标准差和真实值相去甚远,使得模型的泛化性能大幅下降,并降低了模型的收敛速度。

考虑到这些问题,MAML 为何不使用标准的 BN 做法,保存运行时的统计数据呢?作者用标准的 BN 实验了多次之后发现并没有实质效果,且 MAML 作者发现,保持 $\gamma$ 值不变,学习 $\beta$ 值,这样的做法也不甚有用。实际上,标准 BN 的做法假设初始化的模型以及后来更新迭代后的模型都有相似的特征分布,但是该假设在元学习中并不一定成立,由于模型需要快速学习以适应新任务,模型参数变化也会很快。因此,解决这个问题需要学习每一步更新中的 $\beta$ 和 $\gamma$ 参数,并在更新后将每次模型迭代的统计数据都进行保存,这样可以改进模型的收敛速度和泛化性能。

3.3 提升收敛速度和泛化性能 (LSLR,逐步可学习学习率)

内层循环的学习率设置并不简单,但是由于本身元学习策略就在于学习这些参数,因此可以将学习率的设置也作为优化对象。改进的做法使用 Meta-SGD 进行优化,但是对每个参数都学习一个学习率计算复杂,且需要占据大量存储空间。

实际上,在多步骤更新过程中,可以在每一次更新步后进行学习率的学习以及优化方向的学习。改进做法使用每一层、每一步的学习率和优化方向,允许网络在 N 步内层循环中针对不同网络层学习不同的学习率,甚至允许"负学习"情况出现,即学习率为负的情况,这样可以改进收敛速度以及泛化性能。

参考引用