KL Divergence between Matrix-Variate Normal Distributions

The Matrix-Variate Normal (MVN) Distribution

Matrix-Variate Gaussian/Normal distributions (MVN distributions) have been first mentioned in [1] and are a generalization of Multivariate Gaussian Distributions. They formalize the idea of seeing an entire matrix X \in \mathbb{R}^{d_h \times d_w} as a random variable.

In contrast to describing each single entry X_{(i, j)} as a Gaussian-distributed scalar that is independent from the other entries, the parametrisation of the MVN distribution allows for incorporating structural ties between the entries of a matrix.

A (real-valued) MVN Distribution \mathcal{MVN}_{d_h, d_w}(M, U, V) for random variables X \in \mathbb{R}^{d_h \times d_w} is characterised by the matrix height d_h, matrix width d_w, mean matrix M \in \mathbb{R}^{d_h \times d_w} and covariance matrix components U \in \mathbb{R}^{d_h \times d_h} (determines row covariance) and V \in \mathbb{R}^{d_w \times d_w} (determines column covariance).

It is equivalent to the multivariate normal distribution

(1)   \begin{align*}     \mathcal{N}(\text{vec}(M), U \otimes V) \text{ , } \end{align*}

where vec(M) \in \mathbb{R}^{d_h d_w} denotes the concatenation of columns of M, and U \otimes V \in \mathbb{R}^{d_h d_w \times d_h d_w} denotes the Kronecker product of U and V (the distribution’s covariance matrix):

(2)   \begin{align*}     U \otimes V =     \begin{pmatrix}         U_{(1,1)} V & \cdots & U_{(1,d_h)} V \\         \vdots  & \ddots & \vdots  \\         U_{(d_h,1)} V & \cdots & U_{(d_h,d_h)} V     \end{pmatrix} \end{align*}

The KL Divergence between MVN Distributions

The Kullback-Leibler (KL) divergence (also called “relative entropy”) is an asymmetric distance that, to put it bluntly, describes the “difference” between two probability distributions. When using the a probability distribution q as a model where the actual distribution is p (both over the domain \mathcal{X}), the KL divergence D_{KL}\left(p(x) \text{ } || \text{ } q(x)\right) quantifies the expected amount of excess “surprise” [2]. For continuous distributions it can be written as follows, referring to the distributions’ density functions p(x) and q(x):

(3)   \begin{align*}     D_{KL}\left(p(x) \text{ } || \text{ } q(x)\right) =     \int\limits_{x \in \mathcal{X}} p(x) \log \left(\frac{p(x)}{q(x)} \right) \end{align*}

The probability density function for an MVN distribution \mathcal{MVN}_{d_h, d_w} (M_p, U_p ,V_p) on X \in \mathbb{R}^{d_h \times d_w} is given by

(4)   \begin{align*}     p(X) = \frac{1}{(2\pi)^{\frac{d_h d_w}{2}} |U|^{(\frac{d_w}{2})} |V|^{(\frac{d_h}{2})}} \exp \left(-\frac{1}{2}\text{tr}\left(V^{-1}(X-M)^T U^{-1}(X-M) \right) \right),  \end{align*}

Where \text{tr}() denotes the matrix trace and \exp() denotes the exponential function. Correspondingly, the KL divergence between two MVN distributions \mathcal{MVN}_{d, d} (M_p, U_p ,V_p) and \mathcal{MVN}_{d, d} (M_q, U_q ,V_q) on random variable square matrices Z \in \mathbb{R}^{d \times d} can be derived as follows:

