Learning Overcomplete Latent Variable Models through Tensor Methods

Report 2 Downloads 88 Views
Learning Overcomplete Latent Variable Models through Tensor Methods Anima Anandkumar UC Irvine

Joint work with

Majid Janzamin UC Irvine

Rong Ge Microsoft Research

Latent Variable Probabilistic Models Latent (hidden) variable h ∈ Rk , observed variable x ∈ Rd .

Latent Variable Probabilistic Models Latent (hidden) variable h ∈ Rk , observed variable x ∈ Rd . Multiview linear mixture models Categorical hidden variable h. Views: conditionally indep. given h. Linear model: E[x1 |h] = ah , E[x2 |h] = bh , E[x3 |h] = ch .

h x1

x2

x3

···

Latent Variable Probabilistic Models Latent (hidden) variable h ∈ Rk , observed variable x ∈ Rd . Multiview linear mixture models Categorical hidden variable h. Views: conditionally indep. given h. Linear model: E[x1 |h] = ah , E[x2 |h] = bh , E[x3 |h] = ch .

Gaussian Mixture Categorical hidden variable h. x|h ∼ N (µh , Σh ).

h x1

x2

x3

···

Latent Variable Probabilistic Models Latent (hidden) variable h ∈ Rk , observed variable x ∈ Rd . Multiview linear mixture models Categorical hidden variable h. Views: conditionally indep. given h. Linear model: E[x1 |h] = ah , E[x2 |h] = bh , E[x3 |h] = ch .

Gaussian Mixture Categorical hidden variable h. x|h ∼ N (µh , Σh ). ICA, Sparse Coding, HMM, Topic modeling, . . .

h x1

x2

x3

···

Latent Variable Probabilistic Models Latent (hidden) variable h ∈ Rk , observed variable x ∈ Rd . Multiview linear mixture models Categorical hidden variable h. Views: conditionally indep. given h. Linear model:

h x1

E[x1 |h] = ah , E[x2 |h] = bh , E[x3 |h] = ch .

Gaussian Mixture Categorical hidden variable h. x|h ∼ N (µh , Σh ). ICA, Sparse Coding, HMM, Topic modeling, . . . Efficient Learning of the parameters ah , µh , . . . ?

x2

x3

···

Method-of-Moments (Spectral methods) Multi-variate observed moments M1 := E[x],

M2 := E[x ⊗ x],

M3 := E[x ⊗ x ⊗ x].

Method-of-Moments (Spectral methods) Multi-variate observed moments M1 := E[x],

M2 := E[x ⊗ x],

M3 := E[x ⊗ x ⊗ x].

Matrix E[x ⊗ x] ∈ Rd×d is a second order tensor. E[x ⊗ x]i1 ,i2 = E[xi1 xi2 ].

For matrices: E[x ⊗ x] = E[xx⊤ ].

Method-of-Moments (Spectral methods) Multi-variate observed moments M1 := E[x],

M2 := E[x ⊗ x],

M3 := E[x ⊗ x ⊗ x].

Matrix E[x ⊗ x] ∈ Rd×d is a second order tensor. E[x ⊗ x]i1 ,i2 = E[xi1 xi2 ].

For matrices: E[x ⊗ x] = E[xx⊤ ].

Tensor E[x ⊗ x ⊗ x] ∈ Rd×d×d is a third order tensor. E[x ⊗ x ⊗ x]i1 ,i2 ,i3 = E[xi1 xi2 xi3 ].

Method-of-Moments (Spectral methods) Multi-variate observed moments M1 := E[x],

M2 := E[x ⊗ x],

M3 := E[x ⊗ x ⊗ x].

Matrix E[x ⊗ x] ∈ Rd×d is a second order tensor. E[x ⊗ x]i1 ,i2 = E[xi1 xi2 ].

For matrices: E[x ⊗ x] = E[xx⊤ ].

