Title: Learning Mixtures of Product Distributions
1Learning Mixtures of Product Distributions
- Jon Feldman
- Columbia University
Ryan ODonnell IAS
Rocco Servedio Columbia University
2Learning Distributions
- There is a an unknown distribution P over Rn, or
maybe just over 0,1n. - An algorithm gets access to random samples from
P. - In time polynomial in n/e it should output a
hypothesis distribution Q which (w.h.p.) is
e-close to P. - Technical details later.
3Learning Distributions
4Learning Classes of Distributions
Learning Distributions
- Since this is hopeless in general one assumes
that P comes from class of distributions C. - We speak of whether C is polynomial-time
learnable or not this means that there is one
algorithm that learns every P in C. - Some easily learnable classes
- C Gaussians over Rn
- C Product distributions over 0,1n
5Learning product distributions over 0,1n
- E.g. n 3. Samples
- 0 1 0
- 0 1 1
- 0 1 1
- 1 1 1
- 0 1 0
- 0 1 1
- 0 1 0
- 0 1 0
- 1 1 1
- 0 0 0
- Hypothesis .2 .9 .5
6Mixtures of product distributions
- Fix k 2 and let p1 p2 pk 1.
- The p-mixture of distributions P 1, , P k is
- Draw i according to mixture weights pi.
- Draw from P i.
- In the case of product distributions over 0,1n
- p1 µ1 µ1 µ1 µ1
- p2 µ2 µ2 µ2 µ2
-
- pk µk µk µk µk
1
2
3
n
n
1
2
3
n
3
2
1
7Learning mixture example
- E.g. n 4. Samples 1 1 0 0
- 0 0 0 1
- 0 1 0 1
- 0 1 1 0
- 0 0 0 1
- 1 1 1 0
- 0 1 0 1
- 0 0 1 1
- 1 1 1 0
- 1 0 1 0
- True distribution
- 60 .8 .8 .6 .2
- 40 .2 .4 .3 .8
8Prior work
- KMRRSS94 learned in time poly(n/e, 2k) in the
special case that there is a number p lt ½ such
that every µi is either p or 1-p. - FM99 learned mixtures of 2 product
distributions over 0,1n in polynomial time
(with a few minor technical deficiencies). - CGG98 learned a generalization of 2 product
distributions over 0,1n, no deficiencies. - The latter two leave mixtures of 3 as an open
problem there is a qualitative difference
between 2 3. FM99 also leaves open learning
mixes of Gaussians, other Rn distributions.
j
9Our results
- A poly(n/e) time algorithm learning a mixture of
k product distributions over 0,1n for any
constant k. - Evidence that getting a poly(n/e) algorithm for k
?(1) even in the case where µs are in 0, ½,
1 will be very hard (if possible). - Generalizations
- Let C 1, , C n be nice classes of
distributions over R (definable in terms of
O(1) moments) Algorithm learns mixture of O(1)
distributions in C 1 C n. - Only pairwise independence of coords is used
10Technical definitions
- When is a hypothesis distribution Q e-close to
the target distribution P ? - L1 distance? ? P(x) Q(x).
- KL divergence KL(P Q) ? P (x) logP
(x)/Q(x). - Getting a KL-close hypothesis is more stringent
- fact L1 O(KL½).
- We learn under KL divergence, which leads to some
technical advantages (and some technical
difficulties).
11Learning distributions summary
- Learning a class of distributions C.
- Let P be any distribution in the class.
- Given e and d gt 0.
- Get samples and do poly(n/e, log(1/d)) much work.
- With probability at least 1-d output a hypothesis
Q which satisfies KL(P Q) lt e.
12Some intuition for k 2
- Idea Find two coordinates j and j' to key
off. - Suppose you notice that the bits in coords j and
j' are very frequently different. - Then probably most of the 01 examples come
from one mixture and most of the 10 examples
come from the other mixture - Use this separation to estimate all other means.
13More details for the intuition
- Suppose you somehow know the following three
things - The mixture weights are 60 / 40.
- There are j and j' such that means satisfy
- pj pj'
- qj qj'
- The values pj, pj', qj, qj' themselves.
gt e.
14More details for the intuition
- Main algorithmic idea
- For each coord m, estimate (to within e2) the
correlation between j m and j' m. - corr(j, m) (.6 pj) pm (.4 qj) qm
- corr(j', m) (.6 pj') pm (.4 qj') qm
- Solve this system of equations for pm, qm. Done!
- Since the determinant is gt e, any error in
correlation estimation error does not blow up too
much.
15Two questions
- 1. This assumes that there is some 22 submatrix
which is far from singular. In general, no
reason to believe this is the case. - But if not, then one set of means is very nearly
a multiple of the other set problem becomes very
easy. - 2. How did we know p1, p2? How did we know
which j and j' were good? How did we know the 4
means pj, pj', qj, qj'?
16Guessing
- Just guess. I.e., try all possibilities.
- Guess if the 2 n matrix is essentially rank 1
or not. - Guess p1, p2 to within e2. (Time 1/e4.)
- Guess correct j, j'. (Time n2.)
- Guess pj, pj', qj, qj' to within e2. (Time
1/e8.) - Solve the system of equations in every case.
- Time poly(n/e).
17Checking guesses
- After this we get a whole bunch of candidate
hypotheses. - When we get lucky and make all the right guesses,
the resulting candidate hypothesis will be a good
one say, will be e-close in KL to the truth. - Can we pick the (or, a) candidate hypothesis
which is KL-close to the truth? I.e., can we
guess and check? - Yes use a Maximum Likelihood test
18Checking with ML
- Suppose Q is a candidate hypothesis for P.
- Estimate its log likelihood
- log ?x ? S Q(x)
- Sx ? S log Q(x)
- S Elog Q (x)
- S ? P (x) log Q (x)
- S ? P log P KL(P Q ) .
19Checking with ML contd
- By Chernoff bounds, if we take enough samples,
all candidate hypotheses Q will have their
estimated log-likelihoods close to their
expectations. - Any KL-close Q will look very good in the ML
test. - Anything which looks good in the ML test is
KL-close. - Thus assuming there is an e-close candidate
hypothesis among guesses, we find an O(e)-close
candidate hypothesis. - I.e., we can guess and check.
20Overview of the algorithm
- We now give the precise algorithm for learning a
mixture of k product distributions, along with
intuition for why it works. - Intuitively
- Estimate all the pairwise correlations of bits.
- Guess a number of parameters of the mixture
distn. - Use guesses, correlation estimates to solve for
remaining parameters. - Show that whenever guesses are close, the
resulting parameter estimations give a
close-in-KL candidate hypothesis. - Check candidates with ML algorithm, pick best one.
21The algorithm
- 1. Estimate all pairwise correlations corr(j, j')
to within (e/n)k. (Time (n/e)k.) - Note corr(j, j') Si 1..k pi µi µi
- ? µj , µj' ?,
- where µj ( (pi)½ µi )i 1..k
- 2. Guess all pi to within (e/n)k. (Time
(n/e)k2.) - Now it suffices to estimate all vectors µj, j
1 n.
j
j'
j
22Mixtures of product distributions
- Fix k 2 and let p1 p2 pk 1.
- The p-mixture of distributions P 1, , P k is
- Draw i according to mixture weights pi.
- Draw from P i.
- In the case of product distributions over 0,1n
- p1 µ1 µ1 µ1 µ1
- p2 µ2 µ2 µ2 µ2
-
- pk µk µk µk µk
1
2
3
n
n
1
2
3
n
3
2
1
23Guessing matrices from most of their Gram
matrices
- Let A be the k n matrix of µ is.
- A
- After estimating all correlations, we know all
dot products of distinct columns of A to high
accuracy. - Goal determine all entries of A, making only
O(1) guesses.
j
µ1
µ2
µn
24Two remarks
- This is the final problem, where all the main
action and technical challenge lies. Note that
all we ever do with the samples is estimate
pairwise correlations. - If we knew the dot products of the columns of A
with themselves, wed have the whole matrix ATA.
That would be great we could just factor it and
recover A exactly. Unfortunately, there
doesnt seem to be any way to get at these
quantities Si 1..k pi (µi)2.
j
25Keying off a nonsingular submatrix
- Idea find a nonsingular k k matrix to key
off. - As before, the usual case is that A has full
rank. - Then there is a k k nonsingular submatrix AJ.
- Guess this matrix (time nk) and all its entries
to within (e/n)k (time (n/e)k3 final running
time). - Now use this submatrix and correlation estimates
to find all other entries of A - for all m, AJT Am corr(m, j)
(j ? J)
26Non-full rank case
- But what if A is not full rank? (Or in actual
analysis, if A is extremely close to being rank
deficient.) A genuine problem. - Then A has some perpendicular space of dimension
0 lt d k, spanned by some orthonormal vectors
u1, , ud. - Guess d and the vectors u1, , ud.
- Now adjoin these columns to A getting a full rank
matrix. - A' A u1 u2 ud
27Non-full rank case contd
- Now A' has full rank and we can do the full rank
case! - Why do we still know all pairwise dot products of
A's columns? - Dot product of us with A columns are 0!
- Dot product of us with each other is 1. (Dont
need this.) - 4. Guess a k k submatrix of A' and all its
entries. Use these to solve for all other
entries.
28The actual analysis
- The actual analysis of this algorithm is quite
delicate. - Theres some linear algebra numerical analysis
ideas. - The main issue is The degree to which A is
essentially of rank k d is similar to the
degree to which all guessed vectors u really do
have dot product 0 with As original columns. - The key is to find a large multiplicative gap
between As singular values, and treat its
location as the essential rank of A. - This is where the necessary accuracy (e/n)k comes
in.
29Can we learn a mixture of ?(1)?
- Claim Let T be a decision tree on 0,1n with k
leaves. Then the uniform distribution over the
inputs which make T output 1 is a mixture of at
most k product distributions. - Indeed, all product distributions have means 0,
½, or 1.
x1
0
1
x2
x3
2/3 0, 0, ½, ½, ½, 1/3 1, 1, 0, ½, ½,
0
0
1
1
x2
1
0
0
0
1
0
1
30Learning DTs under uniform
- Cor If one can learn a mixture of k product
distributions over 0,1n (even 0/½/1 ones) in
poly(n) time, one can PAC-learn k-leaf decision
trees under uniform in poly(n) time. - PAC-learning ?(1)-size DTs under uniform is an
extremely notorious problem - easier than learning ?(1)-term DNF under uniform,
a 20-year-old problem - essentially equivalent to learning ?(1)-juntas
under uniform worth 1000 from A. Blum to solve
31Generalizations
- We gave an algorithm that guessed the means of an
unknown mixture of k product distributions. - What assumptions did we really need?
- pairwise independence of coords
- means fell in a bounded range -poly(n), poly(n)
- 1-d distributions (and pairwise products of same)
are samplable can find true correlations by
estimation - the means defined the 1-d distributions
- The last of these is rarely true. But
32Higher moments
- Suppose we ran the algorithm and got N guesses
for the means of all the distributions. - Now run the algorithm again, but whenever you get
the point ?x1, , xn?, treat it as ?x12, , xn2?. - You will get N guesses for the second moments!
- Cross product the two lists, get N2 guesses for
the ?mean, second moment? pairs. - Guess and check, as always.
33Generalizations
- Let C 1, , C n be families of distributions on R
which have the following niceness properties - means bounded in -poly(n), poly(n)
- sharp tail bounds / samplability
- defined by O(1) moments, closeness in moments ?
closeness in KL - more technical concerns
- Should be able to learn O(1)-mixtures from C 1
C n in same time. - Definitely can learn mixtures of axis-aligned
Gaussians, mixtures of distributions on
O(1)-sized sets.
34Open questions
- Quantify some nice properties of families of
distributions over R which this algorithm can
learn. - Simplify algorithm
- Simpler analysis?
- Faster? nk2 ? nk ? nlog k ???
- Specific fast results for k 2, 3.
- Solve other distribution-learning problems.