KL divergence: from definition to application

Why unbiased KL estimates need not give unbiased KL gradients; forward vs reverse KL, estimators in on/off-policy RL, and experiments.

Author

Published

May. 04, 2026

PDF

Introduction

在强化学习中,KL divergence 常被用作 policy 正则项,但很多不稳定现象并非来自 KL 本身,而是来自其估计方式。本文展示了为什么“无偏的 KL 估计”并不能保证“无偏的 KL 梯度”,并系统分析了不同 KL estimator 在 on-policy 与 off-policy 场景下的行为差异。通过理论推导与实验验证,文章揭示了 KL 作为 loss 与 reward shaping 时的本质区别,并解释了实践中低方差 KL 设计背后的原因。

在本节中,我们先介绍 KL divergence 的基本定义,然后我们介绍 KL divergence 的一般形式,即 f-divergence.

KL-divergence

KL divergence 用于衡量近似概率分布 Q(x)Q(x) 到真实概率分布 P(x)P(x) 的误差,我们可以将其理解为:如果我们用 Q(x)Q(x) 来替换 P(x)P(x), 会有多大的信息损失?

连续概率分布的 KL divergence 的定义如下

DKL(PQ)=ExP[logP(x)Q(x)]=P(x)log(P(x)Q(x))dxD_{KL}(P\parallel Q) =\mathbb{E}_{x\sim P}\left[\log \frac{P(x)}{Q(x)}\right]=\int P(x)\log\left(\frac{P(x)}{Q(x)}\right)dx

离散概率分布的 KL divergence 定义如下

DKL(PQ)=xP(x)log(P(x)Q(x))D_{KL}(P\parallel Q) = \sum_{x} P(x)\log\left(\frac{P(x)}{Q(x)}\right)

KL divergence 有如下几个关键性质:

  1. 非负性:DKL(PQ)0D_{KL}(P\parallel Q)\geq0, 且 DKL(PQ)=0D_{KL}(P\parallel Q)=0 当且仅当 P(x)=Q(x)P(x)=Q(x) 对任意 xx 成立
  2. 非对称性: 一般情况下,DKL(PQ)DKL(QP)D_{KL}(P\parallel Q)\neq D_{KL}(Q\parallel P).
  3. 有限性:如果存在 xx 使得 P(x)>0P(x)>0 但是 Q(x)=0Q(x)=0, 则 DKL(PQ)=D_{\mathrm{KL}}(P\parallel Q)=\infty.

一般我们称 DKL(PQ)D_{KL}(P\parallel Q)forward KL (相对于 QQ), 对应的还有 reverse KL DKL(QP)D_{KL}(Q\parallel P) (相对于 QQ).

F-divergence

KL divergence 是 f-divergence 的一种特殊情况。 f-divergence 是一类衡量不同概率分布 PPQQ 的函数 Df(PQ)D_f(P\parallel Q).

假设函数 f:(0,)Rf:(0,\infty)\to\mathbb{R} 是一个凸函数,且 f(1)=0f(1)=0. PPQQ 是两个概率分布,则 f-divergence 定义如下

Df(PQ)=ExQ[f(P(x)Q(x))]=Q(x)f(P(x)Q(x))dxD_f(P\parallel Q) = \mathbb{E}_{x\sim Q}\left[ f\left(\frac{P(x)}{Q(x)}\right)\right]=\int Q(x)f\left(\frac{P(x)}{Q(x)}\right)dx

我们称 ffDfD_fgenerator.

以下是几种常见的 f-divergence:

Namegenerator
forward KL divergencef(x)=xlogxf(x)=x\log x
reverse KL divergencef(x)=logxf(x)=-\log x
Total variationf(x)=1/2x1f(x)=1/2\vert x-1\vert
χ2\chi^2-divergencef(x)=(x1)2f(x)=(x-1)^2
JS-divergencef(x)=xlog2xx+1+log2x+1f(x)=x\log\frac{2x}{x+1}+\log\frac{2}{x+1}

我们这里推导一下 KL divergence 对应的 generator.

对于 forward KL, 注意到

Df(PQ)=Q(x)(P(x)Q(x)logP(x)Q(x))dx=P(x)logP(x)Q(x)dx=DKL(PQ)D_f(P \parallel Q) = \int Q(x) \left( \frac{P(x)}{Q(x)} \log \frac{P(x)}{Q(x)} \right) dx = \int P(x) \log \frac{P(x)}{Q(x)} dx = D_{KL}(P \parallel Q)

因此 forward KL 对应的 generator 为 f=xlogxf=x\log x.

对于 reverse KL, 注意到

Df(PQ)=Q(x)(logP(x)Q(x))dx=Q(x)logQ(x)P(x)dx=DKL(QP)D_f(P \parallel Q) = \int Q(x) \left( -\log \frac{P(x)}{Q(x)} \right) dx = \int Q(x) \log \frac{Q(x)}{P(x)} dx = D_{KL}(Q \parallel P)

因此 forward KL 对应的 generator 为 f=logxf=-\log x.

Properties of F-divergence

f-divergence 性质如下

  1. linearity: Da1f1+a2f2=a1Df1+a2Df2D_{a_1f_1+a_2f_2}=a_1D_{f_1}+a_2D_{f_2}.
  2. Df=DgD_f=D_g 当且仅当存在 cRc\in\mathbb{R} 使得 f(x)=g(x)+c(x1)f(x)=g(x)+c(x-1).
  3. non-negativity. Df(PQ)0D_f(P\parallel Q)\geq0Df(PQ)=0D_f(P\parallel Q)=0 当且仅当 P=QP=Q.

性质 2 证明如下:

如果 f(x)=g(x)+c(x1)f(x)=g(x)+c(x-1), 则通过定义,我们可以验证得到 Df=DgD_f=D_g.

反之,如果 Df=DgD_f=D_g, 令 h=fgh=f-g, 对任意两个在集合 {0,1}\{0, 1\} 上的概率分布 P,QP,Q, 由于 Df(PQ)Dg(PQ)=0D_f(P\parallel Q) - D_g(P\parallel Q)=0, 我们有

h(P(1)Q(1))=Q(0)Q(1)h(P(0)Q(0))h\left(\frac{P(1)}{Q(1)}\right) = -\frac{Q(0)}{Q(1)}h\left(\frac{P(0)}{Q(0)}\right)

