Bayesian Adaptive Nearest Neighbor Ruixin Guo∗ and Sounak Chakraborty∗ Department of Statistics, University of Missouri-Columbia, Columbia, MO 65211-6100, USA
Received 21 January 2009; revised 15 October 2009; accepted 24 November 2009 DOI:10.1002/sam.10067 Published online 24 February 2010 in Wiley InterScience (www.interscience.wiley.com).
Abstract: The k nearest neighbor classification (k-NN) is a very simple and popular method for classification. However, it suffers from a major drawback, it assumes constant local class posterior probability. It is also highly dependent on and sensitive to the choice of the number of neighbors k. In addition, it severely lacks the desired probabilistic formulation. In this article, we propose a Bayesian adaptive nearest neighbor method (BANN) that can adaptively select the shape of the neighborhood and the number of neighbors k. The shape of the neighborhood is automatically selected according to the concentration of the data around each query point with the help of discriminants. The neighborhood size is not predetermined and is kept free using a prior distribution. Thus, we are able to make the model to select the appropriate neighborhood size. The model is fitted using Markov Chain Monte Carlo (MCMC), so we are not using exactly one neighborhood size but a mixture of k. Our BANN model is highly flexible, determining any local pattern in the data-generating process, and adapting it to give an improved prediction. We have applied our model on four simulated data sets with special structures and five real-life benchmark data sets. Our proposed BANN method demonstrates substantial improvement over k-NN and discriminant adaptive nearest neighbor (DANN) in all nine case studies. It also outperforms the probabilistic nearest neighbor (PNN) in most of the data analyses. 2010 Wiley Periodicals, Inc. Statistical Analysis and Data Mining 3: 92–105, 2010
Keywords: Bayesian prediction; discriminant adaptive nearest neighbor; Markov Chain Monte Carlo; nearest neighbor; probabilistic nearest neighbor; supervised learning
1. INTRODUCTION Class prediction is a supervised learning method where the model or the algorithm learns from a set of samples whose class labels are known in a training set and establishes a prediction rule to classify new samples whose labels are unknown in a test set. Development of class prediction methods generally consists of three steps: first, selection of predictors; second, fitting the prediction model to develop the classification rule; and third, performance assessment. The first two steps build a prediction model, and the third step assesses the performance of the model. Several model based and algorithmic classification techniques, such as the nearest neighbor method (k-NN), linear discriminant analysis (LDA), logistic regression, neural networks (NNET) [1], support vector machines (SVM) [2], classification trees (CART) [3], and random forests [4] are popularly used to predict the unknown class labels for unlabeled samples. Apart from the nearest neighbor model, all other methods require a parametric model structure, linear or nonlinear. Correspondence to: Ruixin Guo ((
[email protected]) and Sounak Chakraborty (
[email protected]) 2010 Wiley Periodicals, Inc.
Hence, often the prediction becomes highly sensitive to the choices of the tuning parameter. Along with that, most of the modern classification techniques, like neural network and random forests, are not easily interpreted. The SVM provides some geometric formulation in terms of the maximal separating hyperplane but it becomes difficult to visualize when the feature space or the covariates are more than three dimensions. Furthermore, none of the above methods are formulated under a probabilistic framework, and hence they fail to quantify the uncertainty of the predictions. In the last few years, Bayesian formulation of classification models [5] has equipped us to measure the uncertainty of the predicted class label using the posterior predictive distribution. A number of complex machine learning models, such as neural network, classification trees, and SVM, has been developed under the Bayesian framework [6–8]. However, all the Bayesian methods need a parametric formulation, are highly sensitive to the choice of priors, and often are unequipped to use the complex patterns of the data-generating process in the feature space to their advantage. All the methods above mentioned usually overparameterize a model which often creates a problem in the convergence of the Markov Chain Monte Carlo (MCMC).
Guo and Chakraborty: Bayesian Adaptive Nearest Neighbor
93
Fig. 1 DANN metric compared with k-NN.
The computational complexity and implementational difficulties are always very large for these models and that makes them unattractive and infrequently used in practice. The k nearest neighbor classification (k-NN) is one of the most popular and simple methods for classification. It classifies a new point x0 to the most frequent class among the k nearest neighbors of x0 , which are determined by some distance measure on the training set. Most commonly, Euclidean distance is used for determining the neighbors. The most appealing fact about the k nearest neighbor method is its ease of interpretation. The k-NN method is completely non-parametric because it does not make any distributional assumption on the data-generating process. When k = 1, Cover and Hart [9] demonstrated that asymptotically the error rate of 1-NN is twice the Bayes error rate which is optimal. Even so, there are major drawbacks to this algorithm. First, the k nearest neighbor method assumes constant local class posterior probability, which causes severe bias in high dimensions. As discussed in Hastie and Tibshirani [10], ‘in finite samples the curse of dimensionality can severely hurt the nearest neighbor rule’. Secondly, in the k nearest neighbor algorithm, the number of neighbors (k) is fixed. However, there is no formal way to choose k. The general practice is to use cross-validation to choose k. Thirdly, the standard formulation of the k-NN method does not allow any probabilistic interpretation. Hence we cannot assess the quality of each prediction. Fourthly, the Euclidean metric cannot capture the true pattern of the observations around the point of interest. Because it does not take into account the correlations of the data set and if the features are strongly correlated with different variances, it will fail to capture the pattern. As the dimension of the feature space or the number of covariates increases, the data become extremely sparse. Thus, even the closest neighbor is too far away in the feature space. For example, if the data are generated from p an unit hyper cube −1/2, 1/2 , it can be shown that for p = 10, to capture 1% of data, we must cover 63%
of range of each input variable, and when p = 100 we must cover 95%. In order to avoid this curse of dimensionality problem, Hastie and Tibshirani [10] propose a discriminant adaptive nearest neighbor classification (DANN), which uses local discrimination information to determine an effective metric for computing neighbors. The modified neighborhoods shrink in the direction orthogonal to the local decision boundary, and stretch out in the direction parallel to the decision boundary. See Fig. 1 (Fig. 2 in Hastie and Tibshirani [10]. The neighborhood in the left panel of Fig. 1 is obtained by standard k-NN, and the one in the right panel is determined by the DANN method, which extends in directions in which the class probabilities change the least. That means we are stretching the neighborhoods in the direction of the data points where it is more homogeneous in nature. In their initial study, Hastie and Tibshirani [10]applied DANN in a number of examples, and in most of them the DANN procedure showed substantial improvement over the standard k nearest neighbor classification by adapting to the concentration of data in the neighborhoods. Nevertheless, the proposed DANN still suffers from the drawbacks, that they require a fixed number of neighbors (km ) used to update the distance metric and a fixed number of neighbors (k) used for predicting the class label. However, there is no formal framework to choose k and km . Moreover, this method can only provide discrete predictions without any probabilistic formulation. A probabilistic nearest neighbor (PNN) method, which was first formulated by Holmes and Adams [11], overcame the difficulty of choosing k by putting a prior on k. The PNN model leads to the predictive distribution for a new point with the same form as the likelihood function. This method overcomes the second drawback above, but it still suffers from the curse of dimensionality. That is, in higher dimension when the data become sparse PNN fails to explore the concentration of data as DANN does and thus performs poorly for prediction. Statistical Analysis and Data Mining DOI:10.1002/sam
94
Statistical Analysis and Data Mining, Vol. 3 (2010)
(a)
(b)
Simulation 1 1.0
0.5 x2
0.5 x2
Simulation 2 1.0
0.0 –0.5
0.0 –0.5
–1.0
–1.0 –1.0
(c)
–0.5
0.0 x1
0.5
1.0
Simulation 3
–1.0
(d)
0.6
–0.5
0.0 x1
0.5
1.0
Simulation 4 0.2 0.1
x2
x2
0.2
0.0
–0.2 –0.2
–0.6 –0.6
–0.2
0.2 0.4 0.6 x1
–0.2
–0.1
0.0 x1
0.1
0.2
Fig. 2 Simulated data sets.
In this article, we propose a Bayesian adaptive nearest neighbor method (BANN), which adaptively selects the neighborhood shape and automatically selects the neighborhood size. Our proposed model combines the advantages of the PNN and the DANN without any assumptions on the distribution of the predictor variables. In Section 2, we give a brief review and some geometric intuition concerning the DANN method. In Section 3, we describe our BANN model, implementation algorithm, and prediction procedure. In Section 4, we illustrate the effectiveness of our BANN method on four simulated data sets and five real-life benchmark data sets. In the last section, we provide some discussion and future possibilities.
2.
DISCRIMINANT ADAPTIVE NEAREST NEIGHBOR
Consider a data set D = {(y 1 , X 1 ), (y 2 , X2 )} with n = n1 + n2 observations, where (y 1 , X1 ) = {(y11 , x 11 ), (y12 , x 12 ), . . . , (y1n1 , x 1n1 )} is the training set with n1 observations, and (y 2 , X 2 ) = {(y21 , x 21 ), (y22 , x 22 ), . . . , (y2n2 , x 2n2 )} is the test set with n2 observations. The X 1 = (x 11 , . . . , x 1n1 ) and X2 = (x 21 , . . . , x 2n2 ) are n1 × p and Statistical Analysis and Data Mining DOI:10.1002/sam
n2 × p data matrices for training and test sets, respectively. The subscript 1 indicates the training set, whereas the subscript 2 indicates the test set. y 1 is a n1 × 1 vector of known class labels for the training set, whereas y 2 is a n2 × 1 unknown vector and needs to be predicted. The dimension of predictors is p and n is the total number of observations. Suppose that there are J classes, yli ∈ {1, 2, . . . , J }, l = 1, 2, i = 1, 2, . . . , nl . We wish to classify the points in the test set to one of the J classes, based on the known information in the training set. In a word, we need to predict the unknown class memberships y 2 . Suppose that x 0 is a query point to be classified. The k-NN method determines the k nearest neighbors of x 0 in the training set, denoted as Nk (x 0 ), according to the Euclidean distance. The Euclidean distance squared between any point x 1i in the training set and the query point x 0 is calculated as di = (x 1i − x 0 )T (x 1i − x 0 ). Then x 0 is classified to the most frequently occurring class among the k nearest neighbors, Nk (x 0 ). In this section, we give a brief sketch of the DANN method proposed by Hastie and Tibshirani [12]. The DANN classification [10] uses a local LDA to calculate an effective local weight for determining the neighbors for each query point. Then the square of the
Guo and Chakraborty: Bayesian Adaptive Nearest Neighbor
distance between any point in the training set, x 1i , and the query point x 0 is defined as, di = (x 1i − x 0 )T (x 1i − x 0 ),
(1)
= W −1/2 [B ∗ + I ]W −1/2 ,
(3)
(1) Initialize distance metric = I . (2) Use Eq. (1) to find km nearest neighbors around the test point x 0 , denoted as Nkm (x 0 ). (3) Calculate weighted within and between matrices W and B using the points in Nkm (x 0 ). B(x 0 ; km ) =
πj (x 1j − x 1 )(x 1j − x 1 )T ,
(4)
j =1
W (x 0 ; km ) J j =1
= where
y1i =j
(5) Iterate (2), (3), and (4).
(2)
where W and B denote the p × p local weighted matrix of within and between sum-of-squares, and B ∗ = W −1/2 BW −1/2 can be interpreted as the between sum-of-squares in the sphered space using W . Here, is a small tuning parameter, which avoids the neighborhood extending infinitely in the direction of the null space of B ∗ . Thus, the DANN metric effectively stretches the neighborhood in the direction of the zero eigenvalue or least information and limits the stretching with the tuning parameter . This is typically useful in a large dimension when data are sparse or when the data may be concentrated along some specific predictor variables. In the DANN methodology, two different neighborhood sizes are used, km for calculating the matrix and k for the classification of a new sample. The DANN classification algorithm is as following:
J
(4) Update the metric as = W −1/2 [W −1/2 BW −1/2 + I ]W −1/2 .
(6) Classify x 0 by using k nearest neighbor based on the limiting metric .
with the DANN metric proposed as, = W −1/2 [W −1/2 BW −1/2 + I ]W −1/2
95
wi (x 1i − x 1j )(x 1i − x 1j )T n1 , (5) i=1 wi
On the basis of the results of a number of examples in Hastie and Tibshirani [10], they mentioned that ‘on average there seems to be no advantage in carrying out more than one iteration of the DANN procedure’. Hence, in the following part of this article, for simplicity, we use only one iteration to calculate the local weight matrix in our BANN algorithm. In fact, in Eq. (2) implicitly depends on a parameter km and can be treated as a function of km . From now on, we use (x 0 ; km ) to reveal its dependence on km and its local property (calculated for each query point x 0 ). (x 0 ; km ) is local since it is calculated by using the points in the neighborhood of x 0 . Once (x 0 ; km ) is calculated, the new point x 0 can be classified based on this new distance metric and k nearest neighbors. For the choice of km and k, Hastie and Tibshirani [10] suggested using km = max(n/5, 50) and k = 5. However, just like the standard k-NN, there is no formal framework to choose the number of neighbors and often the performance of the predictors is highly sensitive to the choice of km and k. Another drawback is that the predictions made by the DANN method have no probabilistic interpretation.
3. BAYESIAN ADAPTIVE NEAREST NEIGHBOR In this section, we propose our BANN, which is a probabilistic formulation of the nearest neighbor along with the ability to automatically select the neighborhood size and adaptively choose the neighborhood shape according to the distribution of data points.
3.1. Model Formulation
y =j
πj = n1i1
wi
i=1 wi
wi = 1 −
di h
(6)
,
3 3 I (|di | < h), h =
max
i∈Nkm (x 0 )
di , (7)
x 1 is the weighted mean of all the training points in Nkm (x 0 ), x 1j is the weighted mean of the observations belonging to class j in Nkm (x 0 ), and di is defined in Eq. (1).
Similar to Holmes and Adams [11], we define the joint likelihood of our data points p(y 1 |X 1 , km , k, β) by p(y 1 |X 1 , km , k, β)
exp β(1/k) n1 k,km δy1i y1j j ∼ i ,
= J i=1 k,km δqy1j q=1 exp β(1/k) j ∼ i
(8)
where β acts as the parameter for the strength of association among nearest neighbors, k is the neighborhood size Statistical Analysis and Data Mining DOI:10.1002/sam
96
Statistical Analysis and Data Mining, Vol. 3 (2010)
required to construct a k-NN classifier, and km is the neighborhood size used to construct the distance metric as discussed in Section 2. The δ is a delta Dirac function defined as δab = 1 if a = b and zero otherwise. In the likelihood (8), the term (1/k) k,km δqy1j calculates the proportion j ∼ i of training points, in the k nearest neighbors of x1i , belonging to class q. The dependence on km is through the local distance metric (x 1i ; km ). That is, rather than using the Euclidean distance as in Holmes and Adams [11], we adopt the local DANN metric Eq. (2), which depends on km . The squared distance between any training point x 1j and the point x 1i is defined as (x 1j − x 1i )T (x 1i ; km )(x 1j − x 1i ), and we use the DANN procedure as described in Section 2 to derive (x 1i ; km ) locally. That is, (x 1i ; km ) = W −1/2 [W −1/2 BW −1/2 + I ]W −1/2 , (9) where B = B(x 1i ; km ) and W = W (x 1i ; km ) are the weighted within and between matrices calculated similar to Eqs. (4) and (5) by using the points in the neighborhood Nkm (x 1i ). The parameter km is the neighborhood size used to obtain the discriminant adaptive weight matrix (x 1i ; km ). Reviewing the likelihood (8) carefully, we see that it is of the ‘multinomial form’ and also normalized joint probability mass function. In fact, the normalizing constant is independent of β, k, and km . This normalization property is very useful in the analysis when β, k, and km are treated as random. The multinomial type formulation in Eq. (8) utilizes the idea of local specification through the local conditional distribution to determine the joint distribution [12]. This procedure is quite popular in spatial statistics and is referred to as a Markov random field (MRF) [13]. A rigorous proof of this argument can be found in Ref. [14], where they established this fact by proving the converse of the Hammerseley-Clifford Theorem [15]. 3.2. Priors and Posterior In the likelihood described in Eq. (8), we have three unknown parameters β, km , and k. The parameter β acts like a regression coefficient inside a local logistic regression in the number of the neighboring classes and it can be viewed as an interaction or association parameter. Since the neighboring classes should behave similarly, there must be a positive association between them. So we set the association parameter β to be positive. In other words, a positive value of β ensures that the probability of a sample belonging to a class increases if it has a larger number of samples in its neighborhood from that corresponding class. We put a truncated normal density with a large variance on the association parameter β as follows: β ∼ 2N (0, b)I (β > 0), Statistical Analysis and Data Mining DOI:10.1002/sam
Instead of fixing km and k, we place prior densities on them. In our model, km is the number of neighbors needed to obtain the local (x i ; km ) matrix for each training point x i and k is the neighborhood size for predicting the class labels in the test set. As mentioned by Hastie and Tibshirani [10], in order to capture the local data pattern, we need a larger number of neighbors (km ) for calculating (x i ; km ) than for the final prediction (k). Thus, it is intuitive to restrict k and km in the prior formulation by k ≤ km . The conditional prior on k is assigned as discrete uniform, k|km ∼ DU[1, 2, . . . , km ].
(11)
We use a discrete uniform prior on km as, km ∼ DU[km.min , km.min + 1, . . . , km.max ]
(12)
where km.min is usually chosen to be a small number such as 5 and km.max is usually chosen to be the total number of observations n1 in the training set. Although these choices of non-informative priors for k and km are justifiable from the objective Bayes viewpoint, often setting k.max = n1 slows down the computation if n1 is very large. In Section 4, we will discuss some special choices of km.min and km.max which provide faster result without any compromise in terms of the prediction of accuracy and also the sensitivity of our model. We can write the joint posterior density as π(km , k, β|y 1 , X 1 ) ∝ p(y 1 |X 1 , km , k, β) × π(β) × π(k|km ) × π(km )
exp β(1/k) δ n k,km y1i y1j 1 j ∼ i
∝ J i=1 k,km δqy1j q=1 exp β(1/k) j ∼ i 2 1 × e−β /2b I (β > 0) . (13) km 3.3. Prediction The predictive distribution for a new point x 0 is given by, p(y0 |x 0 , y 1 , X1 , km , k, β)
exp β(1/k) k,km δy0 y1j j ∼ 0 .
= J k,km δqy1j q=1 exp β(1/k) j ∼ 0 In the predictive distribution (13), the term (1/k)
(14)
(10)
k,km
∼ 0 δqyj calculates the proportion of training points, in the k j
Guo and Chakraborty: Bayesian Adaptive Nearest Neighbor
nearest neighbors of x0 (from the test set) belonging to class q, where the distance is determined based on (x 0 ; km ). We can calculate the marginal predictive distribution as p(y0 |x 0 , y 1 , X 1 ) =
km
k
p(y0 |x 0 , y 1 , X1 , km , k, β) β
π(km , k, β|y 1 , X1 )dβ (15) where π(km , k, β|y 1 , X 1 ) is given in Eq. (13). Since the integral above has no closed form, it cannot be computed analytically. We use the MCMC method to generate from the posterior π(km , k, β|y 1 , X 1 ) and estimate p(y0 |x 0 , y 1 , X 1 ) as (y0 |x 0 , y 1 , X1 ) p =
M 1 (g) (g) p(y0 |x 0 , y 1 , X1 , km , k , β (g) ), (16) M
97
predictive distribution depends on 2m , which cannot be updated through the data. Therefore, there is no advantage in treating T as unknown over adopting the local DANN metric directly. There might be other solutions, but the computational complexity is far more than the achieved predictive accuracy. That is why in this article, we did not put any prior structure on T . 3.5. BANN Algorithm To fit our BANN model and predict classes in the test set, we use the random walk Metropolis-Hastings algorithm [16] to generate samples from π(km , k, β|y 1 , X 1 ) based on the training set. Using the generated samples in Eq. (15), we do the prediction for the test set. Our detailed step-by-step MCMC algorithm is as follows. (0) , k (0) , β (0) . Step 1. Start with initial values km
g=1
Step 2. Let g denote the gth iteration. where M is the generated MCMC sample size after the (g) initial burn in, and {(km , k (g) , β (g) ), g = 1, 2, . . . , M} are the generated MCMC samples. We calculate these probabilities for y0 = 1, 2, . . . , J and assign the new sample to that class for which the corresponding posterior predictive probability is maximum.
Step 3. Update km by using random walk MetropolisHastings method. (prop)
(g−1)
(1) Propose km from km ∼ km ± DU[0, 1, 2, (prop) . . . , dkm ]. The proposed km must be within (km.min , km.max ). (2) The full conditional distribution for km is
3.4. Justification Against Putting Prior on As discussed in the introduction, one drawback of k-NN and PNN is that they cannot capture the local data pattern. In order to overcome this problem, the distance metric needs to be local depending on each query point, which is one of our motivations. That is why we adopt the DANN method. In fact, our initial thinking was to treat the local weight matrices as unknown and assign priors to these matrices and make it a full Bayesian model. Suppose that T = { 11 , . . . , 1n1 , 21 , . . . , 2n2 }, which is unknown. Then the joint likelihood becomes p(y 1 |X 1 , k, β, T )
exp β(1/k) n1 k, 1i δy1i y1j j ∼ i , (17)
= J i=1 k, 1i δqy1j q=1 exp β(1/k) j ∼ i where Wishart priors can be assigned on T , and we can use the DANN metric as the scale matric of the Wishart distribution to capture the local pattern. However, as we can see from Eq. (17), the data can only update 1i , i = 1, 2, . . . , n1 for the training set in the posterior. When classifying a test point x 2m , m = 1, 2, . . . , n2 , the
π(km |k, β, y 1 , X1 ) −1 ∝ km
n1 i=1
exp β(1/k) k,km δy1i y1j j ∼ i .
Q k,km δqy1j q=1 exp β(1/k) j ∼ i (18)
Note that when determining the neighborhood Nkm (x 1i ), the distance between any training point x 1j and the query point x 1i in the training set is determined by d(x 1i , x 1j ; km ) = (x 1j − x 1i )T (x 1i , km )(x 1j − x 1i ). (19) For each point x 1i in the training set, calcu(prop) late the local distance metrics (x 1i , km ) (g−1) and (x 1i , km ). By using these local distance metrics, we compute the full conditional (prop) posteriors π(km |k (g−1) , β (g−1) , y 1 , X1 ) and (g−1) (g−1) , β (g−1) , y 1 , X1 ). π(km |k Statistical Analysis and Data Mining DOI:10.1002/sam
98
Statistical Analysis and Data Mining, Vol. 3 (2010) (prop)
(3) Accept the single proposal km tance probability rkm = min 1,
with accep-
(prop) (g−1) |k , β (g−1) , y 1 , X1 ) π(km (g−1) (g−1) π(km |k , β (g−1) , y 1 , X 1 )
.
Step 6. Update β by using the random walk Metropolis method. (1) Propose β from β (prop) ∼ N (β (g−1) , dβ ), if β (prop) < 0, β (prop) = −β (prop) . (2) The full conditional probability for β is
(20) The proposal step dkm is controlled to obtain around 30% acceptance rate. Step 4. For each observation x 1i in the training set, use (g) the updated value of km (which is km ) to calculate (g) the local distance metric (x 1i , km ) as defined in Eq. (9). Step 5. Update k by using the random walk MetropolisHastings algorithm. (1) Propose k from k (prop) ∼ k (g−1) ± DU[0, 1, 2, . . . , dk ]. (2) The full conditional posterior distribution for k is π(k|km , β, y 1 , X 1 )
exp β(1/k) n1 k,km δy1i y1j j ∼ i . (21)
∝ Q i=1 exp β(1/k) δ k,km qy1j q=1 j ∼ i (g) Compute π(k (prop) |km , β (g−1) , y 1 , X 1 ) (g) π(k (g−1) |km , β (g−1) , y 1 , X 1 ), where for
and each point x 1i in the training set, the distance between any point x 1j in the training set and x 1i is determined by using the local distance metric obtained from Step 4.
(3) The acceptance rate for k (prop) is rk = min 1,
(g) , β (g−1) , y 1 , X1 ) π(k (prop) |km (g) π(k (g−1) |km , β (g−1) , y 1 , X 1 )
. (22)
The parameter dk is controlled to get around 30% acceptance rate. As mentioned in Section 3.2, it is intuitive to restrict k < km . In fact, in the calculations we found that there is no advantage to considering values of k besides k (prop) < km.max /2. So we restrict our attention to k (prop) < km.max /2, which can speed up the algorithm. Statistical Analysis and Data Mining DOI:10.1002/sam
π(β|km , k, y 1 , X1 )
exp β(1/k) δ n k,km y1i y1j 1 j ∼ i
∝ Q i=1 exp β(1/k) δ k,k qy m 1j q=1 j ∼ i
2 × e−β /2b I (β > 0) . (23) (g)
Compute π(β (prop) |km , k (g) , y 1 , X1 ) and (g) (g−1) (g) |km , k , y 1 , X1 ), where for each π(β point x 1i in the training set its neighbors are determined by using the local distance metric obtained from Step 4. (3) Accept β (prop) with probability rβ = min 1,
(g) (g) , k , y 1 , X1 ) π(β (prop) |km (g) (g) π(β (g−1) |km , k , y 1 , X1 )
. (24)
The normal variance dβ is controlled to get around 30% acceptance rate. Step 7. When G is the total number of MCMC samples required, if g < G, g = g + 1 and go back to Step 3; otherwise, go to the next step. (g) Step 8. Let (km , k (g) , β (g) ), g = 1, 2, . . . , M denote the generated MCMC samples from the posterior distribution π(km , k, β|y 1 , X1 ) after the initial burn in. For each test point x 2i in the test set X 2 , compute the posterior predictive distribution p(y0 |x 2i , y 1 , X 1 , (g) (g) km , k (g) , β (g) ) for each MCMC sample (km , k (g) , (g) β ), g = 1, 2, . . . , M, where the distance between any point x 1j in the training set and x 2i is determined (g) based on the DANN metric (x 2i , km ). Step 9. Estimate the marginal predictive distribution (15) using the Monte Carlo methods in Eq. (16). Step 10. Assign each test point x 2i to class q for which the corresponding estimated posterior predictive proba(y2i = q|x 2i , y 1 , X1 ) is the maximum. bility p
Guo and Chakraborty: Bayesian Adaptive Nearest Neighbor
4. SIMULATIONS AND CASE STUDIES
The MCMC steps in the algorithm are simple. However, in terms of computation time complexity, Step (3) is more complicated than it appears. Because we use the DANN procedure to determine the neighborhood, it depends on the parameter km . So once we propose a new km , we need to use the DANN procedure to update all the local distance metrics and recalculate the neighborhoods. The local determination and update of the distance metric is a novel part of our likelihood-based nearest neighbor approach. For each query point, our local updating scheme of (x 1i , km ) helps us to stretch the neighborhood along the direction of minimum class probability change. For a given km and a query point x 1i , our method stretches the shape of the neighborhood in the direction of the small eigenvalues. Thus, our formulation allows the model to change adaptively. If we have less information along a direction, then our method will try to capture more information from that direction by expanding its shape along that particular direction. Table 1.
99
We apply our proposed BANN algorithm to four simulated data sets with special structures and five real-life benchmark data sets. In all examples, we compare the performance of our BANN with other nearest neighbor methods like (i) standard k nearest neighbor (k-NN), (ii) DANN, and (iii) PNN. For k-NN and DANN, we select k by tenfold cross-validation procedure. Each training set is randomly partitioned in to ten sub-samples. For k = 1, 5, 10, 15, 25, 50, we fit the data using nine sub-samples and predict in the remaining one sub-sample. We do this for all choices of nine sub-samples as the fitting set and one sub-sample for the predicting set, and calculate the average misclassification error for each choice of k. For final prediction in the test set, we pick the k that gave us lowest cross-validation error. In most of the data sets, the best result is obtained by setting k = 5. In Tables 1 and 2, we report the optimal k for each data set that we
Average misclassification percentage for simulated data sets.
Simulations
BANN
k-NN
DANN
PNN
Simulation 1
19.3 (0.026) 8.5 (0.021)
Simulation 3
11.7 (0.012)
Simulation 4
11.0 (0.023)
19.7 (0.024) [5] 9.9 (0.020) [5] 12.5 (0.014) [5] 11.2 (0.021) [10]
20.4 (0.029)
Simulation 2
21.0 (0.032) [5] 10.4 (0.023) [5] 14.2 (0.029) [5] 13.1 (0.023) [10]
10.3 (0.021) 21.4 (0.030) 14.8 (0.026)
Notes: The lowest misclassification errors are highlighted in bold. The values in ( ) are the corresponding standard deviations of the misclassification percentages of the ten splits. The number in [ ] in the k-NN and DANN columns is the optimal values of k (the number of nearest neighbors).
Table 2. Data sets
Misclassification percentage for benchmark data sets. BANN
k-NN
DANN
PNN
Sonar
13.3 (0.021) 36.2 (0.040)
Vowel
39.6
Pima
25.3 (0.013)
Heart
18.0 (0.030)
14.4 (0.019) [5] 37.0 (0.046) [10] 42.4 [10] 26.2 (0.012) [5] 18.4 (0.034) [10]
18.2 (0.033)
Glass
21.9 (0.036) [15] 40.0 (0.049) [15] 48.3 [5] 27.7 (0.013) [10] 18.4 (0.026) [10]
39.4 (0.043) 49.4 25.9 (0.010) 16.2 (0.021)
Notes: The lowest misclassification errors are highlighted in bold. The values in ( ) are the corresponding standard deviations of the misclassification percentages of the ten splits. The number in [ ] in the k-NN and DANN columns is the optimal values of k (the number of nearest neighbors). Statistical Analysis and Data Mining DOI:10.1002/sam
100
Statistical Analysis and Data Mining, Vol. 3 (2010)
found out through the tenfold cross-validation. We fit the PNN by using the available MATLAB function from Chris Holmes’ website (http://www.stats.ox.ac.uk/cholmes/). All other computations are completed in R. For our BANN method, an easy to use R function is created. It is available on request from the author, which is easy to implement with no user-set parameters. To achieve objectivity, we use near diffuse but proper prior for β, and discrete uniform priors for k and km , as discussed in Section 3.2. Throughout the examples, we choose b = 1000, a large variance for the prior of β. For km , we use the prior given in Eq. (12). We recommend choosing km.min and km.max so that the mean of the uniform prior distribution of km is equal to max{n1 /5, 50}. This value is suggested in Hastie and Tibshirani [10] as the fixed value for km . We denote km0 = max{n1 /5, 50} and fix km.min and km.max as km.min = max km0 − wkm , 5 , km.max = min km0 + wkm , n1 ,
(25)
where wkm is used to control the computation time. Simply fixing km.max = n1 slows up the computation if n1 is very large. For a small training set, we can use larger wkm relative to n1 , whereas for large data sets, we can choose a relatively smaller wkm compared with n1 . This is reasonable since for large data sets, for example, with 500 training points, we actually do not need a large km , such as 200 or more to determine a local distance metric. Furthermore, if km is too small, then we will not get a good metric. In most of the cases, we use wkm = 50. This choice of prior settings for km ensures objectivity of our BANN model and at the same time produces a stable adaptive local metric (·; km ). We put a flat prior on the neighborhood size parameter k as discussed in Eq. (10). The support of k is restricted to k < km to avoid bias. This choice is reasonable, since we need fewer samples to do the classifications than to determine . The parameter in Eq. (3) avoids too much stretching of the neighborhood around a query point. In all our calculations, we fixed = 1. Hastie and Tibshirani [10] tried other values of in the set {0.01, 0.1, 0.2, 0.5, 1, 2, 5}. It seemed that there was no obvious difference by using different , and = 1 appears to dominate. For all the examples, we ran two independent MCMC chains for 50 000 iterations with the first half from each chain as the burn in. The convergence of the chain is checked by studying the trace plots and also by calculating Gelman-Rubin Diagnostics (R package coda). For all examples, the calculated Gelman-Rubin scale reduction factors are found to be close to 1, which indicates the convergence of our MCMC chain. Statistical Analysis and Data Mining DOI:10.1002/sam
4.1. Simulation Studies In this section, we describe four simulated data sets and demonstrate how our BANN model can explore the special pattern of the data sets for better prediction accuracy. Under each simulation framework, we generate ten independent data sets and report the average misclassification percentage over all these ten independent sets. The detailed simulation frameworks are as follows: • Simulation 1: Two classes with two covariates (X1 , X2 ), that is J = 2, p = 2. For class 1, we have generated (X1 , X2 ) uniformly such that X12 + X22 = 1 (on the circumference of a circle with radius 1). Then add independent N (0, 0.08) errors to X1 and X2 , respectively. For class 2, we generated (X1 , X2 ) uniformly such that X12 + X22 < 1 (inside the circle with radius 1). Then add independent N (0, 0.08) errors to X1 and X2 , respectively. We have n1 = 250 (training set) and n2 = 250 (test set). The generated point pattern is described in Fig. 2(a). • Simulation 2: Two classes with two covariates (X1 , X2 ), which implies J = 2 and p = 2. For class 1, we have generated (X1 , X2 ) uniformly such that X12 + 2X22 = 1 (which is the circumference of an ellipse with length of the semimajor axis 1 and length of the semiminor axis 1/2). Then add independent N (0, 0.05) errors to X1 and X2 , respectively. For class 2, we generated (X1 , X2 ) uniformly, such that X12 + 2X22 < 1 (inside the ellipse). Then add independent N (0, 0.05) error to X1 and X2 , respectively. We have n1 = 200 (training set) and n2 = 200 (test set). The generated point pattern is described in Fig. 2(b). • Simulation 3: Three classes with two covariates (X1 , X2 ), which implies J = 3 and p = 2. For class 1, we have generated (X1 , X2 ) uniformly such that X12 + X22 = (0.6)2 . That is, on the circumference of a circle with radius 0.6. For class 2, we generated (X1 , X2 ) uniformly such that X12 + X22 = (0.5)2 . That is on the circumference of a circle with radius 0.5. For class 3, we generated (X1 , X2 ) uniformly such that X12 + X22 < (0.5)2 (inside the circle with radius 0.5). In all three classes add independent N (0, 0.04) errors to X1 and X2 , respectively. We have n1 = 300 (training set) and n2 = 240 (test set). The generated point pattern is described in Fig. 2(c). • Simulation 4: Five classes with two covariates (X1 , X2 ), which implies J = 5 and p = 2. For class 1, we have generated (X1 , X2 ) uniformly such that X1 = 0 or X2 = 0, the horizontal and vertical axes of a coordinate system. For class 2, generate points uniformly from the first quadrant. For class 3, generate points
Guo and Chakraborty: Bayesian Adaptive Nearest Neighbor
(a)
Simulation 1
(b)
Simulation 2
0.14
0.26
(c)
Simulation 3
(d)
101
Simulation 4
0.18
0.20
0.16
0.24
0.18
0.12
0.14
0.22 0.16 0.10
0.12
0.20 0.14 0.18
0.10
0.08 0.12
0.16
0.08 0.06
0.10
0.14
0.06 B
k
D
P
B
k
D
B
P
k
D
P
B
k
D
P
Fig. 3 Boxplots of misclassification rates over ten splits for four simulations.
uniformly from the second quadrant. For class 4, generate points uniformly from the third quadrant. For class 5, generate points uniformly from the fourth quadrant. In all cases, add independent N (0, 0.01) errors to X1 and X2 , respectively. We have n1 = 250 (training set) and n2 = 250 (test set). The generated point pattern is described in Fig. 2(d). From Fig. 2, we can see that in all cases, there is a special data pattern, so we need an adaptive neighborhood selection to have a better class prediction. The average misclassification percentage for each simulated data set is reported in Table 1. The numbers in the parentheses are the corresponding standard deviation of the misclassification percentage. From Table 1 and Fig. 3, we can see that our BANN method consistently outperforms other nearest neighbor methods and gives the lowest misclassification rate. Compared with DANN, our BANN method provides a probabilistic interpretation without losing the ability of catching the local data pattern. Unlike DANN, km and k are not fixed, and we select km and k adaptively using the data. This is an added flexibility in our model. Often the performance of k-NN and DANN are dependent on the choice of the neighborhood sizes. Our BANN model can avoid this problem and automatically choose the best neighborhood size by maximizing the posterior. In Fig. 4, we plot the posterior histogram of k and km . Fig. 5 illustrates the neighborhoods of BANN compared with that of k-NN by using the posterior median of km and k. The BANN neighborhoods capture the local pattern of the simulated data, which extends along the boundaries and
shrinks in the direction which is orthogonal to the boundaries. For illustration purpose, Fig. 6 gives the trace plots for one of the ten simulated data sets of Simulation 1. From the plot, we can see that the MCMC chain converges very well for the parameters km , k, and β. The posterior medians of km and k for this MCMC sample is 18 and 7, which is used in the neighborhood plot (Fig. 5). 4.2. Case Studies In this section, we apply our BANN model to five reallife benchmark data sets and compare the results from BANN with k-NN, DANN and PNN. For all the examples in this section, the data sets are standardized before applying the algorithm. 4.2.1. Sonar data This data set is included in the dprep package in R. It contains 208 observations, 60 predictors (p = 60), two classes (J = 2). The two classes are ‘mine’ and ‘rocks’, coded as 1 and 0, respectively. We split this data set into training and test sets with 104 observations in each of them, which implies n1 = n2 = 104. We split the data set randomly ten times and compute the average misclassification error. The misclassification result for the sonar data under the BANN model is listed in the first row of Table 2, along with other nearest neighbor methods. The performance of our BANN method is 8.6% better than the conventional five-nearest neighbor, and around 5% better than the PNN. Statistical Analysis and Data Mining DOI:10.1002/sam
102
Statistical Analysis and Data Mining, Vol. 3 (2010)
Sim 1. Histogram of km
Sim 1. Histogram of k
0.05 0.00
0.4 0.2 0.0
0
10 20 30 40 km
0.08
Density
0.10
Density
Density
0.04 0.00
2 4 6 8
12
16
Sim 3. Histogram of km 0.8
0 20
Sim 3. Histogram of k
60 km
100
0
Sim 4. Histogram of km 0.08
0.2 0.0
0.4 0.2 0.0
10 15 20 25 30 35 km
4
6 k
0.04
10
8
15
0.4
0.00 2
10
Sim 4. Histogram of k
Density
Density
0.4
5 k
0.6
0.6
0.10
0.00
k
Density
Density
Sim 2. Histogram of k 0.20
0.6
0.15
Density
Sim 2. Histogram of km
0.2 0.0
20
40 60 km
80
0
5
10 15 20 k
Fig. 4 Histograms of posterior km and k for four simulations. Neiborhoods for Simulation 4
1.0
0.2
0.5
0.1 x2
x2
Neighborhoods for Simulation 1
0.0
0.0
–0.5
–0.1
–1.0
–0.2 –1.0
–0.5
0.0 x1
0.5
1.0
–0.2
–0.1
0.0 x1
0.1
0.2
Fig. 5 BANN neighborhoods compared with k-NN neighborhoods. The left panel is for Simulation 1 by using posterior median of km = 18 and k = 7. The right panel is for Simulation 4 by using posterior median of km = 42 and k = 6. km is used for estimating the local metrics and k is the number of neighbors used to build up the neighborhoods. The little triangles are the query points in the test set. The navy-blue ellipses represent the BANN neighborhood and the orange circles are the k-NN neighborhoods.
Moreover, our BANN also outperforms DANN. Although the margin is relatively small (1%), our BANN is more appealing because it has a nice probabilistic interpretation. The posterior plot of k (Fig. 7) indicates that a specific neighborhood size is detected through the data. 4.2.2. Glass data This data set is obtained from the UCI repository [17]. The original data set contains 214 observations with seven types of glasses. There are p = 9 attributes. We follow Ripley [18] to regroup the data into 4 classes and omit 29 observations, which implies J = 4 and n = 185. We Statistical Analysis and Data Mining DOI:10.1002/sam
randomly split the 185 samples into a training set of size n1 = 89 and a test set of size n2 = 96. We repeat the split ten times and compute the average misclassification rate. The misclassification results from our BANN and other nearest neighbor methods are listed in Table 2. We can see that BANN and DANN have around 4% lower misclassification errors than k-NN and PNN. Moreover, our BANN has slightly lower misclassification error than DANN. So for the Glass data set, our BANN also is the winner. The posterior plot (Fig. 7) suggests a posterior median of k = 3 and hence for this data set the best neighborhood size is k = 3.
Guo and Chakraborty: Bayesian Adaptive Nearest Neighbor
103
Fig. 6 Illustrative traceplot for Simulation 1 and sonar dataset.
0.00
0.00 40 60 km
80
15 k
25
0.000
5 0
0
2
4
6
8 10
k
Pima. Histogram of km
0.015 0.000
0
1
2
3
4
5
k
Heart. Histogram of km
0.0
Pima. Histogram of k 0.12
15
110 130 150 170 km
1.0
0 20 40 60 80 km
Vowel. Histogram of k Density
Density
Vowel. Histogram of km 0.020
0.00 0 5
Density
20
0.06
Density
0.10
Glass. Histogram of k
Density
Density
Density
0.04
Glass. Histogram of km 0.12 Density
Sonar. Histogram of k
Sonar. Histogram of km
0.06 0.00
20
60 100 km
10 20 30 40 50 k
Heart. Histogram of k Density
Density
0.12 0.020 0.000
0.06 0.00
0 20
60 km
100
0 10 20 30 40 50 k
Fig. 7 Histograms of posterior km and k for five benchmark datasets. Statistical Analysis and Data Mining DOI:10.1002/sam
104
Statistical Analysis and Data Mining, Vol. 3 (2010)
4.2.3. Vowel data This data set is also taken from the UCI repository [17]. It includes a training set of size n1 = 528 and a test set of n2 = 462. There are p = 10 predictors and J = 11 classes. We keep the original training set and test set split for the vowel data as done by other authors. The third row of Table 2 shows the misclassification percentages for the Vowel data. Our BANN achieves 10% lower misclassification rate than 5-NN and PNN, and around 3% lower misclassification than DANN, which is very encouraging. The vowel data is one of the most difficult data sets to classify. The effectiveness of our BANN model is demonstrated by the fact that we are able to achieve the lowest misclassification error (39.6%) beating all other methods.
4.2.4. Pima Indians Diabetes This data set consists of 768 observations, 2 classes and 8 attribute characteristics (which implies n = 768, J = 2, and p = 8). It is obtain from the UCI repository [17]. We randomly split it into a training set with n1 = 368 points and a test set with n2 = 400 points, and we repeated this ten times. The average misclassification results are given in Table 2 showing significant improvement of our BANN over other nearest neighbor methods. For this data set, our BANN also on average achieves the highest correct prediction rate.
4.2.5. Heart data Holmes and Adams [11] used this data set to illustrate their PNN model. There are n = 270 observations, p = 13 predictors, and J = 2 classes, which are also obtained from the UCI repository [17]. It is randomly split into a training set (n1 = 135) and a test set (n2 = 135) ten times. The results are reported in Table 2, which are the average misclassification rates for our BANN model and three other competing nearest neighbor methods. For this data set, the PNN gives the lowest misclassification rate. Compared with the BANN, the PNN achieves a 1.8% lower misclassification rate. Our BANN does better than DANN and k-NN. This indicates a strength in having an adaptive selection of the neighborhood sizes k and km . The posterior histograms of km and k are plotted for each real data set in Fig. 7. In terms of computation time, for the heart, glass, and the sonar data set (with around 100 samples in the training set) our R code took less than 4 h to run two independent MCMC chains of 50 000 iterations. For the vowel and the Pima Indians Diabetes data sets (with around 400 samples in the training set), it took approximately 10 h. Statistical Analysis and Data Mining DOI:10.1002/sam
5.
DISCUSSION
In this article, we proposed a BANN method, which provides us a way to overcome the drawbacks of k-NN, DANN, and PNN. Our BANN method is highly flexible in determining any local pattern in the data and can adapt to it, as illustrated in Fig. 5. Instead of fixing km as in the DANN method, we assign a near diffused but proper prior on km , which solves the difficulty of choosing km . Similarly, neighborhood size k is also kept free by putting a prior on it. Thus, we select the optimal km and k adaptively by maximizing the posterior through the MCMC steps. Interestingly, for class prediction, the posterior predictive probability is calculated by averaging over different values of k and km . Our ability to make use of several values of k and km gives us added flexibility. Moreover, the BANN algorithm provides us a probabilistic interpretation, where the uncertainty is propagated from prior distributions of (km , k, β) and our likelihood (8) based on the nearest neighbor formulation. There is no strong assumption about the distribution of predictors. In Section 4, our proposed BANN method demonstrates substantial improvement over all other nearest neighbor methods in all the applications. From all the simulations and case studies, we can see that our BANN method is consistently better than other nearest neighbor methods. Only for the heart data, PNN has a lower misclassification rate than BANN. Furthermore, from the posterior plots, we can get a good idea about the value of k and km . Although we put non-informative priors on k and km , from Figs. 4 and 7 we notice that the posterior is highly informative. This implies that data has strong information about the best choice of the neighborhood sizes k and km . In this article, we have used the DANN metric for calculating distances, which cannot be used directly for ‘large p, small n’ data sets like microarrays. All the applications in Section 4 are small p, large n. For highdimensional data sets, like microarray data, we can first perform dimension reduction and then use the DANN metric and apply our BANN method. It will be more interesting to investigate a better distance metric which can be used for high-dimensional data.
Acknowledgment The authors would like to thank University of Missouri Bioinformatics Consortium for letting us to use their Dell EM64T cluster system (Lewis) for all computational needs.
REFERENCES [1] B. D. Ripley, Pattern Recognition and Neural Networks, Cambridge, Cambridge University Press, 1996.
Guo and Chakraborty: Bayesian Adaptive Nearest Neighbor [2] V. Vapnik, The Nature of Statistical Learning Theory, New York, Springer, 1995. [3] L. Breiman, J. Friedman, C. J. Stone, and R. A. Olshen, Classification and Regression Trees (Paperback) (1st ed.), New York, Chapman & Hall/CRC, 1984. [4] L. Breiman, Random forests, Machine Learn 45(1) (2001), 5–32. [5] D. Denison, C. Holmes, B. Mallick, and A. F. M. Smith, Bayesian Methods for Nonlinear Classification and Regression, New York, Wiley, 2002. [6] M. Ghosh, T. Maiti, D. Kim, S. Chakraborty, and A. Tewari, Hierarchical Bayesian neural networks: an application to prostate cancer study, J Am Stat Assoc 99 (2004), 601–608. [7] S. Chakraborty, M. Ghosh, T. Maiti, and A. Tewari, Hierarchical Bayesian neural networks for bivariate binary data: an application to prostate cancer study, Stat Med 24(23) (2005), 3645–3662. [8] B. K. Mallick, D. Ghosh, and M. Ghosh, Bayesian classification of tumors using gene expression data, J R Stat Soc., B 67 (2005), 219–232. [9] T. Hastie, and R. Tibshirani, Discriminant adaptive nearest neighbor classification, IEEE Trans Pattern Anal Machine Intelligence 18 (1996), 607–616. [10] C. C. Holmes, and N. M. Adams, A probabilistic nearest neighbour method for statistical pattern recognition, J R Stat Soc., Series B: Stat Methodol 64 (2002), 295–306.
105
[11] S. Banerjee, B. P. Carlin, and A. E. Gelfrand, Hierarchical Modeling and Analysis for Spatial Data, Boca Raton, Chapman & Hall/CRC, 2003. [12] S. Geman, and D. Geman, Stochastic relaxation, gibbs distributions, and the Bayesian restoration of images, IEEE Trans Pattern Anal Machine Intelligence 6 (1984), 721–741. [13] J. Besag, Spatial interaction and the statistical analysis of lattice systems (with discussions), J R Stat Soc., Series B 36 (1974), 192–236. [14] C. P. Robert, and G. Casella, Monte Carlo Statistical Methods, Springer-Verlag, 1999. [15] A. Asuncion and D. Newman, UCI machine learning repository, University of California, Irvine, School of Information and Computer Sciences, 2007. http://www.ics. uci.edu/∼mlearn/MLRepository.html. [16] B. D. Ripley, Neural networks and related methods for classification (Disc: P437–456), J R Stat Soc., Series B: Methodological 56 (1994), 409–437. [17] S. Chakraborty, Simultaneous cancer classification and gene selection with Bayesian nearest neighbor method: an integrated approach, Computational Statistics and Data Analysis 53(4) (2008), 1462–1474. [18] D. Gamerman, and H. F. Lopes, Markov Chain Monte Carlo: Stochastic Simulation for Bayesian Inference (2nd ed.), New York, Chapman & Hall/CRC, 2006.
Statistical Analysis and Data Mining DOI:10.1002/sam