]> git.tdb.fi Git - libs/math.git/blobdiff - source/linal/squarematrix.h
Rename the low-level matrix inversion function to gauss_jordan
[libs/math.git] / source / linal / squarematrix.h
index 94157b0f3b92f2e2de9ebbf49e927d17bf8a9ae8..b8531aafcb8359bcda6ea35628a9f561f77b677b 100644 (file)
@@ -1,19 +1,17 @@
 #ifndef MSP_LINAL_SQUAREMATRIX_H_
 #define MSP_LINAL_SQUAREMATRIX_H_
 
-#include <stdexcept>
+#include <cmath>
 #include "matrix.h"
+#include "matrixops.h"
 
 namespace Msp {
 namespace LinAl {
 
-class not_invertible: public std::domain_error
-{
-public:
-       not_invertible(): domain_error(std::string()) { }
-       virtual ~not_invertible() throw() { }
-};
-
+/**
+A mathematical matrix with S rows and columns.  Some operations are provided
+here that are only possible for square matrices.
+*/
 template<typename T, unsigned S>
 class SquareMatrix: public Matrix<T, S, S>
 {
@@ -42,54 +40,23 @@ inline SquareMatrix<T, S> SquareMatrix<T, S>::identity()
 template<typename T, unsigned S>
 SquareMatrix<T, S> &SquareMatrix<T, S>::operator*=(const SquareMatrix<T, S> &m)
 {
-       Matrix<T, S, S>::operator*=(m);
-       return *this;
+       return *this = *this*m;
 }
 
 template<typename T, unsigned S>
 SquareMatrix<T, S> &SquareMatrix<T, S>::invert()
 {
        SquareMatrix<T, S> r = identity();
-       for(unsigned i=0; i<S; ++i)
-       {
-               if(this->element(i, i)==T(0))
-               {
-                       unsigned pivot = i;
-                       for(unsigned j=i+1; j<S; ++j)
-                               if(abs(this->element(j, i))>abs(this->element(pivot, i)))
-                                       pivot = j;
-
-                       if(pivot==i)
-                               throw not_invertible();
-
-                       this->exchange_rows(i, pivot);
-                       r.exchange_rows(i, pivot);
-               }
-
-               for(unsigned j=i+1; j<S; ++j)
-               {
-                       T a = -this->element(j, i)/this->element(i, i);
-                       this->add_row(i, j, a);
-                       r.add_row(i, j, a);
-               }
-
-               T a = T(1)/this->element(i, i);
-               this->multiply_row(i, a);
-               r.multiply_row(i, a);
-       }
-
-       for(unsigned i=S; i-->0; )
-               for(unsigned j=i; j-->0; )
-                       r.add_row(i, j, -this->element(j, i));
-
+       gauss_jordan(*this, r);
        return *this = r;
 }
 
 template<typename T, unsigned S>
 inline SquareMatrix<T, S> invert(const SquareMatrix<T, S> &m)
 {
-       SquareMatrix<T, S> r = m;
-       return r.invert();
+       SquareMatrix<T, S> temp = m;
+       SquareMatrix<T, S> r = SquareMatrix<T, S>::identity();
+       return gauss_jordan(temp, r);
 }
 
 } // namespace LinAl