Fast Solvers and Efficient Implementations for Distance Metric Learning

Report 24 Downloads 44 Views
ICML July 6th, 2008

Fast Solvers and Efficient Implementations for Distance Metric Learning Kilian Q. Weinberger

Lawrence K. Saul

Mahalanobis distance Euclidean

Mahalanobis

!!xi − !xj !2 =

!!xi − !xj !M =

! (!xi − !xj )! (!xi − !xj )

!

(!xi − !xj )! M(!xi − !xj )

unit circle

Mahalanobis distance Euclidean

Mahalanobis

!!xi − !xj !2 =

!!xi − !xj !M =

! (!xi − !xj )! (!xi − !xj )

!

(!xi − !xj )! M(!xi − !xj )

M!0

positive semi-definite

unit circle

Mahalanobis distance Euclidean

Mahalanobis

!!xi − !xj !2 =

!!xi − !xj !M =

! (!xi − !xj )! (!xi − !xj )

!

(!xi − !xj )! M(!xi − !xj )

M!0

positive semi-definite

unit circle

k-NN classification Adapt metric to the data Amplify informative directions Squash non-informative directions Shental et al [eccv 2002] Xing et al [nips 2003] ?

Bilenko et al [icml 2004] Shalev-Shwartz et al [icml 2004] Goldberger et al [nips 2005] Weinberger et al [nips 2006] Globerson and Roweis [nips 2006] Davis et al [icml 2007]

k-NN classification Adapt metric to the data Amplify informative directions Squash non-informative directions Shental et al [eccv 2002] Xing et al [nips 2003] ?

Bilenko et al [icml 2004] Shalev-Shwartz et al [icml 2004] Goldberger et al [nips 2005] Weinberger et al [nips 2006] Globerson and Roweis [nips 2006] Davis et al [icml 2007]

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

Minimize leave-one-out k-NN error:

!xi

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

Minimize leave-one-out k-NN error: !xj

! !xi

1. Pick target neighbors !xj ! !xi

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

Minimize leave-one-out k-NN error: !xj

! !xi

1. Pick target neighbors !xj ! !xi 2. Learn metric such that:

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

Minimize leave-one-out k-NN error: !xj

! !xl

!xi

1. Pick target neighbors !xj ! !xi 2. Learn metric such that:



No impostors are closer

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

Minimize leave-one-out k-NN error: !xj

! !xl

!xi

1. Pick target neighbors !xj ! !xi 2. Learn metric such that:



No impostors are closer



By a large margin

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

!xj

! !xl

!xi

∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !M

+1≤

2 !!xi − !xl !M

+ ξijl

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

!xj

! !xl

!xi

∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !M

+1≤

2 !!xi − !xl !M

+ ξijl

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

!xj

! !xl

!xi

∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !M

+1≤

2 !!xi − !xl !M

+ ξijl

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

!xj

! !xl

!xi

∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !M

+1≤

2 !!xi − !xl !M

+ ξijl

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

min !xj

! !xl

!xi

M

!

i,j!i

!!xi − !xj !2M + µ

subject to: ∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !M

+1≤

!

ξijl

i,j,l

2 !!xi − !xl !M

+ ξijl

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

min !xj

! !xl

!xi

M

!

i,j!i

!!xi − !xj !2M + µ

subject to: ∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !M

+1≤

M ! 0, ξijl ≥ 0

!

ξijl

i,j,l

2 !!xi − !xl !M

+ ξijl

Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]

min !xj

! !xl

!xi

M

!

i,j!i

Semi-definite program

!!xi − !xj !2M + µ

subject to: ∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !M

+1≤

M ! 0, ξijl ≥ 0

!

ξijl

i,j,l

2 !!xi − !xl !M

+ ξijl

Three Questions

Three Questions

1. How can one apply LMNN to larger data sets?

Three Questions

1. How can one apply LMNN to larger data sets? 2. How can one make LMNN classification faster?

Three Questions

1. How can one apply LMNN to larger data sets? 2. How can one make LMNN classification faster? 3. How can one make LMNN classification more accurate?

1. How can one apply LMNN to larger data sets?

Semidefinite Program min M

!

i,j!i

2 !!xi − !xj !M

subject to:



!

ξijl

i,j,l

∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !M

+1≤

M ! 0, ξijl ≥ 0

2 !!xi − !xl !M

+ ξijl

3 speed-ups: - active set method - cheap gradient update - non-convex initialization (details in the paper / at poster)

Optimization on MNIST

N 60 600 6000 60000

time |active set| |total set| 9s 844 3.2K 37s 6169 323K 4m 50345 32M 3h25m 540037 3.2B

train error 0% 0% 0.48% 1.19%

test error 29.37% 10.79% 3.13% 1.72%

Three Questions 1. How can one apply LMNN to larger data sets? Use special purpose solver (come to poster). 2. How can one make LMNN classification faster? 3. How can one make LMNN classification more accurate?

2. How can one make LMNN classification faster?

Ball-Trees original data

