Overview of optimizers for LLMs

Overview of optimizers for LLMs

Author

Updated

Jun, 22, 2026

Category

Adam

Adam 是一个针对高维参数空间的一阶优化器,Adam 基于 gradient 的一阶和二阶信息为不同的参数安排不同的学习率。 Adam 的来源是 adaptive moment estimation. Adam 主要是结合了 AdaGrad 和 RMSProp 两个算法的优点。

Adam 与 RMSProp 的区别在于:

  1. RMSProp 在 rescaled gradient 上进行 momentum 的计算然后更新,而 Adam 直接使用一阶和二阶矩来进行估计
  2. RMSProp 没有 bias-correction 项

Adam 的主要优势为:

  1. 参数更新的量级与 gradient 的 scaling 无关
  2. 步长被 stepsize 超参数限制
  3. 不要求目标函数 stationary
  4. 对于稀疏梯度 work 的比较好
  5. 优化器自带 annealing

Algorithm

Adam 的算法如下图所示

Adam Algorithm

我们优化的目标函数如下

minθf(θ)\min_{\theta}\quad f(\theta)

这里 ff 一般是一个神经网络。我们记 f(θ)f(\theta)θt\theta_t 处的梯度为 gt=θf(θt)g_t=\nabla_{\theta}f(\theta_t).

算法运行时,会更新梯度 mtm_t 以及梯度二阶矩 vtv_t 的 exponential moving average. 超参数 β1,β2\beta_1,\beta_2 负责控制 exponential decay rates. 这里 mtm_tvtv_t 分别是一阶动量(均值)和二阶动量(未中心化的 variance)的估计。由于 mtm_tvtv_t 的初始化都是 0, 因此他们会引入 bias, 作者在后续通过修正解决了这个问题。

假设 ϵ=0\epsilon=0, 如果除了当前时刻 tt 之外,之前所有时刻的梯度 gi=0,i<tg_i=0,i<t, 此时

mt=(1βt)gt,vt=(1β2)gt2m_t = (1-\beta_t)g_t, v_t=(1-\beta_2)g_t^2

修正后的一阶和二阶矩分别为

Δt=α(1β1)1β2t(1β1t)1β2\Delta_t = \alpha \frac{(1-\beta_1)\sqrt{1-\beta_2^t}}{(1-\beta_1^t)\sqrt{1-\beta_2}}

tt 足够大的时候, β1t0,β2t0\beta_1^t\to0, \beta_2^t\to0, 此时

Δt=α1β11β2\Delta_t = \alpha \frac{1-\beta_1}{\sqrt{1-\beta_2}}

如果之前所有时刻的梯度不全为 0, 则依据 Cauchy-Schwarz 不等式,我们有 (E[XY])2E[X2]E[Y2](\mathbb{E}[XY])^2\leq \mathbb{E}[X^2]\mathbb{E}[Y^2]. 令 X=1X=1, Y=gY=g, 则我们有

(E[g])2E[g2]E[g]E[g2]1(\mathbb{E}[g])^2\leq \mathbb{E}[g^2] \Rightarrow \frac{|\mathbb{E}[g]|}{\sqrt{\mathbb{E}[g^2]}}\leq 1

此时,我们有

E[gt]=m^t,E[gt2]=v^t\mathbb{E}[g_t] = \hat{m}_t, \mathbb{E}[g_t^2] = \hat{v}_t

因此,

Δt=αm^tv^t=αE[g]E[g2]α|\Delta_t| = \left|\alpha\frac{\hat{m}_t}{\sqrt{\hat{v}_t}}\right|=\alpha \left|\frac{|\mathbb{E}[g]|}{\sqrt{\mathbb{E}[g^2]}}\right|\leq\alpha

从而我们有

