5#ifndef LBFGSPP_LINE_SEARCH_NOCEDAL_WRIGHT_H
6#define LBFGSPP_LINE_SEARCH_NOCEDAL_WRIGHT_H
20template <
typename Scalar>
24 using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
29 static Scalar quad_interp(
const Scalar& step_lo,
const Scalar& step_hi,
30 const Scalar& fx_lo,
const Scalar& fx_hi,
const Scalar& dg_lo)
41 const Scalar fdiff = fx_hi - fx_lo;
42 const Scalar sdiff = step_hi - step_lo;
43 const Scalar smid = (step_hi + step_lo) / Scalar(2);
44 Scalar step_candid = fdiff * step_lo - smid * sdiff * dg_lo;
45 step_candid = step_candid / (fdiff - sdiff * dg_lo);
50 const bool candid_nan = !(std::isfinite(step_candid));
51 const Scalar end_dist = std::min(abs(step_candid - step_lo), abs(step_candid - step_hi));
52 const bool near_end = end_dist < Scalar(0.01) * abs(sdiff);
53 const bool bisect = candid_nan ||
54 (step_candid <= std::min(step_lo, step_hi)) ||
55 (step_candid >= std::max(step_lo, step_hi)) ||
57 const Scalar step = bisect ? smid : step_candid;
83 template <
typename Foo>
85 const Vector& xp,
const Vector& drt,
const Scalar& step_max,
86 Scalar& step, Scalar& fx, Vector& grad, Scalar& dg, Vector& x)
89 if (step <= Scalar(0))
90 throw std::invalid_argument(
"'step' must be positive");
93 throw std::invalid_argument(
"'param.linesearch' must be 'LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE' for LineSearchNocedalWright");
106 const Scalar expansion = Scalar(2);
109 const Scalar fx_init = fx;
111 const Scalar dg_init = dg;
113 if (dg_init > Scalar(0))
114 throw std::logic_error(
"the moving direction increases the objective function value");
116 const Scalar test_decr = param.
ftol * dg_init,
117 test_curv = -param.
wolfe * dg_init;
121 Scalar step_hi, fx_hi;
122 Scalar step_lo = Scalar(0), fx_lo = fx_init, dg_lo = dg_init;
125 Vector x_lo = xp, grad_lo = grad;
143 x.noalias() = xp + step * drt;
148 if (fx - fx_init > step * test_decr || (Scalar(0) < step_lo && fx >= fx_lo))
159 if (std::abs(dg) <= test_curv)
213 step = quad_interp(step_lo, step_hi, fx_lo, fx_hi, dg_lo);
216 x.noalias() = xp + step * drt;
221 if (fx - fx_init > step * test_decr || fx >= fx_lo)
224 throw std::runtime_error(
"the line search routine failed, possibly due to insufficient numeric precision");
233 if (std::abs(dg) <= test_curv)
236 if (dg * (step_hi - step_lo) >= Scalar(0))
244 throw std::runtime_error(
"the line search routine failed, possibly due to insufficient numeric precision");
263 if (step_lo <= Scalar(0))
264 throw std::runtime_error(
"the line search routine failed, unable to sufficiently decrease the function value");
static void LineSearch(Foo &f, const LBFGSParam< Scalar > ¶m, const Vector &xp, const Vector &drt, const Scalar &step_max, Scalar &step, Scalar &fx, Vector &grad, Scalar &dg, Vector &x)
@ LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE