Reweighted Wake-Sleep - ICLR

Report 5 Downloads 23 Views
Reweighted Wake-Sleep Jörg Bornschein & Yoshua Bengio University of Montreal & CIFAR

2015/05/08

Reweighted Wake-Sleep

Helmholtz Machines We want to train a directed generative model p

generative network

inference network

p(h)L p(hl |hl+1)

p(x, h) = p(x|h1 )p(h1 |h2 )...p(hL ) q(h|x) = q(h1 |x)q(h2 |h1 )...q(hL |hL−1 )

q(hl+1|hl )

Reweighted Wake-Sleep

Helmholtz Machines

Approaches: Wake-Sleep algorithm (Frey, Hinton, Dayan, Neal; 1995) Neural Variational Inference and Learning (Mnih & Gregor; 2014) Variational Autoencoder (Kingma & Welling; 2014 / Renzede et.al.) ... many more ... ...are typically based on the variational bound of the log-likelihood: log p(x ) ≥

X h

q(h|x) log p(x, h) + H(q(h|x))

Reweighted Wake-Sleep

Helmholtz Machines

Approaches: Wake-Sleep algorithm (Frey, Hinton, Dayan, Neal; 1995) Neural Variational Inference and Learning (Mnih & Gregor; 2014) Variational Autoencoder (Kingma & Welling; 2014 / Renzede et.al.) ... many more ... ...are typically based on the variational bound of the log-likelihood: log p(x ) ≥

X h

q(h|x) log p(x, h) + H(q(h|x))

Reweighted Wake-Sleep

Log-Likelihood Estimation We start with an important sampling estimate: p(x) =

X h

'

1 K

p(x, h) p(x, h) = E q (h | x) q (h | x) h∼q(h | x) q (h | x) 

K X

p(x, h(k) )  q h(k) | x

k=1 h(k) ∼q(h | x)

unbiased estimator variance depends on the proposal distribution q(h|x) minimum variance is obtained with q(h|x) = p(h|x) (zero variance)



Reweighted Wake-Sleep

Reweighted Wake-Sleep

From this we can derive an estimator for the parameter gradient: K X ∂ ∂ Lp ' log p(x, h(k) ) ω ˜k ∂θ ∂θ k=1

with h(k) ∼ q (h | x) and ω ˜ k = PK

ωk

k 0 =1 ωk 0

=> No variational approximation!

; ωk =

p(x, h(k) )  q h(k) | x

Reweighted Wake-Sleep

Reweighted Wake-Sleep

What is the objective for q(h|x)? Minimize variance! train q(h|x) to approximate p(h|x)!

What x do we use when training q(h|x)? from the training data set x ∼ D or

(wake phase update)

from the current model x, h ∼ p(x, h) (sleep phase update)

Reweighted Wake-Sleep

Reweighted Wake-Sleep Sleep phase q-update: consider x, h ∼ p(x, h) a fully observed sample calculate the gradient

∂ ∂φ Lq

=

∂ ∂φ

log q(h|x)

Wake phase q-update: K X ∂ ∂ Lq ' ω ˜k log q(x, h(k) ) ∂φ ∂φ k=1

With the same weights ω ˜ used during the p-update!

Reweighted Wake-Sleep

Reweighted Wake-Sleep Sleep phase q-update: consider x, h ∼ p(x, h) a fully observed sample calculate the gradient

∂ ∂φ Lq

=

∂ ∂φ

log q(h|x)

Wake phase q-update: K X ∂ ∂ Lq ' ω ˜k log q(x, h(k) ) ∂φ ∂φ k=1

With the same weights ω ˜ used during the p-update!

Reweighted Wake-Sleep

Reweighted Wake-Sleep (Reweighted) Wake-Sleep q-update: arg min KL(pΘ (h|x) | qΦ (h|x)) Φ

Variational approaches : arg min KL(qΦ (h|x) | pΘ (h|x)) Φ

RWS with K=1 sample and sleep phase update only is equivalent to classical WS

Reweighted Wake-Sleep

Reweighted Wake-Sleep (Reweighted) Wake-Sleep q-update: arg min KL(pΘ (h|x) | qΦ (h|x)) Φ

Variational approaches : arg min KL(qΦ (h|x) | pΘ (h|x)) Φ

RWS with K=1 sample and sleep phase update only is equivalent to classical WS

Reweighted Wake-Sleep

Empirical results using 5-10 samples during training gives significant better results than classical WS combining wake and sleep phase q-updates consistently gives best results applied to binarized MNIST, RWS with 5-6 layers results in competitive models e.g: 5 hidden layers with 10, 100, 200, 300, 400, 784 units ⇒ NLL ≈ 85.5 SBN/SBN 10-100-200-300-400

Reweighted Wake-Sleep

Empirical Results

Sensitivity to number of test samples −80

0

1.6 1.4

5

1.2

10

bias (epoch 50) bias (last epoch) 0.8 std. dev. (epoch 50) std. dev. (last epoch) 0.6

15

0.4

1.0

std.-dev.

−100

bias

LL estimate

−90

−110

10-100-200-300-400-784

−120

10-200-200-784

0.2

200-784

−130 100

101

# samples

102

103

20

100

101

2

# samples 10

0.0

103

Reweighted Wake-Sleep

Empirical Results

0.9

1.8

−90

0.8

1.6

0.7

1.4

−100

−110

0.5 −120

−130 100

10-100-200-300-400-784 10-200-200-784 200-784 101

# samples during training

0.4

102

1.2

0.6

0.3 0 10

bias (epoch 50) bias (last epoch) std dev. (epoch50) std dev. (last epoch) 10

1

1.0 0.8

10

2

# samples during training

10

0.6

3

std.-dev.

−80

bias

Final LL estimate

Sensitivity to number of samples during training