1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *      http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17package org.apache.commons.math.stat.regression;
18
19import org.apache.commons.math.linear.Array2DRowRealMatrix;
20import org.apache.commons.math.linear.LUDecompositionImpl;
21import org.apache.commons.math.linear.QRDecomposition;
22import org.apache.commons.math.linear.QRDecompositionImpl;
23import org.apache.commons.math.linear.RealMatrix;
24import org.apache.commons.math.linear.RealVector;
25import org.apache.commons.math.stat.StatUtils;
26import org.apache.commons.math.stat.descriptive.moment.SecondMoment;
27
28/**
29 * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
30 * multiple linear regression model.</p>
31 *
32 * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
33 * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p>
34 *
35 * <p>To solve the normal equations, this implementation uses QR decomposition
36 * of the <code>X</code> matrix. (See {@link QRDecompositionImpl} for details on the
37 * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
38 * has rows corresponding to sample observations and columns corresponding to independent
39 * variables.  When the model is estimated using an intercept term (i.e. when
40 * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
41 * matrix includes an initial column identically equal to 1.  We solve the normal equations
42 * as follows:
43 * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
44 * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
45 * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
46 * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
47 * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
48 * R b = Q<sup>T</sup> y </code></pre></p>
49 *
50 * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
51 *
52 * @version $Revision: 1073464 $ $Date: 2011-02-22 20:35:02 +0100 (mar. 22 févr. 2011) $
53 * @since 2.0
54 */
55public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
56
57    /** Cached QR decomposition of X matrix */
58    private QRDecomposition qr = null;
59
60    /**
61     * Loads model x and y sample data, overriding any previous sample.
62     *
63     * Computes and caches QR decomposition of the X matrix.
64     * @param y the [n,1] array representing the y sample
65     * @param x the [n,k] array representing the x sample
66     * @throws IllegalArgumentException if the x and y array data are not
67     *             compatible for the regression
68     */
69    public void newSampleData(double[] y, double[][] x) {
70        validateSampleData(x, y);
71        newYSampleData(y);
72        newXSampleData(x);
73    }
74
75    /**
76     * {@inheritDoc}
77     * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
78     */
79    @Override
80    public void newSampleData(double[] data, int nobs, int nvars) {
81        super.newSampleData(data, nobs, nvars);
82        qr = new QRDecompositionImpl(X);
83    }
84
85    /**
86     * <p>Compute the "hat" matrix.
87     * </p>
88     * <p>The hat matrix is defined in terms of the design matrix X
89     *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
90     * </p>
91     * <p>The implementation here uses the QR decomposition to compute the
92     * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
93     * p-dimensional identity matrix augmented by 0's.  This computational
94     * formula is from "The Hat Matrix in Regression and ANOVA",
95     * David C. Hoaglin and Roy E. Welsch,
96     * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
97     *
98     * @return the hat matrix
99     */
100    public RealMatrix calculateHat() {
101        // Create augmented identity matrix
102        RealMatrix Q = qr.getQ();
103        final int p = qr.getR().getColumnDimension();
104        final int n = Q.getColumnDimension();
105        Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
106        double[][] augIData = augI.getDataRef();
107        for (int i = 0; i < n; i++) {
108            for (int j =0; j < n; j++) {
109                if (i == j && i < p) {
110                    augIData[i][j] = 1d;
111                } else {
112                    augIData[i][j] = 0d;
113                }
114            }
115        }
116
117        // Compute and return Hat matrix
118        return Q.multiply(augI).multiply(Q.transpose());
119    }
120
121    /**
122     * <p>Returns the sum of squared deviations of Y from its mean.</p>
123     *
124     * <p>If the model has no intercept term, <code>0</code> is used for the
125     * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
126     *
127     * <p>The value returned by this method is the SSTO value used in
128     * the {@link #calculateRSquared() R-squared} computation.</p>
129     *
130     * @return SSTO - the total sum of squares
131     * @see #isNoIntercept()
132     * @since 2.2
133     */
134    public double calculateTotalSumOfSquares() {
135        if (isNoIntercept()) {
136            return StatUtils.sumSq(Y.getData());
137        } else {
138            return new SecondMoment().evaluate(Y.getData());
139        }
140    }
141
142    /**
143     * Returns the sum of squared residuals.
144     *
145     * @return residual sum of squares
146     * @since 2.2
147     */
148    public double calculateResidualSumOfSquares() {
149        final RealVector residuals = calculateResiduals();
150        return residuals.dotProduct(residuals);
151    }
152
153    /**
154     * Returns the R-Squared statistic, defined by the formula <pre>
155     * R<sup>2</sup> = 1 - SSR / SSTO
156     * </pre>
157     * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
158     * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
159     *
160     * @return R-square statistic
161     * @since 2.2
162     */
163    public double calculateRSquared() {
164        return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
165    }
166
167    /**
168     * <p>Returns the adjusted R-squared statistic, defined by the formula <pre>
169     * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
170     * </pre>
171     * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
172     * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
173     * of observations and p is the number of parameters estimated (including the intercept).</p>
174     *
175     * <p>If the regression is estimated without an intercept term, what is returned is <pre>
176     * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
177     * </pre></p>
178     *
179     * @return adjusted R-Squared statistic
180     * @see #isNoIntercept()
181     * @since 2.2
182     */
183    public double calculateAdjustedRSquared() {
184        final double n = X.getRowDimension();
185        if (isNoIntercept()) {
186            return 1 - (1 - calculateRSquared()) * (n / (n - X.getColumnDimension()));
187        } else {
188            return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
189                (calculateTotalSumOfSquares() * (n - X.getColumnDimension()));
190        }
191    }
192
193    /**
194     * {@inheritDoc}
195     * <p>This implementation computes and caches the QR decomposition of the X matrix
196     * once it is successfully loaded.</p>
197     */
198    @Override
199    protected void newXSampleData(double[][] x) {
200        super.newXSampleData(x);
201        qr = new QRDecompositionImpl(X);
202    }
203
204    /**
205     * Calculates the regression coefficients using OLS.
206     *
207     * @return beta
208     */
209    @Override
210    protected RealVector calculateBeta() {
211        return qr.getSolver().solve(Y);
212    }
213
214    /**
215     * <p>Calculates the variance-covariance matrix of the regression parameters.
216     * </p>
217     * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
218     * </p>
219     * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
220     * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
221     * R included, where p = the length of the beta vector.</p>
222     *
223     * @return The beta variance-covariance matrix
224     */
225    @Override
226    protected RealMatrix calculateBetaVariance() {
227        int p = X.getColumnDimension();
228        RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
229        RealMatrix Rinv = new LUDecompositionImpl(Raug).getSolver().getInverse();
230        return Rinv.multiply(Rinv.transpose());
231    }
232
233}
234