You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/02/08 19:15:33 UTC
[3/3] incubator-systemml git commit: [SYSTEMML-1039] Added variance,
row/col variance
[SYSTEMML-1039] Added variance, row/col variance
Closes #383.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/ad009d81
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/ad009d81
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/ad009d81
Branch: refs/heads/master
Commit: ad009d81f759caed7ed134771fc6236d7cf21866
Parents: f8d7077
Author: Nakul Jindal <na...@gmail.com>
Authored: Wed Feb 8 11:14:43 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Wed Feb 8 11:14:43 2017 -0800
----------------------------------------------------------------------
src/main/cpp/kernels/SystemML.cu | 105 +-
src/main/cpp/kernels/SystemML.ptx | 2772 ++++++++----------
.../java/org/apache/sysml/hops/AggUnaryOp.java | 1 +
.../instructions/GPUInstructionParser.java | 4 +-
.../MatrixMatrixArithmeticGPUInstruction.java | 2 +-
.../ScalarMatrixArithmeticGPUInstruction.java | 2 +-
.../runtime/matrix/data/LibMatrixCUDA.java | 225 +-
7 files changed, 1479 insertions(+), 1632 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ad009d81/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 4ce6fb2..cda28ba 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -40,7 +40,7 @@ __global__ void copyUpperToLowerTriangleDense(double* ret, int dim, int N) {
}
extern "C"
-__device__ double getBoolean(int val) {
+__forceinline__ __device__ double getBoolean(int val) {
if(val == 0)
return 0.0;
else
@@ -51,39 +51,23 @@ __device__ double getBoolean(int val) {
// 5=less, 6=lessequal, 7=greater, 8=greaterequal, 9=equal, 10=notequal,
// 11=min, 12=max, 13=and, 14=or, 15=log}
extern "C"
-__device__ double binaryOp(double x, double y, int op) {
- // 0=plus, 1=minus, 2=multiply, 3=divide, 4=power
- if(op == 0)
- return x + y;
- else if(op == 1)
- return x - y;
- else if(op == 2)
- return x * y;
- else if(op == 3)
- return x / y;
- else if(op == 4)
- return pow(x, y);
- // 5=less, 6=lessequal, 7=greater, 8=greaterequal, 9=equal, 10=notequal,
- else if(op == 5)
- return getBoolean(x < y);
- else if(op == 6)
- return getBoolean(x <= y);
- else if(op == 7)
- return getBoolean(x > y);
- else if(op == 8)
- return getBoolean(x >= y);
- else if(op == 9)
- return getBoolean(x == y);
- else if(op == 10)
- return getBoolean(x != y);
- // 11=min, 12=max, 13=and, 14=or, 15=log
- else if(op == 11) {
- return min(x, y);
- }
- else if(op == 12) {
- return max(x, y);
- }
- return -999;
+__forceinline__ __device__ double binaryOp(double x, double y, int op) {
+ switch(op) {
+ case 0 : return x + y;
+ case 1 : return x - y;
+ case 2 : return x * y;
+ case 3 : return x / y;
+ case 4 : return pow(x, y);
+ case 5 : return getBoolean(x < y);
+ case 6 : return getBoolean(x <= y);
+ case 7 : return getBoolean(x > y);
+ case 8 : return getBoolean(x >= y);
+ case 9 : return getBoolean(x == y);
+ case 10 : return getBoolean(x != y);
+ case 11 : return min(x, y);
+ case 12 : return max(x, y);
+ default : return DBL_MAX;
+ }
}
extern "C"
@@ -158,8 +142,22 @@ __global__ void compareAndSet(double* A, double* ret, int rlen, int clen, doubl
}
}
+
+/**
+ * Performs a binary cellwise arithmetic operation on 2 matrices.
+ * Either both matrices are of equal size or one of them is a vector or both are.
+ * @param A first input matrix allocated on GPU
+ * @param B second input matrix allocated on GPU
+ * @param C output allocated on GPU
+ * @param maxRlen maximum of the row lengths of A and B
+ * @param maxClen maximum of the column lengths of A and B
+ * @param vectorAStatus if A is a row vector, column vector or neither
+ * @param vectorBStatus if B is a row vector, column vector or neither
+ * @param op the numeric code of the arithmetic operation to perform
+ *
+ */
extern "C"
-__global__ void binCellOp(double* A, double* B, double* C,
+__global__ void matrix_matrix_cellwise_op(double* A, double* B, double* C,
int maxRlen, int maxClen, int vectorAStatus, int vectorBStatus, int op) {
int ix = blockIdx.x * blockDim.x + threadIdx.x;
int iy = blockIdx.y * blockDim.y + threadIdx.y;
@@ -177,21 +175,32 @@ __global__ void binCellOp(double* A, double* B, double* C,
else if(vectorBStatus == 2)
bIndex = iy; // rlen == 1
C[outIndex] = binaryOp(A[aIndex], B[bIndex], op);
- // printf("C[%d] = A[%d](%f) B[%d](%f) (%d %d)\n", outIndex, aIndex, A[aIndex], bIndex, B[bIndex], (ix+1), (iy+1));
+ //printf("C[%d] = A[%d](%f) B[%d](%f) (%d %d)\n", outIndex, aIndex, A[aIndex], bIndex, B[bIndex], (ix+1), (iy+1));
+ __syncthreads();
}
}
+/**
+ * Performs an arithmetic operation between a matrix and a scalar.
+ * C = s op A or C = A op s (where A is the matrix, s is the scalar and op is the operation)
+ * @param A input matrix allocated on GPU
+ * @param scalar scalar input
+ * @param C output matrix allocated on GPU
+ * @param size number of elements in matrix A
+ * @param op number code of the arithmetic operation to perform
+ * @param isLeftScalar whether the scalar is on the left side
+ */
extern "C"
-__global__ void binCellScalarOp(double* A, double scalar, double* C, int rlenA, int clenA, int op, int isLeftScalar) {
- int ix = blockIdx.x * blockDim.x + threadIdx.x;
- int iy = blockIdx.y * blockDim.y + threadIdx.y;
- int index = ix * clenA + iy;
- if(index < rlenA*clenA) {
- if(isLeftScalar)
+__global__ void matrix_scalar_op(double* A, double scalar, double* C, int size, int op, int isLeftScalar) {
+ int index = blockIdx.x *blockDim.x + threadIdx.x;
+ if(index < size) {
+ if(isLeftScalar) {
C[index] = binaryOp(scalar, A[index], op);
- else
+ } else {
C[index] = binaryOp(A[index], scalar, op);
+ }
}
+ __syncthreads();
}
@@ -475,7 +484,7 @@ typedef struct {
extern "C"
__global__ void reduce_max(double *g_idata, double *g_odata, unsigned int n){
MaxOp op;
- reduce<MaxOp>(g_idata, g_odata, n, op, DBL_MIN);
+ reduce<MaxOp>(g_idata, g_odata, n, op, -DBL_MAX);
}
/**
@@ -489,7 +498,7 @@ extern "C"
__global__ void reduce_row_max(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){
MaxOp op;
IdentityOp aop;
- reduce_row<MaxOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, DBL_MIN);
+ reduce_row<MaxOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, -DBL_MAX);
}
/**
@@ -503,7 +512,7 @@ extern "C"
__global__ void reduce_col_max(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){
MaxOp op;
IdentityOp aop;
- reduce_col<MaxOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, DBL_MIN);
+ reduce_col<MaxOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, -DBL_MAX);
}
/**
@@ -602,7 +611,7 @@ struct MeanOp {
extern "C"
__global__ void reduce_row_mean(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){
SumOp op;
- MeanOp aop(rows*cols);
+ MeanOp aop(cols);
reduce_row<SumOp, MeanOp>(g_idata, g_odata, rows, cols, op, aop, 0.0);
}
@@ -616,6 +625,6 @@ __global__ void reduce_row_mean(double *g_idata, double *g_odata, unsigned int r
extern "C"
__global__ void reduce_col_mean(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){
SumOp op;
- MeanOp aop(rows*cols);
+ MeanOp aop(rows);
reduce_col<SumOp, MeanOp>(g_idata, g_odata, rows, cols, op, aop, 0.0);
}