From a7a44dc9cc0f4a74adbc22ff86ad081ecf2383ba Mon Sep 17 00:00:00 2001 From: Justin Berger Date: Thu, 21 Jun 2018 16:22:24 +0000 Subject: Expanded minimal opencv functionality to properly use transpose flags --- redist/minimal_opencv.c | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) (limited to 'redist/minimal_opencv.c') diff --git a/redist/minimal_opencv.c b/redist/minimal_opencv.c index d569d96..c9cacf3 100644 --- a/redist/minimal_opencv.c +++ b/redist/minimal_opencv.c @@ -30,27 +30,39 @@ SURVIVE_LOCAL_ONLY void cvCopyTo(const CvMat *srcarr, CvMat *dstarr) { SURVIVE_LOCAL_ONLY void cvGEMM(const CvMat *src1, const CvMat *src2, double alpha, const CvMat *src3, double beta, CvMat *dst, int tABC) { - lapack_int rows1 = src1->rows; - lapack_int cols1 = src1->cols; - lapack_int rows2 = src2->rows; - lapack_int cols2 = src2->cols; + int rows1 = (tABC & GEMM_1_T) ? src1->cols : src1->rows; + int cols1 = (tABC & GEMM_1_T) ? src1->rows : src1->cols; - lapack_int lda = cols1; - lapack_int ldb = cols2; + int rows2 = (tABC & GEMM_2_T) ? src2->cols : src2->rows; + int cols2 = (tABC & GEMM_2_T) ? src2->rows : src2->cols; - assert(src1->cols == src2->rows); - assert(src1->rows == dst->rows); - assert(src2->cols == dst->cols); + assert(cols1 == rows2); + assert(rows1 == dst->rows); + assert(cols2 == dst->cols); + lapack_int lda = src1->cols; + lapack_int ldb = src2->cols; + if (src3) cvCopyTo(src3, dst); else beta = 0; - cblas_dgemm(CblasRowMajor, (tABC & GEMM_1_T) ? CblasTrans : CblasNoTrans, - (tABC & GEMM_2_T) ? CblasTrans : CblasNoTrans, src1->rows, dst->cols, src1->cols, alpha, src1->data.db, - lda, src2->data.db, ldb, beta, dst->data.db, dst->cols); + cblas_dgemm(CblasRowMajor, + (tABC & GEMM_1_T) ? CblasTrans : CblasNoTrans, + (tABC & GEMM_2_T) ? CblasTrans : CblasNoTrans, + dst->rows, + dst->cols, + cols1, + alpha, + src1->data.db, + lda, + src2->data.db, + ldb, + beta, + dst->data.db, + dst->cols); } SURVIVE_LOCAL_ONLY void cvMulTransposed(const CvMat *src, CvMat *dst, int order, const CvMat *delta, double scale) { @@ -296,10 +308,11 @@ SURVIVE_LOCAL_ONLY void cvTranspose(const CvMat *M, CvMat *dst) { if (inPlace) { tmp = cvCloneMat(dst); src = tmp->data.db; + } else { + assert(M->rows == dst->cols); + assert(M->cols == dst->rows); } - assert(M->rows == dst->cols); - assert(M->cols == dst->rows); for (unsigned i = 0; i < M->rows; i++) { for (unsigned j = 0; j < M->cols; j++) { dst->data.db[j * M->rows + i] = src[i * M->cols + j]; @@ -329,7 +342,7 @@ SURVIVE_LOCAL_ONLY void cvSVD(CvMat *aarr, CvMat *warr, CvMat *uarr, CvMat *varr lapack_int arows = aarr->rows, acols = aarr->cols; lapack_int ulda = uarr ? uarr->cols : 1; lapack_int plda = varr ? varr->cols : acols; - + double *superb = malloc(sizeof(double) * MIN(arows, acols)); inf = LAPACKE_dgesvd(LAPACK_ROW_MAJOR, jobu, jobvt, arows, acols, aarr->data.db, acols, warr ? warr->data.db : 0, uarr ? uarr->data.db : 0, ulda, varr ? varr->data.db : 0, plda, superb); -- cgit v1.2.3