(5)   \begin{align*}     & D_{KL}\left(q(Z) \text{ } || \text{ } p(Z)\right) = \int\limits_{Z \in \mathcal{Z}} \mathcal{MVN}_{d, d} (M_q, U_q ,V_q) \log \left(\frac{\mathcal{MVN}_{d, d} (M_q, U_q, V_q)}{\mathcal{MVN}_{d, d} (M_p, U_p ,V_p)} \right) \\     &  = \mathbb{E}_{Z \sim q(Z)} \left[\log\left(\frac{\frac{1}{(2\pi)^{\frac{d^2}{2}} |U_q|^{(\frac{d}{2})} |V_q|^{(\frac{d}{2})}} \exp \left(-\frac{1}{2}\text{tr}\left(V_q^{-1}(Z-M_q)^T U_q^{-1}(Z-M_q) \right) \right)}{\frac{1}{(2\pi)^{\frac{d^2}{2}} |U_p|^{(\frac{d}{2})} |V_p|^{(\frac{d}{2})}} \exp \left(-\frac{1}{2}\text{tr}\left(V_p^{-1}(Z-M_p)^T U_p^{-1}(Z-M_p) \right) \right)} \right)\right] \\     &  = \frac{1}{2} \mathbb{E}_{Z \sim q(Z)} \bigg[ \log \left(\frac{|U_p|^d |V_p|^d}{|U_q|^d |V_q|^d} \right) - \text{tr}\left(V_q^{-1} (Z-M_q)^T U_q^{-1} (Z-M_q) \right) \nonumber \\     &  \qquad + \text{tr}\left(V_p^{-1} (Z-M_p)^T U_p^{-1} (Z-M_p) \right) \bigg] \\     &  = \frac{1}{2} \bigg(d \cdot \log \left(\frac{|U_p| |V_p|}{|U_q| |V_q|} \right) - \text{tr}\left(V_q^{-1} \mathbb{E}_{Z \sim q(Z)}\left[(Z-M_q)^T U_q^{-1}(Z-M_q) \right] \right) \nonumber \\     &  \qquad + \text{tr}\left(V_p^{-1} \mathbb{E}_{Z \sim q(Z)}\left[(Z-M_p)^T U_p^{-1}(Z-M_p) \right] \right) \bigg) \\     &  = \frac{1}{2} \bigg(d \cdot \log \left(\frac{|U_p| |V_p|}{|U_q| |V_q|} \right) - \text{tr}\Big(V_q^{-1} \Big( \mathbb{E}_{Z \sim q(Z)}\left[Z^T U_q^{-1} Z \right] - M_q^T U_q^{-1} \mathbb{E}_{Z \sim q(Z)}\left[Z \right] \nonumber \\     &  \qquad \qquad - \left(\mathbb{E}_{Z \sim q(Z)}\left[Z\right]\right)^T U_q^{-1} M_q + M_q^T U_q^{-1} M_q \Big) \Big) \nonumber \\     &  \qquad + \text{tr}\Big(V_p^{-1} \Big( \mathbb{E}_{Z \sim q(Z)}\left[Z^T U_p^{-1} Z \right] - M_p^T U_p^{-1} \mathbb{E}_{Z \sim q(Z)}\left[Z \right] \nonumber \\     &  \qquad \qquad - \left(\mathbb{E}_{Z \sim q(Z)}\left[Z\right]\right)^T U_p^{-1} M_p + M_p^T U_p^{-1} M_p \Big) \Big) \bigg)  \end{align*}

From [3], we obtain that the expectation terms can be rewritten as follows:

(6)   \begin{align*}     \mathbb{E}_{Z \sim q(Z)}\left[Z^T A Z \right] &= V_q \text{tr}\left(U_q A^T \right) + M_q^T A M_q \text{, and} \\     \mathbb{E}_{Z \sim q(Z)}\left[Z \right] &= M_q \text{,} \end{align*}

Where A \in \mathbb{R}^{d \times d} is an arbitrary matrix. Consequently, for the KL divergence, we obtain the following expectation-free term by further reshaping equation 5 :

