现在各个训练框架中的KL散度计算都参考了John Schulman的博客。
如果要准确的计算两个分布之间的KL散度,离散分布需要遍历所有的 x,而连续分布需要求积分,在复杂的深度网络中几乎是无法计算的。因此在训练过程中,更多的是通过蒙特卡洛,对训练数据进行采样。
DKL(P∣∣Q)=∑P(x)log(Q(x)P(x))=Ex∼P[log(Q(x)P(x))]
k1#
最简单的估计 k1 为 log(Q(x)P(x))=−log(r), 其中r=P(x)Q(x) ,它是无偏的,有正确的平均值,但方差极高,尤其是 −log(r) 的方向有正有负,而KL散度是不会小于0的,因此这个简单的估计在工程中会引入极大的不稳定。
这里提前算一下r的期望,由于我们是在P(x)的视角下求期望,因此:
Ex∼P(r)=Ex∼P[P(x)Q(x)]=x∑[P(x)P(x)Q(x)]=x∑Q(x)=1
k2#
估计量 k2:21(log(r))2 是从f-散度(KL散度是它的一个特例)推导出来的,它始终>0,这让它的方差很低,不会像k1那样在正负之间来回横跳。根据泰勒展开,k2和k1的期望值在二阶近似上是完全相等的,因此当P和Q两个分布接近时,k2可以用来估计KL散度,但由于在三阶近似上存在差异,因此k2是有偏估计,但它低方差、非负。
k3#
那么怎么能得到一个无偏且低方差,又始终为正的估计呢?
给 k1打个补丁,这个补丁有两个要求:
- 保持k1的无偏:补丁的期望为0,在多次采样中,这个补丁的平均值为0
- 消除负数:k1加上这个补丁后,必须始终为正
John Schulman找到了一个补丁:r−1,−log(r)是一个开口向上的凸函数,在几何学中它的曲线始终在切线的上方,而 1−r 正是 −log(r) 在 r=1 处的切线(注意这里的log是ln)。因此:
−log(r)≥1−r
−log(r)+r−1≥0
而 r 的期望是1,因此 r−1 的期望是0,完美的满足上述两个条件,因此估计量 k3 是一个无偏、低方差、非负 的KL散度估计量。
John Schulman的博客中举了两个例子,在 P=N(0,1) 和 Q=N(0.1,1) 两个分布相差不大时(此时的KL散度是0.005),可以看出k1的方差非常大,k2的偏差较小:
| bias | stdev |
|---|
| k1 | 0 | 20 |
| k2 | 0.002 | 1.42 |
| k3 | 0 | 1.42 |
而在 P=N(0,1) 和 Q=N(0.5,1) 两个分布相差较大时(此时KL散度是0.5),可以看出 k2此时的偏差就较大了。
| bias | stdev |
|---|
| k1 | 0 | 2 |
| k2 | 0.25 | 1.73 |
| k3 | 0 | 1.7 |
尽管 k3 这个估计量有这么多优秀的特征,但 k3 存在一个工程上的计算问题,训练框架在实际算logit的概率时,为了防止极小值的下溢出,训练框架中通常只会计算 log(p(x)),而不是 p(x),所以在计算 k3 时,我们只有 log(r),为了还原k3=r−1−logr 中的r,必须要加一个 e 的底数 r=elog(r) ,这导致当 P 和Q相差较远时,例如当 log(r)=20,r=elogr=e20,是一个非常夸张的数字,因此在工程实践中,k2 仍然是常见的估计方法。而使用 k3时,则通常会加上一个clamp避免数值问题(GPRO和PPO的常见做法)。