Tensor E[x ⊗ x ⊗ x] ∈ Rd×d×d is a third order tensor. E[x ⊗ x ⊗ x]i1 ,i2 ,i3 = E[xi1 xi2 xi3 ].

Information in moments for learning LVMs?

Multiview Mixture Model [k] := {1, . . . , k}.

Multiview linear mixture models Categorical hidden variable h ∈ [k].

h

wj := Pr[h = j]

Views: conditionally indep. given h. Linear model: E[x1 |h] = ah , E[x2 |h] = bh , E[x3 |h] = ch .

x1

x2

x3

···

Multiview Mixture Model

[k] := {1, . . . , k}. Multiview linear mixture models

Categorical hidden variable h ∈ [k].

h

wj := Pr[h = j]

Views: conditionally indep. given h.

x1

Linear model: E[x1 |h] = ah , E[x2 |h] = bh , E[x3 |h] = ch . x1 x⊤

z }|2 { Ex [x1 ⊗ x2 ] = Eh [Ex [x1 ⊗ x2 |h]] = Eh [ah ⊗ bh ] X = wj aj ⊗ bj . j∈[k]

x2

x3

···

Multiview Mixture Model [k] := {1, . . . , k}.

Multiview linear mixture models Categorical hidden variable h ∈ [k].

h

wj := Pr[h = j]

Views: conditionally indep. given h. Linear model: E[x1 |h] = ah , E[x2 |h] = bh , E[x3 |h] = ch . X E[x1 ⊗ x2 ] = wj aj ⊗ bj , j∈[k]

x1

x2

x3

···

Multiview Mixture Model [k] := {1, . . . , k}.

Multiview linear mixture models Categorical hidden variable h ∈ [k].

h

wj := Pr[h = j]

Views: conditionally indep. given h.

x1

Linear model: E[x1 |h] = ah , E[x2 |h] = bh , E[x3 |h] = ch . X E[x1 ⊗ x2 ] = wj aj ⊗ bj , j∈[k]

E[x1 ⊗ x2 ⊗ x3 ] =

X

j∈[k]

wj aj ⊗ bj ⊗ cj .

x2

x3

···

Multiview Mixture Model [k] := {1, . . . , k}.

Multiview linear mixture models Categorical hidden variable h ∈ [k].

h

wj := Pr[h = j]

Views: conditionally indep. given h.

x1

Linear model: E[x1 |h] = ah , E[x2 |h] = bh , E[x3 |h] = ch . X E[x1 ⊗ x2 ] = wj aj ⊗ bj , j∈[k]

E[x1 ⊗ x2 ⊗ x3 ] =

X

j∈[k]

wj aj ⊗ bj ⊗ cj .

Tensor (matrix) factorization for learning LVMs.

x2

x3

···

Tensor Rank and Tensor Decomposition Rank-1 tensor:

T = w · a ⊗ b ⊗ c ⇔ T (i, j, l) = w · a(i) · b(j) · c(l).

Tensor Rank and Tensor Decomposition Rank-1 tensor:

T = w · a ⊗ b ⊗ c ⇔ T (i, j, l) = w · a(i) · b(j) · c(l).

CANDECOMP/PARAFAC (CP) Decomposition T =

X

j∈[k]

wj aj ⊗ bj ⊗ cj ∈ Rd×d×d ,

=

Tensor T

w1 · a1 ⊗ b1 ⊗ c1

+

aj , bj , cj ∈ S d−1 .

....

w2 · a2 ⊗ b2 ⊗ c2

Tensor Rank and Tensor Decomposition Rank-1 tensor:

T = w · a ⊗ b ⊗ c ⇔ T (i, j, l) = w · a(i) · b(j) · c(l).

CANDECOMP/PARAFAC (CP) Decomposition T =

X

j∈[k]

wj aj ⊗ bj ⊗ cj ∈ Rd×d×d ,

=

Tensor T

aj , bj , cj ∈ S d−1 .

....

+

w1 · a1 ⊗ b1 ⊗ c1

w2 · a2 ⊗ b2 ⊗ c2

k: tensor rank, d: ambient dimension. k ≤ d: undercomplete and k > d: overcomplete.

Tensor Rank and Tensor Decomposition Rank-1 tensor:

T = w · a ⊗ b ⊗ c ⇔ T (i, j, l) = w · a(i) · b(j) · c(l).

CANDECOMP/PARAFAC (CP) Decomposition T =

X

j∈[k]

wj aj ⊗ bj ⊗ cj ∈ Rd×d×d ,

=

Tensor T

aj , bj , cj ∈ S d−1 .

....

+

w1 · a1 ⊗ b1 ⊗ c1

w2 · a2 ⊗ b2 ⊗ c2

k: tensor rank, d: ambient dimension. k ≤ d: undercomplete and k > d: overcomplete. This talk: guarantees for overcomplete tensor decomposition

Challenges in Tensor X Decomposition

Symmetric tensor T ∈ Rd×d×d : T =

i∈[k]

λ i vi ⊗ vi ⊗ vi .

Challenges in tensors Decomposition may not always exist for general tensors. Finding the decomposition is NP-hard in general.

Challenges in Tensor X Decomposition

Symmetric tensor T ∈ Rd×d×d : T =

i∈[k]

λ i vi ⊗ vi ⊗ vi .

Challenges in tensors Decomposition may not always exist for general tensors. Finding the decomposition is NP-hard in general.

Tractable case: orthogonal tensor decomposition (hvi , vj i = 0, i 6= j) Algorithm:

tensor power method: v 7→

• {vi }’s are the only robust fixed points.

T (I, v, v) . kT (I, v, v)k

Challenges in Tensor X Decomposition

Symmetric tensor T ∈ Rd×d×d : T =

i∈[k]

λ i vi ⊗ vi ⊗ vi .

Challenges in tensors Decomposition may not always exist for general tensors. Finding the decomposition is NP-hard in general.

Tractable case: orthogonal tensor decomposition (hvi , vj i = 0, i 6= j) Algorithm:

tensor power method: v 7→

• {vi }’s are the only robust fixed points.

T (I, v, v) . kT (I, v, v)k

• All other eigenvectors are saddle points.

Challenges in Tensor X Decomposition

Symmetric tensor T ∈ Rd×d×d : T =

i∈[k]

λ i vi ⊗ vi ⊗ vi .

Challenges in tensors Decomposition may not always exist for general tensors. Finding the decomposition is NP-hard in general.

Tractable case: orthogonal tensor decomposition (hvi , vj i = 0, i 6= j) Algorithm:

tensor power method: v 7→

• {vi }’s are the only robust fixed points.

T (I, v, v) . kT (I, v, v)k

• All other eigenvectors are saddle points.

For an orthogonal tensor, no spurious local optima!

Beyond Orthogonal Tensor Decomposition Limitations Not ALL tensors have orthogonal decomposition (unlike matrices).

Beyond Orthogonal Tensor Decomposition Limitations Not ALL tensors have orthogonal decomposition (unlike matrices).

Undercomplete tensors (k ≤ d) with full rank components Non-orthogonal decomposition T1 = Whitening matrix W :

P

i wi ai

⊗ ai ⊗ ai . a1 a2 a3

W

v1 v3

Multilinear transform: T2 = T1 (W, W, W )

Tensor T1

Tensor T2

v2

Beyond Orthogonal Tensor Decomposition Limitations Not ALL tensors have orthogonal decomposition (unlike matrices).

Undercomplete tensors (k ≤ d) with full rank components Non-orthogonal decomposition T1 = Whitening matrix W :

P

i wi ai

⊗ ai ⊗ ai . a1 a2 a3

W

v1 v3

Multilinear transform: T2 = T1 (W, W, W )

Tensor T1

Tensor T2

This talk: guarantees for overcomplete tensor decomposition

v2

Outline

1

Introduction

2

Overcomplete tensor decomposition

3

Sample Complexity Analysis

4

Conclusion

Our Setup

So far General tensor decomposition: NP-hard. Orthogonal tensors: too limiting. Tractable cases? Covers overcomplete tensors?

Our Setup

So far General tensor decomposition: NP-hard. Orthogonal tensors: too limiting. Tractable cases? Covers overcomplete tensors?

Our framework: Incoherent Components

 √  |hai , aj i| = O 1/ d for i 6= j. Similarly for b, c.

Can handle overcomplete tensors. Satisfied by random vectors. Guaranteed recovery for alternating minimization?

Alternating minimization min

a,b,c∈S d−1 ,w∈R

kT − w · a ⊗ b ⊗ ckF .

Rank-1 ALS iteration (power iteration) Initialization: a(0) , b(0) , c(0) . Update in tth step: fix a(t) , b(t) and c(t+1) ∝ T (a(t) , b(t) , I). After (approx.) convergence, restart.

Alternating minimization min

a,b,c∈S d−1 ,w∈R

kT − w · a ⊗ b ⊗ ckF .

Rank-1 ALS iteration (power iteration) Initialization: a(0) , b(0) , c(0) . Update in tth step: fix a(t) , b(t) and c(t+1) ∝ T (a(t) , b(t) , I). After (approx.) convergence, restart. Simple update: trivially parallelizable and hence scalable. Linear computation in dimension, rank, number of different runs.

Alternating minimization min

a,b,c∈S d−1 ,w∈R

kT − w · a ⊗ b ⊗ ckF .

Rank-1 ALS iteration (power iteration) Initialization: a(0) , b(0) , c(0) . Update in tth step: fix a(t) , b(t) and c(t+1) ∝ T (a(t) , b(t) , I). After (approx.) convergence, restart. Simple update: trivially parallelizable and hence scalable. Linear computation in dimension, rank, number of different runs. Rank-1 ALS iteration ≡ asymmetric power iteration

Main Result: Local Convergence Initialization: max{ka1 − ˆ a(0) k, kb1 − ˆb(0) k} ≤ ǫ0 , and ǫ0 < constant. Noise: Tˆ := T + E, and kEk ≤ 1/ polylog(d). Rank: k = o(d1.5 ).

Main Result: Local Convergence Initialization: max{ka1 − ˆ a(0) k, kb1 − ˆb(0) k} ≤ ǫ0 , and ǫ0 < constant. Noise: Tˆ := T + E, and kEk ≤ 1/ polylog(d). Rank: k = o(d1.5 ).

Theorem (Local Convergence)[AGJ2014] After N = O(log(1/kEk)) steps of alternating rank-1 updates, ka1 − a ˆ(N ) k = O (kEk) .

Main Result: Local Convergence Initialization: max{ka1 − ˆ a(0) k, kb1 − ˆb(0) k} ≤ ǫ0 , and ǫ0 < constant. Noise: Tˆ := T + E, and kEk ≤ 1/ polylog(d). Rank: k = o(d1.5 ).

Theorem (Local Convergence)[AGJ2014] After N = O(log(1/kEk)) steps of alternating rank-1 updates, ka1 − a ˆ(N ) k = O (kEk) . Linear convergence: up to approximation error. Guarantees for overcomplete tensors: k = o(d1.5 ) and for pth -order tensors k = o(dp/2 ). Requires good initialization. What about global convergence?

Global Convergence k = O(d)

SVD Initialization Find the top singular vectors of T (I, I, θ) for θ ∼ N (0, I). Use them for initialization. L trials.

Global Convergence k = O(d)

SVD Initialization Find the top singular vectors of T (I, I, θ) for θ ∼ N (0, I). Use them for initialization. L trials.

Assumptions 2

Number of initializations: L ≥ kΩ(k/d) , Tensor Rank: k = O(d)

No. of Iterations: N = Θ (log(1/kEk)). Recall kEk: recovery error.

Global Convergence k = O(d)

SVD Initialization Find the top singular vectors of T (I, I, θ) for θ ∼ N (0, I). Use them for initialization. L trials.

Assumptions 2

Number of initializations: L ≥ kΩ(k/d) , Tensor Rank: k = O(d)

No. of Iterations: N = Θ (log(1/kEk)). Recall kEk: recovery error.

Theorem (Global Convergence)[AGJ2014]: ka1 − a ˆ(N ) k ≤ O(ǫR ).

Outline

1

Introduction

2

Overcomplete tensor decomposition

3

Sample Complexity Analysis

4

Conclusion

High-level Intuition for Sample Bounds Multi-view Model: x1 = Ah + z1 , where z1 is noise. P Exact moment T = i wi ai ⊗ bi ⊗ ci . P i Sample moment: Tˆ = 1 x ⊗ xi ⊗ xi . n

i

1

2

3

Na¨ıve Idea: kTˆ − T k ≤ k mat(Tˆ) − mat(T )k, apply matrix Bernstein’s.

High-level Intuition for Sample Bounds Multi-view Model: x1 = Ah + z1 , where z1 is noise. P Exact moment T = i wi ai ⊗ bi ⊗ ci . P i Sample moment: Tˆ = 1 x ⊗ xi ⊗ xi . n

i

1

2

3

Na¨ıve Idea: kTˆ − T k ≤ k mat(Tˆ) − mat(T )k, apply matrix Bernstein’s. Our idea: Careful ǫ-net covering for Tˆ − T . P Tˆ − T has many terms, e.g., all-noise term: n1 i z1i ⊗ z2i ⊗ z3i and signal-noise terms. 1X i Need to bound hz1 , uihz2i , vihz3i , wi, for all u, v, w ∈ S d−1 . n i

Classify inner products into buckets and bound them separately. Tight sample bounds for a range of latent variable models

Unsupervised Learning of Gaussian Mixtures No. of mixture components: k = C · d ˜ · d). No. of unlabeled samples: n = Ω(k  2 ˜ kC Computational complexity: O

Our result: achieved error with n unlabeled samples r ! k ˜ max kb aj − aj k = O j n Linear convergence. Error: same as before, for semi-supervised setting. Computational complexity: polynomial when k = Θ(d).

Semi-supervised Learning of Gaussian Mixtures n unlabeled samples, mj : samples for component j. No. of mixture components: k = o(d1.5 ) ˜ No. of labeled samples: mj = Ω(1). ˜ No. of unlabeled samples: n = Ω(k).

Our result: achieved error with n unlabeled samples r ! k ˜ max kb aj − aj k = O j n Linear convergence. Can handle (polynomially) overcomplete mixtures. Extremely small number of labeled samples: polylog(d). ˜ Sample complexity is tight: need Ω(k) samples!

Outline

1

Introduction

2

Overcomplete tensor decomposition

3

Sample Complexity Analysis

4

Conclusion

Conclusion Learning overcomplete Latent variable models. ⋆ Method-of-moments. ⋆ Tensor power iteration.

Robustness to noise. Sample complexity bounds for a range of LVMs. ⋆ Unsupervised setting. ⋆ Semi-supervised setting.

Conclusion Learning overcomplete Latent variable models. ⋆ Method-of-moments. ⋆ Tensor power iteration.

Robustness to noise. Sample complexity bounds for a range of LVMs. ⋆ Unsupervised setting. ⋆ Semi-supervised setting.

Latest result: improved initialization for tensor with Gaussian components.

Conclusion Learning overcomplete Latent variable models. ⋆ Method-of-moments. ⋆ Tensor power iteration.

Robustness to noise. Sample complexity bounds for a range of LVMs. ⋆ Unsupervised setting. ⋆ Semi-supervised setting.

Latest result: improved initialization for tensor with Gaussian components.

Thank you!

Recommend Documents