我们不妨假设 P(0)=aQ(0)P(0)=aQ(0), P(1)=bQ(1)P(1)=bQ(1), 结合 P(0)+P(1)=1P(0)+P(1)=1Q(0)+Q(1)=1Q(0)+Q(1)=1 我们有

Q(0)=1aba,Q(1)=b1baQ(0) = \frac{1-a}{b-a}, Q(1) = \frac{b-1}{b-a}

从而

h(b)b1=h(a)a1\frac{h(b)}{b-1}=\frac{h(a)}{a-1}

由于我们可以任意选定 PPQQ, 因此 hh 是一个线性函数,形式为 h(x)=c(x1)h(x)=c(x-1). \blacksquare

Approximation

本节中,我们将介绍针对 KL divergence 的三种近似形式。

在实际计算 KL divergence 时,由于:

  1. 完整计算 KL divergence 需要的算力或内存过高
  2. 没有闭式解
  3. 我们可以仅保存 log-probability, 而不是整个概率分布

因此,我们假设我们只能计算输入 xx 对应的概率 P(x)P(x)Q(x)Q(x). 一般来说,我们会通过 Monte Carlo estimate 来进行近似。即我们先对 PP 进行采样得到 x1,,xNPx_1,\dots,x_N\sim P, 然后我们构建估计量。

一个高的估计量应该是无偏 (unbiased) 并且方差低 (low variance) 的。John Schulman 给出了三种 estimator. 我们分别针对 forward KL 和 reverse KL 进行介绍。这里我们定义

r=P(x)Q(x)r = \frac{P(x)}{Q(x)}

Forward KL Estimation

对于 forward KL DKL(PQ)D_{KL}(P\parallel Q), 其对应的 generator 为 f(x)=xlogxf(x)=x\log x, 注意到 ExQ[r]=1\mathbb{E}_{x\sim Q}[r]=1, 且 ff 是一个凸函数,因此我们有 f(r)f(1)(r1)0f(r)-f'(1)(r-1)\geq0, 从而我们可以得到一个新的估计为 k=rlogr(r1)\boxed{k=r\log r - (r-1)}.

Reverse KL Estimation

对于 reverse KL DKL(QP)D_{KL}(Q\parallel P), 其对应的 generator 为 f(x)=logxf(x)=-\log x, 由概率性质,k1=logr\boxed{k_1=-\log r}DKL(QP)D_{KL}(Q\parallel P) 的一个无偏估计。但是 k1k_1 的问题在于 当 rr 非常小时,k1k_1 会变得非常大。也就是说,k1k_1 的 variance 比较高。

John Schulman 基于 f-divergence 泰勒展开给出了一个新的估计 k2k_2, 其定义为

k2=12(logr)2\boxed{k_2 = \frac12(\log r)^2}

其期望为

EQ[k2]=EQ[12(logr)2]\mathbb{E}_Q[k_2] = \mathbb{E}_Q\left[\frac12(\log r)^2\right]

这是一个 f-divergence, 对应的 generator 为 fk2(x)=1/2(logx)2f_{k_2}(x)=1/2(\log x)^2, 而 DKL(QP)D_{KL}(Q\parallel P) 对应的 generator 为 fk1(x)=logxf_{k_1}(x)=-\log x.

PPQQ 比较靠近时,我们记 θ=r1\theta=r-1, 对 Df(PQ)D_{f}(P\parallel Q)x=1x=1 处进行展开得到

