From 0b9e66ad2ff686a4dcf8a6838f33edb203a1bff5 Mon Sep 17 00:00:00 2001 From: Justin Berger Date: Sat, 17 Mar 2018 09:32:22 -0600 Subject: Fixed gemm --- redist/dclapack.h | 33 +++++++++++++++++---------------- redist/dclhelpers.c | 9 +++++---- redist/test_dcl.c | 23 +++++++++++++++-------- 3 files changed, 37 insertions(+), 28 deletions(-) (limited to 'redist') diff --git a/redist/dclapack.h b/redist/dclapack.h index af8869c..af5035b 100644 --- a/redist/dclapack.h +++ b/redist/dclapack.h @@ -224,22 +224,23 @@ PRINT(Ainv,n,n); \ /* * Matrix Multiply R = alpha * A * B + beta * C * R (n by p) - * A (n by m) - * B (m by p) - * C (n by p) + * A (m by n) + * B (n by p) + * C (m by p) */ -#define GMULADD(R,A,B,C,alpha,beta,n,m,p) { \ - int i,j,k; \ - float sum; \ - for (i=0; i #include "dclapack.h" #include +#include +#include void dclPrint( const DCL_FLOAT * PMATRIX, int PMATRIXc, int n, int m ) { @@ -77,7 +78,7 @@ void dcldgemm( int Cc //must be n ) { - const DCL_FLOAT * ta; + const DCL_FLOAT *ta; const DCL_FLOAT * tb; int tac = Ac; int tbc = Bc; @@ -102,7 +103,7 @@ void dcldgemm( } else tb = B; - - GMULADD(C,ta,tb,C,alpha,beta,n,m,k); + printf("%d %d %d\n", tac, tbc, Cc); + GMULADD(C, ta, tb, C, alpha, beta, m, n, k); } diff --git a/redist/test_dcl.c b/redist/test_dcl.c index 6d49548..42f4fd6 100644 --- a/redist/test_dcl.c +++ b/redist/test_dcl.c @@ -1,5 +1,6 @@ #include "dclhelpers.h" #include +#include #include #include @@ -36,18 +37,24 @@ int main() dclIdentity(A[0], 4, 3); dclPrint(A[0], 4, 3, 4); - FLT x[4] = {7, 8, 9, 10}; - FLT R[4]; + FLT x[4][2] = { + {7, -7}, {8, -8}, {9, -9}, {10, -10}, + }; + FLT R[4][2]; + printf("%p %p %p\n", A, x, R); // dclMul(R, 1, A[0], 4, x, 1, 4, 1, 3); - dcldgemm(0, 0, 4, 1, 3, 1, A[0], 4, x, 1, 0, R, 1); + dcldgemm(0, 0, 3, 4, 2, 1, A[0], 4, x[0], 2, 0, R[0], 2); - dclPrint(x, 1, 4, 1); - dclPrint(R, 1, 4, 1); + dclPrint(x[0], 2, 4, 2); + dclPrint(R[0], 2, 4, 2); - for (int i = 0; i < 3; i++) - assert(R[i] == x[i]); - assert(R[3] == 0.); + for (int j = 0; j < 2; j++) { + for (int i = 0; i < 3; i++) + assert(R[i][j] == x[i][j]); + + assert(fabs(R[3][j]) < .0000001); + } } // void dclTransp( DCL_FLOAT * R, int Rc, const DCL_FLOAT * A, int Ac, int n, int m ); -- cgit v1.2.3