1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *      http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.commons.math.optimization.direct;
19
20import java.util.Comparator;
21
22import org.apache.commons.math.FunctionEvaluationException;
23import org.apache.commons.math.optimization.OptimizationException;
24import org.apache.commons.math.optimization.RealPointValuePair;
25
26/**
27 * This class implements the Nelder-Mead direct search method.
28 *
29 * @version $Revision: 1070725 $ $Date: 2011-02-15 02:31:12 +0100 (mar. 15 févr. 2011) $
30 * @see MultiDirectional
31 * @since 1.2
32 */
33public class NelderMead extends DirectSearchOptimizer {
34
35    /** Reflection coefficient. */
36    private final double rho;
37
38    /** Expansion coefficient. */
39    private final double khi;
40
41    /** Contraction coefficient. */
42    private final double gamma;
43
44    /** Shrinkage coefficient. */
45    private final double sigma;
46
47    /** Build a Nelder-Mead optimizer with default coefficients.
48     * <p>The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
49     * for both gamma and sigma.</p>
50     */
51    public NelderMead() {
52        this.rho   = 1.0;
53        this.khi   = 2.0;
54        this.gamma = 0.5;
55        this.sigma = 0.5;
56    }
57
58    /** Build a Nelder-Mead optimizer with specified coefficients.
59     * @param rho reflection coefficient
60     * @param khi expansion coefficient
61     * @param gamma contraction coefficient
62     * @param sigma shrinkage coefficient
63     */
64    public NelderMead(final double rho, final double khi,
65                      final double gamma, final double sigma) {
66        this.rho   = rho;
67        this.khi   = khi;
68        this.gamma = gamma;
69        this.sigma = sigma;
70    }
71
72    /** {@inheritDoc} */
73    @Override
74    protected void iterateSimplex(final Comparator<RealPointValuePair> comparator)
75        throws FunctionEvaluationException, OptimizationException {
76
77        incrementIterationsCounter();
78
79        // the simplex has n+1 point if dimension is n
80        final int n = simplex.length - 1;
81
82        // interesting values
83        final RealPointValuePair best       = simplex[0];
84        final RealPointValuePair secondBest = simplex[n-1];
85        final RealPointValuePair worst      = simplex[n];
86        final double[] xWorst = worst.getPointRef();
87
88        // compute the centroid of the best vertices
89        // (dismissing the worst point at index n)
90        final double[] centroid = new double[n];
91        for (int i = 0; i < n; ++i) {
92            final double[] x = simplex[i].getPointRef();
93            for (int j = 0; j < n; ++j) {
94                centroid[j] += x[j];
95            }
96        }
97        final double scaling = 1.0 / n;
98        for (int j = 0; j < n; ++j) {
99            centroid[j] *= scaling;
100        }
101
102        // compute the reflection point
103        final double[] xR = new double[n];
104        for (int j = 0; j < n; ++j) {
105            xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
106        }
107        final RealPointValuePair reflected = new RealPointValuePair(xR, evaluate(xR), false);
108
109        if ((comparator.compare(best, reflected) <= 0) &&
110            (comparator.compare(reflected, secondBest) < 0)) {
111
112            // accept the reflected point
113            replaceWorstPoint(reflected, comparator);
114
115        } else if (comparator.compare(reflected, best) < 0) {
116
117            // compute the expansion point
118            final double[] xE = new double[n];
119            for (int j = 0; j < n; ++j) {
120                xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
121            }
122            final RealPointValuePair expanded = new RealPointValuePair(xE, evaluate(xE), false);
123
124            if (comparator.compare(expanded, reflected) < 0) {
125                // accept the expansion point
126                replaceWorstPoint(expanded, comparator);
127            } else {
128                // accept the reflected point
129                replaceWorstPoint(reflected, comparator);
130            }
131
132        } else {
133
134            if (comparator.compare(reflected, worst) < 0) {
135
136                // perform an outside contraction
137                final double[] xC = new double[n];
138                for (int j = 0; j < n; ++j) {
139                    xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
140                }
141                final RealPointValuePair outContracted = new RealPointValuePair(xC, evaluate(xC), false);
142
143                if (comparator.compare(outContracted, reflected) <= 0) {
144                    // accept the contraction point
145                    replaceWorstPoint(outContracted, comparator);
146                    return;
147                }
148
149            } else {
150
151                // perform an inside contraction
152                final double[] xC = new double[n];
153                for (int j = 0; j < n; ++j) {
154                    xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
155                }
156                final RealPointValuePair inContracted = new RealPointValuePair(xC, evaluate(xC), false);
157
158                if (comparator.compare(inContracted, worst) < 0) {
159                    // accept the contraction point
160                    replaceWorstPoint(inContracted, comparator);
161                    return;
162                }
163
164            }
165
166            // perform a shrink
167            final double[] xSmallest = simplex[0].getPointRef();
168            for (int i = 1; i < simplex.length; ++i) {
169                final double[] x = simplex[i].getPoint();
170                for (int j = 0; j < n; ++j) {
171                    x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
172                }
173                simplex[i] = new RealPointValuePair(x, Double.NaN, false);
174            }
175            evaluateSimplex(comparator);
176
177        }
178
179    }
180
181}
182