Df(PQ)=ExQ[f(r)]=ExQ[f(1)+f(1)θ+f(1)2f(1+λ)θ2+O(θ3)]=f(1)2Fθ2+O(θ3)\begin{aligned} D_f(P\parallel Q) &= \mathbb{E}_{x\sim Q}\left[ f(r)\right]\\ &= \mathbb{E}_{x\sim Q}\left[ f(1) + f'(1)\theta + \frac{f''(1)}{2}f(1+\lambda)\theta^2+O(\theta^3)\right]\\ &= \frac{f''(1)}{2}F\theta^2+O(\theta^3) \end{aligned}

这里我们应用了 f(1)=0f(1)=0, E[θ]=0\mathbb{E}[\theta]=0, F=E[f(1+λθ)F=\mathbb{E}[f(1+\lambda\theta) 是 Fisher information matrix.

我们分别带入 fk1(x)f_{k_1}(x)fk2(x)f_{k_2}(x) 得到 fk1(1)=fk2(1)=1f_{k_1}''(1)=f_{k_2}''(1)=1, 即 k1k_1k2k_2PPQQ 比较靠近时二阶近似是相同的。因此,**k2k_2 表面上是一个二阶近似,在分布接近时有效,但本质上是在优化 另一个 f-divergence*

John Schulman 还构造了第三种估计。回顾前面 f-divergence 的性质 2,即当 f(x)=g(x)+c(x1)f(x)=g(x)+c(x-1) 时,我们有 Df=DgD_f=D_g, 因此我们可以选取合适的 cc 来降低估计的 variance. 注意到 k1k_1 的主要问题在于存在负数的可能性,因此我们就构建一个对应的估计量来解决这个问题。注意到 logxx1\log x \leq x -1, 因此我们可以令 c=1c=1, 此时就得到了新的估计

k3=(r1)logr\boxed{k_3 =(r-1)- \log r }

k3k_3 继承了 k1k_1 的无偏性,并且 k3k_3 通过 f-divergence 等价类消除了负值,兼顾无偏与低方差,解决了 k1k_1 variance 过大的问题

Experiments on Approximation

对于分布 P=N(0,1)P=\mathcal{N}(0,1) 以及 Q=N(0.1,1)Q=\mathcal{N}(0.1, 1), 真实的 KV divergence 为 0.005, 三个 estimator 的误差如下表所示

MethodBiasStd Dev
k1k_10.000120.0005
k2k_20.00251.4175
k3k_30.00001.4163

P=N(1,1)P=\mathcal{N}(1,1), Q=N(0.1,1)Q=\mathcal{N}(0.1, 1) 时, 真实的 KV divergence 为 0.405, 三个 estimator 的误差如下表所示

MethodBiasStd Dev
k1k_1-0.00002.2223
k2k_20.20251.6762
k3k_30.00001.6342

可以看到 k1k_1 的 variance 非常大,k2k_2 是一个有偏估计,k3k_3 既满足了无偏又满足了 low variance.

Summary

我们接下来总结 reverse KL DKL(QP)D_{KL}(Q\parallel P) 的近似 k1k_1, k2k_2k3k_3 的性质如下 (r=P(x)/Q(x)r=P(x)/Q(x))

estimationdefinitionmotivationbiasvariance
k1k_1logr-\log rnaive estimationunbiasedhigh
k2k_212(logr)2\frac12(\log r)^2f-divergence, taylor expansionbiasedlow
k3k_3(r1)logr(r-1)- \log rf-divergence, non-negativityunbiasedlow

Applications to ML

Remark 本节内容主要参考了 KL Divergence for Machine Learning

我们假设真实目标分布和近似的目标分布分别记为 pdata(x)p_{data}(x)pθ(x)p_\theta(x). 由于 KL divergence 的非对称性,因此我们需要考虑两种目标函数:

  1. forward KL: argminθDKL(pdatapθ)\arg\min_\theta D_{KL}(p_{data}\parallel p_\theta)
  2. reverse KL: argminθDKL(pθpdata)\arg\min_\theta D_{KL}(p_\theta \parallel p_{data})

我们将会看到,这两种不同的目标函数导致的结果也不尽相同

Forward KL

对目标函数进行简化得到

argminθDKL(pdatapθ)=argmaxθExpdata[logpθ(x)]\arg\min_\theta D_{KL}(p_{data}\parallel p_\theta) = \arg\max_\theta \mathbb{E}_{x\sim p_{data}}\left[\log p_\theta(x)\right]

实际在计算时,我们会使用 Monte Carlo 的方式对真实分布进行采样然后进行估计。

Forward KL 其代表的含义为,我们从分布 pdatap_{data} 中进行采样,然后求 pθp_\theta 的最大似然估计。最终的结果满足:pdata(x)p_{data}(x) 概率很高时,pθ(x)p_\theta(x) 的概率也需要很高. 这是一种 mean-seeking behavior, 因为 pθp_\theta 必须覆盖 pdatap_{data} 的所有 modes.

一般来说,supervised learning 对应的就是 forward KL. 我们可以证明 forward KL divergence 和 MLE 是等价的。也就是说,最大似然估计得到的分布就是 KL divergence 最小的近似分布。我们将 pdata(x)p_{data}(x)pθ(x)p_\theta(x) 对应的 KL divergence 进行展开得到

θKL=argminθDKL(pdata(x)pθ(x))=argminθpdata(x)pdata(x)pθ(x)dx=argminθpdata(x)logpdata(x)dxpdata(x)logpθ(x)dx=argminθpdata(x)logpθ(x)dx=argmaxθpdata(x)logpθ(x)dx\begin{aligned} \theta_{KL}^* &= \arg\min_{\theta}D_{KL}(p_{data}(x)\parallel p_\theta(x))\\ &= \arg\min_{\theta} \int p_{data}(x)\frac{p_{data}(x)}{p_\theta(x)} dx\\ &= \arg\min_{\theta}\int p_{data}(x)\log p_{data}(x) dx - \int p_{data}(x)\log p_\theta(x)dx \\ &= \arg\min_{\theta} - \int p_{data}(x)\log p_\theta(x)dx \\ &= \arg\max_{\theta} \int p_{data}(x)\log p_\theta(x)dx \end{aligned}

实际上,真实的数据分布 pdata(x)p_{data}(x) 是未知的,我们只有从 pdata(x)p_{data}(x) 采样得到的一批数据 X={x1,,xn}pdata(x)X=\{x_1,\dots,x_n\}\sim p_{data}(x). 基于大数定律,我们有

1ni=1nlogp(θiθ)=Expdata[logpθ(x)]=pdata(x)logpθ(x)dx,n\frac{1}{n}\sum_{i=1}^n\log p(\theta_i\mid \theta)=\mathbb{E}_{x\sim p_{data}}[\log p_\theta(x)] = \int p_{data}(x)\log p_\theta(x)dx, n\to \infty

这样,最大似然估计就与最小化 KL divergence 构建起了联系:

θMLE=argmaxθi=1nlogp(xiθ)=argmaxθpdata(x)logpθ(x)dx=θKL,n.\begin{aligned} \theta_{MLE}^*&=\arg\max_{\theta} \sum_{i=1}^n \log p(x_i\mid \theta)\\ &= \arg\max_{\theta} \int p_{data}(x)\log p_\theta(x)dx\\ &= \theta_{KL}^*, n\to\infty. \end{aligned}

也就是说,当采样样本足够多的时候,最大似然估计和最小 KL divergence 是等价的。监督学习中,我们先从真实分布 pdata(x,y)p_{data}(x,y) 中收集一个数据集 D={(xi,yi)}\mathcal{D}=\{(x_i,y_i)\}, 然后我们会基于模型 fθ:XYf_\theta:\mathcal{X}\to\mathcal{Y} 和损失函数 L:Y×YR\mathcal{L}:\mathcal{Y}\times\mathcal{Y}\to\mathbb{R} 来优化模型参数 θ\theta:

argminθE(xi,yi)D[L(fθ(xi),yi)]\arg\min_\theta \mathbb{E}_{(x_i,y_i)\sim\mathcal{D}}[\mathcal{L}(f_\theta(x_i), y_i)]

对于使用 cross-entropy loss 的分类问题以及 MSE loss 的回归问题,其目标函数实际上都是最小化 KL divergence.

Reverse KL

对目标函数进行简化,得到

argminθDKL(Qθpdata)=argmaxθExQθ[logpdata(x)]ExQθ[logQθ(x)]\arg\min_\theta D_{KL}(Q_\theta\parallel p_{data}) = \arg\max_\theta \mathbb{E}_{x\sim Q_\theta}\left[\log p_{data}(x)\right] - \mathbb{E}_{x\sim Q_\theta}\left[\log Q_\theta(x)\right]

实际在计算时,我们需要知道真实概率分布在采样点上的概率值 pdata(x)p_{data}(x).

Reverse KL 代表的含义为,我们从分布 pθ(x)p_\theta(x) 中进行采样,然后最大化采样点在 pdata(x)p_{data}(x) 中的概率分布。entropy item 鼓励 pθp_\theta 尽可能均匀分布(覆盖广),从而最终结果满足:pθ(x)p_\theta(x) 概率很高时,pdata(x)p_{data}(x) 的概率也需要很高。注意到与 forward KL 不同,Reverse KL 中包含 entropy 项,其避免了 pθp_\theta 收缩到 pdatap_{data} 的某一个 非常窄的 mode 上,最终结果是 pθp_\theta 会找到 pdatap_{data} 的一个 high probability 以及 wide support 的 mode, 然后进行覆盖。

一般来说,reinforcement learning 对应的就是 reverse KL, 这是因为我们希望 policy model 不要离 reference model 太远,并不一定要 cover 所有的 mode.

Experiments on forward and Reverse KL

我们通过概率分布来可视化 forward KL 与 reverse KL 的区别,验证二者不同的模式。

我们假设 pdata=w1N(μ1,σ12)+w2N(μ2,σ22)p_{data}=w_1\mathcal{N}(\mu_1, \sigma_1^2)+w_2\mathcal{N}(\mu_2, \sigma_2^2), 然后我们用一个 normal distribution pθ=N(μ,σ2)p_\theta=\mathcal{N}(\mu, \sigma^2) 来近似 pdatap_{data}, 这里 θ=(μ,σ2)\theta=(\mu, \sigma^2). 对于 forward KL, 我们可以从理论上得出最优解,对应的 μ=w1μ1+w2μ2\mu=w_1\mu_1+w_2\mu_2, 而 reverse KL 则只能通过优化的方式进行求解,并且解与初始化条件相关,下面是相关的实验结果

首先我们令 w1=w2=0.5w_1=w_2=0.5, μ1=μ2=4.0\mu_1=\mu_2=4.0, σ1=σ2=1\sigma_1=\sigma_2=1, reverse KL 的初始化条件为 θ0=(2,1)\theta_0=(2,1), 对应的结果为

visualization of forward KL v.s. reverse KL (1)

接下来我们改变 reverse KL 的初始化条件为 θ0=(2,1)\theta_0=(-2,1), 对应的结果为

visualization of forward KL v.s. reverse KL (2)

可以看到,与前面分析一致,使用 forward KL 时,最终得到的 pθp_\theta 会倾向于拟合分布的中心 (mean seeking), 即 μ(pθ)=μ(pdata)\mu(p_\theta)=\mu(p_{data}), 而使用 reverse KL 时,最终得到的 PP 会倾向于拟合分布的 mode (mode seeking).

Applications to RL

Remark 本节内容主要参考了 Understanding KL Divergence Estimators in RL: From Value Approximation to Gradient Estimation

在本节中,我们将基于 RL 来推导 KL 的相关性质。为了统一,这里我们使用 RL 中常见的 notation 来进行计算

notationdescription
πθ\pi_\thetapolicy model with parameter θ\theta
πref\pi_{ref}reference model
πold\pi_{old}behavior model to sample from
sθ(x)=θlogπθ(x)s_\theta(x)=\nabla_\theta \log \pi_\theta(x)score function
ρ(x)=πθ(x)/πold(x)\rho(x)=\pi_\theta(x)/\pi_{old}(x)importance weight
sg()\mathrm{sg}(\cdot)stop gradient operation

首先 score function 有一个期望为 0 的性质:

Exπθ[sθ(x)]=xπθ(x)θlogπθ(x)dx=xθπθ(x)dx=θxπθ(x)dx=θ1=0\mathbb{E}_{x\sim\pi_\theta}[s_\theta(x)]=\int_x \pi_\theta(x)\nabla_\theta \log \pi_\theta(x)dx = \int_x\nabla_\theta \pi_\theta(x)dx= \nabla_\theta\int_x \pi_\theta(x)dx =\nabla_\theta1 = 0

接下来,我们分别推导 forward KL 和 reverse KL 的梯度。对于 forward KL, 我们有

θDKL(πrefπθ)=πrefθlogπθdx=Eπref[sθ]=Eπθ[πrefπθsθ]\nabla_\theta D_{KL}(\pi_{ref}\parallel \pi_\theta) = -\int \pi_{ref}\nabla_\theta \log \pi_\theta dx=-\mathbb{E}_{\pi_{ref}}[s_\theta] = \boxed{-\mathbb{E}_{\pi_\theta}\left[\frac{\pi_{ref}}{\pi_\theta}s_\theta\right]}

对于 reverse KL,我们有

θDKL(πθπref)=[θπθlogπθπref+πθθlogπθπref]dx=πθsθlogπθπrefdx+πθsθdx=Eπθ[sθlogπθπref]+Eπθ[sθ]=Eπθ[sθlogπθπref]\begin{aligned} \nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})& = \int\left[\nabla_\theta \pi_\theta\cdot\log\frac{\pi_\theta}{\pi_{ref}} + \pi_\theta \nabla_\theta\log \frac{\pi_\theta}{\pi_{ref}}\right]dx\\ &= \int \pi_\theta s_\theta\log \frac{\pi_\theta}{\pi_{ref}}dx + \int \pi_\theta s_\theta dx\\ &= \mathbb{E}_{\pi_\theta}\left[s_\theta\log \frac{\pi_\theta}{\pi_{ref}}\right]+\mathbb{E}_{\pi_\theta}[s_\theta]\\ &= \boxed{\mathbb{E}_{\pi_\theta}\left[s_\theta\log \frac{\pi_\theta}{\pi_{ref}}\right]} \end{aligned}

这里我们使用了 θπθ=πθsθ\nabla_\theta\pi_\theta=\pi_\theta s_\theta , θlogπθ=sθ\nabla_\theta\log\pi_\theta=s_\theta 以及 前面推导的 Eπθ[sθ]=0\mathbb{E}_{\pi_\theta}[s_\theta]=0 的结论.

RL 的目标函数如下

J(θ)=Eτπθ[t=0Tγtr(st,at)]βDKL(πθπref)\mathcal{J}(\theta) = \mathbb{E}_{\tau\sim \pi_\theta}\left[\sum_{t=0}^T\gamma^tr(s_t,a_t)\right] - \beta D_{KL}(\pi_\theta\parallel \pi_{ref})

Ki as Loss

由于 KL divergcne 不能直接计算(或者计算难度较大),因此,基于前面对 KL divergence estimation 的分析,我们可以使用如下代理损失函数来优化我们的模型:

J1(θ)=Eτπθ[t=0Tγtr(st,at)]βki(πθ,πref)\mathcal{J}_1(\theta) = \mathbb{E}_{\tau\sim \pi_\theta}\left[\sum_{t=0}^T\gamma^tr(s_t,a_t)\right] - \beta k_i(\pi_\theta, \pi_{ref})

这里 i{1,2,3}i\in\{1,2,3\} 代表了我们使用的估计。从直觉上来说,这样做是没问题的,但是我们将从数学分析上说明,k1,k3k_1,k_3 作为损失函数都存在问题。其核心问题在于

E[DKL^]=DKLE[θDKL^]=θDKL\mathbb{E}[\widehat{D_{KL}}]=D_{KL} \nRightarrow \mathbb{E}[\nabla_\theta \widehat{D_{KL}}] =\nabla_\theta D_{KL}

也就是说,KL divergence estimation 的无偏性不能推导出 KL divergence estimation gradient 的无偏性,这是因为我们在求期望时,对应的概率分布可能也与参数相关。实际上,我们有

θDKL(πθπref)=θExπθ[DKL^(πθπref)]=Exπθ[θDKL^(πθπref)]+Exπθ[DKL^(πθπref)θπθ(x)]Exπθ[θDKL^(πθπref)]\begin{aligned} \nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref}) &= \nabla_\theta \mathbb{E}_{x\sim\pi_\theta}[\widehat{D_{KL}}(\pi_\theta\parallel \pi_{ref})]\\ &= \mathbb{E}_{x\sim\pi_\theta}[\nabla_\theta \widehat{D_{KL}}(\pi_\theta\parallel \pi_{ref})] + \mathbb{E}_{x\sim\pi_\theta}[\widehat{D_{KL}}(\pi_\theta\parallel \pi_{ref})\nabla_\theta \pi_\theta(x)]\\ &\neq \mathbb{E}_{x\sim\pi_\theta}[\nabla_\theta \widehat{D_{KL}}(\pi_\theta\parallel \pi_{ref})] \end{aligned}

