Overview of optimizers for LLMs

Overview of optimizers for LLMs

Author

Published

May. 04, 2026

PDF

作者提出了 Adam, 一个一阶的优化方法,Adam 更加高效,且具有 scaling invariant 的性质。

Introduction

作者首先回顾了一下已有优化器的进展,其中主要是 SGD. 在本文中,作者提出了 Adam, 一个针对高维参数空间的一阶优化器,Adam 基于 gradient 的一阶和二阶信息为不同的参数安排不同的学习率。Adam 的来源是 adaptive moment estimation. Adam 主要是结合了 AdaGrad 和 RMSProp 两个算法的优点。

Adam 与 RMSProp 的区别在于:

  1. RMSProp 在 rescaled gradient 上进行 momentum 的计算然后更新,而 Adam 直接使用一阶和二阶矩来进行估计
  2. RMSProp 没有 bias-correction 项

Adam 的主要优势为:

  1. 参数更新的量级与 gradient 的 scaling 无关
  2. 步长被 stepsize 超参数限制
  3. 不要求目标函数 stationary
  4. 对于稀疏梯度 work 的比较好
  5. 优化器自带 annealing

Algorithm

Adam 的算法如下图所示

Adam Algorithm

我们优化的目标函数如下

minθf(θ)\min_{\theta}\quad f(\theta)

这里 ff 一般是一个神经网络。我们记 f(θ)f(\theta)θt\theta_t 处的梯度为 gt=θf(θt)g_t=\nabla_{\theta}f(\theta_t).

算法运行时,会更新梯度 mtm_t 以及梯度二阶矩 vtv_t 的 exponential moving average. 超参数 β1,β2\beta_1,\beta_2 负责控制 exponential decay rates. 这里 mtm_tvtv_t 分别是一阶动量(均值)和二阶动量(未中心化的 variance)的估计。由于 mtm_tvtv_t 的初始化都是 0, 因此他们会引入 bias, 作者在后续通过修正解决了这个问题。

假设 ϵ=0\epsilon=0, 如果除了当前时刻 tt 之外,之前所有时刻的梯度 gi=0,i<tg_i=0,i<t, 此时

mt=(1βt)gt,vt=(1β2)gt2m_t = (1-\beta_t)g_t, v_t=(1-\beta_2)g_t^2

修正后的一阶和二阶矩分别为

Δt=α(1β1)1β2t(1β1t)1β2\Delta_t = \alpha \frac{(1-\beta_1)\sqrt{1-\beta_2^t}}{(1-\beta_1^t)\sqrt{1-\beta_2}}

tt 足够大的时候, β1t0,β2t0\beta_1^t\to0, \beta_2^t\to0, 此时

Δt=α1β11β2\Delta_t = \alpha \frac{1-\beta_1}{\sqrt{1-\beta_2}}

如果之前所有时刻的梯度不全为 0, 则依据 Cauchy-Schwarz 不等式,我们有 (E[XY])2E[X2]E[Y2](\mathbb{E}[XY])^2\leq \mathbb{E}[X^2]\mathbb{E}[Y^2]. 令 X=1X=1, Y=gY=g, 则我们有

(E[g])2E[g2]E[g]E[g2]1(\mathbb{E}[g])^2\leq \mathbb{E}[g^2] \Rightarrow \frac{|\mathbb{E}[g]|}{\sqrt{\mathbb{E}[g^2]}}\leq 1

此时,我们有

E[gt]=m^t,E[gt2]=v^t\mathbb{E}[g_t] = \hat{m}_t, \mathbb{E}[g_t^2] = \hat{v}_t

因此,

Δt=αm^tv^t=αE[g]E[g2]α|\Delta_t| = \left|\alpha\frac{\hat{m}_t}{\sqrt{\hat{v}_t}}\right|=\alpha \left|\frac{|\mathbb{E}[g]|}{\sqrt{\mathbb{E}[g^2]}}\right|\leq\alpha

从而我们有

