What is the Wake Sleep Algorithm 🌱

Hemholtz Machine

Β  Β 
Screen Shot 2020-04-30 at 8.18.52 PM

We have two networks:

  1. Recognition network with weights \(\phi\) converts input data \(\mathbf{x}\) into latent representations used in successive hidden states \(\mathbf{z}\).
  2. Generative network reconstructs data from the latent states using weights \(\theta\).

The Hemholtz machine tries to learn \(\phi\) and \(\theta\) such that \(q_{\phi}(\mathbf{z} \mid \mathbf{x}) \approx p_{\theta}(\mathbf{z} \mid \mathbf{x}) \propto p_{\theta}(\mathbf{z},\mathbf{x})\)

  • \(q_{\phi}(\mathbf{z} \mid \mathbf{x})\) is the variational distribution approximating the posterior \(p_{\theta}(\mathbf{z}\mid\mathbf{x})\)

Wake Phase

  1. Feed \(\mathbf{x}^{(i)}\) into the recognition network to get \(\mu_{\phi}\left(\mathbf{x}^{(i)}\right)\) and \(\Sigma_{\phi}\left(\mathbf{x}^{(i)}\right)\)
  2. Draw \(L\) samples \(\mathbf{z}_{1}^{(i)}, \ldots, \mathbf{z}_{L}^{(i)} \sim q_{\phi}\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right)=N\left(\mathbf{z} ; \mu_{\phi}\left(\mathbf{x}^{(i)}\right), \Sigma_{\phi}\left(\mathbf{x}^{(i)}\right)\right)\)
  3. For each \(l \in[L],\) feed \(\mathbf{z}_{l}^{(i)}\) into the generative network to get \(f_{\theta}\left(\mathbf{z}_{l}^{(i)}\right)\) for the likelihood \(p_{\theta}\left(\mathbf{x} \mid \mathbf{z}_{l}^{(i)}\right)=\) Bernoulli(x; \(\left.f_{\theta}\left(\mathbf{z}_{l}^{(i)}\right)\right)\)
  4. Optimize for \(\max _{\theta} \sum_{i=1}^{N} \frac{1}{L} \sum_{l=1}^{L} \log p_{\theta}\left(\mathbf{x}^{(i)} \mid \mathbf{z}_{l}^{(i)}\right)\)
  • Simulate latent state by feeding input data to recognition network and maximize how well the generator’s probabilities for the hidden state fit the actual data.

Sleep Phase

  1. Draw \(\mathbf{z}^{l} \sim N(0, I)\)
  2. Sample \(\mathbf{x}^{l}\) from the generative network \(p_{\theta}\left(\mathbf{x} \mid \mathbf{z}^{l}\right)= {Bernoulli}\left(f_{\theta}\left(\mathbf{z}^{l}\right)\right)\)
  3. Feed \(\mathbf{x}^{l}\) into the recognition network to get \(\mu\left(\mathbf{x}^{l}\right)\) and \(\Sigma\left(\mathbf{x}^{l}\right)\)
  4. Compute \(q_{\phi}\left(\mathbf{z}^{l} \mid \mathbf{x}^{l}\right)=N\left(\mathbf{z}^{l} ; \mu\left(\mathbf{x}^{l}\right), \Sigma\left(\mathbf{x}^{l}\right)\right)\)
  5. Optimize \(\max _{\phi} \frac{1}{L} \sum_{l=1}^{L} \log q_{\phi}\left(\mathbf{z}^{l} \mid \mathbf{x}^{l}\right)\)
  • Simulate random \(\mathbf{x}\) data by following the generator. Then maximize the probability that the recognition network suggests the correct latent states given the simulated \(\mathbf{x}\).

Notes mentioning this note

There are no notes linking to this note.


Here are all the notes in this garden, along with their links, visualized as a graph.