因此 θDKL^\nabla_\theta \widehat{D_{KL}}θDKL\nabla_\theta D_{KL} 的一个有偏估计。

我们分别来分析一下 k1,k2,k3k_1,k_2,k_3 梯度,

θk1=θ[logπrefπθ]=sθθk2=θ[12(logπrefπθ)2]=logπrefπθsθθk3=θ[πrefπθ1logπrefπθ]=(1πrefπθ)sθ\begin{aligned} \nabla_\theta k_1 &= \nabla_\theta\left[-\log \frac{\pi_{ref}}{\pi_\theta}\right] = s_\theta\\ \nabla_\theta k_2 &= \nabla_\theta\left[\frac12\left(\log \frac{\pi_{ref}}{\pi_\theta}\right)^2\right] = -\log \frac{\pi_{ref}}{\pi_\theta}s_\theta\\ \nabla_\theta k_3 &= \nabla_\theta\left[\frac{\pi_{ref}}{\pi_\theta}-1- \log \frac{\pi_{ref}}{\pi_\theta}\right] = \left(1 - \frac{\pi_{ref}}{\pi_\theta}\right)s_\theta \end{aligned}

此时对应的梯度的期望为

Eπθ[θk1]=Eπθ[sθ]=0Eπθ[θk2]=Eπθ[logπrefπθsθ]=θDKL(πθπref)Eπθ[θk3]=Eπθ[(1πrefπθ)sθ]=θDKL(πrefπθ)\begin{aligned} \mathbb{E}_{\pi_{\theta}}[\nabla_\theta k_1] &= \mathbb{E}_{\pi_{\theta}}[s_\theta]=0\\ \mathbb{E}_{\pi_{\theta}}[\nabla_\theta k_2] &= \mathbb{E}_{\pi_{\theta}}\left[-\log \frac{\pi_{ref}}{\pi_\theta}s_\theta\right]=\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})\\ \mathbb{E}_{\pi_{\theta}}[\nabla_\theta k_3] &= \mathbb{E}_{\pi_{\theta}}\left[\left(1 - \frac{\pi_{ref}}{\pi_\theta}\right)s_\theta\right]=\nabla_\theta D_{KL}(\pi_{ref}\parallel \pi_\theta)\\ \end{aligned}

