]> git.tdb.fi Git - libs/math.git/blob - source/linal/squarematrix.h
Always pivot rows when inverting for better numerical stability
[libs/math.git] / source / linal / squarematrix.h
1 #ifndef MSP_LINAL_SQUAREMATRIX_H_
2 #define MSP_LINAL_SQUAREMATRIX_H_
3
4 #include <cmath>
5 #include <stdexcept>
6 #include "matrix.h"
7
8 namespace Msp {
9 namespace LinAl {
10
11 class not_invertible: public std::domain_error
12 {
13 public:
14         not_invertible(): domain_error(std::string()) { }
15         virtual ~not_invertible() throw() { }
16 };
17
18 /**
19 A mathematical matrix with S rows and columns.  Some operations are provided
20 here that are only possible for square matrices.
21 */
22 template<typename T, unsigned S>
23 class SquareMatrix: public Matrix<T, S, S>
24 {
25 public:
26         SquareMatrix() { }
27         SquareMatrix(const T *d): Matrix<T, S, S>(d) { }
28         template<typename U>
29         SquareMatrix(const Matrix<U, S, S> &m): Matrix<T, S, S>(m) { }
30
31         static SquareMatrix identity();
32
33         SquareMatrix &operator*=(const SquareMatrix &);
34
35         SquareMatrix &invert();
36 };
37
38 template<typename T, unsigned S>
39 inline SquareMatrix<T, S> SquareMatrix<T, S>::identity()
40 {
41         SquareMatrix<T, S> m;
42         for(unsigned i=0; i<S; ++i)
43                 m(i, i) = T(1);
44         return m;
45 }
46
47 template<typename T, unsigned S>
48 SquareMatrix<T, S> &SquareMatrix<T, S>::operator*=(const SquareMatrix<T, S> &m)
49 {
50         return *this = *this*m;
51 }
52
53 template<typename T, unsigned S>
54 SquareMatrix<T, S> &SquareMatrix<T, S>::invert()
55 {
56         using std::abs;
57
58         SquareMatrix<T, S> r = identity();
59         for(unsigned i=0; i<S; ++i)
60         {
61                 unsigned pivot = i;
62                 for(unsigned j=i+1; j<S; ++j)
63                         if(abs(this->element(j, i))>abs(this->element(pivot, i)))
64                                 pivot = j;
65
66                 if(this->element(pivot, i)==T(0))
67                         throw not_invertible();
68
69                 if(pivot!=i)
70                 {
71                         this->exchange_rows(i, pivot);
72                         r.exchange_rows(i, pivot);
73                 }
74
75                 for(unsigned j=i+1; j<S; ++j)
76                 {
77                         T a = -this->element(j, i)/this->element(i, i);
78                         this->add_row(i, j, a);
79                         r.add_row(i, j, a);
80                 }
81
82                 T a = T(1)/this->element(i, i);
83                 this->multiply_row(i, a);
84                 r.multiply_row(i, a);
85         }
86
87         for(unsigned i=S; i-->0; )
88                 for(unsigned j=i; j-->0; )
89                         r.add_row(i, j, -this->element(j, i));
90
91         return *this = r;
92 }
93
94 template<typename T, unsigned S>
95 inline SquareMatrix<T, S> invert(const SquareMatrix<T, S> &m)
96 {
97         SquareMatrix<T, S> r = m;
98         return r.invert();
99 }
100
101 } // namespace LinAl
102 } // namespace Msp
103
104 #endif