You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ma...@apache.org on 2022/04/20 12:17:55 UTC
[systemds] 01/02: [SYSTEMDS-3352] CUDA code gen support for connected components
This is an automated email from the ASF dual-hosted git repository.
markd pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
commit fc5b03de84ebd57214a69ca43f63f223dd258c89
Author: Mark Dokter <ma...@dokter.cc>
AuthorDate: Wed Apr 20 13:06:05 2022 +0200
[SYSTEMDS-3352] CUDA code gen support for connected components
General cleanup and bug fixing to make components.dml run.
Also contains improvements to handle single precision execution.
---
src/main/cuda/headers/Matrix.h | 52 ++-
src/main/cuda/headers/spoof_utils.cuh | 215 +++++------
src/main/cuda/headers/vector_write.cuh | 20 +-
src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp | 23 +-
src/main/cuda/spoof-launcher/SpoofRowwise.h | 8 +-
src/main/cuda/spoof-launcher/jni_bridge.cpp | 4 +-
src/main/cuda/spoof/rowwise.cu | 9 +-
.../apache/sysds/hops/codegen/cplan/CNodeRow.java | 7 +-
.../sysds/hops/codegen/cplan/cuda/Binary.java | 391 ++++++++-------------
.../sysds/hops/codegen/cplan/cuda/Ternary.java | 88 ++---
.../sysds/hops/codegen/cplan/cuda/Unary.java | 313 +++++++----------
.../sysds/runtime/codegen/SpoofCUDACellwise.java | 4 +-
.../sysds/runtime/codegen/SpoofCUDARowwise.java | 4 +-
13 files changed, 462 insertions(+), 676 deletions(-)
diff --git a/src/main/cuda/headers/Matrix.h b/src/main/cuda/headers/Matrix.h
index 61ef939b83..f02a76c83c 100644
--- a/src/main/cuda/headers/Matrix.h
+++ b/src/main/cuda/headers/Matrix.h
@@ -18,8 +18,6 @@
*/
#pragma once
-#ifndef SYSTEMDS_MATRIX_H
-#define SYSTEMDS_MATRIX_H
using uint32_t = unsigned int;
using int32_t = int;
@@ -43,22 +41,22 @@ struct Matrix {
#ifdef __CUDACC__
-template<typename T>
-uint32_t bin_search(T* values, uint32_t lower, uint32_t upper, T val) {
- upper -= 1;
- while(lower <= (upper-1)) {
- uint32_t idx = (lower + upper) >> 1;
- uint32_t vi = values[idx];
- if (vi < val)
- lower = idx + 1;
- else {
- if (vi <= val)
- return idx;
- upper = idx - 1;
- }
- }
- return upper + 1;
-}
+//template<typename T>
+//uint32_t bin_search(T* values, uint32_t lower, uint32_t upper, T val) {
+// upper -= 1;
+// while(lower <= (upper-1)) {
+// uint32_t idx = (lower + upper) >> 1;
+// uint32_t vi = values[idx];
+// if (vi < val)
+// lower = idx + 1;
+// else {
+// if (vi <= val)
+// return idx;
+// upper = idx - 1;
+// }
+// }
+// return upper + 1;
+//}
template<typename T>
class MatrixAccessor {
@@ -68,11 +66,11 @@ class MatrixAccessor {
public:
MatrixAccessor() = default;
- __device__ MatrixAccessor(Matrix<T>* mat) : _mat(mat) {}
+ __device__ explicit MatrixAccessor(Matrix<T>* mat) : _mat(mat) {}
__device__ void init(Matrix<T>* mat) { _mat = mat; }
- __device__ uint32_t& nnz() { return return _mat->row_ptr == nullptr ? _mat->rows * _mat->cols : _mat->nnz; }
+// __device__ uint32_t& nnz() { return _mat->row_ptr == nullptr ? _mat->rows * _mat->cols : _mat->nnz; }
__device__ uint32_t cols() { return _mat->cols; }
__device__ uint32_t rows() { return _mat->rows; }
@@ -96,14 +94,14 @@ public:
}
__device__ uint32_t row_len(uint32_t rix) {
- return _mat->row_ptr == nullptr ? row_len_dense(rix) : row_len_sparse(rix);
+ return _mat->row_ptr == nullptr ? _mat->rows : row_len_sparse(rix);
}
__device__ uint32_t* col_idxs(uint32_t rix) { return cols_sparse(rix); }
__device__ void set(uint32_t r, uint32_t c, T v) { set_sparse(r,c,v); }
- __device__ uint32_t* indexes() { return _mat->row_ptr; }
+// __device__ uint32_t* indexes() { return _mat->row_ptr; }
__device__ bool hasData() { return _mat->data != nullptr; }
private:
@@ -127,10 +125,6 @@ private:
return &(_mat->data[rix]);
}
- __device__ uint32_t row_len_dense(uint32_t rix) {
- return _mat->rows;
- }
-
//ToDo sparse accessors
__device__ uint32_t len_sparse() {
return _mat->row_ptr[_mat->rows];
@@ -145,8 +139,8 @@ private:
}
__device__ T& val_sparse_rc(uint32_t r, uint32_t c) {
-// printf("TBI: val_sparse_rc\n");
-// asm("trap;");
+ printf("TBI: val_sparse_rc(%d, %d)\n", r, c);
+ asm("trap;");
return _mat->data[0];
}
@@ -228,5 +222,3 @@ public:
};
#endif // __CUDACC_RTC__
-
-#endif //SYSTEMDS_MATRIX_H
diff --git a/src/main/cuda/headers/spoof_utils.cuh b/src/main/cuda/headers/spoof_utils.cuh
index 5d9b1012b2..8ab0fafdb2 100644
--- a/src/main/cuda/headers/spoof_utils.cuh
+++ b/src/main/cuda/headers/spoof_utils.cuh
@@ -18,8 +18,6 @@
*/
#pragma once
-#ifndef SPOOF_UTILS_CUH
-#define SPOOF_UTILS_CUH
#include <math_constants.h>
#include "vector_add.cuh"
@@ -31,13 +29,8 @@ struct TempStorage;
#include "Matrix.h"
#include "vector_write.cuh"
-// #include "intellisense_cuda_intrinsics.h"
-
using uint32_t = unsigned int;
-//static __device__ bool debug_row() { return blockIdx.x == 0; };
-//static __device__ bool debug_thread() { return threadIdx.x == 0; }
-
__constant__ double DOUBLE_EPS = 1.11022E-16; // 2 ^ -53
__constant__ double FLOAT_EPS = 1.49012E-08; // 2 ^ -26
__constant__ double EPSILON = 1E-11; // margin for comparisons ToDo: make consistent use of it
@@ -79,12 +72,6 @@ __device__ Vector<T>& getVector(MatrixAccessor<T>& data, uint32_t n, uint32_t ri
c[i] = data.val(rix, i);
i += blockDim.x;
}
-// if(debug_thread()) {
-// printf("getVector: c.len=%d rix=%d\n", c.length, rix);
-// for(auto j = 0; j < c.length; ++j)
-// printf("%4.3f ", c[j]);
-// printf("\n");
-// }
return c;
}
@@ -146,122 +133,147 @@ __device__ T BLOCK_ROW_AGG(T *a, T *b, uint32_t len, AggOp agg_op, LoadOp load_o
auto sdata = shared_memory_proxy<T>();
uint tid = threadIdx.x;
- // Initalize shared mem and leave if tid > row length.
-// if(tid >= len) { return sdata[tid] = AggOp::init();; }
-
- __syncthreads();
-
-// if(blockIdx.x == 0 && threadIdx.x == 0)
-// printf("tid=%d sdata[tid + 128]=%f, len=%d\n", tid, len, sdata[tid+128]);
uint i = tid;
T v = AggOp::init();
-// if(blockIdx.x == 0 && threadIdx.x == 0)
-// printf("tid=%d sdata[tid + 128]=%f\n", tid, sdata[tid+128]);
- while (i < len) {
+ while(i < len) {
v = agg_op(v, load_op(a[i], b[i]));
i += blockDim.x;
}
-// if(blockIdx.x == 0 && threadIdx.x == 0)
-// if(debug_row() && debug_thread())
-// printf("tid=%d sdata[tid + 128]=%f\n", tid, sdata[tid+128]);
-
// each thread puts its local sum into shared memory
sdata[tid] = v;
- // if(blockIdx.x==0)
- // printf("tid=%d v=%f, len=%d\n", tid, v, len);
__syncthreads();
- // if(blockIdx.x == 0 && threadIdx.x == 0)
- // printf("tid=%d sdata[tid + 128]=%f\n", tid, sdata[tid+128]);
-
// do reduction in shared mem
- if (blockDim.x >= 1024) {
- if (tid < 512 && (tid+512) < len) {
- // if(blockIdx.x == 0 && threadIdx.x == 0)
- // printf("tid=%d sdata[tid + 512]=%f\n", tid, sdata[tid+512]);
+ if(blockDim.x >= 1024) {
+ if(tid < 512 && (tid+512) < len) {
sdata[tid] = v = agg_op(v, sdata[tid + 512]);
}
__syncthreads();
}
- if (blockDim.x >= 512) {
- if (tid < 256 && (tid+256) < len) {
- // if(blockIdx.x == 0 && threadIdx.x == 0)
- // printf("tid=%d sdata[tid + 256]=%f\n", tid, sdata[tid+256]);
+ if(blockDim.x >= 512) {
+ if(tid < 256 && (tid+256) < len) {
sdata[tid] = v = agg_op(v, sdata[tid + 256]);
}
__syncthreads();
}
- if (blockDim.x >= 256) {
- if (tid < 128 && (tid+128) < len) {
- // if(blockIdx.x == 0 && threadIdx.x == 0)
- // printf("tid=%d sdata[tid + 128]=%f\n", tid, sdata[tid+128]);
+ if(blockDim.x >= 256) {
+ if(tid < 128 && (tid+128) < len) {
sdata[tid] = v = agg_op(v, sdata[tid + 128]);
}
__syncthreads();
}
- if (blockDim.x >= 128) {
- if (tid < 64 && (tid+64) < len) {
- // if(blockIdx.x == 0 && threadIdx.x == 0)
- // printf("tid=%d sdata[tid + 64]=%f\n", tid, sdata[tid+64]);
+ if(blockDim.x >= 128) {
+if(tid < 64 && (tid+64) < len) {
sdata[tid] = v = agg_op(v, sdata[tid + 64]);
}
__syncthreads();
}
-
- if (tid < 32) {
+
+ if(tid < 32) {
// now that we are using warp-synchronous programming (below)
// we need to declare our shared memory volatile so that the compiler
// doesn't reorder stores to it and induce incorrect behavior.
volatile T *smem = sdata;
- if (blockDim.x >= 64 && (tid+32) < len) {
+ if(blockDim.x >= 64 && (tid+32) < len) {
smem[tid] = v = agg_op(v, smem[tid + 32]);
}
- // if(blockIdx.x==0)
- // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
- if (blockDim.x >= 32 && (tid+16) < len) {
+ if(blockDim.x >= 32 && (tid+16) < len) {
smem[tid] = v = agg_op(v, smem[tid + 16]);
}
- // if(blockIdx.x==0)
- // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
- if (blockDim.x >= 16 && (tid+8) < len) {
+ if(blockDim.x >= 16 && (tid+8) < len) {
smem[tid] = v = agg_op(v, smem[tid + 8]);
}
- // if(blockIdx.x==0)
- // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
- if (blockDim.x >= 8 && (tid+4) < len) {
+ if(blockDim.x >= 8 && (tid+4) < len) {
smem[tid] = v = agg_op(v, smem[tid + 4]);
}
- // if(blockIdx.x==0 && threadIdx.x ==0)
- // printf("tid=%d smem[tid + 4]=%f\n", tid, smem[tid+4]);
- if (blockDim.x >= 4 && (tid+2) < len) {
+ if(blockDim.x >= 4 && (tid+2) < len) {
smem[tid] = v = agg_op(v, smem[tid + 2]);
}
- // if(blockIdx.x==0 && threadIdx.x ==0)
- // printf("tid=%d smem[0]=%f\n", tid, smem[0]);
- if (blockDim.x >= 2 && (tid+1) < len) {
- // if (blockDim.x >= 2) {
+ if(blockDim.x >= 2 && (tid+1) < len) {
smem[tid] = v = agg_op(v, smem[tid + 1]);
}
-// if(blockIdx.x==0 && threadIdx.x ==0)
-// if(debug_row() && debug_thread())
-// printf("tid=%d smem[0]=%f\n", tid, smem[0]);
}
-
__syncthreads();
return sdata[0];
}
+
+template<typename T, typename AggOp, typename LoadOp>
+__device__ T BLOCK_ROW_AGG(T *a, T *b, uint32_t* aix, uint32_t len, AggOp agg_op, LoadOp load_op) {
+ auto sdata = shared_memory_proxy<T>();
+ uint tid = threadIdx.x;
+
+ uint i = tid;
+ T v = AggOp::init();
+ while(i < len) {
+ v = agg_op(v, load_op(a[i], b[aix[i]]));
+ i += blockDim.x;
+ }
+
+ // each thread puts its local sum into shared memory
+ sdata[tid] = v;
+ __syncthreads();
+
+ // do reduction in shared mem
+ if(blockDim.x >= 1024) {
+ if(tid < 512 && (tid+512) < len) {
+ sdata[tid] = v = agg_op(v, sdata[tid + 512]);
+ }
+ __syncthreads();
+ }
+ if(blockDim.x >= 512) {
+ if(tid < 256 && (tid+256) < len) {
+ sdata[tid] = v = agg_op(v, sdata[tid + 256]);
+ }
+ __syncthreads();
+ }
+ if(blockDim.x >= 256) {
+ if(tid < 128 && (tid+128) < len) {
+ sdata[tid] = v = agg_op(v, sdata[tid + 128]);
+ }
+ __syncthreads();
+ }
+ if(blockDim.x >= 128) {
+ if(tid < 64 && (tid+64) < len) {
+ sdata[tid] = v = agg_op(v, sdata[tid + 64]);
+ }
+ __syncthreads();
+ }
+
+ if(tid < 32) {
+ // now that we are using warp-synchronous programming (below)
+ // we need to declare our shared memory volatile so that the compiler
+ // doesn't reorder stores to it and induce incorrect behavior.
+ volatile T *smem = sdata;
+ if(blockDim.x >= 64 && (tid+32) < len) {
+ smem[tid] = v = agg_op(v, smem[tid + 32]);
+ }
+ if(blockDim.x >= 32 && (tid+16) < len) {
+ smem[tid] = v = agg_op(v, smem[tid + 16]);
+ }
+ if(blockDim.x >= 16 && (tid+8) < len) {
+ smem[tid] = v = agg_op(v, smem[tid + 8]);
+ }
+ if(blockDim.x >= 8 && (tid+4) < len) {
+ smem[tid] = v = agg_op(v, smem[tid + 4]);
+ }
+ if(blockDim.x >= 4 && (tid+2) < len) {
+ smem[tid] = v = agg_op(v, smem[tid + 2]);
+ }
+ if(blockDim.x >= 2 && (tid+1) < len) {
+ smem[tid] = v = agg_op(v, smem[tid + 1]);
+ }
+ }
+ __syncthreads();
+ return sdata[0];
+}
+
template<typename T>
__device__ T dotProduct(T* a, T* b, uint32_t ai, uint32_t bi, uint32_t len) {
SumOp<T> agg_op;
ProductOp<T> load_op;
-// if(debug_row() && debug_thread())
-// printf("dot len = %d\n", len);
- T ret = BLOCK_ROW_AGG(&a[ai], &b[bi], len, agg_op, load_op);
-// if(debug_row() && debug_thread())
-// printf("bid=%d, ai=%d, dot=%f\n", blockIdx.x, ai, ret);
- return ret;
+ return BLOCK_ROW_AGG(&a[ai], &b[bi], len, agg_op, load_op);
}
template<typename T>
@@ -277,8 +289,6 @@ __device__ T vectSum(T* a, uint32_t ai, uint32_t len) {
SumOp<T> agg_op;
IdentityOp<T> load_op;
T result = BLOCK_ROW_AGG(&a[ai], &a[ai], len, agg_op, load_op);
-// if(debug_row() && debug_thread())
-// printf("vectSum: bid=%d, tid=%d ai=%d len=%d result=%4.3f\n", blockIdx.x, threadIdx.x, ai, len, result);
return result;
}
@@ -286,10 +296,22 @@ template<typename T>
__device__ T vectMin(T* a, int ai, int len) {
MinOp<T> agg_op;
IdentityOp<T> load_op;
- T result = BLOCK_ROW_AGG(&a[ai], &a[ai], len, agg_op, load_op);
-// if(debug_row() && debug_thread())
-// printf("vectMin: bid=%d, tid=%d ai=%d len=%d result=%4.3f\n", blockIdx.x, threadIdx.x, ai, len, result);
- return result;
+ return BLOCK_ROW_AGG(&a[ai], &a[ai], len, agg_op, load_op);
+}
+
+template<typename T>
+__device__ T rowMaxsVectMult(T* a, T* b, uint32_t ai, uint32_t bi, uint32_t len) {
+ MaxOp<T> agg_op;
+ ProductOp<T> load_op;
+ return BLOCK_ROW_AGG(&a[ai], &b[0], len, agg_op, load_op);
+}
+
+template<typename T>
+__device__ T rowMaxsVectMult(T* a, T* b, uint32_t* aix, uint32_t ai, uint32_t bi, uint32_t len) {
+ MaxOp<T> agg_op;
+ ProductOp<T> load_op;
+
+ return BLOCK_ROW_AGG(&a[ai], &b[0], &aix[ai], len, agg_op, load_op);
}
template<typename T>
@@ -302,25 +324,7 @@ __device__ T vectMax(T* a, uint32_t ai, uint32_t len) {
template<typename T>
__device__ T vectMax(T* avals, uint32_t* aix, uint32_t ai, uint32_t alen, uint32_t len) {
-// if (debug_row() && debug_thread()) {
-// printf("\naix[i]:\n");
-// for(auto i = 0; i < alen; ++i)
-// printf(" %d", aix[i]);
-
-// printf("\navals[i]:\n");
-// for(auto i = 0; i < alen; ++i)
-// printf(" %4.3f", avals[i]);
-
-// printf("\navals[aix[i]]:\n");
-// for(auto i = 0; i < alen; ++i)
-// printf(" %4.3f", avals[aix[i]]);
-
-// printf("\n");
-// }
-
T result = vectMax(avals, ai, alen);
-// if (blockIdx.x < 5 && debug_thread())
-// printf("bid=%d, tid=%d, len=%d, alen=%d, ai=%d vectMax=%4.3f\n", blockIdx.x, threadIdx.x, len, alen, ai, result);
return alen < len ? MaxOp<T>::exec(result, 0.0) : result;
}
@@ -547,6 +551,12 @@ Vector<T>& vectMultWrite(T* a, T* b, uint32_t ai, uint32_t bi, uint32_t len, Tem
return vectWriteBinary<T, ProductOp<T>>(a, b, ai, bi, len, fop, "Mult");
}
+// sparse-dense MxV
+template<typename T>
+Vector<T>& vectMultWrite(T* avals, T* b, uint32_t* aix, uint32_t ai, uint32_t bi, uint32_t alen, uint32_t len, TempStorage<T>* fop) {
+ return vectWriteBinary<T, ProductOp<T>>(avals, b, aix, ai, bi, alen, len, fop, "Mult");
+}
+
template<typename T>
Vector<T>& vectDivWrite(T* a, T b, uint32_t ai, uint32_t len, TempStorage<T>* fop) {
return vectWriteBinary<T, DivOp<T>>(a, b, ai, len, fop, "Div");
@@ -744,6 +754,3 @@ void vectOuterMultAdd(T* a, T* b, T* c, uint32_t ai, uint32_t bi, uint32_t ci, u
i += blockDim.x;
}
}
-
-
-#endif // SPOOF_UTILS_CUH
diff --git a/src/main/cuda/headers/vector_write.cuh b/src/main/cuda/headers/vector_write.cuh
index 3099926167..55241bd8d4 100644
--- a/src/main/cuda/headers/vector_write.cuh
+++ b/src/main/cuda/headers/vector_write.cuh
@@ -18,10 +18,9 @@
*/
#pragma once
-#ifndef SYSTEMDS_VECTOR_WRITE_CUH
-#define SYSTEMDS_VECTOR_WRITE_CUH
-__device__ bool debug_row() { return blockIdx.x == 1; };
+#define DEBUG_ROW 2
+__device__ bool debug_row() { return blockIdx.x == DEBUG_ROW; };
__device__ bool debug_thread() { return threadIdx.x == 0; }
// unary transform vector by OP and write to intermediate vector
@@ -143,6 +142,19 @@ __device__ Vector<T>& vectWriteBinary(T* a, T* b, uint32_t ai, uint32_t bi, uint
return c;
}
+// sparse binary vect-vect to intermediate vector
+template<typename T, typename OP>
+__device__ Vector<T>& vectWriteBinary(T* a, T* b, uint32_t* aix, uint32_t ai, uint32_t bi, uint32_t alen, uint32_t len,
+ TempStorage<T>* fop, const char* name = nullptr) {
+ uint32_t i = threadIdx.x;
+ Vector<T>& c = fop->getTempStorage(len);
+ while (i < alen) {
+ c[aix[ai+i]] = OP::exec(a[ai + i], b[aix[ai+i]]);
+ i += blockDim.x;
+ }
+ return c;
+}
+
// binary vector-scalar to output vector c
template<typename T, typename OP>
__device__ void vectWriteBinary(T* a, T b, T* c, uint32_t ai, uint32_t ci, uint32_t len) {
@@ -168,5 +180,3 @@ __device__ void vectWriteBinary(T* a, T* b, T* c, uint32_t ai, uint32_t bi, uint
i += blockDim.x;
}
}
-
-#endif //SYSTEMDS_VECTOR_WRITE_CUH
diff --git a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
index 2ef482d18d..c4ae1e3dff 100644
--- a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
+++ b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
@@ -47,11 +47,7 @@ size_t SpoofCUDAContext::initialize_cuda(uint32_t device_id, const char* resourc
s1 << "-I" << resource_path << "/cuda/headers";
s2 << "-I" << resource_path << "/cuda/spoof";
auto ctx = new SpoofCUDAContext(resource_path,{s1.str(), s2.str(), cuda_include_path});
- // cuda device is handled by jCuda atm
- //cudaSetDevice(device_id);
- //cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync);
- //cudaDeviceSynchronize();
-
+
CHECK_CUDA(cuModuleLoad(&(ctx->reductions), std::string(ctx->resource_path + std::string("/cuda/kernels/reduction.ptx")).c_str()));
CUfunction func;
@@ -87,30 +83,13 @@ void SpoofCUDAContext::destroy_cuda(SpoofCUDAContext *ctx, [[maybe_unused]] uint
cudaFreeHost(ctx->staging_buffer);
cudaFree(ctx->device_buffer);
delete ctx;
- // cuda device is handled by jCuda atm
- //cudaDeviceReset();
}
size_t SpoofCUDAContext::compile(std::unique_ptr<SpoofOperator> op, const std::string &src) {
-#ifndef NDEBUG
-// std::cout << "---=== START source listing of spoof cuda kernel [ " << name << " ]: " << std::endl;
-// uint32_t line_num = 0;
-// std::istringstream src_stream(src);
-// for(std::string line; std::getline(src_stream, line); line_num++)
-// std::cout << line_num << ": " << line << std::endl;
-// std::cout << "---=== END source listing of spoof cuda kernel [ " << name << " ]." << std::endl;
- std::cout << "cwd: " << std::filesystem::current_path() << std::endl;
- std::cout << "include_paths: ";
- for_each (include_paths.begin(), include_paths.end(), [](const std::string& line){ std::cout << line << '\n';});
- std::cout << std::endl;
-#endif
-
-// uncomment all related lines for temporary timing output:
// auto compile_start = clk::now();
op->program = std::make_unique<jitify::Program>(kernel_cache.program(src, 0, include_paths));
// auto compile_end = clk::now();
// auto compile_duration = std::chrono::duration_cast<sec>(compile_end - compile_start).count();
-
compiled_ops.push_back(std::move(op));
// compile_total += compile_duration;
// std::cout << name << " compiled in "
diff --git a/src/main/cuda/spoof-launcher/SpoofRowwise.h b/src/main/cuda/spoof-launcher/SpoofRowwise.h
index 4465ac99fa..01ec5206aa 100644
--- a/src/main/cuda/spoof-launcher/SpoofRowwise.h
+++ b/src/main/cuda/spoof-launcher/SpoofRowwise.h
@@ -18,15 +18,13 @@
*/
#pragma once
-#ifndef SYSTEMDS_SPOOFROWWISE_H
-#define SYSTEMDS_SPOOFROWWISE_H
#include "SpoofCUDAContext.h"
#include <algorithm>
template <typename T>
struct SpoofRowwise {
-
+
static void exec([[maybe_unused]] SpoofCUDAContext* ctx, SpoofOperator* _op, DataBufferWrapper* dbw) {
uint32_t NT=256;
T value_type;
@@ -56,7 +54,7 @@ struct SpoofRowwise {
CHECK_CUDART(cudaMalloc(reinterpret_cast<void**>(&d_temp), temp_buf_size));
CHECK_CUDART(cudaMemsetAsync(d_temp, 0, temp_buf_size, op->stream));
}
-
+
std::string op_name(op->name + "_DENSE");
if(sparse_input)
op_name = std::string(op->name + "_SPARSE");
@@ -77,5 +75,3 @@ struct SpoofRowwise {
CHECK_CUDART(cudaFree(d_temp));
}
};
-
-#endif //SYSTEMDS_SPOOFROWWISE_H
diff --git a/src/main/cuda/spoof-launcher/jni_bridge.cpp b/src/main/cuda/spoof-launcher/jni_bridge.cpp
index 5134d5e292..65f4a5a19f 100644
--- a/src/main/cuda/spoof-launcher/jni_bridge.cpp
+++ b/src/main/cuda/spoof-launcher/jni_bridge.cpp
@@ -165,7 +165,7 @@ int launch_spoof_operator([[maybe_unused]] JNIEnv *jenv, [[maybe_unused]] jclass
[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1f
(JNIEnv *jenv, jclass jobj, jlong ctx) {
- return launch_spoof_operator<double, SpoofCellwise<double>>(jenv, jobj, ctx);
+ return launch_spoof_operator<float, SpoofCellwise<float>>(jenv, jobj, ctx);
}
@@ -177,5 +177,5 @@ int launch_spoof_operator([[maybe_unused]] JNIEnv *jenv, [[maybe_unused]] jclass
[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1f
(JNIEnv *jenv, jclass jobj, jlong ctx) {
- return launch_spoof_operator<double, SpoofRowwise<double>>(jenv, jobj, ctx);
+ return launch_spoof_operator<float, SpoofRowwise<float>>(jenv, jobj, ctx);
}
\ No newline at end of file
diff --git a/src/main/cuda/spoof/rowwise.cu b/src/main/cuda/spoof/rowwise.cu
index b31ce0c2ce..917b8e7087 100644
--- a/src/main/cuda/spoof/rowwise.cu
+++ b/src/main/cuda/spoof/rowwise.cu
@@ -48,15 +48,14 @@ struct SpoofRowwiseOp //%HAS_TEMP_VECT%
a.init(A);
c.init(C);
- if(B)
- for(auto i = 0; i < NUM_B; ++i)
- b[i].init(&(B[i]));
+ if(B) {
+ for(auto i = 0; i < NUM_B; ++i)
+ b[i].init(&(B[i]));
+ }
}
__device__ __forceinline__ void exec_dense(uint32_t ai, uint32_t ci, uint32_t rix) {
//%BODY_dense%
- if (debug_row() && debug_thread())
- printf("c[0]=%4.3f\n", c.vals(0)[0]);
}
__device__ __forceinline__ void exec_sparse(uint32_t ai, uint32_t ci, uint32_t rix) {
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
index 88844d6fbb..13c8e10fca 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
@@ -56,9 +56,12 @@ public class CNodeRow extends CNodeTpl
private static final String TEMPLATE_NOAGG_OUT = " LibSpoofPrimitives.vectWrite(%IN%, c, ci, %LEN%);\n";
private static final String TEMPLATE_NOAGG_CONST_OUT_CUDA = "\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n";
private static final String TEMPLATE_NOAGG_OUT_CUDA = "\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n";
- private static final String TEMPLATE_ROWAGG_OUT_CUDA = "\t\tif(threadIdx.x == 0){\n\t\t\t*(c.vals(rix)) = %IN%;\n//printf(\"rix=%d TMP7=%f TMP8=%f %IN%=%f\\n\",rix, TMP7, TMP8,%IN%);\n}\n";
+// private static final String TEMPLATE_ROWAGG_OUT_CUDA = "\t\tif(threadIdx.x == 0){\n\t\t\t*(c.vals(rix)) = %IN%;\n//printf(\"rix=%d TMP7=%f TMP8=%f %IN%=%f\\n\",rix, TMP7, TMP8,%IN%);\n}\n";
+private static final String TEMPLATE_ROWAGG_OUT_CUDA = "\t\tif(threadIdx.x == 0){\n\t\t\t*(c.vals(rix)) = %IN%;\n\t\t}\n";
+// private static final String TEMPLATE_FULLAGG_OUT_CUDA =
+// "\t\tif(threadIdx.x == 0) {\n\t\t\tT old = atomicAdd(c.vals(0), %IN%);\n//\t\t\tprintf(\"bid=%d full_agg add %f to %f\\n\",blockIdx.x, %IN%, old);\n\t\t}\n";
private static final String TEMPLATE_FULLAGG_OUT_CUDA =
- "\t\tif(threadIdx.x == 0) {\n\t\t\tT old = atomicAdd(c.vals(0), %IN%);\n//\t\t\tprintf(\"bid=%d full_agg add %f to %f\\n\",blockIdx.x, %IN%, old);\n\t\t}\n";
+ "\t\tif(threadIdx.x == 0) {\n\t\tT old = atomicAdd(c.vals(0), %IN%);\n\t\t}\n";
public CNodeRow(ArrayList<CNode> inputs, CNode output ) {
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
index 6d826b16b4..ec46e1196c 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
@@ -42,267 +42,156 @@ public class Binary extends CodeTemplate
"\t\tVector<T>& %TMP% = vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN1%, %LEN2%, this);\n" :
"\t\tVector<T>& %TMP% = vectCbindWrite(%IN1%, %IN2%, %POS1%, %POS2%, %LEN1%, %LEN2%, this);\n";
}
-
- if(isSinglePrecision()) {
- switch(type) {
- case DOT_PRODUCT:
- return sparseLhs ? " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- case VECT_MATRIXMULT:
- return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : " T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- case VECT_OUTERMULT_ADD:
- return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
- //vector-scalar-add operations
- case VECT_MULT_ADD:
- case VECT_DIV_ADD:
- case VECT_MINUS_ADD:
- case VECT_PLUS_ADD:
- case VECT_POW_ADD:
- case VECT_XOR_ADD:
- case VECT_MIN_ADD:
- case VECT_MAX_ADD:
- case VECT_EQUAL_ADD:
- case VECT_NOTEQUAL_ADD:
- case VECT_LESS_ADD:
- case VECT_LESSEQUAL_ADD:
- case VECT_GREATER_ADD:
- case VECT_GREATEREQUAL_ADD:
- case VECT_CBIND_ADD: {
- String vectName = type.getVectorPrimitiveName();
- if(scalarVector)
- return sparseLhs ? " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
- else
- return sparseLhs ? " LibSpoofPrimitives.vect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
- }
+ switch(type) {
+ case ROWMAXS_VECTMULT:
+ return sparseLhs ? "\t\tT %TMP% = rowMaxsVectMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
+ "\t\tT %TMP% = rowMaxsVectMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+ case DOT_PRODUCT:
+ return sparseLhs ? "\t\tT %TMP% = dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : " T %TMP% = dotProduct(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%);\n";
- //vector-scalar operations
- case VECT_MULT_SCALAR:
- case VECT_DIV_SCALAR:
- case VECT_MINUS_SCALAR:
- case VECT_PLUS_SCALAR:
- case VECT_POW_SCALAR:
- case VECT_XOR_SCALAR:
- case VECT_BITWAND_SCALAR:
- case VECT_MIN_SCALAR:
- case VECT_MAX_SCALAR:
- case VECT_EQUAL_SCALAR:
- case VECT_NOTEQUAL_SCALAR:
- case VECT_LESS_SCALAR:
- case VECT_LESSEQUAL_SCALAR:
- case VECT_GREATER_SCALAR:
- case VECT_GREATEREQUAL_SCALAR: {
- String vectName = type.getVectorPrimitiveName();
- if(scalarVector)
- return sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%);\n";
- else
- return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
- }
- //vector-vector operations
- case VECT_MULT:
- case VECT_DIV:
- case VECT_MINUS:
- case VECT_PLUS:
- case VECT_XOR:
- case VECT_BITWAND:
- case VECT_BIASADD:
- case VECT_BIASMULT:
- case VECT_MIN:
- case VECT_MAX:
- case VECT_EQUAL:
- case VECT_NOTEQUAL:
- case VECT_LESS:
- case VECT_LESSEQUAL:
- case VECT_GREATER:
- case VECT_GREATEREQUAL: {
- String vectName = type.getVectorPrimitiveName();
- return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" : sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- }
+ case VECT_MATRIXMULT:
+ return sparseLhs ? " T[] %TMP% = vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : " Vector<T>& %TMP% = vectMatrixMult(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
+ case VECT_OUTERMULT_ADD:
+ return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : "\t\tvectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
- //scalar-scalar operations
- case MULT:
- return " T %TMP% = %IN1% * %IN2%;\n";
- case DIV:
- return " T %TMP% = %IN1% / %IN2%;\n";
- case PLUS:
- return " T %TMP% = %IN1% + %IN2%;\n";
- case MINUS:
- return " T %TMP% = %IN1% - %IN2%;\n";
- case MODULUS:
- return " T %TMP% = modulus(%IN1%, %IN2%);\n";
- case INTDIV:
- return " T %TMP% = intDiv(%IN1%, %IN2%);\n";
- case LESS:
- return " T %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n";
- case LESSEQUAL:
- return " T %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n";
- case GREATER:
- return " T %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n";
- case GREATEREQUAL:
- return " T %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n";
- case EQUAL:
- return " T %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n";
- case NOTEQUAL:
- return " T %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n";
-
- case MIN:
- return " T %TMP% = fminf(%IN1%, %IN2%);\n";
- case MAX:
- return " T %TMP% = fmaxf(%IN1%, %IN2%);\n";
- case LOG:
- return " T %TMP% = logf(%IN1%)/Math.log(%IN2%);\n";
- case LOG_NZ:
- return " T %TMP% = (%IN1% == 0) ? 0 : logf(%IN1%) / logf(%IN2%);\n";
- case POW:
- return " T %TMP% = powf(%IN1%, %IN2%);\n";
- case MINUS1_MULT:
- return " T %TMP% = 1 - %IN1% * %IN2%;\n";
- case MINUS_NZ:
- return " T %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
- case XOR:
- return " T %TMP% = ( (%IN1% != 0) != (%IN2% != 0) ) ? 1.0f : 0.0f;\n";
- case BITWAND:
- return " T %TMP% = bwAnd(%IN1%, %IN2%);\n";
- case SEQ_RIX:
- return " T %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
-
- default:
- throw new RuntimeException("Invalid binary type: " + this.toString());
+ //vector-scalar-add operations
+ case VECT_MULT_ADD:
+ case VECT_DIV_ADD:
+ case VECT_MINUS_ADD:
+ case VECT_PLUS_ADD:
+ case VECT_POW_ADD:
+ case VECT_XOR_ADD:
+ case VECT_MIN_ADD:
+ case VECT_MAX_ADD:
+ case VECT_EQUAL_ADD:
+ case VECT_NOTEQUAL_ADD:
+ case VECT_LESS_ADD:
+ case VECT_LESSEQUAL_ADD:
+ case VECT_GREATER_ADD:
+ case VECT_GREATEREQUAL_ADD:
+ case VECT_CBIND_ADD: {
+ String vectName = type.getVectorPrimitiveName();
+ if(scalarVector)
+ return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
+ else
+ return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, static_cast<uint32_t>(%POSOUT%), %LEN%);\n";
}
- }
- else {
- switch(type) {
- case DOT_PRODUCT:
-// return sparseLhs ? " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
-// return sparseLhs ? " T %TMP% = dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : " T %TMP% = dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n printf(\"dot=%f, bid=%d, tid=%d\\n\",TMP7,blockIdx.x, threadIdx.x);\n __syncthreads();\n";
- return sparseLhs ? " T %TMP% = dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : " T %TMP% = dotProduct(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%);\n";
-
- case VECT_MATRIXMULT:
- return sparseLhs ? " T[] %TMP% = vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : " Vector<T>& %TMP% = vectMatrixMult(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
- case VECT_OUTERMULT_ADD:
- return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : "\t\tvectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
-
- //vector-scalar-add operations
- case VECT_MULT_ADD:
- case VECT_DIV_ADD:
- case VECT_MINUS_ADD:
- case VECT_PLUS_ADD:
- case VECT_POW_ADD:
- case VECT_XOR_ADD:
- case VECT_MIN_ADD:
- case VECT_MAX_ADD:
- case VECT_EQUAL_ADD:
- case VECT_NOTEQUAL_ADD:
- case VECT_LESS_ADD:
- case VECT_LESSEQUAL_ADD:
- case VECT_GREATER_ADD:
- case VECT_GREATEREQUAL_ADD:
- case VECT_CBIND_ADD: {
- String vectName = type.getVectorPrimitiveName();
- if(scalarVector)
- return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
- else
- return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, static_cast<uint32_t>(%POSOUT%), %LEN%);\n";
- }
- //vector-scalar operations
- case VECT_MULT_SCALAR:
- case VECT_DIV_SCALAR:
- case VECT_MINUS_SCALAR:
- case VECT_PLUS_SCALAR:
- case VECT_POW_SCALAR:
- case VECT_XOR_SCALAR:
- case VECT_BITWAND_SCALAR:
- case VECT_MIN_SCALAR:
- case VECT_MAX_SCALAR:
- case VECT_EQUAL_SCALAR:
- case VECT_NOTEQUAL_SCALAR:
- case VECT_LESS_SCALAR:
- case VECT_LESSEQUAL_SCALAR:
- case VECT_GREATER_SCALAR:
- case VECT_GREATEREQUAL_SCALAR: {
- String vectName = type.getVectorPrimitiveName();
- if(scalarVector)
- return sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : " Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%, this);\n";
- else
+ //vector-scalar operations
+ case VECT_MULT_SCALAR:
+ case VECT_DIV_SCALAR:
+ case VECT_MINUS_SCALAR:
+ case VECT_PLUS_SCALAR:
+ case VECT_POW_SCALAR:
+ case VECT_XOR_SCALAR:
+ case VECT_BITWAND_SCALAR:
+ case VECT_MIN_SCALAR:
+ case VECT_MAX_SCALAR:
+ case VECT_EQUAL_SCALAR:
+ case VECT_NOTEQUAL_SCALAR:
+ case VECT_LESS_SCALAR:
+ case VECT_LESSEQUAL_SCALAR:
+ case VECT_GREATER_SCALAR:
+ case VECT_GREATEREQUAL_SCALAR: {
+ String vectName = type.getVectorPrimitiveName();
+ if(scalarVector)
+ return sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : " Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%, this);\n";
+ else
// return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
- return sparseLhs ? " Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%, this);\n" : " Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
- }
-
- //vector-vector operations
- case VECT_MULT:
- case VECT_DIV:
- case VECT_MINUS:
- case VECT_PLUS:
- case VECT_XOR:
- case VECT_BITWAND:
- case VECT_BIASADD:
- case VECT_BIASMULT:
- case VECT_MIN:
- case VECT_MAX:
- case VECT_EQUAL:
- case VECT_NOTEQUAL:
- case VECT_LESS:
- case VECT_LESSEQUAL:
- case VECT_GREATER:
- case VECT_GREATEREQUAL: {
- String vectName = type.getVectorPrimitiveName();
- return sparseLhs ? " Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, " +
- "alen, %LEN%);\n" : sparseRhs ? " Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2v%, " +
- "%POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : " Vector<T>& %TMP% = vect" + vectName +
- "Write(%IN1%, %IN2%, static_cast<uint32_t>(%POS1%), static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
- }
+ return sparseLhs ? " Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%, this);\n" : " Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
+ }
- //scalar-scalar operations
- case MULT:
- return " T %TMP% = %IN1% * %IN2%;\n";
- case DIV:
- return " T %TMP% = %IN1% / %IN2%;\n";
- case PLUS:
- return " T %TMP% = %IN1% + %IN2%;\n";
- case MINUS:
- return " T %TMP% = %IN1% - %IN2%;\n";
- case MODULUS:
- return " T %TMP% = modulus(%IN1%, %IN2%);\n";
- case INTDIV:
- return " T %TMP% = intDiv(%IN1%, %IN2%);\n";
- case LESS:
- return " T %TMP% = (%IN1% < %IN2%) ? 1.0 : 0.0;\n";
- case LESSEQUAL:
- return " T %TMP% = (%IN1% <= %IN2%) ? 1.0 : 0.0;\n";
- case GREATER:
- return " T %TMP% = (%IN1% > (%IN2% + EPSILON)) ? 1.0 : 0.0;\n";
- case GREATEREQUAL:
- return " T %TMP% = (%IN1% >= %IN2%) ? 1.0 : 0.0;\n";
- case EQUAL:
- return " T %TMP% = (%IN1% == %IN2%) ? 1.0 : 0.0;\n";
- case NOTEQUAL:
- return " T %TMP% = (%IN1% != %IN2%) ? 1.0 : 0.0;\n";
+ //vector-vector operations
+ case VECT_MULT:
+ case VECT_DIV:
+ case VECT_MINUS:
+ case VECT_PLUS:
+ case VECT_XOR:
+ case VECT_BITWAND:
+ case VECT_BIASADD:
+ case VECT_BIASMULT:
+ case VECT_MIN:
+ case VECT_MAX:
+ case VECT_EQUAL:
+ case VECT_NOTEQUAL:
+ case VECT_LESS:
+ case VECT_LESSEQUAL:
+ case VECT_GREATER:
+ case VECT_GREATEREQUAL: {
+ String vectName = type.getVectorPrimitiveName();
+ return sparseLhs ? " Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, " +
+ "%POS1%, %POS2%, alen, %LEN%, this);\n" :
+ sparseRhs ? " Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, " +
+ "%IN2i%, %POS2%, alen, %LEN%);\n" :
+ " Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, " +
+ "static_cast<uint32_t>(%POS1%), static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
+ }
- case MIN:
- return " T %TMP% = min(%IN1%, %IN2%);\n";
- case MAX:
- return " T %TMP% = max(%IN1%, %IN2%);\n";
- case LOG:
- return " T %TMP% = log(%IN1%)/Math.log(%IN2%);\n";
- case LOG_NZ:
- return " T %TMP% = (%IN1% == 0) ? 0 : log(%IN1%) / log(%IN2%);\n";
- case POW:
- return " T %TMP% = pow(%IN1%, %IN2%);\n";
- case MINUS1_MULT:
- return " T %TMP% = 1 - %IN1% * %IN2%;\n";
- case MINUS_NZ:
- return " T %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
- case XOR:
+ //scalar-scalar operations
+ case MULT:
+ return " T %TMP% = %IN1% * %IN2%;\n";
+ case DIV:
+ return "\t\tT %TMP% = %IN1% / %IN2%;\n";
+ case PLUS:
+ return "\t\tT %TMP% = %IN1% + %IN2%;\n";
+ case MINUS:
+ return " T %TMP% = %IN1% - %IN2%;\n";
+ case MODULUS:
+ return " T %TMP% = modulus(%IN1%, %IN2%);\n";
+ case INTDIV:
+ return " T %TMP% = intDiv(%IN1%, %IN2%);\n";
+ case LESS:
+ return " T %TMP% = (%IN1% < %IN2%) ? 1.0 : 0.0;\n";
+ case LESSEQUAL:
+ return " T %TMP% = (%IN1% <= %IN2%) ? 1.0 : 0.0;\n";
+ case GREATER:
+ return " T %TMP% = (%IN1% > (%IN2% + EPSILON)) ? 1.0 : 0.0;\n";
+ case GREATEREQUAL:
+ return " T %TMP% = (%IN1% >= %IN2%) ? 1.0 : 0.0;\n";
+ case EQUAL:
+ return " T %TMP% = (%IN1% == %IN2%) ? 1.0 : 0.0;\n";
+ case NOTEQUAL:
+ return "\t\tT %TMP% = (%IN1% != %IN2%) ? 1.0 : 0.0;\n";
+ case MIN:
+ if(isSinglePrecision())
+ return "\t\tT %TMP% = fminf(%IN1%, %IN2%);\n";
+ else
+ return "\t\tT %TMP% = min(%IN1%, %IN2%);\n";
+ case MAX:
+ if(isSinglePrecision())
+ return "\t\tT %TMP% = fmaxf(%IN1%, %IN2%);\n";
+ else
+ return "\t\tT %TMP% = max(%IN1%, %IN2%);\n";
+ case LOG:
+ if(isSinglePrecision())
+ return "\t\tT %TMP% = logf(%IN1%) / logf(%IN2%);\n";
+ else
+ return "\t\tT %TMP% = log(%IN1%) / log(%IN2%);\n";
+ case LOG_NZ:
+ if(isSinglePrecision())
+ return "\t\tT %TMP% = (%IN1% == 0) ? 0 : logf(%IN1%) / logf(%IN2%);\n";
+ else
+ return "\t\tT %TMP% = (%IN1% == 0) ? 0 : log(%IN1%) / log(%IN2%);\n";
+ case POW:
+ if(isSinglePrecision())
+ return "\t\tT %TMP% = powf(%IN1%, %IN2%);\n";
+ else
+ return "\t\tT %TMP% = pow(%IN1%, %IN2%);\n";
+ case MINUS1_MULT:
+ return " T %TMP% = 1 - %IN1% * %IN2%;\n";
+ case MINUS_NZ:
+ return " T %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
+ case XOR:
// return " T %TMP% = ( (%IN1% != 0.0) != (%IN2% != 0.0) ) ? 1.0 : 0.0;\n";
- return " T %TMP% = ( (%IN1% < EPSILON) != (%IN2% < EPSILON) ) ? 1.0 : 0.0;\n";
- case BITWAND:
- return " T %TMP% = bwAnd(%IN1%, %IN2%);\n";
- case SEQ_RIX:
- return " T %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
+ return " T %TMP% = ( (%IN1% < EPSILON) != (%IN2% < EPSILON) ) ? 1.0 : 0.0;\n";
+ case BITWAND:
+ return " T %TMP% = bwAnd(%IN1%, %IN2%);\n";
+ case SEQ_RIX:
+ return "\t\tT %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
- default:
- throw new RuntimeException("Invalid binary type: " + this.toString());
- }
+ default:
+ throw new RuntimeException("Invalid binary type: " + this.toString());
}
}
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
index dd06d6c004..026fe264f8 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
@@ -28,81 +28,41 @@ public class Ternary extends CodeTemplate {
@Override
public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
- if(isSinglePrecision()) {
- switch (type) {
- case PLUS_MULT:
- return " T %TMP% = %IN1% + %IN2% * %IN3%;\n";
+ switch (type) {
+ case PLUS_MULT:
+ return " T %TMP% = %IN1% + %IN2% * %IN3%;\n";
- case MINUS_MULT:
- return " T %TMP% = %IN1% - %IN2% * %IN3%;\n";
+ case MINUS_MULT:
+ return " T %TMP% = %IN1% - %IN2% * %IN3%;\n";
- case BIASADD:
- return " T %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
+ case BIASADD:
+ return " T %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
- case BIASMULT:
- return " T %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
+ case BIASMULT:
+ return " T %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
- case REPLACE:
- return " T %TMP% = (%IN1% == %IN2% || (isnan(%IN1%) "
- + "&& isnan(%IN2%))) ? %IN3% : %IN1%;\n";
+ case REPLACE:
+ return " T %TMP% = (%IN1% == %IN2% || (isnan(%IN1%) "
+ + "&& isnan(%IN2%))) ? %IN3% : %IN1%;\n";
- case REPLACE_NAN:
- return " T %TMP% = isnan(%IN1%) ? %IN3% : %IN1%;\n";
+ case REPLACE_NAN:
+ return " T %TMP% = isnan(%IN1%) ? %IN3% : %IN1%;\n";
- case IFELSE:
- return " T %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
+ case IFELSE:
+ return " T %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
- case LOOKUP_RC1:
- return sparse ?
- " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
+ case LOOKUP_RC1:
+ return sparse ?
+ " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
// " T %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
- " T %TMP% = %IN1%.val(rix, %IN3%-1);\n";
+ " T %TMP% = %IN1%.val(rix, %IN3%-1);\n";
- case LOOKUP_RVECT1:
- return "\t\tVector<T>& %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
- default:
- throw new RuntimeException("Invalid ternary type: " + this.toString());
- }
- }
- else {
- switch (type) {
- case PLUS_MULT:
- return " T %TMP% = %IN1% + %IN2% * %IN3%;\n";
-
- case MINUS_MULT:
- return " T %TMP% = %IN1% - %IN2% * %IN3%;\n";
-
- case BIASADD:
- return " T %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
-
- case BIASMULT:
- return " T %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
-
- case REPLACE:
- return " T %TMP% = (%IN1% == %IN2% || (isnan(%IN1%) "
- + "&& isnan(%IN2%))) ? %IN3% : %IN1%;\n";
-
- case REPLACE_NAN:
- return " T %TMP% = isnan(%IN1%) ? %IN3% : %IN1%;\n";
-
- case IFELSE:
- return " T %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
-
- case LOOKUP_RC1:
- return sparse ?
- " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
-// " T %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
- " T %TMP% = %IN1%.val(rix, %IN3%-1);\n";
-
-
- case LOOKUP_RVECT1:
- return "\t\tVector<T>& %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1, this);\n";
-
- default:
- throw new RuntimeException("Invalid ternary type: "+this.toString());
- }
+ case LOOKUP_RVECT1:
+ return "\t\tVector<T>& %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1, this);\n";
+ default:
+ throw new RuntimeException("Invalid ternary type: "+this.toString());
}
}
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
index f2405d5b5c..405b880715 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
@@ -29,210 +29,161 @@ public class Unary extends CodeTemplate {
@Override
public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
- if(isSinglePrecision()) {
- switch( type ) {
- case ROW_SUMS:
- case ROW_SUMSQS:
- case ROW_MINS:
- case ROW_MAXS:
- case ROW_MEANS:
- case ROW_COUNTNNZS: {
- String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
- return sparse ? " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
- " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
- }
+ switch( type ) {
+ case ROW_SUMS:
+ case ROW_SUMSQS:
+ case ROW_MINS:
+ case ROW_MAXS:
+ case ROW_MEANS:
+ case ROW_COUNTNNZS: {
+ String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
+ return sparse ? " T %TMP% = vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, %LEN%);\n":
+ " T %TMP% = vect"+vectName+"(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%);\n";
- case VECT_EXP:
- case VECT_POW2:
- case VECT_MULT2:
- case VECT_SQRT:
- case VECT_LOG:
- case VECT_ABS:
- case VECT_ROUND:
- case VECT_CEIL:
- case VECT_FLOOR:
- case VECT_SIGN:
- case VECT_SIN:
- case VECT_COS:
- case VECT_TAN:
- case VECT_ASIN:
- case VECT_ACOS:
- case VECT_ATAN:
- case VECT_SINH:
- case VECT_COSH:
- case VECT_TANH:
- case VECT_CUMSUM:
- case VECT_CUMMIN:
- case VECT_CUMMAX:
- case VECT_SPROP:
- case VECT_SIGMOID: {
- String vectName = type.getVectorPrimitiveName();
- return sparse ? " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" :
- " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n";
- }
-
- case EXP:
- return " T %TMP% = expf(%IN1%);\n";
- case LOOKUP_R:
- return sparse ?
- " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
- " T %TMP% = getValue(%IN1%, rix);\n";
- case LOOKUP_C:
- return " T %TMP% = getValue(%IN1%, n, 0, cix);\n";
- case LOOKUP_RC:
- return " T %TMP% = getValue(%IN1%, n, rix, cix);\n";
- case LOOKUP0:
- return " T %TMP% = %IN1%[0];\n";
- case POW2:
- return " T %TMP% = %IN1% * %IN1%;\n";
- case MULT2:
- return " T %TMP% = %IN1% + %IN1%;\n";
- case ABS:
- return " T %TMP% = fabsf(%IN1%);\n";
- case SIN:
- return " T %TMP% = sinf(%IN1%);\n";
- case COS:
- return " T %TMP% = cosf(%IN1%);\n";
- case TAN:
- return " T %TMP% = tanf(%IN1%);\n";
- case ASIN:
- return " T %TMP% = asinf(%IN1%);\n";
- case ACOS:
- return " T %TMP% = acosf(%IN1%);\n";
- case ATAN:
- return " T %TMP% = atanf(%IN1%);\n";
- case SINH:
- return " T %TMP% = sinhf(%IN1%);\n";
- case COSH:
- return " T %TMP% = coshf(%IN1%);\n";
- case TANH:
- return " T %TMP% = tanhf(%IN1%);\n";
- case SIGN:
- return " T %TMP% = signbit(%IN1%) == 0 ? 1.0f : -1.0f;\n";
- case SQRT:
- return " T %TMP% = sqrtf(%IN1%);\n";
- case LOG:
- return " T %TMP% = logf(%IN1%);\n";
- case ROUND:
- return " T %TMP% = roundf(%IN1%);\n";
- case CEIL:
- return " T %TMP% = ceilf(%IN1%);\n";
- case FLOOR:
- return " T %TMP% = floorf(%IN1%);\n";
- case SPROP:
- return " T %TMP% = %IN1% * (1 - %IN1%);\n";
- case SIGMOID:
- return " T %TMP% = 1 / (1 + expf(-%IN1%));\n";
- case LOG_NZ:
- return " T %TMP% = (%IN1%==0) ? 0 : logf(%IN1%);\n";
-
- default:
- throw new RuntimeException("Invalid unary type: "+this.toString());
}
- }
- else { /* double precision */
- switch( type ) {
- case ROW_SUMS:
- case ROW_SUMSQS:
- case ROW_MINS:
- case ROW_MAXS:
- case ROW_MEANS:
- case ROW_COUNTNNZS: {
- String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
- return sparse ? " T %TMP% = vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, %LEN%);\n":
- " T %TMP% = vect"+vectName+"(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%);\n";
-
- }
- case VECT_EXP:
- case VECT_POW2:
- case VECT_MULT2:
- case VECT_SQRT:
- case VECT_LOG:
- case VECT_ABS:
- case VECT_ROUND:
- case VECT_CEIL:
- case VECT_FLOOR:
- case VECT_SIGN:
- case VECT_SIN:
- case VECT_COS:
- case VECT_TAN:
- case VECT_ASIN:
- case VECT_ACOS:
- case VECT_ATAN:
- case VECT_SINH:
- case VECT_COSH:
- case VECT_TANH:
- case VECT_CUMSUM:
- case VECT_CUMMIN:
- case VECT_CUMMAX:
- case VECT_SPROP:
- case VECT_SIGMOID: {
- String vectName = type.getVectorPrimitiveName();
- return sparse ? " Vector<T>& %TMP% = vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, %LEN%, this);\n" :
- " Vector<T>& %TMP% = vect"+vectName+"Write(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
- }
+ case VECT_EXP:
+ case VECT_POW2:
+ case VECT_MULT2:
+ case VECT_SQRT:
+ case VECT_LOG:
+ case VECT_ABS:
+ case VECT_ROUND:
+ case VECT_CEIL:
+ case VECT_FLOOR:
+ case VECT_SIGN:
+ case VECT_SIN:
+ case VECT_COS:
+ case VECT_TAN:
+ case VECT_ASIN:
+ case VECT_ACOS:
+ case VECT_ATAN:
+ case VECT_SINH:
+ case VECT_COSH:
+ case VECT_TANH:
+ case VECT_CUMSUM:
+ case VECT_CUMMIN:
+ case VECT_CUMMAX:
+ case VECT_SPROP:
+ case VECT_SIGMOID: {
+ String vectName = type.getVectorPrimitiveName();
+ return sparse ? " Vector<T>& %TMP% = vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, %LEN%, this);\n" :
+ " Vector<T>& %TMP% = vect"+vectName+"Write(%IN1%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
+ }
- case EXP:
+ case EXP:
+ if(isSinglePrecision())
+ return " T %TMP% = expf(%IN1%);\n";
+ else
return " T %TMP% = exp(%IN1%);\n";
- case LOOKUP_R:
- return sparse ?
- " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
- " T %TMP% = %IN1%.val(rix);\n";
-// " T %TMP% = getValue(%IN1%, rix);\n";
- case LOOKUP_C:
- return " T %TMP% = getValue(%IN1%, n, 0, cix);\n";
- case LOOKUP_RC:
- return " T %TMP% = getValue(%IN1%, n, rix, cix);\n";
- case LOOKUP0:
- return " T %TMP% = %IN1%[0];\n";
- case POW2:
- return " T %TMP% = %IN1% * %IN1%;\n";
- case MULT2:
- return " T %TMP% = %IN1% + %IN1%;\n";
- case ABS:
+ case LOOKUP_R:
+ return sparse ?
+ "\t\tT %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
+// " T %TMP% = %IN1%.val(rix);\n";
+ "\t\tT %TMP% = getValue(%IN1%, rix);\n";
+ case LOOKUP_C:
+ return "\t\tT %TMP% = getValue(%IN1%, n, 0, cix);\n";
+ case LOOKUP_RC:
+ return "\t\tT %TMP% = getValue(%IN1%, n, rix, cix);\n";
+ case LOOKUP0:
+ return "\t\tT %TMP% = %IN1%[0];\n";
+ case POW2:
+ return " T %TMP% = %IN1% * %IN1%;\n";
+ case MULT2:
+ return " T %TMP% = %IN1% + %IN1%;\n";
+ case ABS:
+ if(isSinglePrecision())
+ return " T %TMP% = fabsf(%IN1%);\n";
+ else
return "\t\tT %TMP% = fabs(%IN1%);\n";
- case SIN:
+ case SIN:
+ if(isSinglePrecision())
+ return " T %TMP% = sinf(%IN1%);\n";
+ else
return " T %TMP% = sin(%IN1%);\n";
- case COS:
+ case COS:
+ if(isSinglePrecision())
+ return " T %TMP% = cosf(%IN1%);\n";
+ else
return " T %TMP% = cos(%IN1%);\n";
- case TAN:
+ case TAN:
+ if(isSinglePrecision())
+ return " T %TMP% = tanf(%IN1%);\n";
+ else
return " T %TMP% = tan(%IN1%);\n";
- case ASIN:
+ case ASIN:
+ if(isSinglePrecision())
+ return " T %TMP% = asinf(%IN1%);\n";
+ else
return " T %TMP% = asin(%IN1%);\n";
- case ACOS:
+ case ACOS:
+ if(isSinglePrecision())
+ return " T %TMP% = acosf(%IN1%);\n";
+ else
return " T %TMP% = acos(%IN1%);\n";
- case ATAN:
+ case ATAN:
+ if(isSinglePrecision())
+ return " T %TMP% = atanf(%IN1%);\n";
+ else
return " T %TMP% = atan(%IN1%);\n";
- case SINH:
+ case SINH:
+ if(isSinglePrecision())
+ return " T %TMP% = sinhf(%IN1%);\n";
+ else
return " T %TMP% = sinh(%IN1%);\n";
- case COSH:
+ case COSH:
+ if(isSinglePrecision())
+ return " T %TMP% = coshf(%IN1%);\n";
+ else
return " T %TMP% = cosh(%IN1%);\n";
- case TANH:
+ case TANH:
+ if(isSinglePrecision())
+ return " T %TMP% = tanhf(%IN1%);\n";
+ else
return " T %TMP% = tanh(%IN1%);\n";
- case SIGN:
- return " T %TMP% = signbit(%IN1%) == 0 ? 1.0 : -1.0;\n";
- case SQRT:
+ case SIGN:
+ return " T %TMP% = signbit(%IN1%) == 0 ? 1.0 : -1.0;\n";
+ case SQRT:
+ if(isSinglePrecision())
+ return " T %TMP% = sqrtf(%IN1%);\n";
+ else
return " T %TMP% = sqrt(%IN1%);\n";
- case LOG:
+ case LOG:
+
+ if(isSinglePrecision())
+ return " T %TMP% = logf(%IN1%);\n";
+ else
return " T %TMP% = log(%IN1%);\n";
- case ROUND:
+ case ROUND:
+ if(isSinglePrecision())
+ return " T %TMP% = roundf(%IN1%);\n";
+ else
return "\t\tT %TMP% = round(%IN1%);\n";
- case CEIL:
+ case CEIL:
+ if(isSinglePrecision())
+ return " T %TMP% = ceilf(%IN1%);\n";
+ else
return " T %TMP% = ceil(%IN1%);\n";
- case FLOOR:
+ case FLOOR:
+ if(isSinglePrecision())
+ return " T %TMP% = floorf(%IN1%);\n";
+ else
return " T %TMP% = floor(%IN1%);\n";
- case SPROP:
- return " T %TMP% = %IN1% * (1 - %IN1%);\n";
- case SIGMOID:
+ case SPROP:
+ return " T %TMP% = %IN1% * (1 - %IN1%);\n";
+ case SIGMOID:
+ if(isSinglePrecision())
+ return " T %TMP% = 1 / (1 + expf(-%IN1%));\n";
+ else
return " T %TMP% = 1 / (1 + exp(-%IN1%));\n";
- case LOG_NZ:
+ case LOG_NZ:
+ if(isSinglePrecision())
+ return " T %TMP% = (%IN1%==0) ? 0 : logf(%IN1%);\n";
+ else
return " T %TMP% = (%IN1%==0) ? 0 : log(%IN1%);\n";
- default:
- throw new RuntimeException("Invalid unary type: "+this.toString());
- }
-
+ default:
+ throw new RuntimeException("Invalid unary type: "+this.toString());
}
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
index 03c35da540..cfe5780326 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java
@@ -116,9 +116,9 @@ public class SpoofCUDACellwise extends SpoofCellwise implements SpoofCUDAOperato
}
public int execute_dp(long ctx) { return execute_d(ctx); }
- public int execute_sp(long ctx) { return execute_d(ctx); }
+ public int execute_sp(long ctx) { return execute_f(ctx); }
public long getContext() { return ctx; }
public static native int execute_d(long ctx);
- public static native int execute_s(long ctx);
+ public static native int execute_f(long ctx);
}
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
index 47826a9461..0adf2ec605 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java
@@ -96,9 +96,9 @@ public class SpoofCUDARowwise extends SpoofRowwise implements SpoofCUDAOperator
int ci, int alen, int n, long grix, int rix) { }
public int execute_dp(long ctx) { return execute_d(ctx); }
- public int execute_sp(long ctx) { return execute_d(ctx); }
+ public int execute_sp(long ctx) { return execute_f(ctx); }
public long getContext() { return ctx; }
public static native int execute_d(long ctx);
- public static native int execute_s(long ctx);
+ public static native int execute_f(long ctx);
}