00001 #ifndef __VMML__VMMLIB_LAPACK_SVD__HPP__
00002 #define __VMML__VMMLIB_LAPACK_SVD__HPP__
00003
00004 #include <vmmlib/matrix.hpp>
00005 #include <vmmlib/vector.hpp>
00006 #include <vmmlib/exception.hpp>
00007
00008 #include <vmmlib/lapack_types.hpp>
00009 #include <vmmlib/lapack_includes.hpp>
00010
00011 #include <string>
00012
00035 namespace vmml
00036 {
00037
00038 namespace lapack
00039 {
00040
00041
00042
00043
00044
00045
00046
00047 template< typename float_t >
00048 struct svd_params
00049 {
00050 char jobu;
00051 char jobvt;
00052 lapack_int m;
00053 lapack_int n;
00054 float_t* a;
00055 lapack_int lda;
00056 float_t* s;
00057 float_t* u;
00058 lapack_int ldu;
00059 float_t* vt;
00060 lapack_int ldvt;
00061 float_t* work;
00062 lapack_int lwork;
00063 lapack_int info;
00064
00065 friend std::ostream& operator << ( std::ostream& os,
00066 const svd_params< float_t >& p )
00067 {
00068 os
00069 << "jobu " << p.jobu
00070 << " jobvt " << p.jobvt
00071 << " m " << p.m
00072 << " n " << p.n
00073 << " lda " << p.lda
00074 << " ldu " << p.ldu
00075 << " ldvt " << p.ldvt
00076 << " lwork " << p.lwork
00077 << " info " << p.info
00078 << std::endl;
00079 return os;
00080 }
00081
00082 };
00083
00084
00085 #if 0
00086 int dgesvd_(char *jobu, char *jobvt, integer *m, integer *n,
00087 doublereal *a, integer *lda, doublereal *s, doublereal *u, integer *
00088 ldu, doublereal *vt, integer *ldvt, doublereal *work, integer *lwork,
00089 integer *info);
00090 #endif
00091
00092
00093 template< typename float_t >
00094 inline void
00095 svd_call( svd_params< float_t >& p )
00096 {
00097 VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
00098 }
00099
00100
00101 template<>
00102 inline void
00103 svd_call( svd_params< float >& p )
00104 {
00105
00106 sgesvd_(
00107 &p.jobu,
00108 &p.jobvt,
00109 &p.m,
00110 &p.n,
00111 p.a,
00112 &p.lda,
00113 p.s,
00114 p.u,
00115 &p.ldu,
00116 p.vt,
00117 &p.ldvt,
00118 p.work,
00119 &p.lwork,
00120 &p.info
00121 );
00122
00123 }
00124
00125
00126 template<>
00127 inline void
00128 svd_call( svd_params< double >& p )
00129 {
00130
00131 dgesvd_(
00132 &p.jobu,
00133 &p.jobvt,
00134 &p.m,
00135 &p.n,
00136 p.a,
00137 &p.lda,
00138 p.s,
00139 p.u,
00140 &p.ldu,
00141 p.vt,
00142 &p.ldvt,
00143 p.work,
00144 &p.lwork,
00145 &p.info
00146 );
00147 }
00148
00149 }
00150
00151
00152
00153 template< size_t M, size_t N, typename float_t >
00154 struct lapack_svd
00155 {
00156 lapack_svd();
00157 ~lapack_svd();
00158
00159
00160 bool compute(
00161 const matrix< M, N, float_t >& A,
00162 matrix< M, N, float_t >& U,
00163 vector< N, float_t >& sigma,
00164 matrix< N, N, float_t >& Vt
00165 );
00166
00167
00168 bool compute_and_overwrite_input(
00169 matrix< M, N, float_t >& A_U,
00170 vector< N, float_t >& sigma
00171 );
00172
00173
00174 bool compute(
00175 const matrix< M, N, float_t >& A,
00176 vector< N, float_t >& sigma
00177 );
00178
00179 inline bool test_success( lapack::lapack_int info );
00180
00181 lapack::svd_params< float_t > p;
00182
00183 const lapack::svd_params< float_t >& get_params(){ return p; };
00184
00185 };
00186
00187
00188 template< size_t M, size_t N, typename float_t >
00189 lapack_svd< M, N, float_t >::lapack_svd()
00190 {
00191 p.jobu = 'N';
00192 p.jobvt = 'N';
00193 p.m = M;
00194 p.n = N;
00195 p.a = 0;
00196 p.lda = M;
00197 p.s = 0;
00198 p.u = 0;
00199 p.ldu = M;
00200 p.vt = 0;
00201 p.ldvt = 1;
00202 p.work = new float_t;
00203 p.lwork = -1;
00204
00205
00206 lapack::svd_call( p );
00207
00208 p.lwork = static_cast< lapack::lapack_int >( p.work[0] );
00209 delete p.work;
00210
00211 p.work = new float_t[ p.lwork ];
00212
00213 }
00214
00215
00216
00217 template< size_t M, size_t N, typename float_t >
00218 lapack_svd< M, N, float_t >::~lapack_svd()
00219 {
00220 delete[] p.work;
00221 }
00222
00223
00224
00225 template< size_t M, size_t N, typename float_t >
00226 bool
00227 lapack_svd< M, N, float_t >::compute(
00228 const matrix< M, N, float_t >& A,
00229 matrix< M, N, float_t >& U,
00230 vector< N, float_t >& S,
00231 matrix< N, N, float_t >& Vt
00232 )
00233 {
00234
00235 matrix< M, N, float_t > AA( A );
00236
00237 p.jobu = 'A';
00238 p.jobvt = 'A';
00239 p.a = AA.array;
00240 p.u = U.array;
00241 p.s = S.array;
00242 p.vt = Vt.array;
00243 p.ldvt = N;
00244
00245 lapack::svd_call< float_t >( p );
00246
00247 return p.info == 0;
00248 }
00249
00250
00251
00252 template< size_t M, size_t N, typename float_t >
00253 bool
00254 lapack_svd< M, N, float_t >::compute_and_overwrite_input(
00255 matrix< M, N, float_t >& A_U,
00256 vector< N, float_t >& S
00257 )
00258 {
00259 p.jobu = 'O';
00260 p.jobvt = 'N';
00261 p.a = A_U.array;
00262 p.s = S.array;
00263 p.ldvt = N;
00264
00265 lapack::svd_call< float_t >( p );
00266
00267 return p.info == 0;
00268 }
00269
00270
00271
00272 template< size_t M, size_t N, typename float_t >
00273 bool
00274 lapack_svd< M, N, float_t >::compute(
00275 const matrix< M, N, float_t >& A,
00276 vector< N, float_t >& S
00277 )
00278 {
00279
00280 matrix< M, N, float_t > AA( A );
00281
00282 p.jobu = 'N';
00283 p.jobvt = 'N';
00284 p.a = AA.array;
00285 p.u = 0;
00286 p.s = S.array;
00287 p.vt = 0;
00288
00289 lapack::svd_call< float_t >( p );
00290
00291 return p.info == 0;
00292 }
00293
00294
00295 }
00296
00297 #endif
00298