001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.commons.math.optimization;
019
020 import static org.junit.Assert.assertEquals;
021 import static org.junit.Assert.assertTrue;
022
023 import java.awt.geom.Point2D;
024 import java.util.ArrayList;
025
026 import org.apache.commons.math.FunctionEvaluationException;
027 import org.apache.commons.math.analysis.DifferentiableMultivariateRealFunction;
028 import org.apache.commons.math.analysis.MultivariateRealFunction;
029 import org.apache.commons.math.analysis.MultivariateVectorialFunction;
030 import org.apache.commons.math.analysis.solvers.BrentSolver;
031 import org.apache.commons.math.optimization.general.ConjugateGradientFormula;
032 import org.apache.commons.math.optimization.general.NonLinearConjugateGradientOptimizer;
033 import org.apache.commons.math.random.GaussianRandomGenerator;
034 import org.apache.commons.math.random.JDKRandomGenerator;
035 import org.apache.commons.math.random.RandomVectorGenerator;
036 import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
037 import org.junit.Test;
038
039 public class MultiStartDifferentiableMultivariateRealOptimizerTest {
040
041 @Test
042 public void testCircleFitting() throws FunctionEvaluationException, OptimizationException {
043 Circle circle = new Circle();
044 circle.addPoint( 30.0, 68.0);
045 circle.addPoint( 50.0, -6.0);
046 circle.addPoint(110.0, -20.0);
047 circle.addPoint( 35.0, 15.0);
048 circle.addPoint( 45.0, 97.0);
049 NonLinearConjugateGradientOptimizer underlying =
050 new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
051 JDKRandomGenerator g = new JDKRandomGenerator();
052 g.setSeed(753289573253l);
053 RandomVectorGenerator generator =
054 new UncorrelatedRandomVectorGenerator(new double[] { 50.0, 50.0 }, new double[] { 10.0, 10.0 },
055 new GaussianRandomGenerator(g));
056 MultiStartDifferentiableMultivariateRealOptimizer optimizer =
057 new MultiStartDifferentiableMultivariateRealOptimizer(underlying, 10, generator);
058 optimizer.setMaxIterations(100);
059 assertEquals(100, optimizer.getMaxIterations());
060 optimizer.setMaxEvaluations(100);
061 assertEquals(100, optimizer.getMaxEvaluations());
062 optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-10, 1.0e-10));
063 BrentSolver solver = new BrentSolver();
064 solver.setAbsoluteAccuracy(1.0e-13);
065 solver.setRelativeAccuracy(1.0e-15);
066 RealPointValuePair optimum =
067 optimizer.optimize(circle, GoalType.MINIMIZE, new double[] { 98.680, 47.345 });
068 RealPointValuePair[] optima = optimizer.getOptima();
069 for (RealPointValuePair o : optima) {
070 Point2D.Double center = new Point2D.Double(o.getPointRef()[0], o.getPointRef()[1]);
071 assertEquals(69.960161753, circle.getRadius(center), 1.0e-8);
072 assertEquals(96.075902096, center.x, 1.0e-8);
073 assertEquals(48.135167894, center.y, 1.0e-8);
074 }
075 assertTrue(optimizer.getGradientEvaluations() > 650);
076 assertTrue(optimizer.getGradientEvaluations() < 700);
077 assertTrue(optimizer.getEvaluations() > 70);
078 assertTrue(optimizer.getEvaluations() < 90);
079 assertTrue(optimizer.getIterations() > 70);
080 assertTrue(optimizer.getIterations() < 90);
081 assertEquals(3.1267527, optimum.getValue(), 1.0e-8);
082 }
083
084 private static class Circle implements DifferentiableMultivariateRealFunction {
085
086 private ArrayList<Point2D.Double> points;
087
088 public Circle() {
089 points = new ArrayList<Point2D.Double>();
090 }
091
092 public void addPoint(double px, double py) {
093 points.add(new Point2D.Double(px, py));
094 }
095
096 public double getRadius(Point2D.Double center) {
097 double r = 0;
098 for (Point2D.Double point : points) {
099 r += point.distance(center);
100 }
101 return r / points.size();
102 }
103
104 private double[] gradient(double[] point) {
105
106 // optimal radius
107 Point2D.Double center = new Point2D.Double(point[0], point[1]);
108 double radius = getRadius(center);
109
110 // gradient of the sum of squared residuals
111 double dJdX = 0;
112 double dJdY = 0;
113 for (Point2D.Double pk : points) {
114 double dk = pk.distance(center);
115 dJdX += (center.x - pk.x) * (dk - radius) / dk;
116 dJdY += (center.y - pk.y) * (dk - radius) / dk;
117 }
118 dJdX *= 2;
119 dJdY *= 2;
120
121 return new double[] { dJdX, dJdY };
122
123 }
124
125 public double value(double[] variables)
126 throws IllegalArgumentException, FunctionEvaluationException {
127
128 Point2D.Double center = new Point2D.Double(variables[0], variables[1]);
129 double radius = getRadius(center);
130
131 double sum = 0;
132 for (Point2D.Double point : points) {
133 double di = point.distance(center) - radius;
134 sum += di * di;
135 }
136
137 return sum;
138
139 }
140
141 public MultivariateVectorialFunction gradient() {
142 return new MultivariateVectorialFunction() {
143 public double[] value(double[] point) {
144 return gradient(point);
145 }
146 };
147 }
148
149 public MultivariateRealFunction partialDerivative(final int k) {
150 return new MultivariateRealFunction() {
151 public double value(double[] point) {
152 return gradient(point)[k];
153 }
154 };
155 }
156
157 }
158
159 }