(7)   \begin{align*}     & D_{KL}\left(q(Z) \text{ } || \text{ } p(Z)\right) = \frac{1}{2} \bigg(d \cdot \log \left(\frac{|U_p| |V_p|}{|U_q| |V_q|} \right) - \text{tr}\left(V_q^{-1} V_q \text{tr} \left( U_q ( U_q^{-1} )^T \right) \right) \nonumber \\     & \qquad + \text{tr} \left(V_P^{-1} V_q \text{tr} \left(U_q (U_p^{-1} )^T \right) \right) \nonumber \\     & \qquad + \text{tr} \left( V_p^{-1} \left(M_q^T U_p^{-1} M_q - M_p^T U_p^{-1} M_q - M_q^T U_p^{-1} M_p + M_p^T U_p^{-1} M_p \right) \right) \bigg)\\     & = \frac{1}{2} \bigg(d \cdot \log \left(\frac{|U_p| |V_p|}{|U_q| |V_q|} \right) - d^2 + \text{tr} \left(V_P^{-1} V_q \right) \text{tr} \left(U_q (U_p^{-1} )^T \right) \nonumber \\     & \qquad + \text{tr} \left( V_p^{-1} (M_q - M_p)^T U_p^{-1} (M_q - M_p)\right) \bigg) \end{align*}

This is really good, because we can now actually calculate the KL divergence between two arbitrary MVN distributions! But what is that good for?

Time-Series Modeling with Matrix-Shaped Latents

In [4], Hafner and his collaborators trained an agent in a Reinforcement Learning setting that consults a so-called Recurrent State Space Model (RSSM) to predict the outcome of possible actions and thus choose the optimal action to take. The RSSM combines recurrency (implemented with Gated Recurrent Units [5]) and stochastic modeling (by means of a Variational Autoencoder (VAE) [6]) to handle time-series of environment observations and predict possible future outcomes, and the proposed Planning Network (PlaNet) uses that RSSM as its model to achieve remarkable results on a variety of challenging tasks.

In my Master’s thesis, I had a closer look at the RSSM: The environment observations come in the form of images, which contain spatial information by nature. However, the RSSM’s VAE training procedure assumes that the latent variables follow a diagonal multivariate normal distribution, meaning that spatial information between the (encoded) pixels of the input observations could be lost in the process of latent variable encoding.

Therefore, part of my thesis dealt with the re-modeling of the latent space as a vector of independent random square matrices instead of random scalars. In order to train the resulting “spatial” VAE by means of Variational Inference [7], we need to calculate the KL divergence between the model’s posterior and prior distribution. And that’s where the formula for the KL divergence between two MVN distributions came in handy (albeit under a few factorization assumptions, since matrix inversion is expensive).

Code/Link to my thesis might appear here soon, but no guarantees…

References

  1. A. K. Gupta and D. K. Nagar: Matrix Variate Distributions. Chapman and Hall/CRC, 1999. ISBN 978-0-203-74928-9. doi: 10.1201/9780203749289
  2. https://en.wikipedia.org/wiki/Kullback-Leibler_divergence (accessed: Jan 08 2022) (yes, I actually cite Wikipedia here…)
  3. J. Li, H. Yan, J. Gao, D. Kong, L. Wang, S. Wang, and B. Yin. Matrix-variate variational auto-encoder with applications to image process. Journal of Visual Communication and Image Representation, 67:102750, Feb. 2020. ISSN 1047-3203. doi: 10.1016/j.jvcir.2019.102750.
  4. D. Hafner, T. P. Lillicrap, I. Fischer, R. Villegas, D. Ha, H. Lee, and J. Davidson. Learning latent dynamics for planning from pixels. ICML 2019, 9-15 June 2019, Long Beach, California, USA, volume 97 of Proceedings of Machine Learning Research, pages 2555–2565. PMLR, 2019.
  5. K. Cho, B. van Merrienboer, Ç. Gülçehre, D. Bahdanau, F. Bougares, H. Schwenk, and Y. Bengio. Learning phrase representations using RNN encoder-decoder for statistical machine translation. In Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing, EMNLP 2014, October 25-29, 2014, Doha, Qatar
  6. D. P. Kingma and M. Welling. Auto-encoding variational bayes. In 2nd International Conference on Learning Representations, ICLR 2014, Banff, AB, Canada
  7. https://ermongroup.github.io/cs228-notes/inference/variational/ (accessed: Jan 08 2022)

To the best of my knowledge, I’ve cited all relevant non-trivial scientific contributions used in this post. If you spot any errors or false/missing citations please, let me know!

Leave a Comment

Your email address will not be published.