001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.math.stat.descriptive;
018
019
020 import java.util.Locale;
021
022 import junit.framework.Test;
023 import junit.framework.TestCase;
024 import junit.framework.TestSuite;
025
026 import org.apache.commons.math.DimensionMismatchException;
027 import org.apache.commons.math.TestUtils;
028 import org.apache.commons.math.stat.descriptive.moment.Mean;
029
030 /**
031 * Test cases for the {@link MultivariateSummaryStatistics} class.
032 *
033 * @version $Revision: 797744 $ $Date: 2009-07-25 07:09:14 -0400 (Sat, 25 Jul 2009) $
034 */
035
036 public class MultivariateSummaryStatisticsTest extends TestCase {
037
038 public MultivariateSummaryStatisticsTest(String name) {
039 super(name);
040 }
041
042 public static Test suite() {
043 TestSuite suite = new TestSuite(MultivariateSummaryStatisticsTest.class);
044 suite.setName("MultivariateSummaryStatistics tests");
045 return suite;
046 }
047
048 protected MultivariateSummaryStatistics createMultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) {
049 return new MultivariateSummaryStatistics(k, isCovarianceBiasCorrected);
050 }
051
052 public void testSetterInjection() throws Exception {
053 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
054 u.setMeanImpl(new StorelessUnivariateStatistic[] {
055 new sumMean(), new sumMean()
056 });
057 u.addValue(new double[] { 1, 2 });
058 u.addValue(new double[] { 3, 4 });
059 assertEquals(4, u.getMean()[0], 1E-14);
060 assertEquals(6, u.getMean()[1], 1E-14);
061 u.clear();
062 u.addValue(new double[] { 1, 2 });
063 u.addValue(new double[] { 3, 4 });
064 assertEquals(4, u.getMean()[0], 1E-14);
065 assertEquals(6, u.getMean()[1], 1E-14);
066 u.clear();
067 u.setMeanImpl(new StorelessUnivariateStatistic[] {
068 new Mean(), new Mean()
069 }); // OK after clear
070 u.addValue(new double[] { 1, 2 });
071 u.addValue(new double[] { 3, 4 });
072 assertEquals(2, u.getMean()[0], 1E-14);
073 assertEquals(3, u.getMean()[1], 1E-14);
074 assertEquals(2, u.getDimension());
075 }
076
077 public void testSetterIllegalState() throws Exception {
078 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
079 u.addValue(new double[] { 1, 2 });
080 u.addValue(new double[] { 3, 4 });
081 try {
082 u.setMeanImpl(new StorelessUnivariateStatistic[] {
083 new sumMean(), new sumMean()
084 });
085 fail("Expecting IllegalStateException");
086 } catch (IllegalStateException ex) {
087 // expected
088 }
089 }
090
091 public void testToString() throws DimensionMismatchException {
092 MultivariateSummaryStatistics stats = createMultivariateSummaryStatistics(2, true);
093 stats.addValue(new double[] {1, 3});
094 stats.addValue(new double[] {2, 2});
095 stats.addValue(new double[] {3, 1});
096 Locale d = Locale.getDefault();
097 Locale.setDefault(Locale.US);
098 assertEquals("MultivariateSummaryStatistics:\n" +
099 "n: 3\n" +
100 "min: 1.0, 1.0\n" +
101 "max: 3.0, 3.0\n" +
102 "mean: 2.0, 2.0\n" +
103 "geometric mean: 1.817..., 1.817...\n" +
104 "sum of squares: 14.0, 14.0\n" +
105 "sum of logarithms: 1.791..., 1.791...\n" +
106 "standard deviation: 1.0, 1.0\n" +
107 "covariance: Array2DRowRealMatrix{{1.0,-1.0},{-1.0,1.0}}\n",
108 stats.toString().replaceAll("([0-9]+\\.[0-9][0-9][0-9])[0-9]+", "$1..."));
109 Locale.setDefault(d);
110 }
111
112 public void testShuffledStatistics() throws DimensionMismatchException {
113 // the purpose of this test is only to check the get/set methods
114 // we are aware shuffling statistics like this is really not
115 // something sensible to do in production ...
116 MultivariateSummaryStatistics reference = createMultivariateSummaryStatistics(2, true);
117 MultivariateSummaryStatistics shuffled = createMultivariateSummaryStatistics(2, true);
118
119 StorelessUnivariateStatistic[] tmp = shuffled.getGeoMeanImpl();
120 shuffled.setGeoMeanImpl(shuffled.getMeanImpl());
121 shuffled.setMeanImpl(shuffled.getMaxImpl());
122 shuffled.setMaxImpl(shuffled.getMinImpl());
123 shuffled.setMinImpl(shuffled.getSumImpl());
124 shuffled.setSumImpl(shuffled.getSumsqImpl());
125 shuffled.setSumsqImpl(shuffled.getSumLogImpl());
126 shuffled.setSumLogImpl(tmp);
127
128 for (int i = 100; i > 0; --i) {
129 reference.addValue(new double[] {i, i});
130 shuffled.addValue(new double[] {i, i});
131 }
132
133 TestUtils.assertEquals(reference.getMean(), shuffled.getGeometricMean(), 1.0e-10);
134 TestUtils.assertEquals(reference.getMax(), shuffled.getMean(), 1.0e-10);
135 TestUtils.assertEquals(reference.getMin(), shuffled.getMax(), 1.0e-10);
136 TestUtils.assertEquals(reference.getSum(), shuffled.getMin(), 1.0e-10);
137 TestUtils.assertEquals(reference.getSumSq(), shuffled.getSum(), 1.0e-10);
138 TestUtils.assertEquals(reference.getSumLog(), shuffled.getSumSq(), 1.0e-10);
139 TestUtils.assertEquals(reference.getGeometricMean(), shuffled.getSumLog(), 1.0e-10);
140
141 }
142
143 /**
144 * Bogus mean implementation to test setter injection.
145 * Returns the sum instead of the mean.
146 */
147 static class sumMean implements StorelessUnivariateStatistic {
148 private double sum = 0;
149 private long n = 0;
150 public double evaluate(double[] values, int begin, int length) {
151 return 0;
152 }
153 public double evaluate(double[] values) {
154 return 0;
155 }
156 public void clear() {
157 sum = 0;
158 n = 0;
159 }
160 public long getN() {
161 return n;
162 }
163 public double getResult() {
164 return sum;
165 }
166 public void increment(double d) {
167 sum += d;
168 n++;
169 }
170 public void incrementAll(double[] values, int start, int length) {
171 }
172 public void incrementAll(double[] values) {
173 }
174 public StorelessUnivariateStatistic copy() {
175 return new sumMean();
176 }
177 }
178
179 public void testDimension() {
180 try {
181 createMultivariateSummaryStatistics(2, true).addValue(new double[3]);
182 } catch (DimensionMismatchException dme) {
183 // expected behavior
184 } catch (Exception e) {
185 fail("wrong exception caught");
186 }
187 }
188
189 /** test stats */
190 public void testStats() throws DimensionMismatchException {
191 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
192 assertEquals(0, u.getN());
193 u.addValue(new double[] { 1, 2 });
194 u.addValue(new double[] { 2, 3 });
195 u.addValue(new double[] { 2, 3 });
196 u.addValue(new double[] { 3, 4 });
197 assertEquals( 4, u.getN());
198 assertEquals( 8, u.getSum()[0], 1.0e-10);
199 assertEquals(12, u.getSum()[1], 1.0e-10);
200 assertEquals(18, u.getSumSq()[0], 1.0e-10);
201 assertEquals(38, u.getSumSq()[1], 1.0e-10);
202 assertEquals( 1, u.getMin()[0], 1.0e-10);
203 assertEquals( 2, u.getMin()[1], 1.0e-10);
204 assertEquals( 3, u.getMax()[0], 1.0e-10);
205 assertEquals( 4, u.getMax()[1], 1.0e-10);
206 assertEquals(2.4849066497880003102, u.getSumLog()[0], 1.0e-10);
207 assertEquals( 4.276666119016055311, u.getSumLog()[1], 1.0e-10);
208 assertEquals( 1.8612097182041991979, u.getGeometricMean()[0], 1.0e-10);
209 assertEquals( 2.9129506302439405217, u.getGeometricMean()[1], 1.0e-10);
210 assertEquals( 2, u.getMean()[0], 1.0e-10);
211 assertEquals( 3, u.getMean()[1], 1.0e-10);
212 assertEquals(Math.sqrt(2.0 / 3.0), u.getStandardDeviation()[0], 1.0e-10);
213 assertEquals(Math.sqrt(2.0 / 3.0), u.getStandardDeviation()[1], 1.0e-10);
214 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 0), 1.0e-10);
215 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 1), 1.0e-10);
216 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 0), 1.0e-10);
217 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 1), 1.0e-10);
218 u.clear();
219 assertEquals(0, u.getN());
220 }
221
222 public void testN0andN1Conditions() throws Exception {
223 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true);
224 assertTrue(Double.isNaN(u.getMean()[0]));
225 assertTrue(Double.isNaN(u.getStandardDeviation()[0]));
226
227 /* n=1 */
228 u.addValue(new double[] { 1 });
229 assertEquals(1.0, u.getMean()[0], 1.0e-10);
230 assertEquals(1.0, u.getGeometricMean()[0], 1.0e-10);
231 assertEquals(0.0, u.getStandardDeviation()[0], 1.0e-10);
232
233 /* n=2 */
234 u.addValue(new double[] { 2 });
235 assertTrue(u.getStandardDeviation()[0] > 0);
236
237 }
238
239 public void testNaNContracts() throws DimensionMismatchException {
240 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true);
241 assertTrue(Double.isNaN(u.getMean()[0]));
242 assertTrue(Double.isNaN(u.getMin()[0]));
243 assertTrue(Double.isNaN(u.getStandardDeviation()[0]));
244 assertTrue(Double.isNaN(u.getGeometricMean()[0]));
245
246 u.addValue(new double[] { 1.0 });
247 assertFalse(Double.isNaN(u.getMean()[0]));
248 assertFalse(Double.isNaN(u.getMin()[0]));
249 assertFalse(Double.isNaN(u.getStandardDeviation()[0]));
250 assertFalse(Double.isNaN(u.getGeometricMean()[0]));
251
252 }
253
254 public void testSerialization() throws DimensionMismatchException {
255 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
256 // Empty test
257 TestUtils.checkSerializedEquality(u);
258 MultivariateSummaryStatistics s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u);
259 assertEquals(u, s);
260
261 // Add some data
262 u.addValue(new double[] { 2d, 1d });
263 u.addValue(new double[] { 1d, 1d });
264 u.addValue(new double[] { 3d, 1d });
265 u.addValue(new double[] { 4d, 1d });
266 u.addValue(new double[] { 5d, 1d });
267
268 // Test again
269 TestUtils.checkSerializedEquality(u);
270 s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u);
271 assertEquals(u, s);
272
273 }
274
275 public void testEqualsAndHashCode() throws DimensionMismatchException {
276 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
277 MultivariateSummaryStatistics t = null;
278 int emptyHash = u.hashCode();
279 assertTrue(u.equals(u));
280 assertFalse(u.equals(t));
281 assertFalse(u.equals(Double.valueOf(0)));
282 t = createMultivariateSummaryStatistics(2, true);
283 assertTrue(t.equals(u));
284 assertTrue(u.equals(t));
285 assertEquals(emptyHash, t.hashCode());
286
287 // Add some data to u
288 u.addValue(new double[] { 2d, 1d });
289 u.addValue(new double[] { 1d, 1d });
290 u.addValue(new double[] { 3d, 1d });
291 u.addValue(new double[] { 4d, 1d });
292 u.addValue(new double[] { 5d, 1d });
293 assertFalse(t.equals(u));
294 assertFalse(u.equals(t));
295 assertTrue(u.hashCode() != t.hashCode());
296
297 //Add data in same order to t
298 t.addValue(new double[] { 2d, 1d });
299 t.addValue(new double[] { 1d, 1d });
300 t.addValue(new double[] { 3d, 1d });
301 t.addValue(new double[] { 4d, 1d });
302 t.addValue(new double[] { 5d, 1d });
303 assertTrue(t.equals(u));
304 assertTrue(u.equals(t));
305 assertEquals(u.hashCode(), t.hashCode());
306
307 // Clear and make sure summaries are indistinguishable from empty summary
308 u.clear();
309 t.clear();
310 assertTrue(t.equals(u));
311 assertTrue(u.equals(t));
312 assertEquals(emptyHash, t.hashCode());
313 assertEquals(emptyHash, u.hashCode());
314 }
315
316 }