也就是说,k1k_1 估计的梯度的期望为 0,对整体训练没有任何帮助,k3k_3 估计的梯度的期望等价于优化 forward KL, **只有 k2k_2 估计的梯度的期望等价于优化 reverse KL.


在实际代码实现的时候,KL divergence 有两种不同的实现形式:

第一种是根据定义将 KL divergence 作为损失函数的一部分,此时我们的 KL divergence 参与反向传播,对应的实现方式如下

loss = -advantage * log_prob + beta * kl

第二种是只调整 reward, 而不参与反向传播(通过 sg()\mathrm{sg}(\cdot) 实现),对应的实现方式如下所示

shaped_reward = reward - beta * kl.detach()

这两者对于模型的训练影响很大,下面我们分别来进行介绍

KL as Loss

为了统一 on-policy 和 off-policy 两种形式,我们使用一个统一的表达形式,即

L=ρkiL=\rho k_i

此时对应的 RL 目标函数为

J2(θ)=Eτπθ[t=0Tγtr(st,at)]βρki(πθ,πref)\mathcal{J}_2(\theta) = \mathbb{E}_{\tau\sim \pi_\theta}\left[\sum_{t=0}^T\gamma^tr(s_t,a_t)\right] - \beta\rho k_i(\pi_\theta, \pi_{ref})

这里

ρ=πθsg(πold)\rho = \frac{\pi_\theta}{\mathrm{sg}(\pi_{old})}

是 importance weight,

  1. 当算法为 on-policy 时,πθ=πold\pi_\theta=\pi_{old}, ρ1\rho\equiv1.
  2. 当算法为 off-policy 时,ρ=πθ/πold\rho=\pi_\theta/\pi_{old}, θρ=ρsθ\nabla_\theta \rho=\rho s_\theta.

