KL Divergence of Gaussians
Preliminary: KL Divergence
Kullback–Leibler (KL) Divergence, aka the relative entropy or I-divergence is a distance metric that quantifies the difference between two probability distributions. We denote the KL Divergence of P P P from Q Q Q with:
D K L ( P ∥ Q ) = ∑ x P ( x ) lg P ( x ) Q ( x ) D_\mathrm{KL} \left( P \Vert Q \right) = \sum_{x} P(x) \lg \frac{P(x)}{Q(x)}
D K L ( P ∥ Q ) = x ∑ P ( x ) lg Q ( x ) P ( x )
For distributions of continuous random variable:
D K L ( P ∥ Q ) = ∫ p ( x ) lg p ( x ) q ( x ) d x D_\mathrm{KL} \left( P \Vert Q \right) = \int p(x) \lg \frac{p(x)}{q(x)} \mathrm{d}x
D K L ( P ∥ Q ) = ∫ p ( x ) lg q ( x ) p ( x ) d x
Clearly that KL Divergence is not symmetric, for D K L ( P ∥ Q ) ≠ D K L ( Q ∥ P ) D_\mathrm{KL}(P \Vert Q) \neq D_\mathrm{KL}(Q \Vert P) D K L ( P ∥ Q ) = D K L ( Q ∥ P ) .
KL Divergence of Gaussians
Gaussian Distributions
Recall that
N ( x ; μ , Σ ) = exp { − 1 2 ( x − μ ) T Σ − 1 ( x − μ ) } \mathcal{N}(x; \mu, \Sigma) = \exp \{ - \frac{1}{2} (x - \mu)^T \Sigma^{-1} (x - \mu)\}
N ( x ; μ , Σ ) = exp { − 2 1 ( x − μ ) T Σ − 1 ( x − μ ) }
And here’s also a trick in processing the index term: we know that ( x − μ ) T Σ − 1 ( x − μ ) ∈ R (x - \mu)^T \Sigma^{-1} (x - \mu) \in \mathbb{R} ( x − μ ) T Σ − 1 ( x − μ ) ∈ R , therefore with the trace trick we have:
( x − μ ) T Σ − 1 ( x − μ ) = t r ( ( x − μ ) T Σ − 1 ( x − μ ) ) = t r ( ( x − μ ) T ( x − μ ) Σ − 1 ) (x - \mu)^T \Sigma^{-1} (x - \mu) = \mathrm{tr} \left( (x - \mu)^T \Sigma^{-1} (x - \mu) \right) = \mathrm{tr} \left( (x - \mu)^T (x - \mu) \Sigma^{-1} \right)
( x − μ ) T Σ − 1 ( x − μ ) = t r ( ( x − μ ) T Σ − 1 ( x − μ ) ) = t r ( ( x − μ ) T ( x − μ ) Σ − 1 )
KL Divergence of two Gaussians
Consider two Gaussian distributions:
p ( x ) = N ( x ; μ 1 , Σ 1 ) , q ( x ) = 1 ( 2 π ) k ∣ Σ ∣ N ( x ; μ 2 , Σ 2 ) p(x) = \mathcal{N}(x; \mu_1, \Sigma_1), q(x) = \frac{1}{\sqrt{(2 \pi)^k \vert \Sigma \vert}} \mathcal{N}(x; \mu_2, \Sigma_2)
p ( x ) = N ( x ; μ 1 , Σ 1 ) , q ( x ) = ( 2 π ) k ∣ Σ ∣ 1 N ( x ; μ 2 , Σ 2 )
We have
D K L ( p ∥ q ) = E p [ log p − log q ] = 1 2 E p [ log ∣ Σ q ∣ ∣ Σ p ∣ − ( x − μ 1 ) T Σ 1 − 1 ( x − μ 1 ) + ( x − μ 2 ) T Σ 2 − 1 ( x − μ 2 ) ] = 1 2 log ∣ Σ q ∣ ∣ Σ p ∣ + E p [ − ( x − μ 1 ) T Σ 1 − 1 ( x − μ 1 ) + ( x − μ 2 ) T Σ 2 − 1 ( x − μ 2 ) ] \begin{aligned}
D_\mathrm{KL} (p \Vert q)
&= \mathbb{E}_p \left[ \log p - \log q \right] \\
&= \frac{1}{2} \mathbb{E}_p \left[ \log \frac{\vert \Sigma_q \vert}{\vert \Sigma_p \vert} - (x - \mu_1)^T \Sigma_1^{-1} (x - \mu_1) + (x - \mu_2)^T \Sigma_2^{-1} (x - \mu_2) \right] \\
&= \frac{1}{2} \log \frac{\vert \Sigma_q \vert}{\vert \Sigma_p \vert} + \mathbb{E}_p \left[ - (x - \mu_1)^T \Sigma_1^{-1} (x - \mu_1) + (x - \mu_2)^T \Sigma_2^{-1} (x - \mu_2) \right]
\end{aligned}
D K L ( p ∥ q ) = E p [ log p − log q ] = 2 1 E p [ log ∣ Σ p ∣ ∣ Σ q ∣ − ( x − μ 1 ) T Σ 1 − 1 ( x − μ 1 ) + ( x − μ 2 ) T Σ 2 − 1 ( x − μ 2 ) ] = 2 1 log ∣ Σ p ∣ ∣ Σ q ∣ + E p [ − ( x − μ 1 ) T Σ 1 − 1 ( x − μ 1 ) + ( x − μ 2 ) T Σ 2 − 1 ( x − μ 2 ) ]
Using the trick mentioned above we have
D K L ( p ∥ q ) = 1 2 log ∣ Σ q ∣ ∣ Σ p ∣ + E p [ t r ( ( x − μ 2 ) T ( x − μ 2 ) Σ 2 − 1 ) ] − E p [ t r ( ( x − μ 1 ) T ( x − μ 1 ) Σ 1 − 1 ) ] = 1 2 log ∣ Σ q ∣ ∣ Σ p ∣ + t r ( E p [ ( x − μ 2 ) T ( x − μ 2 ) Σ 2 − 1 ] ) − t r ( E p [ ( x − μ 1 ) T ( x − μ 1 ) Σ 1 − 1 ] ) = 1 2 log ∣ Σ q ∣ ∣ Σ p ∣ + t r ( E p [ ( x − μ 2 ) T ( x − μ 2 ) ] Σ 2 − 1 ) − t r ( E p [ ( x − μ 1 ) T ( x − μ 1 ) ] Σ 1 − 1 ) \begin{aligned}
D_\mathrm{KL} (p \Vert q)
&= \frac{1}{2} \log \frac{\vert \Sigma_q \vert}{\vert \Sigma_p \vert} + \mathbb{E}_p \left[ \mathrm{tr} \left( (x - \mu_2)^T (x - \mu_2) \Sigma_2^{-1} \right) \right] - \mathbb{E}_p \left[ \mathrm{tr} \left( (x - \mu_1)^T (x - \mu_1) \Sigma_1^{-1} \right) \right] \\
&= \frac{1}{2} \log \frac{\vert \Sigma_q \vert}{\vert \Sigma_p \vert} + \mathrm{tr} \left( \mathbb{E}_p \left[ (x - \mu_2)^T (x - \mu_2) \Sigma_2^{-1} \right] \right) - \mathrm{tr} \left( \mathbb{E}_p \left[ (x - \mu_1)^T (x - \mu_1) \Sigma_1^{-1} \right] \right) \\
&= \frac{1}{2} \log \frac{\vert \Sigma_q \vert}{\vert \Sigma_p \vert} + \mathrm{tr} \left( \mathbb{E}_p \left[ (x - \mu_2)^T (x - \mu_2) \right] \Sigma_2^{-1} \right) - \mathrm{tr} \left( \mathbb{E}_p \left[ (x - \mu_1)^T (x - \mu_1) \right] \Sigma_1^{-1} \right)
\end{aligned}
D K L ( p ∥ q ) = 2 1 log ∣ Σ p ∣ ∣ Σ q ∣ + E p [ t r ( ( x − μ 2 ) T ( x − μ 2 ) Σ 2 − 1 ) ] − E p [ t r ( ( x − μ 1 ) T ( x − μ 1 ) Σ 1 − 1 ) ] = 2 1 log ∣ Σ p ∣ ∣ Σ q ∣ + t r ( E p [ ( x − μ 2 ) T ( x − μ 2 ) Σ 2 − 1 ] ) − t r ( E p [ ( x − μ 1 ) T ( x − μ 1 ) Σ 1 − 1 ] ) = 2 1 log ∣ Σ p ∣ ∣ Σ q ∣ + t r ( E p [ ( x − μ 2 ) T ( x − μ 2 ) ] Σ 2 − 1 ) − t r ( E p [ ( x − μ 1 ) T ( x − μ 1 ) ] Σ 1 − 1 )
Interestingly, E p [ ( x − μ 1 ) T ( x − μ 1 ) ] = Σ 1 \mathbb{E}_p \left[ (x - \mu_1)^T (x - \mu_1) \right] = \Sigma_1 E p [ ( x − μ 1 ) T ( x − μ 1 ) ] = Σ 1 , so the last term above is equal to:
t r ( Σ 1 Σ 1 − 1 ) = k \mathrm{tr} \left( \Sigma_1 \Sigma_1^{-1} \right) = k
t r ( Σ 1 Σ 1 − 1 ) = k
But the second term involves 2 difference distributions. Using the trick from:
t r ( E p [ ( x − μ 2 ) T ( x − μ 2 ) ] Σ 2 − 1 ) = ( μ 1 − μ 2 ) T Σ 2 − 1 ( μ 1 − μ 2 ) + t r ( Σ 2 − 1 Σ 1 ) \begin{aligned}
\mathrm{tr} \left( \mathbb{E}_p \left[ (x - \mu_2)^T (x - \mu_2) \right] \Sigma_2^{-1} \right)
&= (\mu_1 - \mu_2)^T \Sigma_2^{-1} (\mu_1 - \mu_2) + \mathrm{tr} \left( \Sigma_2^{-1} \Sigma_1 \right)
\end{aligned}
t r ( E p [ ( x − μ 2 ) T ( x − μ 2 ) ] Σ 2 − 1 ) = ( μ 1 − μ 2 ) T Σ 2 − 1 ( μ 1 − μ 2 ) + t r ( Σ 2 − 1 Σ 1 )
And finally:
D K L ( p ∥ q ) = 1 2 [ ( μ 1 − μ 2 ) T Σ 2 − 1 ( μ 1 − μ 2 ) − k + log ∣ Σ 2 ∣ ∣ Σ 1 ∣ + t r ( Σ 2 Σ 1 ) ] \begin{aligned}
D_\mathrm{KL} (p \Vert q)
&= \frac{1}{2} \left[ (\mu_1 - \mu_2)^T \Sigma_2^{-1} (\mu_1 - \mu_2) - k + \log \frac{\vert \Sigma_2 \vert}{\vert \Sigma_1 \vert} + \mathrm{tr} \left( \Sigma_2 \Sigma_1 \right) \right]
\end{aligned}
D K L ( p ∥ q ) = 2 1 [ ( μ 1 − μ 2 ) T Σ 2 − 1 ( μ 1 − μ 2 ) − k + log ∣ Σ 1 ∣ ∣ Σ 2 ∣ + t r ( Σ 2 Σ 1 ) ]
Further in programming we consider
D K L ( p ∥ q ) = 1 2 ∣ X ∣ ∑ x ∈ X ∑ i = 1 k [ − 1 + ( μ 1 , i − μ 2 , i ) 2 σ 2 , i + log σ 2 , i − log σ 1 , i + σ 1 , i σ 2 , i ] \begin{aligned}
D_\mathrm{KL} (p \Vert q)
&= \frac{1}{2 \vert \mathcal{X} \vert} \sum_{x \in \mathcal{X}} \sum_{i=1}^k \left[ -1 + \frac{(\mu_{1, i} - \mu_{2, i})^2}{\sigma_{2, i}} + \log \sigma_{2, i} - \log \sigma_{1, i} + \sigma_{1, i} \sigma_{2, i} \right]
\end{aligned}
D K L ( p ∥ q ) = 2 ∣ X ∣ 1 x ∈ X ∑ i = 1 ∑ k [ − 1 + σ 2 , i ( μ 1 , i − μ 2 , i ) 2 + log σ 2 , i − log σ 1 , i + σ 1 , i σ 2 , i ]
1 2 mu1, logvar1, mu2, logvar2 = output kl_loss = torch.sum ((mu1 - mu2).pow (2 ) / logvar2.exp() + logvar2 - logvar1 + (logvar1 + logvar2).exp(), dim=-1 ).mean()
KL Divergence from N ( 0 , I ) \mathcal{N}(0, I) N ( 0 , I )
Recall that in models like VAE and cVAE, the evidence lower bound (VLB) is
V L B = E q ϕ ( z ∣ x ) p θ ( x ∣ z ) − D K L ( q ϕ ( z ∣ x ) ∥ p ( z ) ) \mathrm{VLB} = \mathbb{E}_{q_\phi(z \vert x)} p_\theta (x \vert z) - D_\mathrm{KL} \left( q_\phi(z \vert x) \Vert p(z) \right)
V L B = E q ϕ ( z ∣ x ) p θ ( x ∣ z ) − D K L ( q ϕ ( z ∣ x ) ∥ p ( z ) )
We want the optimal parameters that
( ϕ ∗ , θ ∗ ) = arg max ϕ , θ V L B (\phi^*, \theta^*) = \argmax_{\phi, \theta} \mathrm{VLB}
( ϕ ∗ , θ ∗ ) = ϕ , θ a r g m a x V L B
Note that to align with most literature, here the KL Divergence is from p p p to q q q .
And we assume that p ( z ) = N ( z ; 0 , I ) p(z) = \mathcal{N} (z; 0, I) p ( z ) = N ( z ; 0 , I ) and q ϕ ( z ∣ x ) = N ( z ; μ , Σ ) q_\phi(z \vert x) = \mathcal{N} (z; \mu, \Sigma) q ϕ ( z ∣ x ) = N ( z ; μ , Σ ) , where Σ = d i a g { σ 1 , … , σ k } \Sigma = \mathrm{diag} \{ \sigma_1, \dots, \sigma_k\} Σ = d i a g { σ 1 , … , σ k } for which:
D K L ( q ϕ ( z ∣ x ) ∥ p ( z ) ) = E p ( z ) [ ∥ μ 2 ∥ 2 2 − k − log ∣ Σ 2 ∣ ] = − 1 2 ∣ X ∣ ∑ x ∈ X ∑ i = 1 k [ 1 + ( log σ 2 , i ) − μ 2 i 2 − σ 2 , i ] \begin{aligned}
D_\mathrm{KL} \left( q_\phi(z \vert x) \Vert p(z) \right)
&= \mathbb{E}_{p(z)} \left[ \Vert \mu_2 \Vert^2_2 - k - \log\vert \Sigma_2 \vert \right] \\
&= - \frac{1}{2 \vert \mathcal{X} \vert} \sum_{x \in \mathcal{X}} \sum_{i=1}^k \left[ 1 + \left(\log \sigma_{2,i} \right) - \mu_{2i}^2 - \sigma_{2, i} \right]
\end{aligned}
D K L ( q ϕ ( z ∣ x ) ∥ p ( z ) ) = E p ( z ) [ ∥ μ 2 ∥ 2 2 − k − log ∣ Σ 2 ∣ ] = − 2 ∣ X ∣ 1 x ∈ X ∑ i = 1 ∑ k [ 1 + ( log σ 2 , i ) − μ 2 i 2 − σ 2 , i ]
In coding, we get mu
and logvar
(which is log Σ \log \Sigma log Σ ), and the KL Divergence loss term can be computed with:
1 2 mu, logvar = output kld_loss = -0.5 * torch.sum (1 + logvar - mu.pow (2 ) - logvar.exp(), dim=-1 ).mean()
KL Divergence by Sampling
Pending…
References
Mr.Esay’s Blog: KL Divergence between 2 Gaussian Distributions
Kingma, D. P., & Welling, M. (2013). Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114.
Sohn, K., Lee, H., & Yan, X. (2015). Learning structured output representation using deep conditional generative models. Advances in neural information processing systems, 28.
Petersen, K. B., & Pedersen, M. S. (2008). The matrix cookbook. Technical University of Denmark, 7(15), 510.
Kullback, S., & Leibler, R. A. (1951). On information and sufficiency. The annals of mathematical statistics, 22(1), 79-86.
Csiszár, I. (1975). I-divergence geometry of probability distributions and minimization problems. The annals of probability, 146-158.