What is the Wake Sleep Algorithm π±
Hemholtz Machine
Β | Β |
---|---|
We have two networks:
- Recognition network with weights \(\phi\) converts input data \(\mathbf{x}\) into latent representations used in successive hidden states \(\mathbf{z}\).
- Generative network reconstructs data from the latent states using weights \(\theta\).
The Hemholtz machine tries to learn \(\phi\) and \(\theta\) such that \(q_{\phi}(\mathbf{z} \mid \mathbf{x}) \approx p_{\theta}(\mathbf{z} \mid \mathbf{x}) \propto p_{\theta}(\mathbf{z},\mathbf{x})\)
- \(q_{\phi}(\mathbf{z} \mid \mathbf{x})\) is the variational distribution approximating the posterior \(p_{\theta}(\mathbf{z}\mid\mathbf{x})\)
Wake Phase
- Feed \(\mathbf{x}^{(i)}\) into the recognition network to get \(\mu_{\phi}\left(\mathbf{x}^{(i)}\right)\) and \(\Sigma_{\phi}\left(\mathbf{x}^{(i)}\right)\)
- Draw \(L\) samples \(\mathbf{z}_{1}^{(i)}, \ldots, \mathbf{z}_{L}^{(i)} \sim q_{\phi}\left(\mathbf{z} \mid \mathbf{x}^{(i)}\right)=N\left(\mathbf{z} ; \mu_{\phi}\left(\mathbf{x}^{(i)}\right), \Sigma_{\phi}\left(\mathbf{x}^{(i)}\right)\right)\)
- For each \(l \in[L],\) feed \(\mathbf{z}_{l}^{(i)}\) into the generative network to get \(f_{\theta}\left(\mathbf{z}_{l}^{(i)}\right)\) for the likelihood \(p_{\theta}\left(\mathbf{x} \mid \mathbf{z}_{l}^{(i)}\right)=\) Bernoulli(x; \(\left.f_{\theta}\left(\mathbf{z}_{l}^{(i)}\right)\right)\)
- Optimize for \(\max _{\theta} \sum_{i=1}^{N} \frac{1}{L} \sum_{l=1}^{L} \log p_{\theta}\left(\mathbf{x}^{(i)} \mid \mathbf{z}_{l}^{(i)}\right)\)
- Simulate latent state by feeding input data to recognition network and maximize how well the generatorβs probabilities for the hidden state fit the actual data.
Sleep Phase
- Draw \(\mathbf{z}^{l} \sim N(0, I)\)
- Sample \(\mathbf{x}^{l}\) from the generative network \(p_{\theta}\left(\mathbf{x} \mid \mathbf{z}^{l}\right)= {Bernoulli}\left(f_{\theta}\left(\mathbf{z}^{l}\right)\right)\)
- Feed \(\mathbf{x}^{l}\) into the recognition network to get \(\mu\left(\mathbf{x}^{l}\right)\) and \(\Sigma\left(\mathbf{x}^{l}\right)\)
- Compute \(q_{\phi}\left(\mathbf{z}^{l} \mid \mathbf{x}^{l}\right)=N\left(\mathbf{z}^{l} ; \mu\left(\mathbf{x}^{l}\right), \Sigma\left(\mathbf{x}^{l}\right)\right)\)
- Optimize \(\max _{\phi} \frac{1}{L} \sum_{l=1}^{L} \log q_{\phi}\left(\mathbf{z}^{l} \mid \mathbf{x}^{l}\right)\)
- Simulate random \(\mathbf{x}\) data by following the generator. Then maximize the probability that the recognition network suggests the correct latent states given the simulated \(\mathbf{x}\).
Notes mentioning this note
There are no notes linking to this note.