00001 #ifndef __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
00002 #define __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
00003
00004 #include <vmmlib/matrix.hpp>
00005 #include <vmmlib/vector.hpp>
00006 #include <vmmlib/exception.hpp>
00007
00008 #include <vmmlib/lapack_types.hpp>
00009 #include <vmmlib/lapack_includes.hpp>
00010
00011 #include <string>
00012
00023 namespace vmml
00024 {
00025
00026
00027
00028
00029
00030
00031 namespace lapack
00032 {
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042 template< typename float_t >
00043 struct llsq_params_xgels
00044 {
00045 char trans;
00046 lapack_int m;
00047 lapack_int n;
00048 lapack_int nrhs;
00049 float_t* a;
00050 lapack_int lda;
00051 float_t* b;
00052 lapack_int ldb;
00053 float_t* work;
00054 lapack_int lwork;
00055 lapack_int info;
00056
00057 friend std::ostream& operator << ( std::ostream& os,
00058 const llsq_params_xgels< float_t >& p )
00059 {
00060 os
00061 << " m " << p.m
00062 << " n " << p.n
00063 << " nrhs " << p.nrhs
00064 << " lda " << p.lda
00065 << " ldb " << p.ldb
00066 << " lwork " << p.lwork
00067 << " info " << p.info
00068 << std::endl;
00069 return os;
00070 }
00071
00072 };
00073
00074
00075
00076 #if 0
00077 void dgels_(const char *trans, const int *M, const int *N, const int *nrhs,
00078 double *A, const int *lda, double *b, const int *ldb, double *work,
00079 const int * lwork, int *info);
00080 #endif
00081
00082 template< typename float_t >
00083 inline void
00084 llsq_call_xgels( llsq_params_xgels< float_t >& p )
00085 {
00086 VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
00087 }
00088
00089 template<>
00090 inline void
00091 llsq_call_xgels( llsq_params_xgels< float >& p )
00092 {
00093 sgels_(
00094 &p.trans,
00095 &p.m,
00096 &p.n,
00097 &p.nrhs,
00098 p.a,
00099 &p.lda,
00100 p.b,
00101 &p.ldb,
00102 p.work,
00103 &p.lwork,
00104 &p.info
00105 );
00106 }
00107
00108 template<>
00109 inline void
00110 llsq_call_xgels( llsq_params_xgels< double >& p )
00111 {
00112 dgels_(
00113 &p.trans,
00114 &p.m,
00115 &p.n,
00116 &p.nrhs,
00117 p.a,
00118 &p.lda,
00119 p.b,
00120 &p.ldb,
00121 p.work,
00122 &p.lwork,
00123 &p.info
00124 );
00125
00126 }
00127
00128
00129 template< size_t M, size_t N, typename float_t >
00130 struct linear_least_squares_xgels
00131 {
00132 void compute(
00133 const matrix< M, N, float_t >& A,
00134 const vector< M, float_t >& B,
00135 vector< N, float_t >& x );
00136
00137 linear_least_squares_xgels();
00138 ~linear_least_squares_xgels();
00139
00140 const lapack::llsq_params_xgels< float_t >& get_params(){ return p; };
00141
00142 matrix< M, N, float_t >& get_factorized_A() { return _A; }
00143
00144 protected:
00145 matrix< M, N, float_t > _A;
00146 vector< M, float_t > _b;
00147
00148 llsq_params_xgels< float_t > p;
00149
00150 };
00151
00152
00153
00154 template< size_t M, size_t N, typename float_t >
00155 void
00156 linear_least_squares_xgels< M, N, float_t >::compute(
00157 const matrix< M, N, float_t >& A,
00158 const vector< M, float_t >& B,
00159 vector< N, float_t >& x )
00160 {
00161 _A = A;
00162 _b = B;
00163
00164 llsq_call_xgels( p );
00165
00166
00167 if ( p.info == 0 )
00168 {
00169 for( size_t index = 0; index < N; ++index )
00170 {
00171 x( index ) = _b( index );
00172 }
00173
00174 return;
00175 }
00176 if ( p.info < 0 )
00177 {
00178 VMMLIB_ERROR( "xGELS - invalid argument.", VMMLIB_HERE );
00179 }
00180 else
00181 {
00182 VMMLIB_ERROR( "least squares solution could not be computed.",
00183 VMMLIB_HERE );
00184 }
00185
00186 }
00187
00188
00189
00190 template< size_t M, size_t N, typename float_t >
00191 linear_least_squares_xgels< M, N, float_t >::
00192 linear_least_squares_xgels()
00193 {
00194 p.trans = 'N';
00195 p.m = M;
00196 p.n = N;
00197 p.nrhs = 1;
00198 p.a = _A.array;
00199 p.lda = M;
00200 p.b = _b.array;
00201 p.ldb = M;
00202 p.work = new float_t();
00203 p.lwork = -1;
00204
00205
00206 llsq_call_xgels( p );
00207
00208 p.lwork = static_cast< lapack_int > ( p.work[0] );
00209 delete p.work;
00210
00211 p.work = new float_t[ p.lwork ];
00212 }
00213
00214
00215
00216 template< size_t M, size_t N, typename float_t >
00217 linear_least_squares_xgels< M, N, float_t >::
00218 ~linear_least_squares_xgels()
00219 {
00220 delete[] p.work;
00221 }
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231 template< typename float_t >
00232 struct llsq_params_xgesv
00233 {
00234 lapack_int n;
00235 lapack_int nrhs;
00236 float_t* a;
00237 lapack_int lda;
00238 lapack_int* ipiv;
00239 float_t* b;
00240 lapack_int ldb;
00241 lapack_int info;
00242
00243 friend std::ostream& operator << ( std::ostream& os,
00244 const llsq_params_xgesv< float_t >& p )
00245 {
00246 os
00247 << "n " << p.n
00248 << " nrhs " << p.nrhs
00249 << " lda " << p.lda
00250 << " ldb " << p.ldvt
00251 << " info " << p.info
00252 << std::endl;
00253 return os;
00254 }
00255
00256 };
00257
00258
00259 #if 0
00260 int dgesv_(integer *n, integer *nrhs, doublereal *a, integer
00261 *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info);
00262 #endif
00263
00264
00265 template< typename float_t >
00266 inline void
00267 llsq_call_xgesv( llsq_params_xgesv< float_t >& p )
00268 {
00269 VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
00270 }
00271
00272
00273 template<>
00274 inline void
00275 llsq_call_xgesv( llsq_params_xgesv< float >& p )
00276 {
00277 sgesv_(
00278 &p.n,
00279 &p.nrhs,
00280 p.a,
00281 &p.lda,
00282 p.ipiv,
00283 p.b,
00284 &p.ldb,
00285 &p.info
00286 );
00287
00288 }
00289
00290
00291 template<>
00292 inline void
00293 llsq_call_xgesv( llsq_params_xgesv< double >& p )
00294 {
00295 dgesv_(
00296 &p.n,
00297 &p.nrhs,
00298 p.a,
00299 &p.lda,
00300 p.ipiv,
00301 p.b,
00302 &p.ldb,
00303 &p.info
00304 );
00305 }
00306
00307
00308 template< size_t M, size_t N, typename float_t >
00309 struct linear_least_squares_xgesv
00310 {
00311
00312 void compute(
00313 matrix< N, N, float_t >& A,
00314 matrix< N, M, float_t >& b
00315 );
00316
00317 linear_least_squares_xgesv();
00318 ~linear_least_squares_xgesv();
00319
00320 const lapack::llsq_params_xgesv< float_t >& get_params() { return p; }
00321
00322 lapack::llsq_params_xgesv< float_t > p;
00323
00324 };
00325
00326
00327 template< size_t M, size_t N, typename float_t >
00328 void
00329 linear_least_squares_xgesv< M, N, float_t >::
00330 compute(
00331 matrix< N, N, float_t >& A,
00332 matrix< N, M, float_t >& b
00333 )
00334 {
00335 p.a = A.array;
00336 p.b = b.array;
00337
00338 lapack::llsq_call_xgesv( p );
00339
00340 if ( p.info != 0 )
00341 {
00342 if ( p.info < 0 )
00343 VMMLIB_ERROR( "invalid value in input matrix", VMMLIB_HERE );
00344 else
00345 VMMLIB_ERROR( "factor U is exactly singular, solution could not be computed.", VMMLIB_HERE );
00346 }
00347 }
00348
00349
00350
00351 template< size_t M, size_t N, typename float_t >
00352 linear_least_squares_xgesv< M, N, float_t >::
00353 linear_least_squares_xgesv()
00354 {
00355 p.n = N;
00356 p.nrhs = M;
00357 p.lda = N;
00358 p.ldb = N;
00359 p.ipiv = new lapack_int[ N ];
00360
00361 }
00362
00363
00364
00365 template< size_t M, size_t N, typename float_t >
00366 linear_least_squares_xgesv< M, N, float_t >::
00367 ~linear_least_squares_xgesv()
00368 {
00369 delete[] p.ipiv;
00370 }
00371
00372
00373 }
00374
00375 }
00376
00377 #endif
00378