2 balls

4 balls

155 balls

[Deng and Moore, 1995]

Pruning Principle nearest neighbor so far

dbest test point

Pruning Principle nearest neighbor so far

dbest dball

ball with points

test point

Pruning Principle nearest neighbor so far

dbest dball

ball with points

test point

if dbest ≤ dball we can prune the entire ball!

Catch: Dimensionality (1-nn search of MNIST images)

10.0 8.9 7.5

Speedup

7.5

5.4

5.0

4.4 3.6

3.3

3.0

2.8

2.6

2.4

80

90

100

2.5

0 10

20

30

40

50

60

70

dimensionality (PCA)

Catch: Dimensionality (1-nn search of MNIST images)

10.0 8.9

7.5

7.5

To make kNN fast, Speedup we need to reduce the dimensionality!! 5.4

5.0

4.4

3.6

2.5

3.3

3.0

2.8

2.6

2.4

80

90

100

0 10

20

30

40

50

60

70

dimensionality (PCA)

LMNN-svd min M

!

i,j!i

!!xi − !xj !2M

subject to:



!

ξijl

i,j,l

∀i, j ! i, l ∈ / Ci ,

!!xi − !xj !2M + 1 ≤ !!xi − !xl !2M + ξijl M ! 0, ξijl ≥ 0

1. Solve SDP

3. Apply SVD

M=L L !

!!xi − !xj !2M =

(!xi − !xj )! L! L(!xi − !xj ) = !L(!xi − !xj )!22

!x → L!x

2. Decompose

SVD: L∈R

d×d

→L∈R

r×d

LMNN-rect We learn a linear transformation directly:

!x → L!x min M

!

i,j!i

2 !!xi − !xj !M

subject to:



!

ξijl

i,j,l

∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !M

+1≤

M ! 0, ξijl ≥ 0

2 !!xi − !xl !M

+ ξijl !!xi − !xj !2M =

(!xi − !xj )! L! L(!xi − !xj ) = !L(!xi − !xj )!22

LMNN-rect We learn a linear transformation directly:

!x → L!x min M

!

i,j!i

2 !!xi − !xj !M

subject to:



!

ξijl

i,j,l

∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !M

+1≤

M ! 0, ξijl ≥ 0

2 !!xi − !xl !M

+ ξijl !!xi − !xj !2M =

(!xi − !xj )! L! L(!xi − !xj ) = !L(!xi − !xj )!22

LMNN-rect We learn a linear transformation directly:

!x → L!x min L

!

i,j!i

2 !L(!xi − !xj )!2

subject to:



!

ξijl

i,j,l

∀i, j ! i, l ∈ / Ci ,

2 !L(!xi − !xj )!2

ξijl ≥ 0

+1≤

2 !L(!xi − !xl )!2

(no longer convex)

+ ξijl

LMNN-rect min L

!

i,j!i

2 !L(!xi − !xj )!2

subject to:



!

ξijl

i,j,l

∀i, j ! i, l ∈ / Ci ,

2 !L(!xi − !xj )!2

ξijl ≥ 0

+1≤

2 !L(!xi − !xl )!2

+ ξijl

L∈R

r×d

[Torresani and Lee, nips 2006]

LMNN-rect min L

!

i,j!i

2 !L(!xi − !xj )!2

subject to:



!

ξijl

i,j,l

∀i, j ! i, l ∈ / Ci ,

2 !L(!xi − !xj )!2

ξijl ≥ 0

+1≤

2 !L(!xi − !xl )!2

+ ξijl

L∈R

r×d

[Torresani and Lee, nips 2006]

LMNN for Balltrees Classification Error in %

4

MNIST 3-NN classification after dimensionality reduction

3.5 pca

3

lda LMNN-svd LMNN-rec

2.5 (2.33)

pca (d=350)

2.38

2

lmnn (d=350)

2.09 (1.72)

1.76

1.82

1.79

1.8

1.5 15

20

25

30

Inputdimensionality Dimensionality

35

40

Input: train=60K test=10K d=350

LMNN for Balltrees 4

MNIST 3-NN classification after dimensionality reduction

Classification Error in %

15x

9x

6x

5x

4x

3x

3.5

ball tree speedup

3

pca lda LMNN-svd LMNN-rec

2.5 (2.33)

pca (d=350)

2.38

2

lmnn (d=350)

2.09 (1.72)

1.76

1.82

1.79

1.8

1.5 15

20

25

30

Inputdimensionality Dimensionality

35

40

Input: train=60K test=10K d=350

Three Questions 1. How can one apply LMNN to larger data sets? Use special purpose solver (come to poster). 2. How can one make LMNN classification faster? Learn rectangular matrix, then use ball-trees. 3. How can one make LMNN classification more accurate?

3. How can one make LMNN classification more accurate?

Limits of a global linear metric

1 Metric: 1-nn error: 100%

3 Metrics: 1-nn error: 0%

Limits of a global linear metric

1 Metric: -nn error: 100%

