]> git.tdb.fi Git - libs/math.git/blobdiff - source/linal/squarematrix.h
Always pivot rows when inverting for better numerical stability
[libs/math.git] / source / linal / squarematrix.h
index 4cdc9e88f7288654d4dcbae9d2ec006517bfb734..5117bb3fa189ee4339d0ce52cdbb4d2ff141b83b 100644 (file)
@@ -1,6 +1,7 @@
 #ifndef MSP_LINAL_SQUAREMATRIX_H_
 #define MSP_LINAL_SQUAREMATRIX_H_
 
+#include <cmath>
 #include <stdexcept>
 #include "matrix.h"
 
@@ -14,14 +15,18 @@ public:
        virtual ~not_invertible() throw() { }
 };
 
+/**
+A mathematical matrix with S rows and columns.  Some operations are provided
+here that are only possible for square matrices.
+*/
 template<typename T, unsigned S>
 class SquareMatrix: public Matrix<T, S, S>
 {
 public:
        SquareMatrix() { }
-       SquareMatrix(const T *);
+       SquareMatrix(const T *d): Matrix<T, S, S>(d) { }
        template<typename U>
-       SquareMatrix(const Matrix<U, S, S> &);
+       SquareMatrix(const Matrix<U, S, S> &m): Matrix<T, S, S>(m) { }
 
        static SquareMatrix identity();
 
@@ -30,17 +35,6 @@ public:
        SquareMatrix &invert();
 };
 
-template<typename T, unsigned S>
-SquareMatrix<T, S>::SquareMatrix(const T *d):
-       Matrix<T, S, S>(d)
-{ }
-
-template<typename T, unsigned S>
-template<typename U>
-SquareMatrix<T, S>::SquareMatrix(const Matrix<U, S, S> &m):
-       Matrix<T, S, S>(m)
-{ }
-
 template<typename T, unsigned S>
 inline SquareMatrix<T, S> SquareMatrix<T, S>::identity()
 {
@@ -53,26 +47,27 @@ inline SquareMatrix<T, S> SquareMatrix<T, S>::identity()
 template<typename T, unsigned S>
 SquareMatrix<T, S> &SquareMatrix<T, S>::operator*=(const SquareMatrix<T, S> &m)
 {
-       Matrix<T, S, S>::operator*=(m);
-       return *this;
+       return *this = *this*m;
 }
 
 template<typename T, unsigned S>
 SquareMatrix<T, S> &SquareMatrix<T, S>::invert()
 {
+       using std::abs;
+
        SquareMatrix<T, S> r = identity();
        for(unsigned i=0; i<S; ++i)
        {
-               if(this->element(i, i)==T(0))
-               {
-                       unsigned pivot = i;
-                       for(unsigned j=i+1; j<S; ++j)
-                               if(abs(this->element(j, i))>abs(this->element(pivot, i)))
-                                       pivot = j;
+               unsigned pivot = i;
+               for(unsigned j=i+1; j<S; ++j)
+                       if(abs(this->element(j, i))>abs(this->element(pivot, i)))
+                               pivot = j;
 
-                       if(pivot==i)
-                               throw not_invertible();
+               if(this->element(pivot, i)==T(0))
+                       throw not_invertible();
 
+               if(pivot!=i)
+               {
                        this->exchange_rows(i, pivot);
                        r.exchange_rows(i, pivot);
                }