LBFGS++
Loading...
Searching...
No Matches
BKLDLT.h
1// Copyright (C) 2020-2023 Yixuan Qiu <yixuan.qiu@cos.name>
2// Under MIT license
3
4#ifndef LBFGSPP_BK_LDLT_H
5#define LBFGSPP_BK_LDLT_H
6
7#include <vector>
8#include <stdexcept>
9#include <Eigen/Core>
10
12
13namespace LBFGSpp {
14
15enum COMPUTATION_INFO
16{
17 SUCCESSFUL = 0,
18 NOT_COMPUTED,
19 NUMERICAL_ISSUE
20};
21
22// Bunch-Kaufman LDLT decomposition
23// References:
24// 1. Bunch, J. R., & Kaufman, L. (1977). Some stable methods for calculating inertia and solving symmetric linear systems.
25// Mathematics of computation, 31(137), 163-179.
26// 2. Golub, G. H., & Van Loan, C. F. (2012). Matrix computations (Vol. 3). JHU press. Section 4.4.
27// 3. Bunch-Parlett diagonal pivoting <http://oz.nthu.edu.tw/~d947207/Chap13_GE3.ppt>
28// 4. Ashcraft, C., Grimes, R. G., & Lewis, J. G. (1998). Accurate symmetric indefinite linear equation solvers.
29// SIAM Journal on Matrix Analysis and Applications, 20(2), 513-561.
30template <typename Scalar = double>
31class BKLDLT
32{
33private:
34 using Index = Eigen::Index;
35 using Matrix = Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>;
36 using Vector = Eigen::Matrix<Scalar, Eigen::Dynamic, 1>;
37 using MapVec = Eigen::Map<Vector>;
38 using MapConstVec = Eigen::Map<const Vector>;
39
40 using IntVector = Eigen::Matrix<Index, Eigen::Dynamic, 1>;
41 using GenericVector = Eigen::Ref<Vector>;
42 using GenericMatrix = Eigen::Ref<Matrix>;
43 using ConstGenericMatrix = const Eigen::Ref<const Matrix>;
44 using ConstGenericVector = const Eigen::Ref<const Vector>;
45
46 Index m_n;
47 Vector m_data; // storage for a lower-triangular matrix
48 std::vector<Scalar*> m_colptr; // pointers to columns
49 IntVector m_perm; // [-2, -1, 3, 1, 4, 5]: 0 <-> 2, 1 <-> 1, 2 <-> 3, 3 <-> 1, 4 <-> 4, 5 <-> 5
50 std::vector<std::pair<Index, Index> > m_permc; // compressed version of m_perm: [(0, 2), (2, 3), (3, 1)]
51
52 bool m_computed;
53 int m_info;
54
55 // Access to elements
56 // Pointer to the k-th column
57 Scalar* col_pointer(Index k) { return m_colptr[k]; }
58 // A[i, j] -> m_colptr[j][i - j], i >= j
59 Scalar& coeff(Index i, Index j) { return m_colptr[j][i - j]; }
60 const Scalar& coeff(Index i, Index j) const { return m_colptr[j][i - j]; }
61 // A[i, i] -> m_colptr[i][0]
62 Scalar& diag_coeff(Index i) { return m_colptr[i][0]; }
63 const Scalar& diag_coeff(Index i) const { return m_colptr[i][0]; }
64
65 // Compute column pointers
66 void compute_pointer()
67 {
68 m_colptr.clear();
69 m_colptr.reserve(m_n);
70 Scalar* head = m_data.data();
71
72 for (Index i = 0; i < m_n; i++)
73 {
74 m_colptr.push_back(head);
75 head += (m_n - i);
76 }
77 }
78
79 // Copy mat - shift * I to m_data
80 void copy_data(ConstGenericMatrix& mat, int uplo, const Scalar& shift)
81 {
82 if (uplo == Eigen::Lower)
83 {
84 for (Index j = 0; j < m_n; j++)
85 {
86 const Scalar* begin = &mat.coeffRef(j, j);
87 const Index len = m_n - j;
88 std::copy(begin, begin + len, col_pointer(j));
89 diag_coeff(j) -= shift;
90 }
91 }
92 else
93 {
94 Scalar* dest = m_data.data();
95 for (Index i = 0; i < m_n; i++)
96 {
97 for (Index j = i; j < m_n; j++, dest++)
98 {
99 *dest = mat.coeff(i, j);
100 }
101 diag_coeff(i) -= shift;
102 }
103 }
104 }
105
106 // Compute compressed permutations
107 void compress_permutation()
108 {
109 for (Index i = 0; i < m_n; i++)
110 {
111 // Recover the permutation action
112 const Index perm = (m_perm[i] >= 0) ? (m_perm[i]) : (-m_perm[i] - 1);
113 if (perm != i)
114 m_permc.push_back(std::make_pair(i, perm));
115 }
116 }
117
118 // Working on the A[k:end, k:end] submatrix
119 // Exchange k <-> r
120 // Assume r >= k
121 void pivoting_1x1(Index k, Index r)
122 {
123 // No permutation
124 if (k == r)
125 {
126 m_perm[k] = r;
127 return;
128 }
129
130 // A[k, k] <-> A[r, r]
131 std::swap(diag_coeff(k), diag_coeff(r));
132
133 // A[(r+1):end, k] <-> A[(r+1):end, r]
134 std::swap_ranges(&coeff(r + 1, k), col_pointer(k + 1), &coeff(r + 1, r));
135
136 // A[(k+1):(r-1), k] <-> A[r, (k+1):(r-1)]
137 Scalar* src = &coeff(k + 1, k);
138 for (Index j = k + 1; j < r; j++, src++)
139 {
140 std::swap(*src, coeff(r, j));
141 }
142
143 m_perm[k] = r;
144 }
145
146 // Working on the A[k:end, k:end] submatrix
147 // Exchange [k+1, k] <-> [r, p]
148 // Assume p >= k, r >= k+1
149 void pivoting_2x2(Index k, Index r, Index p)
150 {
151 pivoting_1x1(k, p);
152 pivoting_1x1(k + 1, r);
153
154 // A[k+1, k] <-> A[r, k]
155 std::swap(coeff(k + 1, k), coeff(r, k));
156
157 // Use negative signs to indicate a 2x2 block
158 // Also minus one to distinguish a negative zero from a positive zero
159 m_perm[k] = -m_perm[k] - 1;
160 m_perm[k + 1] = -m_perm[k + 1] - 1;
161 }
162
163 // A[r1, c1:c2] <-> A[r2, c1:c2]
164 // Assume r2 >= r1 > c2 >= c1
165 void interchange_rows(Index r1, Index r2, Index c1, Index c2)
166 {
167 if (r1 == r2)
168 return;
169
170 for (Index j = c1; j <= c2; j++)
171 {
172 std::swap(coeff(r1, j), coeff(r2, j));
173 }
174 }
175
176 // lambda = |A[r, k]| = max{|A[k+1, k]|, ..., |A[end, k]|}
177 // Largest (in magnitude) off-diagonal element in the first column of the current reduced matrix
178 // r is the row index
179 // Assume k < end
180 Scalar find_lambda(Index k, Index& r)
181 {
182 using std::abs;
183
184 const Scalar* head = col_pointer(k); // => A[k, k]
185 const Scalar* end = col_pointer(k + 1);
186 // Start with r=k+1, lambda=A[k+1, k]
187 r = k + 1;
188 Scalar lambda = abs(head[1]);
189 // Scan remaining elements
190 for (const Scalar* ptr = head + 2; ptr < end; ptr++)
191 {
192 const Scalar abs_elem = abs(*ptr);
193 if (lambda < abs_elem)
194 {
195 lambda = abs_elem;
196 r = k + (ptr - head);
197 }
198 }
199
200 return lambda;
201 }
202
203 // sigma = |A[p, r]| = max {|A[k, r]|, ..., |A[end, r]|} \ {A[r, r]}
204 // Largest (in magnitude) off-diagonal element in the r-th column of the current reduced matrix
205 // p is the row index
206 // Assume k < r < end
207 Scalar find_sigma(Index k, Index r, Index& p)
208 {
209 using std::abs;
210
211 // First search A[r+1, r], ..., A[end, r], which has the same task as find_lambda()
212 // If r == end, we skip this search
213 Scalar sigma = Scalar(-1);
214 if (r < m_n - 1)
215 sigma = find_lambda(r, p);
216
217 // Then search A[k, r], ..., A[r-1, r], which maps to A[r, k], ..., A[r, r-1]
218 for (Index j = k; j < r; j++)
219 {
220 const Scalar abs_elem = abs(coeff(r, j));
221 if (sigma < abs_elem)
222 {
223 sigma = abs_elem;
224 p = j;
225 }
226 }
227
228 return sigma;
229 }
230
231 // Generate permutations and apply to A
232 // Return true if the resulting pivoting is 1x1, and false if 2x2
233 bool permutate_mat(Index k, const Scalar& alpha)
234 {
235 using std::abs;
236
237 Index r = k, p = k;
238 const Scalar lambda = find_lambda(k, r);
239
240 // If lambda=0, no need to interchange
241 if (lambda > Scalar(0))
242 {
243 const Scalar abs_akk = abs(diag_coeff(k));
244 // If |A[k, k]| >= alpha * lambda, no need to interchange
245 if (abs_akk < alpha * lambda)
246 {
247 const Scalar sigma = find_sigma(k, r, p);
248
249 // If sigma * |A[k, k]| >= alpha * lambda^2, no need to interchange
250 if (sigma * abs_akk < alpha * lambda * lambda)
251 {
252 if (abs_akk >= alpha * sigma)
253 {
254 // Permutation on A
255 pivoting_1x1(k, r);
256
257 // Permutation on L
258 interchange_rows(k, r, 0, k - 1);
259 return true;
260 }
261 else
262 {
263 // There are two versions of permutation here
264 // 1. A[k+1, k] <-> A[r, k]
265 // 2. A[k+1, k] <-> A[r, p], where p >= k and r >= k+1
266 //
267 // Version 1 and 2 are used by Ref[1] and Ref[2], respectively
268
269 // Version 1 implementation
270 p = k;
271
272 // Version 2 implementation
273 // [r, p] and [p, r] are symmetric, but we need to make sure
274 // p >= k and r >= k+1, so it is safe to always make r > p
275 // One exception is when min{r,p} == k+1, in which case we make
276 // r = k+1, so that only one permutation needs to be performed
277 /* const Index rp_min = std::min(r, p);
278 const Index rp_max = std::max(r, p);
279 if(rp_min == k + 1)
280 {
281 r = rp_min; p = rp_max;
282 } else {
283 r = rp_max; p = rp_min;
284 } */
285
286 // Right now we use Version 1 since it reduces the overhead of interchange
287
288 // Permutation on A
289 pivoting_2x2(k, r, p);
290 // Permutation on L
291 interchange_rows(k, p, 0, k - 1);
292 interchange_rows(k + 1, r, 0, k - 1);
293 return false;
294 }
295 }
296 }
297 }
298
299 return true;
300 }
301
302 // E = [e11, e12]
303 // [e21, e22]
304 // Overwrite E with inv(E)
305 void inverse_inplace_2x2(Scalar& e11, Scalar& e21, Scalar& e22) const
306 {
307 // inv(E) = [d11, d12], d11 = e22/delta, d21 = -e21/delta, d22 = e11/delta
308 // [d21, d22]
309 const Scalar delta = e11 * e22 - e21 * e21;
310 std::swap(e11, e22);
311 e11 /= delta;
312 e22 /= delta;
313 e21 = -e21 / delta;
314 }
315
316 // Return value is the status, SUCCESSFUL/NUMERICAL_ISSUE
317 int gaussian_elimination_1x1(Index k)
318 {
319 // D = 1 / A[k, k]
320 const Scalar akk = diag_coeff(k);
321 // Return NUMERICAL_ISSUE if not invertible
322 if (akk == Scalar(0))
323 return NUMERICAL_ISSUE;
324
325 diag_coeff(k) = Scalar(1) / akk;
326
327 // B -= l * l' / A[k, k], B := A[(k+1):end, (k+1):end], l := L[(k+1):end, k]
328 Scalar* lptr = col_pointer(k) + 1;
329 const Index ldim = m_n - k - 1;
330 MapVec l(lptr, ldim);
331 for (Index j = 0; j < ldim; j++)
332 {
333 MapVec(col_pointer(j + k + 1), ldim - j).noalias() -= (lptr[j] / akk) * l.tail(ldim - j);
334 }
335
336 // l /= A[k, k]
337 l /= akk;
338
339 return SUCCESSFUL;
340 }
341
342 // Return value is the status, SUCCESSFUL/NUMERICAL_ISSUE
343 int gaussian_elimination_2x2(Index k)
344 {
345 // D = inv(E)
346 Scalar& e11 = diag_coeff(k);
347 Scalar& e21 = coeff(k + 1, k);
348 Scalar& e22 = diag_coeff(k + 1);
349 // Return NUMERICAL_ISSUE if not invertible
350 if (e11 * e22 - e21 * e21 == Scalar(0))
351 return NUMERICAL_ISSUE;
352
353 inverse_inplace_2x2(e11, e21, e22);
354
355 // X = l * inv(E), l := L[(k+2):end, k:(k+1)]
356 Scalar* l1ptr = &coeff(k + 2, k);
357 Scalar* l2ptr = &coeff(k + 2, k + 1);
358 const Index ldim = m_n - k - 2;
359 MapVec l1(l1ptr, ldim), l2(l2ptr, ldim);
360
361 Eigen::Matrix<Scalar, Eigen::Dynamic, 2> X(ldim, 2);
362 X.col(0).noalias() = l1 * e11 + l2 * e21;
363 X.col(1).noalias() = l1 * e21 + l2 * e22;
364
365 // B -= l * inv(E) * l' = X * l', B = A[(k+2):end, (k+2):end]
366 for (Index j = 0; j < ldim; j++)
367 {
368 MapVec(col_pointer(j + k + 2), ldim - j).noalias() -= (X.col(0).tail(ldim - j) * l1ptr[j] + X.col(1).tail(ldim - j) * l2ptr[j]);
369 }
370
371 // l = X
372 l1.noalias() = X.col(0);
373 l2.noalias() = X.col(1);
374
375 return SUCCESSFUL;
376 }
377
378public:
379 BKLDLT() :
380 m_n(0), m_computed(false), m_info(NOT_COMPUTED)
381 {}
382
383 // Factorize mat - shift * I
384 BKLDLT(ConstGenericMatrix& mat, int uplo = Eigen::Lower, const Scalar& shift = Scalar(0)) :
385 m_n(mat.rows()), m_computed(false), m_info(NOT_COMPUTED)
386 {
387 compute(mat, uplo, shift);
388 }
389
390 void compute(ConstGenericMatrix& mat, int uplo = Eigen::Lower, const Scalar& shift = Scalar(0))
391 {
392 using std::abs;
393
394 m_n = mat.rows();
395 if (m_n != mat.cols())
396 throw std::invalid_argument("BKLDLT: matrix must be square");
397
398 m_perm.setLinSpaced(m_n, 0, m_n - 1);
399 m_permc.clear();
400
401 // Copy data
402 m_data.resize((m_n * (m_n + 1)) / 2);
403 compute_pointer();
404 copy_data(mat, uplo, shift);
405
406 const Scalar alpha = (1.0 + std::sqrt(17.0)) / 8.0;
407 Index k = 0;
408 for (k = 0; k < m_n - 1; k++)
409 {
410 // 1. Interchange rows and columns of A, and save the result to m_perm
411 bool is_1x1 = permutate_mat(k, alpha);
412
413 // 2. Gaussian elimination
414 if (is_1x1)
415 {
416 m_info = gaussian_elimination_1x1(k);
417 }
418 else
419 {
420 m_info = gaussian_elimination_2x2(k);
421 k++;
422 }
423
424 // 3. Check status
425 if (m_info != SUCCESSFUL)
426 break;
427 }
428 // Invert the last 1x1 block if it exists
429 if (k == m_n - 1)
430 {
431 const Scalar akk = diag_coeff(k);
432 if (akk == Scalar(0))
433 m_info = NUMERICAL_ISSUE;
434
435 diag_coeff(k) = Scalar(1) / diag_coeff(k);
436 }
437
438 compress_permutation();
439
440 m_computed = true;
441 }
442
443 // Solve Ax=b
444 void solve_inplace(GenericVector b) const
445 {
446 if (!m_computed)
447 throw std::logic_error("BKLDLT: need to call compute() first");
448
449 // PAP' = LDL'
450 // 1. b -> Pb
451 Scalar* x = b.data();
452 MapVec res(x, m_n);
453 Index npermc = m_permc.size();
454 for (Index i = 0; i < npermc; i++)
455 {
456 std::swap(x[m_permc[i].first], x[m_permc[i].second]);
457 }
458
459 // 2. Lz = Pb
460 // If m_perm[end] < 0, then end with m_n - 3, otherwise end with m_n - 2
461 const Index end = (m_perm[m_n - 1] < 0) ? (m_n - 3) : (m_n - 2);
462 for (Index i = 0; i <= end; i++)
463 {
464 const Index b1size = m_n - i - 1;
465 const Index b2size = b1size - 1;
466 if (m_perm[i] >= 0)
467 {
468 MapConstVec l(&coeff(i + 1, i), b1size);
469 res.segment(i + 1, b1size).noalias() -= l * x[i];
470 }
471 else
472 {
473 MapConstVec l1(&coeff(i + 2, i), b2size);
474 MapConstVec l2(&coeff(i + 2, i + 1), b2size);
475 res.segment(i + 2, b2size).noalias() -= (l1 * x[i] + l2 * x[i + 1]);
476 i++;
477 }
478 }
479
480 // 3. Dw = z
481 for (Index i = 0; i < m_n; i++)
482 {
483 const Scalar e11 = diag_coeff(i);
484 if (m_perm[i] >= 0)
485 {
486 x[i] *= e11;
487 }
488 else
489 {
490 const Scalar e21 = coeff(i + 1, i), e22 = diag_coeff(i + 1);
491 const Scalar wi = x[i] * e11 + x[i + 1] * e21;
492 x[i + 1] = x[i] * e21 + x[i + 1] * e22;
493 x[i] = wi;
494 i++;
495 }
496 }
497
498 // 4. L'y = w
499 // If m_perm[end] < 0, then start with m_n - 3, otherwise start with m_n - 2
500 Index i = (m_perm[m_n - 1] < 0) ? (m_n - 3) : (m_n - 2);
501 for (; i >= 0; i--)
502 {
503 const Index ldim = m_n - i - 1;
504 MapConstVec l(&coeff(i + 1, i), ldim);
505 x[i] -= res.segment(i + 1, ldim).dot(l);
506
507 if (m_perm[i] < 0)
508 {
509 MapConstVec l2(&coeff(i + 1, i - 1), ldim);
510 x[i - 1] -= res.segment(i + 1, ldim).dot(l2);
511 i--;
512 }
513 }
514
515 // 5. x = P'y
516 for (i = npermc - 1; i >= 0; i--)
517 {
518 std::swap(x[m_permc[i].first], x[m_permc[i].second]);
519 }
520 }
521
522 Vector solve(ConstGenericVector& b) const
523 {
524 Vector res = b;
525 solve_inplace(res);
526 return res;
527 }
528
529 int info() const { return m_info; }
530};
531
532} // namespace LBFGSpp
533
535
536#endif // LBFGSPP_BK_LDLT_H