LBFGS++
Loading...
Searching...
No Matches
LineSearchBacktracking.h
1// Copyright (C) 2016-2023 Yixuan Qiu <yixuan.qiu@cos.name>
2// Under MIT license
3
4#ifndef LBFGSPP_LINE_SEARCH_BACKTRACKING_H
5#define LBFGSPP_LINE_SEARCH_BACKTRACKING_H
6
7#include <Eigen/Core>
8#include <stdexcept> // std::runtime_error
9
10namespace LBFGSpp {
11
15template <typename Scalar>
17{
18private:
19 using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
20
21public:
43 template <typename Foo>
44 static void LineSearch(Foo& f, const LBFGSParam<Scalar>& param,
45 const Vector& xp, const Vector& drt, const Scalar& step_max,
46 Scalar& step, Scalar& fx, Vector& grad, Scalar& dg, Vector& x)
47 {
48 // Decreasing and increasing factors
49 const Scalar dec = 0.5;
50 const Scalar inc = 2.1;
51
52 // Check the value of step
53 if (step <= Scalar(0))
54 throw std::invalid_argument("'step' must be positive");
55
56 // Save the function value at the current x
57 const Scalar fx_init = fx;
58 // Projection of gradient on the search direction
59 const Scalar dg_init = grad.dot(drt);
60 // Make sure d points to a descent direction
61 if (dg_init > 0)
62 throw std::logic_error("the moving direction increases the objective function value");
63
64 const Scalar test_decr = param.ftol * dg_init;
65 Scalar width;
66
67 int iter;
68 for (iter = 0; iter < param.max_linesearch; iter++)
69 {
70 // x_{k+1} = x_k + step * d_k
71 x.noalias() = xp + step * drt;
72 // Evaluate this candidate
73 fx = f(x, grad);
74
75 if (fx > fx_init + step * test_decr || (fx != fx))
76 {
77 width = dec;
78 }
79 else
80 {
81 dg = grad.dot(drt);
82
83 // Armijo condition is met
85 break;
86
87 if (dg < param.wolfe * dg_init)
88 {
89 width = inc;
90 }
91 else
92 {
93 // Regular Wolfe condition is met
95 break;
96
97 if (dg > -param.wolfe * dg_init)
98 {
99 width = dec;
100 }
101 else
102 {
103 // Strong Wolfe condition is met
104 break;
105 }
106 }
107 }
108
109 if (step < param.min_step)
110 throw std::runtime_error("the line search step became smaller than the minimum value allowed");
111
112 if (step > param.max_step)
113 throw std::runtime_error("the line search step became larger than the maximum value allowed");
114
115 step *= width;
116 }
117
118 if (iter >= param.max_linesearch)
119 throw std::runtime_error("the line search routine reached the maximum number of iterations");
120 }
121};
122
123} // namespace LBFGSpp
124
125#endif // LBFGSPP_LINE_SEARCH_BACKTRACKING_H
Scalar max_step
Definition: Param.h:147
Scalar min_step
Definition: Param.h:141
static void LineSearch(Foo &f, const LBFGSParam< Scalar > &param, const Vector &xp, const Vector &drt, const Scalar &step_max, Scalar &step, Scalar &fx, Vector &grad, Scalar &dg, Vector &x)
@ LBFGS_LINESEARCH_BACKTRACKING_ARMIJO
Definition: Param.h:35
@ LBFGS_LINESEARCH_BACKTRACKING_WOLFE
Definition: Param.h:51