You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ma...@apache.org on 2021/06/16 22:57:07 UTC

[systemds] branch master updated: [SYSTEMDS-3023] Cuda Codegen Sparse I/O failing (bugfix)

This is an automated email from the ASF dual-hosted git repository.

markd pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new a17b511  [SYSTEMDS-3023] Cuda Codegen Sparse I/O failing (bugfix)
a17b511 is described below

commit a17b5114cef70ab92f8f461326e294b2fda0614f
Author: Mark Dokter <ma...@dokter.cc>
AuthorDate: Thu Jun 17 00:42:42 2021 +0200

    [SYSTEMDS-3023] Cuda Codegen Sparse I/O failing (bugfix)
    
    This patch fixes the sparse input output support of spoof cuda codegen (was faulty after SYSTEMDS-2930).
    
    Closes #1318
---
 src/main/cuda/headers/Matrix.h      | 12 +++++-----
 src/main/cuda/headers/reduction.cuh | 47 ++++++++++++-------------------------
 src/main/cuda/spoof/cellwise.cu     |  2 --
 3 files changed, 21 insertions(+), 40 deletions(-)

diff --git a/src/main/cuda/headers/Matrix.h b/src/main/cuda/headers/Matrix.h
index 755e764..0446983 100644
--- a/src/main/cuda/headers/Matrix.h
+++ b/src/main/cuda/headers/Matrix.h
@@ -76,27 +76,27 @@ public:
 	__device__ uint32_t cols() { return _mat->cols; }
 	__device__ uint32_t rows() { return _mat->rows; }
 	
-	__device__ uint32_t len() { return _mat->data == nullptr ? len_sparse() : len_dense(); }
+	__device__ uint32_t len() { return _mat->row_ptr == nullptr ? len_dense() : len_sparse(); }
 	
 	__device__ uint32_t pos(uint32_t rix) {
-		return _mat->data == nullptr ? pos_sparse(rix) : pos_dense(rix);
+		return _mat->row_ptr == nullptr ? pos_dense(rix) : pos_sparse(rix);
 	}
 	
 	__device__ T& val(uint32_t r, uint32_t c) {
-		return _mat->data == nullptr ? val_sparse_rc(r, c) : val_dense_rc(r,c);
+		return _mat->row_ptr == nullptr ? val_dense_rc(r,c) : val_sparse_rc(r, c) ;
 	}
 	
 	__device__ T& val(uint32_t i) {
-		return _mat->data == nullptr ? val_sparse_i(i) : val_dense_i(i);
+		return _mat->row_ptr == nullptr ? val_dense_i(i) : val_sparse_i(i);
 	}
 	__device__ T& operator[](uint32_t i) { return val(i); }
 	
 	__device__ T* vals(uint32_t rix) {
-		return _mat->data == nullptr ? vals_sparse(rix) : vals_dense(rix);
+		return _mat->row_ptr == nullptr ? vals_dense(rix) : vals_sparse(rix) ;
 	}
 	
 	__device__ uint32_t row_len(uint32_t rix) {
-		return _mat->data == nullptr ? row_len_sparse(rix) : row_len_dense(rix);
+		return _mat->row_ptr == nullptr ? row_len_dense(rix) : row_len_sparse(rix);
 	}
 	
 	__device__ uint32_t* col_idxs(uint32_t rix) { return cols_sparse(rix); }
diff --git a/src/main/cuda/headers/reduction.cuh b/src/main/cuda/headers/reduction.cuh
index 88cc45c..3cd0a0b 100644
--- a/src/main/cuda/headers/reduction.cuh
+++ b/src/main/cuda/headers/reduction.cuh
@@ -330,47 +330,30 @@ __device__ void NO_AGG_SPARSE(MatrixAccessor<T>* in, MatrixAccessor<T>* out, uin
 {
 	const uint32_t& rix = blockIdx.x;
 	uint32_t tid = threadIdx.x;
-//	uint32_t rix = (gtid * VT) / in->cols();
-//	//uint32_t cix = (gtid % in->cols());// *static_cast<uint32_t>(VT);
-//	uint32_t cix = in->col_idxs(0)[gtid];
-	uint32_t row_start = in->pos(rix);
-	uint32_t row_len = in->row_len(rix);
-
+	uint32_t row_start = 0;
+	uint32_t row_len = 0;
+	if(in->hasData()) {
+		row_start = in->pos(rix);
+		row_len = in->row_len(rix);
+	}
+	else {
+		row_start = rix * in->cols();
+		row_len = in->cols();
+	}
 	while(tid < row_len) {
+		uint32_t idx = row_start + tid;
 		if(in->hasData()) {
 			uint32_t *aix = in->col_idxs(rix);
 			uint32_t cix = aix[tid];
-//		T result = spoof_op(in->val(rix, cix), rix*in->rows()+cix, rix, cix);
-			T result = spoof_op(in->val(row_start + tid), rix * in->rows() + cix, rix, cix);
-			out->set(row_start + tid, cix, result);
-
-//		if(rix > 899 && rix < 903 && cix==0)
-//		if(rix < 10 && cix==0)
-//			printf("rix=%d row_start=%d tid=%d result=%4.3f\n", rix, row_start, tid, result);
+			T result = spoof_op(in->val(idx), idx, rix, cix);
+			out->set(idx, cix, result);
 		}
 		else {
 			uint32_t cix = tid;
-			T result = spoof_op(0, rix * in->rows() + cix, rix, cix);
-			out->set(row_start + tid, cix, result);
+			T result = spoof_op(0, idx, rix, cix);
+			out->set(idx, cix, result);
 		}
 		tid+=blockDim.x;
-
-
-//#pragma unroll
-//		for (auto i = first_idx; i < last_idx; i++) {
-////		out->vals(0)[i] = spoof_op(in->vals(0)[i], i);
-////		out->col_idxs(0)[i] = gtid % blockDim.x;
-//			T result = spoof_op(in->vals(0)[i], i);
-//			out->vals(0)[i] = result;
-//			//out->col_idxs(0)[i] = i % in->cols();
-//			out->col_idxs(0)[i] = in->col_idxs(0)[i];
-//			//out->set(i/in->cols(), i%in->cols(), result);
-//			//out->set(rix, i%in->cols(), result);
-//			if (i > in->nnz() - 10)
-//				printf("i=%d in=%4.3f res=%4.3f out=%4.3f r=%d out->index(i=%d)=%d out->col_idxs()[i=%d]=%d first=%d last=%d gtid=%d\n",
-//					   i, in->vals(0)[i], result, out->vals(0)[i],
-//					   i / in->cols(), i, out->indexes()[i], i, out->col_idxs(0)[i], first_idx, last_idx, gtid);
-//		}
 	}
 }
 
diff --git a/src/main/cuda/spoof/cellwise.cu b/src/main/cuda/spoof/cellwise.cu
index d70cc3d..951f709 100644
--- a/src/main/cuda/spoof/cellwise.cu
+++ b/src/main/cuda/spoof/cellwise.cu
@@ -53,8 +53,6 @@ struct SpoofCellwiseOp {
 	}
 
 	__device__  __forceinline__ T operator()(T a, uint32_t idx, uint32_t rix, uint32_t cix) {
-//%NEED_RIX%
-//%NEED_CIX%
 //%NEED_GRIX%
 %BODY_dense%
 		return %OUT%;