通过这种方式,我们使得参数分布本身不会对梯度计算产生影响,从而使得对期望进行求导和对导数求期望相等,即

θEπold[k]=πold(x)θkdx=Eπold[θk]\nabla_\theta\mathbb{E}_{\pi_{old}}[k] = \int \pi_{old}(x)\nabla_\theta kdx= \mathbb{E}_{\pi_{old}}[\nabla_\theta k]

接下来我们来计算对应估计的梯度的期望,即 E[θ(ρki)]\mathbb{E}[\nabla_\theta(\rho k_i)], 首先我们计算对应的梯度

θ(ρk1)=ρsθk1+rρθ=ρsθ(k1+1)θ(ρk2)=ρsθk2+ρ(logπrefπθsθ)=ρsθ(k1+k2)θ(ρk3)=ρsθk3+ρ(1πrefπθ)sθ=ρsθ(k3+1πrefπθ)=ρsθk1\begin{aligned} \nabla_\theta (\rho k_1) &= \rho s_\theta k_1+r\rho_\theta=\rho s_\theta(k_1+1)\\ \nabla_\theta (\rho k_2) &= \rho s_\theta k_2+\rho\left(-\log \frac{\pi_{ref}}{\pi_\theta}s_\theta\right)=\rho s_\theta(k_1+k_2)\\ \nabla_\theta (\rho k_3) &= \rho s_\theta k_3+\rho\left(1 - \frac{\pi_{ref}}{\pi_\theta}\right)s_\theta=\rho s_\theta\left(k_3+1-\frac{\pi_{ref}}{\pi_\theta}\right)=\rho s_\theta k_1 \end{aligned}

注意到 Eπold[ρki]=Eπθ[ki]\mathbb{E}_{\pi_{old}} [\rho k_i]=\mathbb{E}_{\pi_{\theta}}[k_i] 以及 Eπθ[sθ]=0\mathbb{E}_{\pi_{\theta}}[s_\theta]=0, 我们对上述梯度求期望得到

Eπold[θ(ρk1)]=Eπold[ρsθ(k1+1)]=Eπθ[sθk1]=θDKL(πθπref)Eπold[θ(ρk2)]=Eπold[ρsθ(k1+k2)]=θEπθ[k2]Eπold[θ(ρk3)]=Eπold[ρsθk1]=θDKL(πθπref)\begin{aligned} \mathbb{E}_{\pi_{old}}[\nabla_\theta (\rho k_1)] &= \mathbb{E}_{\pi_{old}}[\rho s_\theta(k_1+1)]=\mathbb{E}_{\pi_{\theta}}[s_\theta k_1]=\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})\\ \mathbb{E}_{\pi_{old}}[\nabla_\theta (\rho k_2)] &= \mathbb{E}_{\pi_{old}}[\rho s_\theta(k_1+k_2)]=\nabla_\theta \mathbb{E}_{\pi_\theta}[k_2]\\ \mathbb{E}_{\pi_{old}}[\nabla_\theta (\rho k_3)] &= \mathbb{E}_{\pi_{old}}[\rho s_\theta k_1]=\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref}) \end{aligned}

这里在计算 Eπold[θ(ρk2)]\mathbb{E}_{\pi_{old}}[\nabla_\theta (\rho k_2)] 时,我们使用了 Leibniz 乘法法则:

Eπold[ρsθ(k1+k2)]=Eπθ[sθk2]+Eπθ[θk2]=θEπθ[k2] \mathbb{E}_{\pi_{old}}[\rho s_\theta(k_1+k_2)]= \mathbb{E}_{\pi_{\theta}}[s_\theta k_2]+\mathbb{E}_{\pi_{\theta}}[\nabla_\theta k_2]=\nabla_\theta\mathbb{E}_{\pi_{\theta}}[k_2]

可以看到,ρk1\rho k_1ρk3\rho k_3 都满足梯度与期望的可交换性,而 ρk2\rho k_2 不满足,为了解决这个问题,我们可以使用 stop gradient, 即 sg(ρ)l2\mathrm{sg}(\rho)l_2, 此时,我们有

θ(sg(ρ)k2)=sg(ρ)θk2=ρsθk1\nabla_\theta(\mathrm{sg}(\rho) k_2) = \mathrm{sg}(\rho)\nabla_\theta k_2 = \rho s_\theta k_1

对其求期望有

Eπold[θ(sg(ρ)k2)]=Eπold[ρsθk1]=Eπθ[sθk1]=θDKL(πθπref)\mathbb{E}_{\pi_{old}}[\nabla_\theta(\mathrm{sg}(\rho) k_2)] = \mathbb{E}_{\pi_{old}}[\rho s_\theta k_1] = \mathbb{E}_{\pi_{\theta}}[s_\theta k_1]=\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})

我们将如上结果总结为下表

Lossgradientexpected gradientobjective
ρk1\rho k_1ρsθ(k1+1)\rho s_\theta (k_1+1)θDKL(πθπref)\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})reverse KL
ρk2\rho k_2ρsθ(k1+k2)\rho s_\theta (k_1+k_2)θEπθ[k2]\nabla_\theta\mathbb{E}_{\pi_{\theta}}[k_2]f-divergence
sg(ρ)k2\mathrm{sg}(\rho) k_2ρsθk1\rho s_\theta k_1θDKL(πθπref)\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})reverse KL
ρk3\rho k_3ρsθk1\rho s_\theta k_1θDKL(πθπref)\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})reverse KL

接下来,我们就可以分析在 on-policy 和 off-policy 场景下分析不同 estimator 的性质了。

如果说,我们显式加入 ρ\rho, 则根据上表我们可以使用上表的 ρk1\rho k_1, sg(ρ)k2\mathrm{sg}(\rho) k_2 以及 ρk3\rho k_3 都可以作为损失函数的代替。

注 实际上 on-policy 场景下使用 k2k_2 也有用的原因在于 θk2=sθk1\nabla_\theta k_2=s_\theta k_1, 也就是 k2k_2ρk3\rho k_3 的梯度相同,其本质上是一个等效梯度。但是其收敛得到的 policy 与 target optimal policy 不同

