001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.math.stat.regression;
018
019 import static org.junit.Assert.assertEquals;
020
021 import org.apache.commons.math.TestUtils;
022 import org.apache.commons.math.linear.DefaultRealMatrixChangingVisitor;
023 import org.apache.commons.math.linear.MatrixUtils;
024 import org.apache.commons.math.linear.MatrixVisitorException;
025 import org.apache.commons.math.linear.RealMatrix;
026 import org.apache.commons.math.linear.Array2DRowRealMatrix;
027 import org.junit.Before;
028 import org.junit.Test;
029
030 public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest {
031
032 private double[] y;
033 private double[][] x;
034
035 @Before
036 @Override
037 public void setUp(){
038 y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
039 x = new double[6][];
040 x[0] = new double[]{1.0, 0, 0, 0, 0, 0};
041 x[1] = new double[]{1.0, 2.0, 0, 0, 0, 0};
042 x[2] = new double[]{1.0, 0, 3.0, 0, 0, 0};
043 x[3] = new double[]{1.0, 0, 0, 4.0, 0, 0};
044 x[4] = new double[]{1.0, 0, 0, 0, 5.0, 0};
045 x[5] = new double[]{1.0, 0, 0, 0, 0, 6.0};
046 super.setUp();
047 }
048
049 @Override
050 protected OLSMultipleLinearRegression createRegression() {
051 OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
052 regression.newSampleData(y, x);
053 return regression;
054 }
055
056 @Override
057 protected int getNumberOfRegressors() {
058 return x[0].length;
059 }
060
061 @Override
062 protected int getSampleSize() {
063 return y.length;
064 }
065
066 @Test(expected=IllegalArgumentException.class)
067 public void cannotAddXSampleData() {
068 createRegression().newSampleData(new double[]{}, null);
069 }
070
071 @Test(expected=IllegalArgumentException.class)
072 public void cannotAddNullYSampleData() {
073 createRegression().newSampleData(null, new double[][]{});
074 }
075
076 @Test(expected=IllegalArgumentException.class)
077 public void cannotAddSampleDataWithSizeMismatch() {
078 double[] y = new double[]{1.0, 2.0};
079 double[][] x = new double[1][];
080 x[0] = new double[]{1.0, 0};
081 createRegression().newSampleData(y, x);
082 }
083
084 @Test
085 public void testPerfectFit() {
086 double[] betaHat = regression.estimateRegressionParameters();
087 TestUtils.assertEquals(betaHat,
088 new double[]{ 11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0 },
089 1e-14);
090 double[] residuals = regression.estimateResiduals();
091 TestUtils.assertEquals(residuals, new double[]{0d,0d,0d,0d,0d,0d},
092 1e-14);
093 RealMatrix errors =
094 new Array2DRowRealMatrix(regression.estimateRegressionParametersVariance(), false);
095 final double[] s = { 1.0, -1.0 / 2.0, -1.0 / 3.0, -1.0 / 4.0, -1.0 / 5.0, -1.0 / 6.0 };
096 RealMatrix referenceVariance = new Array2DRowRealMatrix(s.length, s.length);
097 referenceVariance.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
098 @Override
099 public double visit(int row, int column, double value)
100 throws MatrixVisitorException {
101 if (row == 0) {
102 return s[column];
103 }
104 double x = s[row] * s[column];
105 return (row == column) ? 2 * x : x;
106 }
107 });
108 assertEquals(0.0,
109 errors.subtract(referenceVariance).getNorm(),
110 5.0e-16 * referenceVariance.getNorm());
111 }
112
113
114 /**
115 * Test Longley dataset against certified values provided by NIST.
116 * Data Source: J. Longley (1967) "An Appraisal of Least Squares
117 * Programs for the Electronic Computer from the Point of View of the User"
118 * Journal of the American Statistical Association, vol. 62. September,
119 * pp. 819-841.
120 *
121 * Certified values (and data) are from NIST:
122 * http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat
123 */
124 @Test
125 public void testLongly() {
126 // Y values are first, then independent vars
127 // Each row is one observation
128 double[] design = new double[] {
129 60323,83.0,234289,2356,1590,107608,1947,
130 61122,88.5,259426,2325,1456,108632,1948,
131 60171,88.2,258054,3682,1616,109773,1949,
132 61187,89.5,284599,3351,1650,110929,1950,
133 63221,96.2,328975,2099,3099,112075,1951,
134 63639,98.1,346999,1932,3594,113270,1952,
135 64989,99.0,365385,1870,3547,115094,1953,
136 63761,100.0,363112,3578,3350,116219,1954,
137 66019,101.2,397469,2904,3048,117388,1955,
138 67857,104.6,419180,2822,2857,118734,1956,
139 68169,108.4,442769,2936,2798,120445,1957,
140 66513,110.8,444546,4681,2637,121950,1958,
141 68655,112.6,482704,3813,2552,123366,1959,
142 69564,114.2,502601,3931,2514,125368,1960,
143 69331,115.7,518173,4806,2572,127852,1961,
144 70551,116.9,554894,4007,2827,130081,1962
145 };
146
147 // Transform to Y and X required by interface
148 int nobs = 16;
149 int nvars = 6;
150
151 // Estimate the model
152 OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
153 model.newSampleData(design, nobs, nvars);
154
155 // Check expected beta values from NIST
156 double[] betaHat = model.estimateRegressionParameters();
157 TestUtils.assertEquals(betaHat,
158 new double[]{-3482258.63459582, 15.0618722713733,
159 -0.358191792925910E-01,-2.02022980381683,
160 -1.03322686717359,-0.511041056535807E-01,
161 1829.15146461355}, 2E-8); //
162
163 // Check expected residuals from R
164 double[] residuals = model.estimateResiduals();
165 TestUtils.assertEquals(residuals, new double[]{
166 267.340029759711,-94.0139423988359,46.28716775752924,
167 -410.114621930906,309.7145907602313,-249.3112153297231,
168 -164.0489563956039,-13.18035686637081,14.30477260005235,
169 455.394094551857,-17.26892711483297,-39.0550425226967,
170 -155.5499735953195,-85.6713080421283,341.9315139607727,
171 -206.7578251937366},
172 1E-8);
173
174 // Check standard errors from NIST
175 double[] errors = model.estimateRegressionParametersStandardErrors();
176 TestUtils.assertEquals(new double[] {890420.383607373,
177 84.9149257747669,
178 0.334910077722432E-01,
179 0.488399681651699,
180 0.214274163161675,
181 0.226073200069370,
182 455.478499142212}, errors, 1E-6);
183 }
184
185 /**
186 * Test R Swiss fertility dataset against R.
187 * Data Source: R datasets package
188 */
189 @Test
190 public void testSwissFertility() {
191 double[] design = new double[] {
192 80.2,17.0,15,12,9.96,
193 83.1,45.1,6,9,84.84,
194 92.5,39.7,5,5,93.40,
195 85.8,36.5,12,7,33.77,
196 76.9,43.5,17,15,5.16,
197 76.1,35.3,9,7,90.57,
198 83.8,70.2,16,7,92.85,
199 92.4,67.8,14,8,97.16,
200 82.4,53.3,12,7,97.67,
201 82.9,45.2,16,13,91.38,
202 87.1,64.5,14,6,98.61,
203 64.1,62.0,21,12,8.52,
204 66.9,67.5,14,7,2.27,
205 68.9,60.7,19,12,4.43,
206 61.7,69.3,22,5,2.82,
207 68.3,72.6,18,2,24.20,
208 71.7,34.0,17,8,3.30,
209 55.7,19.4,26,28,12.11,
210 54.3,15.2,31,20,2.15,
211 65.1,73.0,19,9,2.84,
212 65.5,59.8,22,10,5.23,
213 65.0,55.1,14,3,4.52,
214 56.6,50.9,22,12,15.14,
215 57.4,54.1,20,6,4.20,
216 72.5,71.2,12,1,2.40,
217 74.2,58.1,14,8,5.23,
218 72.0,63.5,6,3,2.56,
219 60.5,60.8,16,10,7.72,
220 58.3,26.8,25,19,18.46,
221 65.4,49.5,15,8,6.10,
222 75.5,85.9,3,2,99.71,
223 69.3,84.9,7,6,99.68,
224 77.3,89.7,5,2,100.00,
225 70.5,78.2,12,6,98.96,
226 79.4,64.9,7,3,98.22,
227 65.0,75.9,9,9,99.06,
228 92.2,84.6,3,3,99.46,
229 79.3,63.1,13,13,96.83,
230 70.4,38.4,26,12,5.62,
231 65.7,7.7,29,11,13.79,
232 72.7,16.7,22,13,11.22,
233 64.4,17.6,35,32,16.92,
234 77.6,37.6,15,7,4.97,
235 67.6,18.7,25,7,8.65,
236 35.0,1.2,37,53,42.34,
237 44.7,46.6,16,29,50.43,
238 42.8,27.7,22,29,58.33
239 };
240
241 // Transform to Y and X required by interface
242 int nobs = 47;
243 int nvars = 4;
244
245 // Estimate the model
246 OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
247 model.newSampleData(design, nobs, nvars);
248
249 // Check expected beta values from R
250 double[] betaHat = model.estimateRegressionParameters();
251 TestUtils.assertEquals(betaHat,
252 new double[]{91.05542390271397,
253 -0.22064551045715,
254 -0.26058239824328,
255 -0.96161238456030,
256 0.12441843147162}, 1E-12);
257
258 // Check expected residuals from R
259 double[] residuals = model.estimateResiduals();
260 TestUtils.assertEquals(residuals, new double[]{
261 7.1044267859730512,1.6580347433531366,
262 4.6944952770029644,8.4548022690166160,13.6547432343186212,
263 -9.3586864458500774,7.5822446330520386,15.5568995563859289,
264 0.8113090736598980,7.1186762732484308,7.4251378771228724,
265 2.6761316873234109,0.8351584810309354,7.1769991119615177,
266 -3.8746753206299553,-3.1337779476387251,-0.1412575244091504,
267 1.1186809170469780,-6.3588097346816594,3.4039270429434074,
268 2.3374058329820175,-7.9272368576900503,-7.8361010968497959,
269 -11.2597369269357070,0.9445333697827101,6.6544245101380328,
270 -0.9146136301118665,-4.3152449403848570,-4.3536932047009183,
271 -3.8907885169304661,-6.3027643926302188,-7.8308982189289091,
272 -3.1792280015332750,-6.7167298771158226,-4.8469946718041754,
273 -10.6335664353633685,11.1031134362036958,6.0084032641811733,
274 5.4326230830188482,-7.2375578629692230,2.1671550814448222,
275 15.0147574652763112,4.8625103516321015,-7.1597256413907706,
276 -0.4515205619767598,-10.2916870903837587,-15.7812984571900063},
277 1E-12);
278
279 // Check standard errors from R
280 double[] errors = model.estimateRegressionParametersStandardErrors();
281 TestUtils.assertEquals(new double[] {6.94881329475087,
282 0.07360008972340,
283 0.27410957467466,
284 0.19454551679325,
285 0.03726654773803}, errors, 1E-10);
286 }
287
288 /**
289 * Test hat matrix computation
290 *
291 * @throws Exception
292 */
293 @Test
294 public void testHat() throws Exception {
295
296 /*
297 * This example is from "The Hat Matrix in Regression and ANOVA",
298 * David C. Hoaglin and Roy E. Welsch,
299 * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
300 *
301 */
302 double[] design = new double[] {
303 11.14, .499, 11.1,
304 12.74, .558, 8.9,
305 13.13, .604, 8.8,
306 11.51, .441, 8.9,
307 12.38, .550, 8.8,
308 12.60, .528, 9.9,
309 11.13, .418, 10.7,
310 11.7, .480, 10.5,
311 11.02, .406, 10.5,
312 11.41, .467, 10.7
313 };
314
315 int nobs = 10;
316 int nvars = 2;
317
318 // Estimate the model
319 OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
320 model.newSampleData(design, nobs, nvars);
321
322 RealMatrix hat = model.calculateHat();
323
324 // Reference data is upper half of symmetric hat matrix
325 double[] referenceData = new double[] {
326 .418, -.002, .079, -.274, -.046, .181, .128, .222, .050, .242,
327 .242, .292, .136, .243, .128, -.041, .033, -.035, .004,
328 .417, -.019, .273, .187, -.126, .044, -.153, .004,
329 .604, .197, -.038, .168, -.022, .275, -.028,
330 .252, .111, -.030, .019, -.010, -.010,
331 .148, .042, .117, .012, .111,
332 .262, .145, .277, .174,
333 .154, .120, .168,
334 .315, .148,
335 .187
336 };
337
338 // Check against reference data and verify symmetry
339 int k = 0;
340 for (int i = 0; i < 10; i++) {
341 for (int j = i; j < 10; j++) {
342 assertEquals(referenceData[k], hat.getEntry(i, j), 10e-3);
343 assertEquals(hat.getEntry(i, j), hat.getEntry(j, i), 10e-12);
344 k++;
345 }
346 }
347
348 /*
349 * Verify that residuals computed using the hat matrix are close to
350 * what we get from direct computation, i.e. r = (I - H) y
351 */
352 double[] residuals = model.estimateResiduals();
353 RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
354 double[] hatResiduals = I.subtract(hat).operate(model.Y).getData();
355 TestUtils.assertEquals(residuals, hatResiduals, 10e-12);
356 }
357 }