Notes on "Auto-Encoding Variational Bayes" (2014) and Implementation in Pytorch
$$ \newcommand{\bzero}{\mathbf{0}} \newcommand{\ba}{\mathbf{a}}\newcommand{\bA}{\mathbf{A}} \newcommand{\bb}{\mathbf{b}}\newcommand{\bB}{\mathbf{B}} \newcommand{\bc}{\mathbf{c}}\newcommand{\bC}{\mathbf{C}} \newcommand{\bd}{\mathbf{d}}\newcommand{\bD}{\mathbf{D}} \newcommand{\be}{\mathbf{e}}\newcommand{\bE}{\mathbf{E}} \newcommand{\bff}{\mathbf{f}}\newcommand{\bF}{\mathbf{F}} \newcommand{\bg}{\mathbf{g}}\newcommand{\bG}{\mathbf{G}} \newcommand{\bh}{\mathbf{h}}\newcommand{\bH}{\mathbf{H}} \newcommand{\bi}{\mathbf{i}}\newcommand{\bI}{\mathbf{I}} \newcommand{\bj}{\mathbf{j}}\newcommand{\bJ}{\mathbf{J}} \newcommand{\bk}{\mathbf{k}}\newcommand{\bK}{\mathbf{K}} \newcommand{\bl}{\mathbf{l}}\newcommand{\bL}{\mathbf{L}} \newcommand{\bm}{\mathbf{m}}\newcommand{\bM}{\mathbf{M}} \newcommand{\bn}{\mathbf{n}}\newcommand{\bN}{\mathbf{N}} \newcommand{\bo}{\mathbf{o}}\newcommand{\bO}{\mathbf{O}} \newcommand{\bp}{\mathbf{p}}\newcommand{\bP}{\mathbf{P}} \newcommand{\bq}{\mathbf{q}}\newcommand{\bQ}{\mathbf{Q}} \newcommand{\br}{\mathbf{r}}\newcommand{\bR}{\mathbf{R}} \newcommand{\bs}{\mathbf{s}}\newcommand{\bS}{\mathbf{S}} \newcommand{\bt}{\mathbf{t}}\newcommand{\bT}{\mathbf{T}} \newcommand{\bu}{\mathbf{u}}\newcommand{\bU}{\mathbf{U}} \newcommand{\bv}{\mathbf{v}}\newcommand{\bV}{\mathbf{V}} \newcommand{\bw}{\mathbf{w}}\newcommand{\bW}{\mathbf{W}} \newcommand{\bx}{\mathbf{x}}\newcommand{\bX}{\mathbf{X}} \newcommand{\by}{\mathbf{y}}\newcommand{\bY}{\mathbf{Y}} \newcommand{\bz}{\mathbf{z}}\newcommand{\bZ}{\mathbf{Z}} \newcommand{\balpha}{\boldsymbol{\alpha}}\newcommand{\bAlpha}{\boldsymbol{\Alpha}} \newcommand{\bbeta}{\boldsymbol{\beta}}\newcommand{\bBeta}{\boldsymbol{\Beta}} \newcommand{\bgamma}{\boldsymbol{\gamma}}\newcommand{\bGamma}{\boldsymbol{\Gamma}} \newcommand{\bdelta}{\boldsymbol{\delta}}\newcommand{\bDelta}{\boldsymbol{\Delta}} \newcommand{\bepsilon}{\boldsymbol{\epsilon}}\newcommand{\bEpsilon}{\boldsymbol{\Epsilon}} \newcommand{\bzeta}{\boldsymbol{\zeta}}\newcommand{\bZeta}{\boldsymbol{\Zeta}} \newcommand{\beeta}{\boldsymbol{\eta}}\newcommand{\bEta}{\boldsymbol{\Eta}} % \beta already taken \newcommand{\btheta}{\boldsymbol{\theta}}\newcommand{\bTheta}{\boldsymbol{\Theta}} \newcommand{\biota}{\boldsymbol{\iota}}\newcommand{\bIota}{\boldsymbol{\Iota}} \newcommand{\bkappa}{\boldsymbol{\kappa}}\newcommand{\bKappa}{\boldsymbol{\Kappa}} \newcommand{\blambda}{\boldsymbol{\lambda}}\newcommand{\bLambda}{\boldsymbol{\Lambda}} \newcommand{\bmu}{\boldsymbol{\mu}}\newcommand{\bMu}{\boldsymbol{\Mu}} \newcommand{\bnu}{\boldsymbol{\nu}}\newcommand{\bNu}{\boldsymbol{\Nu}} \newcommand{\bxi}{\boldsymbol{\xi}}\newcommand{\bXi}{\boldsymbol{\Xi}} \newcommand{\bomikron}{\boldsymbol{\omikron}}\newcommand{\bOmikron}{\boldsymbol{\Omikron}} \newcommand{\bpi}{\boldsymbol{\pi}}\newcommand{\bPi}{\boldsymbol{\Pi}} \newcommand{\brho}{\boldsymbol{\rho}}\newcommand{\bRho}{\boldsymbol{\Rho}} \newcommand{\bsigma}{\boldsymbol{\sigma}}\newcommand{\bSigma}{\boldsymbol{\Sigma}} \newcommand{\btau}{\boldsymbol{\tau}}\newcommand{\bTau}{\boldsymbol{\Tau}} \newcommand{\bypsilon}{\boldsymbol{\ypsilon}}\newcommand{\bYpsilon}{\boldsymbol{\Ypsilon}} \newcommand{\bphi}{\boldsymbol{\phi}}\newcommand{\bPhi}{\boldsymbol{\Phi}} \newcommand{\bchi}{\boldsymbol{\chi}}\newcommand{\bChi}{\boldsymbol{\Chi}} \newcommand{\bpsi}{\boldsymbol{\psi}}\newcommand{\bPsi}{\boldsymbol{\Psi}} \newcommand{\bomega}{\boldsymbol{\omega}}\newcommand{\bOmega}{\boldsymbol{\Omega}} \newcommand{\nA}{\mathbb{A}} \newcommand{\nB}{\mathbb{B}} \newcommand{\nC}{\mathbb{C}} \newcommand{\nD}{\mathbb{D}} \newcommand{\nE}{\mathbb{E}} \newcommand{\nF}{\mathbb{F}} \newcommand{\nG}{\mathbb{G}} \newcommand{\nH}{\mathbb{H}} \newcommand{\nI}{\mathbb{I}} \newcommand{\nJ}{\mathbb{J}} \newcommand{\nK}{\mathbb{K}} \newcommand{\nL}{\mathbb{L}} \newcommand{\nM}{\mathbb{M}} \newcommand{\nN}{\mathbb{N}} \newcommand{\nO}{\mathbb{O}} \newcommand{\nP}{\mathbb{P}} \newcommand{\nQ}{\mathbb{Q}} \newcommand{\nR}{\mathbb{R}} \newcommand{\nS}{\mathbb{S}} \newcommand{\nT}{\mathbb{T}} \newcommand{\nU}{\mathbb{U}} \newcommand{\nV}{\mathbb{V}} \newcommand{\nW}{\mathbb{W}} \newcommand{\nX}{\mathbb{X}} \newcommand{\nY}{\mathbb{Y}} \newcommand{\nZ}{\mathbb{Z}} \newcommand{\cA}{\mathcal{A}} \newcommand{\cB}{\mathcal{B}} \newcommand{\cC}{\mathcal{C}} \newcommand{\cD}{\mathcal{D}} \newcommand{\cE}{\mathcal{E}} \newcommand{\cF}{\mathcal{F}} \newcommand{\cG}{\mathcal{G}} \newcommand{\cH}{\mathcal{H}} \newcommand{\cI}{\mathcal{I}} \newcommand{\cJ}{\mathcal{J}} \newcommand{\cK}{\mathcal{K}} \newcommand{\cL}{\mathcal{L}} \newcommand{\cM}{\mathcal{M}} \newcommand{\cN}{\mathcal{N}} \newcommand{\cO}{\mathcal{O}} \newcommand{\cP}{\mathcal{P}} \newcommand{\cQ}{\mathcal{Q}} \newcommand{\cR}{\mathcal{R}} \newcommand{\cS}{\mathcal{S}} \newcommand{\cT}{\mathcal{T}} \newcommand{\cU}{\mathcal{U}} \newcommand{\cV}{\mathcal{V}} \newcommand{\cW}{\mathcal{W}} \newcommand{\cX}{\mathcal{X}} \newcommand{\cY}{\mathcal{Y}} \newcommand{\cZ}{\mathcal{Z}} \DeclareMathOperator*{\argmax}{argmax~} \DeclareMathOperator*{\argmin}{argmin~} \DeclareMathOperator*{\Tr}{Tr} \DeclareMathOperator*{\Bias}{Bias} \DeclareMathOperator*{\Var}{Var} \newcommand{\Perp}{\perp\!\!\! \perp} \let\dsad\d \renewcommand\d{\mathrm{d}} \newcommand{\R}{\mathbb{R}} \newcommand{\N}{\mathbb{N}} \newcommand{\E}{\mathbb{E}} \newcommand{\Eb}[1]{\mathbb{E} \left[ #1 \right]} \newcommand{\F}{\mathcal{F}} \newcommand{\X}{\mathcal{X}} \newcommand{\vocab}[1]{\textbf{\color{blue} #1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\inner}[2]{ \left \langle #1, #2 \right \rangle } \newcommand{\ibra}[1]{\llbracket #1 \rrbracket} \newenvironment{nalign}{ \begin{equation} \begin{aligned} }{ \end{aligned} \end{equation} \ignorespacesafterend } \newcommand{\icol}[1]{ \left(\begin{smallmatrix}#1\end{smallmatrix}\right)% } \newcommand{\cc}[1]{\overline{#1}} \newcommand{\tr}[1]{\text{tr} \left( #1 \right)} \newcommand{\aff}{\text{aff} \,} \newcommand{\ri}{\text{ri} \, } \newcommand{\rb}{\text{rb} \, } \newcommand{\cl}{\text{cl} \, } \newcommand{\conv}{\text{conv} \, } \newcommand{\irow}[1]{ \begin{smallmatrix}(#1)\end{smallmatrix}% } \newcommand\abs[1]{\lvert #1 \rvert} $$
The goal of this post is to digest the famous VAE paper (missing reference) and try to get some intuition as to how such a method could have been discovered and the motivaton behind it.
Motivation
Suppose we have have tons of unlabeled data. We might guess that this data is actually a result of some generative process that first generates a set of features (or latents) and through some mechanism constructs the data that we currently hold. For instance, in visual domain, given a set of features such as (head pose, facial expression), the generated data would have all of these attributes, just in pixel space. Of course this generative process does not have to be causal, it is just an alternative representation.
We are interested in these latents for a variety of reasons, such as having a compressed representation of the data. But the point is more general: given a generative model, we want to recover the model. In fact, we want to recover 2 models: the forward and the backward one, which we will introduce later.
The forward (or generative or decoding) model takes in a latent representation and produces the correspondng datapoint. Continuing the example from before, given a description of a person’s face, say, (hair color, nose shape, facial expression), the generative model would produce an image in pixel space. If all we consider is these three features, then the latent space is clearly not discriminative enough to produce unique faces. How to deal with it? Resigning to point estimates, we could take the classical encoder route and simply do MLE, but we would lose the distribution of faces under the latent variable.
The backward (or inference or encoding) model is used for converting the data into latent space, arguably the model in which we are most interested in. Continuing the example, in this step we observe an image and want to decompose it into features. There are many features that would capture the pixel space well. In fact, if we wanted to list all of them, we would have to consider observational noise too. Since this is not what we want, we will restrict the dimension of the latents to be smaller than input space[^2]. And once again as in the forward model, the problem of representation comes in – there might not be a single latent variable that would capture the data well. Two types of uncertainty is at play here:
- Inherent observational noise. If there is measurement noise in the data, it is always possible to yield identical recordings from distinct latents.
- The latent space might not be discriminatory enough, especially if we restrict it to have a particular structure (more on this later). Simply put in the context of out example, what if it happens so that the mouth of the person is covered (say due to noise)? Excluding other cues, there is inconclusive envidence to conlcude if a person is smiling or not, and opting out for statistically most likely one might be undesired. However, I think if this is not the case and there are no ambiguities, this backward step might be omitted. We will experiment with this later.
So that is the motivation: given a general model that is defined through a forward/generative and backward/inference phases, we want to recover the distributions of both – do variational inference.
*Interlude: How Autoencoders ``Fail’’ or The Need for Uncertainty
To illustrate the need of probabilistics in this case, take an example of a two Gaussians in 2D. In particular, construct a i.i.d. dataset $\{\bx_i \}$ by sampling from Gaussian, i.e. $\bx_i \sim \cN\left(\icol{0 \\ 0}, \nI \right)$ and construct a classical Autoencoder with a 1-dimensional bottleneck.
What to expect? Since the bottleneck’s dimension is 1, I would expect to see the points projected onto closest cluster’s mean.
autoencoder = AutoEncoder().train()
epochs = 5000
batch_size = 5000
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)
torch_data = torch.from_numpy(joint_data.copy())
loss_step = []
for epoch in tqdm(range(epochs)):
steps = joint_data.shape[0] // batch_size
# reshuffle the data at each epoch to remain iid
rand_idxes = torch.randperm(torch_data.shape[0])
torch_data = torch_data[rand_idxes]
for step in range(steps):
batch = torch_data[step * batch_size : (step+1) * batch_size]
output = autoencoder(batch)["decoded"]
loss = ((output - batch) ** 2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_step.append(loss.mean().item())
In the Figure 2, the cluster points are encoded and then decoded, depicted in red. Clearly this approach is not suited for data generation.
Formal Problem Statement
We assume there is a generative probabilistic model from a parametric famly of distributions described by the set of parameters $\boldsymbol{\theta}$ in which we are interested, these are described by their joint over data $\{ \bx_i \}$ and latent variables $\{\bz_i \}$. We assume the data is i.i.d., so $P_\btheta(X, Z) = \prod_i P_\btheta(\bz_i) P_\btheta(\bx_i |\bz_i)$. Additionally, assume the family of the prior $P_\btheta(\bz)$ and the likelihood $P_\btheta(\bx | \bz)$ are known (or simply assumed), and want to infer the $\btheta$. The likelihood and prior define the forward step completely. For the inference step, we need want to know $P(\bz | \bx)$. This is where the main problem lies. Writing out the posterior
\[\begin{equation} \begin{split} p_\btheta(\bz | \bx) = \frac{p_\btheta(\bz) p_\btheta(\bx | \bz)}{ p_\btheta(\bx)} \end{split} \end{equation}\]the prior over $\bx$ (or the marginal likelihood or the evidence) is attained by marginalizing the likelihood:
\[\begin{equation} \begin{split} p_\btheta(\bx) = \int p_\btheta(\bz) p_\btheta(\bx | \bz) d \bz \end{split} \end{equation}\]But this is incomputable in analytical form under sufficiently complex distribututions and dependence between variables.
If we could somehow get around this problem of intractability, we could do the inference step.
Objective
Defining the Objective: Going Around the Problem of Intractable $P_\btheta(\bx)$ – Maximizing the Evidence Lower Bound (ELBO)
To go around the problem of computing exact evidence term, we will resort to an approximation. The key insights to arrive at the final form are these insights:
- Expoit the easy-to-work-wth terms. We want to make use of as many quantites (that we know of) as possible: the prior over $\bz$ and the likelihood;
- Use $\log$ to separate the terms into managable chunks;
- Isolate the intractable terms that persist by introducing a recognition network – approximate posterior, $q_\bphi(\bx | \bz)$ which would isolate terms that are nonnegative;
- Note that $\text{D}_{\text{KL}}$ is always non-negative.
Putting it all together, we arrive at the ELBO term: \(\begin{equation} \begin{split} \quad & p_\btheta(\bx) &&= \frac{p_\btheta(\bz) p_\btheta(\bx | \bz)}{p_\btheta(\bz | \bx)} &&& \quad (\text{for any } \bz) \\ \Rightarrow \quad & \log p_\btheta(\bx) &&= \log p_\btheta(\bz) + \log p_\btheta(\bx | \bz) - \log p_\btheta(\bz | \bx) \\ & && = \E_{\bz \sim q_\bphi(\bx | \bz)}\left[ \log p_\btheta(\bz) + \log p_\btheta(\bx | \bz) - \log p_\btheta(\bz | \bx) \right] &&& \quad (\text{expectation over all } \bz) \\ & && = \E_{\bz \sim q_\bphi(\bx | \bz)}\left[ \underbrace{\log q_\bphi(\bz | \bx) - \log p_\btheta(\bz |\bx)}_{[\text{up to expectation}] = \text{D}_{\text{KL}}(q_\bphi(\bz | \bx) || p_\btheta(\bz | \bx))} \underbrace{- \log q_\bphi(\bz | \bx) + \log p_\btheta(\bz)}_{[\text{up to expectation}]= -\text{D}_{\text{KL}}(q_\bphi(\bz|\bx) || p_\btheta(\bz)) } + \log p_\btheta(\bx | \bz) \right] &&& \quad (\text{add and subtract} \log q(\bz | \bx)) \\ & && \geq \E_{\bz \sim q_\bphi(\bz | \bx)} \left[\log p_\btheta(\bx | \bz) \right] -\text{D}_{\text{KL}} \left(q_\bphi(\bz|\bx) || p_\btheta(\bz) \right) =: \cL(\btheta, \bphi ; \bx) \quad \text{(ELBO)} &&& \end{split} \end{equation}\) So what remains are two interpretable terms that should be maximized in the ELBO in order to yield a model that is close to the evidence $p_\btheta(\bx)$:
- The expected likelihood term, $\E_{z \sim q_\bphi(\bz \mid \bx)}[\log p_\btheta(\bx \mid \bz)]$, over $\bz$ (so it is different than cross-entropy between the probability distributions, because $p_\btheta(\bx \mid \bz)$ is not a probability distrubution over $\bz$). Intuitivelly, we want to have the recognition network to output latents that yield high likelihood of the data $\bx$. As an example, if the considered datapoint $\bx$ is an image of a smiling person, the prediction network should output latents that encode this feature, which consequently maximizes the likelihood of reconstructing this face.
- Negative KL divergence, $-D_{\text{KL}} \left ( q_{\bphi} (\bz \mid \bx) \mid \mid p_{\btheta} (\bz ) \right )$. This simply means that we want to minimize the “distance” from approximate posterior network $q_\bphi$ to the prior over all the latents. Why?
I don’t have a good understanding of this. But if $\bz$ are sampled from $p_\btheta(\bz)$, then certainly over a large batch of data $\bx$, we would expect that the approximate posterior would be close to the prior.
All is left is to differenciate $\cL(\btheta, \bphi; \bx)$ wrt $\btheta$ and $\bphi$. Since NNs parameterize $P_\btheta$, gradient wrt $\btheta$ is easy (assuming we can calculate KL divergence in closed form, which we can). Differenciating wrt $\bphi$ is trickier, as the parameters of a NN are inside the non-differenciable sampling operator.
Optimizing the objective
The Divergence Term
The KL divergence term between the prior $p_\btheta(\bz)$ and the approximate posterior $q_\bphi(\bz | \bx)$ has a closed-form solution if we assume some structure about the family of the distributions. One convenient family is the Gaussian distributions. As in the paper, we assume the prior is a isometric Gaussian, $p_\btheta(\bz) = \cN(\bzero, \nI)$, and the approximate poterior is another Gaussian with parameters from a NN $q_\bphi(\bz | \bx) = \cN( \bff_{\mu} (\bx), \bff( \bx )_{\sigma}^2 \nI)$. These choices are both convenient, since KL divergence has a closed form solution (derived for the general case here), and also allows for interpretability. Because the prior over $\bz$ is an independent Gaussian, knowing about one feature reveals no information about other features. This means we assume that overall the datapoints generated by the latents are based on independent sampling of features. Additionally, because of the form of approximate posterior, we also assume that given $\bx$, the latents are independent. The KL term has this form:
\[\begin{equation} \begin{split} -\text{D}_{\text{KL}}(q_\bphi(\bz | \bx) || \, p_\btheta(\bz)) &= -\text{D}_{\text{KL}}(\cN(\bff_{\bz_\bmu}(\bx), \bff_{\bz_\bsigma}(\bx)^2 || \, \cN(\bzero, \nI)) \\ &= \frac12 \sum_{j=1}^{\text{dim } \bz} (1 + \log (\bff_{\bz_\bsigma}^{(j)}(\bx)^2) - \bff_{\bz_\bmu}^{(j)}(\bx)^2 - \bff_{\bz_\bsigma}^{(j)}(\bx)^2) \end{split} \end{equation}\]Note: when there is no closed-form solution of the divergence term, we can always fall back to sampling by rewritting the ELBO as the sum of entropy of approximate posterior and the joint of generative model under the approximate posterior.
The Log-likelihood Tem Under Approximate Posterior
The trickier part is the log-likelihood expectation under the approximate posterior. Optimizing it wrt $\btheta$ is easy, but optimizing wrt $\bphi$ requires optimizing non-differentiable sampling operation. Simplest solution is to use the high-variance but unbiased REINFORCE estimator \(\begin{equation} \begin{split} \nabla_\bphi \, \E_{z \sim q_\bphi(\bz | \bx)}[\log p_\btheta(\bx | \bz) ] &= \int \nabla_\bphi q_\bphi(\bz | \bx) \log p_\btheta(\bx | \bz) \, \d\bz \\ &= \int \nabla_\bphi \log q_\bphi(\bz | \bx) q_\bphi(\bz | \bx) \log p_\btheta(\bx | \bz) \, \d\bz \\ &= \E_{\bz \sim q_\bphi(\bz | \bx)} \left [\nabla_\bphi \log q_\bphi(\bz | \bx) \log p_\btheta(\bx | \bz) \right] \end{split} \end{equation}\)
it is easy to try and is unbiased. But it requires sampling and is not computation-friendly.
An alternative approach that does not require sampling is the reparemterization trick. Suppose once again we are interested in general objective of gradient of expectation \(\begin{equation} \begin{split} \nabla_\xi \E_{z \sim \cD_\xi(z)}[f(z)] \end{split} \end{equation}.\)
Additionally assume that the distribution over $z$ with paramter $\xi$ can be described by some other distribution $\cS$ that does not depend on the parameter $\xi$, with a differentiable transformation $h_\xi(\cdot)$, so that for all $u$ it could be rewritten as \(\begin{equation} \begin{split} \cD_\xi(h_\xi(u)) \propto \cS(u). \end{split} \end{equation}\)
In that case, we can replace the original objective with of taking the gradient wrt expectation, with finding an expected value of a gradient wrt distribution that is independent of optimized parameters:
\[\begin{equation} \begin{split} \nabla_\xi \E_{z \sim \cD_\xi(z)}[f(z)] = \E_{\epsilon \sim \cS(\epsilon)}[\nabla_{\xi} f(h_\xi(\epsilon))] \end{split} \end{equation}.\]This is very neat because if we wanted to estiamate the gradient after the forward pass when some $\bz$ was sampled, we could not estimate the gradient for that particular $\bz$ without sampling. But now for any particular $\bz$ the gradient can be recovered exactly. To put it another way, if we need to estimate an expectation in the forward pass, before we had to do an extra expectation in the backward pass. But now, this is unnecessary (
I think. It can be done, but I’m unsure how much benefit it gives).
Example: Take a univariate Gaussian $\cD_{\mu, \sigma} = \cN(\mu, \sigma^2)$, then taking $\cS = \cN(0, 1)$ and $h_{\mu, \sigma}(\epsilon) = \epsilon \sigma + \mu$ is an equivalent reparameterization, i.e. $\cN(z\sigma + \mu; \mu, \sigma^2) \propto \cN(z; 0, 1)$.
Thus using our assumptions about the family of distributions being Gaussian, we can write the log-likelihood under the approximate posterior as an MSE term by sampling $\bz$ $n_z$ times:
\[\begin{multline} \begin{split} \E_{\bz \sim q_\bphi(\bz | \bx)}[\log p_\btheta(\bx | \bz)] &= \E_{\bz \sim \cN \left(\bff_{\bz_\bmu}(\bx), \bff_{\bz_\bsigma}(\bx)^2 \right)}\left [ \log \cN \left(\bff_{\bx_\bmu}(\bz), \bff_{\bx_\bsigma} (\bz)^2\right) \right] \\ & \approx \frac{1}{n_z} \sum_{i = 1, \bz \sim \cN(\cdot)}^{n_z} \log \cN \left(\bff_{\bx_\bmu}(\bz), \bff_{\bx_\bsigma} (\bz)^2\right) \\ &= \frac{1}{n_z} \sum_{i = 1, \bz \sim \cN(\cdot)}^{n_z} \log (1 / (2\pi)^{ \text{dim } \bx / 2}) - \frac{1}{2} \sum_j^{\text{dim } \bx} \log \left(\bff_{\bx_\bsigma}(\bz)^{(j)} \right)^2 -\frac12 (\bx - \bff_{\bx_\bmu}(\bz))^\intercal (\bff_{\bx_\bsigma}(\bz)^2 \nI)^{-1} (\bx - \bff_{\bx_\bmu}(\bz)) \\ &= \frac{1}{n_z} \sum_{i = 1, \bz \sim \cN(\cdot)}^{n_z} - \frac{\text{dim }\bx }{2} \log (2\pi) - \sum_j^{\text{dim } \bx} \left [ \log \abs{\bff_{\bx_\bsigma}(\bz)^{(j)}} + \frac{1}{ 2 (\bff_{\bx_\bsigma}(\bz)^{(j)})^2} \left(\bx_j - \bff_{\bx_\bmu}(\bz)^{(j)} \right)^2 \right] \end{split} \end{multline}\]where we use the fact that diagonal matrix’s eigenvalues are on the diagonal, so the determinant is simply the product of them (and consequently the log is the sum of the logs of variances). Note that we do not need to sample $\bx$ from the decoder, all we need to do is sample $\bz$ and construct the distribution $p_\btheta(\bx | \bz)$.
Since this expecation should be maximized in the ELBO, the non-constant log-likelihood term should be minimized. The reconstruction error is weighted by the inverse of variance per $\bx$ dimension. This term tries to both minimize the reconstruction error, while at the same time trying to maximize the variance. However, the first term penalizes high variance, and so a tradeoff is present between the high-variance and and highly-accurate mean solutions. In total, this term is maximized when the solution is peaked at the mean with variance approaching to zero.
Implementation
Assumptions:
- The prior over features is a Gaussian $p_\btheta(\bz) = \cN(\bzero, \nI)$;
- The likelihood is a Gaussian with with parameters from a NN: $p_\btheta(\bx | \bz) = \cN(\be_\mu(\bz), \be_\sigma(\bz)^2)$;
- The approximate posterior (recognition model) is also a Gaussian with parameters from a NN: $q_\bphi(\bz | \bx) = \cN(\bd_\mu(\bx), \bd_\sigma(\bx)^2)$.
The procedure is then as follows:
- For a batch of samples $\cB$, for each $\bx \in \cB$: Encoding for number of samplings $n_z$ do
- Encode the sample deterministically by taking the mean and variance from the NN: $\bmu_\bz = \be_\mu(\bx), \bsigma_\bz = \be_\sigma(\bx)$
- Sample a latent from the Gaussian using the parameters from step 1 using reparameterization trick: $\bepsilon \sim N(\bzero, \nI), \bz = \bz_\sigma \bepsilon + \bz_\mu$
Decoding do and for number of samplings $n_x$ do
- Take the sample $\bz$ from the encoding step and decode deterministically by taking the mean and variance from the NN: $\bmu_\bx = \bd_\mu(\bz), \bsigma_\bx = \bd_\bsigma(\bz)$
- Compute the loss for the batch (and all the samplings):
The code is a bit messy due to option to sample variable number of times in the forward and backward passes.
Note: To enforce the variance to be non-negative, we use a common log-exp trick to make the network predict log-variance, which we then exponentiate where needed. This adds stability to the network because the scale of network outputs will not be out of proportion in case large values of variance need to be predicted.
Anyway, a simple one-layer VAE is as follows
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
class VariationalAutoEncoder(nn.Module):
def __init__(self, input_size=2, hidden_size=10, bottleneck_size=1, n_sample_z=3):
super(VariationalAutoEncoder, self).__init__()
self.input_hidden = nn.Linear(input_size, hidden_size)
self.hidden_bottleneck = nn.Linear(hidden_size, 2 * bottleneck_size)
self.bottleneck_hidden = nn.Linear(bottleneck_size, hidden_size)
self.hidden_output = nn.Linear(hidden_size, input_size * 2)
self.bottleneck_size = bottleneck_size
self.input_size = input_size
self.n_sample_z = n_sample_z
def forward(self, x):
"""
x: (n_batches, datapoint_dim)
returns reconstructed \hatx: (B, n_samples_z, n_samples_x, datapoint_dim)
"""
B, _ = x.shape
x = F.relu(self.input_hidden(x))
# Get the means and s.d.s for the approximate posterior network
q_params = self.hidden_bottleneck(x).view((B, 1, self.bottleneck_size, 2))
mu_z, sigma_z = q_params[..., 0], q_params[..., 1]
# Ensure non-negative variance
sigma_z = torch.abs(sigma_z)
# Sample from standard Gaussian n_sample_z times => (B, n_samples_z, bottleneck_dim)
# Because we assume diagonal covariance, sample from B * bottleneck_size * n_sample_z univariate Gaussians
sampled_eps_z = torch.normal(
0.0, 1.0, size=(B, self.n_sample_z, self.bottleneck_size)
)
# B x n_sample_z x bottleneck_size
sampled_z = sampled_eps_z * sigma_z + mu_z
# (B * n_sample_z) x bottleneck_size
permuted_flat_sampled_z = sampled_z.view(
B * self.n_sample_z, self.bottleneck_size
)
h = F.relu(self.bottleneck_hidden(permuted_flat_sampled_z))
p_likelihood_params = self.hidden_output(h).view(
(B, self.n_sample_z, self.input_size, 2)
)
mu_x, sigma_x = p_likelihood_params[..., 0], p_likelihood_params[..., 1]
sigma_x = torch.abs(sigma_x)
# Sample n_sample_x times
sampled_eps_x = torch.normal(
0.0, 1.0, size=(B, self.n_sample_z, self.input_size)
)
sampled_x = sampled_eps_x * sigma_x + mu_x
return {
"encoded": sampled_z,
"mu_z": mu_z,
"sigma_z": sigma_z,
"decoded": sampled_x,
"mu_x": mu_x,
"sigma_x": sigma_x,
}
While the ELBO calculation is implemented as follows:
def get_ELBO(self, input_x, forward_pass_output):
"""
Returns the evidence lower bound E_q(z|x)[log p(x|z)] - D_KL(q||p)
assuming Gaussians both in forward and backward models _per batch_
"""
B, _ = input_x.shape
negative_d_kl_term = 0.5 * torch.sum(
1.0
+ torch.log(
forward_pass_output["sigma_z"].view(B, self.bottleneck_size) ** 2
)
- forward_pass_output["mu_z"].view(B, self.bottleneck_size) ** 2
- forward_pass_output["sigma_z"].view(B, self.bottleneck_size) ** 2,
dim=1,
) # closed form, simple term when diagonal covariances
const_term = -self.input_size / 2 * np.log(2 * torch.pi)
log_variance_term = -torch.mean(
torch.sum(torch.log(forward_pass_output["sigma_x"]), dim=2), dim=1
)
quadratic_term = -torch.mean(
1.0
/ 2.0
* torch.sum(
1.0 / (2.0 * (forward_pass_output["sigma_x"]) ** 2)
+ (input_x.view(B, 1, -1) - forward_pass_output["mu_x"]) ** 2,
dim=2,
),
dim=1,
)
likelihood_under_appx_posterior = (
const_term + log_variance_term + quadratic_term
) # log Gaussian is proprotional to MSE
ELBO = negative_d_kl_term + const_term + likelihood_under_appx_posterior
return likelihood_under_appx_posterior
Result
We replicate the setting that was used with an autoencoder with 1-dimensional bottleneck with a VAE. Even though the dimensionality does not permit of encoding the input information perfectly, the network still manages to produce realistic-looking results (see Figure 3).
Enjoy Reading This Article?
Here are some more articles you might like to read next: