Lower Bound for High-Dimensional Statistical Learning Problem via Direct-Sum Theorem Ankit Garg
∗
Tengyu Ma
†
Huy L. Nguy˜ên
‡
May 7, 2014
Abstract We explore the connection between dimensionality and communication cost in distributed learning problems. Specifically we study the problem of estimating the mean θ~ of an unknown d dimensional normal distribution in the distributed setting. In this problem, the samples from the unknown distribution are distributed among m different machines. The goal is to estimate the mean θ~ at the optimal minimax rate while communicating as few bits as possible. We show that in this simple setting, the communication cost scales linearly in the number of dimensions i.e. one needs to deal with different dimensions individually.
1
Introduction
The last decade has witnessed a tremendous growth in the amount data involved in machine learning tasks. In many cases, data volume has outgrown the capacity of a single machine and it is increasingly common that learning tasks are performed in a distributed fashion on many machines. Beside traditional aspects of computation such as the running time and memory usage, communication has emerged as an important resource and sometimes the bottleneck of the whole system. A lot of recent works in machine learning are devoted to understanding the amount of communication needed in distributed learning tasks [BBFM12, IPSV12b, IPSV12a, ZDJW13]. In this paper, we study the relation between the dimensionality and the communication complexity of statistical estimation problems. Most modern statistical problems are characterized by high dimensionality. Thus, it is natural to ask the following meta question: How does the communication cost scale in the dimensionality? We study this question via the problem of estimating the mean θ of an unknown d dimensional normal distribution in the distributed setting. In this problem, the samples from the unknown distribution are distributed among m different machines. The goal is to estimate the mean θ at the optimal minimax rate while communicating as few bits as possible. We show that in this simplest setting, one really needs to deal with different dimensions individually. Theorem 1.1. [Informal] To estimate the mean of a d-dimensional Gaussian in the distributed setting with error R, one must pay Ω(d) times the minimum communication cost needed for estimating the mean of one dimensional Gaussian with error R/d. ∗
Department of Computer Science, Princeton University, email:
[email protected]. Department of Computer Science, Princeton University, email:
[email protected]. ‡ Department of Computer Science, Princeton University, email:
[email protected]. †
1
The work [ZDJW13] showed a lower bound on the communication cost for this problem when d = 1. Our technique when applied to their theorem immediately yields a lower bound equal to d times the lower bound for the one dimension problem for any choice of d. See Theorem 3.3 for the precise statement. We use tools from the recent development in communication complexity and information complexity. There has been a lot of work on the paradigm of studying communication complexity via the notion of information complexity [CSWY01, BYJKS04, BR11, BBCR13, BEO+ 13]. Information complexity can be thought of as a proxy for communication complexity that is especially accurate for solving multiple copies of the same problem simultaneously [BR11]. It has become a standard tool for proving so-called “direct-sum” results, namely the fact that the amount of resources required for solving d copies of a problem in parallel is equal to d times the amount required for one copy. In other words, there is no saving from solving many copies of the same problem in batch and the trivial solution of solving each of them separately is optimal. Our result can be viewed as a direct sum theorem for communication complexity for statistical estimation problems: the amount of communication needed for solving an estimation problem in d dimensions is equal to d times the amount of communication needed for the same problem in one dimension. The proof technique is directly inspired by the notion of conditional information complexity [BYJKS04], which was used to prove direct sum theorems and lower bounds for streaming algorithms. We believe this is a fruitful connection and can lead to more lower bounds in statistical machine learning. Our techniques. Consider the problem of estimating some parameter θ ∈ Rd of a d-dimensional distribution from samples. To prove a lower bound for the d dimensional problem using an existing lower bound for one dimensional problem, we demonstrate a reduction that uses the (hypothetical) protocol for d dimensions to construct a better protocol for the one dimensional problem. The new protocol for the one dimensional problem works as follows: the one dimensional problem is embedded in a random coordinate of the d dimensional problem and the rest of the coordinates are filled in independently according to the prior of θ. As the other coordinates are independent from the input, they are shared among the machines before the protocol starts. When the machines get the samples, they proceed to simulate the protocol for the d-dimensional problem. The first question to ask is, what is this “simulation” protocol for the one dimensional problem good for. It has the same communication cost as the protocol for the d-dimensional problem, and that doesn’t seem to help. But it turns out, while the communication cost remains the same, the information cost goes down by a factor of d. The information is somehow smeared across all the dimensions. This is consistent with a general paradigm in mathematics, when certain discrete quantities are not that well behaved, but there continuous relaxations are, and are often used as a proxy to study the discrete quantities. Here we are using information cost as a proxy for communication cost. Using the “simulation” protocol, we prove that the information cost of the d-dimensional problem is d times the information cost of the one dimensional problem. Now using the fact that information cost is less than the communication cost, we get a lower bound for the communication cost of the d dimensional problem in terms of the information cost of the one dimensional problem. The work [ZDJW13] already showed a lower bound on the information cost of the one dimensional problem.
2
2
Notation and Setup
Statistical parameter estimation Let P be a family of distributions over X . Let θ : P → Θ denote a function defined on P. We are given samples X1 , . . . , Xn from some P ∈ P, and are asked ˆ 1 , . . . , Xn ) is the corresponding to estimate θ(P ). Let θˆ : X n → Θ be such an estimator, and θ(X estimate. Define the squared loss R of the estimator to be ˆ θ) = E R(θ,
ˆ θ,X,θ
h
ˆ 1 , . . . , Xn ) − θ(P )k2 kθ(X 2
i
In the high-dimensional case, let P d := {P~ = P1 × · · · × Pd : Pi ∈ P} be the family of product distributions over X d . Let θ~ : P d → Θd ⊂ Rd be the d-dimensional function obtained by applying θ point-wise θ~ (P1 × · · · × Pd ) = (θ(P1 ), . . . , θ(Pd )). Throughout this paper, we consider the case when P = {N (θ, σ 2 ) : θ ∈ [−1, 1]} for some fixed ˆ σ. Therefore, in the high-dimensional case, P d = {N ( θ~ , σ 2 Id ) : θ~ ∈ [−1, 1]d }. We use θ~ to denote the d-dimensional estimator. For clarity, in this paper, we always use ~· to indicate a vector in high dimensions. ~ (j,1) , . . . , X ~ (j,n) ∈ Multi-Machine setting: There are m machines. Machine j receives n samples X d ~ X from the distribution P . The machines communicate via a publicly shown blackboard. When a machine writes a message on the blackboard, all other machines can see the content of the message. Note that this model captures both point-to-point communication as well as broadcast communication. Therefore, our lower bounds in this model apply to both the message passing setting and the broadcast setting. We denote the transcript of the communication as Y . A deterministic ˆ ~ˆ ). Let letter j function θ~ is then applied to the transcript Y to get the estimation of the mean θ(Y be reserved for index of the machine and k for the samples and letter i for the dimension. In other ~ (j,k) is the ith-coordinate of kth sample of machine j. words, X i Private/public randomness: We allow the protocol to use both private and public randomness, which is crucial. The public randomness is used purely for convenience in the proof and is not counted toward the total communication because it can be shared among machines before the start of the protocol. Alternatively, because the protocol works well on average over all public randomness, there exists a fixing of the public randomness so that the protocol still works as well as the average. This particular fixing of the public randomness gives a protocol that uses no public randomness at all while performing just as well as the one with public randomness. On the other hand, the use of private randomness is extremely crucial. With a little bit of thought, one can be convinced that the machines can use private randomness to hide information from other machines in a protocol. Indeed, we will see that in the direct sum argument, the simulation protocol for one dimension, private randomness plays a very important role. Lets denote public and private randomness of the protocol by Rpub and Rpriv respectively. We define the squared loss of a protocol Π by
~ˆ θ~ R (Π, θ),
=
E
~ˆ ) − θ~ k2 ] [kθ(Y
~ θ~ ,X,Y,R pub ,Rpriv
Information cost: We define information cost IC(Π) of protocol Π as follows: ~ Y | θ~ , Rpub ) IC(Π) = I(X; 3
Private randomness doesn’t explicitly appear in the definition of information cost but if affects it. Note that the information cost is a lower bound on the communication cost: ~ Y | θ~ ) ≤ H(Y ) ≤ length of Y IC(Π) = I(X;
3
Distributed Statistical Learning
We start by formally defining the our task and the mean-squared loss and information cost of a protocol. ~ˆ solves task T (d, m, n, σ 2 , Dd ) with Definition 1. We say a protocol and estimator pair (Π, θ) θ information cost C and mean-squared loss R, if for θ~ randomly chosen from Dθd , m machines, each of which takes n samples from N ( θ~ , σ 2 Id ) as input, can run the protocol Π and get transcript Y so that the followings are true: ~ˆ ) − θ~ k2 ] = R E[kθ(Y ~ Y | θ~ ) = C I(X;
(1) (2)
~ˆ solves the task T (d, m, n, σ 2 , V d ) with information Theorem 3.1. [Direct-sum Theorem] If (Π, θ) ˆ that solves the task T (1, m, n, σ 2 , V) with information cost C and squared loss R, there exists (Π0 , θ) cost at most 4C/d and squared loss 4R/d. Proof. For each i ∈ [d], we could define the following protocol Πi and estimator θˆi induced by the ~ˆ Πi is described as Protocol 1. d-dimensional estimator θ. Inputs : Machine j gets samples X (j,1) , . . . , X (j,n) distributed according to N (θ, σ 2 ), where θ ∼ V. 1. All machines publicly sample θ˘−i distributed according to V d−1 . ˘ (j,1) , . . . , X ˘ (j,n) distributed according to N (θ˘−i , σ 2 Id−1 ). Let 2. Machine j privately samples X −i −i (j,k) (j,k) (j,k) (j,k) (j,k) ˘ ˘ ˘ ˘ ˘ (j,k) ). X = (X ,...,X ,X ,X ,...,X 1
i−1
i+1
d
˘ and get transcript Yi . The estimator θˆi is θˆi (Yi ) = 3. All machines run protocol Π on data X ~ˆ i )i i.e. the ith coordinate of the d-dimensional estimator. θ(Y Protocol 1: Πi The role of private randomness can be crucially seen here. It is very important for the machines to privately get samples in coordinates other than i for the information cost to go down by a factor of d. Lets denote the private and public randomness of the protocol Πi as Rpriv and Rpub respectively. We prove that Πi does a good job in the average sense by the following two lemmas: Lemma 1. If θ ∼ V and θ~ ∼ V d , then d X
R (Πi , θˆi ), θ = R (Π, θ~ ), θ~
i=1
4
Proof. Note that
R (Πi , θˆi ), θ =
E
[(θˆi (Yi ) − θ)2 ]
E
~ˆ i )i − θ)2 ] [(θ(Y
θ,X,Yi ,Rpriv ,Rpub
=
θ,X,Yi ,Rpriv ,Rpub
Hence d X
R (Πi , θˆi ), θ =
i=1
=
d X
E
i=1 θ,X,Yi ,Rpriv ,Rpub d X
E [
~ θ~ ,X,Y i=1
=
~ˆ i )i − θ)2 ] [(θ(Y
~ˆ )i − θ~ i )2 ] (θ(Y
~ˆ ) − θ~ k2 ] E [kθ(Y
~ θ~ ,X,Y
The second equality follows from the fact that the joint distribution of θ, X, Yi , Rpriv , Rpub is the ~ Y . Also the marginal distributions of θ~ are the same as the same as the distribution of θ~ , X, distribution of θ. Lemma 2. If θ ∼ V and θ~ ∼ V d , then d X
IC(Πi ) ≤ IC(Π)
i=1
˘ which has the same distribution as X ~ in the Proof. Recall under (Πi , θˆi ), machines prepare X, 2 d ~ ~ problem T (d, m, n, σ , V ). Also the joint distribution of Xi , Y, θ is the same as the distribution of X, Yi , θ, θ˘−i . Therefore, we have that ~ i ; Y | θ~ ) = I(X; Yi | θ, θ˘−i ) I(X Since IC(Πi ) = I(X; Yi | θ, Rpub ) = I(X; Yi | θ, θ˘−i ), we have that d X
IC(Πi ) =
i=1
=
d X i=1 d X
I(X; Yi | θ, θ˘−i ) ~ i ; Y | θ~ ) I(X
i=1
~ conditioned on θ~ is N ( θ~ , σ 2 Id ), X ~ 1, . . . , X ~ d are independent conditioned Since the distribution of X on θ~ . Hence d X
~ i ; Y | θ~ ) ≤ I(X; ~ Y | θ~ ) = IC(Π) I(X
i=1
5
The inequality is true because of the following: ~ Y | θ~ ) = I(X;
d X
~ i ; Y | θ~ , X ~ 1, . . . , X ~ i−1 ) I(X
i=1
=
d X
~ i | θ~ , X ~ 1, . . . , X ~ i−1 ) − H(X ~ i | Y, θ~ , X ~ 1, . . . , X ~ i−1 ) H(X
i=1
=
d X
~ i | θ~ ) − H(X ~ i | Y, θ~ , X ~ 1, . . . , X ~ i−1 ) H(X
i=1
≥ =
d X i=1 d X
~ i | θ~ ) − H(X ~ i | Y, θ~ ) H(X
~ i ; Y | θ~ ) I(X
i=1
~ 1, . . . , X ~ d are independent conditioned on θ~ . The inequality The third equality is true because X follows from the fact that conditioning decreases entropy. By Lemma 1 and Lemma 2 and a Markov argument, there exists an i ∈ {1, . . . , d} such that 4 R (Πi , θˆi ), θ ≤ · R (Π, θ~ ), θ~ d
and IC(Πi ) ≤
4 · IC(Π) d
ˆ = (Πi , θˆi ) solves the task T (1, m, n, σ 2 , V) with information cost at most 4C/d Then the pair (Π0 , θ) and squared loss 4R/d. We are going to apply the theorem above to the one-dimensional lower bound by [ZDJW13]. This theorem is not explicitly stated in the paper but is implicit in the proof of Theorem 1 in their paper. Also they do not mention this, but their techniques are general enough to prove lower bounds on the information cost for protocols with private randomness. Also in their case, the definition of information cost is a bit different. They do not condition on the prior of θ, but since in the one dimensional case, this prior is just over {±δ}, conditioning on it can reduce the mutual information by at most 1 bit. I(X; Y | θ, Rpub ) ≥ I(X; Y |Rpub ) − 1
2
. Theorem 3.2. [ZDJW13] Let V be the uniform distribution over {±δ}, where δ 2 ≤ min 1, σ log(m) n ˆ solves the task T (1, m, n, σ 2 , V) with information cost C and squared loss R, then either If (Π, θ) σ2 C ≥ Ω δ2 n log(m) or R ≥ δ 2 /10. The corollary below directly follows from Theorem 3.2 and Theorem 3.1. 6
2
m Corollary 3.1. Let V be the uniform distribution over {±δ}, where δ 2 ≤ min 1, σ log . If n ˆ solves the task T (1, m, n, σ 2 , V d ) with information cost C and squared loss R, then either (Π, θ) 2 C ≥ Ω δ2 ndσlog m or R ≥ dδ 2 /40.
This immediately proves the main theorem of the paper. ˆ estimates the mean of N ( θ~ , σ 2 Id ), where θ~ ∈ [−1, 1]d , with mean-squared Theorem 3.3. If (Π, θ) loss R, and communication cost B. Then (
R ≥ Ω min
)!
d2 σ 2 dσ 2 , ,d nB log m n log m
2
As a corollary, to achieve the optimal mean-squared loss R = dσ for the case when data is on a mn dm single machine, the communication cost is at least B = Ω log m Proof. Apply corollary 3.1with the trivial bound C ≤ B. We divide into two cases depending 1 dσ 2 d on whether B ≥ c · max n log m , log2 m or not, c > 1 is a constant to be specified later. If
2
d 2 B ≥ 1c · max n dσ log m , log2 m , choose δ = hence we can apply corollary 3.1. Also
1 c
2
σ 2 · nBdσ log m . Then we have that δ ≤ min 1,
C≤B=
2
log m n
, and
1 dσ 2 · 2 c δ n log m
Choose c is such that violates the lower bound on C in corrollary Thus, we must have 2this 3.1. dσ 2 d σ2 1 2 R ≥ dδ /40 ≥ Ω nB log m . On the other hand, if B ≤ c · max n log m , logd2 m , choose δ 2 =
dσ 2 n max
dσ 2 , d n log m log2 m
. Again δ 2 ≤ min 1, σ log m
n
o
dσ 2 n log m , d
log m n
and
dσ 2 d , n log m log2 m
1 C ≤ B ≤ · max c Hence R ≥ dδ 2 /40 ≥ Ω min
2
!
=
dσ 2 1 · 2 c δ n log m
. Combining the two cases, we get (
R ≥ Ω min
)!
d2 σ 2 dσ 2 , ,d nB log m n log m
References [BBCR13] Boaz Barak, Mark Braverman, Xi Chen, and Anup Rao. How to compress interactive communication. SIAM J. Comput., 42(3):1327–1363, 2013. [BBFM12] Maria-Florina Balcan, Avrim Blum, Shai Fine, and Yishay Mansour. Distributed learning, communication complexity and privacy. In COLT, pages 26.1–26.22, 2012.
7
[BEO+ 13] Mark Braverman, Faith Ellen, Rotem Oshman, Toniann Pitassi, and Vinod Vaikuntanathan. A tight bound for set disjointness in the message-passing model. In FOCS, pages 668–677, 2013. [BR11]
Mark Braverman and Anup Rao. Information equals amortized communication. In FOCS, pages 748–757, 2011.
[BYJKS04] Ziv Bar-Yossef, T. S. Jayram, Ravi Kumar, and D. Sivakumar. An information statistics approach to data stream and communication complexity. J. Comput. Syst. Sci., 68(4), 2004. [CSWY01] Amit Chakrabarti, Yaoyun Shi, Anthony Wirth, and Andrew Chi-Chih Yao. Informational complexity and the direct sum problem for simultaneous message complexity. In FOCS, pages 270–278, 2001. [IPSV12a] Hal Daumé III, Jeff M. Phillips, Avishek Saha, and Suresh Venkatasubramanian. Efficient protocols for distributed classification and optimization. In ALT, pages 154–168, 2012. [IPSV12b] Hal Daumé III, Jeff M. Phillips, Avishek Saha, and Suresh Venkatasubramanian. Protocols for learning classifiers on distributed data. In AISTATS, pages 282–290, 2012. [ZDJW13] Yuchen Zhang, John C. Duchi, Michael I. Jordan, and Martin J. Wainwright. Information-theoretic lower bounds for distributed statistical estimation with communication constraints. In NIPS, pages 2328–2336, 2013.
8