1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *      http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.commons.math.optimization.general;
19
20import org.apache.commons.math.ConvergenceException;
21import org.apache.commons.math.FunctionEvaluationException;
22import org.apache.commons.math.analysis.UnivariateRealFunction;
23import org.apache.commons.math.analysis.solvers.BrentSolver;
24import org.apache.commons.math.analysis.solvers.UnivariateRealSolver;
25import org.apache.commons.math.exception.util.LocalizedFormats;
26import org.apache.commons.math.optimization.GoalType;
27import org.apache.commons.math.optimization.OptimizationException;
28import org.apache.commons.math.optimization.RealPointValuePair;
29import org.apache.commons.math.util.FastMath;
30
31/**
32 * Non-linear conjugate gradient optimizer.
33 * <p>
34 * This class supports both the Fletcher-Reeves and the Polak-Ribi&egrave;re
35 * update formulas for the conjugate search directions. It also supports
36 * optional preconditioning.
37 * </p>
38 *
39 * @version $Revision: 1070725 $ $Date: 2011-02-15 02:31:12 +0100 (mar. 15 févr. 2011) $
40 * @since 2.0
41 *
42 */
43
44public class NonLinearConjugateGradientOptimizer
45    extends AbstractScalarDifferentiableOptimizer {
46
47    /** Update formula for the beta parameter. */
48    private final ConjugateGradientFormula updateFormula;
49
50    /** Preconditioner (may be null). */
51    private Preconditioner preconditioner;
52
53    /** solver to use in the line search (may be null). */
54    private UnivariateRealSolver solver;
55
56    /** Initial step used to bracket the optimum in line search. */
57    private double initialStep;
58
59    /** Simple constructor with default settings.
60     * <p>The convergence check is set to a {@link
61     * org.apache.commons.math.optimization.SimpleVectorialValueChecker}
62     * and the maximal number of iterations is set to
63     * {@link AbstractScalarDifferentiableOptimizer#DEFAULT_MAX_ITERATIONS}.
64     * @param updateFormula formula to use for updating the &beta; parameter,
65     * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
66     * ConjugateGradientFormula#POLAK_RIBIERE}
67     */
68    public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula) {
69        this.updateFormula = updateFormula;
70        preconditioner     = null;
71        solver             = null;
72        initialStep        = 1.0;
73    }
74
75    /**
76     * Set the preconditioner.
77     * @param preconditioner preconditioner to use for next optimization,
78     * may be null to remove an already registered preconditioner
79     */
80    public void setPreconditioner(final Preconditioner preconditioner) {
81        this.preconditioner = preconditioner;
82    }
83
84    /**
85     * Set the solver to use during line search.
86     * @param lineSearchSolver solver to use during line search, may be null
87     * to remove an already registered solver and fall back to the
88     * default {@link BrentSolver Brent solver}.
89     */
90    public void setLineSearchSolver(final UnivariateRealSolver lineSearchSolver) {
91        this.solver = lineSearchSolver;
92    }
93
94    /**
95     * Set the initial step used to bracket the optimum in line search.
96     * <p>
97     * The initial step is a factor with respect to the search direction,
98     * which itself is roughly related to the gradient of the function
99     * </p>
100     * @param initialStep initial step used to bracket the optimum in line search,
101     * if a non-positive value is used, the initial step is reset to its
102     * default value of 1.0
103     */
104    public void setInitialStep(final double initialStep) {
105        if (initialStep <= 0) {
106            this.initialStep = 1.0;
107        } else {
108            this.initialStep = initialStep;
109        }
110    }
111
112    /** {@inheritDoc} */
113    @Override
114    protected RealPointValuePair doOptimize()
115        throws FunctionEvaluationException, OptimizationException, IllegalArgumentException {
116        try {
117
118            // initialization
119            if (preconditioner == null) {
120                preconditioner = new IdentityPreconditioner();
121            }
122            if (solver == null) {
123                solver = new BrentSolver();
124            }
125            final int n = point.length;
126            double[] r = computeObjectiveGradient(point);
127            if (goal == GoalType.MINIMIZE) {
128                for (int i = 0; i < n; ++i) {
129                    r[i] = -r[i];
130                }
131            }
132
133            // initial search direction
134            double[] steepestDescent = preconditioner.precondition(point, r);
135            double[] searchDirection = steepestDescent.clone();
136
137            double delta = 0;
138            for (int i = 0; i < n; ++i) {
139                delta += r[i] * searchDirection[i];
140            }
141
142            RealPointValuePair current = null;
143            while (true) {
144
145                final double objective = computeObjectiveValue(point);
146                RealPointValuePair previous = current;
147                current = new RealPointValuePair(point, objective);
148                if (previous != null) {
149                    if (checker.converged(getIterations(), previous, current)) {
150                        // we have found an optimum
151                        return current;
152                    }
153                }
154
155                incrementIterationsCounter();
156
157                double dTd = 0;
158                for (final double di : searchDirection) {
159                    dTd += di * di;
160                }
161
162                // find the optimal step in the search direction
163                final UnivariateRealFunction lsf = new LineSearchFunction(searchDirection);
164                final double step = solver.solve(lsf, 0, findUpperBound(lsf, 0, initialStep));
165
166                // validate new point
167                for (int i = 0; i < point.length; ++i) {
168                    point[i] += step * searchDirection[i];
169                }
170                r = computeObjectiveGradient(point);
171                if (goal == GoalType.MINIMIZE) {
172                    for (int i = 0; i < n; ++i) {
173                        r[i] = -r[i];
174                    }
175                }
176
177                // compute beta
178                final double deltaOld = delta;
179                final double[] newSteepestDescent = preconditioner.precondition(point, r);
180                delta = 0;
181                for (int i = 0; i < n; ++i) {
182                    delta += r[i] * newSteepestDescent[i];
183                }
184
185                final double beta;
186                if (updateFormula == ConjugateGradientFormula.FLETCHER_REEVES) {
187                    beta = delta / deltaOld;
188                } else {
189                    double deltaMid = 0;
190                    for (int i = 0; i < r.length; ++i) {
191                        deltaMid += r[i] * steepestDescent[i];
192                    }
193                    beta = (delta - deltaMid) / deltaOld;
194                }
195                steepestDescent = newSteepestDescent;
196
197                // compute conjugate search direction
198                if ((getIterations() % n == 0) || (beta < 0)) {
199                    // break conjugation: reset search direction
200                    searchDirection = steepestDescent.clone();
201                } else {
202                    // compute new conjugate search direction
203                    for (int i = 0; i < n; ++i) {
204                        searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
205                    }
206                }
207
208            }
209
210        } catch (ConvergenceException ce) {
211            throw new OptimizationException(ce);
212        }
213    }
214
215    /**
216     * Find the upper bound b ensuring bracketing of a root between a and b
217     * @param f function whose root must be bracketed
218     * @param a lower bound of the interval
219     * @param h initial step to try
220     * @return b such that f(a) and f(b) have opposite signs
221     * @exception FunctionEvaluationException if the function cannot be computed
222     * @exception OptimizationException if no bracket can be found
223     */
224    private double findUpperBound(final UnivariateRealFunction f,
225                                  final double a, final double h)
226        throws FunctionEvaluationException, OptimizationException {
227        final double yA = f.value(a);
228        double yB = yA;
229        for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) {
230            final double b = a + step;
231            yB = f.value(b);
232            if (yA * yB <= 0) {
233                return b;
234            }
235        }
236        throw new OptimizationException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
237    }
238
239    /** Default identity preconditioner. */
240    private static class IdentityPreconditioner implements Preconditioner {
241
242        /** {@inheritDoc} */
243        public double[] precondition(double[] variables, double[] r) {
244            return r.clone();
245        }
246
247    }
248
249    /** Internal class for line search.
250     * <p>
251     * The function represented by this class is the dot product of
252     * the objective function gradient and the search direction. Its
253     * value is zero when the gradient is orthogonal to the search
254     * direction, i.e. when the objective function value is a local
255     * extremum along the search direction.
256     * </p>
257     */
258    private class LineSearchFunction implements UnivariateRealFunction {
259        /** Search direction. */
260        private final double[] searchDirection;
261
262        /** Simple constructor.
263         * @param searchDirection search direction
264         */
265        public LineSearchFunction(final double[] searchDirection) {
266            this.searchDirection = searchDirection;
267        }
268
269        /** {@inheritDoc} */
270        public double value(double x) throws FunctionEvaluationException {
271
272            // current point in the search direction
273            final double[] shiftedPoint = point.clone();
274            for (int i = 0; i < shiftedPoint.length; ++i) {
275                shiftedPoint[i] += x * searchDirection[i];
276            }
277
278            // gradient of the objective function
279            final double[] gradient;
280            gradient = computeObjectiveGradient(shiftedPoint);
281
282            // dot product with the search direction
283            double dotProduct = 0;
284            for (int i = 0; i < gradient.length; ++i) {
285                dotProduct += gradient[i] * searchDirection[i];
286            }
287
288            return dotProduct;
289
290        }
291
292    }
293
294}
295