You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ma...@apache.org on 2022/04/20 12:17:55 UTC

[systemds] 01/02: [SYSTEMDS-3352] CUDA code gen support for connected components

This is an automated email from the ASF dual-hosted git repository.

markd pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit fc5b03de84ebd57214a69ca43f63f223dd258c89
Author: Mark Dokter <ma...@dokter.cc>
AuthorDate: Wed Apr 20 13:06:05 2022 +0200

    [SYSTEMDS-3352] CUDA code gen support for connected components
    
    General cleanup and bug fixing to make components.dml run.
    Also contains improvements to handle single precision execution.
---
 src/main/cuda/headers/Matrix.h                     |  52 ++-
 src/main/cuda/headers/spoof_utils.cuh              | 215 +++++------
 src/main/cuda/headers/vector_write.cuh             |  20 +-
 src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp  |  23 +-
 src/main/cuda/spoof-launcher/SpoofRowwise.h        |   8 +-
 src/main/cuda/spoof-launcher/jni_bridge.cpp        |   4 +-
 src/main/cuda/spoof/rowwise.cu                     |   9 +-
 .../apache/sysds/hops/codegen/cplan/CNodeRow.java  |   7 +-
 .../sysds/hops/codegen/cplan/cuda/Binary.java      | 391 ++++++++-------------
 .../sysds/hops/codegen/cplan/cuda/Ternary.java     |  88 ++---
 .../sysds/hops/codegen/cplan/cuda/Unary.java       | 313 +++++++----------
 .../sysds/runtime/codegen/SpoofCUDACellwise.java   |   4 +-
 .../sysds/runtime/codegen/SpoofCUDARowwise.java    |   4 +-
 13 files changed, 462 insertions(+), 676 deletions(-)

diff --git a/src/main/cuda/headers/Matrix.h b/src/main/cuda/headers/Matrix.h
index 61ef939b83..f02a76c83c 100644
--- a/src/main/cuda/headers/Matrix.h
+++ b/src/main/cuda/headers/Matrix.h
@@ -18,8 +18,6 @@
  */
 
 #pragma once
-#ifndef SYSTEMDS_MATRIX_H
-#define SYSTEMDS_MATRIX_H
 
 using uint32_t = unsigned int;
 using int32_t = int;
@@ -43,22 +41,22 @@ struct Matrix {
 
 #ifdef __CUDACC__
 
-template<typename T>
-uint32_t bin_search(T* values, uint32_t lower, uint32_t upper, T val) {
-	upper -= 1;
-	while(lower <= (upper-1)) {
-		uint32_t idx = (lower + upper) >> 1;
-		uint32_t vi = values[idx];
-		if (vi < val)
-			lower = idx + 1;
-		else {
-			if (vi <= val)
-				return idx;
-			upper = idx - 1;
-		}
-	}
-	return upper + 1;
-}
+//template<typename T>
+//uint32_t bin_search(T* values, uint32_t lower, uint32_t upper, T val) {
+//	upper -= 1;
+//	while(lower <= (upper-1)) {
+//		uint32_t idx = (lower + upper) >> 1;
+//		uint32_t vi = values[idx];
+//		if (vi < val)
+//			lower = idx + 1;
+//		else {
+//			if (vi <= val)
+//				return idx;
+//			upper = idx - 1;
+//		}
+//	}
+//	return upper + 1;
+//}
 
 template<typename T>
 class MatrixAccessor {
@@ -68,11 +66,11 @@ class MatrixAccessor {
 public:
 	MatrixAccessor() = default;
 	
-	__device__ MatrixAccessor(Matrix<T>* mat) : _mat(mat) {}
+	__device__ explicit MatrixAccessor(Matrix<T>* mat) : _mat(mat) {}
 	
 	__device__ void init(Matrix<T>* mat) { _mat = mat; }
 	
-	__device__ uint32_t& nnz() { return return _mat->row_ptr == nullptr ? _mat->rows * _mat->cols : _mat->nnz; }
+//	__device__ uint32_t& nnz() { return _mat->row_ptr == nullptr ? _mat->rows * _mat->cols : _mat->nnz; }
 	__device__ uint32_t cols() { return _mat->cols; }
 	__device__ uint32_t rows() { return _mat->rows; }
 	
@@ -96,14 +94,14 @@ public:
 	}
 	
 	__device__ uint32_t row_len(uint32_t rix) {
-		return _mat->row_ptr == nullptr ? row_len_dense(rix) : row_len_sparse(rix);
+		return _mat->row_ptr == nullptr ? _mat->rows : row_len_sparse(rix);
 	}
 	
 	__device__ uint32_t* col_idxs(uint32_t rix) { return cols_sparse(rix); }
 
 	__device__ void set(uint32_t r, uint32_t c, T v) { set_sparse(r,c,v); }
 	
-	__device__ uint32_t* indexes() {  return _mat->row_ptr;	}
+//	__device__ uint32_t* indexes() {  return _mat->row_ptr;	}
 	
 	__device__ bool hasData() { return _mat->data != nullptr; }
 private:
@@ -127,10 +125,6 @@ private:
 		return &(_mat->data[rix]);
 	}
 	
-	__device__ uint32_t row_len_dense(uint32_t rix) {
-		return _mat->rows;
-	}
-	
 	//ToDo sparse accessors
 	__device__ uint32_t len_sparse() {
 		return _mat->row_ptr[_mat->rows];
@@ -145,8 +139,8 @@ private:
 	}
 	
 	__device__ T& val_sparse_rc(uint32_t r, uint32_t c) {
-//		printf("TBI: val_sparse_rc\n");
-//		asm("trap;");
+		printf("TBI: val_sparse_rc(%d, %d)\n", r, c);
+		asm("trap;");
 
 		return _mat->data[0];
 	}
@@ -228,5 +222,3 @@ public:
 };
 
 #endif // __CUDACC_RTC__
-
-#endif //SYSTEMDS_MATRIX_H
diff --git a/src/main/cuda/headers/spoof_utils.cuh b/src/main/cuda/headers/spoof_utils.cuh
index 5d9b1012b2..8ab0fafdb2 100644
--- a/src/main/cuda/headers/spoof_utils.cuh
+++ b/src/main/cuda/headers/spoof_utils.cuh
@@ -18,8 +18,6 @@
  */
 
 #pragma once
-#ifndef SPOOF_UTILS_CUH
-#define SPOOF_UTILS_CUH
 
 #include <math_constants.h>
 #include "vector_add.cuh"
@@ -31,13 +29,8 @@ struct TempStorage;
 #include "Matrix.h"
 #include "vector_write.cuh"
 
-// #include "intellisense_cuda_intrinsics.h"
-
 using uint32_t = unsigned int;
 
-//static __device__  bool debug_row() { return blockIdx.x == 0; };
-//static __device__ bool debug_thread() { return threadIdx.x == 0; }
-
 __constant__ double DOUBLE_EPS = 1.11022E-16; // 2 ^ -53
 __constant__ double FLOAT_EPS = 1.49012E-08; // 2 ^ -26
 __constant__ double EPSILON = 1E-11; // margin for comparisons ToDo: make consistent use of it
@@ -79,12 +72,6 @@ __device__ Vector<T>& getVector(MatrixAccessor<T>& data, uint32_t n, uint32_t ri
 		c[i] = data.val(rix, i);
 		i += blockDim.x;
 	}
-//	if(debug_thread()) {
-//		printf("getVector: c.len=%d rix=%d\n", c.length, rix);
-//		for(auto j = 0; j < c.length; ++j)
-//			printf("%4.3f ", c[j]);
-//		printf("\n");
-//	}
 	return c;
 }
 
