aboutsummaryrefslogtreecommitdiff
path: root/site-from-md/posts/2019-01-03-discriminant-analysis.html
blob: c28df60230341cec39e57caec77cf63586f460e3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
<!doctype html>
<html lang="en">
    <head>
        <meta charset="utf-8">
        <title>Discriminant analysis</title>
        <link rel="stylesheet" href="../assets/css/default.css" />
        <script data-isso="/comments/"
                data-isso-css="true"
                data-isso-lang="en"
                data-isso-reply-to-self="false"
                data-isso-require-author="true"
                data-isso-require-email="true"
                data-isso-max-comments-top="10"
                data-isso-max-comments-nested="5"
                data-isso-reveal-on-click="5"
                data-isso-avatar="true"
                data-isso-avatar-bg="#f0f0f0"
                data-isso-avatar-fg="#9abf88 #5698c4 #e279a3 #9163b6 ..."
                data-isso-vote="true"
                data-vote-levels=""
                src="/comments/js/embed.min.js"></script>
        <script src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script>
        <script src="../assets/js/analytics.js" type="text/javascript"></script>
    </head>
    <body>
        <header>
            <span class="logo">
                <a href="../blog.html">Yuchen's Blog</a>
            </span>
            <nav>
                <a href="../index.html">About</a><a href="../postlist.html">All posts</a><a href="../blog-feed.xml">Feed</a>
            </nav>
        </header>

        <div class="main">
            <div class="bodyitem">
                <h2> Discriminant analysis </h2>
                <p>Posted on 2019-01-03 | <a href="/posts/2019-01-03-discriminant-analysis.html#isso-thread">Comments</a> </p>
                    <!DOCTYPE html>
<html xmlns="http://www.w3.org/1999/xhtml" lang="" xml:lang="">
<head>
  <meta charset="utf-8" />
  <meta name="generator" content="pandoc" />
  <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes" />
  <title>Untitled</title>
  <style>
      code{white-space: pre-wrap;}
      span.smallcaps{font-variant: small-caps;}
      span.underline{text-decoration: underline;}
      div.column{display: inline-block; vertical-align: top; width: 50%;}
  </style>
  <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.2/MathJax.js?config=TeX-AMS_CHTML-full" type="text/javascript"></script>
  <!--[if lt IE 9]>
    <script src="//cdnjs.cloudflare.com/ajax/libs/html5shiv/3.7.3/html5shiv-printshiv.min.js"></script>
  <![endif]-->