接下来,我们来分析一下 ρk1,sg(ρ)k2,ρk3\rho k_1, \mathrm{sg}(\rho)k_2, \rho k_3 这三种估计的梯度的 variance, 为了避免混淆,【2】使用了 “projection variance in any direction” 的概念,即任意取一个向量 uu, 然后计算 ρk1\rho k_1 和后两者之间对应的 variance 的差(由于 sg(ρ)k2\mathrm{sg}(\rho)k_2 的梯度与 ρk3\rho k_3 相同,因此这里我们仅计算 ρk3\rho k_3),得到:

var[θ(ρk1)Tu]var[θ(ρk3)Tu]=(Eπold[(θ(ρk1)Tu)2]Eπold2[θ(ρk1)Tu])(Eπold[(θ(ρk3)Tu)2]Eπold2[θ(ρk3)Tu])=Eπold[(θ(ρk1)Tu)2]Eπold[(θ(ρk3)Tu)2]=Eπold[ρ(x)2(s(θ)(x)Tu)2(2k1(x)+1)]\begin{aligned} \mathrm{var}[\nabla_\theta (\rho k_1)^Tu] - \mathrm{var}[\nabla_\theta (\rho k_3)^Tu] &= (\mathbb{E}_{\pi_{old}}[(\nabla_\theta (\rho k_1)^Tu)^2] -\mathbb{E}_{\pi_{old}}^2[\nabla_\theta (\rho k_1)^Tu] ) - (\mathbb{E}_{\pi_{old}}[(\nabla_\theta (\rho k_3)^Tu)^2] -\mathbb{E}_{\pi_{old}}^2[\nabla_\theta (\rho k_3)^Tu] ) \\ &= \mathbb{E}_{\pi_{old}}[(\nabla_\theta (\rho k_1)^Tu)^2] - \mathbb{E}_{\pi_{old}}[(\nabla_\theta (\rho k_3)^Tu)^2]\\ &= \mathbb{E}_{\pi_{old}}[\rho(x)^2(s(\theta)(x)^Tu)^2(2k_1(x)+1)] \end{aligned}

πθ\pi_\thetaπref\pi_{ref} 比较接近时,我们有

πref(x)πθ(x)=1+ϵ(x), where ϵ(x)1\frac{\pi_{ref}(x)}{\pi_\theta(x)} = 1+\epsilon(x), \text{ where } |\epsilon(x)| \ll 1

此时

2k1(x)+1=12log(1+ϵ(x))12ϵ(x)02k_1(x) + 1 = 1-2\log(1+\epsilon(x))\approx 1-2\epsilon(x) \geq 0

从而我们有

var[θ(ρk1)]var[θ(ρk3)]=var[θ(sg(ρ)k2)]\boxed{\mathrm{var}[\nabla_\theta (\rho k_1)]\geq \mathrm{var}[\nabla_\theta (\rho k_3)]=\mathrm{var}[\nabla_\theta (\mathrm{sg}(\rho)k_2)]}

即当 πθ\pi_\thetaπref\pi_{ref} 比较接近时,ρk3\rho k_3 的 variance 比 ρk1\rho k_1 更小,这是由于 ρsθ(k1+1)\rho s_\theta (k_1+1) 额外包含了一个 期望为零的项,这导致了其 variance 比较高。在 DeepSeek-V3.2 中,作者就使用了 ρk3\rho k_3 来降低梯度的 variance, 提高训练的稳定性。

【3】将相关的估计总结为了下表的形式

TypeLossGradientExpected gradientObjectiveBiasedVariance
on/off-policyρk1\rho k_1ρsθ(k1+1)\rho s_\theta (k_1+1)θDKL(πθπref)\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})reverse KLunbiasedhigh
on/off-policyρk2\rho k_2ρsθ(k1+k2)\rho s_\theta (k_1+k_2)θEπθ[k2]\nabla_\theta\mathbb{E}_{\pi_{\theta}}[k_2]f-divergencebiased-
on/off-policysg(ρ)k2\mathrm{sg}(\rho) k_2ρsθk1\rho s_\theta k_1θDKL(πθπref)\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})reverse KLunbiasedlow
on/off-policyρk3\rho k_3ρsθk1\rho s_\theta k_1θDKL(πθπref)\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})reverse KLunbiasedlow

【3】还强调了一点就是我们的损失函数必须显式包含 ρ\rho, 在 on-policy 场景下,虽然 ρ1\rho\equiv1, 但是在反向传播时我们通过 θρ=sθ\nabla_\theta \rho=s_\theta 保留了采样信息从而避免了梯度估计期望的错配问题。

对于 ρk1\rho k_1 variance 比较高的特点,我们还可以采用 variance reduction 的方法来降低不同估计的 variance. 【TODO】

analytic gradient 当 action space 有限时,我们还可以使用解析梯度【TODO】

As a Reward Reshaping Item

接下来我们来探究一下第二种形式,即 KL divergence 只影响最终的 reward, 而不参与反向传播。对应的代理目标函数形式为

J3(θ)=Eτπθ[R]β sg(ki(πθ,πref))\mathcal{J}_3(\theta) = \mathbb{E}_{\tau\sim \pi_\theta}\left[R\right] - \beta\ \mathrm{sg}(k_i(\pi_\theta, \pi_{ref}))

这里 R=t=0Tγtr(st,at)R=\sum_{t=0}^T\gamma^tr(s_t,a_t) 为 accumulative reward

首先,基于前面分析,我们可以得到原始目标函数的梯度为

θJ(θ)=θEπθ[R]βθDKL(πθ,πref)=Eπθ[sθR]βEπθ[sθlogπθπref]=Eπθ[sθ(Rβk1)]\begin{aligned} \nabla_\theta \mathcal{J}(\theta) &= \nabla_\theta\mathbb{E}_{\pi_\theta}\left[R\right] - \beta \nabla_\theta D_{KL}(\pi_\theta, \pi_{ref})\\ &= \mathbb{E}_{\pi_\theta}\left[s_\theta R\right]-\beta \mathbb{E}_{\pi_\theta}\left[s_\theta\log \frac{\pi_\theta}{\pi_{ref}}\right]\\ &= \mathbb{E}_{\pi_\theta}\left[s_\theta(R-\beta k_1) \right] \end{aligned}