@@ -146,122 +133,147 @@ __device__ T BLOCK_ROW_AGG(T *a, T *b, uint32_t len, AggOp agg_op, LoadOp load_o
 	auto sdata = shared_memory_proxy<T>();
 	uint tid = threadIdx.x;
 
-	// Initalize shared mem and leave if tid > row length. 
-//	if(tid >= len) { return sdata[tid] = AggOp::init();; }
-
-	__syncthreads();
-	
-//			 if(blockIdx.x == 0 && threadIdx.x == 0)
-//		   printf("tid=%d sdata[tid + 128]=%f, len=%d\n", tid, len, sdata[tid+128]);
 	uint i = tid;
 	T v = AggOp::init();
-//			 if(blockIdx.x == 0 && threadIdx.x == 0)
-//		   printf("tid=%d sdata[tid + 128]=%f\n", tid, sdata[tid+128]);
-	while (i < len) {
+	while(i < len) {
 		v = agg_op(v, load_op(a[i], b[i]));
 		i += blockDim.x;
 	}
 
-//		 if(blockIdx.x == 0 && threadIdx.x == 0)
-//	if(debug_row() && debug_thread())
-//		   printf("tid=%d sdata[tid + 128]=%f\n", tid, sdata[tid+128]);
-	
 	// each thread puts its local sum into shared memory
 	sdata[tid] = v;
-	// if(blockIdx.x==0)
-		// printf("tid=%d v=%f, len=%d\n", tid, v, len);
 	__syncthreads();
 
-			// if(blockIdx.x == 0 && threadIdx.x == 0)
-		 //  printf("tid=%d sdata[tid + 128]=%f\n", tid, sdata[tid+128]);
-	
 	// do reduction in shared mem
-	if (blockDim.x >= 1024) {
-		if (tid < 512 && (tid+512) < len) {
-				// if(blockIdx.x == 0 && threadIdx.x == 0)
-		  // printf("tid=%d sdata[tid + 512]=%f\n", tid, sdata[tid+512]);
+	if(blockDim.x >= 1024) {
+		if(tid < 512 && (tid+512) < len) {
 			sdata[tid] = v = agg_op(v, sdata[tid + 512]);
 		}
 		__syncthreads();
 	}
-	if (blockDim.x >= 512) {
-		if (tid < 256 && (tid+256) < len) {
-				// if(blockIdx.x == 0 && threadIdx.x == 0)
-		  // printf("tid=%d sdata[tid + 256]=%f\n", tid, sdata[tid+256]);
+	if(blockDim.x >= 512) {
+		if(tid < 256 && (tid+256) < len) {
 			sdata[tid] = v = agg_op(v, sdata[tid + 256]);
 		}
 		__syncthreads();
 	}
-	if (blockDim.x >= 256) {
-		if (tid < 128 && (tid+128) < len) {
-				// if(blockIdx.x == 0 && threadIdx.x == 0)
-		  // printf("tid=%d sdata[tid + 128]=%f\n", tid, sdata[tid+128]);
+	if(blockDim.x >= 256) {
+		if(tid < 128 && (tid+128) < len) {
 			sdata[tid] = v = agg_op(v, sdata[tid + 128]);
 		}
 		__syncthreads();
 	}
-	if (blockDim.x >= 128) {
-		if (tid < 64 && (tid+64) < len) {
-				// if(blockIdx.x == 0 && threadIdx.x == 0)
-		  // printf("tid=%d sdata[tid + 64]=%f\n", tid, sdata[tid+64]);
+	if(blockDim.x >= 128) {
+if(tid < 64 && (tid+64) < len) {
 			sdata[tid] = v = agg_op(v, sdata[tid + 64]);
 		}
 		__syncthreads();
 	}
- 
-	if (tid < 32) {
+
+	if(tid < 32) {
 		// now that we are using warp-synchronous programming (below)
 		// we need to declare our shared memory volatile so that the compiler
 		// doesn't reorder stores to it and induce incorrect behavior.
 		volatile T *smem = sdata;
-		if (blockDim.x >= 64 && (tid+32) < len) {
+		if(blockDim.x >= 64 && (tid+32) < len) {
 			smem[tid] = v = agg_op(v, smem[tid + 32]);
 		}
-		// if(blockIdx.x==0)
-		  // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
-		if (blockDim.x >= 32 && (tid+16) < len) {
+		if(blockDim.x >= 32 && (tid+16) < len) {
 			smem[tid] = v = agg_op(v, smem[tid + 16]);
 		}
-		// if(blockIdx.x==0)
-		  // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
-		if (blockDim.x >= 16 && (tid+8) < len) {
+		if(blockDim.x >= 16 && (tid+8) < len) {
 			smem[tid] = v = agg_op(v, smem[tid + 8]);
 		}
-		// if(blockIdx.x==0)
-		  // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
-		if (blockDim.x >= 8 && (tid+4) < len) {
+		if(blockDim.x >= 8 && (tid+4) < len) {
 			smem[tid] = v = agg_op(v, smem[tid + 4]);
 		}
-		// if(blockIdx.x==0 && threadIdx.x ==0)
-		  // printf("tid=%d smem[tid + 4]=%f\n", tid, smem[tid+4]);
-		if (blockDim.x >= 4 && (tid+2) < len) {
+		if(blockDim.x >= 4 && (tid+2) < len) {
 			smem[tid] = v = agg_op(v, smem[tid + 2]);
 		}
-		// if(blockIdx.x==0 && threadIdx.x ==0)
-		  // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
-		if (blockDim.x >= 2 && (tid+1) < len) {
-		// if (blockDim.x >= 2) {
+		if(blockDim.x >= 2 && (tid+1) < len) {
 			smem[tid] = v = agg_op(v, smem[tid + 1]);
 		}
-//		 if(blockIdx.x==0 && threadIdx.x ==0)
-//		if(debug_row() && debug_thread())
-//		   printf("tid=%d smem[0]=%f\n", tid, smem[0]);
 	}
-
 	__syncthreads();
 	return sdata[0];
 }
 
+
+template<typename T, typename AggOp, typename LoadOp>
+__device__ T BLOCK_ROW_AGG(T *a, T *b, uint32_t* aix, uint32_t len, AggOp agg_op, LoadOp load_op) {
+    auto sdata = shared_memory_proxy<T>();
+    uint tid = threadIdx.x;
+
+    uint i = tid;
+    T v = AggOp::init();
+    while(i < len) {
+        v = agg_op(v, load_op(a[i], b[aix[i]]));
+        i += blockDim.x;
+    }
+
+    // each thread puts its local sum into shared memory
+    sdata[tid] = v;
+    __syncthreads();
+
+    // do reduction in shared mem
+    if(blockDim.x >= 1024) {
+        if(tid < 512 && (tid+512) < len) {
+            sdata[tid] = v = agg_op(v, sdata[tid + 512]);
+        }
+        __syncthreads();
+    }
+    if(blockDim.x >= 512) {
+        if(tid < 256 && (tid+256) < len) {
+            sdata[tid] = v = agg_op(v, sdata[tid + 256]);
+        }
+        __syncthreads();
+    }
+    if(blockDim.x >= 256) {
+        if(tid < 128 && (tid+128) < len) {
+            sdata[tid] = v = agg_op(v, sdata[tid + 128]);
+        }
+        __syncthreads();
+    }
+    if(blockDim.x >= 128) {
+        if(tid < 64 && (tid+64) < len) {
+            sdata[tid] = v = agg_op(v, sdata[tid + 64]);
+        }
+        __syncthreads();
+    }
+
+    if(tid < 32) {
+        // now that we are using warp-synchronous programming (below)
+        // we need to declare our shared memory volatile so that the compiler
+        // doesn't reorder stores to it and induce incorrect behavior.
+        volatile T *smem = sdata;
+        if(blockDim.x >= 64 && (tid+32) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 32]);
+        }
+        if(blockDim.x >= 32 && (tid+16) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 16]);
+        }
+        if(blockDim.x >= 16 && (tid+8) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 8]);
+        }
+        if(blockDim.x >= 8 && (tid+4) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 4]);
+        }
+        if(blockDim.x >= 4 && (tid+2) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 2]);
+        }
+        if(blockDim.x >= 2 && (tid+1) < len) {
+            smem[tid] = v = agg_op(v, smem[tid + 1]);
+        }
+    }
+    __syncthreads();
+    return sdata[0];
+}
+
 template<typename T>
 __device__ T dotProduct(T* a, T* b, uint32_t ai, uint32_t bi, uint32_t len) {
 	SumOp<T> agg_op;
 	ProductOp<T> load_op;
-//	if(debug_row() && debug_thread())
-//		printf("dot len = %d\n", len);
-	T ret =  BLOCK_ROW_AGG(&a[ai], &b[bi], len, agg_op, load_op);
-//	if(debug_row() && debug_thread())
-//		printf("bid=%d, ai=%d, dot=%f\n", blockIdx.x, ai, ret);
-	return ret;
+	return BLOCK_ROW_AGG(&a[ai], &b[bi], len, agg_op, load_op);
 }
 
 template<typename T>
@@ -277,8 +289,6 @@ __device__ T vectSum(T* a, uint32_t ai, uint32_t len) {
 	SumOp<T> agg_op;
 	IdentityOp<T> load_op;
 	T result = BLOCK_ROW_AGG(&a[ai], &a[ai], len, agg_op, load_op);
-//	if(debug_row() && debug_thread())
-//		printf("vectSum: bid=%d, tid=%d ai=%d len=%d result=%4.3f\n", blockIdx.x, threadIdx.x, ai, len, result);
 	return result;
 }
 
@@ -286,10 +296,22 @@ template<typename T>
 __device__ T vectMin(T* a, int ai, int len) {
 	MinOp<T> agg_op;
 	IdentityOp<T> load_op;
-	T result = BLOCK_ROW_AGG(&a[ai], &a[ai], len, agg_op, load_op);
-//	if(debug_row() && debug_thread())
-//		printf("vectMin: bid=%d, tid=%d ai=%d len=%d result=%4.3f\n", blockIdx.x, threadIdx.x, ai, len, result);
-	return result;
+	return BLOCK_ROW_AGG(&a[ai], &a[ai], len, agg_op, load_op);
+}
+
+template<typename T>
+__device__ T rowMaxsVectMult(T* a, T* b, uint32_t ai, uint32_t bi, uint32_t len) {
+    MaxOp<T> agg_op;
+    ProductOp<T> load_op;
+    return BLOCK_ROW_AGG(&a[ai], &b[0], len, agg_op, load_op);
+}
+
+template<typename T>
+__device__ T rowMaxsVectMult(T* a, T* b, uint32_t* aix, uint32_t ai, uint32_t bi, uint32_t len) {
+    MaxOp<T> agg_op;
+    ProductOp<T> load_op;
+
+    return BLOCK_ROW_AGG(&a[ai], &b[0], &aix[ai], len, agg_op, load_op);
 }
 
 template<typename T>
@@ -302,25 +324,7 @@ __device__ T vectMax(T* a, uint32_t ai, uint32_t len) {
 
 template<typename T>
 __device__ T vectMax(T* avals, uint32_t* aix, uint32_t ai, uint32_t alen, uint32_t len) {
-//	if (debug_row() && debug_thread()) {
-//		printf("\naix[i]:\n");
-//		for(auto i = 0; i < alen; ++i)
-//			printf(" %d", aix[i]);
-		
-//		printf("\navals[i]:\n");
-//		for(auto i = 0; i < alen; ++i)
-//			printf(" %4.3f", avals[i]);
-		
-//		printf("\navals[aix[i]]:\n");
-//		for(auto i = 0; i < alen; ++i)
-//			printf(" %4.3f", avals[aix[i]]);
-
-//		printf("\n");
-//	}
-
 	T result = vectMax(avals, ai, alen);
-//	if (blockIdx.x < 5 && debug_thread())
-//		printf("bid=%d, tid=%d, len=%d, alen=%d, ai=%d vectMax=%4.3f\n", blockIdx.x, threadIdx.x, len, alen, ai, result);
 	return alen < len ? MaxOp<T>::exec(result, 0.0) : result;
 }
 
@@ -547,6 +551,12 @@ Vector<T>& vectMultWrite(T* a, T* b, uint32_t ai, uint32_t bi, uint32_t len, Tem
 	return vectWriteBinary<T, ProductOp<T>>(a, b, ai, bi, len, fop, "Mult");
 }
 
+// sparse-dense MxV
+template<typename T>
+Vector<T>& vectMultWrite(T* avals, T* b, uint32_t* aix, uint32_t ai, uint32_t bi, uint32_t alen, uint32_t len, TempStorage<T>* fop) {
+    return vectWriteBinary<T, ProductOp<T>>(avals, b, aix, ai, bi, alen, len, fop, "Mult");
+}
+
 template<typename T>
 Vector<T>& vectDivWrite(T* a, T b, uint32_t ai, uint32_t len, TempStorage<T>* fop) {
 	return vectWriteBinary<T, DivOp<T>>(a, b, ai, len, fop, "Div");
@@ -744,6 +754,3 @@ void vectOuterMultAdd(T* a, T* b, T* c, uint32_t ai, uint32_t bi, uint32_t ci, u
 		i += blockDim.x;
 	}
 }
-
-
-#endif // SPOOF_UTILS_CUH
diff --git a/src/main/cuda/headers/vector_write.cuh b/src/main/cuda/headers/vector_write.cuh
index 3099926167..55241bd8d4 100644
--- a/src/main/cuda/headers/vector_write.cuh
+++ b/src/main/cuda/headers/vector_write.cuh
@@ -18,10 +18,9 @@
  */
 
 #pragma once
-#ifndef SYSTEMDS_VECTOR_WRITE_CUH
-#define SYSTEMDS_VECTOR_WRITE_CUH
 
-__device__ bool debug_row() { return blockIdx.x == 1; };
+#define DEBUG_ROW 2
+__device__ bool debug_row() { return blockIdx.x == DEBUG_ROW; };
 __device__ bool debug_thread() { return threadIdx.x == 0; }
 
 // unary transform vector by OP and write to intermediate vector
@@ -143,6 +142,19 @@ __device__ Vector<T>& vectWriteBinary(T* a, T* b, uint32_t ai, uint32_t bi, uint
 	return c;
 }
 
+// sparse binary vect-vect to intermediate vector
+template<typename T, typename OP>
+__device__ Vector<T>& vectWriteBinary(T* a, T* b, uint32_t* aix, uint32_t ai, uint32_t bi, uint32_t alen, uint32_t len,
+        TempStorage<T>* fop, const char* name = nullptr) {
+    uint32_t i = threadIdx.x;
+    Vector<T>& c = fop->getTempStorage(len);
+    while (i < alen) {
+        c[aix[ai+i]] = OP::exec(a[ai + i], b[aix[ai+i]]);
+        i += blockDim.x;
+    }
+    return c;
+}
+
 // binary vector-scalar to output vector c
 template<typename T, typename OP>
 __device__ void vectWriteBinary(T* a, T b, T* c, uint32_t ai, uint32_t ci, uint32_t len) {
@@ -168,5 +180,3 @@ __device__ void vectWriteBinary(T* a, T* b, T* c, uint32_t ai, uint32_t bi, uint
 		i += blockDim.x;
 	}
 }
-
-#endif //SYSTEMDS_VECTOR_WRITE_CUH
diff --git a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
index 2ef482d18d..c4ae1e3dff 100644
--- a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
+++ b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
@@ -47,11 +47,7 @@ size_t SpoofCUDAContext::initialize_cuda(uint32_t device_id, const char* resourc
 	s1 << "-I" << resource_path << "/cuda/headers";
 	s2 << "-I" << resource_path << "/cuda/spoof";
 	auto ctx = new SpoofCUDAContext(resource_path,{s1.str(), s2.str(), cuda_include_path});
-	// cuda device is handled by jCuda atm
-	//cudaSetDevice(device_id);
-	//cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync);
-	//cudaDeviceSynchronize();
-	
+
 	CHECK_CUDA(cuModuleLoad(&(ctx->reductions), std::string(ctx->resource_path + std::string("/cuda/kernels/reduction.ptx")).c_str()));
 	
 	CUfunction func;
@@ -87,30 +83,13 @@ void SpoofCUDAContext::destroy_cuda(SpoofCUDAContext *ctx, [[maybe_unused]] uint
 	cudaFreeHost(ctx->staging_buffer);
 	cudaFree(ctx->device_buffer);
 	delete ctx;
-	// cuda device is handled by jCuda atm
-	//cudaDeviceReset();
 }
 
 size_t SpoofCUDAContext::compile(std::unique_ptr<SpoofOperator> op, const std::string &src) {
-#ifndef NDEBUG
-//	std::cout << "---=== START source listing of spoof cuda kernel [ " << name << " ]: " << std::endl;
-//    uint32_t line_num = 0;
-//	std::istringstream src_stream(src);
-//    for(std::string line; std::getline(src_stream, line); line_num++)
-//		std::cout << line_num << ": " << line << std::endl;
-//	std::cout << "---=== END source listing of spoof cuda kernel [ " << name << " ]." << std::endl;
-	std::cout << "cwd: " << std::filesystem::current_path() << std::endl;
-	std::cout << "include_paths: ";
-	for_each (include_paths.begin(), include_paths.end(), [](const std::string& line){ std::cout << line << '\n';});
-	std::cout << std::endl;
-#endif
-
-// uncomment all related lines for temporary timing output:
 //	auto compile_start = clk::now();
 	op->program = std::make_unique<jitify::Program>(kernel_cache.program(src, 0, include_paths));
 //	auto compile_end = clk::now();
 //	auto compile_duration = std::chrono::duration_cast<sec>(compile_end - compile_start).count();
-
 	compiled_ops.push_back(std::move(op));
 //	compile_total += compile_duration;
 //	std::cout << name << " compiled in "
diff --git a/src/main/cuda/spoof-launcher/SpoofRowwise.h b/src/main/cuda/spoof-launcher/SpoofRowwise.h
index 4465ac99fa..01ec5206aa 100644
--- a/src/main/cuda/spoof-launcher/SpoofRowwise.h
+++ b/src/main/cuda/spoof-launcher/SpoofRowwise.h
@@ -18,15 +18,13 @@
  */
 
 #pragma once
-#ifndef SYSTEMDS_SPOOFROWWISE_H
-#define SYSTEMDS_SPOOFROWWISE_H
 
 #include "SpoofCUDAContext.h"
 #include <algorithm>
 
 template <typename T>
 struct SpoofRowwise {
-	
+
 	static void exec([[maybe_unused]] SpoofCUDAContext* ctx, SpoofOperator* _op, DataBufferWrapper* dbw)  {
 		uint32_t NT=256;
 		T value_type;
@@ -56,7 +54,7 @@ struct SpoofRowwise {
 			CHECK_CUDART(cudaMalloc(reinterpret_cast<void**>(&d_temp), temp_buf_size));
 			CHECK_CUDART(cudaMemsetAsync(d_temp, 0, temp_buf_size, op->stream));
 		}
-		
+
 		std::string op_name(op->name + "_DENSE");
 		if(sparse_input)
 			op_name = std::string(op->name + "_SPARSE");
@@ -77,5 +75,3 @@ struct SpoofRowwise {
 			CHECK_CUDART(cudaFree(d_temp));
 	}
 };
-
-#endif //SYSTEMDS_SPOOFROWWISE_H
diff --git a/src/main/cuda/spoof-launcher/jni_bridge.cpp b/src/main/cuda/spoof-launcher/jni_bridge.cpp
index 5134d5e292..65f4a5a19f 100644
--- a/src/main/cuda/spoof-launcher/jni_bridge.cpp
+++ b/src/main/cuda/spoof-launcher/jni_bridge.cpp
@@ -165,7 +165,7 @@ int launch_spoof_operator([[maybe_unused]] JNIEnv *jenv, [[maybe_unused]] jclass
 
 [[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1f
 	(JNIEnv *jenv, jclass jobj, jlong ctx) {
-	return launch_spoof_operator<double, SpoofCellwise<double>>(jenv, jobj, ctx);
+	return launch_spoof_operator<float, SpoofCellwise<float>>(jenv, jobj, ctx);
 }
 
 
@@ -177,5 +177,5 @@ int launch_spoof_operator([[maybe_unused]] JNIEnv *jenv, [[maybe_unused]] jclass
 
 [[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1f
 	(JNIEnv *jenv, jclass jobj, jlong ctx) {
-	return launch_spoof_operator<double, SpoofRowwise<double>>(jenv, jobj, ctx);
+	return launch_spoof_operator<float, SpoofRowwise<float>>(jenv, jobj, ctx);
 }
\ No newline at end of file
diff --git a/src/main/cuda/spoof/rowwise.cu b/src/main/cuda/spoof/rowwise.cu
index b31ce0c2ce..917b8e7087 100644
--- a/src/main/cuda/spoof/rowwise.cu
+++ b/src/main/cuda/spoof/rowwise.cu
@@ -48,15 +48,14 @@ struct SpoofRowwiseOp //%HAS_TEMP_VECT%
 		a.init(A);
 		c.init(C);
 		
-		if(B)
-			for(auto i = 0; i < NUM_B; ++i)
-				b[i].init(&(B[i]));
+		if(B) {
+		    for(auto i = 0; i < NUM_B; ++i)
+		        b[i].init(&(B[i]));
+		}
 	}
 
 	__device__  __forceinline__ void exec_dense(uint32_t ai, uint32_t ci, uint32_t rix) {
 //%BODY_dense%
-		if (debug_row() && debug_thread())
-			printf("c[0]=%4.3f\n", c.vals(0)[0]);
 	}
 
 	__device__  __forceinline__ void exec_sparse(uint32_t ai, uint32_t ci, uint32_t rix) {
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
index 88844d6fbb..13c8e10fca 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
@@ -56,9 +56,12 @@ public class CNodeRow extends CNodeTpl
 	private static final String TEMPLATE_NOAGG_OUT   = "    LibSpoofPrimitives.vectWrite(%IN%, c, ci, %LEN%);\n";
 	private static final String TEMPLATE_NOAGG_CONST_OUT_CUDA   = "\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n";
 	private static final String TEMPLATE_NOAGG_OUT_CUDA   = "\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n";
-	private static final String TEMPLATE_ROWAGG_OUT_CUDA  = "\t\tif(threadIdx.x == 0){\n\t\t\t*(c.vals(rix)) = %IN%;\n//printf(\"rix=%d TMP7=%f TMP8=%f %IN%=%f\\n\",rix, TMP7, TMP8,%IN%);\n}\n";
+//	private static final String TEMPLATE_ROWAGG_OUT_CUDA  = "\t\tif(threadIdx.x == 0){\n\t\t\t*(c.vals(rix)) = %IN%;\n//printf(\"rix=%d TMP7=%f TMP8=%f %IN%=%f\\n\",rix, TMP7, TMP8,%IN%);\n}\n";
+private static final String TEMPLATE_ROWAGG_OUT_CUDA  = "\t\tif(threadIdx.x == 0){\n\t\t\t*(c.vals(rix)) = %IN%;\n\t\t}\n";
+//	private static final String TEMPLATE_FULLAGG_OUT_CUDA =
+//		"\t\tif(threadIdx.x == 0) {\n\t\t\tT old = atomicAdd(c.vals(0), %IN%);\n//\t\t\tprintf(\"bid=%d full_agg add %f to %f\\n\",blockIdx.x, %IN%, old);\n\t\t}\n";
 	private static final String TEMPLATE_FULLAGG_OUT_CUDA =
-		"\t\tif(threadIdx.x == 0) {\n\t\t\tT old = atomicAdd(c.vals(0), %IN%);\n//\t\t\tprintf(\"bid=%d full_agg add %f to %f\\n\",blockIdx.x, %IN%, old);\n\t\t}\n";
+		"\t\tif(threadIdx.x == 0) {\n\t\tT old = atomicAdd(c.vals(0), %IN%);\n\t\t}\n";
 
 
 	public CNodeRow(ArrayList<CNode> inputs, CNode output ) {
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
index 6d826b16b4..ec46e1196c 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
@@ -42,267 +42,156 @@ public class Binary extends CodeTemplate
 					"\t\tVector<T>& %TMP% = vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN1%, %LEN2%, this);\n" :
 					"\t\tVector<T>& %TMP% = vectCbindWrite(%IN1%, %IN2%, %POS1%, %POS2%, %LEN1%, %LEN2%, this);\n";
 		}
-		
-		if(isSinglePrecision()) {
-			switch(type) {
-				case DOT_PRODUCT:
-					return sparseLhs ? "	T %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : "	T %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
-				case VECT_MATRIXMULT:
-					return sparseLhs ? "	T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : "	T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
-				case VECT_OUTERMULT_ADD:
-					return sparseLhs ? "	LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? "	LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : "	LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
 
-				//vector-scalar-add operations
-				case VECT_MULT_ADD:
-				case VECT_DIV_ADD:
-				case VECT_MINUS_ADD:
-				case VECT_PLUS_ADD:
-				case VECT_POW_ADD:
-				case VECT_XOR_ADD:
-				case VECT_MIN_ADD:
-				case VECT_MAX_ADD:
-				case VECT_EQUAL_ADD:
-				case VECT_NOTEQUAL_ADD:
-				case VECT_LESS_ADD:
-				case VECT_LESSEQUAL_ADD:
-				case VECT_GREATER_ADD:
-				case VECT_GREATEREQUAL_ADD:
-				case VECT_CBIND_ADD: {
-					String vectName = type.getVectorPrimitiveName();
-					if(scalarVector)
-						return sparseLhs ? "	LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : "	LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
-					else
-						return sparseLhs ? "	LibSpoofPrimitives.vect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : "	LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
-				}
+		switch(type) {
+			case ROWMAXS_VECTMULT:
+				return sparseLhs ? "\t\tT %TMP% = rowMaxsVectMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
+						"\t\tT %TMP% = rowMaxsVectMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+			case DOT_PRODUCT:
+				return sparseLhs ? "\t\tT %TMP% = dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : "		T %TMP% = dotProduct(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%);\n";
 
-				//vector-scalar operations
-				case VECT_MULT_SCALAR:
-				case VECT_DIV_SCALAR:
-				case VECT_MINUS_SCALAR:
-				case VECT_PLUS_SCALAR:
-				case VECT_POW_SCALAR:
-				case VECT_XOR_SCALAR:
-				case VECT_BITWAND_SCALAR:
-				case VECT_MIN_SCALAR:
-				case VECT_MAX_SCALAR:
-				case VECT_EQUAL_SCALAR:
-				case VECT_NOTEQUAL_SCALAR:
-				case VECT_LESS_SCALAR:
-				case VECT_LESSEQUAL_SCALAR:
-				case VECT_GREATER_SCALAR:
-				case VECT_GREATEREQUAL_SCALAR: {
-					String vectName = type.getVectorPrimitiveName();
-					if(scalarVector)
-						return sparseRhs ? "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%);\n";
-					else
-						return sparseLhs ? "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
-				}
-				//vector-vector operations
-				case VECT_MULT:
-				case VECT_DIV:
-				case VECT_MINUS:
-				case VECT_PLUS:
-				case VECT_XOR:
-				case VECT_BITWAND:
-				case VECT_BIASADD:
-				case VECT_BIASMULT:
-				case VECT_MIN:
-				case VECT_MAX:
-				case VECT_EQUAL:
-				case VECT_NOTEQUAL:
-				case VECT_LESS:
-				case VECT_LESSEQUAL:
-				case VECT_GREATER:
-				case VECT_GREATEREQUAL: {
-					String vectName = type.getVectorPrimitiveName();
-					return sparseLhs ? "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" : sparseRhs ? "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
-				}
+			case VECT_MATRIXMULT:
+				return sparseLhs ? "	T[] %TMP% = vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : "		Vector<T>& %TMP% = vectMatrixMult(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
+			case VECT_OUTERMULT_ADD:
+				return sparseLhs ? "	LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? "	LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : "\t\tvectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
 
-				//scalar-scalar operations
-				case MULT:
-					return "	T %TMP% = %IN1% * %IN2%;\n";
-				case DIV:
-					return "	T %TMP% = %IN1% / %IN2%;\n";
-				case PLUS:
-					return "	T %TMP% = %IN1% + %IN2%;\n";
-				case MINUS:
-					return "	T %TMP% = %IN1% - %IN2%;\n";
-				case MODULUS:
-					return "	T %TMP% = modulus(%IN1%, %IN2%);\n";
-				case INTDIV:
-					return "	T %TMP% = intDiv(%IN1%, %IN2%);\n";
-				case LESS:
-					return "	T %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n";
-				case LESSEQUAL:
-					return "	T %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n";
-				case GREATER:
-					return "	T %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n";
-				case GREATEREQUAL:
-					return "	T %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n";
-				case EQUAL:
-					return "	T %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n";
-				case NOTEQUAL:
-					return "	T %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n";
-
-				case MIN:
-					return "	T %TMP% = fminf(%IN1%, %IN2%);\n";
-				case MAX:
-					return "	T %TMP% = fmaxf(%IN1%, %IN2%);\n";
-				case LOG:
-					return "	T %TMP% = logf(%IN1%)/Math.log(%IN2%);\n";
-				case LOG_NZ:
-					return "	T %TMP% = (%IN1% == 0) ? 0 : logf(%IN1%) / logf(%IN2%);\n";
-				case POW:
-					return "	T %TMP% = powf(%IN1%, %IN2%);\n";
-				case MINUS1_MULT:
-					return "	T %TMP% = 1 - %IN1% * %IN2%;\n";
-				case MINUS_NZ:
-					return "	T %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
-				case XOR:
-					return "	T %TMP% = ( (%IN1% != 0) != (%IN2% != 0) ) ? 1.0f : 0.0f;\n";
-				case BITWAND:
-					return "	T %TMP% = bwAnd(%IN1%, %IN2%);\n";
-				case SEQ_RIX:
-					return "	T %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
-
-				default:
-					throw new RuntimeException("Invalid binary type: " + this.toString());
+			//vector-scalar-add operations
+			case VECT_MULT_ADD:
+			case VECT_DIV_ADD:
+			case VECT_MINUS_ADD:
+			case VECT_PLUS_ADD:
+			case VECT_POW_ADD:
+			case VECT_XOR_ADD:
+			case VECT_MIN_ADD:
+			case VECT_MAX_ADD:
+			case VECT_EQUAL_ADD:
+			case VECT_NOTEQUAL_ADD:
+			case VECT_LESS_ADD:
+			case VECT_LESSEQUAL_ADD:
+			case VECT_GREATER_ADD:
+			case VECT_GREATEREQUAL_ADD:
+			case VECT_CBIND_ADD: {
+				String vectName = type.getVectorPrimitiveName();
+				if(scalarVector)
+					return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
+				else
+					return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, static_cast<uint32_t>(%POSOUT%), %LEN%);\n";
 			}
-		}
-		else {
-			switch(type) {
-				case DOT_PRODUCT:
-//					return sparseLhs ? "	T %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : "	T %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
-//					return sparseLhs ? "		T %TMP% = dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : "		T %TMP% = dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n	printf(\"dot=%f, bid=%d, tid=%d\\n\",TMP7,blockIdx.x, threadIdx.x);\n	__syncthreads();\n";
-					return sparseLhs ? "		T %TMP% = dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : "		T %TMP% = dotProduct(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%);\n";
-				
-				case VECT_MATRIXMULT:
-					return sparseLhs ? "	T[] %TMP% = vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : "		Vector<T>& %TMP% = vectMatrixMult(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
-				case VECT_OUTERMULT_ADD:
-					return sparseLhs ? "	LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? "	LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : "\t\tvectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
-
-				//vector-scalar-add operations
-				case VECT_MULT_ADD:
-				case VECT_DIV_ADD:
-				case VECT_MINUS_ADD:
-				case VECT_PLUS_ADD:
-				case VECT_POW_ADD:
-				case VECT_XOR_ADD:
-				case VECT_MIN_ADD:
-				case VECT_MAX_ADD:
-				case VECT_EQUAL_ADD:
-				case VECT_NOTEQUAL_ADD:
-				case VECT_LESS_ADD:
-				case VECT_LESSEQUAL_ADD:
-				case VECT_GREATER_ADD:
-				case VECT_GREATEREQUAL_ADD:
-				case VECT_CBIND_ADD: {
-					String vectName = type.getVectorPrimitiveName();
-					if(scalarVector)
-						return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
-					else
-						return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, static_cast<uint32_t>(%POSOUT%), %LEN%);\n";
-				}
 
-				//vector-scalar operations
-				case VECT_MULT_SCALAR:
-				case VECT_DIV_SCALAR:
-				case VECT_MINUS_SCALAR:
-				case VECT_PLUS_SCALAR:
-				case VECT_POW_SCALAR:
-				case VECT_XOR_SCALAR:
-				case VECT_BITWAND_SCALAR:
-				case VECT_MIN_SCALAR:
-				case VECT_MAX_SCALAR:
-				case VECT_EQUAL_SCALAR:
-				case VECT_NOTEQUAL_SCALAR:
-				case VECT_LESS_SCALAR:
-				case VECT_LESSEQUAL_SCALAR:
-				case VECT_GREATER_SCALAR:
-				case VECT_GREATEREQUAL_SCALAR: {
-					String vectName = type.getVectorPrimitiveName();
-					if(scalarVector)
-						return sparseRhs ? "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%, this);\n";
-					else
+			//vector-scalar operations
+			case VECT_MULT_SCALAR:
+			case VECT_DIV_SCALAR:
+			case VECT_MINUS_SCALAR:
+			case VECT_PLUS_SCALAR:
+			case VECT_POW_SCALAR:
+			case VECT_XOR_SCALAR:
+			case VECT_BITWAND_SCALAR:
+			case VECT_MIN_SCALAR:
+			case VECT_MAX_SCALAR:
+			case VECT_EQUAL_SCALAR:
+			case VECT_NOTEQUAL_SCALAR:
+			case VECT_LESS_SCALAR:
+			case VECT_LESSEQUAL_SCALAR:
+			case VECT_GREATER_SCALAR:
+			case VECT_GREATEREQUAL_SCALAR: {
+				String vectName = type.getVectorPrimitiveName();
+				if(scalarVector)
+					return sparseRhs ? "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%, this);\n";
+				else
 //						return sparseLhs ? "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
-						return sparseLhs ? "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%, this);\n" : "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
-				}
-					
-				//vector-vector operations
-				case VECT_MULT:
-				case VECT_DIV:
-				case VECT_MINUS:
-				case VECT_PLUS:
-				case VECT_XOR:
-				case VECT_BITWAND:
-				case VECT_BIASADD:
-				case VECT_BIASMULT:
-				case VECT_MIN:
-				case VECT_MAX:
-				case VECT_EQUAL:
-				case VECT_NOTEQUAL:
-				case VECT_LESS:
-				case VECT_LESSEQUAL:
-				case VECT_GREATER:
-				case VECT_GREATEREQUAL: {
-					String vectName = type.getVectorPrimitiveName();
-					return sparseLhs ? "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, " +
-							"alen, %LEN%);\n" : sparseRhs ? "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2v%, " +
-							"%POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : "		Vector<T>& %TMP% = vect" + vectName + 
-						"Write(%IN1%, %IN2%, static_cast<uint32_t>(%POS1%), static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
-				}
+					return sparseLhs ? "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%, this);\n" : "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
+			}
 
-				//scalar-scalar operations
-				case MULT:
-					return "		T %TMP% = %IN1% * %IN2%;\n";
-				case DIV:
-					return "	T %TMP% = %IN1% / %IN2%;\n";
-				case PLUS:
-					return "		T %TMP% = %IN1% + %IN2%;\n";
-				case MINUS:
-					return "	T %TMP% = %IN1% - %IN2%;\n";
-				case MODULUS:
-					return "	T %TMP% = modulus(%IN1%, %IN2%);\n";
-				case INTDIV:
-					return "	T %TMP% = intDiv(%IN1%, %IN2%);\n";
-				case LESS:
-					return "	T %TMP% = (%IN1% < %IN2%) ? 1.0 : 0.0;\n";
-				case LESSEQUAL:
-					return "	T %TMP% = (%IN1% <= %IN2%) ? 1.0 : 0.0;\n";
-				case GREATER:
-					return "	T %TMP% = (%IN1% > (%IN2% + EPSILON)) ? 1.0 : 0.0;\n";
-				case GREATEREQUAL:
-					return "	T %TMP% = (%IN1% >= %IN2%) ? 1.0 : 0.0;\n";
-				case EQUAL:
-					return "	T %TMP% = (%IN1% == %IN2%) ? 1.0 : 0.0;\n";
-				case NOTEQUAL:
-					return "	T %TMP% = (%IN1% != %IN2%) ? 1.0 : 0.0;\n";
+			//vector-vector operations
+			case VECT_MULT:
+			case VECT_DIV:
+			case VECT_MINUS:
+			case VECT_PLUS:
+			case VECT_XOR:
+			case VECT_BITWAND:
+			case VECT_BIASADD:
+			case VECT_BIASMULT:
+			case VECT_MIN:
+			case VECT_MAX:
+			case VECT_EQUAL:
+			case VECT_NOTEQUAL:
+			case VECT_LESS:
+			case VECT_LESSEQUAL:
+			case VECT_GREATER:
+			case VECT_GREATEREQUAL: {
+				String vectName = type.getVectorPrimitiveName();
+				return sparseLhs ? "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, " +
+					   "%POS1%, %POS2%, alen, %LEN%, this);\n" :
+					   sparseRhs ? "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, " +
+					   "%IN2i%, %POS2%, alen, %LEN%);\n" :
+					   "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, " +
+					   "static_cast<uint32_t>(%POS1%), static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
+			}
 
-				case MIN:
-					return "	T %TMP% = min(%IN1%, %IN2%);\n";
-				case MAX:
-					return "	T %TMP% = max(%IN1%, %IN2%);\n";
-				case LOG:
-					return "	T %TMP% = log(%IN1%)/Math.log(%IN2%);\n";
-				case LOG_NZ:
-					return "	T %TMP% = (%IN1% == 0) ? 0 : log(%IN1%) / log(%IN2%);\n";
-				case POW:
-					return "	T %TMP% = pow(%IN1%, %IN2%);\n";
-				case MINUS1_MULT:
-					return "	T %TMP% = 1 - %IN1% * %IN2%;\n";
-				case MINUS_NZ:
-					return "	T %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
-				case XOR:
+			//scalar-scalar operations
+			case MULT:
+				return "		T %TMP% = %IN1% * %IN2%;\n";
+			case DIV:
+				return "\t\tT %TMP% = %IN1% / %IN2%;\n";
+			case PLUS:
+				return "\t\tT %TMP% = %IN1% + %IN2%;\n";
+			case MINUS:
+				return "	T %TMP% = %IN1% - %IN2%;\n";
+			case MODULUS:
+				return "	T %TMP% = modulus(%IN1%, %IN2%);\n";
+			case INTDIV:
+				return "	T %TMP% = intDiv(%IN1%, %IN2%);\n";
+			case LESS:
+				return "	T %TMP% = (%IN1% < %IN2%) ? 1.0 : 0.0;\n";
+			case LESSEQUAL:
+				return "	T %TMP% = (%IN1% <= %IN2%) ? 1.0 : 0.0;\n";
+			case GREATER:
+				return "	T %TMP% = (%IN1% > (%IN2% + EPSILON)) ? 1.0 : 0.0;\n";
+			case GREATEREQUAL:
+				return "	T %TMP% = (%IN1% >= %IN2%) ? 1.0 : 0.0;\n";
+			case EQUAL:
+				return "	T %TMP% = (%IN1% == %IN2%) ? 1.0 : 0.0;\n";
+			case NOTEQUAL:
+				return "\t\tT %TMP% = (%IN1% != %IN2%) ? 1.0 : 0.0;\n";
+			case MIN:
+				if(isSinglePrecision())
+					return "\t\tT %TMP% = fminf(%IN1%, %IN2%);\n";
+				else
+					return "\t\tT %TMP% = min(%IN1%, %IN2%);\n";
+			case MAX:
+				if(isSinglePrecision())
+					return "\t\tT %TMP% = fmaxf(%IN1%, %IN2%);\n";
+				else
+					return "\t\tT %TMP% = max(%IN1%, %IN2%);\n";
+			case LOG:
+				if(isSinglePrecision())
+					return "\t\tT %TMP% = logf(%IN1%) / logf(%IN2%);\n";
+				else
+					return "\t\tT %TMP% = log(%IN1%) / log(%IN2%);\n";
+			case LOG_NZ:
+				if(isSinglePrecision())
+					return "\t\tT %TMP% = (%IN1% == 0) ? 0 : logf(%IN1%) / logf(%IN2%);\n";
+				else
+					return "\t\tT %TMP% = (%IN1% == 0) ? 0 : log(%IN1%) / log(%IN2%);\n";
+			case POW:
+				if(isSinglePrecision())
+					return "\t\tT %TMP% = powf(%IN1%, %IN2%);\n";
+				else
+					return "\t\tT %TMP% = pow(%IN1%, %IN2%);\n";
+			case MINUS1_MULT:
+				return "	T %TMP% = 1 - %IN1% * %IN2%;\n";
+			case MINUS_NZ:
+				return "	T %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
+			case XOR:
 //					return "	T %TMP% = ( (%IN1% != 0.0) != (%IN2% != 0.0) ) ? 1.0 : 0.0;\n";
-					return "	T %TMP% = ( (%IN1% < EPSILON) != (%IN2% < EPSILON) ) ? 1.0 : 0.0;\n";
-				case BITWAND:
-					return "	T %TMP% = bwAnd(%IN1%, %IN2%);\n";
-				case SEQ_RIX:
-					return "		T %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
+				return "	T %TMP% = ( (%IN1% < EPSILON) != (%IN2% < EPSILON) ) ? 1.0 : 0.0;\n";
+			case BITWAND:
+				return "	T %TMP% = bwAnd(%IN1%, %IN2%);\n";
+			case SEQ_RIX:
+				return "\t\tT %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
 
-				default:
-					throw new RuntimeException("Invalid binary type: " + this.toString());
-			}
+			default:
+				throw new RuntimeException("Invalid binary type: " + this.toString());
 		}
 	}
 }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
index dd06d6c004..026fe264f8 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
@@ -28,81 +28,41 @@ public class Ternary extends CodeTemplate {
 
 	@Override
 	public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
-		if(isSinglePrecision()) {
-			switch (type) {
-				case PLUS_MULT:
-					return "	T %TMP% = %IN1% + %IN2% * %IN3%;\n";
+		switch (type) {
+			case PLUS_MULT:
+				return "	T %TMP% = %IN1% + %IN2% * %IN3%;\n";
 
-				case MINUS_MULT:
-					return "	T %TMP% = %IN1% - %IN2% * %IN3%;\n";
+			case MINUS_MULT:
+				return "	T %TMP% = %IN1% - %IN2% * %IN3%;\n";
 
-				case BIASADD:
-					return "	T %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
+			case BIASADD:
+				return "	T %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
 
-				case BIASMULT:
-					return "	T %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
+			case BIASMULT:
+				return "	T %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
 
-				case REPLACE:
-					return "	T %TMP% = (%IN1% == %IN2% || (isnan(%IN1%) "
-							+ "&& isnan(%IN2%))) ? %IN3% : %IN1%;\n";
+			case REPLACE:
+				return "	T %TMP% = (%IN1% == %IN2% || (isnan(%IN1%) "
+						+ "&& isnan(%IN2%))) ? %IN3% : %IN1%;\n";
 
-				case REPLACE_NAN:
-					return "	T %TMP% = isnan(%IN1%) ? %IN3% : %IN1%;\n";
+			case REPLACE_NAN:
+				return "	T %TMP% = isnan(%IN1%) ? %IN3% : %IN1%;\n";
 
-				case IFELSE:
-					return "	T %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
+			case IFELSE:
+				return "	T %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
 
-				case LOOKUP_RC1:
-					return sparse ?
-							"	T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
+			case LOOKUP_RC1:
+				return sparse ?
+						"	T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
 //							"	T %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
-							"		T %TMP% = %IN1%.val(rix, %IN3%-1);\n";
+						"		T %TMP% = %IN1%.val(rix, %IN3%-1);\n";
 
-				case LOOKUP_RVECT1:
-					return "\t\tVector<T>& %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
 
-				default:
-					throw new RuntimeException("Invalid ternary type: " + this.toString());
-			}
-		}
-		else {
-			switch (type) {
-				case PLUS_MULT:
-					return "	T %TMP% = %IN1% + %IN2% * %IN3%;\n";
-
-				case MINUS_MULT:
-					return "	T %TMP% = %IN1% - %IN2% * %IN3%;\n";
-
-				case BIASADD:
-					return "	T %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
-
-				case BIASMULT:
-					return "	T %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
-
-				case REPLACE:
-					return "	T %TMP% = (%IN1% == %IN2% || (isnan(%IN1%) "
-							+ "&& isnan(%IN2%))) ? %IN3% : %IN1%;\n";
-
-				case REPLACE_NAN:
-					return "	T %TMP% = isnan(%IN1%) ? %IN3% : %IN1%;\n";
-
-				case IFELSE:
-					return "	T %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
-
-				case LOOKUP_RC1:
-					return sparse ?
-							"	T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
-//							"	T %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
-							"		T %TMP% = %IN1%.val(rix, %IN3%-1);\n";
-				
-				
-				case LOOKUP_RVECT1:
-					return "\t\tVector<T>& %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1, this);\n";
-
-				default:
-					throw new RuntimeException("Invalid ternary type: "+this.toString());
-			}
+			case LOOKUP_RVECT1:
+				return "\t\tVector<T>& %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1, this);\n";
 
+			default:
+				throw new RuntimeException("Invalid ternary type: "+this.toString());
 		}
 	}
 }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
index f2405d5b5c..405b880715 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
@@ -29,210 +29,161 @@ public class Unary extends CodeTemplate {
 
 	@Override
 	public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
-		if(isSinglePrecision()) {
-			switch( type ) {
-				case ROW_SUMS:
-				case ROW_SUMSQS:
-				case ROW_MINS:
-				case ROW_MAXS:
-				case ROW_MEANS:
-				case ROW_COUNTNNZS: {
-					String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
-					return sparse ? "	T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
-						"	T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
-				}
+		switch( type ) {
+			case ROW_SUMS:
+			case ROW_SUMSQS:
+			case ROW_MINS:
+			case ROW_MAXS:
+			case ROW_MEANS:
+			case ROW_COUNTNNZS: {
+				String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
+				return sparse ? "		T %TMP% = vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, %LEN%);\n":
+					"		T %TMP% = vect"+vectName+"(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%);\n";
 
-				case VECT_EXP:
-				case VECT_POW2:
-				case VECT_MULT2:
-				case VECT_SQRT:
-				case VECT_LOG:
-				case VECT_ABS:
-				case VECT_ROUND:
-				case VECT_CEIL:
-				case VECT_FLOOR:
-				case VECT_SIGN:
-				case VECT_SIN:
-				case VECT_COS:
-				case VECT_TAN:
-				case VECT_ASIN:
-				case VECT_ACOS:
-				case VECT_ATAN:
-				case VECT_SINH:
-				case VECT_COSH:
-				case VECT_TANH:
-				case VECT_CUMSUM:
-				case VECT_CUMMIN:
-				case VECT_CUMMAX:
-				case VECT_SPROP:
-				case VECT_SIGMOID: {
-					String vectName = type.getVectorPrimitiveName();
-					return sparse ? "	T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" :
-						"	T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n";
-				}
-
-				case EXP:
-					return "	T %TMP% = expf(%IN1%);\n";
-				case LOOKUP_R:
-					return sparse ?
-						"	T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
-						"	T %TMP% = getValue(%IN1%, rix);\n";
-				case LOOKUP_C:
-					return "	T %TMP% = getValue(%IN1%, n, 0, cix);\n";
-				case LOOKUP_RC:
-					return "	T %TMP% = getValue(%IN1%, n, rix, cix);\n";
-				case LOOKUP0:
-					return "	T %TMP% = %IN1%[0];\n";
-				case POW2:
-					return "	T %TMP% = %IN1% * %IN1%;\n";
-				case MULT2:
-					return "	T %TMP% = %IN1% + %IN1%;\n";
-				case ABS:
-					return "	T %TMP% = fabsf(%IN1%);\n";
-				case SIN:
-					return "	T %TMP% = sinf(%IN1%);\n";
-				case COS:
-					return "	T %TMP% = cosf(%IN1%);\n";
-				case TAN:
-					return "	T %TMP% = tanf(%IN1%);\n";
-				case ASIN:
-					return "	T %TMP% = asinf(%IN1%);\n";
-				case ACOS:
-					return "	T %TMP% = acosf(%IN1%);\n";
-				case ATAN:
-					return "	T %TMP% = atanf(%IN1%);\n";
-				case SINH:
-					return "	T %TMP% = sinhf(%IN1%);\n";
-				case COSH:
-					return "	T %TMP% = coshf(%IN1%);\n";
-				case TANH:
-					return "	T %TMP% = tanhf(%IN1%);\n";
-				case SIGN:
-					return "	T %TMP% = signbit(%IN1%) == 0 ? 1.0f : -1.0f;\n";
-				case SQRT:
-					return "	T %TMP% = sqrtf(%IN1%);\n";
-				case LOG:
-					return "	T %TMP% = logf(%IN1%);\n";
-				case ROUND:
-					return "	T %TMP% = roundf(%IN1%);\n";
-				case CEIL:
-					return "	T %TMP% = ceilf(%IN1%);\n";
-				case FLOOR:
-					return "	T %TMP% = floorf(%IN1%);\n";
-				case SPROP:
-					return "	T %TMP% = %IN1% * (1 - %IN1%);\n";
-				case SIGMOID:
-					return "	T %TMP% = 1 / (1 + expf(-%IN1%));\n";
-				case LOG_NZ:
-					return "	T %TMP% = (%IN1%==0) ? 0 : logf(%IN1%);\n";
-
-				default:
-					throw new RuntimeException("Invalid unary type: "+this.toString());
 			}
-		}
-		else { /* double precision */
-			switch( type ) {
-				case ROW_SUMS:
-				case ROW_SUMSQS:
-				case ROW_MINS:
-				case ROW_MAXS:
-				case ROW_MEANS:
-				case ROW_COUNTNNZS: {
-					String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
-					return sparse ? "		T %TMP% = vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, %LEN%);\n":
-						"		T %TMP% = vect"+vectName+"(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%);\n";
-					
-				}
 
-				case VECT_EXP:
-				case VECT_POW2:
-				case VECT_MULT2:
-				case VECT_SQRT:
-				case VECT_LOG:
-				case VECT_ABS:
-				case VECT_ROUND:
-				case VECT_CEIL:
-				case VECT_FLOOR:
-				case VECT_SIGN:
-				case VECT_SIN:
-				case VECT_COS:
-				case VECT_TAN:
-				case VECT_ASIN:
-				case VECT_ACOS:
-				case VECT_ATAN:
-				case VECT_SINH:
-				case VECT_COSH:
-				case VECT_TANH:
-				case VECT_CUMSUM:
-				case VECT_CUMMIN:
-				case VECT_CUMMAX:
-				case VECT_SPROP:
-				case VECT_SIGMOID: {
-					String vectName = type.getVectorPrimitiveName();
-					return sparse ? "		Vector<T>& %TMP% = vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, %LEN%, this);\n" :
-						"		Vector<T>& %TMP% = vect"+vectName+"Write(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
-				}
+			case VECT_EXP:
+			case VECT_POW2:
+			case VECT_MULT2:
+			case VECT_SQRT:
+			case VECT_LOG:
+			case VECT_ABS:
+			case VECT_ROUND:
+			case VECT_CEIL:
+			case VECT_FLOOR:
+			case VECT_SIGN:
+			case VECT_SIN:
+			case VECT_COS:
+			case VECT_TAN:
+			case VECT_ASIN:
+			case VECT_ACOS:
+			case VECT_ATAN:
+			case VECT_SINH:
+			case VECT_COSH:
+			case VECT_TANH:
+			case VECT_CUMSUM:
+			case VECT_CUMMIN:
+			case VECT_CUMMAX:
+			case VECT_SPROP:
+			case VECT_SIGMOID: {
+				String vectName = type.getVectorPrimitiveName();
+				return sparse ? "		Vector<T>& %TMP% = vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, %LEN%, this);\n" :
+					"		Vector<T>& %TMP% = vect"+vectName+"Write(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
+			}
 
-				case EXP:
+			case EXP:
+				if(isSinglePrecision())
+					return "	T %TMP% = expf(%IN1%);\n";
+				else
 					return "	T %TMP% = exp(%IN1%);\n";
-				case LOOKUP_R:
-					return sparse ?
-						"	T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
-						"		T %TMP% = %IN1%.val(rix);\n";
-//						"	T %TMP% = getValue(%IN1%, rix);\n";
-				case LOOKUP_C:
-					return "	T %TMP% = getValue(%IN1%, n, 0, cix);\n";
-				case LOOKUP_RC:
-					return "	T %TMP% = getValue(%IN1%, n, rix, cix);\n";
-				case LOOKUP0:
-					return "	T %TMP% = %IN1%[0];\n";
-				case POW2:
-					return "	T %TMP% = %IN1% * %IN1%;\n";
-				case MULT2:
-					return "	T %TMP% = %IN1% + %IN1%;\n";
-				case ABS:
+			case LOOKUP_R:
+				return sparse ?
+					"\t\tT %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
+//						"		T %TMP% = %IN1%.val(rix);\n";
+					"\t\tT %TMP% = getValue(%IN1%, rix);\n";
+			case LOOKUP_C:
+				return "\t\tT %TMP% = getValue(%IN1%, n, 0, cix);\n";
+			case LOOKUP_RC:
+				return "\t\tT %TMP% = getValue(%IN1%, n, rix, cix);\n";
+			case LOOKUP0:
+				return "\t\tT %TMP% = %IN1%[0];\n";
+			case POW2:
+				return "	T %TMP% = %IN1% * %IN1%;\n";
+			case MULT2:
+				return "	T %TMP% = %IN1% + %IN1%;\n";
+			case ABS:
+				if(isSinglePrecision())
+					return "	T %TMP% = fabsf(%IN1%);\n";
+				else
 					return "\t\tT %TMP% = fabs(%IN1%);\n";
-				case SIN:
+			case SIN:
+				if(isSinglePrecision())
+					return "	T %TMP% = sinf(%IN1%);\n";
+				else
 					return "	T %TMP% = sin(%IN1%);\n";
-				case COS:
+			case COS:
+				if(isSinglePrecision())
+					return "	T %TMP% = cosf(%IN1%);\n";
+				else
 					return "	T %TMP% = cos(%IN1%);\n";
-				case TAN:
+			case TAN:
+				if(isSinglePrecision())
+					return "	T %TMP% = tanf(%IN1%);\n";
+				else
 					return "	T %TMP% = tan(%IN1%);\n";
-				case ASIN:
+			case ASIN:
+				if(isSinglePrecision())
+					return "	T %TMP% = asinf(%IN1%);\n";
+				else
 					return "	T %TMP% = asin(%IN1%);\n";
-				case ACOS:
+			case ACOS:
+				if(isSinglePrecision())
+					return "	T %TMP% = acosf(%IN1%);\n";
+				else
 					return "	T %TMP% = acos(%IN1%);\n";
-				case ATAN:
+			case ATAN:
+				if(isSinglePrecision())
+					return "	T %TMP% = atanf(%IN1%);\n";
+				else
 					return "	T %TMP% = atan(%IN1%);\n";
-				case SINH:
+			case SINH:
+				if(isSinglePrecision())
+					return "	T %TMP% = sinhf(%IN1%);\n";
+				else
 					return "	T %TMP% = sinh(%IN1%);\n";
-				case COSH:
+			case COSH:
+				if(isSinglePrecision())
+					return "	T %TMP% = coshf(%IN1%);\n";
+				else
 					return "	T %TMP% = cosh(%IN1%);\n";
-				case TANH:
+			case TANH:
+				if(isSinglePrecision())
+					return "	T %TMP% = tanhf(%IN1%);\n";
+				else
 					return "	T %TMP% = tanh(%IN1%);\n";
-				case SIGN:
-					return "	T %TMP% = signbit(%IN1%) == 0 ? 1.0 : -1.0;\n";
-				case SQRT:
+			case SIGN:
+				return "	T %TMP% = signbit(%IN1%) == 0 ? 1.0 : -1.0;\n";
+			case SQRT:
+				if(isSinglePrecision())
+					return "	T %TMP% = sqrtf(%IN1%);\n";
+				else
 					return "	T %TMP% = sqrt(%IN1%);\n";
-				case LOG:
+			case LOG:
+
+				if(isSinglePrecision())
+					return "	T %TMP% = logf(%IN1%);\n";
+				else
 					return "		T %TMP% = log(%IN1%);\n";
-				case ROUND:
+			case ROUND:
+				if(isSinglePrecision())
+					return "	T %TMP% = roundf(%IN1%);\n";
+				else
 					return "\t\tT %TMP% = round(%IN1%);\n";
-				case CEIL:
+			case CEIL:
+				if(isSinglePrecision())
+					return "	T %TMP% = ceilf(%IN1%);\n";
+				else
 					return "	T %TMP% = ceil(%IN1%);\n";
-				case FLOOR:
+			case FLOOR:
+				if(isSinglePrecision())
+					return "	T %TMP% = floorf(%IN1%);\n";
+				else
 					return "	T %TMP% = floor(%IN1%);\n";
-				case SPROP:
-					return "	T %TMP% = %IN1% * (1 - %IN1%);\n";
-				case SIGMOID:
+			case SPROP:
+				return "	T %TMP% = %IN1% * (1 - %IN1%);\n";
+			case SIGMOID:
+				if(isSinglePrecision())
+					return "	T %TMP% = 1 / (1 + expf(-%IN1%));\n";
+				else
 					return "	T %TMP% = 1 / (1 + exp(-%IN1%));\n";
-				case LOG_NZ:
+			case LOG_NZ:
+				if(isSinglePrecision())
+					return "	T %TMP% = (%IN1%==0) ? 0 : logf(%IN1%);\n";
+				else
 					return "	T %TMP% = (%IN1%==0) ? 0 : log(%IN1%);\n";
 
-				default:
-					throw new RuntimeException("Invalid unary type: "+this.toString());
-			}
-
+			default:
+				throw new RuntimeException("Invalid unary type: "+this.toString());
 		}
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
index 03c35da540..cfe5780326 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
@@ -116,9 +116,9 @@ public class SpoofCUDACellwise extends SpoofCellwise implements SpoofCUDAOperato
 	}
 
 	public int execute_dp(long ctx) { return execute_d(ctx); }
-	public int execute_sp(long ctx) { return execute_d(ctx); }
+	public int execute_sp(long ctx) { return execute_f(ctx); }
 	public long getContext() { return ctx; }
 
 	public static native int execute_d(long ctx);
-	public static native int execute_s(long ctx);
+	public static native int execute_f(long ctx);
 }
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
index 47826a9461..0adf2ec605 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
@@ -96,9 +96,9 @@ public class SpoofCUDARowwise extends SpoofRowwise implements SpoofCUDAOperator
 		int ci, int alen, int n, long grix, int rix) { }
 
 	public int execute_dp(long ctx) { return execute_d(ctx); }
-	public int execute_sp(long ctx) { return execute_d(ctx); }
+	public int execute_sp(long ctx) { return execute_f(ctx); }
 	public long getContext() { return ctx; }
 
 	public static native int execute_d(long ctx);
-	public static native int execute_s(long ctx);
+	public static native int execute_f(long ctx);
 }