aboutsummaryrefslogtreecommitdiff
path: root/redist
diff options
context:
space:
mode:
authorJustin Berger <j.david.berger@gmail.com>2018-03-17 09:32:22 -0600
committerJustin Berger <j.david.berger@gmail.com>2018-03-17 09:32:22 -0600
commit0b9e66ad2ff686a4dcf8a6838f33edb203a1bff5 (patch)
tree2ccc16e01e625901f00c59cc85535286116b1792 /redist
parent7c97cfe7f63650fc79ce4fa7f081b556ce275475 (diff)
downloadlibsurvive-0b9e66ad2ff686a4dcf8a6838f33edb203a1bff5.tar.gz
libsurvive-0b9e66ad2ff686a4dcf8a6838f33edb203a1bff5.tar.bz2
Fixed gemm
Diffstat (limited to 'redist')
-rw-r--r--redist/dclapack.h33
-rw-r--r--redist/dclhelpers.c9
-rw-r--r--redist/test_dcl.c23
3 files changed, 37 insertions, 28 deletions
diff --git a/redist/dclapack.h b/redist/dclapack.h
index af8869c..af5035b 100644
--- a/redist/dclapack.h
+++ b/redist/dclapack.h
@@ -224,22 +224,23 @@ PRINT(Ainv,n,n); \
/*
* Matrix Multiply R = alpha * A * B + beta * C
* R (n by p)
- * A (n by m)
- * B (m by p)
- * C (n by p)
+ * A (m by n)
+ * B (n by p)
+ * C (m by p)
*/
-#define GMULADD(R,A,B,C,alpha,beta,n,m,p) { \
- int i,j,k; \
- float sum; \
- for (i=0; i<n; i++) { \
- for (j=0; j<p; j++) { \
- sum = 0.0f; \
- for (k=0; k<m; k++) { \
- sum += _(A,i,k) * _(B,k,j); \
- } \
- _(R,i,j) = alpha * sum + beta * _(C,i,j); \
- } \
- } \
-}
+#define GMULADD(R, A, B, C, alpha, beta, m, n, p) \
+ { \
+ int _i, _j, _k; \
+ float sum; \
+ for (_i = 0; _i < m; _i++) { \
+ for (_j = 0; _j < p; _j++) { \
+ sum = 0.0f; \
+ for (_k = 0; _k < n; _k++) { \
+ sum += _(A, _i, _k) * _(B, _k, _j); \
+ } \
+ _(R, _i, _j) = alpha * sum + beta * _(C, _i, _j); \
+ } \
+ } \
+ }
#endif
diff --git a/redist/dclhelpers.c b/redist/dclhelpers.c
index fb6aba6..3e51fd2 100644
--- a/redist/dclhelpers.c
+++ b/redist/dclhelpers.c
@@ -1,9 +1,10 @@
#include "dclhelpers.h"
#define FLOAT DCL_FLOAT
#define DYNAMIC_INDEX
-#include <stdio.h>
#include "dclapack.h"
#include <alloca.h>
+#include <assert.h>
+#include <stdio.h>
void dclPrint( const DCL_FLOAT * PMATRIX, int PMATRIXc, int n, int m )
{
@@ -77,7 +78,7 @@ void dcldgemm(
int Cc //must be n
)
{
- const DCL_FLOAT * ta;
+ const DCL_FLOAT *ta;
const DCL_FLOAT * tb;
int tac = Ac;
int tbc = Bc;
@@ -102,7 +103,7 @@ void dcldgemm(
}
else
tb = B;
-
- GMULADD(C,ta,tb,C,alpha,beta,n,m,k);
+ printf("%d %d %d\n", tac, tbc, Cc);
+ GMULADD(C, ta, tb, C, alpha, beta, m, n, k);
}
diff --git a/redist/test_dcl.c b/redist/test_dcl.c
index 6d49548..42f4fd6 100644
--- a/redist/test_dcl.c
+++ b/redist/test_dcl.c
@@ -1,5 +1,6 @@
#include "dclhelpers.h"
#include <assert.h>
+#include <math.h>
#include <stdint.h>
#include <stdio.h>
@@ -36,18 +37,24 @@ int main()
dclIdentity(A[0], 4, 3);
dclPrint(A[0], 4, 3, 4);
- FLT x[4] = {7, 8, 9, 10};
- FLT R[4];
+ FLT x[4][2] = {
+ {7, -7}, {8, -8}, {9, -9}, {10, -10},
+ };
+ FLT R[4][2];
+ printf("%p %p %p\n", A, x, R);
// dclMul(R, 1, A[0], 4, x, 1, 4, 1, 3);
- dcldgemm(0, 0, 4, 1, 3, 1, A[0], 4, x, 1, 0, R, 1);
+ dcldgemm(0, 0, 3, 4, 2, 1, A[0], 4, x[0], 2, 0, R[0], 2);
- dclPrint(x, 1, 4, 1);
- dclPrint(R, 1, 4, 1);
+ dclPrint(x[0], 2, 4, 2);
+ dclPrint(R[0], 2, 4, 2);
- for (int i = 0; i < 3; i++)
- assert(R[i] == x[i]);
- assert(R[3] == 0.);
+ for (int j = 0; j < 2; j++) {
+ for (int i = 0; i < 3; i++)
+ assert(R[i][j] == x[i][j]);
+
+ assert(fabs(R[3][j]) < .0000001);
+ }
}
// void dclTransp( DCL_FLOAT * R, int Rc, const DCL_FLOAT * A, int Ac, int n, int m );