文章

KL散度的估计方法

2025年4月24日 Harry

现在各个训练框架中的KL散度计算都参考了John Schulman的博客

如果要准确的计算两个分布之间的KL散度,离散分布需要遍历所有的 xx,而连续分布需要求积分,在复杂的深度网络中几乎是无法计算的。因此在训练过程中,更多的是通过蒙特卡洛,对训练数据进行采样。

DKL(PQ)=P(x)log(P(x)Q(x))=ExP[log(P(x)Q(x))]D_{KL}(P||Q)=\sum P(x)log(\frac{P(x)}{Q(x)}) = \mathbb{E}_{x\sim P}\left[log(\frac{P(x)}{Q(x)}) \right]

k1k_1#

最简单的估计 k1k_1log(P(x)Q(x))=log(r)log(\frac{P(x)}{Q(x)})=-log(r), 其中r=Q(x)P(x)r=\frac{Q(x)}{P(x)} ,它是无偏的,有正确的平均值,但方差极高,尤其是 log(r)-log(r) 的方向有正有负,而KL散度是不会小于0的,因此这个简单的估计在工程中会引入极大的不稳定。

这里提前算一下rr的期望,由于我们是在P(x)P(x)的视角下求期望,因此:

ExP(r)=ExP[Q(x)P(x)]=x[P(x)Q(x)P(x)]=xQ(x)=1\mathbb{E}_{x\sim P}(r)=\mathbb{E}_{x\sim P}\left[ \frac{Q(x)}{P(x)} \right]=\sum_{x} \left[ P(x) \frac{Q(x)}{P(x)}\right] =\sum_{x} Q(x)=1

k2k_2#

估计量 k2k_212(log(r))2\frac{1}{2}(log(r))^2 是从f-散度(KL散度是它的一个特例)推导出来的,它始终>0,这让它的方差很低,不会像k1k_1那样在正负之间来回横跳。根据泰勒展开,k2k_2k1k_1的期望值在二阶近似上是完全相等的,因此当P和Q两个分布接近时,k2k_2可以用来估计KL散度,但由于在三阶近似上存在差异,因此k2k_2是有偏估计,但它低方差、非负

k3k_3#

那么怎么能得到一个无偏且低方差,又始终为正的估计呢?

k1k_1打个补丁,这个补丁有两个要求:

  1. 保持k1k_1的无偏:补丁的期望为0,在多次采样中,这个补丁的平均值为0
  2. 消除负数k1k_1加上这个补丁后,必须始终为正

John Schulman找到了一个补丁:r1r-1log(r)-log(r)是一个开口向上的凸函数,在几何学中它的曲线始终在切线的上方,而 1r1-r 正是 log(r)-log(r)r=1r=1 处的切线(注意这里的logloglnln)。因此:

log(r)1r-log(r) \ge 1-r log(r)+r10-log(r) + r - 1 \ge 0

rr 的期望是1,因此 r1r-1 的期望是0,完美的满足上述两个条件,因此估计量 k3k_3 是一个无偏、低方差、非负 的KL散度估计量。

John Schulman的博客中举了两个例子,在 P=N(0,1)P=N(0, 1)Q=N(0.1,1)Q=N(0.1, 1) 两个分布相差不大时(此时的KL散度是0.005),可以看出k1k_1的方差非常大,k2k_2的偏差较小:

biasstdev
k1k_1020
k2k_20.0021.42
k3k_301.42

而在 P=N(0,1)P=N(0, 1)Q=N(0.5,1)Q=N(0.5, 1) 两个分布相差较大时(此时KL散度是0.5),可以看出 k2k_2此时的偏差就较大了。

biasstdev
k1k_102
k2k_20.251.73
k3k_301.7

尽管 k3k_3 这个估计量有这么多优秀的特征,但 k3k_3 存在一个工程上的计算问题,训练框架在实际算logit的概率时,为了防止极小值的下溢出,训练框架中通常只会计算 log(p(x))log(p(x)),而不是 p(x)p(x),所以在计算 k3k_3 时,我们只有 log(r)log(r),为了还原k3=r1logrk_3 = r - 1 - \log r 中的rr,必须要加一个 ee 的底数 r=elog(r)r=e^{log (r)} ,这导致当 PPQQ相差较远时,例如当 log(r)=20log(r)=20r=elogr=e20r=e^{log r}=e^{20},是一个非常夸张的数字,因此在工程实践中,k2k_2 仍然是常见的估计方法。而使用 k3k_3时,则通常会加上一个clamp避免数值问题(GPRO和PPO的常见做法)。