Notes on "Semi-supervised Learning with Deep Generative Models " (2014)

$$ \newcommand{\bzero}{\mathbf{0}} \newcommand{\ba}{\mathbf{a}}\newcommand{\bA}{\mathbf{A}} \newcommand{\bb}{\mathbf{b}}\newcommand{\bB}{\mathbf{B}} \newcommand{\bc}{\mathbf{c}}\newcommand{\bC}{\mathbf{C}} \newcommand{\bd}{\mathbf{d}}\newcommand{\bD}{\mathbf{D}} \newcommand{\be}{\mathbf{e}}\newcommand{\bE}{\mathbf{E}} \newcommand{\bff}{\mathbf{f}}\newcommand{\bF}{\mathbf{F}} \newcommand{\bg}{\mathbf{g}}\newcommand{\bG}{\mathbf{G}} \newcommand{\bh}{\mathbf{h}}\newcommand{\bH}{\mathbf{H}} \newcommand{\bi}{\mathbf{i}}\newcommand{\bI}{\mathbf{I}} \newcommand{\bj}{\mathbf{j}}\newcommand{\bJ}{\mathbf{J}} \newcommand{\bk}{\mathbf{k}}\newcommand{\bK}{\mathbf{K}} \newcommand{\bl}{\mathbf{l}}\newcommand{\bL}{\mathbf{L}} \newcommand{\bm}{\mathbf{m}}\newcommand{\bM}{\mathbf{M}} \newcommand{\bn}{\mathbf{n}}\newcommand{\bN}{\mathbf{N}} \newcommand{\bo}{\mathbf{o}}\newcommand{\bO}{\mathbf{O}} \newcommand{\bp}{\mathbf{p}}\newcommand{\bP}{\mathbf{P}} \newcommand{\bq}{\mathbf{q}}\newcommand{\bQ}{\mathbf{Q}} \newcommand{\br}{\mathbf{r}}\newcommand{\bR}{\mathbf{R}} \newcommand{\bs}{\mathbf{s}}\newcommand{\bS}{\mathbf{S}} \newcommand{\bt}{\mathbf{t}}\newcommand{\bT}{\mathbf{T}} \newcommand{\bu}{\mathbf{u}}\newcommand{\bU}{\mathbf{U}} \newcommand{\bv}{\mathbf{v}}\newcommand{\bV}{\mathbf{V}} \newcommand{\bw}{\mathbf{w}}\newcommand{\bW}{\mathbf{W}} \newcommand{\bx}{\mathbf{x}}\newcommand{\bX}{\mathbf{X}} \newcommand{\by}{\mathbf{y}}\newcommand{\bY}{\mathbf{Y}} \newcommand{\bz}{\mathbf{z}}\newcommand{\bZ}{\mathbf{Z}} \newcommand{\balpha}{\boldsymbol{\alpha}}\newcommand{\bAlpha}{\boldsymbol{\Alpha}} \newcommand{\bbeta}{\boldsymbol{\beta}}\newcommand{\bBeta}{\boldsymbol{\Beta}} \newcommand{\bgamma}{\boldsymbol{\gamma}}\newcommand{\bGamma}{\boldsymbol{\Gamma}} \newcommand{\bdelta}{\boldsymbol{\delta}}\newcommand{\bDelta}{\boldsymbol{\Delta}} \newcommand{\bepsilon}{\boldsymbol{\epsilon}}\newcommand{\bEpsilon}{\boldsymbol{\Epsilon}} \newcommand{\bzeta}{\boldsymbol{\zeta}}\newcommand{\bZeta}{\boldsymbol{\Zeta}} \newcommand{\beeta}{\boldsymbol{\eta}}\newcommand{\bEta}{\boldsymbol{\Eta}} % \beta already taken \newcommand{\btheta}{\boldsymbol{\theta}}\newcommand{\bTheta}{\boldsymbol{\Theta}} \newcommand{\biota}{\boldsymbol{\iota}}\newcommand{\bIota}{\boldsymbol{\Iota}} \newcommand{\bkappa}{\boldsymbol{\kappa}}\newcommand{\bKappa}{\boldsymbol{\Kappa}} \newcommand{\blambda}{\boldsymbol{\lambda}}\newcommand{\bLambda}{\boldsymbol{\Lambda}} \newcommand{\bmu}{\boldsymbol{\mu}}\newcommand{\bMu}{\boldsymbol{\Mu}} \newcommand{\bnu}{\boldsymbol{\nu}}\newcommand{\bNu}{\boldsymbol{\Nu}} \newcommand{\bxi}{\boldsymbol{\xi}}\newcommand{\bXi}{\boldsymbol{\Xi}} \newcommand{\bomikron}{\boldsymbol{\omikron}}\newcommand{\bOmikron}{\boldsymbol{\Omikron}} \newcommand{\bpi}{\boldsymbol{\pi}}\newcommand{\bPi}{\boldsymbol{\Pi}} \newcommand{\brho}{\boldsymbol{\rho}}\newcommand{\bRho}{\boldsymbol{\Rho}} \newcommand{\bsigma}{\boldsymbol{\sigma}}\newcommand{\bSigma}{\boldsymbol{\Sigma}} \newcommand{\btau}{\boldsymbol{\tau}}\newcommand{\bTau}{\boldsymbol{\Tau}} \newcommand{\bypsilon}{\boldsymbol{\ypsilon}}\newcommand{\bYpsilon}{\boldsymbol{\Ypsilon}} \newcommand{\bphi}{\boldsymbol{\phi}}\newcommand{\bPhi}{\boldsymbol{\Phi}} \newcommand{\bchi}{\boldsymbol{\chi}}\newcommand{\bChi}{\boldsymbol{\Chi}} \newcommand{\bpsi}{\boldsymbol{\psi}}\newcommand{\bPsi}{\boldsymbol{\Psi}} \newcommand{\bomega}{\boldsymbol{\omega}}\newcommand{\bOmega}{\boldsymbol{\Omega}} \newcommand{\nA}{\mathbb{A}} \newcommand{\nB}{\mathbb{B}} \newcommand{\nC}{\mathbb{C}} \newcommand{\nD}{\mathbb{D}} \newcommand{\nE}{\mathbb{E}} \newcommand{\nF}{\mathbb{F}} \newcommand{\nG}{\mathbb{G}} \newcommand{\nH}{\mathbb{H}} \newcommand{\nI}{\mathbb{I}} \newcommand{\nJ}{\mathbb{J}} \newcommand{\nK}{\mathbb{K}} \newcommand{\nL}{\mathbb{L}} \newcommand{\nM}{\mathbb{M}} \newcommand{\nN}{\mathbb{N}} \newcommand{\nO}{\mathbb{O}} \newcommand{\nP}{\mathbb{P}} \newcommand{\nQ}{\mathbb{Q}} \newcommand{\nR}{\mathbb{R}} \newcommand{\nS}{\mathbb{S}} \newcommand{\nT}{\mathbb{T}} \newcommand{\nU}{\mathbb{U}} \newcommand{\nV}{\mathbb{V}} \newcommand{\nW}{\mathbb{W}} \newcommand{\nX}{\mathbb{X}} \newcommand{\nY}{\mathbb{Y}} \newcommand{\nZ}{\mathbb{Z}} \newcommand{\cA}{\mathcal{A}} \newcommand{\cB}{\mathcal{B}} \newcommand{\cC}{\mathcal{C}} \newcommand{\cD}{\mathcal{D}} \newcommand{\cE}{\mathcal{E}} \newcommand{\cF}{\mathcal{F}} \newcommand{\cG}{\mathcal{G}} \newcommand{\cH}{\mathcal{H}} \newcommand{\cI}{\mathcal{I}} \newcommand{\cJ}{\mathcal{J}} \newcommand{\cK}{\mathcal{K}} \newcommand{\cL}{\mathcal{L}} \newcommand{\cM}{\mathcal{M}} \newcommand{\cN}{\mathcal{N}} \newcommand{\cO}{\mathcal{O}} \newcommand{\cP}{\mathcal{P}} \newcommand{\cQ}{\mathcal{Q}} \newcommand{\cR}{\mathcal{R}} \newcommand{\cS}{\mathcal{S}} \newcommand{\cT}{\mathcal{T}} \newcommand{\cU}{\mathcal{U}} \newcommand{\cV}{\mathcal{V}} \newcommand{\cW}{\mathcal{W}} \newcommand{\cX}{\mathcal{X}} \newcommand{\cY}{\mathcal{Y}} \newcommand{\cZ}{\mathcal{Z}} \DeclareMathOperator*{\argmax}{argmax~} \DeclareMathOperator*{\argmin}{argmin~} \DeclareMathOperator*{\Tr}{Tr} \DeclareMathOperator*{\Bias}{Bias} \DeclareMathOperator*{\Var}{Var} \newcommand{\Perp}{\perp\!\!\! \perp} \let\dsad\d \renewcommand\d{\mathrm{d}} \newcommand{\R}{\mathbb{R}} \newcommand{\N}{\mathbb{N}} \newcommand{\E}{\mathbb{E}} \newcommand{\Eb}[1]{\mathbb{E} \left[ #1 \right]} \newcommand{\F}{\mathcal{F}} \newcommand{\X}{\mathcal{X}} \newcommand{\vocab}[1]{\textbf{\color{blue} #1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\inner}[2]{ \left \langle #1, #2 \right \rangle } \newcommand{\ibra}[1]{\llbracket #1 \rrbracket} \newenvironment{nalign}{ \begin{equation} \begin{aligned} }{ \end{aligned} \end{equation} \ignorespacesafterend } \newcommand{\icol}[1]{ \left(\begin{smallmatrix}#1\end{smallmatrix}\right)% } \newcommand{\cc}[1]{\overline{#1}} \newcommand{\tr}[1]{\text{tr} \left( #1 \right)} \newcommand{\aff}{\text{aff} \,} \newcommand{\ri}{\text{ri} \, } \newcommand{\rb}{\text{rb} \, } \newcommand{\cl}{\text{cl} \, } \newcommand{\conv}{\text{conv} \, } \newcommand{\irow}[1]{ \begin{smallmatrix}(#1)\end{smallmatrix}% } \newcommand\abs[1]{\lvert #1 \rvert} $$

In this post I continue exploring the works applying VAEs by looking into the work extending (missing reference) to allow more flexibility of the graphical model, which consequently results in a possibility to learn with fewer labels.


High-Level Overview

There are two main results of the paper, which arrise naturally when trying to recover a graphical model that encorporates information of supervised data.

The first one is the application of the original VAE method on a bigger graphical model with mixed families of distributions. In this case, the assumption is that the datapoints are generated using two pieces of information: continous and discrete latent variables. This poses no new problems than the continous-only generation process, but gives greater control of generating samples and allows for discrimination between samples in feature space due to the use of additional discrete data. As an example, in the original AEVB paper, in the MNIST setting, no label information was used. This meant that the type of digit and the look of it had to be captured in a single continous variable. But this limts control and does not reflect the actual generative process of digit formation.

The second one is a natural consequence of the first: learning from partially (and sparselly) labeled data. Having decomposed the model into pieces reflecting the generative process, the label of the data is just another random variable. As such, in case of its abstance, it can be integrated out. This is made simpler especially because it is discrete with only 10 classes!

Defining the Models

There are thee models considered in the paper: unsupervised one M1, and semi-supervised M2 and M3.

M1

The M1 model is equivalent to the original VAE (missing reference). It simply drops the labels of the data $y$ away and only tries to recover “style” features $\bz$. In this way, the generative model reads $p(\bx , \bz) = p(\bz) p(\bx \mid \bz) = \cN(\bzero, \nI) \cN(\bz; \bmu_\bz, \bsigma_\bz^2 \nI)$, where we assume a standard normal as the prior over $\bz$. The approximate posterior model is assumed to be $q(\bz \mid \bx) = \cN(\bx; \bmu_\bx, \bsigma_\bx^2 \nI)$.

M2

The next is M2 – a semi-supervised generative model (displayed below). Its joint is defined consists of a “style” $\bz$ and a label $y$. The distribution is parameterized by $\btheta$. In sum, the generative model takes the form of

\[\begin{equation} \begin{split} p_\btheta(\bx, \bz, y) &= p(y) p(\bz) p(\bx\mid\bz, y) \\ \end{split} \end{equation}\]

where we assume that the label $y$ follows a categorical: $y \sim \text{Cat}(\bpi)$ over the discrete labels.

On the other hand, the inference model $q_\bphi$ has a point of contention. In the paper, authors note that factorized distribution is assumed, in the form of

\[q(\bz, y\mid\bx) = q(\bz\mid\bx) q(y\mid\bx),\]

but later define different conditional structure. By looking around on the internet (citation needed), it seems like actual factorization of the approximate posterior is

\(q(\bz, y\mid\bx) = q(y\mid\bx) q(\bz \mid y, \bx)\). This is consistent with their concrete specification of the model in equation (4) and makes sense intuitivelly: the style of a datapoint should depend on its type (in MNIST case, it makes sense to assume that digit “1” is usually narrow, so knowing about the digit gives informaton about the style). However, we can also argue that given a datapoint $\bx$ the conditional structure is not neccessary. As such, I will keep the inference model simple by sticking to full factorization.

While there are other (and IMO more natural) ways to define the model (say, assuming that the style depends on the label), this is particularly attractive because it assumes that the style and the label are independent. This means that, by assumption, all the data have a common space of styles.

M1 + M2

Finally, the last defined (and turns out the most powerful one) is M1+M2: stacked generative semi-supervised model. It simply combines the above-defined models, and instead of generating the data $\bx$ directly from the the hidden variables, it assumes that a latent $\bz_2$ is produced first, which then gives rise to $\bz_1$, which only then generates $\bx$:

\[\begin{equation} \begin{split} p(\bx, y, \bz_2, \bz_1) = p(y) p(\bz_2) p(\bz_1|\bz_2) p(\bx\mid\bz_1). \end{split} \end{equation}\]

In all of the cases we assume that all distributions (involving $\btheta$ and $\bphi$) are Gaussians, parameterized by a NN.

Interlude: M1 and the Manifold of MNIST Digits

One interesting illustration that in the original VAE (missing reference) paper is the 2D manifold of MNIST digits, where a 2D grid of latent variables are decoded into the pixel space. As is argued, the unsupervised VAE mixes “style” and the class of the digits. I think it’s instructive to reproduce that plot to see that for ourselves.

Not many ingredients are required to carry this out: a trained unsupervised VAE model, and a way to relate a square grid in 2D to a “circle” covariance that is due to the standard Normal prior over $\bz$: $p(\bz) = \cN(\bzero, \nI)$.

Transforming a Uniform over 2D Unit Square to a Standard Normal

The main idea is to divide the prior 2D space into a grid such that each of the grid would have roughly equal mass. If the prior was a uniform distribution this would be trivial as all points would be equally likely. But now the points close to the center are more likely to be selected. Luckily, the Gaussian prior is independent dimension-wise, so we can simply divide the axes independently by preserving the mass per region. This leads to denser divisions near zero and sparser ones farther away. In code:

interval_size = 20

from scipy.stats import norm

pnts_z_1d = norm.ppf(np.linspace(1.0 / interval_size, 1.0 - 1.0 / interval_size, interval_size))

crossed = np.array( [[(x,y) for y in pnts_z_1d]  for x in pnts_z_1d ])

To confirm that this makes sense, we can look into the mass that is covered by each 2D rectangle by numerical CDF calculation techniques:

rv = multivariate_normal([0.0, 0.0])

lst = 0
block_areas = []
for i in range(1, crossed.shape[0]): # x
    for j in range(1, crossed.shape[1]): # y
        cdf_lst_left = rv.cdf( [crossed[i-1,j][0], crossed[i-1,j][1]])
        cdf_lst_down = rv.cdf( [crossed[i,j-1][0], crossed[i,j-1][1]])
        cdf_lst_left_down = rv.cdf( [crossed[i-1,j-1][0], crossed[i-1,j-1][1]] )

        cdf_here = rv.cdf([crossed[i,j][0], crossed[i,j][1]])

        this_block_area = cdf_here - cdf_lst_down - cdf_lst_left + cdf_lst_left_down
        block_areas.append(this_block_area)
print(block_areas[1] * len(block_areas)) # - borders
assert (np.max(block_areas) - np.min(block_areas) <= 1e-4), "The space is not mapped to the circle equally"

MNIST manifold

Finally, we can display the decoded 2D latent space (Figure 2).

Figure 2. Decoded MNIST 2-dimensional latent space. Note that samples near each other are similar in terms of semantics.

Samples “far away” from the mean

I found it quite cool to look at sames “far away” from the mean of the Gaussian. To visualize this, I simply show a unit square in range of $[-5, 5]^2$ divided equidistantly. This leads to very sharp, and caricature-like reconstructions (see Figure 3) – cool!

Figure 3 Decoded MNIST 2-dimensional latent space on a "improbable" range $[-5,5]^2$. The latent space was constrained to be similar to a isometric Gaussian, and so only most extreme cases get picked up in this plot.

Deriving the Objective (M2 Case)

General Variational Principle

Following the variational principle, a lower bound on the log-evidence $p(\bx)$ (ELBO) can be derived, as was done in Vanilla-VAE paper [1]. This amounts to the following general form of a generative model with latent variables $\bl$ and data $\bx$:

\[\begin{equation} \begin{split} p(\bx) & \leq \E_{\bl \sim q(\bl|\bx)} \left [ \log p(\bx\mid\bl) - \log q(\bl\mid\bx) + \log p(\bl) \right ] \\ &= \E_{\bl \sim q(\bl|\bx)} \left [ \log p(\bx\mid\bl) \right] - \text{D}_{\text{KL}}(q(\bl|\bx) || p(\bl)) \\ \end{split} \end{equation}\]

so later we can replace $\bl$ with the set of varialbes that are considered hidden/unknown.

Case #1: Labels are Given

Following the general VAE principle, we assume that the data that is known is $(\bx, y)$, while the latent variable is $\bz$, thus substituting these quantities into (3), we compute the ELBO:

\[\begin{equation} \begin{split} \cL_{\text{labeled}}(\bx, y) &= \E_{\bz \sim q(\bz|\bx, y)} \left [ \log p(\bx, y\mid\bz) - \log q(\bz\mid\bx, y) + \log p(\bz) \right ] \\ &= \E_{\bz \sim q(\bz|\bx, y)} \left [ \log p(y\mid\bz) + \log p(\bx\mid\bz, y) - \log q(\bz\mid\bx, y) + \log p(\bz) \right ] \\ &= \log p(y) + \E_{z \sim q(\bz\mid\bx, y)} \left [\log p(\bx\mid\bz, y) \right ] - \text{D}_{\text{KL}}(q(\bz\mid\bx, y) || p(\bz)) \end{split} \end{equation}\]

No problems here: KL has a closed form expression, and the expectation is poses no problem in sampling and the prior over $y$ adds no increase (on a surface level atleast) in costs in comparison to unsupervised VAE.

Case #2: Labels not Given

It will be useful to use this fact that for a general joint probability distribution over r.v.s $x,y,$ it holds that

\[\E_{(x, y) \sim p(x) p(y|x)}[f(x,y)] = \E_{p(x)}\left[\E_{p(y|x)} [f(x,y)] \,\mid \, X = x \right].\]

In the case when labels $y$ are not given, the set of latents become $\bl = (\bz, y)$. Writing down the ELBO and making use of the identity defined in (4) and the fact that by assumption $q(\bz, y\mid\bx) = q(y\mid\bx) q(\bz\mid\bx, y)$, yields:

\[\begin{equation} \begin{split} \cL_{\text{unlabeled}}(\bx) &= \E_{\bz, y \sim q(\bz, y|\bx)} \left [ \log p(\bx\mid\bz, y) - \log q(\bz, y\mid\bx) + \log p(\bz, y) \right ] \\ &= \E_{\bz, y \sim q(\bz, y|\bx)} \left [ \log p(\bx\mid\bz, y) - \log q(y\mid\bx) - \log q(\bz \mid\bx, y) + \log p(\bz) + \log p(y) \right ] \\ &= \E_{y \sim q(y \mid\bx)} \left [ \E_{\bz \sim q(\bz\mid\bx, y)} \left [ \log p(\bx\mid\bz, y) - \log q(y\mid\bx) - \log q(\bz\mid\bx, y) + \log p(\bz) + \log p(y) \right ] \right] \\ &= \E_{y \sim q(y \mid\bx)} \left [\E_{\bz \sim q(\bz\mid\bx, y)} \left [ \log p(\bx\mid\bz, y) - \log q(\bz\mid\bx, y) + \log p(\bz) \right ] - \log q(y\mid\bx) + \log p(y) \right] \\ &= \sum_{i=1}^{n_y} q(y = c_i | \bx) \big [\E_{\bz \sim q(\bz\mid\bx, y = c_i)} \left [ \log p(\bx\mid\bz, y = c_i) - \log q(\bz\mid\bx, y = c_i) + \log p(\bz) \right ] \\ & \qquad - \log q(y=c_i\mid\bx) + \log p(y = c_i) \big] \\ &= \sum_{i=1}^{n_y} q(y = c_i | \bx) \big [ \cL_{\text{labeled}}(\bx, y = c_i) - \log q(y = c_i | \bx) \big], \end{split} \end{equation}\]

where $n_y$ is defined as the number of possible classes. This bound almost reduces to the labeled case (weighted by the classifier $q(y \mid \bx)$).

What is $p(y)$ exactly?

$p(y)$ denotes the pior probability of the classes occuring in the data. In all the experiments in the paper, the authors split the data accross classes equally, but I’m not sure if they actually assumed this by setting $p(y) = \frac{1}{N}$, or made this a learnable parameter. For now, lets assume that this is known.

Putting It All Together

One curious thing about evidence lower bound is that in the labeled case (4) the classifier doesn’t learn at all. An ad-hoc solution to this is to take the part of the unlabeled loss that depends on the classifier $q(y \mid \bx)$ and inject that into the labeled case, which is exactly what authors did. Conretelly, the augmented labeled is then defined as

\[\begin{equation} \begin{split} \cL^{\alpha}_{\text{labeled}}(\bx, y) = \cL_{\text{labeled}} + \alpha \log q(y | \bx). \end{split} \end{equation}\]

Because I use a bit different notation, it is sufficient to set $\alpha = 1.0$ to get the same effect as the authors did in the paper.

Alternatively, this additional classification loss can be derived by assuming a (symmetric) Dirichlet prior over the $y$ (derived here), though this requires a bit of a puzzling change in the inference model.

The classification model $q(y \mid bx)$ is structurally different for the two bounds for labeled and unlabeled case. In unlabeled case, we the classification model is encouraged to maximize the entropy, making the model as unconfident as possible. But in labeled case the opposite happens, and the model is encouraged to classify correctly.




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Notes on "Domain Adaptation by Using Causal Inference to Predict Invariant Conditional Distributions" (2018)
  • Notes on "Attention Is All You Need" (2017)
  • Notes on Conditional Gradient Method
  • Notes on Fenchel Conjugate and Duality
  • Comparing Non-Smooth Optimization through Subgradients on Regularized Logistic Regression Problem