00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef __UT_SparseMatrix_H__
00026 #define __UT_SparseMatrix_H__
00027
00028 #include "UT_API.h"
00029
00030 #include "UT_ThreadedAlgorithm.h"
00031 #include "UT_Vector.h"
00032
00033 #include <SYS/SYS_Types.h>
00034
00035 #include <iterator>
00036
00037 class UT_IStream;
00038
00039 template <typename T> class UT_SparseMatrixRowT;
00040
00041 template <typename T, bool IsPaged>
00042 class UT_API UT_SparseMatrixT
00043 {
00044 class ut_MatrixCell
00045 {
00046 public:
00047 int myRow;
00048 int myCol;
00049 T myValue;
00050
00051 inline bool operator<(const ut_MatrixCell &o) const
00052 {
00053 if (myRow < o.myRow)
00054 return true;
00055 if (myRow == o.myRow && myCol < o.myCol)
00056 return true;
00057 return false;
00058 }
00059 };
00060
00061
00062
00063 static const int CELL_PAGESIZE = 1024;
00064 static const int CELL_PAGEMASK = 1023;
00065 static const int CELL_PAGEBITS = 10;
00066
00067 class ut_CellIterator : public std::iterator<random_access_iterator_tag, ut_MatrixCell>
00068 {
00069 public:
00070 ut_CellIterator(const UT_SparseMatrixT<T, IsPaged> *matrix, int idx)
00071 { myMatrix = matrix; myPos = idx; repage(); }
00072
00073 ut_CellIterator operator+(ptrdiff_t n) { return ut_CellIterator(myMatrix, myPos+n); }
00074 ut_CellIterator operator-(ptrdiff_t n) { return ut_CellIterator(myMatrix, myPos-n); }
00075 ut_CellIterator &operator+=(ptrdiff_t n) { myPos += n; repage(); return *this; }
00076 ut_CellIterator &operator-=(ptrdiff_t n) { myPos -= n; repage(); return *this; }
00077 ut_CellIterator &operator++() { myPos++; myOffset++; myData++; if (myOffset >= CELL_PAGESIZE) repage(); return *this; }
00078 ut_CellIterator &operator--() { myPos--; myOffset--; myData--; if (myOffset < 0) repage(); return *this; }
00079 ut_CellIterator operator++(int) { ut_CellIterator result = *this; myPos++; myOffset++; myData++; if (myOffset >= CELL_PAGESIZE) repage(); return result; }
00080 ut_CellIterator operator--(int) { ut_CellIterator result = *this; myPos--; myOffset--; myData--; if (myOffset < 0) repage(); return result; }
00081 int operator-(ut_CellIterator b) { return myPos - b.myPos; }
00082 ut_MatrixCell &operator[](ptrdiff_t idx) { return myMatrix->getCell(idx + myPos); }
00083
00084 bool operator<(ut_CellIterator b) { return myPos < b.myPos; }
00085 bool operator==(ut_CellIterator b) { return myPos == b.myPos; }
00086 bool operator!=(ut_CellIterator b) { return myPos != b.myPos; }
00087
00088 ut_MatrixCell &operator*() const { return *myData; }
00089 ut_MatrixCell *operator->() const { return myData; }
00090
00091 protected:
00092
00093 void repage()
00094 { myPage = myPos >> CELL_PAGEBITS; myOffset = myPos & CELL_PAGEMASK;
00095 myData = 0;
00096 if (myPage < myMatrix->myCellPages.entries())
00097 myData = &myMatrix->myCellPages(myPage)[myOffset];
00098 }
00099
00100 const UT_SparseMatrixT<T, IsPaged> *myMatrix;
00101 ut_MatrixCell *myData;
00102 ptrdiff_t myPos;
00103 int myPage, myOffset;
00104 };
00105
00106
00107
00108 static const int CELLBITS = 2;
00109 static const int CELLSIZE = 1 << CELLBITS;
00110 static const int CELLMASK = CELLSIZE-1;
00111
00112 class ut_4MatrixCell
00113 {
00114 public:
00115 T myValue[CELLSIZE];
00116 int myRow;
00117 int myCol;
00118 } SYS_ALIGN16;
00119
00120 public:
00121 UT_SparseMatrixT();
00122
00123
00124 UT_SparseMatrixT(int rows, int cols);
00125 UT_SparseMatrixT(const UT_SparseMatrixT<T, IsPaged> &m);
00126 ~UT_SparseMatrixT();
00127
00128 int getNumRows() const { return myNumRows; }
00129 int getNumCols() const { return myNumCols; }
00130
00131
00132 int64 getMemoryUsage() const;
00133 int64 getIdealMemoryUsage() const;
00134
00135
00136
00137 void reserve(int numcells);
00138
00139
00140 void shrinkToFit();
00141
00142
00143 void init(int rows, int cols);
00144
00145
00146 void zero();
00147
00148
00149
00150 bool shouldMultiThread() const
00151 {
00152 #ifdef CELLBE
00153 return false;
00154 #else
00155 return getNumRows() > 5000;
00156 #endif
00157 }
00158
00159
00160 bool addToElement(int row, int col, T value);
00161
00162
00163
00164
00165
00166 int findCellFromRow(int row) const;
00167
00168
00169 THREADED_METHOD2_CONST(UT_SparseMatrixT, shouldMultiThread(), multVec,
00170 const UT_VectorT<T> &, v,
00171 UT_VectorT<T> &, result)
00172 void multVecPartial(const UT_VectorT<T> &v, UT_VectorT<T> &result,
00173 const UT_JobInfo &info) const;
00174
00175
00176 THREADED_METHOD2_CONST(UT_SparseMatrixT, shouldMultiThread(),
00177 subtractMultVec,
00178 const UT_VectorT<T>&, v,
00179 UT_VectorT<T>&, result)
00180 void subtractMultVecPartial(const UT_VectorT<T>& v, UT_VectorT<T>& result,
00181 const UT_JobInfo& info) const;
00182
00183
00184 void transposeMultVec(const UT_VectorT<T> &v, UT_VectorT<T> &result) const;
00185
00186
00187
00188
00189 void allColNorm2(UT_VectorT<T> &result) const;
00190
00191
00192
00193
00194
00195
00196
00197 void extractSubMatrix(UT_SparseMatrixT<T, IsPaged> &out,
00198 int rowstart, int rowend,
00199 int colstart, int colend) const;
00200
00201
00202 void extractSubMatrixUncompiled(UT_SparseMatrixT<T, IsPaged> &out,
00203 int rowstart, int rowend,
00204 int colstart, int colend) const;
00205
00206
00207
00208
00209
00210 void extractDiagonal(UT_VectorT<T> &out, int idx = 0) const;
00211
00212
00213
00214
00215 void extractNondiagonal(UT_SparseMatrixT<T, IsPaged> &out) const;
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226 int incompleteCholeskyFactorization(T tol=1e-5);
00227
00228
00229
00230
00231
00232 int modifiedIncompleteCholesky(T tau = 0.97,
00233 T mindiagratio = 0.25,
00234 T tol=1e-5);
00235
00236
00237
00238
00239
00240
00241
00242
00243 int solveLowerTriangular(UT_VectorT<T> &x, const UT_VectorT<T> &b,
00244 T tol=1e-5) const;
00245 int solveUpperTriangular(UT_VectorT<T> &x, const UT_VectorT<T> &b,
00246 T tol=1e-5) const;
00247
00248
00249 int solveLowerTriangularTransposeNegate(UT_VectorT<T> &x,
00250 const UT_VectorT<T> &b,
00251 T tol=1e-5) const;
00252
00253
00254
00255 bool solveConjugateGradient(UT_VectorT<T> &x, const UT_VectorT<T> &b,
00256 bool (*callback_func)(void *) = 0,
00257 void *callback_data = 0, T tol=1e-5,
00258 int max_iters = -1) const;
00259
00260
00261 THREADED_METHOD(UT_SparseMatrixT, shouldMultiThread(), transpose)
00262 void transposePartial(const UT_JobInfo &info);
00263
00264
00265
00266
00267 void transposeCompiled(const UT_SparseMatrixT<T, IsPaged> &src);
00268
00269
00270 THREADED_METHOD(UT_SparseMatrixT, shouldMultiThread(), negate)
00271 void negatePartial(const UT_JobInfo &info);
00272
00273 UT_SparseMatrixT<T, IsPaged> &operator=(const UT_SparseMatrixT<T, IsPaged> &m);
00274 UT_SparseMatrixT<T, IsPaged> &operator*=(T scalar);
00275 UT_SparseMatrixT<T, IsPaged> &operator+=(const UT_SparseMatrixT<T, IsPaged> &m);
00276
00277
00278
00279 void printFull(ostream &os) const;
00280
00281
00282 void printSparse(ostream &os) const;
00283
00284
00285 void save(ostream &os) const;
00286 void load(UT_IStream &is);
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296 void compile() const;
00297 bool isCompiled() const { return myCompiledFlag; }
00298 bool isStillSorted() const { return myStillSortedFlag; }
00299
00300 private:
00301
00302 void forceCompile() const;
00303
00304
00305 inline ut_MatrixCell &getCell(int idx) const
00306 {
00307 if (IsPaged)
00308 return myCellPages(idx >> CELL_PAGEBITS)[idx & CELL_PAGEMASK];
00309 else
00310 return myCells[idx];
00311 }
00312
00313
00314
00315 void compile4() const;
00316
00317 int myNumRows;
00318 int myNumCols;
00319 mutable UT_PtrArray<ut_MatrixCell *> myCellPages;
00320 mutable ut_MatrixCell *myCells;
00321 mutable int myCount;
00322 mutable int *my4RowOffsets;
00323 mutable ut_4MatrixCell *my4Cells;
00324 mutable int my4Count;
00325 int myMaxSize;
00326 mutable bool myCompiledFlag;
00327 mutable bool myStillSortedFlag;
00328 mutable bool my4CompiledFlag;
00329
00330 friend class UT_SparseMatrixRowT<T>;
00331 };
00332
00333 typedef UT_SparseMatrixT<fpreal32, false> UT_SparseMatrixF;
00334 typedef UT_SparseMatrixT<fpreal64, false> UT_SparseMatrixD;
00335 typedef UT_SparseMatrixT<fpreal64, false> UT_SparseMatrix;
00336
00337
00338
00339
00340
00341
00342
00343
00344 template <typename T>
00345 class UT_API UT_SparseMatrixRowT
00346 {
00347 class ut_MatrixCell
00348 {
00349 public:
00350 int myCol;
00351 T myValue;
00352 };
00353
00354 public:
00355 UT_SparseMatrixRowT();
00356 ~UT_SparseMatrixRowT();
00357
00358
00359
00360
00361 void buildFrom(UT_SparseMatrixT<T, false> &m,
00362 bool invertdiag = false,
00363 T tol=1e-5f);
00364
00365 int getNumRows() const { return myNumRows; }
00366 int getNumCols() const { return myNumCols; }
00367
00368
00369 int64 getMemoryUsage() const;
00370
00371
00372 bool shouldMultiThread() const
00373 {
00374 return getNumRows() > 5000;
00375 }
00376
00377
00378
00379
00380
00381 int findCellFromRow(int row) const
00382 { return myRowOffsets(row); }
00383
00384
00385 THREADED_METHOD2_CONST(UT_SparseMatrixRowT, shouldMultiThread(), multVec,
00386 const UT_VectorT<T> &, v,
00387 UT_VectorT<T> &, result)
00388 void multVecPartial(const UT_VectorT<T> &v, UT_VectorT<T> &result,
00389 const UT_JobInfo &info) const;
00390
00391
00392 THREADED_METHOD3_CONST(UT_SparseMatrixRowT, shouldMultiThread(),
00393 multVecAndDot,
00394 const UT_VectorT<T> &, v,
00395 UT_VectorT<T> &, result,
00396 fpreal64 *, dotpq)
00397 void multVecAndDotPartial(const UT_VectorT<T> &v, UT_VectorT<T> &result,
00398 fpreal64 *dotpq,
00399 const UT_JobInfo &info) const;
00400
00401
00402
00403
00404
00405
00406
00407
00408 int solveUpperTriangular(UT_VectorT<T> &x, const UT_VectorT<T> &b,
00409 T tol=1e-5) const;
00410
00411
00412 int solveLowerTriangularTransposeNegate(UT_VectorT<T> &x,
00413 const UT_VectorT<T> &b,
00414 T tol=1e-5) const;
00415
00416
00417
00418
00419
00420
00421
00422 float solveConjugateGradient(UT_VectorT<T> &x, const UT_VectorT<T> &b,
00423 UT_SparseMatrixRowT<T> *GT, T tol=1e-5,
00424 int max_iters = -1) const;
00425 private:
00426
00427 inline ut_MatrixCell &getCell(int idx) const
00428 {
00429 return myCells[idx];
00430 }
00431
00432 int myNumRows;
00433 int myNumCols;
00434 ut_MatrixCell *myCells;
00435 int myCount;
00436 UT_IntArray myRowOffsets;
00437 UT_VectorT<T> myDiagonal;
00438 };
00439
00440 typedef UT_SparseMatrixRowT<fpreal32> UT_SparseMatrixRowF;
00441 typedef UT_SparseMatrixRowT<fpreal64> UT_SparseMatrixRowD;
00442
00443 #endif