]> git.tdb.fi Git - libs/math.git/blobdiff - source/linal/squarematrix.h
Another big batch of stuff
[libs/math.git] / source / linal / squarematrix.h
index 6cda86a5baa6b202ec00dba93acb0a152b03a985..4cdc9e88f7288654d4dcbae9d2ec006517bfb734 100644 (file)
 #ifndef MSP_LINAL_SQUAREMATRIX_H_
 #define MSP_LINAL_SQUAREMATRIX_H_
 
+#include <stdexcept>
 #include "matrix.h"
 
 namespace Msp {
 namespace LinAl {
 
+class not_invertible: public std::domain_error
+{
+public:
+       not_invertible(): domain_error(std::string()) { }
+       virtual ~not_invertible() throw() { }
+};
+
 template<typename T, unsigned S>
 class SquareMatrix: public Matrix<T, S, S>
 {
 public:
-       SquareMatrix();
+       SquareMatrix() { }
        SquareMatrix(const T *);
        template<typename U>
        SquareMatrix(const Matrix<U, S, S> &);
+
        static SquareMatrix identity();
 
        SquareMatrix &operator*=(const SquareMatrix &);
 
-       void invert();
+       SquareMatrix &invert();
 };
 
+template<typename T, unsigned S>
+SquareMatrix<T, S>::SquareMatrix(const T *d):
+       Matrix<T, S, S>(d)
+{ }
+
+template<typename T, unsigned S>
+template<typename U>
+SquareMatrix<T, S>::SquareMatrix(const Matrix<U, S, S> &m):
+       Matrix<T, S, S>(m)
+{ }
+
+template<typename T, unsigned S>
+inline SquareMatrix<T, S> SquareMatrix<T, S>::identity()
+{
+       SquareMatrix<T, S> m;
+       for(unsigned i=0; i<S; ++i)
+               m(i, i) = T(1);
+       return m;
+}
+
+template<typename T, unsigned S>
+SquareMatrix<T, S> &SquareMatrix<T, S>::operator*=(const SquareMatrix<T, S> &m)
+{
+       Matrix<T, S, S>::operator*=(m);
+       return *this;
+}
+
+template<typename T, unsigned S>
+SquareMatrix<T, S> &SquareMatrix<T, S>::invert()
+{
+       SquareMatrix<T, S> r = identity();
+       for(unsigned i=0; i<S; ++i)
+       {
+               if(this->element(i, i)==T(0))
+               {
+                       unsigned pivot = i;
+                       for(unsigned j=i+1; j<S; ++j)
+                               if(abs(this->element(j, i))>abs(this->element(pivot, i)))
+                                       pivot = j;
+
+                       if(pivot==i)
+                               throw not_invertible();
+
+                       this->exchange_rows(i, pivot);
+                       r.exchange_rows(i, pivot);
+               }
+
+               for(unsigned j=i+1; j<S; ++j)
+               {
+                       T a = -this->element(j, i)/this->element(i, i);
+                       this->add_row(i, j, a);
+                       r.add_row(i, j, a);
+               }
+
+               T a = T(1)/this->element(i, i);
+               this->multiply_row(i, a);
+               r.multiply_row(i, a);
+       }
+
+       for(unsigned i=S; i-->0; )
+               for(unsigned j=i; j-->0; )
+                       r.add_row(i, j, -this->element(j, i));
+
+       return *this = r;
+}
+
+template<typename T, unsigned S>
+inline SquareMatrix<T, S> invert(const SquareMatrix<T, S> &m)
+{
+       SquareMatrix<T, S> r = m;
+       return r.invert();
+}
+
 } // namespace LinAl
 } // namespace Msp