</head>
<body>
<nav id="TOC">
<ul>
<li><a href="#theory">Theory</a><ul>
<li><a href="#qda">QDA</a></li>
<li><a href="#vanilla-lda">Vanilla LDA</a></li>
<li><a href="#nearest-neighbour-classifier">Nearest neighbour classifier</a></li>
<li><a href="#dimensionality-reduction">Dimensionality reduction</a></li>
<li><a href="#fisher-discriminant-analysis">Fisher discriminant analysis</a></li>
<li><a href="#linear-model">Linear model</a></li>
</ul></li>
<li><a href="#implementation">Implementation</a><ul>
<li><a href="#fun-facts-about-lda">Fun facts about LDA</a></li>
</ul></li>
</ul>
</nav>
<p>In this post I talk about the theory and implementation of linear and quadratic discriminant analysis, classical methods in statistical learning.</p>
<p><strong>Acknowledgement</strong>. Various sources were of great help to my understanding of the subject, including Chapter 4 of <a href="https://web.stanford.edu/~hastie/ElemStatLearn/">The Elements of Statistical Learning</a>, <a href="http://cs229.stanford.edu/notes/cs229-notes2.pdf">Stanford CS229 Lecture notes</a>, and <a href="https://github.com/scikit-learn/scikit-learn/blob/7389dba/sklearn/discriminant_analysis.py">the scikit-learn code</a>. Research was done while working at KTH mathematics department.</p>
<p><em>If you are reading on a mobile device, you may need to “request desktop site” for the equations to be properly displayed. This post is licensed under CC BY-SA and GNU FDL.</em></p>
<h2 id="theory">Theory</h2>
<p>Quadratic discriminant analysis (QDA) is a classical classification algorithm. It assumes that the data is generated by Gaussian distributions, where each class has its own mean and covariance.</p>
<p><span class="math display">\[(x | y = i) \sim N(\mu_i, \Sigma_i).\]</span></p>
<p>It also assumes a categorical class prior:</p>
<p><span class="math display">\[\mathbb P(y = i) = \pi_i\]</span></p>
<p>The log-likelihood is thus</p>
<p><span class="math display">\[\begin{aligned}
\log \mathbb P(y = i | x) &amp;= \log \mathbb P(x | y = i) \log \mathbb P(y = i) + C\\
&amp;= - {1 \over 2} \log \det \Sigma_i - {1 \over 2} (x - \mu_i)^T \Sigma_i^{-1} (x - \mu_i) + \log \pi_i + C&#39;, \qquad (0)
\end{aligned}\]</span></p>
<p>where <span class="math inline">\(C\)</span> and <span class="math inline">\(C&#39;\)</span> are constants.</p>
<p>Thus the prediction is done by taking argmax of the above formula.</p>
<p>In training, let <span class="math inline">\(X\)</span>, <span class="math inline">\(y\)</span> be the input data, where <span class="math inline">\(X\)</span> is of shape <span class="math inline">\(m \times n\)</span>, and <span class="math inline">\(y\)</span> of shape <span class="math inline">\(m\)</span>. We adopt the convention that each row of <span class="math inline">\(X\)</span> is a sample <span class="math inline">\(x^{(i)T}\)</span>. So there are <span class="math inline">\(m\)</span> samples and <span class="math inline">\(n\)</span> features. Denote by <span class="math inline">\(m_i = \#\{j: y_j = i\}\)</span> be the number of samples in class <span class="math inline">\(i\)</span>. Let <span class="math inline">\(n_c\)</span> be the number of classes.</p>
<p>We estimate <span class="math inline">\(\mu_i\)</span> by the sample means, and <span class="math inline">\(\pi_i\)</span> by the frequencies:</p>
<p><span class="math display">\[\begin{aligned}
\mu_i &amp;:= {1 \over m_i} \sum_{j: y_j = i} x^{(j)}, \\
\pi_i &amp;:= \mathbb P(y = i) = {m_i \over m}.
\end{aligned}\]</span></p>
<p>Linear discriminant analysis (LDA) is a specialisation of QDA: it assumes all classes share the same covariance, i.e. <span class="math inline">\(\Sigma_i = \Sigma\)</span> for all <span class="math inline">\(i\)</span>.</p>
<p>Guassian Naive Bayes is a different specialisation of QDA: it assumes that all <span class="math inline">\(\Sigma_i\)</span> are diagonal, since all the features are assumed to be independent.</p>
<h3 id="qda">QDA</h3>
<p>We look at QDA.</p>
<p>We estimate <span class="math inline">\(\Sigma_i\)</span> by the variance mean:</p>
<p><span class="math display">\[\begin{aligned}
\Sigma_i &amp;= {1 \over m_i - 1} \sum_{j: y_j = i} \hat x^{(j)} \hat x^{(j)T}.
\end{aligned}\]</span></p>
<p>where <span class="math inline">\(\hat x^{(j)} = x^{(j)} - \mu_{y_j}\)</span> are the centred <span class="math inline">\(x^{(j)}\)</span>. Plugging this into (0) we are done.</p>
<p>There are two problems that can break the algorithm. First, if one of the <span class="math inline">\(m_i\)</span> is <span class="math inline">\(1\)</span>, then <span class="math inline">\(\Sigma_i\)</span> is ill-defined. Second, one of <span class="math inline">\(\Sigma_i\)</span>'s might be singular.</p>
<p>In either case, there is no way around it, and the implementation should throw an exception.</p>
<p>This won't be a problem of the LDA, though, unless there is only one sample for each class.</p>
<h3 id="vanilla-lda">Vanilla LDA</h3>
<p>Now let us look at LDA.</p>
<p>Since all classes share the same covariance, we estimate <span class="math inline">\(\Sigma\)</span> using sample variance</p>
<p><span class="math display">\[\begin{aligned}
\Sigma &amp;= {1 \over m - n_c} \sum_j \hat x^{(j)} \hat x^{(j)T},
\end{aligned}\]</span></p>
<p>where <span class="math inline">\(\hat x^{(j)} = x^{(j)} - \mu_{y_j}\)</span> and <span class="math inline">\({1 \over m - n_c}\)</span> comes from <a href="https://en.wikipedia.org/wiki/Bessel%27s_correction">Bessel's Correction</a>.</p>
<p>Let us write down the decision function (0). We can remove the first term on the right hand side, since all <span class="math inline">\(\Sigma_i\)</span> are the same, and we only care about argmax of that equation. Thus it becomes</p>
<p><span class="math display">\[- {1 \over 2} (x - \mu_i)^T \Sigma^{-1} (x - \mu_i) + \log\pi_i. \qquad (1)\]</span></p>
<p>Notice that we just walked around the problem of figuring out <span class="math inline">\(\log \det \Sigma\)</span> when <span class="math inline">\(\Sigma\)</span> is singular.</p>
<p>But how about <span class="math inline">\(\Sigma^{-1}\)</span>?</p>
<p>We sidestep this problem by using the pseudoinverse of <span class="math inline">\(\Sigma\)</span> instead. This can be seen as applying a linear transformation to <span class="math inline">\(X\)</span> to turn its covariance matrix to identity. And thus the model becomes a sort of a nearest neighbour classifier.</p>
<h3 id="nearest-neighbour-classifier">Nearest neighbour classifier</h3>
<p>More specifically, we want to transform the first term of (0) to a norm to get a classifier based on nearest neighbour modulo <span class="math inline">\(\log \pi_i\)</span>:</p>
<p><span class="math display">\[- {1 \over 2} \|A(x - \mu_i)\|^2 + \log\pi_i\]</span></p>
<p>To compute <span class="math inline">\(A\)</span>, we denote</p>
<p><span class="math display">\[X_c = X - M,\]</span></p>
<p>where the <span class="math inline">\(i\)</span>th row of <span class="math inline">\(M\)</span> is <span class="math inline">\(\mu_{y_i}^T\)</span>, the mean of the class <span class="math inline">\(x_i\)</span> belongs to, so that <span class="math inline">\(\Sigma = {1 \over m - n_c} X_c^T X_c\)</span>.</p>
<p>Let</p>
<p><span class="math display">\[{1 \over \sqrt{m - n_c}} X_c = U_x \Sigma_x V_x^T\]</span></p>
<p>be the SVD of <span class="math inline">\({1 \over \sqrt{m - n_c}}X_c\)</span>. Let <span class="math inline">\(D_x = \text{diag} (s_1, ..., s_r)\)</span> be the diagonal matrix with all the nonzero singular values, and rewrite <span class="math inline">\(V_x\)</span> as an <span class="math inline">\(n \times r\)</span> matrix consisting of the first <span class="math inline">\(r\)</span> columns of <span class="math inline">\(V_x\)</span>.</p>
<p>Then with an abuse of notation, the pseudoinverse of <span class="math inline">\(\Sigma\)</span> is</p>
<p><span class="math display">\[\Sigma^{-1} = V_x D_x^{-2} V_x^T.\]</span></p>
<p>So we just need to make <span class="math inline">\(A = D_x^{-1} V_x^T\)</span>. When it comes to prediction, just transform <span class="math inline">\(x\)</span> with <span class="math inline">\(A\)</span>, and find the nearest centroid <span class="math inline">\(A \mu_i\)</span> (again, modulo <span class="math inline">\(\log \pi_i\)</span>) and label the input with <span class="math inline">\(i\)</span>.</p>
<h3 id="dimensionality-reduction">Dimensionality reduction</h3>
<p>We can further simplify the prediction by dimensionality reduction. Assume <span class="math inline">\(n_c \le n\)</span>. Then the centroid spans an affine space of dimension <span class="math inline">\(p\)</span> which is at most <span class="math inline">\(n_c - 1\)</span>. So what we can do is to project both the transformed sample <span class="math inline">\(Ax\)</span> and centroids <span class="math inline">\(A\mu_i\)</span> to the linear subspace parallel to the affine space, and do the nearest neighbour classification there.</p>
<p>So we can perform SVD on the matrix <span class="math inline">\((M - \bar x) V_x D_x^{-1}\)</span> where <span class="math inline">\(\bar x\)</span>, a row vector, is the sample mean of all data i.e. average of rows of <span class="math inline">\(X\)</span>:</p>
<p><span class="math display">\[(M - \bar x) V_x D_x^{-1} = U_m \Sigma_m V_m^T.\]</span></p>
<p>Again, we let <span class="math inline">\(V_m\)</span> be the <span class="math inline">\(r \times p\)</span> matrix by keeping the first <span class="math inline">\(p\)</span> columns of <span class="math inline">\(V_m\)</span>.</p>
<p>The projection operator is thus <span class="math inline">\(V_m\)</span>. And so the final transformation is <span class="math inline">\(V_m^T D_x^{-1} V_x^T\)</span>.</p>
<p>There is no reason to stop here, and we can set <span class="math inline">\(p\)</span> even smaller, which will result in a lossy compression / regularisation equivalent to doing <a href="https://en.wikipedia.org/wiki/Principal_component_analysis">principle component analysis</a> on <span class="math inline">\((M - \bar x) V_x D_x^{-1}\)</span>.</p>
<p>Note that as of 2019-01-04, in the <a href="https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/discriminant_analysis.py">scikit-learn implementation of LDA</a>, the prediction is done without any lossy compression, even if the parameter <code>n_components</code> is set to be smaller than dimension of the affine space spanned by the centroids. In other words, the prediction does not change regardless of <code>n_components</code>.</p>
<h3 id="fisher-discriminant-analysis">Fisher discriminant analysis</h3>
<p>The Fisher discriminant analysis involves finding an <span class="math inline">\(n\)</span>-dimensional vector <span class="math inline">\(a\)</span> that maximises between-class covariance with respect to within-class covariance:</p>
<p><span class="math display">\[{a^T M_c^T M_c a \over a^T X_c^T X_c a},\]</span></p>
<p>where <span class="math inline">\(M_c = M - \bar x\)</span> is the centred sample mean matrix.</p>
<p>As it turns out, this is (almost) equivalent to the derivation above, modulo a constant. In particular, <span class="math inline">\(a = c V_m^T D_x^{-1} V_x^T\)</span> where <span class="math inline">\(p = 1\)</span> for arbitrary constant <span class="math inline">\(c\)</span>.</p>
<p>To see this, we can first multiply the denominator with a constant <span class="math inline">\({1 \over m - n_c}\)</span> so that the matrix in the denominator becomes the covariance estimate <span class="math inline">\(\Sigma\)</span>.</p>
<p>We decompose <span class="math inline">\(a\)</span>: <span class="math inline">\(a = V_x D_x^{-1} b + \tilde V_x \tilde b\)</span>, where <span class="math inline">\(\tilde V_x\)</span> consists of column vectors orthogonal to the column space of <span class="math inline">\(V_x\)</span>.</p>
<p>We ignore the second term in the decomposition. In other words, we only consider <span class="math inline">\(a\)</span> in the column space of <span class="math inline">\(V_x\)</span>.</p>
<p>Then the problem is to find an <span class="math inline">\(r\)</span>-dimensional vector <span class="math inline">\(b\)</span> to maximise</p>
<p><span class="math display">\[{b^T (M_c V_x D_x^{-1})^T (M_c V_x D_x^{-1}) b \over b^T b}.\]</span></p>
<p>This is the problem of principle component analysis, and so <span class="math inline">\(b\)</span> is first column of <span class="math inline">\(V_m\)</span>.</p>
<p>Therefore, the solution to Fisher discriminant analysis is <span class="math inline">\(a = c V_x D_x^{-1} V_m\)</span> with <span class="math inline">\(p = 1\)</span>.</p>
<h3 id="linear-model">Linear model</h3>
<p>The model is called linear discriminant analysis because it is a linear model. To see this, let <span class="math inline">\(B = V_m^T D_x^{-1} V_x^T\)</span> be the matrix of transformation. Now we are comparing</p>
<p><span class="math display">\[- {1 \over 2} \| B x - B \mu_k\|^2 + \log \pi_k\]</span></p>
<p>across all <span class="math inline">\(k\)</span>s. Expanding the norm and removing the common term <span class="math inline">\(\|B x\|^2\)</span>, we see a linear form:</p>
<p><span class="math display">\[\mu_k^T B^T B x - {1 \over 2} \|B \mu_k\|^2 + \log\pi_k\]</span></p>
<p>So the prediction for <span class="math inline">\(X_{\text{new}}\)</span> is</p>
<p><span class="math display">\[\text{argmax}_{\text{axis}=0} \left(K B^T B X_{\text{new}}^T - {1 \over 2} \|K B^T\|_{\text{axis}=1}^2 + \log \pi\right)\]</span></p>
<p>thus the decision boundaries are linear.</p>
<p>This is how scikit-learn implements LDA, by inheriting from <code>LinearClassifierMixin</code> and redirecting the classification there.</p>
<h2 id="implementation">Implementation</h2>
<p>This is where things get interesting. How do I validate my understanding of the theory? By implementing and testing the algorithm.</p>
<p>I try to implement it as close as possible to the natural language / mathematical descriptions of the model, which means clarity over performance.</p>
<p>How about testing? Numerical experiments are harder to test than combinatorial / discrete algorithms in general because the output is less verifiable by hand. My shortcut solution to this problem is to test against output from the scikit-learn package.</p>
<p>It turned out to be harder than expected, as I had to dig into the code of scikit-learn when the outputs mismatch. Their code is quite well-written though.</p>
<p>The result is <a href="https://github.com/ycpei/machine-learning/tree/master/discriminant-analysis">here</a>.</p>
<h3 id="fun-facts-about-lda">Fun facts about LDA</h3>
<p>One property that can be used to test the LDA implementation is the fact that the scatter matrix <span class="math inline">\(B(X - \bar x)^T (X - \bar X) B^T\)</span> of the transformed centred sample is diagonal.</p>
<p>This can be derived by using another fun fact that the sum of the in-class scatter matrix and the between-class scatter matrix is the sample scatter matrix:</p>
<p><span class="math display">\[X_c^T X_c + M_c^T M_c = (X - \bar x)^T (X - \bar x) = (X_c + M_c)^T (X_c + M_c).\]</span></p>
<p>The verification is not very hard and left as an exercise.</p>
</body>
</html>

            </div>
        <section id="isso-thread"></section>
        </div>
    </body>
</html>