consteig
Compile-time eigenvalue and eigenvector computation for C++17
Loading...
Searching...
No Matches
lu.hpp
1#ifndef LU_DECOMP_HPP
2#define LU_DECOMP_HPP
3
4#include "../../math/constmath.hpp"
5#include "../matrix.hpp"
6#include "../operations.hpp"
7
8namespace consteig
9{
10
13
26template <typename T, Size S> struct LUMatrix
27{
28 Matrix<T, S, S> _l;
29 Matrix<T, S, S> _u;
30 Size _p[S];
31};
32
46// Algorithm: LU Decomposition with Partial Pivoting
47// Factors a square matrix A such that PA = LU, where P is a permutation matrix,
48// L is unit lower triangular, and U is upper triangular. Partial pivoting
49// selects the largest magnitude entry in each column as the pivot, which
50// controls the growth of rounding errors during elimination.
51//
52// Reference: Golub & Van Loan, "Matrix Computations" (4th ed.), sec. 3.4
53template <typename T, Size S>
54constexpr LUMatrix<T, S> lu(const Matrix<T, S, S> &a)
55{
56 LUMatrix<T, S> res{};
57 res._u = a;
58 res._l = eye<T, S>();
59 for (Size row = 0; row < S; ++row)
60 {
61 res._p[row] = row;
62 }
63
64 for (Size diag = 0; diag < S; ++diag)
65 {
66 // Pivot
67 Size max_row = diag;
68 auto max_val = abs(res._u(diag, diag));
69 for (Size search_row = diag + 1; search_row < S; ++search_row)
70 {
71 auto val = abs(res._u(search_row, diag));
72 if (val > max_val)
73 {
74 max_val = val;
76 }
77 }
78
79 if (max_row != diag)
80 {
81 // Swap rows in U
82 for (Size col = 0; col < S; ++col)
83 {
84 T tmp = res._u(diag, col);
85 res._u(diag, col) = res._u(max_row, col);
86 res._u(max_row, col) = tmp;
87 }
88 // Swap rows in L (elements below diagonal)
89 for (Size col = 0; col < diag; ++col)
90 {
91 T tmp = res._l(diag, col);
92 res._l(diag, col) = res._l(max_row, col);
93 res._l(max_row, col) = tmp;
94 }
95 // Swap pivots
96 Size tmp_p = res._p[diag];
97 res._p[diag] = res._p[max_row];
98 res._p[max_row] = tmp_p;
99 }
100
101 for (Size row = diag + 1; row < S; ++row)
102 {
103 // Note: In inverse iteration, we might encounter nearly singular
104 // matrices. We use a small epsilon to avoid exact division by zero.
105 auto pivot_abs = abs(res._u(diag, diag));
106 if (pivot_abs > 1e-30)
107 {
108 res._l(row, diag) = res._u(row, diag) / res._u(diag, diag);
109 for (Size col = diag; col < S; ++col)
110 {
111 res._u(row, col) = res._u(row, col) -
112 res._l(row, diag) * res._u(diag, col);
113 }
114 }
115 }
116 }
117 return res;
118}
119
134// Solves Ax = b using the LU factorization PA = LU.
135// 1. Solve Ly = Pb for y (Forward Substitution)
136// 2. Solve Ux = y for x (Backward Substitution)
137template <typename T, Size S>
138constexpr Matrix<T, S, 1> lu_solve(const LUMatrix<T, S> &lu,
139 const Matrix<T, S, 1> &b)
140{
141 // Solve Ly = Pb
143 for (Size row = 0; row < S; ++row)
144 {
145 pb(row, 0) = b(lu._p[row], 0);
146 }
147
149 for (Size row = 0; row < S; ++row)
150 {
151 T sum = static_cast<T>(0);
152 for (Size col = 0; col < row; ++col)
153 {
154 sum = sum + lu._l(row, col) * y(col, 0);
155 }
156 y(row, 0) = pb(row, 0) - sum;
157 }
158
159 // Solve Ux = y
161 for (Size i = S; i > 0; --i)
162 {
163 Size row = i - 1;
164 T sum = static_cast<T>(0);
165 for (Size col = row + 1; col < S; ++col)
166 {
167 sum = sum + lu._u(row, col) * x(col, 0);
168 }
169
170 auto diag_abs = abs(lu._u(row, row));
171 if (diag_abs > 1e-30)
172 {
173 x(row, 0) = (y(row, 0) - sum) / lu._u(row, row);
174 }
175 else
176 {
177 // If nearly singular, we use a very small value instead of zero to
178 // encourage the "explosion" required by Inverse Iteration.
179 x(row, 0) = (y(row, 0) - sum) / static_cast<T>(1e-30);
180 }
181 }
182 return x;
183}
184
186
187} // namespace consteig
188
189#endif
Fixed-size matrix with compile-time dimensions.
Definition matrix.hpp:56
constexpr Matrix< T, S, 1 > lu_solve(const LUMatrix< T, S > &lu, const Matrix< T, S, 1 > &b)
Solve the linear system Ax = b given the LU factorization of A.
Definition lu.hpp:138
constexpr LUMatrix< T, S > lu(const Matrix< T, S, S > &a)
LU decomposition with partial pivoting.
Definition lu.hpp:54
constexpr T abs(const T x)
Absolute value of a real number.
Definition abs.hpp:17
constexpr T epsilon()
Machine epsilon for type T.
Definition utilities.hpp:82