001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.commons.math.optimization.direct;
019
020 import static org.junit.Assert.assertEquals;
021 import static org.junit.Assert.assertNotNull;
022 import static org.junit.Assert.assertNull;
023 import static org.junit.Assert.assertTrue;
024 import static org.junit.Assert.fail;
025
026 import org.apache.commons.math.ConvergenceException;
027 import org.apache.commons.math.FunctionEvaluationException;
028 import org.apache.commons.math.MathException;
029 import org.apache.commons.math.MaxEvaluationsExceededException;
030 import org.apache.commons.math.MaxIterationsExceededException;
031 import org.apache.commons.math.analysis.MultivariateRealFunction;
032 import org.apache.commons.math.analysis.MultivariateVectorialFunction;
033 import org.apache.commons.math.linear.Array2DRowRealMatrix;
034 import org.apache.commons.math.linear.RealMatrix;
035 import org.apache.commons.math.optimization.GoalType;
036 import org.apache.commons.math.optimization.LeastSquaresConverter;
037 import org.apache.commons.math.optimization.OptimizationException;
038 import org.apache.commons.math.optimization.RealPointValuePair;
039 import org.apache.commons.math.optimization.SimpleRealPointChecker;
040 import org.apache.commons.math.optimization.SimpleScalarValueChecker;
041 import org.junit.Test;
042
043 public class NelderMeadTest {
044
045 @Test
046 public void testFunctionEvaluationExceptions() {
047 MultivariateRealFunction wrong =
048 new MultivariateRealFunction() {
049 private static final long serialVersionUID = 4751314470965489371L;
050 public double value(double[] x) throws FunctionEvaluationException {
051 if (x[0] < 0) {
052 throw new FunctionEvaluationException(x, "{0}", "oops");
053 } else if (x[0] > 1) {
054 throw new FunctionEvaluationException(new RuntimeException("oops"), x);
055 } else {
056 return x[0] * (1 - x[0]);
057 }
058 }
059 };
060 try {
061 NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6);
062 optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { -1.0 });
063 fail("an exception should have been thrown");
064 } catch (FunctionEvaluationException ce) {
065 // expected behavior
066 assertNull(ce.getCause());
067 } catch (Exception e) {
068 fail("wrong exception caught: " + e.getMessage());
069 }
070 try {
071 NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6);
072 optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { +2.0 });
073 fail("an exception should have been thrown");
074 } catch (FunctionEvaluationException ce) {
075 // expected behavior
076 assertNotNull(ce.getCause());
077 } catch (Exception e) {
078 fail("wrong exception caught: " + e.getMessage());
079 }
080 }
081
082 @Test
083 public void testMinimizeMaximize()
084 throws FunctionEvaluationException, ConvergenceException {
085
086 // the following function has 4 local extrema:
087 final double xM = -3.841947088256863675365;
088 final double yM = -1.391745200270734924416;
089 final double xP = 0.2286682237349059125691;
090 final double yP = -yM;
091 final double valueXmYm = 0.2373295333134216789769; // local maximum
092 final double valueXmYp = -valueXmYm; // local minimum
093 final double valueXpYm = -0.7290400707055187115322; // global minimum
094 final double valueXpYp = -valueXpYm; // global maximum
095 MultivariateRealFunction fourExtrema = new MultivariateRealFunction() {
096 private static final long serialVersionUID = -7039124064449091152L;
097 public double value(double[] variables) throws FunctionEvaluationException {
098 final double x = variables[0];
099 final double y = variables[1];
100 return ((x == 0) || (y == 0)) ? 0 : (Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y));
101 }
102 };
103
104 NelderMead optimizer = new NelderMead();
105 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-10, 1.0e-30));
106 optimizer.setMaxIterations(100);
107 optimizer.setStartConfiguration(new double[] { 0.2, 0.2 });
108 RealPointValuePair optimum;
109
110 // minimization
111 optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { -3.0, 0 });
112 assertEquals(xM, optimum.getPoint()[0], 2.0e-7);
113 assertEquals(yP, optimum.getPoint()[1], 2.0e-5);
114 assertEquals(valueXmYp, optimum.getValue(), 6.0e-12);
115 assertTrue(optimizer.getEvaluations() > 60);
116 assertTrue(optimizer.getEvaluations() < 90);
117
118 optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { +1, 0 });
119 assertEquals(xP, optimum.getPoint()[0], 5.0e-6);
120 assertEquals(yM, optimum.getPoint()[1], 6.0e-6);
121 assertEquals(valueXpYm, optimum.getValue(), 1.0e-11);
122 assertTrue(optimizer.getEvaluations() > 60);
123 assertTrue(optimizer.getEvaluations() < 90);
124
125 // maximization
126 optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3.0, 0.0 });
127 assertEquals(xM, optimum.getPoint()[0], 1.0e-5);
128 assertEquals(yM, optimum.getPoint()[1], 3.0e-6);
129 assertEquals(valueXmYm, optimum.getValue(), 3.0e-12);
130 assertTrue(optimizer.getEvaluations() > 60);
131 assertTrue(optimizer.getEvaluations() < 90);
132
133 optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { +1, 0 });
134 assertEquals(xP, optimum.getPoint()[0], 4.0e-6);
135 assertEquals(yP, optimum.getPoint()[1], 5.0e-6);
136 assertEquals(valueXpYp, optimum.getValue(), 7.0e-12);
137 assertTrue(optimizer.getEvaluations() > 60);
138 assertTrue(optimizer.getEvaluations() < 90);
139
140 }
141
142 @Test
143 public void testRosenbrock()
144 throws FunctionEvaluationException, ConvergenceException {
145
146 Rosenbrock rosenbrock = new Rosenbrock();
147 NelderMead optimizer = new NelderMead();
148 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1, 1.0e-3));
149 optimizer.setMaxIterations(100);
150 optimizer.setStartConfiguration(new double[][] {
151 { -1.2, 1.0 }, { 0.9, 1.2 } , { 3.5, -2.3 }
152 });
153 RealPointValuePair optimum =
154 optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1.0 });
155
156 assertEquals(rosenbrock.getCount(), optimizer.getEvaluations());
157 assertTrue(optimizer.getEvaluations() > 40);
158 assertTrue(optimizer.getEvaluations() < 50);
159 assertTrue(optimum.getValue() < 8.0e-4);
160
161 }
162
163 @Test
164 public void testPowell()
165 throws FunctionEvaluationException, ConvergenceException {
166
167 Powell powell = new Powell();
168 NelderMead optimizer = new NelderMead();
169 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-3));
170 optimizer.setMaxIterations(200);
171 RealPointValuePair optimum =
172 optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
173 assertEquals(powell.getCount(), optimizer.getEvaluations());
174 assertTrue(optimizer.getEvaluations() > 110);
175 assertTrue(optimizer.getEvaluations() < 130);
176 assertTrue(optimum.getValue() < 2.0e-3);
177
178 }
179
180 @Test
181 public void testLeastSquares1()
182 throws FunctionEvaluationException, ConvergenceException {
183
184 final RealMatrix factors =
185 new Array2DRowRealMatrix(new double[][] {
186 { 1.0, 0.0 },
187 { 0.0, 1.0 }
188 }, false);
189 LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
190 public double[] value(double[] variables) {
191 return factors.operate(variables);
192 }
193 }, new double[] { 2.0, -3.0 });
194 NelderMead optimizer = new NelderMead();
195 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
196 optimizer.setMaxIterations(200);
197 RealPointValuePair optimum =
198 optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
199 assertEquals( 2.0, optimum.getPointRef()[0], 3.0e-5);
200 assertEquals(-3.0, optimum.getPointRef()[1], 4.0e-4);
201 assertTrue(optimizer.getEvaluations() > 60);
202 assertTrue(optimizer.getEvaluations() < 80);
203 assertTrue(optimum.getValue() < 1.0e-6);
204 }
205
206 @Test
207 public void testLeastSquares2()
208 throws FunctionEvaluationException, ConvergenceException {
209
210 final RealMatrix factors =
211 new Array2DRowRealMatrix(new double[][] {
212 { 1.0, 0.0 },
213 { 0.0, 1.0 }
214 }, false);
215 LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
216 public double[] value(double[] variables) {
217 return factors.operate(variables);
218 }
219 }, new double[] { 2.0, -3.0 }, new double[] { 10.0, 0.1 });
220 NelderMead optimizer = new NelderMead();
221 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
222 optimizer.setMaxIterations(200);
223 RealPointValuePair optimum =
224 optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
225 assertEquals( 2.0, optimum.getPointRef()[0], 5.0e-5);
226 assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4);
227 assertTrue(optimizer.getEvaluations() > 60);
228 assertTrue(optimizer.getEvaluations() < 80);
229 assertTrue(optimum.getValue() < 1.0e-6);
230 }
231
232 @Test
233 public void testLeastSquares3()
234 throws FunctionEvaluationException, ConvergenceException {
235
236 final RealMatrix factors =
237 new Array2DRowRealMatrix(new double[][] {
238 { 1.0, 0.0 },
239 { 0.0, 1.0 }
240 }, false);
241 LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
242 public double[] value(double[] variables) {
243 return factors.operate(variables);
244 }
245 }, new double[] { 2.0, -3.0 }, new Array2DRowRealMatrix(new double [][] {
246 { 1.0, 1.2 }, { 1.2, 2.0 }
247 }));
248 NelderMead optimizer = new NelderMead();
249 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
250 optimizer.setMaxIterations(200);
251 RealPointValuePair optimum =
252 optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
253 assertEquals( 2.0, optimum.getPointRef()[0], 2.0e-3);
254 assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4);
255 assertTrue(optimizer.getEvaluations() > 60);
256 assertTrue(optimizer.getEvaluations() < 80);
257 assertTrue(optimum.getValue() < 1.0e-6);
258 }
259
260 @Test(expected = MaxIterationsExceededException.class)
261 public void testMaxIterations() throws MathException {
262 try {
263 Powell powell = new Powell();
264 NelderMead optimizer = new NelderMead();
265 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-3));
266 optimizer.setMaxIterations(20);
267 optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
268 } catch (OptimizationException oe) {
269 if (oe.getCause() instanceof ConvergenceException) {
270 throw (ConvergenceException) oe.getCause();
271 }
272 throw oe;
273 }
274 }
275
276 @Test(expected = MaxEvaluationsExceededException.class)
277 public void testMaxEvaluations() throws MathException {
278 try {
279 Powell powell = new Powell();
280 NelderMead optimizer = new NelderMead();
281 optimizer.setConvergenceChecker(new SimpleRealPointChecker(-1.0, 1.0e-3));
282 optimizer.setMaxEvaluations(20);
283 optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
284 } catch (FunctionEvaluationException fee) {
285 if (fee.getCause() instanceof ConvergenceException) {
286 throw (ConvergenceException) fee.getCause();
287 }
288 throw fee;
289 }
290 }
291
292 private static class Rosenbrock implements MultivariateRealFunction {
293
294 private int count;
295
296 public Rosenbrock() {
297 count = 0;
298 }
299
300 public double value(double[] x) throws FunctionEvaluationException {
301 ++count;
302 double a = x[1] - x[0] * x[0];
303 double b = 1.0 - x[0];
304 return 100 * a * a + b * b;
305 }
306
307 public int getCount() {
308 return count;
309 }
310
311 }
312
313 private static class Powell implements MultivariateRealFunction {
314
315 private int count;
316
317 public Powell() {
318 count = 0;
319 }
320
321 public double value(double[] x) throws FunctionEvaluationException {
322 ++count;
323 double a = x[0] + 10 * x[1];
324 double b = x[2] - x[3];
325 double c = x[1] - 2 * x[2];
326 double d = x[0] - x[3];
327 return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
328 }
329
330 public int getCount() {
331 return count;
332 }
333
334 }
335
336 }