Neural Computation 10(1998), pp.2137-2157
Complexity Issues in Natural Gradient Descent Method for Training Multi-Layer Perceptrons Howard Hua Yang Department of Computer Science, Oregon Graduate Institute PO Box 91000, Portland, OR 97291, USA
[email protected] FAX: +503 690 1548 Shun-ichi Amari Lab. for Information Synthesis, RIKEN Brain Science Institute Hirosawa 2-1, Wako-shi, Saitama 351-01, JAPAN
[email protected] Abstract
The natural gradient descent method is applied to train a n-m-1 multi-layer perceptron. Based on an ecient scheme to represent the Fisher information matrix for a n-m-1 stochastic multi-layer perceptron, a new algorithm is proposed to calculate the natural gradient without inverting the Fisher information matrix explicitly. When the input dimension n is much larger than the number of hidden neurons m, the time complexity of computing the natural gradient is O(n).
1 Introduction Amari(1997, 1998) has shown that the natural gradient descent learning rule is statistically ecient. In principle, this learning rule can be used to train any adaptive system. But the complexity of this learning rule depends on the architecture of the learning machine. The main diculty in implementing this learning rule is to design a fast algorithm to compute the natural gradient. For a n-m-1 multi-layer perceptron, it is shown in this paper that the natural gradient can be computed in O(n) ops when m n. Here, a op is a oating point operation: an add or a multiplication. Orr and Leen(1997) used the curvature information (the Hessian matrix) in the nonlinear adaptive momentum scheme to optimize the convergence of the stochastic gradient descent. They have shown that the complexity of their algorithm is O(n). However, their algorithm is dierent from the natural gradient descent. Instead of using the momentum scheme, we use the inverse of the Fisher information matrix to transform the stochastic gradient to optimize 1
Yang and Amari: Complexity Issues in Natural GD
2
the learning dynamics. This method is based on statistical inference. A review on the learning of arti cial neural networks is given by (Yang et al, 1998). The rest of this paper is organized as follows. The stochastic multi-layer perceptron is described in section 2. The natural gradient descent method using the inverse of the Fisher information matrix and its properties are discussed in section 3. A constructive procedure to compute the inverse of the Fisher information matrix is formulated in section 4. For a singlelayer perceptron, we calculate an explicit expression of the natural gradient in section 4.1. It is obvious from this expression that the time complexity for computing the natural gradient for the single-layer perceptron is O(n). Importantly, this is also true for a multi-layer perceptron but its proof is not straightforward. To prove this result, we analyze the structure of the Fisher information matrix in section 4.2. Based on this analysis, we discuss the time complexities of computing the inverse of the Fisher information matrix and the natural gradient in section 5 for the multi-layer perceptron. We rst discuss the simplest multi-layer perceptron, a committee machine, in section 5.1, then discuss the multi-layer perceptron in section 5.2. Finally, the conclusions are summarized in section 6.
2 Stochastic Multi-layer Perceptron We consider the following stochastic multi-layer perceptron model:
z=
m X i=1
ai '(wTi x + bi ) +
(1)
where ()T denotes the transpose, N (0; 2 ) is the Gaussian additive noise with the variance 2 , and '(x) is a dierentiable output function of hidden neurons. Assume that the multi-layer network has a n-dimensional input, m hidden neurons, a one dimensional output, and m n. Denote a = (a1 ; ; am )T the weight vector of the output neuron, wi = (w1i ; ; wni )T the weight vector of the i-th hidden neuron, and b = (b1 ; ; bm )T is the vector of thresholds for the hidden neurons. Let W = [w1 ; ; wm ] be a matrix formed by column weight vectors wi , then the equation (1) can be rewritten as z = aT '(W T x + b) + : Here, the scalar function ' operates on each component of the vector W T x + b.
3 Fisher information matrix and natural gradient descent The stochastic perceptron model enables us to derive learning algorithms for neural networks from some statistical inference methods and evaluate the performance of the algorithms systematically. The joint probability density function (pdf) of the input and the output is p(x; z; W ; a; b) = p(zjx; W ; a; b)p(x) where p(z jx; W ; a; b) is the conditional pdf of the output z given the input x and p(x) is the pdf of the input x. The loss function is de ned as the negative log-likelihood function L(x; z; ) = ? log p(x; z; ) = l(zjx; ) ? log p(x)
Yang and Amari: Complexity Issues in Natural GD
3
where
= (wT ; ; wTm ; aT ; bT )T ; and l(zjx; ) = ? log p(zjx; ) = 21 (z ? aT '(W T x + b)) : Since p(x) does not depend on , minimizing the loss function L(x; z ; ) is equivalent to minimizing l(z jx; ). Given the training set DT = f(xt ; zt ); t = 1; ; T g Q minimizing the loss function L(DT ; ) = Tt L(xt ; zt ; ) is equivalent to minimizing the train1
2
2
ing error
=1
Etr (; DT ) = T1
T X t=1
(zt ? aT '(W T xt + b))2 :
Since @@L = @@l , the Fisher information matrix is de ned by
@L )T ] = E [ @l ( @l )T ] G = G() = E [ @L ( (2) @ @ @ @ where E [] denotes the expectation with respect to p(x; z ; ). Let be a parameter set and P = fp(x; z ; ) : 2 g be a family of pdfs parameterized by 2 . The Kullback-Leibler divergence between two pdfs in P is Z 0 D(; ) = dxdz p(x; z; ) log pp((xx;;zz;;0)) : Let G() be the Riemannian metric tensor for the Riemannian space . The squared length of a small d in is kdkG = dT G()d: 2
(
)
The two spaces P and are naturally related by the following equation: D(; 0 ) = 21 d T G()d + O(kdk3 ) where 0 = + d and d is small. Amari(1998) showed that in the Riemannian space , for any dierentiable loss function F () the steepest descent direction is given by
?G? () @F @ : 1
Based on this result, Amari(1997) proposed the following natural gradient descent algorithm:
t = t ? t G? () @F @ +1
1
where is a learning rate. In particular, when negative log-likelihood function is chosen as the loss function, the above algorithm is known as the method of scoring in statistics.
Yang and Amari: Complexity Issues in Natural GD
4
Amari(1998) showed that this algorithm gives a Fisher ecient on-line estimator, i.e., the asymptotic variance of t driven by (6) satis es E [( t ? )( t ? )T j ] 1t G?1 ( ) (3) which gives the mean square error
E [k t ? k2 j ] 1t Tr(G?1 ( )):
(4)
The above property is veri ed by the simulation results in (Yang and Amari, 1998). Yang and Amari (1997a) has shown that G() = 12 A() where A() does not depend on 2 (the variance of the additive noise). De ne l1 (zjx; ) = 12 (z ? aT '(W T x + b))2 and 1 (5) re l1(zjx; ) = A?1() @l @ (zjx; ) To minimize the training error, we propose the following natural gradient descent algorithm:
t = t ? t re l (zt jxt ; t ) 1
+1
(6)
where the natural gradient re l1 (zt jxt ; t ) does not depend on 2 . To implement this algorithm, we need to compute the natural gradient re l1 (zt jxt ; t ) in each iteration. If we compute the inverse of A( t ) rst then the natural gradient, the time complexity is O(n3 ). We proposed a method in (Yang and Amari, 1997b) based on the conjugate gradient method to compute the natural gradient without inverting the Fisher information matrix. The idea is to solve the following linear equation
A(t )y = @l@ (zt jxt; t) 1
for y = re l1 (zt jxt ; t ) without inverting the matrix A(t ). Since the matrix A( ) is (n +2)m (n + 2)m, it will take at most (n + 2)m steps by the conjugate gradient method to compute re l1(zt jxt; t) and each step needs O(n) ops to compute a matrix vector product. So, the number of ops needed to compute the natural gradient by the conjugate gradient method is O(n2 ) when m n and O(n3 ) when m = O(n). Using the conjugate gradient method to compute the natural gradient is also useful for other probability families when the method of scoring is used to nd the maximum likelihood estimates. However, the conjugate gradient method is still slow in computing the natural gradient for training the stochastic perceptron. We shall explore the structure of the Fisher information matrix of the stochastic perceptron and develop an algorithm which computes the natural gradient in just O(n) ops.
Yang and Amari: Complexity Issues in Natural GD
5
4 Computing natural gradient for stochastic perceptron To nd an analytic form of the Fisher information matrix, we assume a white Gaussian input x N (0; I ) with an identity matrix I as the covariance matrix of the input. Yang and Amari(1997a) proposed an ecient scheme to represent the Fisher information matrix for the multi-layer perceptron. Due to this scheme, when m n the storage space needed for the Fisher information matrix is O(n) units rather than O(n2 ) units where a unit is the memory space to keep a parameter or a variable. Based on this scheme, we found a constructive procedure to compute the inverse of the Fisher information matrix with O(n2 ) ops. This procedure can be improved to compute the natural gradient with only O(n) ops.
4.1 Single-layer perceptron
We shall give some explicit expressions of the Fisher information matrix and the natural gradient for one-layer stochastic perceptron. Since z = '(wT x + b) + ; we have
@l = ? '0 (wT x + b)x; @w 2 0 T @l @b = ? 2 ' (w x + b):
The Fisher information matrix is
G() = 12 A() = 12
"
A a
11
21
#
a
12
a22
where = (wT ; b)T ,
A = E [('0 (wT x + b)) xxT ]; a = aT = E [('0 (wT x + b)) x]; a = E [('0 (wT x + b)) ]: Let w = kwk and u = w=w and extendPu to an orthogonal basis fu ; ; un g for 0; d1 (w; b) = p 2 Z?1 1 0 x2 1 (' (wx + b))2 x2 e? 2 dx > 0: d2 (w; b) = p 2 ?1
(9) (10) (11)
n X
A = d (w; b)u uT + d (w; b) uj uTj ; j 0 a = E [(' (wx + b)) x ]u = d (w; b)u and 11
2
1
0
1
=2
12
a22 = d0 (w; b):
2
1
1
1
1
1
A can be written as A = d (w; b)I + (d (w; b) ? d (w; b))u uT since fui g are orthogonal and n X uiuTi = I ? u uT : k Summarizing previous calculations for the blocks in G( ), we have A = d (w; b)I + (d (w; b) ? d (w; b))wwT =w ; a = aT = d (w; b)w=w; a = d (w; b): 11
11
0
2
0
1
1
1
1
=2
11
12
0
21
2
1
2
0
22
0
R px
(12) (13)
Saad and Solla(1995) used the scaled error function '(x) = erf( px2 ) = p2 0 2 e?t2 dt as the sigmoid function for the hidden units in order to obtain an analytic expression of the generalization error. If we choose '(x) as the sigmoid function for the hidden neurons, we obtain the following close forms for the integrals (9)-(11): p 2 2 d0(w; b) = p 2 2 expf?b2 + w2b +w0:5 g; w + 0:5 2 2 d2(w; b) = d0 (w; b) w2 +1 0:5 ( 21 + w2w+b0:5 ); p b2 w2 g: 2 exp f? b + d1(w; b) = ? (w2 +2wb w2 + 0:5 0:5)3=2
Yang and Amari: Complexity Issues in Natural GD
7
Even for the single-layer perceptron, the size of the Fisher information matrix is (n + 1) (n + 1). However, the Fisher information matrix can be generated on-line with O(n2 ) ops by the equations (12)-(13). If one wants to trace the Fisher information matrix, one only needs O(n) units to store w and b in each iteration. To compute the inverse of A( ), we need the following well known inverse formula of a four-block matrix:
Lemma 1
"
B B
11 21
B B
12
#?1
22
"
= B B
B B
11 21
#
12 22
provided jB 11 j 6= 0 and jB 22 ? B 21 B ?111 B 12 j 6= 0. Here, B 11 = B?111 + B?111 B12B ?221;1B21 B?111 , B 2222;1 = B?122 ? B 21B?111 B12 , B = B22;1, and B 12 = (B21 )T = ?B?111 B12B ?221;1.
Applying the above inverse formula of a four-block matrix, we obtain "
A? () = Aa aa 1
where
11
12
21
22
#
T d wwT ; + A = d1 I + ( d1 ? d1 ) ww w d (d d ? d ) w d d a = (a )T = ? d (d d ? d ) w; 2 1
11
0
12
2
2
0
2
2 1
0 2
(14)
2
(15)
1 2
21
2
a22 = d d d2? d2 : 0 2
0 2
2 1
(16)
1
By using the equations (14)-(15), the time complexity for computing A?1 ( ) is O(n2 ). If we compute the inverse of the Fisher information matrix and the function l1 (z jx; ) in two separate steps and then multiply them together, the time complexity for computing the natural gradient is O(n2 ). Instead of computing natural gradient in this way, we compute it in a single step by applying the equations (14)-(15):
x + ( d ? d + d d dd ?d ) wwT x w ? d d e rl (zjx; ) = ?h(x; z; w; b) 4 d d wT x ? d d d ?d + d dd ?d 2
1
0
1
1
2 1 2 2( 0 2
0
2) 1
2 1 2( 0 2 2 2 0 2 1
2) 1
2
3
d1 d2 2 w 2 (d0 d2 ?d1 ) 5
(17)
where h(x; z ; w; b) = (z ? '(wT x + b))'0 (wT x + b). It only needs O(n) ops to compute the natural gradient re l1 (z jx; ) by the equation (17).
Yang and Amari: Complexity Issues in Natural GD
8
4.2 Multi-layer perceptron
The main diculty in implementing the natural gradient descent algorithm (6) is to compute the natural gradient on-line. The number of parameters in the stochastic perceptron de ned by the equation (1) is m(n +2). So the size of the Fisher information matrix is m(n +2) m(n +2). In this paper, we assume m n. To compute the inverse of a n n matrix, it generally needs O(n3 ) ops by commonly used methods such as the lower and upper triangular factorization together with the triangular matrix inverse (Stewart, 1973) . In previous section, we have shown some procedures to compute the Fisher information matrix and the natural gradient for one-layer perceptron. Following the similar procedures, Yang and Amari (1997a) found a scheme to represent the Fisher information matrix for multilayer perceptron. Based on this scheme, we gave a method which requires O(n2 ) ops to compute the inverse of the Fisher information matrix. In this paper, we adapt this method to compute the natural gradient re l1 (z jx; ). We shall show that the adapted method requires only O(n) ops to compute the natural gradient. We rst brie y describe this scheme to represent the Fisher information matrix, then introduce the method to compute the natural gradient. We need some notations for block matrix. An m n block matrix X is denoted by X = [X ij ][mn] or X = [X ij ]i=1;;m;j =1;;n . Let A() = [Aij ][(m+2)(m+2)] be a partition of A() corresponding to the partition of = (wT1 ; ; wTm ; aT ; bT )T : De ne ui = wi=kwik; i = 1; ; m; U 1 = [u1; ; um]; and [v1 ; ; vm ] = U 1 (U T1 U 1 )?1 : Let rij = uTi uj , R1 = (rij )mm and R?1 1 = (rij )mm . De ne x0i = uTi x for i = 1; ; m; then (x01 ; ; x0m ) N (0; R1 ): The structure of the blocks in A() is elucidated by the following two lemmas proved in (Yang and Amari, 1997a) which give an ecient scheme to represent the Fisher information matrix for the stochastic multi-layer perceptron. Lemma 2 For 1 i; j m and ai 6= 0,
Aij = aiaj (cij + 0
where
0 =
n X k=m+1
uk uTk = I ?
m X clk
l;k=1 m X k=1
T ij ul vk )
uk vTk ;
(18)
(19)
Yang and Amari: Complexity Issues in Natural GD
9
cij = E ['0 (wi x0i + bi )'0 (wj x0j + bj )]; clkij = E ['0 (wi x0i + bi)'0 (wj x0j + bj )(
m X
s=1
(20)
rlsx0s)x0k ]:
(21)
Lemma 3 For 1 i m,
Ai;m = ATm +1
where
Ai;m
+2
where
;i = (
+1
m X ck
k=1
i1 vk ; ;
m X ck
k=1
im vk )
(22)
ckij = E ['0 (wi x0i + bi)'(wj x0j + bj )x0k ]; 1 i; j; k m:
(23)
m m X X T k e Ai;m+2 = Am+2;i = ( ci1 vk ; ; cekim vk ) k=1 k=1
(24)
has the same structure as Ai;m+1 :
cekij = ai aj E ['0 (wix0i + bi)'0 (wj x0j + bj )x0k ]; 1 i; j; k m:
Am
;m+1 = (bij )mm
(25)
;m+2 = ATm+2;m+1 = (ebij )mm
(26)
+1
with bij = E ['(wi x0i + bi )'(wj x0j + bj )] is a function of bi , bj , wi , wj and rij .
Am
+1
with ebij = aj E ['(wi x0i + bi )'0 (wj x0j + bj )].
Am
with b0ij = ai aj E ['0 (wi x0i + bi )'0 (wj x0j + bj )].
p
0
;m+2 = (bij )mm
+2
(27)
Assume '(x) = erf(x= 2) where erf(u) is the error function. It is not dicult to obtain analytic expressions for cij ; b0ij ; clkij , and c~kij . When bi = 0; i = 1; ; m;, applying the multivariate Gaussian average method in (Saad and Solla, 1995), we can obtain analytic expressions for these coecients ckij ; bij and ~bij . But, when bi 6= 0, the analytic expressions for these coef cients seemingly do not exist. These coecients are m-dimensional integrals which can be approximated o-line by some numerical integral methods. Recalling that n is the input dimension and m is the number of hidden neurons, in this paper we assume m n. Using Lemma 2 and Lemma 3, instead of using O(n2 ) units to store the Fisher information matrix we only need O(n) units to store the vectors uk and vk . O(n2 )
ops are needed to construct the Fisher information matrix from these vectors.
Yang and Amari: Complexity Issues in Natural GD
10
5 Complexity issues
We shall discuss the time complexity to compute A(), A?1 ( ) and the natural gradient for the multi-layer perceptron. By Lemma 2 and Lemma 3, to compute the blocks in A(), we need to compute uk and vk , k = 1; ; m. It needs O(n) ops to compute uk and vk by their de nitions. The time complexity to compute A() is O(n2 ) since each outer product uk vTk needs O(n2 ) ops. It is shown in (Yang and Amari, 1997a) that about O(n2) ops are needed to compute the inverse of the Fisher information matrix. So, if we compute the Fisher information matrix rst and then compute the natural gradient, this will take O(n2 ) ops in each iteration to implement the natural gradient descent. This is in the same order as the conjugate gradient method proposed in (Yang and Amari, 1997b). In the rest of this paper, we show that the time complexity of computing the natural gradient for the multi-layer perceptron can be signi cantly improved by using the representation scheme in Lemma 2 and Lemma 3 and by computing the natural gradient without inverting the Fisher information matrix. We shall show that we can compute the natural gradient directly with only O(n) ops. To show this, we re-examine the the process of computing the inverse of the Fisher information matrix. We need the following notations:
Gl(m;