3 Metrics: 1-nn error: 0%

Can we apply LMNN locally? M2 M1

M4 M3

Can we apply LMNN locally? M2 M1

!xl

!xj

We could not compare the local metrics! !!xi − !xl !M2 !!xi − !xj !M1

M4 M3

?

!xi

Can we apply LMNN locally? M1

M2

Solution: We train all metrics We could not !xl !xj compare the local jointly metrics! and enforce large margin distance !!xi − !xl !M constraints! 2

!!xi − !xj !M1 M4 M3

?

!xi

Multiple Metrics-LMNN

!xj

! !xl

!xi

We compute distances under local metrics. ∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !Mc(j)

+1≤

2 !!xi − !xl !Mc(l)

M1 , . . . , Mm ! 0, ξijl ≥ 0 c(j) = cluster of point !xj

+ ξijl

Multiple Metrics-LMNN

!xj

! !xl

!xi

We compute distances under local metrics. ∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !Mc(j)

+1≤

2 !!xi − !xl !Mc(l)

M1 , . . . , Mm ! 0, ξijl ≥ 0 c(j) = cluster of point !xj

+ ξijl

Multiple Metrics-LMNN

!xj

! !xl

!xi

We compute distances under local metrics. ∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !Mc(j)

+1≤

2 !!xi − !xl !Mc(l)

M1 , . . . , Mm ! 0, ξijl ≥ 0 c(j) = cluster of point !xj

+ ξijl

Multiple Metrics-LMNN Semi-definite program

!xj

! !xl

!xi

min M

!

i,j!i

2 !!xi − !xj !Mc(j)



subject to: ∀i, j ! i, l ∈ / Ci ,

2 !!xi − !xj !Mc(j)

+1≤

!

ξijl

i,j,l

2 !!xi − !xl !Mc(l)

M1 , . . . , Mm ! 0, ξijl ≥ 0 c(j) = cluster of point !xj

+ ξijl

Multiple Metrics M2 M1

!xl

!xj

We can now compare any two metrics. !!xi − !xl !M2 !!xi − !xj !M1

M4 M3

?

!xi

How should one divide the training data? M2 3.4

7

3.2 3

6.5

2.8 2.6

6

2.4 5.5

1

5

MNIST ISOLET

10 15 20 25 30

2.2

% classification error

M1

ISOLET

7

6.5 9 8 9

6

7 8 6 7

1 5.5 5 65 4 5

ISOLET mean ISOLET std mean std std st. deviation std 1 metric/class std 1 global metric 10 1 metric/class 15 20 1 global metric

M3

3.2 3 3.2 2.8 3 2.6 2.8

25 30 1 5 10 15 20 25 30 2.4 2.6

3.4 3.2 3 MNIST 2.8 MNIST 2.6 2.4 2.2

Number of clusters

35

10

15

20

5

10

15

20

2 2.2 5 2 5

1 5 10 15 20 25 30 Number of clusters

2.2 2.4

3 4

M4

MNIST MNIST (n=10000)

10

15

20

10

15

20

How should one divide the training data? M1 ISOLET M2

MNIST ISOLET

7

3.2 3

6.5

2.8 2.6

6

2.4 5.5

1

5

10 15 20 25 30

2.2

% classification error

3.4

MNIST MNIST (n=10000)

7

6.5 9 8 9

6

7 8 6 7

1 5.5 5 65 4 5

ISOLET mean ISOLET std mean std std st. deviation std 1 metric/class std 1 global metric 10 1 metric/class 15 20 1 global metric

3.2 3 3.2 2.8 3 2.6 2.8

25 30 1 5 10 15 20 25 30 2.4 2.6

3.4 3.2 3 MNIST 2.8 MNIST 2.6 2.4 2.2

Number of clusters

Number of clusters

2.2 2.4

3 4 35

10

5

10

2 2.2 5 2 5

1 5 10 15 20 25 30

One metric per class. 15

20

15

20

10

15

20

10

15

20

2d mnist

Comparison with LMNN

Classification error in %

PCA

20

LMNN

MM-LMNN

18.2 15.0

15

12.8

12.8 10.8

10 5

4.7 2.4

1.7 1.2

5.6 5.3 3.7

3.6 3.1

2.7

0 mnist

letters

20news

isolet

yalefaces

Comparison with SVMs

Classification error in %

MCSVM [Crammer and Singer 2001] MM-LMNN

20 15

14.0

12.8

10

8.0 5.3

5

3.2 1.2

3.1

2.8

1.2

2.7

0 mnist

letters

20news

isolet

yalefaces

Three Questions 1. How can one apply LMNN to larger data sets? Use special purpose solver (come to poster). 2. How can one make LMNN classification faster? Learn rectangular matrix, then use ball-trees. 3. How can one make LMNN classification more accurate? Use multiple locally linear metrics.

Conclusion • LMNN ... • ... scales to large data sets • ... has fast test-time performance • ... obtains state-of-the-art accuracies.

Thank you!