]> git.tdb.fi Git - libs/math.git/blob - source/linal/squarematrix.h
Another big batch of stuff
[libs/math.git] / source / linal / squarematrix.h
1 #ifndef MSP_LINAL_SQUAREMATRIX_H_
2 #define MSP_LINAL_SQUAREMATRIX_H_
3
4 #include <stdexcept>
5 #include "matrix.h"
6
7 namespace Msp {
8 namespace LinAl {
9
10 class not_invertible: public std::domain_error
11 {
12 public:
13         not_invertible(): domain_error(std::string()) { }
14         virtual ~not_invertible() throw() { }
15 };
16
17 template<typename T, unsigned S>
18 class SquareMatrix: public Matrix<T, S, S>
19 {
20 public:
21         SquareMatrix() { }
22         SquareMatrix(const T *);
23         template<typename U>
24         SquareMatrix(const Matrix<U, S, S> &);
25
26         static SquareMatrix identity();
27
28         SquareMatrix &operator*=(const SquareMatrix &);
29
30         SquareMatrix &invert();
31 };
32
33 template<typename T, unsigned S>
34 SquareMatrix<T, S>::SquareMatrix(const T *d):
35         Matrix<T, S, S>(d)
36 { }
37
38 template<typename T, unsigned S>
39 template<typename U>
40 SquareMatrix<T, S>::SquareMatrix(const Matrix<U, S, S> &m):
41         Matrix<T, S, S>(m)
42 { }
43
44 template<typename T, unsigned S>
45 inline SquareMatrix<T, S> SquareMatrix<T, S>::identity()
46 {
47         SquareMatrix<T, S> m;
48         for(unsigned i=0; i<S; ++i)
49                 m(i, i) = T(1);
50         return m;
51 }
52
53 template<typename T, unsigned S>
54 SquareMatrix<T, S> &SquareMatrix<T, S>::operator*=(const SquareMatrix<T, S> &m)
55 {
56         Matrix<T, S, S>::operator*=(m);
57         return *this;
58 }
59
60 template<typename T, unsigned S>
61 SquareMatrix<T, S> &SquareMatrix<T, S>::invert()
62 {
63         SquareMatrix<T, S> r = identity();
64         for(unsigned i=0; i<S; ++i)
65         {
66                 if(this->element(i, i)==T(0))
67                 {
68                         unsigned pivot = i;
69                         for(unsigned j=i+1; j<S; ++j)
70                                 if(abs(this->element(j, i))>abs(this->element(pivot, i)))
71                                         pivot = j;
72
73                         if(pivot==i)
74                                 throw not_invertible();
75
76                         this->exchange_rows(i, pivot);
77                         r.exchange_rows(i, pivot);
78                 }
79
80                 for(unsigned j=i+1; j<S; ++j)
81                 {
82                         T a = -this->element(j, i)/this->element(i, i);
83                         this->add_row(i, j, a);
84                         r.add_row(i, j, a);
85                 }
86
87                 T a = T(1)/this->element(i, i);
88                 this->multiply_row(i, a);
89                 r.multiply_row(i, a);
90         }
91
92         for(unsigned i=S; i-->0; )
93                 for(unsigned j=i; j-->0; )
94                         r.add_row(i, j, -this->element(j, i));
95
96         return *this = r;
97 }
98
99 template<typename T, unsigned S>
100 inline SquareMatrix<T, S> invert(const SquareMatrix<T, S> &m)
101 {
102         SquareMatrix<T, S> r = m;
103         return r.invert();
104 }
105
106 } // namespace LinAl
107 } // namespace Msp
108
109 #endif