Categorical Reparameterization with Gumbel-Softmax


The Gumbel Distribution

Notation: \(X\sim\text{Gumbel}(\mu, \beta)\), where \(\mu\in\mathbb{R}\) is the location parameter and \(\beta>0\) is the scale parameter.


\[f_X(x)=\frac{1}{\beta}e^{-(z+e^{-z})}, \text{ where } z=\frac{x-\mu}{\beta}.\]


\[F_X(x)=e^{-e^{-z}}, \text{ where } z=\frac{x-\mu}{\beta}.\]

See Wiki for more details.

The Gumbel-Max Trick

Let \(\pi=(\pi_1,\dots,\pi_k)\) be \(k\)-d nonnegative vector, where not all elements are zero, and let \(g_1,\dots,g_k\) be \(k\) iid samples from \(\text{Gumbel}(0,1)\). Then



Let \(I = \arg\max_i\{G_i + \log\pi_i\}\) and \(M = \max_i\{G_i + \log\pi_i\}\).

\[\begin{aligned} \mathbb{P}(I=i)&=\mathbb{P}(G_i + \log\pi_i < M, \forall j\neq i) \\ & = \int_{-\infty}^\infty f_{G_i}(m-\log\pi_i) \prod_{j\neq i} F_{G_j}(m-\log\pi_j) dm \\ & = \int_{-\infty}^\infty \exp(\log\pi_i-m-\exp(\log\pi_i-m)) \prod_{j\neq i} \exp(-\exp(\log\pi_j-m)) dm \\ & = \int_{-\infty}^\infty \exp(\log\pi_i-m)\exp(-\exp(\log\pi_i-m)) \prod_{j\neq i} \exp(-\exp(\log\pi_j-m)) dm \\ & = \int_{-\infty}^\infty \exp(\log\pi_i-m) \prod_{j} \exp(-\exp(\log\pi_j-m)) dm \\ & = \int_{-\infty}^\infty \exp(\log\pi_i-m) \exp(-\sum_{j}\exp(\log\pi_j-m)) dm \\ & = \int_{-\infty}^\infty \exp(\log\pi_i)\exp(-m) \exp(-\exp(-m)\sum_{j}\exp(\log\pi_j)) dm \\ & = \int_{-\infty}^\infty \pi_i\exp(-m) \exp(-\exp(-m)\sum_{j}\pi_j) dm \\ & = \int_{0}^\infty \pi_i \exp(-x\sum_{j}\pi_j) dx \\ & = \frac{\pi_i}{\sum_j\pi_j} \end{aligned}\]

The Gumbel-Softmax Distribution

Relax the Gumbel-Max trick by replacing argmax with softmax (continuous, differentiable) and generate \(k\)-d sample vectors

\[y_i = \frac{\exp((\log(\pi_i)+g_i)/\tau)}{\sum_{j=1}^k\exp((\log(\pi_j)+g_j)/\tau)}.\]


\[f_{Y_1,\dots,Y_k}(y_1,\dots,y_k;\pi,\tau)=\Gamma(k)\tau^{k-1}\left( \sum_{i=1}^k \pi_i/y_i^\tau \right)^{-k}\prod_{i=1}^k(\pi_i/y_i^{\tau+1}).\]

The Gumbel-Softmax Estimator

The Gumbel-Softmax distribution is smooth for \(\tau > 0\), and therefore has a well-defined gradient \(\partial y/\partial \pi\) with respect to the parameters \(\pi\). Thus, by replacing categorical samples with Gumbel-Softmax samples we can use backpropagation to compute gradients.

Denote the procedure of replacing non-differentiable categorical samples with a differentiable approximation during training as the Gumbel-Softmax estimator.

A tradeoff between small and large temperatures:

In practice

The Straight-Through Gumbel-Softmax Estimator

For scenarios that are constrained to sampling discrete values

Call this Straight-Through (ST) Gumbel-Softmax Estimator.