4#ifndef LBFGSPP_LINE_SEARCH_MORE_THUENTE_H
5#define LBFGSPP_LINE_SEARCH_MORE_THUENTE_H
9#include "LBFGSpp/Param.h"
25template <
typename Scalar>
29 using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
34 static Scalar quadratic_minimizer(
const Scalar& a,
const Scalar& b,
const Scalar& fa,
const Scalar& ga,
const Scalar& fb)
36 const Scalar ba = b - a;
37 const Scalar w = Scalar(0.5) * ba * ga / (fa - fb + ba * ga);
46 static Scalar quadratic_minimizer(
const Scalar& a,
const Scalar& b,
const Scalar& ga,
const Scalar& gb)
48 const Scalar w = ga / (ga - gb);
49 return a + w * (b - a);
55 static Scalar cubic_minimizer(
const Scalar& a,
const Scalar& b,
const Scalar& fa,
const Scalar& fb,
56 const Scalar& ga,
const Scalar& gb,
bool& exists)
61 const Scalar apb = a + b;
62 const Scalar ba = b - a;
63 const Scalar ba2 = ba * ba;
64 const Scalar fba = fb - fa;
65 const Scalar gba = gb - ga;
67 const Scalar z3 = (ga + gb) * ba - Scalar(2) * fba;
68 const Scalar z2 = Scalar(0.5) * (gba * ba2 - Scalar(3) * apb * z3);
69 const Scalar z1 = fba * ba2 - apb * z2 - (a * apb + b * b) * z3;
73 const Scalar eps = std::numeric_limits<Scalar>::epsilon();
74 if (abs(z3) < eps * abs(z2) || abs(z3) < eps * abs(z1))
77 exists = (z2 * ba > Scalar(0));
79 return exists ? (-Scalar(0.5) * z1 / z2) : b;
88 const Scalar u = z2 / (Scalar(3) * z3), v = z1 / z2;
89 const Scalar vu = v / u;
90 exists = (vu <= Scalar(1));
102 Scalar r1 = Scalar(0), r2 = Scalar(0);
103 if (abs(u) >= abs(v))
105 const Scalar w = Scalar(1) + sqrt(Scalar(1) - vu);
111 const Scalar sqrtd = sqrt(abs(u)) * sqrt(abs(v)) * sqrt(1 - u / v);
115 return (z3 * ba > Scalar(0)) ? ((std::max)(r1, r2)) : ((std::min)(r1, r2));
120 static Scalar step_selection(
121 const Scalar& al,
const Scalar& au,
const Scalar& at,
122 const Scalar& fl,
const Scalar& fu,
const Scalar& ft,
123 const Scalar& gl,
const Scalar& gu,
const Scalar& gt)
131 if (!std::isfinite(ft) || !std::isfinite(gt))
132 return (al + at) / Scalar(2);
138 const Scalar ac = cubic_minimizer(al, at, fl, ft, gl, gt, ac_exists);
139 const Scalar aq = quadratic_minimizer(al, at, fl, gl, ft);
148 return (abs(ac - al) < abs(aq - al)) ? ac : ((aq + ac) / Scalar(2));
152 const Scalar as = quadratic_minimizer(al, at, gl, gt);
154 if (gt * gl < Scalar(0))
155 return (abs(ac - at) >= abs(as - at)) ? ac : as;
158 const Scalar deltal = Scalar(1.1), deltau = Scalar(0.66);
159 if (abs(gt) < abs(gl))
166 const Scalar res = (ac_exists &&
167 (ac - at) * (at - al) > Scalar(0) &&
168 abs(ac - at) < abs(as - at)) ?
173 std::min(at + deltau * (au - at), res) :
174 std::max(at + deltau * (au - at), res);
178 if ((!std::isfinite(au)) || (!std::isfinite(fu)) || (!std::isfinite(gu)))
179 return at + deltal * (at - al);
183 const Scalar ae = cubic_minimizer(at, au, ft, fu, gt, gu, ae_exists);
187 std::min(at + deltau * (au - at), ae) :
188 std::max(at + deltau * (au - at), ae);
213 template <
typename Foo,
typename SolverParam>
215 const Vector& xp,
const Vector& drt,
const Scalar& step_max,
216 Scalar& step, Scalar& fx, Vector& grad, Scalar& dg, Vector& x)
222 if (step <= Scalar(0))
223 throw std::invalid_argument(
"'step' must be positive");
225 throw std::invalid_argument(
"'step' exceeds 'step_max'");
228 const Scalar fx_init = fx;
230 const Scalar dg_init = dg;
235 if (dg_init >= Scalar(0))
236 throw std::logic_error(
"the moving direction does not decrease the objective function value");
240 const Scalar test_decr = param.ftol * dg_init;
242 const Scalar test_curv = -param.wolfe * dg_init;
245 Scalar I_lo = Scalar(0), I_hi = std::numeric_limits<Scalar>::infinity();
246 Scalar fI_lo = Scalar(0), fI_hi = std::numeric_limits<Scalar>::infinity();
247 Scalar gI_lo = (Scalar(1) - param.ftol) * dg_init, gI_hi = std::numeric_limits<Scalar>::infinity();
250 Vector x_lo = xp, grad_lo = grad;
251 Scalar fx_lo = fx_init, dg_lo = dg_init;
254 x.noalias() = xp + step * drt;
261 if (fx <= fx_init + step * test_decr && abs(dg) <= test_curv)
269 const Scalar delta = Scalar(1.1);
271 for (iter = 0; iter < param.max_linesearch; iter++)
276 const Scalar ft = fx - fx_init - step * test_decr;
277 const Scalar gt = dg - param.ftol * dg_init;
284 new_step = step_selection(I_lo, I_hi, step, fI_lo, fI_hi, ft, gI_lo, gI_hi, gt);
287 if (new_step <= param.min_step)
288 new_step = (I_lo + step) / Scalar(2);
296 else if (gt * (I_lo - step) > Scalar(0))
302 new_step = std::min(step_max, step + delta * (step - I_lo));
329 new_step = step_selection(I_lo, I_hi, step, fI_lo, fI_hi, ft, gI_lo, gI_hi, gt);
356 if (step == step_max && new_step >= step_max)
369 if (step < param.min_step)
370 throw std::runtime_error(
"the line search step became smaller than the minimum value allowed");
372 if (step > param.max_step)
373 throw std::runtime_error(
"the line search step became larger than the maximum value allowed");
376 x.noalias() = xp + step * drt;
383 if (fx <= fx_init + step * test_decr && abs(dg) <= test_curv)
401 if (step >= step_max)
403 const Scalar ft_bound = fx - fx_init - step * test_decr;
404 if (ft_bound <= fI_lo)
416 if (iter >= param.max_linesearch)
422 const Scalar ft = fx - fx_init - step * test_decr;
427 if (I_lo <= Scalar(0))
428 throw std::runtime_error(
"the line search routine is unable to sufficiently decrease the function value");
static void LineSearch(Foo &f, const SolverParam ¶m, const Vector &xp, const Vector &drt, const Scalar &step_max, Scalar &step, Scalar &fx, Vector &grad, Scalar &dg, Vector &x)