Δt{α1β11β2 if 1β1>1β2α otherwise|\Delta_t| \leq\begin{cases} \alpha \frac{1-\beta_1}{\sqrt{1-\beta_2}} & \text{ if }1-\beta_1>\sqrt{1-\beta_2}\\ \alpha &\text{ otherwise} \end{cases}

实际上,Δt\Delta_t 可以理解为一个 trust region, 可以用来保证当前更新的参数不会离原始参数太远。

作者定义 signal-noise ratio (SNR) 为

SNR=m^tv^tSNR = \frac{\hat{m}_t}{\sqrt{\hat{v}_t}}

当 SNR 较小时,说明此时的不确定性比较大,因此 Δt\Delta_t 也比较小。这就避免了模型朝错误的方向更新。也就是automatic annealing.

Δt\Delta_t 还对 gradient 的 scaling 有不变的性质,这是因为,

cm^tc2v^t=m^tv^t\frac{c\cdot\hat{m}_t}{\sqrt{c^2\cdot\hat{v}_t}} = \frac{\hat{m}_t}{\sqrt{\hat{v}_t}}

Bias Correction

上一节提到,Adam 算法的初始化是存在 bias 的,作者在本届就对齐进行了分析。令 gg 为目标函数 ff 的梯度,我们希望估计其二阶动量的期望.令 g1,,gTg_1,\dots,g_T 分别为 θ1,,θT\theta_1,\dots,\theta_T 处的梯度估计,其中 gtp(gt)g_t\sim p(g_t) 是对应时刻梯度的分布。令 v0=0v_0=0, 在 tt 时刻,我们有

vt=β2vt1+(1β2)gt2=(1β2)i=1tβ2tigi2v_t = \beta_2 v_{t-1} + (1-\beta_2)g_t^2 = (1-\beta_2)\sum_{i=1}^t\beta_2^{t-i} g_i^2

我们希望计算 E[vt]\mathbb{E}[v_t]E[gt2]\mathbb{E}[g_t^2] 之间的关系,我们有

E[vt]=[(1β2)i=1tβ2tigi2]=E[gt2](1β2)i=1tβ2ti+ζ=E[gt2](1β2t)+ζ\begin{aligned} \mathbb{E}[v_t] &= \left[(1-\beta_2)\sum_{i=1}^t\beta_2^{t-i} g_i^2\right]\\ &= \mathbb{E}[g_t^2]\cdot (1-\beta_2)\sum_{i=1}^t\beta_2^{t-i}+\zeta\\ &= \mathbb{E}[g_t^2](1-\beta_2^t)+\zeta \end{aligned}

其中当 E[gi2]\mathbb{E}[g_i^2] 为 stationary 时,ζ=0\zeta=0, 否则我们可以通过控制 β2\beta_2 来让 past gradient 保持在一个较小的规模。最后,我们剩下的就是 1β2t1-\beta_2^t, 这也是我们在算法中进行修正的地方。

对于一阶动量 mtm_t 的修正也是同理。

Convergence Analysis

作者在本节中使用了 online learning framework 来分写 Adam 的收敛性。给定一系列 convex cost function f1(θ),,fT(θ)f_1(\theta),\dots,f_T(\theta). 在 tt 时刻,我们的目标是基于上一个 cost function ft(θ)f_t(\theta) 来预测 θt\theta_t.

作者在这里使用 regret 来分析,记 ft(θ)f_t(\theta^*)tt 时刻最优的参数对应的 cost function, regret 定义为

R(T)=t=1T[ft(θt)ft(θ)]R(T) = \sum_{t=1}^T [f_t(\theta_t) - f_t(\theta^*)]

其中,

θ=argminθXt=1Tft(θ)\theta^* = \arg\min_{\theta\in\mathcal{X}}\sum_{t=1}^Tf_t(\theta)

则我们有如下的结论

Theorem 1 假设

  1. 函数 ftf_t 的梯度是有界的,即 ft(θ)2G\|\nabla f_t(\theta)\|_2\leq G, ft(θ)G\|\nabla f_t(\theta)\|_{\infty}\leq G_{\infty} 对任意 θRd\theta\in\mathbb{R}^d 都成立
  2. {θ1,,θT}\{\theta_1,\dots,\theta_T\} 中任意两个参数的距离都是有界的,即 θmθm2D\|\theta_m-\theta_m\|_2\leq D, θmθnD\|\theta_m-\theta_n\|_{\infty}\leq D_{\infty} 对任意 m,n{1,,T}m,n\in\{1,\dots,T\} 都成立
  3. β1,β2[0,1)\beta_1,\beta_2\in[0,1) 满足 β12β2<1\frac{\beta_1^2}{\sqrt{\beta_2}}<1αt=α/t\alpha_t=\alpha/\sqrt{t}, β1,t=β1λt1\beta_{1,t}=\beta_1\lambda^{t-1}, λ(0,1)\lambda\in(0,1), 则我们有
R(T)D22α(1β1)i=1dTv^T,i+α(1+β1)G(1β1)1β2(1γ)2i=1dg1:T,i2+i=1dD2G1β22α(1β1)(1λ)2R(T)\leq \frac{D^2}{2\alpha(1-\beta_1)}\sum_{i=1}^d\sqrt{T\hat{v}_{T,i}}+\frac{\alpha(1+\beta_1)G_{\infty}}{(1-\beta_1)\sqrt{1-\beta_2}(1-\gamma)^2}\sum_{i=1}^d\|g_{1:T,i}\|_2+\sum_{i=1}^d\frac{D_{\infty}^2G_{\infty}\sqrt{1-\beta_2}}{2\alpha(1-\beta_1)(1-\lambda)^2}

结果说明,当我们的 data feature 稀疏且梯度有界时我们有

i=1dg1:T,i2dGT\sum_{i=1}^d\|g_{1:T,i}\|_2 \ll dG_{\infty}\sqrt{T}

以及

i=1dTv^T,idGT\sum_{i=1}^d\sqrt{T\hat{v}_{T,i}} \ll dG_{\infty}\sqrt{T}

实际上,对于 Adam 以及 Adamgrad,这个上界可以优化到 O(logdT)O(\log d\sqrt{T}).

最终,我们可以证明 Adam 的收敛性

Corollary 1 假设

  1. 函数 ftf_t 的梯度是有界的,即 ft(θ)2G\|\nabla f_t(\theta)\|_2\leq G, ft(θ)G\|\nabla f_t(\theta)\|_{\infty}\leq G_{\infty} 对任意 θRd\theta\in\mathbb{R}^d 都成立
  2. {θ1,,θT}\{\theta_1,\dots,\theta_T\} 中任意两个参数的距离都是有界的,即 θmθm2D\|\theta_m-\theta_m\|_2\leq D, θmθnD\|\theta_m-\theta_n\|_{\infty}\leq D_{\infty} 对任意 m,n{1,,T}m,n\in\{1,\dots,T\} 都成立 则对 T1T\geq1, 我们有
R(T)T=O(1T) \frac{R(T)}{T}=O\left(\frac{1}{\sqrt{T}}\right)

Experiment

作者在 logistic regression, MLP, CNN 等三种模型架构上进行了实验。

Conclusion

作者在本文中提出了 Adam optimizer, 一个基于 AdaGrad 和 RMSProp 优点的优化器,作者通过理论验证了 Adam 的收敛性,然后通过实验验证了 Adam 的有效性。

    作者提出了一个针对 Adam 优化器的 weight decay 方法

    Introduction

    作者首先回顾了动态梯度算法如 AdaGrad, RMSProp, Adam 的进展。已有工作表明动态梯度算法的泛化性要比 SGD with momentum 要差。作者在本文中探究了在 SGD 和 Adam 中使用 L2 regularization 和 weight decay 对最终模型表现的影响。结果表明,模型泛化性较差的原因在于对于 Adam, L2 regularization 的效果要比 SGD 差。

    作者有如下发现:

    1. L2 regularization 和 weight decay 不等价。在 SGD 中,L2 regularization 是等价的,但是在 Adam 中这个结论不成立。具体来说,L2 regularization 对历史参数的惩罚要小于 weight decay
    2. L2 regularization 对 Adam 效果提升有效
    3. weight decay 对于 SGD 和 AdamW 都很有效,在 SGD 中,weight decay 与 L2 regularization 等价
    4. 最优的 weight decay 取决于 batch, batch 越大,最优的 weight decay 越小
    5. 通过 learning rate scheduler 可以进一步提高 Adam 的表现

    作者在本文中的主要贡献是通过解耦梯度更新中的 weight decay 来提高 Adam 的 regularization.

    作者的主要 motivation 是提升 Adam 表现,让其可以和 SGD with momentum 相比

    Method

    Weight decay 的定义如下

    θt+1=(1λ)θtαft(θt)(1)\theta_{t+1} = (1-\lambda)\theta_t - \alpha \nabla f_t(\theta_t) \tag{1}

    其中 λ\lambda 是 weight decay rate, ft(θt)\nabla f_t(\theta_t) 是第 tt 个 batch 的梯度,α\alpha 是学习率。

    首先,对于标准的 SGD 来说,weight decay 与 L2 regularization 等价

    Proposition 1 对于标准的 SGD 来说,对损失函数 ft(θt)f_t(\theta_t) 执行 weight decay (公式 (1)(1))与对损失函数 ft(θt)+λ/2θt22f_t(\theta_t)+\lambda'/2\|\theta_t\|_2^2 执行梯度下降算法是等价的,这里 λ=λ/α\lambda'=\lambda/\alpha

    证明比较简单,只需要写出损失函数的梯度下降更新公式即可。

    基于这个结论,大部分优化算法都将 L2 regularization 和 weight decay 看做是等价的。但实际上,这个结论对于 adaptive gradient 方法来说是不成立的。结论如下

    Proposition 2 令 OO 为一个 optimizer, 其目标函数为 ft(θ)f_t(\theta), 当不考虑 weight decay 时,梯度更新过程为 θt+1θtαMtft(θt)\theta_{t+1}\gets \theta_t-\alpha M_t\nabla f_t(\theta_t). 当考虑 weight decay 时,梯度更新过程为 θt+1(1λ)θtαMtft(θt)\theta_{t+1}\gets (1-\lambda)\theta_t-\alpha M_t\nabla f_t(\theta_t). 如果 MtkIM_t\neq kI, 则不存在 λ\lambda', 使得 OO 在优化目标函数 ftreg(θ)=ft(θ)+λ/2θ22f_t^{reg}(\theta)=f_t(\theta)+\lambda'/2\|\theta\|_2^2 时,不考虑 weight decay 的梯度更新与 OO 在优化目标函数 ft(θ)f_t(\theta) 时,考虑 weight decay 的梯度更新等价。

    证明比较简单,只需要写出两个目标函数对应的梯度更新公式即可。

    作者通过分析发现,在 adaptive gradient 方法中,对于 L2 regularization,梯度和 regularization 是打包在一起考虑的。而 weight decay 是分开考虑的。这就导致了对于梯度比较大的权重,L2 regularization 的学习率较小,从而 regularization 效应减弱。而 weight decay 中,这种效应则不存在。因此 weight decay 的 regularization 效应更强。

    作者通过这个分析,给出了一个 weight decay 与 L2 regularization 相等的条件

    Proposition 3 令 OO 为一个 optimizer, 其目标函数为 ft(θ)f_t(\theta), 当不考虑 weight decay 时,梯度更新过程为 θt+1θtαMtft(θt)\theta_{t+1}\gets \theta_t-\alpha M_t\nabla f_t(\theta_t). 当考虑 weight decay 时,梯度更新过程为 θt+1(1λ)θtαMtft(θt)\theta_{t+1}\gets (1-\lambda)\theta_t-\alpha M_t\nabla f_t(\theta_t). 如果 Mt=diag(s)1M_t= \mathrm{diag}(s)^{-1} (si>0,is_i>0,\forall i), 则 OO 在优化目标函数

    ftreg(θ)=ft(θ)+λ2αθs22 f_t^{reg}(\theta)=f_t(\theta)+\frac{\lambda'}{2\alpha}\|\theta\odot \sqrt{s}\|_2^2

    时,不考虑 weight decay 的梯度更新与 OO 在优化目标函数 ft(θ)f_t(\theta) 时,考虑 weight decay 的梯度更新等价。

    上面的结论显示,对于比较大的 preconditioner sis_i, 其在相比于 L2 regularization 被 regularized 的效应更强。

    为了解耦这两个参数,作者提出了 SGDW 算法,其 weight decay 和梯度更新同时进行,算法如下图所示

    SGDW algorithm

    在算法中,为了支持同时给 α\alphaλ\lambda 做 scheduling, 作者提出了一个 scaling factor ηt\eta_t, ηt\eta_t 由用户定义的 scheduler SetScheduleMultiplier(t) 决定。此时,针对 SGD with momentum 的 weight decay 与 L2 regularization 是等价的

    同理,我们也可以对 Adam 算法实行同样的操作,算法如下图所示

    AdamW algorithm

    Conclusion

    作者在本文中分析了 adaptive gradient 方法中 L2 regularization 与 weight decay 的不一致性。基于分析,作者提出了 SGDW 和 AdamW 两个优化算法。

      Muon (MomentUm Orthogonalized by Newton-Schulz) 是一个针对二维神经网络的优化器,它基于 SGD-momentum 改进,增加了一个 Newton-Schulz 的后处理步骤

      Method

      Newton-Schulz (NS) 的目的是用一个正交矩阵近似一个给定矩阵,即

      Ortho(G)=argminO{OGF:either OTO=I or OOT=I}\mathrm{Ortho}(G) = \arg\min_{O} \{\|O-G\|_F: \text{either } O^TO=I\text{ or } OO^T=I\}

      也就是说,NS iteration 将 SDG-moment 的更新矩阵替换为了“最近的” semi-orthogonal matrix. 这等价于将更新矩阵替换为 UVTUV^T, 其中 USVTUSV^T 是更新矩阵的 SVD 分解。

      [!tip] 作者观察到,对于 SGD-momentum 和 Adam 来说,其在基于 transformer 的神经网络里有非常高的 condition number, 也就是 optimizer 仅在少数几个方向上进行优化。作者认为,通过正交化,可以有效提高模型在其他方向上的更新速度,进而提高模型表现

      Newton-Schulz

      作者提到,正交化矩阵的方法有很多,比如 SVD 分解,但是其问题是非常慢,还有 Coupled Newton iteration, 但是其精度要求非常高,必须要在 float32 以上。

      作者因此使用了 Newton-Schulz iteration.

      G=USVTG=USV^T 是 SGD-momentum 更新矩阵的 SVD 分解,则基于系数 (a,b,c)(a,b,c) 的 NS iteration 定义如下:

      G=aG+b(GGT)G+c(GGT)2G=(aI+b(GGT)+c(GGT)2)G=(aI+bUS2UT+cUS4UT)USVT=U(aS+bS3+cS5)VT\begin{aligned} G' &= aG + b(GG^T)G + c(GG^T)^2G\\ &= (aI+b(GG^T)+c(GG^T)^2)G\\ &= (aI+bUS^2U^T+cUS^4U^T)USV^T\\ &= U(aS+bS^3+cS^5)V^T \end{aligned}

      也就是说,如果我们定义五次多项式函数 ϕ(x)=ax+bx3+cx5\phi(x)=ax+bx^3+cx^5, 然后执行 NN 次 NS iteration, 则我们得到 UϕN(S)VTU\phi^N(S)V^T, 其中 ϕN\phi^N 代表 ϕ\phi 复合 NN 次。

      为了保证 NS iteration 收敛到 Ortho(G)=UVT\mathrm{Ortho}(G) = UV^T, 我们必须保证两点:

      1. SS 的值,也就是 GG 的奇异值必须在区间 [0,1][0,1]
      2. ϕ\phi 必须满足 ϕN1\phi^N\to 1, NN\to\infty, x[0,1]\forall x\in[0,1].

      为了满足第一个条件,我们可以对 GG 进行 rescale, 即 GG/GFG\gets G/\|G\|_F, rescale 不影响最终的结果,即 Ortho(G)=Ortho(cG)\mathrm{Ortho}(G) = \mathrm{Ortho}(cG).

      对于 ϕ(x)\phi(x), 我们有很多选择,比如我们定义 (a,b,c):=(2,1.5,0.5)(a,b,c):=(2,-1.5,0.5) 就得到如下结果

      Plot of f(x)=2x-1.5x^3+0.5x^5

      Coefficient Optimization

      尽管 (a,b,c):=(2,1.5,0.5)(a,b,c):=(2,-1.5,0.5) 已经满足了第二个条件,但是我们还是想进一步优化,优化的方向主要有两个:

      1. aa 尽可能大,这是因为 ϕ(0)=a\phi'(0)=a 控制了较小奇异值的收敛速率。
      2. 对于所有的 x[0,1]x\in[0,1], 我们希望 ϕN(x)[1ϵ,1+ϵ]\phi^N(x)\in[1-\epsilon, 1+\epsilon], NN\to\infty. 这样 NS iteration 的结果与 Ortho(G)\mathrm{Ortho}(G) 不会相差太远。

      作者发现, ϵ\epsilon 可以设置为 0.30.3 而不影响 Muon optimizer 的收敛性。因此,作者的目标现在是

      maxas.t.limNϕN(x)[0.7,1.3]\begin{aligned} \max\quad &a\\ \mathrm{s.t.}\quad &\lim_{N\to\infty}\phi^N(x)\in[0.7, 1.3] \end{aligned}

      作者通过 ad-hoc gradient 方法求解得到一组数值解为 (a,b,c)=(3.4445,4.7750,2.0315)(a,b,c)=(3.4445, 4.7750, 2.0315), 作者将这组数值应用于 Muon optimizer 中。迭代结果如下图,可以看到,当 x0x\approx0 时,函数变得更加陡峭。

      Plot of f(x)=3.4445x-4.7750x^3+2.0315x^5

      实验中,作者发现,仅需迭代五次,最终的结果就 work 的很好。作者还尝试了不同的多项式,结果发现并没有太大的提升。

      Algorithm

      最终,Muon Optimizer 的算法如下

      Muon Algorithm

      其中, NewtonSchulz5 算法伪代码定义如下

      def newtonschulz5(G, steps=5, eps=1e-7):
          assert G.ndim=2
          a, b, c = (3.4445, -4.7750, 2.0315)
          X = G.bfloat16()
          X /= (X.norm() + eps)
          if G.size(0) > G.size(1):
              X = X.T
          for _ in range(steps):
              A = X @ X.T
              B = b * A + c * A @ A
              X = a * X + B @ X
          if G.size(0) > G.size(1):
              X = X.T
          return X
      

      Analysis

      本节作者分析了以下 Muon 的内存占用和算力开销。

      在 NS iteration 之前,Muon optimizer 和 SGD-moment 是一样的。

      对于 n×mn\times m 的矩阵(假设 mnm\leq n), 首先 NS iteration 会进行转置,NS iteration 的每一步需要 2(2nm2+m3)2(2nm^2+m^3) FLOPs, 其中括号前面的系数 22 代表精度。因此,Muon 相比于 SGD momentum 需要的额外 FLOPs 为 2T(2nm2+m3)2T(2nm^2+m^3), 其中 TT 是迭代次数。

      使用 baseline 进行一次训练(前向 + 后向),所需要的 FLOPS 为 6nmB6nmB, 其中 BB 是 batch size. 因此,Muon 的 FLOP 开销至多为 Tm/BTm/B, 其中 mm 是模型的 hidden size, BB 是 batch size, TT 是 NS iteration 的步数。

      作者分别基于 nanoGPT 和 LLaMA-405B 进行验证,结果发现,Muon optimizer 带来的额外开销不足 1%1\%.

      作者发信啊,使用 Nesterov-style momentum 可以比普通的 SGD-momentum 效果更好,因此作者在 muon 中使用了前者。

      作者还发现,对于 QKV layer,分别进行优化效果会更好。

      Experiments

      Optimizer comparison by tokens

      Limitation and Future Work

      Muon 仅被设计用于优化 2D 参数(因为涉及矩阵计算),其余的参数仍然需要 AdamW 等优化器参与。

      作者认为未来的工作有:

      1. 能否 scale up Muon Optimizer
      2. 分布式优化
      3. 在 fine-tuning 和 RL 阶段使用 Muon Optimizer

      Conclusion

      作者提出了 Muon optimizer,该优化器在 nanoGPT speedrun 上取得了 SOTA 的结果,作者详细介绍了优化器的工作原理。

      Moonlight

      Kimi 提出了 Moonlight, 一个基于 Muon optimizer 训练得到的 16B-A3B MoE LLM. 作者详细介绍了如何 scale up muon optimizer.

      Introduction

      Muon 验证了 Muon optimizer 在小语言模型 nanoGPT 上的表现,但是对于更大规模 LLM 的表现,尚未有人探究。因此 Kimi 就希望在大规模 LLM 上验证 Muon optimizer 的表现。作者主要进行了两点改进:

      1. 加入 weight decay
      2. 调整了不同参数更新的 scale

      基于改进后的 Muon optimizer, 其训练效率相比于 AdamW 提升了 2 倍。作者基于 Muon Optimizer 训练得到了 Moonlight, 一个 16B-A3B 的 MoE LLM.

      作者主要作出了三点贡献:

      1. 探究了 weight decay 在 scaling Muon 时的作用
      2. 分布式 Muon optimizer 的实现
      3. 验证了 Muon optimizer 的 scaling law

      Method

      Background

      作者首先介绍了一下 Muon optimizer, 给定步数 tt, 参数矩阵 Wt1W_{t-1}, momentum μ\mu, 学习率 ηt\eta_t 以及目标函数 Lt\mathcal{L}_t, Muon optimizer 的更新方式如下:

      Mt=μMt1+Lt(Wt1)Ot=NewtonSchulz(Mt)Wt=Wt1ηtOt\begin{aligned} M_t &= \mu M_{t-1} + \nabla\mathcal{L}_t(W_{t-1})\\ O_t &= \mathrm{Newton-Schulz}(M_t)\\ W_t &= W_{t-1} - \eta_t O_t \end{aligned}

      这里 MtM_t 是 gradient 的 momentum, 初始化为 M0=0M_0=0. 在上面的更新公式中,Newton-Schulz 的作用是求解 (MtMtT)1/2Mt(M_tM_t^T)^{-1/2}M_t. 令 Mt=UΣVTM_t=U\Sigma V^T 为 SVD 分解, 我们有

      (MtMtT)1/2Mt=UVT(M_tM_t^T)^{-1/2}M_t = UV^T

      这是一个半正交矩阵,即 (UVT)T(UVT)=I(UV^T)^T(UV^T)=I.

      Newton-Schulz 迭代的具体公式如下:

      X0=MtMtF,Xk=aXk1+b(Xk1Xk1T)Xk1+c(Xk1Xk1T)2Xk1X_0 = \frac{M_t}{\|M_t\|_F},\quad X_k = aX_{k-1} + b(X_{k-1}X_{k-1}^T)X_{k-1} + c(X_{k-1}X_{k-1}^T)^2X_{k-1}

      其中,normalization 是为了保证 Newton-Schulz 的收敛性。 a,b,ca,b,c 是三个超参数,在 Muon 中设置为 (a,b,c)=(3.4445,4.7750,2.0315)(a,b,c)=(3.4445, 4.7750, 2.0315).

      Scaling up Muon

      作者发现,尽管 Muon 在小规模场景下 work 的很好,但是大规模性场景下的收益就非常有限了。作者发现,这是因为模型的参数以及每一层输出的 RMS 变得很大,这可能会影响模型的性能。因此,作者就和 AdamW 一样使用 weight dacay 来避免这个问题,即

      Wt=Wt1ηt(Ot+λWt1)W_t =W_{t-1} - \eta_t(O_t + \lambda W_{t-1})

      作者通过实验对比了 AdamW, vanilla Muon 和 Muon w/ weigth decay 三者的表现,实验结果如下图所示

      实验结果显示,尽管 vanilla Muon 手链最快,但是由于其权重增长很快,因此最后模型的表现不如 AdamW 和 Muon w/ weigth decay.

      接下来,作者分析了以下更新矩阵的 Root Mean Square (RMS), 结论是 Muon optimizer 的 RMS 与参数矩阵的形状相关:

      Lemma For a full-rank matrix parameter of shape [A,B][A, B], its theoretical Muon update RMS is 1/max(A,B)\sqrt{1/\max(A, B)}.

      证明如下:通过 Newton-Schulz 迭代,我们得到 Ot=UVTO_t=UV^T, 其中 Mt=UΣVTM_t=U\Sigma V^T 是 SVD 分解,我们有

      RMS(Ot)=i=1Aj=1BOt,i,j2AB=rAB\mathrm{RMS}(O_t) = \sqrt{\frac{\sum_{i=1}^A\sum_{j=1}^BO_{t,i,j}^2}{AB}}=\sqrt{\frac{r}{AB}}

      其中, r=rank(Mt)r=\mathrm{rank}(M_t) , 这样就完成了证明。

      而 Adam 和 AdamW 的 RMS 都在 11 附近。作者认为 RMS 也会影响模型表现:

      1. max(A,B)\max(A,B) 过大时,如 dense MLP matrix, 其更新就会变得很小,限制了模型的表现
      2. max(A,B)\max(A,B) 过小时,如 GQA 中的 KV head 或者 DeepSeek-V3 中的 MLA, 更新又会变得很大,导致训练不稳定。

      因此,作者就提出了一个 rescaling 的技巧,来消除 Muon optimizer 的影响。

      作者通过实验发现,AdamW 的 RMS 通常在 0.20.40.2\sim0.4 左右,因此,作者将 Muon optimizer 的更新设置如下

      Wt=Wt1ηt(0.2Otmax(A,B)+λWt1)W_t = W_{t-1} - \eta_t(0.2\cdot O_t\cdot \sqrt{\max(A,B)} + \lambda W_{t-1})

      基于这个改变, Muon 和 AdamW 可以共享学习率以及 weight decay 参数。

      Distributed Muon

      ZeRO-1 天然适合 AdamW, 因为 AdamW 都是 element-wise 进行计算的。但是 Muon 则需要梯度矩阵的全部信息。因此,作者就针对 ZeRO-1 进行适配, 提出了 Distributed Muon, 分布式版本将优化器的状态进行切分,然后加入了两个额外的操作:

      1. DP gather: 将 ZeRO-1 切分的梯度矩阵 gather 为一个完整的矩阵
      2. Calculate Full Update: 对完整的梯度矩阵执行 Newton-Schulz 迭代

      最终,Distributed Muon 的算法如下图所示

      Distributed Muon

      最后,作者分析了一下 distributed Muon 和 distributed AdamW 的内存和算力占用:

      1. 内存开销:Muon 只有一阶矩,而 AdamW 有二阶矩,因此 Muon 的额外内存开销为 AdamW 的一半。
      2. 通信开销:对于 ZeRO-1,通信开销来源于三个过程:All-Gather 参数 PP 用于前向传播, Reduce-Scatter 梯度 GG 用于反向传播, All-Gather 更新后的参数 PP 用于下一轮的前向传播。AdamW 不引入额外通信,所以其每个参数的通信量为 4+4=84+4=8, 分别代表 GGPP 的通信量。而 Muon 则需要额外的一次通信来得到 full matrix, 因此每个参数通信量为 4+4+2=104+4+2=10, 分别代表 P,GP, G 和 full matrix. 也就是说,分布式 Muon 的通信量最高为 AdamW 的 1.251.25 倍。实际上由于我们使用 multiple DP, 这个比例会更接近于 1.01.0.
      3. latency:Distributed Muon 相比于 AdamW latency 更高,这是因为 Muon 需要进行 DP gather 以及计算 Newton-Schulz 迭代。但实际上,latency 很小,因为 Newton-Schulz 迭代只需要迭代 5 次,并且 optimizer 的 end-to-end latency 相比于 forward-backward 过程是可以忽略的。一些额外的技巧也可以降低 latency.

      实际在训练的过程中,作者发现 Distributed Muon 相比于 AdamW 并没有太明显的 latency.

      Experiments

      Scaling Law of Muon

      作者分析了一下 Muon Optimizer 的 scaling law, 实验结果如下图所示

      Scaling law for Muon and AdamW

      实验结果表明,在最优设置下,Muon Optimizer 只需要 52%52\% 的 FLOPs 就可以达到 AdamW 的表现

      Pretraining with Muon

      作者分贝使用 AdamW 和 Muon 训练模型,然后评测了以下模型在不同 benchmark 上的表现,结果如下图所示

      Pretraining performance of different optimizers

      可以看到,在相同的设置下,Muon optimizer 的表现更好。

      Dynamics of Singular Spectrum

      Muon optimizer 的核心思想就是让比较难更新的方向也能被更新到,本节作者就探究了 Muon 是否满足这个性质,作者对参数矩阵进行 SVD 分解,然后定义 SVD entropy 如下

      H(σ)=1logni=1nσi2j=1nσj2logσi2j=1nσj2H(\sigma) = -\frac{1}{\log n}\sum_{i=1}^n\frac{\sigma_i^2}{\sum_{j=1}^n\sigma_j^2}\log\frac{\sigma_i^2}{\sum_{j=1}^n\sigma_j^2}

      作者对 SVD entropy 可视化如下

      Visualization of SVD entropy

      可以看到,Muon optimizer 的 SVD entropy 比 AdamW 更大,这说明 AdamW 的更新方向更多更广,验证了 Muon optimizer 的核心思想

      SFT with Muon

      作者还在 SFT 阶段验证了 Muon optimizer 的有效性。实验结果如下图所示

      Performance of Muon on SFT stage

      结论主要有两个:

      1. 预训练阶段与 SFT 阶段使用不同的优化器时,模型表现没有明显区别
      2. SFT 阶段使用 Muon 可以达到与 AdamW 差不多的表现,但是最好还是在 pre-training 阶段使用 Muon

      Conclusion

      作者探究了如何 scale up Muon Optimizer. 通过改进,作者在 16B-A3B 的 MoE LLM 上验证了 Muon Optimizer 的性能。实验结果发现,Muon Optimizer 的训练效率比 AdamW 提升了 2 倍左右。

      作者提出了三个未来可行的研究方向:

      1. 目前 Muon 只能针对 2D 参数进行优化,其他参数仍然依赖于 AdamW 优化器,是否可以使用 Muon 优化所有参数?
      2. Muon optimizer 可以理解是 spectral norm 下的 steepest descent 方法,如何将其扩展到 Schatten norm 是一个可以研究的方向
      3. 实验里提到,预训练和 SFT 阶段使用不同的 optimizer, 表现不是最优的,如何解决这个因为不同 optimizer 导致的性能差距是一个需要解决的问题。