Reweighted Wake-Sleep Jörg Bornschein & Yoshua Bengio University of Montreal & CIFAR
2015/05/08
Reweighted Wake-Sleep
Helmholtz Machines We want to train a directed generative model p
generative network
inference network
p(h)L p(hl |hl+1)
p(x, h) = p(x|h1 )p(h1 |h2 )...p(hL ) q(h|x) = q(h1 |x)q(h2 |h1 )...q(hL |hL−1 )
q(hl+1|hl )
Reweighted Wake-Sleep
Helmholtz Machines
Approaches: Wake-Sleep algorithm (Frey, Hinton, Dayan, Neal; 1995) Neural Variational Inference and Learning (Mnih & Gregor; 2014) Variational Autoencoder (Kingma & Welling; 2014 / Renzede et.al.) ... many more ... ...are typically based on the variational bound of the log-likelihood: log p(x ) ≥
X h
q(h|x) log p(x, h) + H(q(h|x))
Reweighted Wake-Sleep
Helmholtz Machines
Approaches: Wake-Sleep algorithm (Frey, Hinton, Dayan, Neal; 1995) Neural Variational Inference and Learning (Mnih & Gregor; 2014) Variational Autoencoder (Kingma & Welling; 2014 / Renzede et.al.) ... many more ... ...are typically based on the variational bound of the log-likelihood: log p(x ) ≥
X h
q(h|x) log p(x, h) + H(q(h|x))
Reweighted Wake-Sleep
Log-Likelihood Estimation We start with an important sampling estimate: p(x) =
X h
'
1 K
p(x, h) p(x, h) = E q (h | x) q (h | x) h∼q(h | x) q (h | x)
K X
p(x, h(k) ) q h(k) | x
k=1 h(k) ∼q(h | x)
unbiased estimator variance depends on the proposal distribution q(h|x) minimum variance is obtained with q(h|x) = p(h|x) (zero variance)
Reweighted Wake-Sleep
Reweighted Wake-Sleep
From this we can derive an estimator for the parameter gradient: K X ∂ ∂ Lp ' log p(x, h(k) ) ω ˜k ∂θ ∂θ k=1
with h(k) ∼ q (h | x) and ω ˜ k = PK
ωk
k 0 =1 ωk 0
=> No variational approximation!
; ωk =
p(x, h(k) ) q h(k) | x
Reweighted Wake-Sleep
Reweighted Wake-Sleep
What is the objective for q(h|x)? Minimize variance! train q(h|x) to approximate p(h|x)!
What x do we use when training q(h|x)? from the training data set x ∼ D or
(wake phase update)
from the current model x, h ∼ p(x, h) (sleep phase update)
Reweighted Wake-Sleep
Reweighted Wake-Sleep Sleep phase q-update: consider x, h ∼ p(x, h) a fully observed sample calculate the gradient
∂ ∂φ Lq
=
∂ ∂φ
log q(h|x)
Wake phase q-update: K X ∂ ∂ Lq ' ω ˜k log q(x, h(k) ) ∂φ ∂φ k=1
With the same weights ω ˜ used during the p-update!
Reweighted Wake-Sleep
Reweighted Wake-Sleep Sleep phase q-update: consider x, h ∼ p(x, h) a fully observed sample calculate the gradient
∂ ∂φ Lq
=
∂ ∂φ
log q(h|x)
Wake phase q-update: K X ∂ ∂ Lq ' ω ˜k log q(x, h(k) ) ∂φ ∂φ k=1
With the same weights ω ˜ used during the p-update!
Reweighted Wake-Sleep
Reweighted Wake-Sleep (Reweighted) Wake-Sleep q-update: arg min KL(pΘ (h|x) | qΦ (h|x)) Φ
Variational approaches : arg min KL(qΦ (h|x) | pΘ (h|x)) Φ
RWS with K=1 sample and sleep phase update only is equivalent to classical WS
Reweighted Wake-Sleep
Reweighted Wake-Sleep (Reweighted) Wake-Sleep q-update: arg min KL(pΘ (h|x) | qΦ (h|x)) Φ
Variational approaches : arg min KL(qΦ (h|x) | pΘ (h|x)) Φ
RWS with K=1 sample and sleep phase update only is equivalent to classical WS
Reweighted Wake-Sleep
Empirical results using 5-10 samples during training gives significant better results than classical WS combining wake and sleep phase q-updates consistently gives best results applied to binarized MNIST, RWS with 5-6 layers results in competitive models e.g: 5 hidden layers with 10, 100, 200, 300, 400, 784 units ⇒ NLL ≈ 85.5 SBN/SBN 10-100-200-300-400
Reweighted Wake-Sleep
Empirical Results
Sensitivity to number of test samples −80
0
1.6 1.4
5
1.2
10
bias (epoch 50) bias (last epoch) 0.8 std. dev. (epoch 50) std. dev. (last epoch) 0.6
15
0.4
1.0
std.-dev.
−100
bias
LL estimate
−90
−110
10-100-200-300-400-784
−120
10-200-200-784
0.2
200-784
−130 100
101
# samples
102
103
20
100
101
2
# samples 10
0.0
103
Reweighted Wake-Sleep
Empirical Results
0.9
1.8
−90
0.8
1.6
0.7
1.4
−100
−110
0.5 −120
−130 100
10-100-200-300-400-784 10-200-200-784 200-784 101
# samples during training
0.4
102
1.2
0.6
0.3 0 10
bias (epoch 50) bias (last epoch) std dev. (epoch50) std dev. (last epoch) 10
1
1.0 0.8
10
2
# samples during training
10
0.6
3
std.-dev.
−80
bias
Final LL estimate
Sensitivity to number of samples during training