1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
|
template<class T>
struct atlas_gemm_function
{
};
template<>
struct atlas_gemm_function<float>
{
template<class G1,class G2,class G3> inline void operator()(CBLAS_TRANSPOSE TransA,CBLAS_TRANSPOSE TransB, const Matrix<G1> &A, const Matrix<G2> &B, Matrix<G3> &C)
{
int m=C.nrows(), n=C.ncols(), k=A.ncols();
cblas_sgemm(AtlasRowMajor,TransA,TransB,m,n,k,1,&A(0,0),m,&B(0,0),n,0,&C(0,0),k);
}
};
template<>
struct atlas_gemm_function<double>
{
template<class G1,class G2,class G3> inline void operator()(CBLAS_TRANSPOSE TransA,CBLAS_TRANSPOSE TransB, const Matrix<G1> &A, const Matrix<G2> &B, Matrix<G3> &C)
{
int m=C.nrows(), n=C.ncols(), k=A.ncols();
cblas_dgemm(AtlasRowMajor,TransA,TransB,m,n,k,1,&A(0,0),m,&B(0,0),n,0,&C(0,0),k);
}
};
template<>
struct atlas_gemm_function<complex<float> >
{
typedef float real_type;
typedef complex<real_type> value_type;
template<class G1,class G2,class G3> inline void operator()(CBLAS_TRANSPOSE TransA,CBLAS_TRANSPOSE TransB, const Matrix<G1> &A, const Matrix<G2> &B, Matrix<G3> &C)
{
int m=C.nrows(), n=C.ncols(), k=A.ncols();
value_type alpha(1,0), beta(0,0);
cblas_cgemm(AtlasRowMajor,TransA,TransB,m,n,k,(real_type*)&alpha,(real_type*)&A(0,0),m,(real_type*)&B(0,0),n,(real_type*)&beta,(real_type*)&C(0,0),k);
}
};
template<>
struct atlas_gemm_function<complex<double> >
{
typedef double real_type;
typedef complex<real_type> value_type;
template<class G1,class G2,class G3> inline void operator()(CBLAS_TRANSPOSE TransA,CBLAS_TRANSPOSE TransB, const Matrix<G1> &A, const Matrix<G2> &B, Matrix<G3> &C)
{
int m=C.nrows(), n=C.ncols(), k=A.ncols();
value_type alpha(1,0), beta(0,0);
cblas_zgemm(AtlasRowMajor,TransA,TransB,m,n,k,(real_type*)&alpha,(real_type*)&A(0,0),m,(real_type*)&B(0,0),n,(real_type*)&beta,(real_type*)&C(0,0),k);
}
};
template<class G1,class G2,class G3>
void atlas_gemm(const Matrix<G1> &A, const Matrix<G2> &B, Matrix<G3> &C)
{
atlas_gemm_function<typename Matrix<G3>::value_type>()(AtlasNoTrans,AtlasNoTrans,A,B,C);
}
template<class G1,class G2,class G3>
void atlas_tgemm(const Matrix<G1> &A, const Matrix<G2> &B, Matrix<G3> &C)
{
atlas_gemm_function<typename Matrix<G3>::value_type>()(AtlasNoTrans,AtlasTrans,A,B,C);
} |