1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.commons.math.stat.regression;
18
19 import org.apache.commons.math.MathRuntimeException;
20 import org.apache.commons.math.linear.RealMatrix;
21 import org.apache.commons.math.linear.Array2DRowRealMatrix;
22 import org.apache.commons.math.linear.RealVector;
23 import org.apache.commons.math.linear.ArrayRealVector;
24
25 /**
26 * Abstract base class for implementations of MultipleLinearRegression.
27 * @version $Revision: 791244 $ $Date: 2009-07-05 09:29:37 -0400 (Sun, 05 Jul 2009) $
28 * @since 2.0
29 */
30 public abstract class AbstractMultipleLinearRegression implements
31 MultipleLinearRegression {
32
33 /** X sample data. */
34 protected RealMatrix X;
35
36 /** Y sample data. */
37 protected RealVector Y;
38
39 /**
40 * Loads model x and y sample data from a flat array of data, overriding any previous sample.
41 * Assumes that rows are concatenated with y values first in each row.
42 *
43 * @param data input data array
44 * @param nobs number of observations (rows)
45 * @param nvars number of independent variables (columns, not counting y)
46 */
47 public void newSampleData(double[] data, int nobs, int nvars) {
48 double[] y = new double[nobs];
49 double[][] x = new double[nobs][nvars + 1];
50 int pointer = 0;
51 for (int i = 0; i < nobs; i++) {
52 y[i] = data[pointer++];
53 x[i][0] = 1.0d;
54 for (int j = 1; j < nvars + 1; j++) {
55 x[i][j] = data[pointer++];
56 }
57 }
58 this.X = new Array2DRowRealMatrix(x);
59 this.Y = new ArrayRealVector(y);
60 }
61
62 /**
63 * Loads new y sample data, overriding any previous sample
64 *
65 * @param y the [n,1] array representing the y sample
66 */
67 protected void newYSampleData(double[] y) {
68 this.Y = new ArrayRealVector(y);
69 }
70
71 /**
72 * Loads new x sample data, overriding any previous sample
73 *
74 * @param x the [n,k] array representing the x sample
75 */
76 protected void newXSampleData(double[][] x) {
77 this.X = new Array2DRowRealMatrix(x);
78 }
79
80 /**
81 * Validates sample data.
82 *
83 * @param x the [n,k] array representing the x sample
84 * @param y the [n,1] array representing the y sample
85 * @throws IllegalArgumentException if the x and y array data are not
86 * compatible for the regression
87 */
88 protected void validateSampleData(double[][] x, double[] y) {
89 if ((x == null) || (y == null) || (x.length != y.length)) {
90 throw MathRuntimeException.createIllegalArgumentException(
91 "dimension mismatch {0} != {1}",
92 (x == null) ? 0 : x.length,
93 (y == null) ? 0 : y.length);
94 } else if ((x.length > 0) && (x[0].length > x.length)) {
95 throw MathRuntimeException.createIllegalArgumentException(
96 "not enough data ({0} rows) for this many predictors ({1} predictors)",
97 x.length, x[0].length);
98 }
99 }
100
101 /**
102 * Validates sample data.
103 *
104 * @param x the [n,k] array representing the x sample
105 * @param covariance the [n,n] array representing the covariance matrix
106 * @throws IllegalArgumentException if the x sample data or covariance
107 * matrix are not compatible for the regression
108 */
109 protected void validateCovarianceData(double[][] x, double[][] covariance) {
110 if (x.length != covariance.length) {
111 throw MathRuntimeException.createIllegalArgumentException(
112 "dimension mismatch {0} != {1}", x.length, covariance.length);
113 }
114 if (covariance.length > 0 && covariance.length != covariance[0].length) {
115 throw MathRuntimeException.createIllegalArgumentException(
116 "a {0}x{1} matrix was provided instead of a square matrix",
117 covariance.length, covariance[0].length);
118 }
119 }
120
121 /**
122 * {@inheritDoc}
123 */
124 public double[] estimateRegressionParameters() {
125 RealVector b = calculateBeta();
126 return b.getData();
127 }
128
129 /**
130 * {@inheritDoc}
131 */
132 public double[] estimateResiduals() {
133 RealVector b = calculateBeta();
134 RealVector e = Y.subtract(X.operate(b));
135 return e.getData();
136 }
137
138 /**
139 * {@inheritDoc}
140 */
141 public double[][] estimateRegressionParametersVariance() {
142 return calculateBetaVariance().getData();
143 }
144
145 /**
146 * {@inheritDoc}
147 */
148 public double[] estimateRegressionParametersStandardErrors() {
149 double[][] betaVariance = estimateRegressionParametersVariance();
150 double sigma = calculateYVariance();
151 int length = betaVariance[0].length;
152 double[] result = new double[length];
153 for (int i = 0; i < length; i++) {
154 result[i] = Math.sqrt(sigma * betaVariance[i][i]);
155 }
156 return result;
157 }
158
159 /**
160 * {@inheritDoc}
161 */
162 public double estimateRegressandVariance() {
163 return calculateYVariance();
164 }
165
166 /**
167 * Calculates the beta of multiple linear regression in matrix notation.
168 *
169 * @return beta
170 */
171 protected abstract RealVector calculateBeta();
172
173 /**
174 * Calculates the beta variance of multiple linear regression in matrix
175 * notation.
176 *
177 * @return beta variance
178 */
179 protected abstract RealMatrix calculateBetaVariance();
180
181 /**
182 * Calculates the Y variance of multiple linear regression.
183 *
184 * @return Y variance
185 */
186 protected abstract double calculateYVariance();
187
188 /**
189 * Calculates the residuals of multiple linear regression in matrix
190 * notation.
191 *
192 * <pre>
193 * u = y - X * b
194 * </pre>
195 *
196 * @return The residuals [n,1] matrix
197 */
198 protected RealVector calculateResiduals() {
199 RealVector b = calculateBeta();
200 return Y.subtract(X.operate(b));
201 }
202
203 }