ICML July 6th, 2008
Fast Solvers and Efficient Implementations for Distance Metric Learning Kilian Q. Weinberger
Lawrence K. Saul
Mahalanobis distance Euclidean
Mahalanobis
!!xi − !xj !2 =
!!xi − !xj !M =
! (!xi − !xj )! (!xi − !xj )
!
(!xi − !xj )! M(!xi − !xj )
unit circle
Mahalanobis distance Euclidean
Mahalanobis
!!xi − !xj !2 =
!!xi − !xj !M =
! (!xi − !xj )! (!xi − !xj )
!
(!xi − !xj )! M(!xi − !xj )
M!0
positive semi-definite
unit circle
Mahalanobis distance Euclidean
Mahalanobis
!!xi − !xj !2 =
!!xi − !xj !M =
! (!xi − !xj )! (!xi − !xj )
!
(!xi − !xj )! M(!xi − !xj )
M!0
positive semi-definite
unit circle
k-NN classification Adapt metric to the data Amplify informative directions Squash non-informative directions Shental et al [eccv 2002] Xing et al [nips 2003] ?
Bilenko et al [icml 2004] Shalev-Shwartz et al [icml 2004] Goldberger et al [nips 2005] Weinberger et al [nips 2006] Globerson and Roweis [nips 2006] Davis et al [icml 2007]
k-NN classification Adapt metric to the data Amplify informative directions Squash non-informative directions Shental et al [eccv 2002] Xing et al [nips 2003] ?
Bilenko et al [icml 2004] Shalev-Shwartz et al [icml 2004] Goldberger et al [nips 2005] Weinberger et al [nips 2006] Globerson and Roweis [nips 2006] Davis et al [icml 2007]
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
Minimize leave-one-out k-NN error:
!xi
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
Minimize leave-one-out k-NN error: !xj
! !xi
1. Pick target neighbors !xj ! !xi
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
Minimize leave-one-out k-NN error: !xj
! !xi
1. Pick target neighbors !xj ! !xi 2. Learn metric such that:
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
Minimize leave-one-out k-NN error: !xj
! !xl
!xi
1. Pick target neighbors !xj ! !xi 2. Learn metric such that:
•
No impostors are closer
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
Minimize leave-one-out k-NN error: !xj
! !xl
!xi
1. Pick target neighbors !xj ! !xi 2. Learn metric such that:
•
No impostors are closer
•
By a large margin
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
!xj
! !xl
!xi
∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !M
+1≤
2 !!xi − !xl !M
+ ξijl
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
!xj
! !xl
!xi
∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !M
+1≤
2 !!xi − !xl !M
+ ξijl
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
!xj
! !xl
!xi
∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !M
+1≤
2 !!xi − !xl !M
+ ξijl
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
!xj
! !xl
!xi
∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !M
+1≤
2 !!xi − !xl !M
+ ξijl
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
min !xj
! !xl
!xi
M
!
i,j!i
!!xi − !xj !2M + µ
subject to: ∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !M
+1≤
!
ξijl
i,j,l
2 !!xi − !xl !M
+ ξijl
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
min !xj
! !xl
!xi
M
!
i,j!i
!!xi − !xj !2M + µ
subject to: ∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !M
+1≤
M ! 0, ξijl ≥ 0
!
ξijl
i,j,l
2 !!xi − !xl !M
+ ξijl
Large margin nearest neighbor (LMNN) [Weinberger et al , nips 2006]
min !xj
! !xl
!xi
M
!
i,j!i
Semi-definite program
!!xi − !xj !2M + µ
subject to: ∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !M
+1≤
M ! 0, ξijl ≥ 0
!
ξijl
i,j,l
2 !!xi − !xl !M
+ ξijl
Three Questions
Three Questions
1. How can one apply LMNN to larger data sets?
Three Questions
1. How can one apply LMNN to larger data sets? 2. How can one make LMNN classification faster?
Three Questions
1. How can one apply LMNN to larger data sets? 2. How can one make LMNN classification faster? 3. How can one make LMNN classification more accurate?
1. How can one apply LMNN to larger data sets?
Semidefinite Program min M
!
i,j!i
2 !!xi − !xj !M
subject to:
+µ
!
ξijl
i,j,l
∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !M
+1≤
M ! 0, ξijl ≥ 0
2 !!xi − !xl !M
+ ξijl
3 speed-ups: - active set method - cheap gradient update - non-convex initialization (details in the paper / at poster)
Optimization on MNIST
N 60 600 6000 60000
time |active set| |total set| 9s 844 3.2K 37s 6169 323K 4m 50345 32M 3h25m 540037 3.2B
train error 0% 0% 0.48% 1.19%
test error 29.37% 10.79% 3.13% 1.72%
Three Questions 1. How can one apply LMNN to larger data sets? Use special purpose solver (come to poster). 2. How can one make LMNN classification faster? 3. How can one make LMNN classification more accurate?
2. How can one make LMNN classification faster?
Ball-Trees original data
2 balls
4 balls
155 balls
[Deng and Moore, 1995]
Pruning Principle nearest neighbor so far
dbest test point
Pruning Principle nearest neighbor so far
dbest dball
ball with points
test point
Pruning Principle nearest neighbor so far
dbest dball
ball with points
test point
if dbest ≤ dball we can prune the entire ball!
Catch: Dimensionality (1-nn search of MNIST images)
10.0 8.9 7.5
Speedup
7.5
5.4
5.0
4.4 3.6
3.3
3.0
2.8
2.6
2.4
80
90
100
2.5
0 10
20
30
40
50
60
70
dimensionality (PCA)
Catch: Dimensionality (1-nn search of MNIST images)
10.0 8.9
7.5
7.5
To make kNN fast, Speedup we need to reduce the dimensionality!! 5.4
5.0
4.4
3.6
2.5
3.3
3.0
2.8
2.6
2.4
80
90
100
0 10
20
30
40
50
60
70
dimensionality (PCA)
LMNN-svd min M
!
i,j!i
!!xi − !xj !2M
subject to:
+µ
!
ξijl
i,j,l
∀i, j ! i, l ∈ / Ci ,
!!xi − !xj !2M + 1 ≤ !!xi − !xl !2M + ξijl M ! 0, ξijl ≥ 0
1. Solve SDP
3. Apply SVD
M=L L !
!!xi − !xj !2M =
(!xi − !xj )! L! L(!xi − !xj ) = !L(!xi − !xj )!22
!x → L!x
2. Decompose
SVD: L∈R
d×d
→L∈R
r×d
LMNN-rect We learn a linear transformation directly:
!x → L!x min M
!
i,j!i
2 !!xi − !xj !M
subject to:
+µ
!
ξijl
i,j,l
∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !M
+1≤
M ! 0, ξijl ≥ 0
2 !!xi − !xl !M
+ ξijl !!xi − !xj !2M =
(!xi − !xj )! L! L(!xi − !xj ) = !L(!xi − !xj )!22
LMNN-rect We learn a linear transformation directly:
!x → L!x min M
!
i,j!i
2 !!xi − !xj !M
subject to:
+µ
!
ξijl
i,j,l
∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !M
+1≤
M ! 0, ξijl ≥ 0
2 !!xi − !xl !M
+ ξijl !!xi − !xj !2M =
(!xi − !xj )! L! L(!xi − !xj ) = !L(!xi − !xj )!22
LMNN-rect We learn a linear transformation directly:
!x → L!x min L
!
i,j!i
2 !L(!xi − !xj )!2
subject to:
+µ
!
ξijl
i,j,l
∀i, j ! i, l ∈ / Ci ,
2 !L(!xi − !xj )!2
ξijl ≥ 0
+1≤
2 !L(!xi − !xl )!2
(no longer convex)
+ ξijl
LMNN-rect min L
!
i,j!i
2 !L(!xi − !xj )!2
subject to:
+µ
!
ξijl
i,j,l
∀i, j ! i, l ∈ / Ci ,
2 !L(!xi − !xj )!2
ξijl ≥ 0
+1≤
2 !L(!xi − !xl )!2
+ ξijl
L∈R
r×d
[Torresani and Lee, nips 2006]
LMNN-rect min L
!
i,j!i
2 !L(!xi − !xj )!2
subject to:
+µ
!
ξijl
i,j,l
∀i, j ! i, l ∈ / Ci ,
2 !L(!xi − !xj )!2
ξijl ≥ 0
+1≤
2 !L(!xi − !xl )!2
+ ξijl
L∈R
r×d
[Torresani and Lee, nips 2006]
LMNN for Balltrees Classification Error in %
4
MNIST 3-NN classification after dimensionality reduction
3.5 pca
3
lda LMNN-svd LMNN-rec
2.5 (2.33)
pca (d=350)
2.38
2
lmnn (d=350)
2.09 (1.72)
1.76
1.82
1.79
1.8
1.5 15
20
25
30
Inputdimensionality Dimensionality
35
40
Input: train=60K test=10K d=350
LMNN for Balltrees 4
MNIST 3-NN classification after dimensionality reduction
Classification Error in %
15x
9x
6x
5x
4x
3x
3.5
ball tree speedup
3
pca lda LMNN-svd LMNN-rec
2.5 (2.33)
pca (d=350)
2.38
2
lmnn (d=350)
2.09 (1.72)
1.76
1.82
1.79
1.8
1.5 15
20
25
30
Inputdimensionality Dimensionality
35
40
Input: train=60K test=10K d=350
Three Questions 1. How can one apply LMNN to larger data sets? Use special purpose solver (come to poster). 2. How can one make LMNN classification faster? Learn rectangular matrix, then use ball-trees. 3. How can one make LMNN classification more accurate?
3. How can one make LMNN classification more accurate?
Limits of a global linear metric
1 Metric: 1-nn error: 100%
3 Metrics: 1-nn error: 0%
Limits of a global linear metric
1 Metric: -nn error: 100%
3 Metrics: 1-nn error: 0%
Can we apply LMNN locally? M2 M1
M4 M3
Can we apply LMNN locally? M2 M1
!xl
!xj
We could not compare the local metrics! !!xi − !xl !M2 !!xi − !xj !M1
M4 M3
?
!xi
Can we apply LMNN locally? M1
M2
Solution: We train all metrics We could not !xl !xj compare the local jointly metrics! and enforce large margin distance !!xi − !xl !M constraints! 2
!!xi − !xj !M1 M4 M3
?
!xi
Multiple Metrics-LMNN
!xj
! !xl
!xi
We compute distances under local metrics. ∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !Mc(j)
+1≤
2 !!xi − !xl !Mc(l)
M1 , . . . , Mm ! 0, ξijl ≥ 0 c(j) = cluster of point !xj
+ ξijl
Multiple Metrics-LMNN
!xj
! !xl
!xi
We compute distances under local metrics. ∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !Mc(j)
+1≤
2 !!xi − !xl !Mc(l)
M1 , . . . , Mm ! 0, ξijl ≥ 0 c(j) = cluster of point !xj
+ ξijl
Multiple Metrics-LMNN
!xj
! !xl
!xi
We compute distances under local metrics. ∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !Mc(j)
+1≤
2 !!xi − !xl !Mc(l)
M1 , . . . , Mm ! 0, ξijl ≥ 0 c(j) = cluster of point !xj
+ ξijl
Multiple Metrics-LMNN Semi-definite program
!xj
! !xl
!xi
min M
!
i,j!i
2 !!xi − !xj !Mc(j)
+µ
subject to: ∀i, j ! i, l ∈ / Ci ,
2 !!xi − !xj !Mc(j)
+1≤
!
ξijl
i,j,l
2 !!xi − !xl !Mc(l)
M1 , . . . , Mm ! 0, ξijl ≥ 0 c(j) = cluster of point !xj
+ ξijl
Multiple Metrics M2 M1
!xl
!xj
We can now compare any two metrics. !!xi − !xl !M2 !!xi − !xj !M1
M4 M3
?
!xi
How should one divide the training data? M2 3.4
7
3.2 3
6.5
2.8 2.6
6
2.4 5.5
1
5
MNIST ISOLET
10 15 20 25 30
2.2
% classification error
M1
ISOLET
7
6.5 9 8 9
6
7 8 6 7
1 5.5 5 65 4 5
ISOLET mean ISOLET std mean std std st. deviation std 1 metric/class std 1 global metric 10 1 metric/class 15 20 1 global metric
M3
3.2 3 3.2 2.8 3 2.6 2.8
25 30 1 5 10 15 20 25 30 2.4 2.6
3.4 3.2 3 MNIST 2.8 MNIST 2.6 2.4 2.2
Number of clusters
35
10
15
20
5
10
15
20
2 2.2 5 2 5
1 5 10 15 20 25 30 Number of clusters
2.2 2.4
3 4
M4
MNIST MNIST (n=10000)
10
15
20
10
15
20
How should one divide the training data? M1 ISOLET M2
MNIST ISOLET
7
3.2 3
6.5
2.8 2.6
6
2.4 5.5
1
5
10 15 20 25 30
2.2
% classification error
3.4
MNIST MNIST (n=10000)
7
6.5 9 8 9
6
7 8 6 7
1 5.5 5 65 4 5
ISOLET mean ISOLET std mean std std st. deviation std 1 metric/class std 1 global metric 10 1 metric/class 15 20 1 global metric
3.2 3 3.2 2.8 3 2.6 2.8
25 30 1 5 10 15 20 25 30 2.4 2.6
3.4 3.2 3 MNIST 2.8 MNIST 2.6 2.4 2.2
Number of clusters
Number of clusters
2.2 2.4
3 4 35
10
5
10
2 2.2 5 2 5
1 5 10 15 20 25 30
One metric per class. 15
20
15
20
10
15
20
10
15
20
2d mnist
Comparison with LMNN
Classification error in %
PCA
20
LMNN
MM-LMNN
18.2 15.0
15
12.8
12.8 10.8
10 5
4.7 2.4
1.7 1.2
5.6 5.3 3.7
3.6 3.1
2.7
0 mnist
letters
20news
isolet
yalefaces
Comparison with SVMs
Classification error in %
MCSVM [Crammer and Singer 2001] MM-LMNN
20 15
14.0
12.8
10
8.0 5.3
5
3.2 1.2
3.1
2.8
1.2
2.7
0 mnist
letters
20news
isolet
yalefaces
Three Questions 1. How can one apply LMNN to larger data sets? Use special purpose solver (come to poster). 2. How can one make LMNN classification faster? Learn rectangular matrix, then use ball-trees. 3. How can one make LMNN classification more accurate? Use multiple locally linear metrics.
Conclusion • LMNN ... • ... scales to large data sets • ... has fast test-time performance • ... obtains state-of-the-art accuracies.
Thank you!