代理目标函数的梯度为

θJ3(θ)=Eπθ[sθ(Rβki)]\nabla_\theta \mathcal{J}_3(\theta) = \mathbb{E}_{\pi_\theta}\left[s_\theta(R-\beta k_i) \right]

显然,当我们使用 k1k_1 时,我们有 θJ(θ)=θJ3(θ)\nabla_\theta \mathcal{J}(\theta)=\nabla_\theta \mathcal{J}_3(\theta).

当我们使用 k2k_2 时,带入 k2k_2 表达式易知 θJ3(θ)θJ(θ)\nabla_\theta \mathcal{J}_3(\theta)\neq \nabla_\theta \mathcal{J}(\theta),

当我们使用 k3k_3 时,

Eπθ[sθki]=Eπθ[sθ(πrefπθ1logπrefπθ)]=Eπθ[sθπrefπθ]Eπθ[sθ]Eπθ[sθlogπrefπθ]=sθk1θDKL(πrefπθ)\begin{aligned} \mathbb{E}_{\pi_\theta}\left[s_\theta k_i \right] &= \mathbb{E}_{\pi_\theta}\left[s_\theta \left(\frac{\pi_{ref}}{\pi_\theta}-1- \log \frac{\pi_{ref}}{\pi_\theta} \right)\right]\\ &= \mathbb{E}_{\pi_\theta}\left[s_\theta \frac{\pi_{ref}}{\pi_\theta}\right] - \mathbb{E}_{\pi_\theta}\left[s_\theta \right] - \mathbb{E}_{\pi_\theta}\left[s_\theta \log \frac{\pi_{ref}}{\pi_\theta} \right]\\ &=s_\theta k_1 -\nabla_\theta D_{KL}(\pi_{ref}\parallel \pi_\theta) \end{aligned}

此时,θJ3(θ)θJ(θ)\nabla_\theta \mathcal{J}_3(\theta)\neq \nabla_\theta \mathcal{J}(\theta). 因此,在 on-policy 场景下,只有 k1k_1 对应的梯度是无偏的

在 off-policy 场景下,由于 Off-policy 只影响 RR 的计算,因此原始目标函数和代理目标函数的梯度仍然保持不变,on-policy 场景的结论也适用。

总之,当我们将 KL divergence 作为 reward reshaping item 时,只有 k1k_1 产生的梯度是无偏的。

Comparison of Two Paradigms

接下来我们来比较一下 KL divergence 作为 loss 和 reward shaping item 的异同之处。首先,两者对于梯度的贡献分别为

ρsθk1Eπold[ρsθk1]\begin{align} &\rho s_\theta k_1\tag{loss}\\ & \mathbb{E}_{\pi_{old}}[\rho s_\theta k_1]\tag{reward shaping} \end{align}

即两者在期望上时一致的。但是两者也存在不一致的地方,即 KL divergence 作为 loss 时不会影响 RR, 而作为 reward shaping item 时会影响。因此这就导致两者的优化方向不一致。

Experiments

首先,我们来验证前面的结论,我们构造一个包含 100100 个 arms 的 multi-arm bandits, 然后令

πref=ϵ1,π=ϵ1+ϵ2\pi_{ref}=\epsilon_1, \pi= \epsilon_1+\epsilon_2

其中 ϵ1,ϵ2N(0,1)\epsilon_1,\epsilon_2\sim\mathcal{N}(0,1), 我们实验 100 次然后取平均值,然后分别计算 estimator 与真实 KL divergence 之间的 MSE 和 estimator gradient 与真实 kl divergence gradient 的 RMSE, 结果如下图所示

bias of KL divergence estimators and their gradients

可以看到,这验证了我们之前分析的结论,即 k1k_1k3k_3 是无偏估计,而在计算梯度时,只有 k2k_2 梯度的期望与真实 KL divergence 的梯度相同。

Overview

我们在本节总结前面的分析,如下表所示

TypeLossGradientExpected gradientObjectiveBiasedVariance
on-policyk1k_1sθs_\theta00constantsbiased-
on-policyk2k_2logrsθ-\log r s_\thetaθDKL(πθπref)\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})reverse KLunbiased
on-policyk3k_3(1r)sθ(1-r)s_\thetaθDKL(πrefπθ)\nabla_\theta D_{KL}(\pi_{ref}\parallel \pi_\theta)forward KLbiased-
on/off-policyρk1\rho k_1ρsθ(k1+1)\rho s_\theta (k_1+1)θDKL(πθπref)\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})reverse KLunbiasedhigh
on/off-policyρk2\rho k_2ρsθ(k1+k2)\rho s_\theta (k_1+k_2)θEπθ[k2]\nabla_\theta\mathbb{E}_{\pi_{\theta}}[k_2]f-divergencebiased-
on/off-policysg(ρ)k2\mathrm{sg}(\rho) k_2ρsθk1\rho s_\theta k_1θDKL(πθπref)\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})reverse KLunbiasedlow
on/off-policyρk3\rho k_3ρsθk1\rho s_\theta k_1θDKL(πθπref)\nabla_\theta D_{KL}(\pi_\theta\parallel \pi_{ref})reverse KLunbiasedlow
on/off-policyρsg(k1)\rho\mathrm{sg}(k_1)---unbiased-
on/off-policyρsg(k2)\rho \mathrm{sg}(k_2)---biased-
on/off-policyρsg(k3)\rho \mathrm{sg}(k_3)---biased-

Conclusion

在本文中,我们详细介绍了 KL-divergence 的基本性质,相关估计方法以及在机器学习特别是 RL 领域中的应用。最终结论为:

  1. 如果希望稳定可控,则将 KL divergence 作为 loss item; 如果希望更灵活,与奖励信号结合的话,则将其作为 reward shaping item.
  2. 使用 KL divergence 作为 loss item 时,on-policy 场景下使用 k2k_2 近似 KL divergence 效果最好;off-policy 场景下,使用 sg(ρ)k2,ρk3\mathrm{sg}(\rho)k_2, \rho k_3 效果最好
  3. 使用 KL divergence 作为 reward shaping item 时,k1k_1 的效果最好