Δt{α1β11β2 if 1β1>1β2α otherwise|\Delta_t| \leq\begin{cases} \alpha \frac{1-\beta_1}{\sqrt{1-\beta_2}} & \text{ if }1-\beta_1>\sqrt{1-\beta_2}\\ \alpha &\text{ otherwise} \end{cases}

实际上,Δt\Delta_t 可以理解为一个 trust region, 可以用来保证当前更新的参数不会离原始参数太远。

作者定义 signal-noise ratio (SNR) 为

SNR=m^tv^tSNR = \frac{\hat{m}_t}{\sqrt{\hat{v}_t}}

当 SNR 较小时,说明此时的不确定性比较大,因此 Δt\Delta_t 也比较小。这就避免了模型朝错误的方向更新。也就是automatic annealing.

Δt\Delta_t 还对 gradient 的 scaling 有不变的性质,这是因为,

cm^tc2v^t=m^tv^t\frac{c\cdot\hat{m}_t}{\sqrt{c^2\cdot\hat{v}_t}} = \frac{\hat{m}_t}{\sqrt{\hat{v}_t}}

Bias Correction

上一节提到,Adam 算法的初始化是存在 bias 的,作者在本届就对齐进行了分析。令 gg 为目标函数 ff 的梯度,我们希望估计其二阶动量的期望.令 g1,,gTg_1,\dots,g_T 分别为 θ1,,θT\theta_1,\dots,\theta_T 处的梯度估计,其中 gtp(gt)g_t\sim p(g_t) 是对应时刻梯度的分布。令 v0=0v_0=0, 在 tt 时刻,我们有

vt=β2vt1+(1β2)gt2=(1β2)i=1tβ2tigi2v_t = \beta_2 v_{t-1} + (1-\beta_2)g_t^2 = (1-\beta_2)\sum_{i=1}^t\beta_2^{t-i} g_i^2

我们希望计算 E[vt]\mathbb{E}[v_t]E[gt2]\mathbb{E}[g_t^2] 之间的关系,我们有

E[vt]=[(1β2)i=1tβ2tigi2]=E[gt2](1β2)i=1tβ2ti+ζ=E[gt2](1β2t)+ζ\begin{aligned} \mathbb{E}[v_t] &= \left[(1-\beta_2)\sum_{i=1}^t\beta_2^{t-i} g_i^2\right]\\ &= \mathbb{E}[g_t^2]\cdot (1-\beta_2)\sum_{i=1}^t\beta_2^{t-i}+\zeta\\ &= \mathbb{E}[g_t^2](1-\beta_2^t)+\zeta \end{aligned}

其中当 E[gi2]\mathbb{E}[g_i^2] 为 stationary 时,ζ=0\zeta=0, 否则我们可以通过控制 β2\beta_2 来让 past gradient 保持在一个较小的规模。最后,我们剩下的就是 1β2t1-\beta_2^t, 这也是我们在算法中进行修正的地方。

对于一阶动量 mtm_t 的修正也是同理。

Convergence Analysis

作者在本节中使用了 online learning framework 来分写 Adam 的收敛性。给定一系列 convex cost function f1(θ),,fT(θ)f_1(\theta),\dots,f_T(\theta). 在 tt 时刻,我们的目标是基于上一个 cost function ft(θ)f_t(\theta) 来预测 θt\theta_t.

作者在这里使用 regret 来分析,记 ft(θ)f_t(\theta^*)tt 时刻最优的参数对应的 cost function, regret 定义为

R(T)=t=1T[ft(θt)ft(θ)]R(T) = \sum_{t=1}^T [f_t(\theta_t) - f_t(\theta^*)]

其中,

θ=argminθXt=1Tft(θ)\theta^* = \arg\min_{\theta\in\mathcal{X}}\sum_{t=1}^Tf_t(\theta)

则我们有如下的结论

Theorem 1 假设

  1. 函数 ftf_t 的梯度是有界的,即 ft(θ)2G\|\nabla f_t(\theta)\|_2\leq G, ft(θ)G\|\nabla f_t(\theta)\|_{\infty}\leq G_{\infty} 对任意 θRd\theta\in\mathbb{R}^d 都成立
  2. {θ1,,θT}\{\theta_1,\dots,\theta_T\} 中任意两个参数的距离都是有界的,即 θmθm2D\|\theta_m-\theta_m\|_2\leq D, θmθnD\|\theta_m-\theta_n\|_{\infty}\leq D_{\infty} 对任意 m,n{1,,T}m,n\in\{1,\dots,T\} 都成立
  3. β1,β2[0,1)\beta_1,\beta_2\in[0,1) 满足 β12β2<1\frac{\beta_1^2}{\sqrt{\beta_2}}<1αt=α/t\alpha_t=\alpha/\sqrt{t}, β1,t=β1λt1\beta_{1,t}=\beta_1\lambda^{t-1}, λ(0,1)\lambda\in(0,1), 则我们有
R(T)D22α(1β1)i=1dTv^T,i+α(1+β1)G(1β1)1β2(1γ)2i=1dg1:T,i2+i=1dD2G1β22α(1β1)(1λ)2R(T)\leq \frac{D^2}{2\alpha(1-\beta_1)}\sum_{i=1}^d\sqrt{T\hat{v}_{T,i}}+\frac{\alpha(1+\beta_1)G_{\infty}}{(1-\beta_1)\sqrt{1-\beta_2}(1-\gamma)^2}\sum_{i=1}^d\|g_{1:T,i}\|_2+\sum_{i=1}^d\frac{D_{\infty}^2G_{\infty}\sqrt{1-\beta_2}}{2\alpha(1-\beta_1)(1-\lambda)^2}

结果说明,当我们的 data feature 稀疏且梯度有界时我们有

i=1dg1:T,i2dGT\sum_{i=1}^d\|g_{1:T,i}\|_2 \ll dG_{\infty}\sqrt{T}

以及

i=1dTv^T,idGT\sum_{i=1}^d\sqrt{T\hat{v}_{T,i}} \ll dG_{\infty}\sqrt{T}

实际上,对于 Adam 以及 Adamgrad,这个上界可以优化到 O(logdT)O(\log d\sqrt{T}).

最终,我们可以证明 Adam 的收敛性

Corollary 1 假设

  1. 函数 ftf_t 的梯度是有界的,即 ft(θ)2G\|\nabla f_t(\theta)\|_2\leq G, ft(θ)G\|\nabla f_t(\theta)\|_{\infty}\leq G_{\infty} 对任意 θRd\theta\in\mathbb{R}^d 都成立
  2. {θ1,,θT}\{\theta_1,\dots,\theta_T\} 中任意两个参数的距离都是有界的,即 θmθm2D\|\theta_m-\theta_m\|_2\leq D, θmθnD\|\theta_m-\theta_n\|_{\infty}\leq D_{\infty} 对任意 m,n{1,,T}m,n\in\{1,\dots,T\} 都成立 则对 T1T\geq1, 我们有
R(T)T=O(1T) \frac{R(T)}{T}=O\left(\frac{1}{\sqrt{T}}\right)

Experiment

作者在 logistic regression, MLP, CNN 等三种模型架构上进行了实验。

Conclusion

作者在本文中提出了 Adam optimizer, 一个基于 AdaGrad 和 RMSProp 优点的优化器,作者通过理论验证了 Adam 的收敛性,然后通过实验验证了 Adam 的有效性。

    作者提出了一个针对 Adam 优化器的 weight decay 方法

    Introduction

    作者首先回顾了动态梯度算法如 AdaGrad, RMSProp, Adam 的进展。已有工作表明动态梯度算法的泛化性要比 SGD with momentum 要差。作者在本文中探究了在 SGD 和 Adam 中使用 L2 regularization 和 weight decay 对最终模型表现的影响。结果表明,模型泛化性较差的原因在于对于 Adam, L2 regularization 的效果要比 SGD 差。

    作者有如下发现:

    1. L2 regularization 和 weight decay 不等价。在 SGD 中,L2 regularization 是等价的,但是在 Adam 中这个结论不成立。具体来说,L2 regularization 对历史参数的惩罚要小于 weight decay
    2. L2 regularization 对 Adam 效果提升有效
    3. weight decay 对于 SGD 和 AdamW 都很有效,在 SGD 中,weight decay 与 L2 regularization 等价
    4. 最优的 weight decay 取决于 batch, batch 越大,最优的 weight decay 越小
    5. 通过 learning rate scheduler 可以进一步提高 Adam 的表现

    作者在本文中的主要贡献是通过解耦梯度更新中的 weight decay 来提高 Adam 的 regularization.

    作者的主要 motivation 是提升 Adam 表现,让其可以和 SGD with momentum 相比

    Method

    Weight decay 的定义如下

    θt+1=(1λ)θtαft(θt)(1)\theta_{t+1} = (1-\lambda)\theta_t - \alpha \nabla f_t(\theta_t) \tag{1}

    其中 λ\lambda 是 weight decay rate, ft(θt)\nabla f_t(\theta_t) 是第 tt 个 batch 的梯度,α\alpha 是学习率。

    首先,对于标准的 SGD 来说,weight decay 与 L2 regularization 等价

    Proposition 1 对于标准的 SGD 来说,对损失函数 ft(θt)f_t(\theta_t) 执行 weight decay (公式 (1)(1))与对损失函数 ft(θt)+λ/2θt22f_t(\theta_t)+\lambda'/2\|\theta_t\|_2^2 执行梯度下降算法是等价的,这里 λ=λ/α\lambda'=\lambda/\alpha

    证明比较简单,只需要写出损失函数的梯度下降更新公式即可。

    基于这个结论,大部分优化算法都将 L2 regularization 和 weight decay 看做是等价的。但实际上,这个结论对于 adaptive gradient 方法来说是不成立的。结论如下

    Proposition 2 令 OO 为一个 optimizer, 其目标函数为 ft(θ)f_t(\theta), 当不考虑 weight decay 时,梯度更新过程为 θt+1θtαMtft(θt)\theta_{t+1}\gets \theta_t-\alpha M_t\nabla f_t(\theta_t). 当考虑 weight decay 时,梯度更新过程为 θt+1(1λ)θtαMtft(θt)\theta_{t+1}\gets (1-\lambda)\theta_t-\alpha M_t\nabla f_t(\theta_t). 如果 MtkIM_t\neq kI, 则不存在 λ\lambda', 使得 OO 在优化目标函数 ftreg(θ)=ft(θ)+λ/2θ22f_t^{reg}(\theta)=f_t(\theta)+\lambda'/2\|\theta\|_2^2 时,不考虑 weight decay 的梯度更新与 OO 在优化目标函数 ft(θ)f_t(\theta) 时,考虑 weight decay 的梯度更新等价。

    证明比较简单,只需要写出两个目标函数对应的梯度更新公式即可。

    作者通过分析发现,在 adaptive gradient 方法中,对于 L2 regularization,梯度和 regularization 是打包在一起考虑的。而 weight decay 是分开考虑的。这就导致了对于梯度比较大的权重,L2 regularization 的学习率较小,从而 regularization 效应减弱。而 weight decay 中,这种效应则不存在。因此 weight decay 的 regularization 效应更强。

    作者通过这个分析,给出了一个 weight decay 与 L2 regularization 相等的条件

    Proposition 3 令 OO 为一个 optimizer, 其目标函数为 ft(θ)f_t(\theta), 当不考虑 weight decay 时,梯度更新过程为 θt+1θtαMtft(θt)\theta_{t+1}\gets \theta_t-\alpha M_t\nabla f_t(\theta_t). 当考虑 weight decay 时,梯度更新过程为 θt+1(1λ)θtαMtft(θt)\theta_{t+1}\gets (1-\lambda)\theta_t-\alpha M_t\nabla f_t(\theta_t). 如果 Mt=diag(s)1M_t= \mathrm{diag}(s)^{-1} (si>0,is_i>0,\forall i), 则 OO 在优化目标函数

    ftreg(θ)=ft(θ)+λ2αθs22 f_t^{reg}(\theta)=f_t(\theta)+\frac{\lambda'}{2\alpha}\|\theta\odot \sqrt{s}\|_2^2

    时,不考虑 weight decay 的梯度更新与 OO 在优化目标函数 ft(θ)f_t(\theta) 时,考虑 weight decay 的梯度更新等价。

    上面的结论显示,对于比较大的 preconditioner sis_i, 其在相比于 L2 regularization 被 regularized 的效应更强。

    为了解耦这两个参数,作者提出了 SGDW 算法,其 weight decay 和梯度更新同时进行,算法如下图所示

    SGDW algorithm

    在算法中,为了支持同时给 α\alphaλ\lambda 做 scheduling, 作者提出了一个 scaling factor ηt\eta_t, ηt\eta_t 由用户定义的 scheduler SetScheduleMultiplier(t) 决定。此时,针对 SGD with momentum 的 weight decay 与 L2 regularization 是等价的

    同理,我们也可以对 Adam 算法实行同样的操作,算法如下图所示

    AdamW algorithm

    Conclusion

    作者在本文中分析了 adaptive gradient 方法中 L2 regularization 与 weight decay 的不一致性。基于分析,作者提出了 SGDW 和 AdamW 两个优化算法。

      Muon (MomentUm Orthogonalized by Newton-Schulz) 是一个针对二维神经网络的优化器,它基于 SGD-momentum 改进,增加了一个 Newton-Schulz 的后处理步骤

      Method

      Newton-Schulz (NS) 的目的是用一个正交矩阵近似一个给定矩阵,即

      Ortho(G)=argminO{OGF:either OTO=I or OOT=I}\mathrm{Ortho}(G) = \arg\min_{O} \{\|O-G\|_F: \text{either } O^TO=I\text{ or } OO^T=I\}

      也就是说,NS iteration 将 SDG-moment 的更新矩阵替换为了“最近的” semi-orthogonal matrix. 这等价于将更新矩阵替换为 UVTUV^T, 其中 USVTUSV^T 是更新矩阵的 SVD 分解。

      [!tip] 作者观察到,对于 SGD-momentum 和 Adam 来说,其在基于 transformer 的神经网络里有非常高的 condition number, 也就是 optimizer 仅在少数几个方向上进行优化。作者认为,通过正交化,可以有效提高模型在其他方向上的更新速度,进而提高模型表现

      Newton-Schulz

      作者提到,正交化矩阵的方法有很多,比如 SVD 分解,但是其问题是非常慢,还有 Coupled Newton iteration, 但是其精度要求非常高,必须要在 float32 以上。

      作者因此使用了 Newton-Schulz iteration.

      G=USVTG=USV^T 是 SGD-momentum 更新矩阵的 SVD 分解,则基于系数 (a,b,c)(a,b,c) 的 NS iteration 定义如下:

      G=aG+b(GGT)G+c(GGT)2G=(aI+b(GGT)+c(GGT)2)G=(aI+bUS2UT+cUS4UT)USVT=U(aS+bS3+cS5)VT\begin{aligned} G' &= aG + b(GG^T)G + c(GG^T)^2G\\ &= (aI+b(GG^T)+c(GG^T)^2)G\\ &= (aI+bUS^2U^T+cUS^4U^T)USV^T\\ &= U(aS+bS^3+cS^5)V^T \end{aligned}

      也就是说,如果我们定义五次多项式函数 ϕ(x)=ax+bx3+cx5\phi(x)=ax+bx^3+cx^5, 然后执行 NN 次 NS iteration, 则我们得到 UϕN(S)VTU\phi^N(S)V^T, 其中 ϕN\phi^N 代表 ϕ\phi 复合 NN 次。

      为了保证 NS iteration 收敛到 Ortho(G)=UVT\mathrm{Ortho}(G) = UV^T, 我们必须保证两点:

      1. SS 的值,也就是 GG 的奇异值必须在区间 [0,1][0,1]
      2. ϕ\phi 必须满足 ϕN1\phi^N\to 1, NN\to\infty, x[0,1]\forall x\in[0,1].

      为了满足第一个条件,我们可以对 GG 进行 rescale, 即 GG/GFG\gets G/\|G\|_F, rescale 不影响最终的结果,即 Ortho(G)=Ortho(cG)\mathrm{Ortho}(G) = \mathrm{Ortho}(cG).

      对于 ϕ(x)\phi(x), 我们有很多选择,比如我们定义 (a,b,c):=(2,1.5,0.5)(a,b,c):=(2,-1.5,0.5) 就得到如下结果

      Plot of f(x)=2x-1.5x^3+0.5x^5

      Coefficient Optimization

      尽管 (a,b,c):=(2,1.5,0.5)(a,b,c):=(2,-1.5,0.5) 已经满足了第二个条件,但是我们还是想进一步优化,优化的方向主要有两个:

      1. aa 尽可能大,这是因为 ϕ(0)=a\phi'(0)=a 控制了较小奇异值的收敛速率。
      2. 对于所有的 x[0,1]x\in[0,1], 我们希望 ϕN(x)[1ϵ,1+ϵ]\phi^N(x)\in[1-\epsilon, 1+\epsilon], NN\to\infty. 这样 NS iteration 的结果与 Ortho(G)\mathrm{Ortho}(G) 不会相差太远。

      作者发现, ϵ\epsilon 可以设置为 0.30.3 而不影响 Muon optimizer 的收敛性。因此,作者的目标现在是

      maxas.t.limNϕN(x)[0.7,1.3]\begin{aligned} \max\quad &a\\ \mathrm{s.t.}\quad &\lim_{N\to\infty}\phi^N(x)\in[0.7, 1.3] \end{aligned}

      作者通过 ad-hoc gradient 方法求解得到一组数值解为 (a,b,c)=(3.4445,4.7750,2.0315)(a,b,c)=(3.4445, 4.7750, 2.0315), 作者将这组数值应用于 Muon optimizer 中。迭代结果如下图,可以看到,当 x0x\approx0 时,函数变得更加陡峭。

      Plot of f(x)=3.4445x-4.7750x^3+2.0315x^5

      实验中,作者发现,仅需迭代五次,最终的结果就 work 的很好。作者还尝试了不同的多项式,结果发现并没有太大的提升。

      Algorithm

      最终,Muon Optimizer 的算法如下

      Muon Algorithm

      其中, NewtonSchulz5 算法伪代码定义如下

      def newtonschulz5(G, steps=5, eps=1e-7):
          assert G.ndim=2
          a, b, c = (3.4445, -4.7750, 2.0315)
          X = G.bfloat16()
          X /= (X.norm() + eps)
          if G.size(0) > G.size(1):
              X = X.T
          for _ in range(steps):
              A = X @ X.T
              B = b * A + c * A @ A
              X = a * X + B @ X
          if G.size(0) > G.size(1):
              X = X.T
          return X
      

      Analysis

      本节作者分析了以下 Muon 的内存占用和算力开销。

      在 NS iteration 之前,Muon optimizer 和 SGD-moment 是一样的。

      对于 n×mn\times m 的矩阵(假设 mnm\leq n), 首先 NS iteration 会进行转置,NS iteration 的每一步需要 2(2nm2+m3)2(2nm^2+m^3) FLOPs, 其中括号前面的系数 22 代表精度。因此,Muon 相比于 SGD momentum 需要的额外 FLOPs 为 2T(2nm2+m3)2T(2nm^2+m^3), 其中 TT 是迭代次数。

      使用 baseline 进行一次训练(前向 + 后向),所需要的 FLOPS 为 6nmB6nmB, 其中 BB 是 batch size. 因此,Muon 的 FLOP 开销至多为 Tm/BTm/B, 其中 mm 是模型的 hidden size, BB 是 batch size, TT 是 NS iteration 的步数。

      作者分别基于 nanoGPT 和 LLaMA-405B 进行验证,结果发现,Muon optimizer 带来的额外开销不足 1%1\%.

      作者发信啊,使用 Nesterov-style momentum 可以比普通的 SGD-momentum 效果更好,因此作者在 muon 中使用了前者。

      作者还发现,对于 QKV layer,分别进行优化效果会更好。

      Experiments

      Optimizer comparison by tokens

      Limitation and Future Work

      Muon 仅被设计用于优化 2D 参数(因为涉及矩阵计算),其余的参数仍然需要 AdamW 等优化器参与。

      作者认为未来的工作有:

      1. 能否 scale up Muon Optimizer
      2. 分布式优化
      3. 在 fine-tuning 和 RL 阶段使用 Muon Optimizer

      Conclusion

      作者提出了 Muon optimizer,该优化器在 nanoGPT speedrun 上取得了 SOTA 的结果,作者详细介绍了优化器的工作原理。

      Moonlight

      Kimi 提出了 Moonlight, 一个基于 Muon optimizer 训练得到的 16B-A3B MoE LLM. 作者详细介绍了如何 scale up muon optimizer.

      Introduction

      Muon 验证了 Muon optimizer 在小语言模型 nanoGPT 上的表现,但是对于更大规模 LLM 的表现,尚未有人探究。因此 Kimi 就希望在大规模 LLM 上验证 Muon optimizer 的表现。作者主要进行了两点改进:

      1. 加入 weight decay
      2. 调整了不同参数更新的 scale

      基于改进后的 Muon optimizer, 其训练效率相比于 AdamW 提升了 2 倍。作者基于 Muon Optimizer 训练得到了 Moonlight, 一个 16B-A3B 的 MoE LLM.

      作者主要作出了三点贡献:

      1. 探究了 weight decay 在 scaling Muon 时的作用
      2. 分布式 Muon optimizer 的实现
      3. 验证了 Muon optimizer 的 scaling law

      Method

      Background

      作者首先介绍了一下 Muon optimizer, 给定步数 tt, 参数矩阵 Wt1W_{t-1}, momentum μ\mu, 学习率 ηt\eta_t 以及目标函数 Lt\mathcal{L}_t, Muon optimizer 的更新方式如下:

      Mt=μMt1+Lt(Wt1)Ot=NewtonSchulz(Mt)Wt=Wt1ηtOt\begin{aligned} M_t &= \mu M_{t-1} + \nabla\mathcal{L}_t(W_{t-1})\\ O_t &= \mathrm{Newton-Schulz}(M_t)\\ W_t &= W_{t-1} - \eta_t O_t \end{aligned}

      这里 MtM_t 是 gradient 的 momentum, 初始化为 M0=0M_0=0. 在上面的更新公式中,Newton-Schulz 的作用是求解 (MtMtT)1/2Mt(M_tM_t^T)^{-1/2}M_t. 令 Mt=UΣVTM_t=U\Sigma V^T 为 SVD 分解, 我们有

      (MtMtT)1/2Mt=UVT(M_tM_t^T)^{-1/2}M_t = UV^T

      这是一个半正交矩阵,即 (UVT)T(UVT)=I(UV^T)^T(UV^T)=I.

      Newton-Schulz 迭代的具体公式如下:

      X0=MtMtF,Xk=aXk1+b(Xk1Xk1T)Xk1+c(Xk1Xk1T)2Xk1X_0 = \frac{M_t}{\|M_t\|_F},\quad X_k = aX_{k-1} + b(X_{k-1}X_{k-1}^T)X_{k-1} + c(X_{k-1}X_{k-1}^T)^2X_{k-1}

      其中,normalization 是为了保证 Newton-Schulz 的收敛性。 a,b,ca,b,c 是三个超参数,在 Muon 中设置为 (a,b,c)=(3.4445,4.7750,2.0315)(a,b,c)=(3.4445, 4.7750, 2.0315).

      Scaling up Muon

      作者发现,尽管 Muon 在小规模场景下 work 的很好,但是大规模性场景下的收益就非常有限了。作者发现,这是因为模型的参数以及每一层输出的 RMS 变得很大,这可能会影响模型的性能。因此,作者就和 AdamW 一样使用 weight dacay 来避免这个问题,即

      Wt=Wt1ηt(Ot+λWt1)W_t =W_{t-1} - \eta_t(O_t + \lambda W_{t-1})

      作者通过实验对比了 AdamW, vanilla Muon 和 Muon w/ weigth decay 三者的表现,实验结果如下图所示

      实验结果显示,尽管 vanilla Muon 手链最快,但是由于其权重增长很快,因此最后模型的表现不如 AdamW 和 Muon w/ weigth decay.

      接下来,作者分析了以下更新矩阵的 Root Mean Square (RMS), 结论是 Muon optimizer 的 RMS 与参数矩阵的形状相关:

      Lemma For a full-rank matrix parameter of shape [A,B][A, B], its theoretical Muon update RMS is 1/max(A,B)\sqrt{1/\max(A, B)}.

      证明如下:通过 Newton-Schulz 迭代,我们得到 Ot=UVTO_t=UV^T, 其中 Mt=UΣVTM_t=U\Sigma V^T 是 SVD 分解,我们有

      RMS(Ot)=i=1Aj=1BOt,i,j2AB=rAB\mathrm{RMS}(O_t) = \sqrt{\frac{\sum_{i=1}^A\sum_{j=1}^BO_{t,i,j}^2}{AB}}=\sqrt{\frac{r}{AB}}

      其中, r=rank(Mt)r=\mathrm{rank}(M_t) , 这样就完成了证明。

      而 Adam 和 AdamW 的 RMS 都在 11 附近。作者认为 RMS 也会影响模型表现:

      1. max(A,B)\max(A,B) 过大时,如 dense MLP matrix, 其更新就会变得很小,限制了模型的表现
      2. max(A,B)\max(A,B) 过小时,如 GQA 中的 KV head 或者 DeepSeek-V3 中的 MLA, 更新又会变得很大,导致训练不稳定。

      因此,作者就提出了一个 rescaling 的技巧,来消除 Muon optimizer 的影响。

      作者通过实验发现,AdamW 的 RMS 通常在 0.20.40.2\sim0.4 左右,因此,作者将 Muon optimizer 的更新设置如下

      Wt=Wt1ηt(0.2Otmax(A,B)+λWt1)W_t = W_{t-1} - \eta_t(0.2\cdot O_t\cdot \sqrt{\max(A,B)} + \lambda W_{t-1})

      基于这个改变, Muon 和 AdamW 可以共享学习率以及 weight decay 参数。

      Distributed Muon

      ZeRO-1 天然适合 AdamW, 因为 AdamW 都是 element-wise 进行计算的。但是 Muon 则需要梯度矩阵的全部信息。因此,作者就针对 ZeRO-1 进行适配, 提出了 Distributed Muon, 分布式版本将优化器的状态进行切分,然后加入了两个额外的操作:

      1. DP gather: 将 ZeRO-1 切分的梯度矩阵 gather 为一个完整的矩阵
      2. Calculate Full Update: 对完整的梯度矩阵执行 Newton-Schulz 迭代

      最终,Distributed Muon 的算法如下图所示

      Distributed Muon

      最后,作者分析了一下 distributed Muon 和 distributed AdamW 的内存和算力占用:

      1. 内存开销:Muon 只有一阶矩,而 AdamW 有二阶矩,因此 Muon 的额外内存开销为 AdamW 的一半。
      2. 通信开销:对于 ZeRO-1,通信开销来源于三个过程:All-Gather 参数 PP 用于前向传播, Reduce-Scatter 梯度 GG 用于反向传播, All-Gather 更新后的参数 PP 用于下一轮的前向传播。AdamW 不引入额外通信,所以其每个参数的通信量为 4+4=84+4=8, 分别代表 GGPP 的通信量。而 Muon 则需要额外的一次通信来得到 full matrix, 因此每个参数通信量为 4+4+2=104+4+2=10, 分别代表 P,GP, G 和 full matrix. 也就是说,分布式 Muon 的通信量最高为 AdamW 的 1.251.25 倍。实际上由于我们使用 multiple DP, 这个比例会更接近于 1.01.0.
      3. latency:Distributed Muon 相比于 AdamW latency 更高,这是因为 Muon 需要进行 DP gather 以及计算 Newton-Schulz 迭代。但实际上,latency 很小,因为 Newton-Schulz 迭代只需要迭代 5 次,并且 optimizer 的 end-to-end latency 相比于 forward-backward 过程是可以忽略的。一些额外的技巧也可以降低 latency.

      实际在训练的过程中,作者发现 Distributed Muon 相比于 AdamW 并没有太明显的 latency.

      Experiments

      Scaling Law of Muon

      作者分析了一下 Muon Optimizer 的 scaling law, 实验结果如下图所示

      Scaling law for Muon and AdamW

      实验结果表明,在最优设置下,Muon Optimizer 只需要 52%52\% 的 FLOPs 就可以达到 AdamW 的表现

      Pretraining with Muon

      作者分贝使用 AdamW 和 Muon 训练模型,然后评测了以下模型在不同 benchmark 上的表现,结果如下图所示

      Pretraining performance of different optimizers

      可以看到,在相同的设置下,Muon optimizer 的表现更好。

      Dynamics of Singular Spectrum

      Muon optimizer 的核心思想就是让比较难更新的方向也能被更新到,本节作者就探究了 Muon 是否满足这个性质,作者对参数矩阵进行 SVD 分解,然后定义 SVD entropy 如下

      H(σ)=1logni=1nσi2j=1nσj2logσi2j=1nσj2H(\sigma) = -\frac{1}{\log n}\sum_{i=1}^n\frac{\sigma_i^2}{\sum_{j=1}^n\sigma_j^2}\log\frac{\sigma_i^2}{\sum_{j=1}^n\sigma_j^2}

      作者对 SVD entropy 可视化如下

      Visualization of SVD entropy

      可以看到,Muon optimizer 的 SVD entropy 比 AdamW 更大,这说明 AdamW 的更新方向更多更广,验证了 Muon optimizer 的核心思想

      SFT with Muon

      作者还在 SFT 阶段验证了 Muon optimizer 的有效性。实验结果如下图所示

      Performance of Muon on SFT stage

      结论主要有两个:

      1. 预训练阶段与 SFT 阶段使用不同的优化器时,模型表现没有明显区别
      2. SFT 阶段使用 Muon 可以达到与 AdamW 差不多的表现,但是最好还是在 pre-training 阶段使用 Muon

      Conclusion

      作者探究了如何 scale up Muon Optimizer. 通过改进,作者在 16B-A3B 的 MoE LLM 上验证了 Muon Optimizer 的性能。实验结果发现,Muon Optimizer 的训练效率比 AdamW 提升了 2 倍左右。

      作者提出了三个未来可行的研究方向:

      1. 目前 Muon 只能针对 2D 参数进行优化,其他参数仍然依赖于 AdamW 优化器,是否可以使用 Muon 优化所有参数?
      2. Muon optimizer 可以理解是 spectral norm 下的 steepest descent 方法,如何将其扩展到 Schatten norm 是一个可以研究的方向
      3. 实验里提到,预训练和 SFT 阶段使用不同的 optimizer, 表现不是最优的,如何解决这个因为不同 optimizer 导致的性能差距是一个需要解决的问题。