source: foam/trunk/vision/src/Matrix.h @ 96

Revision 96, 11.8 KB checked in by dave, 10 years ago (diff)

Benchmark testing added for the eigen face recognition

RevLine 
[85]1// Copyright (C) 2009 foam
2//
3// This program is free software; you can redistribute it and/or modify
4// it under the terms of the GNU General Public License as published by
5// the Free Software Foundation; either version 2 of the License, or
6// (at your option) any later version.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program; if not, write to the Free Software
15// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
16
17#include <assert.h>
18#include <iostream>
[86]19#include "Vector.h"
[85]20
21#ifndef FOAM_MATRIX
22#define FOAM_MATRIX
23
[86]24template<class T>
[85]25class Matrix
26{
27public:
[86]28        Matrix();
[85]29        Matrix(unsigned int r, unsigned int c);
30        ~Matrix();
31        Matrix(const Matrix &other);
[86]32        Matrix(unsigned int r, unsigned int c, float *data);
33
[85]34        // Row proxy classes to allow matrix[r][c] notation
35        class Row
36        {
37        public:
38                Row(Matrix *owner, unsigned int r)
39                {
40                        m_Data=&owner->GetRawData()[r*owner->GetCols()];
41                        m_Cols=owner->GetCols();
42                }
43               
44                T &operator[](unsigned int c)
45                {
46                        assert(c<m_Cols);
47                        return m_Data[c];
48                }
49               
50        private:
51                T *m_Data;
52                unsigned int m_Cols;
53        };
54
55        class ConstRow
56        {
57        public:
58                ConstRow(const Matrix *owner, unsigned int r)
59                {
60                        m_Data=&owner->GetRawDataConst()[r*owner->GetCols()];
61                        m_Cols=owner->GetCols();
62                }
63               
64                const T &operator[](unsigned int c) const
65                {
66                        assert(c<m_Cols);
67                        return m_Data[c];
68                }
69               
70        private:
71                const T *m_Data;
72                unsigned int m_Cols;
73        };
74
75       
76        Row operator[](unsigned int r)
77        {
78                assert(r<m_Rows);
79                return Row(this,r);
80        }
81
82        ConstRow operator[](unsigned int r) const
83        {
84                assert(r<m_Rows);
85                return ConstRow(this,r);
86        }
87 
88        unsigned int GetRows() const { return m_Rows; }
89        unsigned int GetCols() const { return m_Cols; }
90        T *GetRawData() { return m_Data; }
91        const T *GetRawDataConst() const { return m_Data; }
[86]92        Vector<T> GetRowVector(unsigned int r) const;
93        Vector<T> GetColVector(unsigned int c) const;
[89]94        void SetRowVector(unsigned int r, const Vector<T> &row);
95        void SetColVector(unsigned int c, const Vector<T> &col);
[85]96         
97        void Print() const;
98        void SetAll(T s);
[86]99        void Zero() { SetAll(0); }
100        bool IsInf();
[90]101        Matrix Transposed() const;
[96]102        Matrix Inverted() const;
[85]103
104        Matrix &operator=(const Matrix &other);
105        Matrix operator+(const Matrix &other) const;
106        Matrix operator-(const Matrix &other) const;
107        Matrix operator*(const Matrix &other) const;
[90]108        Vector<T> operator*(const Vector<T> &other) const;
109        Vector<T> VecMulTransposed(const Vector<T> &other) const;
[85]110        Matrix &operator+=(const Matrix &other);
111        Matrix &operator-=(const Matrix &other);
112        Matrix &operator*=(const Matrix &other);
[96]113        bool operator==(const Matrix &other) const;
[90]114
[89]115        void SortRows(Vector<T> &v);
116        void SortCols(Vector<T> &v);
[90]117       
118        Matrix CropRows(unsigned int s, unsigned int e);
119        Matrix CropCols(unsigned int s, unsigned int e);
[89]120
[90]121        void Save(FILE *f);
122        void Load(FILE *f);
123
[85]124        static void RunTests();
125       
126private:
127
128        unsigned int m_Rows;
129        unsigned int m_Cols;
130       
131        T *m_Data;
132       
133};
134
135template<class T>
136Matrix<T>::Matrix(unsigned int r, unsigned int c) :
137m_Rows(r),
138m_Cols(c)
139{
140        m_Data=new T[r*c];
141}
142
143template<class T>
[86]144Matrix<T>::Matrix() :
145m_Rows(0),
146m_Cols(0),
147m_Data(NULL)
148{
149}
150
151template<class T>
152Matrix<T>::Matrix(unsigned int r, unsigned int c, float *data) :
153m_Rows(r),
154m_Cols(c),
155m_Data(data)
156{
157}
158
159template<class T>
[85]160Matrix<T>::~Matrix()
161{
162        delete[] m_Data;
163}
164
165template<class T>
166Matrix<T>::Matrix(const Matrix &other)
167{
168        m_Rows = other.m_Rows;
169        m_Cols = other.m_Cols;
170        m_Data=new T[m_Rows*m_Cols];
171        memcpy(m_Data,other.m_Data,m_Rows*m_Cols*sizeof(T));
172}
173
174template<class T>
175Matrix<T> &Matrix<T>::operator=(const Matrix &other)
176{
177        if (m_Data!=NULL)
178        {
179                delete[] m_Data;
180        }
181       
182        m_Rows = other.m_Rows;
183        m_Cols = other.m_Cols;
184        m_Data=new T[m_Rows*m_Cols];
185        memcpy(m_Data,other.m_Data,m_Rows*m_Cols*sizeof(T));
186       
187        return *this;
188}
189
190template<class T>
191void Matrix<T>::Print() const
192{
193        for (unsigned int i=0; i<m_Rows; i++)
194        {
195                for (unsigned int j=0; j<m_Cols; j++)
196                {
197                        std::cerr<<(*this)[i][j]<<" ";
198                }
199                std::cerr<<std::endl;
200        }
201}
202
203template<class T>
204void Matrix<T>::SetAll(T s)
205{
206        for (unsigned int i=0; i<m_Rows; i++)
207        {
208                for (unsigned int j=0; j<m_Cols; j++)
209                {
210                        (*this)[i][j]=s;
211                }
212        }
213}
214
215template<class T>
[86]216bool Matrix<T>::IsInf()
217{
218        for (unsigned int i=0; i<m_Rows; i++)
219        {
220                for (unsigned int j=0; j<m_Cols; j++)
221                {
222                        if (isinf((*this)[i][j])) return true;
223                        if (isnan((*this)[i][j])) return true;
224                }
225        }
226        return false;
227}
228
229template<class T>
[90]230Matrix<T> Matrix<T>::Transposed() const
[85]231{
232        Matrix<T> copy(*this);
233        for (unsigned int i=0; i<m_Rows; i++)
234        {
235                for (unsigned int j=0; j<m_Cols; j++)
236                {
237                        copy[i][j]=(*this)[j][i];
238                }
239        }
240        return copy;
241}
242
[96]243void matrix_inverse(const float *Min, float *Mout, int actualsize);
[85]244
245template<class T>
[96]246Matrix<T> Matrix<T>::Inverted() const
247{
248        assert(m_Rows==m_Cols); // only works for square matrices
249        Matrix<T> ret(m_Rows,m_Cols);
250        matrix_inverse(GetRawDataConst(),ret.GetRawData(),m_Rows);
251        return ret;
252}
253
254template<class T>
[85]255Matrix<T> Matrix<T>::operator+(const Matrix &other) const
256{
257        assert(m_Rows=other.m_Rows);
258        assert(m_Cols=other.m_Cols);
259       
260        Matrix<T> ret(m_Rows,m_Cols);
261        for (unsigned int i=0; i<m_Rows; i++)
262        {
263                for (unsigned int j=0; j<m_Cols; j++)
264                {
265                        ret[i][j]=(*this)[i][j]+other[i][j];
266                }
267        }
268        return ret;
269}
270
271template<class T>
272Matrix<T> Matrix<T>::operator-(const Matrix &other) const
273{
274        assert(m_Rows=other.m_Rows);
275        assert(m_Cols=other.m_Cols);
276       
277        Matrix<T> ret(m_Rows,m_Cols);
278        for (unsigned int i=0; i<m_Rows; i++)
279        {
280                for (unsigned int j=0; j<m_Cols; j++)
281                {
282                        ret[i][j]=(*this)[i][j]-other[i][j];
283                }
284        }
285        return ret;
286}
287
288template<class T>
289Matrix<T> Matrix<T>::operator*(const Matrix &other) const
290{
291        assert(m_Cols==other.m_Rows);
292       
293        Matrix<T> ret(m_Rows,other.m_Cols);
294       
295        for (unsigned int i=0; i<m_Rows; i++)
296        {
297                for (unsigned int j=0; j<other.m_Cols; j++)
298                {
299                        ret[i][j]=0;
300                        for (unsigned int k=0; k<m_Cols; k++)
301                        {
302                                ret[i][j]+=(*this)[i][k]*other[k][j];
303                        }
304                }
305        }
306        return ret;
307}
308
309template<class T>
[90]310Vector<T> Matrix<T>::operator*(const Vector<T> &other) const
311{
312        assert(m_Cols==other.Size());
313       
314        Vector<T> ret(m_Rows);
315       
316        for (unsigned int i=0; i<m_Rows; i++)
317        {
318                for (unsigned int j=0; j<other.Size(); j++)
319                {
320                        ret[i]=0;
321                        for (unsigned int k=0; k<m_Cols; k++)
322                        {
323                                ret[i]+=(*this)[i][k]*other[k];
324                        }
325                }
326        }
327        return ret;
328}
329
330template<class T>
331Vector<T> Matrix<T>::VecMulTransposed(const Vector<T> &other) const
332{
333        assert(m_Rows==other.Size());
334       
335        Vector<T> ret(m_Cols);
336       
337        for (unsigned int i=0; i<m_Cols; i++)
338        {
339                for (unsigned int j=0; j<other.Size(); j++)
340                {
341                        ret[i]=0;
342                        for (unsigned int k=0; k<m_Rows; k++)
343                        {
344                                ret[i]+=(*this)[k][i]*other[k];
345                        }
346                }
347        }
348        return ret;
349}
350
351template<class T>
[85]352Matrix<T> &Matrix<T>::operator+=(const Matrix &other)
353{
354        (*this)=(*this)+other;
355        return *this;
356}
357
358template<class T>
359Matrix<T> &Matrix<T>::operator-=(const Matrix &other)
360{
361        (*this)=(*this)-other;
362        return *this;
363}
364
365template<class T>
366Matrix<T> &Matrix<T>::operator*=(const Matrix &other)
367{
368        (*this)=(*this)*other;
369        return *this;
370}
371
[96]372template<class T>
373bool Matrix<T>::operator==(const Matrix &other) const
374{
375        if (m_Rows != other.m_Rows ||
376                m_Cols != other.m_Cols)
377        {
378                return false;
379        }
380       
381        for (unsigned int i=0; i<m_Cols; i++)
382        {
383                for (unsigned int j=0; j<m_Rows; j++)
384                {
385                        if (!feq((*this)[i][j],other[i][j])) return false;
386                }
387        }
[89]388
[96]389        return true;
390}
391
[89]392//todo: use memcpy for these 4 functions
[86]393template<class T>
394Vector<T> Matrix<T>::GetRowVector(unsigned int r) const
395{
396        assert(r<m_Rows);
397        Vector<T> ret(m_Cols);
398        for (unsigned int j=0; j<m_Cols; j++)
399        {
400                ret[j]=(*this)[r][j];
401        }
402        return ret;
403}
[85]404
405template<class T>
[86]406Vector<T> Matrix<T>::GetColVector(unsigned int c) const
407{
408        assert(c<m_Cols);
409        Vector<T> ret(m_Rows);
410        for (unsigned int i=0; i<m_Rows; i++)
411        {
412                ret[i]=(*this)[i][c];
413        }
414        return ret;
415}
416
417template<class T>
[89]418void Matrix<T>::SetRowVector(unsigned int r, const Vector<T> &row)
419{
420        assert(r<m_Rows);
421        assert(row.Size()==m_Cols);
422        for (unsigned int j=0; j<m_Cols; j++)
423        {
424                (*this)[r][j]=row[j];
425        }
426}
427
428template<class T>
429void Matrix<T>::SetColVector(unsigned int c, const Vector<T> &col)
430{
431        assert(c<m_Cols);
432        assert(col.Size()==m_Rows);
433        for (unsigned int i=0; i<m_Rows; i++)
434        {
435                (*this)[i][c]=col[i];
436        }
437}
438
439// sort rows by v
440template<class T>
441void Matrix<T>::SortRows(Vector<T> &v)
442{
443        assert(v.Size()==m_Rows);
444       
445        bool sorted=false;
446        while(!sorted)
447        {
448                sorted=true;
449               
450                for (unsigned int i=0; i<v.Size()-1; i++)
451                {
452                        if (v[i]<v[i+1])
453                        {
454                                sorted=false;
455                                float vtmp = v[i];
456                                v[i]=v[i+1];
457                                v[i+1]=vtmp;
458                               
459                                Vector<float> rtmp = GetRowVector(i);
460                                SetRowVector(i,GetRowVector(i+1));
461                                SetRowVector(i+1,rtmp);
462                        }
463                }
464        }
465}
466
467// sort cols by v
468template<class T>
469void Matrix<T>::SortCols(Vector<T> &v)
470{
471        assert(v.Size()==m_Cols);
472       
473        bool sorted=false;
474        while(!sorted)
475        {
476                sorted=true;
477               
478                for (unsigned int i=0; i<v.Size()-1; i++)
479                {
480                        if (v[i]<v[i+1])
481                        {
482                                sorted=false;
483                                float vtmp = v[i];
484                                v[i]=v[i+1];
485                                v[i+1]=vtmp;
486                               
487                                Vector<float> rtmp = GetColVector(i);
488                                SetColVector(i,GetColVector(i+1));
489                                SetColVector(i+1,rtmp);
490                        }
491                }
492        }
493}
[90]494       
495template<class T>
496Matrix<T> Matrix<T>::CropRows(unsigned int s, unsigned int e)
497{
498        assert(s<e);
499        assert(s<m_Rows);
500        assert(e<=m_Rows);
501       
502        Matrix r(e-s,m_Cols);
503        unsigned int c=0;
504        for(unsigned int i=s; i<e; i++)
505        {
506                r.SetRowVector(c,GetRowVector(i));
507                c++;
508        }
509       
510        return r;
511}
[89]512
[90]513template<class T>
514Matrix<T> Matrix<T>::CropCols(unsigned int s, unsigned int e)
515{
516        assert(s<e);
517        assert(s<m_Cols);
518        assert(e<=m_Cols);
519       
520        Matrix r(m_Rows,e-s);
521        unsigned int c=0;
522        for(unsigned int i=s; i<e; i++)
523        {
524                r.SetColVector(c,GetColVector(i));
525                c++;
526        }
527       
528        return r;
529}
[89]530
531template<class T>
[90]532void Matrix<T>::Save(FILE* f)
533{
534        int version = 1;       
535        fwrite(&version,1,sizeof(version),f);
536        fwrite(&m_Rows,1,sizeof(m_Rows),f);
537        fwrite(&m_Cols,1,sizeof(m_Cols),f);
538        fwrite(m_Data,1,sizeof(T)*m_Rows*m_Cols,f);
539}
540
541template<class T>
542void Matrix<T>::Load(FILE* f)
543{
544        int version;   
545        fread(&version,sizeof(version),1,f);
546        fread(&m_Rows,sizeof(m_Rows),1,f);
547        fread(&m_Cols,sizeof(m_Cols),1,f);
548        m_Data=new T[m_Rows*m_Cols];
549        fread(m_Data,sizeof(T)*m_Rows*m_Cols,1,f);
550}
551
552template<class T>
[85]553void Matrix<T>::RunTests()
554{
[90]555        Vector<T>::RunTests();
556
[85]557        Matrix<T> m(10,10);
558        m.SetAll(0);
559        assert(m[0][0]==0);
560        m[5][2]=0.5;
561        assert(m[5][2]==0.5);
562        Matrix<T> om(m);
563        assert(om[5][2]==0.5);
564        Matrix<T> a(2,3);
565        a[0][0]=1; a[0][1]=2; a[0][2]=3;
566        a[1][0]=4; a[1][1]=5; a[1][2]=6;
567        Matrix<T> b(3,1);
568        b[0][0]=3;
569        b[1][0]=1;
570        b[2][0]=2;
571        Matrix<T> c=a*b;
572        assert(c[0][0]==11 && c[1][0]==29);
[90]573       
574        // test matrix * vector
575        Vector<T> d(3);
576        d[0]=3;
577        d[1]=1;
578        d[2]=2;
579        Vector<T> e=a*d;
580        assert(e[0]==11 && e[1]==29);
581       
582        Matrix<T> f=a.CropCols(1,3);
583        assert(f.GetRows()==2 && f.GetCols()==2 && f[0][0]==2);
584        Matrix<T> g=a.CropRows(0,1);
585        assert(g.GetRows()==1 && g.GetCols()==3 && g[0][0]==1);
[96]586       
587        // test matrix invert
588        Matrix<T> h(3,3);
589        h.Zero();
590        h[0][0]=1;
591        h[1][1]=1;
592        h[2][2]=1;
593        Matrix<T> i=h.Inverted();
594        i==h;
595       
596        // some transforms from fluxus
597        Matrix<T> j(4,4);       
598        j[0][0]=1.0;                   
599        j[0][1]=0.0 ;                           
600        j[0][2]=0.0;                           
601        j[0][3]=0.0;                           
602       
603        j[1][0]=0.0                     ;       
604        j[1][1]=0.7071067690849304 ;
605        j[1][2]=0.7071067690849304 ;
606        j[1][3]=0.0                             ;
607       
608        j[2][0]=0.0                             ;
609        j[2][1]=-0.7071067690849304 ;
610        j[2][2]=0.7071067690849304  ;
611        j[2][3]=0.0                             ;
612       
613        j[3][0]=1.0                             ;
614        j[3][1]=2.0                             ;
615        j[3][2]=3.0                             ;
616        j[3][3]=1.0                             ;
[90]617
[96]618        Matrix<T> k(4,4);
619        k[0][0]=1.0                              ;
620        k[0][1]=0.0                              ;
621        k[0][2]=0.0                              ;
622        k[0][3]=0.0                              ;
623
624        k[1][0]=0.0                              ;
625        k[1][1]=0.7071068286895752   ;
626        k[1][2]=-0.7071068286895752  ;
627        k[1][3]=0.0                              ;
628
629        k[2][0]=0.0                              ;
630        k[2][1]=0.7071068286895752   ;
631        k[2][2]=0.7071068286895752   ;
632        k[2][3]=0.0                              ;
633
634        k[3][0]=-0.9999999403953552  ;
635        k[3][1]=-3.535533905029297   ;
636        k[3][2]=-0.7071067690849304  ;
637        k[3][3]=0.9999999403953552       ;
638       
639        assert(j.Inverted()==k);
[85]640}
641
642#endif
Note: See TracBrowser for help on using the repository browser.