You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/11/08 19:28:29 UTC
[systemds] branch master updated: [MINOR] Fix warnings (imports,
resources) and wrong code formatting
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new c8a5433 [MINOR] Fix warnings (imports, resources) and wrong code formatting
c8a5433 is described below
commit c8a543317394131463e25a7f95c90a8d0f1c14fb
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sun Nov 8 20:28:06 2020 +0100
[MINOR] Fix warnings (imports, resources) and wrong code formatting
---
src/main/cuda/ext/jitify | 1 -
.../apache/sysds/conf/ConfigurationManager.java | 1 -
.../apache/sysds/hops/codegen/SpoofCompiler.java | 27 +-
.../org/apache/sysds/hops/codegen/cplan/CNode.java | 2 -
.../sysds/hops/codegen/cplan/cpp/Binary.java | 544 ++++++++++-----------
.../sysds/hops/codegen/cplan/cpp/CellWise.java | 73 +--
.../sysds/hops/codegen/cplan/cpp/Ternary.java | 194 ++++----
.../apache/sysds/hops/codegen/cplan/cpp/Unary.java | 430 ++++++++--------
.../sysds/hops/codegen/cplan/java/Binary.java | 342 ++++++-------
.../sysds/hops/codegen/cplan/java/CellWise.java | 90 ++--
.../sysds/hops/codegen/cplan/java/Ternary.java | 116 ++---
.../sysds/hops/codegen/cplan/java/Unary.java | 230 ++++-----
.../org/apache/sysds/parser/LanguageException.java | 50 +-
.../apache/sysds/runtime/codegen/SpoofCUDA.java | 167 +++----
.../runtime/compress/colgroup/ColGroupConst.java | 450 ++++++++---------
.../instructions/fed/VariableFEDInstruction.java | 64 +--
.../instructions/gpu/MMTSJGPUInstruction.java | 70 +--
.../instructions/gpu/SpoofCUDAInstruction.java | 154 +++---
.../instructions/gpu/context/GPUContextPool.java | 1 -
.../sysds/runtime/io/ReaderWriterFederated.java | 298 +++++------
.../runtime/matrix/data/LibMatrixDatagen.java | 2 +-
21 files changed, 1652 insertions(+), 1654 deletions(-)
diff --git a/src/main/cuda/ext/jitify b/src/main/cuda/ext/jitify
deleted file mode 160000
index 3e96bcc..0000000
--- a/src/main/cuda/ext/jitify
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 3e96bcceb9e42105f6a32315abb2af04585a55b0
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 516b956..aea8514 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -20,7 +20,6 @@
package org.apache.sysds.conf;
import org.apache.hadoop.mapred.JobConf;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.CompilerConfig.ConfigType;
diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
index d388583..46bc481 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
@@ -266,20 +266,21 @@ public class SpoofCompiler {
}
private static void extractCodegenSources(String resource_path, String jar_path) throws IOException {
- JarFile jar_file = new JarFile(jar_path);
- Enumeration<JarEntry> files_in_jar = jar_file.entries();
-
- while (files_in_jar.hasMoreElements()) {
- JarEntry in_file = files_in_jar.nextElement();
- if (in_file.getName().startsWith("cuda/") && !in_file.isDirectory()) {
- File out_file = new File(resource_path, in_file.getName());
- out_file.deleteOnExit();
- File parent = out_file.getParentFile();
- if (parent != null) {
- parent.mkdirs();
- parent.deleteOnExit();
+ try(JarFile jar_file = new JarFile(jar_path)) {
+ Enumeration<JarEntry> files_in_jar = jar_file.entries();
+
+ while (files_in_jar.hasMoreElements()) {
+ JarEntry in_file = files_in_jar.nextElement();
+ if (in_file.getName().startsWith("cuda/") && !in_file.isDirectory()) {
+ File out_file = new File(resource_path, in_file.getName());
+ out_file.deleteOnExit();
+ File parent = out_file.getParentFile();
+ if (parent != null) {
+ parent.mkdirs();
+ parent.deleteOnExit();
+ }
+ IOUtils.copy(jar_file.getInputStream(in_file), FileUtils.openOutputStream(out_file));
}
- IOUtils.copy(jar_file.getInputStream(in_file), FileUtils.openOutputStream(out_file));
}
}
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
index a2f918e..07abcce 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
@@ -27,8 +27,6 @@ import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
-import static org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI.CUDA;
-
public abstract class CNode
{
private static final IDSequence _seqVar = new IDSequence();
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Binary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Binary.java
index 8d78b7b..287a884 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Binary.java
@@ -28,295 +28,295 @@ import org.apache.sysds.runtime.codegen.SpoofCellwise;
import static org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePrecision;
public class Binary implements CodeTemplate {
- @Override
- public String getTemplate() {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate() {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(SpoofCellwise.CellType ct) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(SpoofCellwise.CellType ct) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector,
- boolean scalarInput) {
+ public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector,
+ boolean scalarInput) {
- if(isSinglePrecision()) {
- switch(type) {
- case DOT_PRODUCT:
- return sparseLhs ? " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- case VECT_MATRIXMULT:
- return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : " T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- case VECT_OUTERMULT_ADD:
- return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
+ if(isSinglePrecision()) {
+ switch(type) {
+ case DOT_PRODUCT:
+ return sparseLhs ? " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+ case VECT_MATRIXMULT:
+ return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : " T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+ case VECT_OUTERMULT_ADD:
+ return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
- //vector-scalar-add operations
- case VECT_MULT_ADD:
- case VECT_DIV_ADD:
- case VECT_MINUS_ADD:
- case VECT_PLUS_ADD:
- case VECT_POW_ADD:
- case VECT_XOR_ADD:
- case VECT_MIN_ADD:
- case VECT_MAX_ADD:
- case VECT_EQUAL_ADD:
- case VECT_NOTEQUAL_ADD:
- case VECT_LESS_ADD:
- case VECT_LESSEQUAL_ADD:
- case VECT_GREATER_ADD:
- case VECT_GREATEREQUAL_ADD:
- case VECT_CBIND_ADD: {
- String vectName = type.getVectorPrimitiveName();
- if(scalarVector)
- return sparseLhs ? " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
- else
- return sparseLhs ? " LibSpoofPrimitives.vect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
- }
+ //vector-scalar-add operations
+ case VECT_MULT_ADD:
+ case VECT_DIV_ADD:
+ case VECT_MINUS_ADD:
+ case VECT_PLUS_ADD:
+ case VECT_POW_ADD:
+ case VECT_XOR_ADD:
+ case VECT_MIN_ADD:
+ case VECT_MAX_ADD:
+ case VECT_EQUAL_ADD:
+ case VECT_NOTEQUAL_ADD:
+ case VECT_LESS_ADD:
+ case VECT_LESSEQUAL_ADD:
+ case VECT_GREATER_ADD:
+ case VECT_GREATEREQUAL_ADD:
+ case VECT_CBIND_ADD: {
+ String vectName = type.getVectorPrimitiveName();
+ if(scalarVector)
+ return sparseLhs ? " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
+ else
+ return sparseLhs ? " LibSpoofPrimitives.vect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
+ }
- //vector-scalar operations
- case VECT_MULT_SCALAR:
- case VECT_DIV_SCALAR:
- case VECT_MINUS_SCALAR:
- case VECT_PLUS_SCALAR:
- case VECT_POW_SCALAR:
- case VECT_XOR_SCALAR:
- case VECT_BITWAND_SCALAR:
- case VECT_MIN_SCALAR:
- case VECT_MAX_SCALAR:
- case VECT_EQUAL_SCALAR:
- case VECT_NOTEQUAL_SCALAR:
- case VECT_LESS_SCALAR:
- case VECT_LESSEQUAL_SCALAR:
- case VECT_GREATER_SCALAR:
- case VECT_GREATEREQUAL_SCALAR: {
- String vectName = type.getVectorPrimitiveName();
- if(scalarVector)
- return sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%);\n";
- else
- return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
- }
+ //vector-scalar operations
+ case VECT_MULT_SCALAR:
+ case VECT_DIV_SCALAR:
+ case VECT_MINUS_SCALAR:
+ case VECT_PLUS_SCALAR:
+ case VECT_POW_SCALAR:
+ case VECT_XOR_SCALAR:
+ case VECT_BITWAND_SCALAR:
+ case VECT_MIN_SCALAR:
+ case VECT_MAX_SCALAR:
+ case VECT_EQUAL_SCALAR:
+ case VECT_NOTEQUAL_SCALAR:
+ case VECT_LESS_SCALAR:
+ case VECT_LESSEQUAL_SCALAR:
+ case VECT_GREATER_SCALAR:
+ case VECT_GREATEREQUAL_SCALAR: {
+ String vectName = type.getVectorPrimitiveName();
+ if(scalarVector)
+ return sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%);\n";
+ else
+ return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
+ }
- case VECT_CBIND:
- if(scalarInput)
- return " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%);\n";
- else
- return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n";
+ case VECT_CBIND:
+ if(scalarInput)
+ return " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%);\n";
+ else
+ return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n";
- //vector-vector operations
- case VECT_MULT:
- case VECT_DIV:
- case VECT_MINUS:
- case VECT_PLUS:
- case VECT_XOR:
- case VECT_BITWAND:
- case VECT_BIASADD:
- case VECT_BIASMULT:
- case VECT_MIN:
- case VECT_MAX:
- case VECT_EQUAL:
- case VECT_NOTEQUAL:
- case VECT_LESS:
- case VECT_LESSEQUAL:
- case VECT_GREATER:
- case VECT_GREATEREQUAL: {
- String vectName = type.getVectorPrimitiveName();
- return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" : sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- }
+ //vector-vector operations
+ case VECT_MULT:
+ case VECT_DIV:
+ case VECT_MINUS:
+ case VECT_PLUS:
+ case VECT_XOR:
+ case VECT_BITWAND:
+ case VECT_BIASADD:
+ case VECT_BIASMULT:
+ case VECT_MIN:
+ case VECT_MAX:
+ case VECT_EQUAL:
+ case VECT_NOTEQUAL:
+ case VECT_LESS:
+ case VECT_LESSEQUAL:
+ case VECT_GREATER:
+ case VECT_GREATEREQUAL: {
+ String vectName = type.getVectorPrimitiveName();
+ return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" : sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+ }
- //scalar-scalar operations
- case MULT:
- return " T %TMP% = %IN1% * %IN2%;\n";
- case DIV:
- return " T %TMP% = %IN1% / %IN2%;\n";
- case PLUS:
- return " T %TMP% = %IN1% + %IN2%;\n";
- case MINUS:
- return " T %TMP% = %IN1% - %IN2%;\n";
- case MODULUS:
- return " T %TMP% = modulus(%IN1%, %IN2%);\n";
- case INTDIV:
- return " T %TMP% = intDiv(%IN1%, %IN2%);\n";
- case LESS:
- return " T %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n";
- case LESSEQUAL:
- return " T %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n";
- case GREATER:
- return " T %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n";
- case GREATEREQUAL:
- return " T %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n";
- case EQUAL:
- return " T %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n";
- case NOTEQUAL:
- return " T %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n";
+ //scalar-scalar operations
+ case MULT:
+ return " T %TMP% = %IN1% * %IN2%;\n";
+ case DIV:
+ return " T %TMP% = %IN1% / %IN2%;\n";
+ case PLUS:
+ return " T %TMP% = %IN1% + %IN2%;\n";
+ case MINUS:
+ return " T %TMP% = %IN1% - %IN2%;\n";
+ case MODULUS:
+ return " T %TMP% = modulus(%IN1%, %IN2%);\n";
+ case INTDIV:
+ return " T %TMP% = intDiv(%IN1%, %IN2%);\n";
+ case LESS:
+ return " T %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n";
+ case LESSEQUAL:
+ return " T %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n";
+ case GREATER:
+ return " T %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n";
+ case GREATEREQUAL:
+ return " T %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n";
+ case EQUAL:
+ return " T %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n";
+ case NOTEQUAL:
+ return " T %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n";
- case MIN:
- return " T %TMP% = fminf(%IN1%, %IN2%);\n";
- case MAX:
- return " T %TMP% = fmaxf(%IN1%, %IN2%);\n";
- case LOG:
- return " T %TMP% = logf(%IN1%)/Math.log(%IN2%);\n";
- case LOG_NZ:
- return " T %TMP% = (%IN1% == 0) ? 0 : logf(%IN1%) / logf(%IN2%);\n";
- case POW:
- return " T %TMP% = powf(%IN1%, %IN2%);\n";
- case MINUS1_MULT:
- return " T %TMP% = 1 - %IN1% * %IN2%;\n";
- case MINUS_NZ:
- return " T %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
- case XOR:
- return " T %TMP% = ( (%IN1% != 0) != (%IN2% != 0) ) ? 1.0f : 0.0f;\n";
- case BITWAND:
- return " T %TMP% = bwAnd(%IN1%, %IN2%);\n";
- case SEQ_RIX:
- return " T %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
+ case MIN:
+ return " T %TMP% = fminf(%IN1%, %IN2%);\n";
+ case MAX:
+ return " T %TMP% = fmaxf(%IN1%, %IN2%);\n";
+ case LOG:
+ return " T %TMP% = logf(%IN1%)/Math.log(%IN2%);\n";
+ case LOG_NZ:
+ return " T %TMP% = (%IN1% == 0) ? 0 : logf(%IN1%) / logf(%IN2%);\n";
+ case POW:
+ return " T %TMP% = powf(%IN1%, %IN2%);\n";
+ case MINUS1_MULT:
+ return " T %TMP% = 1 - %IN1% * %IN2%;\n";
+ case MINUS_NZ:
+ return " T %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
+ case XOR:
+ return " T %TMP% = ( (%IN1% != 0) != (%IN2% != 0) ) ? 1.0f : 0.0f;\n";
+ case BITWAND:
+ return " T %TMP% = bwAnd(%IN1%, %IN2%);\n";
+ case SEQ_RIX:
+ return " T %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
- default:
- throw new RuntimeException("Invalid binary type: " + this.toString());
- }
- }
- else {
- switch(type) {
- case DOT_PRODUCT:
- return sparseLhs ? " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- case VECT_MATRIXMULT:
- return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : " T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- case VECT_OUTERMULT_ADD:
- return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
+ default:
+ throw new RuntimeException("Invalid binary type: " + this.toString());
+ }
+ }
+ else {
+ switch(type) {
+ case DOT_PRODUCT:
+ return sparseLhs ? " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : " T %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+ case VECT_MATRIXMULT:
+ return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : " T[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+ case VECT_OUTERMULT_ADD:
+ return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
- //vector-scalar-add operations
- case VECT_MULT_ADD:
- case VECT_DIV_ADD:
- case VECT_MINUS_ADD:
- case VECT_PLUS_ADD:
- case VECT_POW_ADD:
- case VECT_XOR_ADD:
- case VECT_MIN_ADD:
- case VECT_MAX_ADD:
- case VECT_EQUAL_ADD:
- case VECT_NOTEQUAL_ADD:
- case VECT_LESS_ADD:
- case VECT_LESSEQUAL_ADD:
- case VECT_GREATER_ADD:
- case VECT_GREATEREQUAL_ADD:
- case VECT_CBIND_ADD: {
- String vectName = type.getVectorPrimitiveName();
- if(scalarVector)
- return sparseLhs ? " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
- else
- return sparseLhs ? " LibSpoofPrimitives.vect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
- }
+ //vector-scalar-add operations
+ case VECT_MULT_ADD:
+ case VECT_DIV_ADD:
+ case VECT_MINUS_ADD:
+ case VECT_PLUS_ADD:
+ case VECT_POW_ADD:
+ case VECT_XOR_ADD:
+ case VECT_MIN_ADD:
+ case VECT_MAX_ADD:
+ case VECT_EQUAL_ADD:
+ case VECT_NOTEQUAL_ADD:
+ case VECT_LESS_ADD:
+ case VECT_LESSEQUAL_ADD:
+ case VECT_GREATER_ADD:
+ case VECT_GREATEREQUAL_ADD:
+ case VECT_CBIND_ADD: {
+ String vectName = type.getVectorPrimitiveName();
+ if(scalarVector)
+ return sparseLhs ? " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
+ else
+ return sparseLhs ? " LibSpoofPrimitives.vect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : " LibSpoofPrimitives.vect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
+ }
- //vector-scalar operations
- case VECT_MULT_SCALAR:
- case VECT_DIV_SCALAR:
- case VECT_MINUS_SCALAR:
- case VECT_PLUS_SCALAR:
- case VECT_POW_SCALAR:
- case VECT_XOR_SCALAR:
- case VECT_BITWAND_SCALAR:
- case VECT_MIN_SCALAR:
- case VECT_MAX_SCALAR:
- case VECT_EQUAL_SCALAR:
- case VECT_NOTEQUAL_SCALAR:
- case VECT_LESS_SCALAR:
- case VECT_LESSEQUAL_SCALAR:
- case VECT_GREATER_SCALAR:
- case VECT_GREATEREQUAL_SCALAR: {
- String vectName = type.getVectorPrimitiveName();
- if(scalarVector)
- return sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%);\n";
- else
- return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
- }
+ //vector-scalar operations
+ case VECT_MULT_SCALAR:
+ case VECT_DIV_SCALAR:
+ case VECT_MINUS_SCALAR:
+ case VECT_PLUS_SCALAR:
+ case VECT_POW_SCALAR:
+ case VECT_XOR_SCALAR:
+ case VECT_BITWAND_SCALAR:
+ case VECT_MIN_SCALAR:
+ case VECT_MAX_SCALAR:
+ case VECT_EQUAL_SCALAR:
+ case VECT_NOTEQUAL_SCALAR:
+ case VECT_LESS_SCALAR:
+ case VECT_LESSEQUAL_SCALAR:
+ case VECT_GREATER_SCALAR:
+ case VECT_GREATEREQUAL_SCALAR: {
+ String vectName = type.getVectorPrimitiveName();
+ if(scalarVector)
+ return sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%);\n";
+ else
+ return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
+ }
- case VECT_CBIND:
- if(scalarInput)
- return " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%);\n";
- else
- return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n";
+ case VECT_CBIND:
+ if(scalarInput)
+ return " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%);\n";
+ else
+ return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n";
- //vector-vector operations
- case VECT_MULT:
- case VECT_DIV:
- case VECT_MINUS:
- case VECT_PLUS:
- case VECT_XOR:
- case VECT_BITWAND:
- case VECT_BIASADD:
- case VECT_BIASMULT:
- case VECT_MIN:
- case VECT_MAX:
- case VECT_EQUAL:
- case VECT_NOTEQUAL:
- case VECT_LESS:
- case VECT_LESSEQUAL:
- case VECT_GREATER:
- case VECT_GREATEREQUAL: {
- String vectName = type.getVectorPrimitiveName();
- return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" : sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- }
+ //vector-vector operations
+ case VECT_MULT:
+ case VECT_DIV:
+ case VECT_MINUS:
+ case VECT_PLUS:
+ case VECT_XOR:
+ case VECT_BITWAND:
+ case VECT_BIASADD:
+ case VECT_BIASMULT:
+ case VECT_MIN:
+ case VECT_MAX:
+ case VECT_EQUAL:
+ case VECT_NOTEQUAL:
+ case VECT_LESS:
+ case VECT_LESSEQUAL:
+ case VECT_GREATER:
+ case VECT_GREATEREQUAL: {
+ String vectName = type.getVectorPrimitiveName();
+ return sparseLhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" : sparseRhs ? " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" : " T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+ }
- //scalar-scalar operations
- case MULT:
- return " T %TMP% = %IN1% * %IN2%;\n";
- case DIV:
- return " T %TMP% = %IN1% / %IN2%;\n";
- case PLUS:
- return " T %TMP% = %IN1% + %IN2%;\n";
- case MINUS:
- return " T %TMP% = %IN1% - %IN2%;\n";
- case MODULUS:
- return " T %TMP% = modulus(%IN1%, %IN2%);\n";
- case INTDIV:
- return " T %TMP% = intDiv(%IN1%, %IN2%);\n";
- case LESS:
- return " T %TMP% = (%IN1% < %IN2%) ? 1.0 : 0.0;\n";
- case LESSEQUAL:
- return " T %TMP% = (%IN1% <= %IN2%) ? 1.0 : 0.0;\n";
- case GREATER:
- return " T %TMP% = (%IN1% > (%IN2% + EPSILON)) ? 1.0 : 0.0;\n";
- case GREATEREQUAL:
- return " T %TMP% = (%IN1% >= %IN2%) ? 1.0 : 0.0;\n";
- case EQUAL:
- return " T %TMP% = (%IN1% == %IN2%) ? 1.0 : 0.0;\n";
- case NOTEQUAL:
- return " T %TMP% = (%IN1% != %IN2%) ? 1.0 : 0.0;\n";
+ //scalar-scalar operations
+ case MULT:
+ return " T %TMP% = %IN1% * %IN2%;\n";
+ case DIV:
+ return " T %TMP% = %IN1% / %IN2%;\n";
+ case PLUS:
+ return " T %TMP% = %IN1% + %IN2%;\n";
+ case MINUS:
+ return " T %TMP% = %IN1% - %IN2%;\n";
+ case MODULUS:
+ return " T %TMP% = modulus(%IN1%, %IN2%);\n";
+ case INTDIV:
+ return " T %TMP% = intDiv(%IN1%, %IN2%);\n";
+ case LESS:
+ return " T %TMP% = (%IN1% < %IN2%) ? 1.0 : 0.0;\n";
+ case LESSEQUAL:
+ return " T %TMP% = (%IN1% <= %IN2%) ? 1.0 : 0.0;\n";
+ case GREATER:
+ return " T %TMP% = (%IN1% > (%IN2% + EPSILON)) ? 1.0 : 0.0;\n";
+ case GREATEREQUAL:
+ return " T %TMP% = (%IN1% >= %IN2%) ? 1.0 : 0.0;\n";
+ case EQUAL:
+ return " T %TMP% = (%IN1% == %IN2%) ? 1.0 : 0.0;\n";
+ case NOTEQUAL:
+ return " T %TMP% = (%IN1% != %IN2%) ? 1.0 : 0.0;\n";
- case MIN:
- return " T %TMP% = min(%IN1%, %IN2%);\n";
- case MAX:
- return " T %TMP% = max(%IN1%, %IN2%);\n";
- case LOG:
- return " T %TMP% = log(%IN1%)/Math.log(%IN2%);\n";
- case LOG_NZ:
- return " T %TMP% = (%IN1% == 0) ? 0 : log(%IN1%) / log(%IN2%);\n";
- case POW:
- return " T %TMP% = pow(%IN1%, %IN2%);\n";
- case MINUS1_MULT:
- return " T %TMP% = 1 - %IN1% * %IN2%;\n";
- case MINUS_NZ:
- return " T %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
- case XOR:
-// return " T %TMP% = ( (%IN1% != 0.0) != (%IN2% != 0.0) ) ? 1.0 : 0.0;\n";
- return " T %TMP% = ( (%IN1% < EPSILON) != (%IN2% < EPSILON) ) ? 1.0 : 0.0;\n";
- case BITWAND:
- return " T %TMP% = bwAnd(%IN1%, %IN2%);\n";
- case SEQ_RIX:
- return " T %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
+ case MIN:
+ return " T %TMP% = min(%IN1%, %IN2%);\n";
+ case MAX:
+ return " T %TMP% = max(%IN1%, %IN2%);\n";
+ case LOG:
+ return " T %TMP% = log(%IN1%)/Math.log(%IN2%);\n";
+ case LOG_NZ:
+ return " T %TMP% = (%IN1% == 0) ? 0 : log(%IN1%) / log(%IN2%);\n";
+ case POW:
+ return " T %TMP% = pow(%IN1%, %IN2%);\n";
+ case MINUS1_MULT:
+ return " T %TMP% = 1 - %IN1% * %IN2%;\n";
+ case MINUS_NZ:
+ return " T %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
+ case XOR:
+// return " T %TMP% = ( (%IN1% != 0.0) != (%IN2% != 0.0) ) ? 1.0 : 0.0;\n";
+ return " T %TMP% = ( (%IN1% < EPSILON) != (%IN2% < EPSILON) ) ? 1.0 : 0.0;\n";
+ case BITWAND:
+ return " T %TMP% = bwAnd(%IN1%, %IN2%);\n";
+ case SEQ_RIX:
+ return " T %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
- default:
- throw new RuntimeException("Invalid binary type: " + this.toString());
- }
- }
- }
+ default:
+ throw new RuntimeException("Invalid binary type: " + this.toString());
+ }
+ }
+ }
- @Override
- public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/CellWise.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/CellWise.java
index f76f3ec..7c14a40 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/CellWise.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/CellWise.java
@@ -19,6 +19,9 @@
package org.apache.sysds.hops.codegen.cplan.cpp;
+import java.io.FileInputStream;
+import java.io.IOException;
+
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
@@ -28,49 +31,47 @@ import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.io.IOUtilFunctions;
-import java.io.*;
-import java.util.stream.Collectors;
// ToDo: clean code template and load from file
public class CellWise implements CodeTemplate {
- private static final String TEMPLATE_PATH = "/cuda/spoof/cellwise.cu";
+ private static final String TEMPLATE_PATH = "/cuda/spoof/cellwise.cu";
- @Override
- public String getTemplate() {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate() {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(SpoofCellwise.CellType ct) {
- try {
- // Change prefix to the code template file if running from jar. File were extracted to a temporary
- // directory in that case. By default we load the template from the source tree.
- if(CellWise.class.getProtectionDomain().getCodeSource().getLocation().getPath().contains(".jar"))
- return(IOUtilFunctions.toString(new FileInputStream(ConfigurationManager.getDMLConfig()
- .getTextValue(DMLConfig.LOCAL_TMP_DIR) + TEMPLATE_PATH)));
- else
- return IOUtilFunctions.toString(new FileInputStream(System.getProperty("user.dir") +
- "/src/main" + TEMPLATE_PATH));
- }
- catch(IOException e) {
- System.out.println(e.getMessage());
- return null;
- }
- }
+ @Override
+ public String getTemplate(SpoofCellwise.CellType ct) {
+ try {
+ // Change prefix to the code template file if running from jar. File were extracted to a temporary
+ // directory in that case. By default we load the template from the source tree.
+ if(CellWise.class.getProtectionDomain().getCodeSource().getLocation().getPath().contains(".jar"))
+ return(IOUtilFunctions.toString(new FileInputStream(ConfigurationManager.getDMLConfig()
+ .getTextValue(DMLConfig.LOCAL_TMP_DIR) + TEMPLATE_PATH)));
+ else
+ return IOUtilFunctions.toString(new FileInputStream(System.getProperty("user.dir") +
+ "/src/main" + TEMPLATE_PATH));
+ }
+ catch(IOException e) {
+ System.out.println(e.getMessage());
+ return null;
+ }
+ }
- @Override
- public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Ternary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Ternary.java
index 3edfcea..ccce19b 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Ternary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Ternary.java
@@ -29,102 +29,102 @@ import static org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePreci
public class Ternary implements CodeTemplate {
- @Override
- public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
- if(isSinglePrecision()) {
- switch (type) {
- case PLUS_MULT:
- return " T %TMP% = %IN1% + %IN2% * %IN3%;\n";
-
- case MINUS_MULT:
- return " T %TMP% = %IN1% - %IN2% * %IN3%;\n";
-
- case BIASADD:
- return " T %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
-
- case BIASMULT:
- return " T %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
-
- case REPLACE:
- return " T %TMP% = (%IN1% == %IN2% || (isnan(%IN1%) "
- + "&& isnan(%IN2%))) ? %IN3% : %IN1%;\n";
-
- case REPLACE_NAN:
- return " T %TMP% = isnan(%IN1%) ? %IN3% : %IN1%;\n";
-
- case IFELSE:
- return " T %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
-
- case LOOKUP_RC1:
- return sparse ?
- " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
- " T %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
-
- case LOOKUP_RVECT1:
- return " T[] %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
-
- default:
- throw new RuntimeException("Invalid ternary type: " + this.toString());
- }
- }
- else {
- switch (type) {
- case PLUS_MULT:
- return " T %TMP% = %IN1% + %IN2% * %IN3%;\n";
-
- case MINUS_MULT:
- return " T %TMP% = %IN1% - %IN2% * %IN3%;\n";
-
- case BIASADD:
- return " T %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
-
- case BIASMULT:
- return " T %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
-
- case REPLACE:
- return " T %TMP% = (%IN1% == %IN2% || (isnan(%IN1%) "
- + "&& isnan(%IN2%))) ? %IN3% : %IN1%;\n";
-
- case REPLACE_NAN:
- return " T %TMP% = isnan(%IN1%) ? %IN3% : %IN1%;\n";
-
- case IFELSE:
- return " T %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
-
- case LOOKUP_RC1:
- return sparse ?
- " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
- " T %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
-
- case LOOKUP_RVECT1:
- return " T[] %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
-
- default:
- throw new RuntimeException("Invalid ternary type: "+this.toString());
- }
-
- }
- }
-
- @Override
- public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector,
- boolean scalarInput) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
-
- @Override
- public String getTemplate() {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
-
- @Override
- public String getTemplate(SpoofCellwise.CellType ct) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
-
- @Override
- public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
+ if(isSinglePrecision()) {
+ switch (type) {
+ case PLUS_MULT:
+ return " T %TMP% = %IN1% + %IN2% * %IN3%;\n";
+
+ case MINUS_MULT:
+ return " T %TMP% = %IN1% - %IN2% * %IN3%;\n";
+
+ case BIASADD:
+ return " T %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
+
+ case BIASMULT:
+ return " T %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
+
+ case REPLACE:
+ return " T %TMP% = (%IN1% == %IN2% || (isnan(%IN1%) "
+ + "&& isnan(%IN2%))) ? %IN3% : %IN1%;\n";
+
+ case REPLACE_NAN:
+ return " T %TMP% = isnan(%IN1%) ? %IN3% : %IN1%;\n";
+
+ case IFELSE:
+ return " T %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
+
+ case LOOKUP_RC1:
+ return sparse ?
+ " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
+ " T %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
+
+ case LOOKUP_RVECT1:
+ return " T[] %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
+
+ default:
+ throw new RuntimeException("Invalid ternary type: " + this.toString());
+ }
+ }
+ else {
+ switch (type) {
+ case PLUS_MULT:
+ return " T %TMP% = %IN1% + %IN2% * %IN3%;\n";
+
+ case MINUS_MULT:
+ return " T %TMP% = %IN1% - %IN2% * %IN3%;\n";
+
+ case BIASADD:
+ return " T %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
+
+ case BIASMULT:
+ return " T %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
+
+ case REPLACE:
+ return " T %TMP% = (%IN1% == %IN2% || (isnan(%IN1%) "
+ + "&& isnan(%IN2%))) ? %IN3% : %IN1%;\n";
+
+ case REPLACE_NAN:
+ return " T %TMP% = isnan(%IN1%) ? %IN3% : %IN1%;\n";
+
+ case IFELSE:
+ return " T %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
+
+ case LOOKUP_RC1:
+ return sparse ?
+ " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
+ " T %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
+
+ case LOOKUP_RVECT1:
+ return " T[] %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
+
+ default:
+ throw new RuntimeException("Invalid ternary type: "+this.toString());
+ }
+
+ }
+ }
+
+ @Override
+ public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector,
+ boolean scalarInput) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
+
+ @Override
+ public String getTemplate() {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
+
+ @Override
+ public String getTemplate(SpoofCellwise.CellType ct) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
+
+ @Override
+ public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Unary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Unary.java
index ed18779..d50e4b0 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Unary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cpp/Unary.java
@@ -29,230 +29,230 @@ import org.apache.sysds.runtime.codegen.SpoofCellwise;
import static org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePrecision;
public class Unary implements CodeTemplate {
- @Override
- public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
- if(isSinglePrecision()) {
- switch( type ) {
- case ROW_SUMS:
- case ROW_SUMSQS:
- case ROW_MINS:
- case ROW_MAXS:
- case ROW_MEANS:
- case ROW_COUNTNNZS: {
- String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
- return sparse ? " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
- " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
- }
+ @Override
+ public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
+ if(isSinglePrecision()) {
+ switch( type ) {
+ case ROW_SUMS:
+ case ROW_SUMSQS:
+ case ROW_MINS:
+ case ROW_MAXS:
+ case ROW_MEANS:
+ case ROW_COUNTNNZS: {
+ String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
+ return sparse ? " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
+ " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
+ }
- case VECT_EXP:
- case VECT_POW2:
- case VECT_MULT2:
- case VECT_SQRT:
- case VECT_LOG:
- case VECT_ABS:
- case VECT_ROUND:
- case VECT_CEIL:
- case VECT_FLOOR:
- case VECT_SIGN:
- case VECT_SIN:
- case VECT_COS:
- case VECT_TAN:
- case VECT_ASIN:
- case VECT_ACOS:
- case VECT_ATAN:
- case VECT_SINH:
- case VECT_COSH:
- case VECT_TANH:
- case VECT_CUMSUM:
- case VECT_CUMMIN:
- case VECT_CUMMAX:
- case VECT_SPROP:
- case VECT_SIGMOID: {
- String vectName = type.getVectorPrimitiveName();
- return sparse ? " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" :
- " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n";
- }
+ case VECT_EXP:
+ case VECT_POW2:
+ case VECT_MULT2:
+ case VECT_SQRT:
+ case VECT_LOG:
+ case VECT_ABS:
+ case VECT_ROUND:
+ case VECT_CEIL:
+ case VECT_FLOOR:
+ case VECT_SIGN:
+ case VECT_SIN:
+ case VECT_COS:
+ case VECT_TAN:
+ case VECT_ASIN:
+ case VECT_ACOS:
+ case VECT_ATAN:
+ case VECT_SINH:
+ case VECT_COSH:
+ case VECT_TANH:
+ case VECT_CUMSUM:
+ case VECT_CUMMIN:
+ case VECT_CUMMAX:
+ case VECT_SPROP:
+ case VECT_SIGMOID: {
+ String vectName = type.getVectorPrimitiveName();
+ return sparse ? " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" :
+ " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n";
+ }
- case EXP:
- return " T %TMP% = expf(%IN1%);\n";
- case LOOKUP_R:
- return sparse ?
- " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
- " T %TMP% = getValue(%IN1%, rix);\n";
- case LOOKUP_C:
- return " T %TMP% = getValue(%IN1%, n, 0, cix);\n";
- case LOOKUP_RC:
- return " T %TMP% = getValue(%IN1%, n, rix, cix);\n";
- case LOOKUP0:
- return " T %TMP% = %IN1%[0];\n";
- case POW2:
- return " T %TMP% = %IN1% * %IN1%;\n";
- case MULT2:
- return " T %TMP% = %IN1% + %IN1%;\n";
- case ABS:
- return " T %TMP% = fabsf(%IN1%);\n";
- case SIN:
- return " T %TMP% = sinf(%IN1%);\n";
- case COS:
- return " T %TMP% = cosf(%IN1%);\n";
- case TAN:
- return " T %TMP% = tanf(%IN1%);\n";
- case ASIN:
- return " T %TMP% = asinf(%IN1%);\n";
- case ACOS:
- return " T %TMP% = acosf(%IN1%);\n";
- case ATAN:
- return " T %TMP% = atanf(%IN1%);\n";
- case SINH:
- return " T %TMP% = sinhf(%IN1%);\n";
- case COSH:
- return " T %TMP% = coshf(%IN1%);\n";
- case TANH:
- return " T %TMP% = tanhf(%IN1%);\n";
- case SIGN:
- return " T %TMP% = signbit(%IN1%) == 0 ? 1.0f : -1.0f;\n";
- case SQRT:
- return " T %TMP% = sqrtf(%IN1%);\n";
- case LOG:
- return " T %TMP% = logf(%IN1%);\n";
- case ROUND:
- return " T %TMP% = roundf(%IN1%);\n";
- case CEIL:
- return " T %TMP% = ceilf(%IN1%);\n";
- case FLOOR:
- return " T %TMP% = floorf(%IN1%);\n";
- case SPROP:
- return " T %TMP% = %IN1% * (1 - %IN1%);\n";
- case SIGMOID:
- return " T %TMP% = 1 / (1 + expf(-%IN1%));\n";
- case LOG_NZ:
- return " T %TMP% = (%IN1%==0) ? 0 : logf(%IN1%);\n";
+ case EXP:
+ return " T %TMP% = expf(%IN1%);\n";
+ case LOOKUP_R:
+ return sparse ?
+ " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
+ " T %TMP% = getValue(%IN1%, rix);\n";
+ case LOOKUP_C:
+ return " T %TMP% = getValue(%IN1%, n, 0, cix);\n";
+ case LOOKUP_RC:
+ return " T %TMP% = getValue(%IN1%, n, rix, cix);\n";
+ case LOOKUP0:
+ return " T %TMP% = %IN1%[0];\n";
+ case POW2:
+ return " T %TMP% = %IN1% * %IN1%;\n";
+ case MULT2:
+ return " T %TMP% = %IN1% + %IN1%;\n";
+ case ABS:
+ return " T %TMP% = fabsf(%IN1%);\n";
+ case SIN:
+ return " T %TMP% = sinf(%IN1%);\n";
+ case COS:
+ return " T %TMP% = cosf(%IN1%);\n";
+ case TAN:
+ return " T %TMP% = tanf(%IN1%);\n";
+ case ASIN:
+ return " T %TMP% = asinf(%IN1%);\n";
+ case ACOS:
+ return " T %TMP% = acosf(%IN1%);\n";
+ case ATAN:
+ return " T %TMP% = atanf(%IN1%);\n";
+ case SINH:
+ return " T %TMP% = sinhf(%IN1%);\n";
+ case COSH:
+ return " T %TMP% = coshf(%IN1%);\n";
+ case TANH:
+ return " T %TMP% = tanhf(%IN1%);\n";
+ case SIGN:
+ return " T %TMP% = signbit(%IN1%) == 0 ? 1.0f : -1.0f;\n";
+ case SQRT:
+ return " T %TMP% = sqrtf(%IN1%);\n";
+ case LOG:
+ return " T %TMP% = logf(%IN1%);\n";
+ case ROUND:
+ return " T %TMP% = roundf(%IN1%);\n";
+ case CEIL:
+ return " T %TMP% = ceilf(%IN1%);\n";
+ case FLOOR:
+ return " T %TMP% = floorf(%IN1%);\n";
+ case SPROP:
+ return " T %TMP% = %IN1% * (1 - %IN1%);\n";
+ case SIGMOID:
+ return " T %TMP% = 1 / (1 + expf(-%IN1%));\n";
+ case LOG_NZ:
+ return " T %TMP% = (%IN1%==0) ? 0 : logf(%IN1%);\n";
- default:
- throw new RuntimeException("Invalid unary type: "+this.toString());
- }
- }
- else { /* double precision */
- switch( type ) {
- case ROW_SUMS:
- case ROW_SUMSQS:
- case ROW_MINS:
- case ROW_MAXS:
- case ROW_MEANS:
- case ROW_COUNTNNZS: {
- String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
- return sparse ? " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
- " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
- }
+ default:
+ throw new RuntimeException("Invalid unary type: "+this.toString());
+ }
+ }
+ else { /* double precision */
+ switch( type ) {
+ case ROW_SUMS:
+ case ROW_SUMSQS:
+ case ROW_MINS:
+ case ROW_MAXS:
+ case ROW_MEANS:
+ case ROW_COUNTNNZS: {
+ String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
+ return sparse ? " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
+ " T %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
+ }
- case VECT_EXP:
- case VECT_POW2:
- case VECT_MULT2:
- case VECT_SQRT:
- case VECT_LOG:
- case VECT_ABS:
- case VECT_ROUND:
- case VECT_CEIL:
- case VECT_FLOOR:
- case VECT_SIGN:
- case VECT_SIN:
- case VECT_COS:
- case VECT_TAN:
- case VECT_ASIN:
- case VECT_ACOS:
- case VECT_ATAN:
- case VECT_SINH:
- case VECT_COSH:
- case VECT_TANH:
- case VECT_CUMSUM:
- case VECT_CUMMIN:
- case VECT_CUMMAX:
- case VECT_SPROP:
- case VECT_SIGMOID: {
- String vectName = type.getVectorPrimitiveName();
- return sparse ? " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" :
- " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n";
- }
+ case VECT_EXP:
+ case VECT_POW2:
+ case VECT_MULT2:
+ case VECT_SQRT:
+ case VECT_LOG:
+ case VECT_ABS:
+ case VECT_ROUND:
+ case VECT_CEIL:
+ case VECT_FLOOR:
+ case VECT_SIGN:
+ case VECT_SIN:
+ case VECT_COS:
+ case VECT_TAN:
+ case VECT_ASIN:
+ case VECT_ACOS:
+ case VECT_ATAN:
+ case VECT_SINH:
+ case VECT_COSH:
+ case VECT_TANH:
+ case VECT_CUMSUM:
+ case VECT_CUMMIN:
+ case VECT_CUMMAX:
+ case VECT_SPROP:
+ case VECT_SIGMOID: {
+ String vectName = type.getVectorPrimitiveName();
+ return sparse ? " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" :
+ " T[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n";
+ }
- case EXP:
- return " T %TMP% = exp(%IN1%);\n";
- case LOOKUP_R:
- return sparse ?
- " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
- " T %TMP% = getValue(%IN1%, rix);\n";
- case LOOKUP_C:
- return " T %TMP% = getValue(%IN1%, n, 0, cix);\n";
- case LOOKUP_RC:
- return " T %TMP% = getValue(%IN1%, n, rix, cix);\n";
- case LOOKUP0:
- return " T %TMP% = %IN1%[0];\n";
- case POW2:
- return " T %TMP% = %IN1% * %IN1%;\n";
- case MULT2:
- return " T %TMP% = %IN1% + %IN1%;\n";
- case ABS:
- return " T %TMP% = fabs(%IN1%);\n";
- case SIN:
- return " T %TMP% = sin(%IN1%);\n";
- case COS:
- return " T %TMP% = cos(%IN1%);\n";
- case TAN:
- return " T %TMP% = tan(%IN1%);\n";
- case ASIN:
- return " T %TMP% = asin(%IN1%);\n";
- case ACOS:
- return " T %TMP% = acos(%IN1%);\n";
- case ATAN:
- return " T %TMP% = atan(%IN1%);\n";
- case SINH:
- return " T %TMP% = sinh(%IN1%);\n";
- case COSH:
- return " T %TMP% = cosh(%IN1%);\n";
- case TANH:
- return " T %TMP% = tanh(%IN1%);\n";
- case SIGN:
- return " T %TMP% = signbit(%IN1%) == 0 ? 1.0f : -1.0f;\n";
- case SQRT:
- return " T %TMP% = sqrt(%IN1%);\n";
- case LOG:
- return " T %TMP% = log(%IN1%);\n";
- case ROUND:
- return " T %TMP% = round(%IN1%);\n";
- case CEIL:
- return " T %TMP% = ceil(%IN1%);\n";
- case FLOOR:
- return " T %TMP% = floor(%IN1%);\n";
- case SPROP:
- return " T %TMP% = %IN1% * (1 - %IN1%);\n";
- case SIGMOID:
- return " T %TMP% = 1 / (1 + exp(-%IN1%));\n";
- case LOG_NZ:
- return " T %TMP% = (%IN1%==0) ? 0 : log(%IN1%);\n";
+ case EXP:
+ return " T %TMP% = exp(%IN1%);\n";
+ case LOOKUP_R:
+ return sparse ?
+ " T %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
+ " T %TMP% = getValue(%IN1%, rix);\n";
+ case LOOKUP_C:
+ return " T %TMP% = getValue(%IN1%, n, 0, cix);\n";
+ case LOOKUP_RC:
+ return " T %TMP% = getValue(%IN1%, n, rix, cix);\n";
+ case LOOKUP0:
+ return " T %TMP% = %IN1%[0];\n";
+ case POW2:
+ return " T %TMP% = %IN1% * %IN1%;\n";
+ case MULT2:
+ return " T %TMP% = %IN1% + %IN1%;\n";
+ case ABS:
+ return " T %TMP% = fabs(%IN1%);\n";
+ case SIN:
+ return " T %TMP% = sin(%IN1%);\n";
+ case COS:
+ return " T %TMP% = cos(%IN1%);\n";
+ case TAN:
+ return " T %TMP% = tan(%IN1%);\n";
+ case ASIN:
+ return " T %TMP% = asin(%IN1%);\n";
+ case ACOS:
+ return " T %TMP% = acos(%IN1%);\n";
+ case ATAN:
+ return " T %TMP% = atan(%IN1%);\n";
+ case SINH:
+ return " T %TMP% = sinh(%IN1%);\n";
+ case COSH:
+ return " T %TMP% = cosh(%IN1%);\n";
+ case TANH:
+ return " T %TMP% = tanh(%IN1%);\n";
+ case SIGN:
+ return " T %TMP% = signbit(%IN1%) == 0 ? 1.0f : -1.0f;\n";
+ case SQRT:
+ return " T %TMP% = sqrt(%IN1%);\n";
+ case LOG:
+ return " T %TMP% = log(%IN1%);\n";
+ case ROUND:
+ return " T %TMP% = round(%IN1%);\n";
+ case CEIL:
+ return " T %TMP% = ceil(%IN1%);\n";
+ case FLOOR:
+ return " T %TMP% = floor(%IN1%);\n";
+ case SPROP:
+ return " T %TMP% = %IN1% * (1 - %IN1%);\n";
+ case SIGMOID:
+ return " T %TMP% = 1 / (1 + exp(-%IN1%));\n";
+ case LOG_NZ:
+ return " T %TMP% = (%IN1%==0) ? 0 : log(%IN1%);\n";
- default:
- throw new RuntimeException("Invalid unary type: "+this.toString());
- }
+ default:
+ throw new RuntimeException("Invalid unary type: "+this.toString());
+ }
- }
- }
+ }
+ }
- @Override
- public String getTemplate() {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate() {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(SpoofCellwise.CellType ct) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(SpoofCellwise.CellType ct) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
index 39b0f6f..28b970a 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
@@ -26,175 +26,175 @@ import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
public class Binary implements CodeTemplate {
- @Override
- public String getTemplate(BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector,
- boolean scalarInput) {
-
- switch (type) {
- case DOT_PRODUCT:
- return sparseLhs ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
- " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- case VECT_MATRIXMULT:
- return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" :
- " double[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- case VECT_OUTERMULT_ADD:
- return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" :
- sparseRhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" :
- " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
-
- //vector-scalar-add operations
- case VECT_MULT_ADD:
- case VECT_DIV_ADD:
- case VECT_MINUS_ADD:
- case VECT_PLUS_ADD:
- case VECT_POW_ADD:
- case VECT_XOR_ADD:
- case VECT_MIN_ADD:
- case VECT_MAX_ADD:
- case VECT_EQUAL_ADD:
- case VECT_NOTEQUAL_ADD:
- case VECT_LESS_ADD:
- case VECT_LESSEQUAL_ADD:
- case VECT_GREATER_ADD:
- case VECT_GREATEREQUAL_ADD:
- case VECT_CBIND_ADD: {
- String vectName = type.getVectorPrimitiveName();
- if( scalarVector )
- return sparseLhs ? " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" :
- " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
- else
- return sparseLhs ? " LibSpoofPrimitives.vect"+vectName+"Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" :
- " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
- }
-
- //vector-scalar operations
- case VECT_MULT_SCALAR:
- case VECT_DIV_SCALAR:
- case VECT_MINUS_SCALAR:
- case VECT_PLUS_SCALAR:
- case VECT_POW_SCALAR:
- case VECT_XOR_SCALAR:
- case VECT_BITWAND_SCALAR:
- case VECT_MIN_SCALAR:
- case VECT_MAX_SCALAR:
- case VECT_EQUAL_SCALAR:
- case VECT_NOTEQUAL_SCALAR:
- case VECT_LESS_SCALAR:
- case VECT_LESSEQUAL_SCALAR:
- case VECT_GREATER_SCALAR:
- case VECT_GREATEREQUAL_SCALAR: {
- String vectName = type.getVectorPrimitiveName();
- if( scalarVector )
- return sparseRhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" :
- " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS2%, %LEN%);\n";
- else
- return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" :
- " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
- }
-
- case VECT_CBIND:
- if( scalarInput )
- return " double[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%);\n";
- else
- return sparseLhs ?
- " double[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" :
- " double[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n";
-
- //vector-vector operations
- case VECT_MULT:
- case VECT_DIV:
- case VECT_MINUS:
- case VECT_PLUS:
- case VECT_XOR:
- case VECT_BITWAND:
- case VECT_BIASADD:
- case VECT_BIASMULT:
- case VECT_MIN:
- case VECT_MAX:
- case VECT_EQUAL:
- case VECT_NOTEQUAL:
- case VECT_LESS:
- case VECT_LESSEQUAL:
- case VECT_GREATER:
- case VECT_GREATEREQUAL: {
- String vectName = type.getVectorPrimitiveName();
- return sparseLhs ?
- " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" :
- sparseRhs ?
- " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" :
- " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
- }
-
- //scalar-scalar operations
- case MULT:
- return " double %TMP% = %IN1% * %IN2%;\n";
-
- case DIV:
- return " double %TMP% = %IN1% / %IN2%;\n";
- case PLUS:
- return " double %TMP% = %IN1% + %IN2%;\n";
- case MINUS:
- return " double %TMP% = %IN1% - %IN2%;\n";
- case MODULUS:
- return " double %TMP% = LibSpoofPrimitives.mod(%IN1%, %IN2%);\n";
- case INTDIV:
- return " double %TMP% = LibSpoofPrimitives.intDiv(%IN1%, %IN2%);\n";
- case LESS:
- return " double %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n";
- case LESSEQUAL:
- return " double %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n";
- case GREATER:
- return " double %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n";
- case GREATEREQUAL:
- return " double %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n";
- case EQUAL:
- return " double %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n";
- case NOTEQUAL:
- return " double %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n";
-
- case MIN:
- return " double %TMP% = Math.min(%IN1%, %IN2%);\n";
- case MAX:
- return " double %TMP% = Math.max(%IN1%, %IN2%);\n";
- case LOG:
- return " double %TMP% = Math.log(%IN1%)/Math.log(%IN2%);\n";
- case LOG_NZ:
- return " double %TMP% = (%IN1% == 0) ? 0 : Math.log(%IN1%)/Math.log(%IN2%);\n";
- case POW:
- return " double %TMP% = Math.pow(%IN1%, %IN2%);\n";
- case MINUS1_MULT:
- return " double %TMP% = 1 - %IN1% * %IN2%;\n";
- case MINUS_NZ:
- return " double %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
- case XOR:
- return " double %TMP% = ( (%IN1% != 0) != (%IN2% != 0) ) ? 1 : 0;\n";
- case BITWAND:
- return " double %TMP% = LibSpoofPrimitives.bwAnd(%IN1%, %IN2%);\n";
- case SEQ_RIX:
- return " double %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
-
- default:
- throw new RuntimeException("Invalid binary type: "+this.toString());
- }
- }
-
- @Override
- public String getTemplate() {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
-
- @Override
- public String getTemplate(SpoofCellwise.CellType ct) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
-
- @Override
- public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
-
- @Override
- public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector,
+ boolean scalarInput) {
+
+ switch (type) {
+ case DOT_PRODUCT:
+ return sparseLhs ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
+ " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+ case VECT_MATRIXMULT:
+ return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" :
+ " double[] %TMP% = LibSpoofPrimitives.vectMatrixMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+ case VECT_OUTERMULT_ADD:
+ return sparseLhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" :
+ sparseRhs ? " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" :
+ " LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";
+
+ //vector-scalar-add operations
+ case VECT_MULT_ADD:
+ case VECT_DIV_ADD:
+ case VECT_MINUS_ADD:
+ case VECT_PLUS_ADD:
+ case VECT_POW_ADD:
+ case VECT_XOR_ADD:
+ case VECT_MIN_ADD:
+ case VECT_MAX_ADD:
+ case VECT_EQUAL_ADD:
+ case VECT_NOTEQUAL_ADD:
+ case VECT_LESS_ADD:
+ case VECT_LESSEQUAL_ADD:
+ case VECT_GREATER_ADD:
+ case VECT_GREATEREQUAL_ADD:
+ case VECT_CBIND_ADD: {
+ String vectName = type.getVectorPrimitiveName();
+ if( scalarVector )
+ return sparseLhs ? " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" :
+ " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
+ else
+ return sparseLhs ? " LibSpoofPrimitives.vect"+vectName+"Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" :
+ " LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
+ }
+
+ //vector-scalar operations
+ case VECT_MULT_SCALAR:
+ case VECT_DIV_SCALAR:
+ case VECT_MINUS_SCALAR:
+ case VECT_PLUS_SCALAR:
+ case VECT_POW_SCALAR:
+ case VECT_XOR_SCALAR:
+ case VECT_BITWAND_SCALAR:
+ case VECT_MIN_SCALAR:
+ case VECT_MAX_SCALAR:
+ case VECT_EQUAL_SCALAR:
+ case VECT_NOTEQUAL_SCALAR:
+ case VECT_LESS_SCALAR:
+ case VECT_LESSEQUAL_SCALAR:
+ case VECT_GREATER_SCALAR:
+ case VECT_GREATEREQUAL_SCALAR: {
+ String vectName = type.getVectorPrimitiveName();
+ if( scalarVector )
+ return sparseRhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" :
+ " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS2%, %LEN%);\n";
+ else
+ return sparseLhs ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" :
+ " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
+ }
+
+ case VECT_CBIND:
+ if( scalarInput )
+ return " double[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%);\n";
+ else
+ return sparseLhs ?
+ " double[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" :
+ " double[] %TMP% = LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n";
+
+ //vector-vector operations
+ case VECT_MULT:
+ case VECT_DIV:
+ case VECT_MINUS:
+ case VECT_PLUS:
+ case VECT_XOR:
+ case VECT_BITWAND:
+ case VECT_BIASADD:
+ case VECT_BIASMULT:
+ case VECT_MIN:
+ case VECT_MAX:
+ case VECT_EQUAL:
+ case VECT_NOTEQUAL:
+ case VECT_LESS:
+ case VECT_LESSEQUAL:
+ case VECT_GREATER:
+ case VECT_GREATEREQUAL: {
+ String vectName = type.getVectorPrimitiveName();
+ return sparseLhs ?
+ " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN%);\n" :
+ sparseRhs ?
+ " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, alen, %LEN%);\n" :
+ " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+ }
+
+ //scalar-scalar operations
+ case MULT:
+ return " double %TMP% = %IN1% * %IN2%;\n";
+
+ case DIV:
+ return " double %TMP% = %IN1% / %IN2%;\n";
+ case PLUS:
+ return " double %TMP% = %IN1% + %IN2%;\n";
+ case MINUS:
+ return " double %TMP% = %IN1% - %IN2%;\n";
+ case MODULUS:
+ return " double %TMP% = LibSpoofPrimitives.mod(%IN1%, %IN2%);\n";
+ case INTDIV:
+ return " double %TMP% = LibSpoofPrimitives.intDiv(%IN1%, %IN2%);\n";
+ case LESS:
+ return " double %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n";
+ case LESSEQUAL:
+ return " double %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n";
+ case GREATER:
+ return " double %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n";
+ case GREATEREQUAL:
+ return " double %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n";
+ case EQUAL:
+ return " double %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n";
+ case NOTEQUAL:
+ return " double %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n";
+
+ case MIN:
+ return " double %TMP% = Math.min(%IN1%, %IN2%);\n";
+ case MAX:
+ return " double %TMP% = Math.max(%IN1%, %IN2%);\n";
+ case LOG:
+ return " double %TMP% = Math.log(%IN1%)/Math.log(%IN2%);\n";
+ case LOG_NZ:
+ return " double %TMP% = (%IN1% == 0) ? 0 : Math.log(%IN1%)/Math.log(%IN2%);\n";
+ case POW:
+ return " double %TMP% = Math.pow(%IN1%, %IN2%);\n";
+ case MINUS1_MULT:
+ return " double %TMP% = 1 - %IN1% * %IN2%;\n";
+ case MINUS_NZ:
+ return " double %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
+ case XOR:
+ return " double %TMP% = ( (%IN1% != 0) != (%IN2% != 0) ) ? 1 : 0;\n";
+ case BITWAND:
+ return " double %TMP% = LibSpoofPrimitives.bwAnd(%IN1%, %IN2%);\n";
+ case SEQ_RIX:
+ return " double %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix
+
+ default:
+ throw new RuntimeException("Invalid binary type: "+this.toString());
+ }
+ }
+
+ @Override
+ public String getTemplate() {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
+
+ @Override
+ public String getTemplate(SpoofCellwise.CellType ct) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
+
+ @Override
+ public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
+
+ @Override
+ public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/CellWise.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/CellWise.java
index 319c872..85476a7 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/CellWise.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/CellWise.java
@@ -26,54 +26,54 @@ import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
public class CellWise implements CodeTemplate {
- public static final String TEMPLATE =
- "package codegen;\n"
- + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n"
- + "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n"
- + "import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;\n"
- + "import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;\n"
- + "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n"
- + "import org.apache.commons.math3.util.FastMath;\n"
- + "\n"
- + "public final class %TMP% extends SpoofCellwise {\n"
- + " public %TMP%() {\n"
- + " super(CellType.%TYPE%, %SPARSE_SAFE%, %SEQ%, %AGG_OP_NAME%);\n"
- + " }\n"
- + " protected double genexec(double a, SideInput[] b, double[] scalars, int m, int n, long grix, int rix, int cix) { \n"
- + "%BODY_dense%"
- + " return %OUT%;\n"
- + " }\n"
- + "}\n";
+ public static final String TEMPLATE =
+ "package codegen;\n"
+ + "import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n"
+ + "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n"
+ + "import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;\n"
+ + "import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;\n"
+ + "import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n"
+ + "import org.apache.commons.math3.util.FastMath;\n"
+ + "\n"
+ + "public final class %TMP% extends SpoofCellwise {\n"
+ + " public %TMP%() {\n"
+ + " super(CellType.%TYPE%, %SPARSE_SAFE%, %SEQ%, %AGG_OP_NAME%);\n"
+ + " }\n"
+ + " protected double genexec(double a, SideInput[] b, double[] scalars, int m, int n, long grix, int rix, int cix) { \n"
+ + "%BODY_dense%"
+ + " return %OUT%;\n"
+ + " }\n"
+ + "}\n";
- @Override
- public String getTemplate() {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate() {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(SpoofCellwise.CellType ct) {
- switch(ct) {
- case NO_AGG:
- case FULL_AGG:
- case ROW_AGG:
- case COL_AGG:
- default:
- return TEMPLATE;
- }
- }
+ @Override
+ public String getTemplate(SpoofCellwise.CellType ct) {
+ switch(ct) {
+ case NO_AGG:
+ case FULL_AGG:
+ case ROW_AGG:
+ case COL_AGG:
+ default:
+ return TEMPLATE;
+ }
+ }
- @Override
- public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java
index af48d05..a499f49 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java
@@ -27,63 +27,63 @@ import org.apache.sysds.runtime.codegen.SpoofCellwise;
public class Ternary implements CodeTemplate {
- @Override
- public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
- switch (type) {
- case PLUS_MULT:
- return " double %TMP% = %IN1% + %IN2% * %IN3%;\n";
-
- case MINUS_MULT:
- return " double %TMP% = %IN1% - %IN2% * %IN3%;\n";
-
- case BIASADD:
- return " double %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
-
- case BIASMULT:
- return " double %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
-
- case REPLACE:
- return " double %TMP% = (%IN1% == %IN2% || (Double.isNaN(%IN1%) "
- + "&& Double.isNaN(%IN2%))) ? %IN3% : %IN1%;\n";
-
- case REPLACE_NAN:
- return " double %TMP% = Double.isNaN(%IN1%) ? %IN3% : %IN1%;\n";
-
- case IFELSE:
- return " double %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
-
- case LOOKUP_RC1:
- return sparse ?
- " double %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
- " double %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
-
- case LOOKUP_RVECT1:
- return " double[] %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
-
- default:
- throw new RuntimeException("Invalid ternary type: "+this.toString());
- }
- }
-
- @Override
- public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector,
- boolean scalarInput) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
-
- @Override
- public String getTemplate() {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
-
- @Override
- public String getTemplate(SpoofCellwise.CellType ct) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
-
- @Override
- public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
+ switch (type) {
+ case PLUS_MULT:
+ return " double %TMP% = %IN1% + %IN2% * %IN3%;\n";
+
+ case MINUS_MULT:
+ return " double %TMP% = %IN1% - %IN2% * %IN3%;\n";
+
+ case BIASADD:
+ return " double %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
+
+ case BIASMULT:
+ return " double %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
+
+ case REPLACE:
+ return " double %TMP% = (%IN1% == %IN2% || (Double.isNaN(%IN1%) "
+ + "&& Double.isNaN(%IN2%))) ? %IN3% : %IN1%;\n";
+
+ case REPLACE_NAN:
+ return " double %TMP% = Double.isNaN(%IN1%) ? %IN3% : %IN1%;\n";
+
+ case IFELSE:
+ return " double %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
+
+ case LOOKUP_RC1:
+ return sparse ?
+ " double %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
+ " double %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
+
+ case LOOKUP_RVECT1:
+ return " double[] %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
+
+ default:
+ throw new RuntimeException("Invalid ternary type: "+this.toString());
+ }
+ }
+
+ @Override
+ public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector,
+ boolean scalarInput) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
+
+ @Override
+ public String getTemplate() {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
+
+ @Override
+ public String getTemplate(SpoofCellwise.CellType ct) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
+
+ @Override
+ public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
index 7071e08..5f6a392 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
@@ -27,126 +27,126 @@ import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
public class Unary implements CodeTemplate {
- @Override
- public String getTemplate(UnaryType type, boolean sparse) {
- switch( type ) {
- case ROW_SUMS:
- case ROW_SUMSQS:
- case ROW_MINS:
- case ROW_MAXS:
- case ROW_MEANS:
- case ROW_COUNTNNZS: {
- String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
- return sparse ? " double %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
- " double %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
- }
+ @Override
+ public String getTemplate(UnaryType type, boolean sparse) {
+ switch( type ) {
+ case ROW_SUMS:
+ case ROW_SUMSQS:
+ case ROW_MINS:
+ case ROW_MAXS:
+ case ROW_MEANS:
+ case ROW_COUNTNNZS: {
+ String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
+ return sparse ? " double %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
+ " double %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
+ }
- case VECT_EXP:
- case VECT_POW2:
- case VECT_MULT2:
- case VECT_SQRT:
- case VECT_LOG:
- case VECT_ABS:
- case VECT_ROUND:
- case VECT_CEIL:
- case VECT_FLOOR:
- case VECT_SIGN:
- case VECT_SIN:
- case VECT_COS:
- case VECT_TAN:
- case VECT_ASIN:
- case VECT_ACOS:
- case VECT_ATAN:
- case VECT_SINH:
- case VECT_COSH:
- case VECT_TANH:
- case VECT_CUMSUM:
- case VECT_CUMMIN:
- case VECT_CUMMAX:
- case VECT_SPROP:
- case VECT_SIGMOID: {
- String vectName = type.getVectorPrimitiveName();
- return sparse ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" :
- " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n";
- }
+ case VECT_EXP:
+ case VECT_POW2:
+ case VECT_MULT2:
+ case VECT_SQRT:
+ case VECT_LOG:
+ case VECT_ABS:
+ case VECT_ROUND:
+ case VECT_CEIL:
+ case VECT_FLOOR:
+ case VECT_SIGN:
+ case VECT_SIN:
+ case VECT_COS:
+ case VECT_TAN:
+ case VECT_ASIN:
+ case VECT_ACOS:
+ case VECT_ATAN:
+ case VECT_SINH:
+ case VECT_COSH:
+ case VECT_TANH:
+ case VECT_CUMSUM:
+ case VECT_CUMMIN:
+ case VECT_CUMMAX:
+ case VECT_SPROP:
+ case VECT_SIGMOID: {
+ String vectName = type.getVectorPrimitiveName();
+ return sparse ? " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len);\n" :
+ " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%);\n";
+ }
- case EXP:
- return " double %TMP% = FastMath.exp(%IN1%);\n";
- case LOOKUP_R:
- return sparse ?
- " double %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
- " double %TMP% = getValue(%IN1%, rix);\n";
- case LOOKUP_C:
- return " double %TMP% = getValue(%IN1%, n, 0, cix);\n";
- case LOOKUP_RC:
- return " double %TMP% = getValue(%IN1%, n, rix, cix);\n";
- case LOOKUP0:
- return " double %TMP% = %IN1%[0];\n";
- case POW2:
- return " double %TMP% = %IN1% * %IN1%;\n";
- case MULT2:
- return " double %TMP% = %IN1% + %IN1%;\n";
- case ABS:
- return " double %TMP% = Math.abs(%IN1%);\n";
- case SIN:
- return " double %TMP% = FastMath.sin(%IN1%);\n";
- case COS:
- return " double %TMP% = FastMath.cos(%IN1%);\n";
- case TAN:
- return " double %TMP% = FastMath.tan(%IN1%);\n";
- case ASIN:
- return " double %TMP% = FastMath.asin(%IN1%);\n";
- case ACOS:
- return " double %TMP% = FastMath.acos(%IN1%);\n";
- case ATAN:
- return " double %TMP% = Math.atan(%IN1%);\n";
- case SINH:
- return " double %TMP% = FastMath.sinh(%IN1%);\n";
- case COSH:
- return " double %TMP% = FastMath.cosh(%IN1%);\n";
- case TANH:
- return " double %TMP% = FastMath.tanh(%IN1%);\n";
- case SIGN:
- return " double %TMP% = FastMath.signum(%IN1%);\n";
- case SQRT:
- return " double %TMP% = Math.sqrt(%IN1%);\n";
- case LOG:
- return " double %TMP% = Math.log(%IN1%);\n";
- case ROUND:
- return " double %TMP% = Math.round(%IN1%);\n";
- case CEIL:
- return " double %TMP% = FastMath.ceil(%IN1%);\n";
- case FLOOR:
- return " double %TMP% = FastMath.floor(%IN1%);\n";
- case SPROP:
- return " double %TMP% = %IN1% * (1 - %IN1%);\n";
- case SIGMOID:
- return " double %TMP% = 1 / (1 + FastMath.exp(-%IN1%));\n";
- case LOG_NZ:
- return " double %TMP% = (%IN1%==0) ? 0 : Math.log(%IN1%);\n";
+ case EXP:
+ return " double %TMP% = FastMath.exp(%IN1%);\n";
+ case LOOKUP_R:
+ return sparse ?
+ " double %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
+ " double %TMP% = getValue(%IN1%, rix);\n";
+ case LOOKUP_C:
+ return " double %TMP% = getValue(%IN1%, n, 0, cix);\n";
+ case LOOKUP_RC:
+ return " double %TMP% = getValue(%IN1%, n, rix, cix);\n";
+ case LOOKUP0:
+ return " double %TMP% = %IN1%[0];\n";
+ case POW2:
+ return " double %TMP% = %IN1% * %IN1%;\n";
+ case MULT2:
+ return " double %TMP% = %IN1% + %IN1%;\n";
+ case ABS:
+ return " double %TMP% = Math.abs(%IN1%);\n";
+ case SIN:
+ return " double %TMP% = FastMath.sin(%IN1%);\n";
+ case COS:
+ return " double %TMP% = FastMath.cos(%IN1%);\n";
+ case TAN:
+ return " double %TMP% = FastMath.tan(%IN1%);\n";
+ case ASIN:
+ return " double %TMP% = FastMath.asin(%IN1%);\n";
+ case ACOS:
+ return " double %TMP% = FastMath.acos(%IN1%);\n";
+ case ATAN:
+ return " double %TMP% = Math.atan(%IN1%);\n";
+ case SINH:
+ return " double %TMP% = FastMath.sinh(%IN1%);\n";
+ case COSH:
+ return " double %TMP% = FastMath.cosh(%IN1%);\n";
+ case TANH:
+ return " double %TMP% = FastMath.tanh(%IN1%);\n";
+ case SIGN:
+ return " double %TMP% = FastMath.signum(%IN1%);\n";
+ case SQRT:
+ return " double %TMP% = Math.sqrt(%IN1%);\n";
+ case LOG:
+ return " double %TMP% = Math.log(%IN1%);\n";
+ case ROUND:
+ return " double %TMP% = Math.round(%IN1%);\n";
+ case CEIL:
+ return " double %TMP% = FastMath.ceil(%IN1%);\n";
+ case FLOOR:
+ return " double %TMP% = FastMath.floor(%IN1%);\n";
+ case SPROP:
+ return " double %TMP% = %IN1% * (1 - %IN1%);\n";
+ case SIGMOID:
+ return " double %TMP% = 1 / (1 + FastMath.exp(-%IN1%));\n";
+ case LOG_NZ:
+ return " double %TMP% = (%IN1%==0) ? 0 : Math.log(%IN1%);\n";
- default:
- throw new RuntimeException("Invalid unary type: "+this.toString());
- }
- }
+ default:
+ throw new RuntimeException("Invalid unary type: "+this.toString());
+ }
+ }
- @Override
- public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate() {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate() {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(SpoofCellwise.CellType ct) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(SpoofCellwise.CellType ct) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
- @Override
- public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
- throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
- }
+ @Override
+ public String getTemplate(CNodeTernary.TernaryType type, boolean sparse) {
+ throw new RuntimeException("Calling wrong getTemplate method on " + getClass().getCanonicalName());
+ }
}
diff --git a/src/main/java/org/apache/sysds/parser/LanguageException.java b/src/main/java/org/apache/sysds/parser/LanguageException.java
index 755836a..e25f631 100644
--- a/src/main/java/org/apache/sysds/parser/LanguageException.java
+++ b/src/main/java/org/apache/sysds/parser/LanguageException.java
@@ -29,30 +29,30 @@ public class LanguageException extends DMLException
private static final long serialVersionUID = 1L;
- public LanguageException() {
- super();
- }
-
- public LanguageException(String message) {
- super(message);
- }
-
- public LanguageException(Throwable cause) {
- super(cause);
- }
-
- public LanguageException(String message, Throwable cause) {
- super(message, cause);
- }
-
- public LanguageException(String message, String code) {
- super(code + ERROR_MSG_DELIMITER + message);
- }
-
- public static class LanguageErrorCodes {
- public static final String UNSUPPORTED_EXPRESSION = "Unsupported Expression";
- public static final String INVALID_PARAMETERS = "Invalid Parameters";
- public static final String UNSUPPORTED_PARAMETERS = "Unsupported Parameters";
- }
+ public LanguageException() {
+ super();
+ }
+
+ public LanguageException(String message) {
+ super(message);
+ }
+
+ public LanguageException(Throwable cause) {
+ super(cause);
+ }
+
+ public LanguageException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ public LanguageException(String message, String code) {
+ super(code + ERROR_MSG_DELIMITER + message);
+ }
+
+ public static class LanguageErrorCodes {
+ public static final String UNSUPPORTED_EXPRESSION = "Unsupported Expression";
+ public static final String INVALID_PARAMETERS = "Invalid Parameters";
+ public static final String UNSUPPORTED_PARAMETERS = "Unsupported Parameters";
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java
index ac783c6..127ef36 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java
@@ -35,87 +35,88 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import static org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePrecision;
public class SpoofCUDA extends SpoofOperator {
-
- private final CNodeTpl cnt;
- public final String name;
-
- public SpoofCUDA(CNodeTpl cnode) {
- name = "codegen." + cnode.getVarname();
- cnt = cnode;
- }
-
- public String getName() {
- return name;
- }
-
- public CNodeTpl getCNodeTemplate() {
- return cnt;
- }
-
- public String getSpoofTemplateType() {
- if (cnt instanceof CNodeCell)
- return "CW";
- else if(cnt instanceof CNodeRow)
- return "RA";
- else if(cnt instanceof CNodeMultiAgg)
- return "MA";
- else if(cnt instanceof CNodeOuterProduct)
- return "OP";
- else
- throw new RuntimeException("unknown spoof operator type");
- }
- @Override
- public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out) {
- throw new RuntimeException("method not implemented for SpoofNativeCUDA");
- }
-
- public double execute(ArrayList<MatrixObject> inputs, ArrayList<ScalarObject> scalarObjects, MatrixObject out_obj,
- ExecutionContext ec) {
- double ret = 0;
- long out_ptr = 0;
-
- if(out_obj != null)
- out_ptr = ec.getGPUPointerAddress(out_obj);
-
- int offset = 1;
- if(cnt instanceof CNodeOuterProduct)
- offset = 2;
-
- // only dense input preparation for now
- long[] in_ptrs = new long[offset];
- for(int i = 0; i < offset; ++i)
- in_ptrs[i] = ec.getGPUPointerAddress(inputs.get(i));
-
- long[] side_ptrs = new long[inputs.size() - offset];
- for(int i = offset; i < inputs.size(); ++i)
- side_ptrs[i - offset] = ec.getGPUPointerAddress(inputs.get(i));
-
- if(isSinglePrecision()) {
- float[] scalars = prepInputScalarsFloat(scalarObjects);
-
- // ToDo: handle float
- ret = execute_f(SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA), name.split("\\.")[1],
- in_ptrs, side_ptrs, out_ptr, scalars, inputs.get(0).getNumRows(), inputs.get(0).getNumColumns(), 0);
-
- }
- else {
- double[] scalars = prepInputScalars(scalarObjects);
-
- ret = execute_d(SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA), name.split("\\.")[1],
- in_ptrs, side_ptrs, out_ptr, scalars, inputs.get(0).getNumRows(), inputs.get(0).getNumColumns(), 0);
- }
- return ret;
- }
-
- @Override
- public String getSpoofType() {
- String tmp[] = getClass().getName().split("\\.");
- return tmp[tmp.length-1] + "_" + getSpoofTemplateType() + "_" + name.split("\\.")[1];
- }
-
- private native float execute_f(long ctx, String name, long[] in_ptr, long[] side_ptr,
- long out_ptr, float[] scalars, long m, long n, long grix);
-
- private native double execute_d(long ctx, String name, long[] in_ptr, long[] side_ptr,
- long out_ptr, double[] scalars, long m, long n, long grix);
+ private static final long serialVersionUID = -2161276866245388359L;
+
+ private final CNodeTpl cnt;
+ public final String name;
+
+ public SpoofCUDA(CNodeTpl cnode) {
+ name = "codegen." + cnode.getVarname();
+ cnt = cnode;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public CNodeTpl getCNodeTemplate() {
+ return cnt;
+ }
+
+ public String getSpoofTemplateType() {
+ if (cnt instanceof CNodeCell)
+ return "CW";
+ else if(cnt instanceof CNodeRow)
+ return "RA";
+ else if(cnt instanceof CNodeMultiAgg)
+ return "MA";
+ else if(cnt instanceof CNodeOuterProduct)
+ return "OP";
+ else
+ throw new RuntimeException("unknown spoof operator type");
+ }
+ @Override
+ public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out) {
+ throw new RuntimeException("method not implemented for SpoofNativeCUDA");
+ }
+
+ public double execute(ArrayList<MatrixObject> inputs, ArrayList<ScalarObject> scalarObjects, MatrixObject out_obj,
+ ExecutionContext ec) {
+ double ret = 0;
+ long out_ptr = 0;
+
+ if(out_obj != null)
+ out_ptr = ec.getGPUPointerAddress(out_obj);
+
+ int offset = 1;
+ if(cnt instanceof CNodeOuterProduct)
+ offset = 2;
+
+ // only dense input preparation for now
+ long[] in_ptrs = new long[offset];
+ for(int i = 0; i < offset; ++i)
+ in_ptrs[i] = ec.getGPUPointerAddress(inputs.get(i));
+
+ long[] side_ptrs = new long[inputs.size() - offset];
+ for(int i = offset; i < inputs.size(); ++i)
+ side_ptrs[i - offset] = ec.getGPUPointerAddress(inputs.get(i));
+
+ if(isSinglePrecision()) {
+ float[] scalars = prepInputScalarsFloat(scalarObjects);
+
+ // ToDo: handle float
+ ret = execute_f(SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA), name.split("\\.")[1],
+ in_ptrs, side_ptrs, out_ptr, scalars, inputs.get(0).getNumRows(), inputs.get(0).getNumColumns(), 0);
+
+ }
+ else {
+ double[] scalars = prepInputScalars(scalarObjects);
+
+ ret = execute_d(SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA), name.split("\\.")[1],
+ in_ptrs, side_ptrs, out_ptr, scalars, inputs.get(0).getNumRows(), inputs.get(0).getNumColumns(), 0);
+ }
+ return ret;
+ }
+
+ @Override
+ public String getSpoofType() {
+ String tmp[] = getClass().getName().split("\\.");
+ return tmp[tmp.length-1] + "_" + getSpoofTemplateType() + "_" + name.split("\\.")[1];
+ }
+
+ private native float execute_f(long ctx, String name, long[] in_ptr, long[] side_ptr,
+ long out_ptr, float[] scalars, long m, long n, long grix);
+
+ private native double execute_d(long ctx, String name, long[] in_ptr, long[] side_ptr,
+ long out_ptr, double[] scalars, long m, long n, long grix);
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
index 2d28fb2..c4f148a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
@@ -36,229 +36,229 @@ import edu.emory.mathcs.backport.java.util.Arrays;
public class ColGroupConst extends ColGroupValue {
- private static final long serialVersionUID = 3204391661346504L;
-
- /**
- * Constructor for serialization
- */
- protected ColGroupConst() {
- super();
- }
-
- /**
- * Constructs an Constant Colum Group, that contains only one tuple, with the given value.
- *
- * @param colIndices The Colum indexes for the column group.
- * @param numRows The number of rows contained in the group.
- * @param dict The dictionary containing one tuple for the entire compression.
- */
- public ColGroupConst(int[] colIndices, int numRows, ADictionary dict) {
- super(colIndices, numRows, dict);
- }
-
- @Override
- public int[] getCounts(int[] out) {
- out[0] = _numRows;
- return out;
- }
-
- @Override
- public int[] getCounts(int rl, int ru, int[] out) {
- out[0] = ru - rl;
- return out;
- }
-
- @Override
- protected void computeSum(double[] c, KahanFunction kplus) {
- c[0] += _dict.sum(getCounts(), _colIndexes.length, kplus);
- }
-
- @Override
- protected void computeRowSums(double[] c, KahanFunction kplus, int rl, int ru, boolean mean) {
- KahanObject kbuff = new KahanObject(0, 0);
- KahanPlus kplus2 = KahanPlus.getKahanPlusFnObject();
- double[] vals = _dict.sumAllRowsToDouble(kplus, kbuff, _colIndexes.length);
- for(int rix = rl; rix < ru; rix++) {
- setandExecute(c, kbuff, kplus2, vals[0], rix * (2 + (mean ? 1 : 0)));
- }
- }
-
- @Override
- protected void computeColSums(double[] c, KahanFunction kplus) {
- _dict.colSum(c, getCounts(), _colIndexes, kplus);
- }
-
- @Override
- protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru) {
- throw new DMLCompressionException(
- "Row max not supported for Const since Const is used for overlapping ColGroups, You have to materialize rows and then calculate row max");
- }
-
- @Override
- public CompressionType getCompType() {
- return CompressionType.CONST;
- }
-
- @Override
- protected ColGroupType getColGroupType() {
- return ColGroupType.CONST;
- }
-
- @Override
- public long estimateInMemorySize() {
- return ColGroupSizes.estimateInMemorySizeCONST(getNumCols(), getNumValues(), isLossy());
- }
-
- @Override
- public void decompressToBlock(MatrixBlock target, int rl, int ru) {
- final int ncol = getNumCols();
- final double[] values = getValues();
-
- for(int i = rl; i < ru; i++)
- for(int j = 0; j < ncol; j++) {
- double v = target.quickGetValue(i, _colIndexes[j]);
- target.setValue(i, _colIndexes[j], values[j] + v);
- }
- }
-
- @Override
- public void decompressToBlock(MatrixBlock target, int[] colIndexTargets) {
- int ncol = getNumCols();
- double[] values = getValues();
- for(int i = 0; i < _numRows; i++) {
- for(int colIx = 0; colIx < ncol; colIx++) {
- int origMatrixColIx = getColIndex(colIx);
- int col = colIndexTargets[origMatrixColIx];
- double cellVal = values[colIx];
- target.quickSetValue(i, col, target.quickGetValue(i, col) + cellVal);
- }
- }
- }
-
- @Override
- public void decompressToBlock(MatrixBlock target, int colpos) {
- double[] c = target.getDenseBlockValues();
-
- int nnz = 0;
- double v = _dict.getValue(Arrays.binarySearch(_colIndexes, colpos));
- if(v != 0) {
- for(int i = 0; i < c.length; i++)
- c[i] += v;
- nnz = _numRows;
- }
- target.setNonZeros(nnz);
- }
-
- @Override
- public double get(int r, int c) {
- return _dict.getValue(Arrays.binarySearch(_colIndexes, c));
- }
-
- @Override
- public void rightMultByVector(double[] b, double[] c, int rl, int ru, double[] dictVals) {
- double[] vals = preaggValues(1, b, dictVals);
- for(int i = 0; i < c.length; i++) {
- c[i] += vals[0];
- }
- }
-
- @Override
- public void rightMultByMatrix(double[] preAggregatedB, double[] c, int thatNrColumns, int rl, int ru, int cl,
- int cu) {
-
- for(int i = rl * thatNrColumns; i < ru * thatNrColumns; i += thatNrColumns)
- for(int j = i + cl; j < i + cu; j++)
- c[j] += preAggregatedB[j % thatNrColumns];
-
- }
-
- @Override
- public void rightMultBySparseMatrix(SparseRow[] rows, double[] c, int numVals, double[] dictVals, int nrColumns,
- int rl, int ru) {
- throw new DMLCompressionException(
- "Depreciated and not supported right mult by sparse matrix Please preAggregate before calling");
- }
-
- private double preAggregate(double[] a, int aRows) {
- double vals = 0;
- for(int i = 0, off = _numRows * aRows; i < _numRows; i++, off++) {
- vals += a[off];
- }
- return vals;
- }
-
- @Override
- public void leftMultByRowVector(double[] a, double[] c, int numVals) {
- double preAggVals = preAggregate(a, 0);
- double[] dictVals = getValues();
- for(int i = 0; i < _colIndexes.length; i++) {
- c[i] += preAggVals * dictVals[i];
- }
- }
-
- @Override
- public void leftMultByRowVector(double[] a, double[] c, int numVals, double[] values) {
- double preAggVals = preAggregate(a, 0);
- for(int i = 0; i < _colIndexes.length; i++) {
- c[i] += preAggVals * values[i];
- }
- }
-
- @Override
- public void leftMultByMatrix(double[] a, double[] c, double[] values, int numRows, int numCols, int rl, int ru,
- int vOff) {
- for(int i = rl; i < ru; i++) {
- double preAggVals = preAggregate(a, i);
- int offC = i * numCols;
- for(int j = 0; j < _colIndexes.length; j++) {
- c[offC + j] += preAggVals * values[j];
- }
- }
- }
-
- @Override
- public void leftMultBySparseMatrix(int spNrVals, int[] indexes, double[] sparseV, double[] c, int numVals,
- double[] values, int numRows, int numCols, int row, double[] MaterializedRow) {
- double v = 0;
- for(int i = 0; i < spNrVals; i++) {
- v += sparseV[i];
- }
- int offC = row * numCols;
- for(int j = 0; j < _colIndexes.length; j++) {
- c[offC + j] += v * values[j];
- }
- }
-
- @Override
- public ColGroup scalarOperation(ScalarOperator op) {
- return new ColGroupConst(_colIndexes, _numRows, applyScalarOp(op));
- }
-
- @Override
- public ColGroup binaryRowOp(BinaryOperator op, double[] v, boolean sparseSafe) {
- return new ColGroupConst(_colIndexes, _numRows, applyBinaryRowOp(op.fn, v, true));
- }
-
- @Override
- public Iterator<IJV> getIterator(int rl, int ru, boolean inclZeros, boolean rowMajor) {
- throw new DMLCompressionException("Unsupported Iterator of Const ColGroup");
- }
-
- @Override
- public ColGroupRowIterator getRowIterator(int rl, int ru) {
- throw new DMLCompressionException("Unsupported Row iterator of Const ColGroup");
- }
-
- @Override
- public void countNonZerosPerRow(int[] rnnz, int rl, int ru) {
-
- double[] values = _dict.getValues();
- int base = 0;
- for(int i = 0; i < values.length; i++) {
- base += values[i] == 0 ? 0 : 1;
- }
- for(int i = 0; i < ru - rl; i++) {
- rnnz[i] = base;
- }
- }
+ private static final long serialVersionUID = 3204391661346504L;
+
+ /**
+ * Constructor for serialization
+ */
+ protected ColGroupConst() {
+ super();
+ }
+
+ /**
+ * Constructs an Constant Colum Group, that contains only one tuple, with the given value.
+ *
+ * @param colIndices The Colum indexes for the column group.
+ * @param numRows The number of rows contained in the group.
+ * @param dict The dictionary containing one tuple for the entire compression.
+ */
+ public ColGroupConst(int[] colIndices, int numRows, ADictionary dict) {
+ super(colIndices, numRows, dict);
+ }
+
+ @Override
+ public int[] getCounts(int[] out) {
+ out[0] = _numRows;
+ return out;
+ }
+
+ @Override
+ public int[] getCounts(int rl, int ru, int[] out) {
+ out[0] = ru - rl;
+ return out;
+ }
+
+ @Override
+ protected void computeSum(double[] c, KahanFunction kplus) {
+ c[0] += _dict.sum(getCounts(), _colIndexes.length, kplus);
+ }
+
+ @Override
+ protected void computeRowSums(double[] c, KahanFunction kplus, int rl, int ru, boolean mean) {
+ KahanObject kbuff = new KahanObject(0, 0);
+ KahanPlus kplus2 = KahanPlus.getKahanPlusFnObject();
+ double[] vals = _dict.sumAllRowsToDouble(kplus, kbuff, _colIndexes.length);
+ for(int rix = rl; rix < ru; rix++) {
+ setandExecute(c, kbuff, kplus2, vals[0], rix * (2 + (mean ? 1 : 0)));
+ }
+ }
+
+ @Override
+ protected void computeColSums(double[] c, KahanFunction kplus) {
+ _dict.colSum(c, getCounts(), _colIndexes, kplus);
+ }
+
+ @Override
+ protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru) {
+ throw new DMLCompressionException(
+ "Row max not supported for Const since Const is used for overlapping ColGroups, You have to materialize rows and then calculate row max");
+ }
+
+ @Override
+ public CompressionType getCompType() {
+ return CompressionType.CONST;
+ }
+
+ @Override
+ protected ColGroupType getColGroupType() {
+ return ColGroupType.CONST;
+ }
+
+ @Override
+ public long estimateInMemorySize() {
+ return ColGroupSizes.estimateInMemorySizeCONST(getNumCols(), getNumValues(), isLossy());
+ }
+
+ @Override
+ public void decompressToBlock(MatrixBlock target, int rl, int ru) {
+ final int ncol = getNumCols();
+ final double[] values = getValues();
+
+ for(int i = rl; i < ru; i++)
+ for(int j = 0; j < ncol; j++) {
+ double v = target.quickGetValue(i, _colIndexes[j]);
+ target.setValue(i, _colIndexes[j], values[j] + v);
+ }
+ }
+
+ @Override
+ public void decompressToBlock(MatrixBlock target, int[] colIndexTargets) {
+ int ncol = getNumCols();
+ double[] values = getValues();
+ for(int i = 0; i < _numRows; i++) {
+ for(int colIx = 0; colIx < ncol; colIx++) {
+ int origMatrixColIx = getColIndex(colIx);
+ int col = colIndexTargets[origMatrixColIx];
+ double cellVal = values[colIx];
+ target.quickSetValue(i, col, target.quickGetValue(i, col) + cellVal);
+ }
+ }
+ }
+
+ @Override
+ public void decompressToBlock(MatrixBlock target, int colpos) {
+ double[] c = target.getDenseBlockValues();
+
+ int nnz = 0;
+ double v = _dict.getValue(Arrays.binarySearch(_colIndexes, colpos));
+ if(v != 0) {
+ for(int i = 0; i < c.length; i++)
+ c[i] += v;
+ nnz = _numRows;
+ }
+ target.setNonZeros(nnz);
+ }
+
+ @Override
+ public double get(int r, int c) {
+ return _dict.getValue(Arrays.binarySearch(_colIndexes, c));
+ }
+
+ @Override
+ public void rightMultByVector(double[] b, double[] c, int rl, int ru, double[] dictVals) {
+ double[] vals = preaggValues(1, b, dictVals);
+ for(int i = 0; i < c.length; i++) {
+ c[i] += vals[0];
+ }
+ }
+
+ @Override
+ public void rightMultByMatrix(double[] preAggregatedB, double[] c, int thatNrColumns, int rl, int ru, int cl,
+ int cu) {
+
+ for(int i = rl * thatNrColumns; i < ru * thatNrColumns; i += thatNrColumns)
+ for(int j = i + cl; j < i + cu; j++)
+ c[j] += preAggregatedB[j % thatNrColumns];
+
+ }
+
+ @Override
+ public void rightMultBySparseMatrix(SparseRow[] rows, double[] c, int numVals, double[] dictVals, int nrColumns,
+ int rl, int ru) {
+ throw new DMLCompressionException(
+ "Depreciated and not supported right mult by sparse matrix Please preAggregate before calling");
+ }
+
+ private double preAggregate(double[] a, int aRows) {
+ double vals = 0;
+ for(int i = 0, off = _numRows * aRows; i < _numRows; i++, off++) {
+ vals += a[off];
+ }
+ return vals;
+ }
+
+ @Override
+ public void leftMultByRowVector(double[] a, double[] c, int numVals) {
+ double preAggVals = preAggregate(a, 0);
+ double[] dictVals = getValues();
+ for(int i = 0; i < _colIndexes.length; i++) {
+ c[i] += preAggVals * dictVals[i];
+ }
+ }
+
+ @Override
+ public void leftMultByRowVector(double[] a, double[] c, int numVals, double[] values) {
+ double preAggVals = preAggregate(a, 0);
+ for(int i = 0; i < _colIndexes.length; i++) {
+ c[i] += preAggVals * values[i];
+ }
+ }
+
+ @Override
+ public void leftMultByMatrix(double[] a, double[] c, double[] values, int numRows, int numCols, int rl, int ru,
+ int vOff) {
+ for(int i = rl; i < ru; i++) {
+ double preAggVals = preAggregate(a, i);
+ int offC = i * numCols;
+ for(int j = 0; j < _colIndexes.length; j++) {
+ c[offC + j] += preAggVals * values[j];
+ }
+ }
+ }
+
+ @Override
+ public void leftMultBySparseMatrix(int spNrVals, int[] indexes, double[] sparseV, double[] c, int numVals,
+ double[] values, int numRows, int numCols, int row, double[] MaterializedRow) {
+ double v = 0;
+ for(int i = 0; i < spNrVals; i++) {
+ v += sparseV[i];
+ }
+ int offC = row * numCols;
+ for(int j = 0; j < _colIndexes.length; j++) {
+ c[offC + j] += v * values[j];
+ }
+ }
+
+ @Override
+ public ColGroup scalarOperation(ScalarOperator op) {
+ return new ColGroupConst(_colIndexes, _numRows, applyScalarOp(op));
+ }
+
+ @Override
+ public ColGroup binaryRowOp(BinaryOperator op, double[] v, boolean sparseSafe) {
+ return new ColGroupConst(_colIndexes, _numRows, applyBinaryRowOp(op.fn, v, true));
+ }
+
+ @Override
+ public Iterator<IJV> getIterator(int rl, int ru, boolean inclZeros, boolean rowMajor) {
+ throw new DMLCompressionException("Unsupported Iterator of Const ColGroup");
+ }
+
+ @Override
+ public ColGroupRowIterator getRowIterator(int rl, int ru) {
+ throw new DMLCompressionException("Unsupported Row iterator of Const ColGroup");
+ }
+
+ @Override
+ public void countNonZerosPerRow(int[] rnnz, int rl, int ru) {
+
+ double[] values = _dict.getValues();
+ int base = 0;
+ for(int i = 0; i < values.length; i++) {
+ base += values[i] == 0 ? 0 : 1;
+ }
+ for(int i = 0; i < ru - rl; i++) {
+ rnnz[i] = base;
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
index 793ba59..91efc2c 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
@@ -30,45 +30,45 @@ import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageTraceable;
public class VariableFEDInstruction extends FEDInstruction implements LineageTraceable {
- private static final Log LOG = LogFactory.getLog(VariableFEDInstruction.class.getName());
+ private static final Log LOG = LogFactory.getLog(VariableFEDInstruction.class.getName());
- private final VariableCPInstruction _in;
+ private final VariableCPInstruction _in;
- protected VariableFEDInstruction(VariableCPInstruction in) {
- super(null, in.getOperator(), in.getOpcode(), in.getInstructionString());
- _in = in;
- }
+ protected VariableFEDInstruction(VariableCPInstruction in) {
+ super(null, in.getOperator(), in.getOpcode(), in.getInstructionString());
+ _in = in;
+ }
- public static VariableFEDInstruction parseInstruction(VariableCPInstruction cpInstruction) {
- return new VariableFEDInstruction(cpInstruction);
- }
+ public static VariableFEDInstruction parseInstruction(VariableCPInstruction cpInstruction) {
+ return new VariableFEDInstruction(cpInstruction);
+ }
- @Override
- public void processInstruction(ExecutionContext ec) {
- VariableOperationCode opcode = _in.getVariableOpcode();
- switch(opcode) {
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ VariableOperationCode opcode = _in.getVariableOpcode();
+ switch(opcode) {
- case Write:
- processWriteInstruction(ec);
- break;
+ case Write:
+ processWriteInstruction(ec);
+ break;
- default:
- throw new DMLRuntimeException("Unsupported Opcode for federated Variable Instruction : " + opcode);
- }
- }
+ default:
+ throw new DMLRuntimeException("Unsupported Opcode for federated Variable Instruction : " + opcode);
+ }
+ }
- private void processWriteInstruction(ExecutionContext ec) {
- LOG.warn("Processing write command federated");
- // TODO Add write command to the federated site if the matrix has been modified
- // this has to be done while appending some string to the federated output file.
- // furthermore the outputted file on the federated sites path should be returned
- // the controller.
- _in.processInstruction(ec);
- }
+ private void processWriteInstruction(ExecutionContext ec) {
+ LOG.warn("Processing write command federated");
+ // TODO Add write command to the federated site if the matrix has been modified
+ // this has to be done while appending some string to the federated output file.
+ // furthermore the outputted file on the federated sites path should be returned
+ // the controller.
+ _in.processInstruction(ec);
+ }
- @Override
- public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
- return _in.getLineageItem(ec);
- }
+ @Override
+ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+ return _in.getLineageItem(ec);
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/gpu/MMTSJGPUInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/gpu/MMTSJGPUInstruction.java
index 67e6648..0a435d7 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/gpu/MMTSJGPUInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/gpu/MMTSJGPUInstruction.java
@@ -38,17 +38,17 @@ public class MMTSJGPUInstruction extends GPUInstruction {
* MMTSJGPUInstruction constructor.
*
* @param op
- * operator
+ * operator
* @param in1
- * input
+ * input
* @param type
- * left/right, left-> A' %*% A, right-> A %*% A'
+ * left/right, left-> A' %*% A, right-> A %*% A'
* @param out
- * output
+ * output
* @param opcode
- * the opcode
+ * the opcode
* @param istr
- * ?
+ * ?
*/
private MMTSJGPUInstruction(Operator op, CPOperand in1, MMTSJType type, CPOperand out, String opcode, String istr) {
super(op, opcode, istr);
@@ -58,37 +58,37 @@ public class MMTSJGPUInstruction extends GPUInstruction {
_output = out;
}
- public static MMTSJGPUInstruction parseInstruction ( String str )
- {
- String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
- InstructionUtils.checkNumFields ( parts, 3 );
+ public static MMTSJGPUInstruction parseInstruction ( String str )
+ {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ InstructionUtils.checkNumFields ( parts, 3 );
- String opcode = parts[0];
- CPOperand in1 = new CPOperand(parts[1]);
- CPOperand out = new CPOperand(parts[2]);
- MMTSJType titype = MMTSJType.valueOf(parts[3]);
+ String opcode = parts[0];
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand out = new CPOperand(parts[2]);
+ MMTSJType titype = MMTSJType.valueOf(parts[3]);
- if(!opcode.equalsIgnoreCase("tsmm"))
- throw new DMLRuntimeException("Unknown opcode while parsing an MMTSJGPUInstruction: " + str);
- else
- return new MMTSJGPUInstruction(new Operator(true), in1, titype, out, opcode, str);
- }
+ if(!opcode.equalsIgnoreCase("tsmm"))
+ throw new DMLRuntimeException("Unknown opcode while parsing an MMTSJGPUInstruction: " + str);
+ else
+ return new MMTSJGPUInstruction(new Operator(true), in1, titype, out, opcode, str);
+ }
- @Override
- public void processInstruction(ExecutionContext ec) {
- GPUStatistics.incrementNoOfExecutedGPUInst();
- MatrixObject mat = getMatrixInputForGPUInstruction(ec, _input.getName());
- boolean isLeftTransposed = ( _type == MMTSJType.LEFT);
- int rlen = (int) (isLeftTransposed? mat.getNumColumns() : mat.getNumRows());
- int clen = rlen;
- //execute operations
- ec.setMetaData(_output.getName(), rlen, clen);
- LibMatrixCUDA.matmultTSMM(ec, ec.getGPUContext(0), getExtendedOpcode(), mat, _output.getName(), isLeftTransposed);
- ec.releaseMatrixInputForGPUInstruction(_input.getName());
- ec.releaseMatrixOutputForGPUInstruction(_output.getName());
- }
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ GPUStatistics.incrementNoOfExecutedGPUInst();
+ MatrixObject mat = getMatrixInputForGPUInstruction(ec, _input.getName());
+ boolean isLeftTransposed = ( _type == MMTSJType.LEFT);
+ int rlen = (int) (isLeftTransposed? mat.getNumColumns() : mat.getNumRows());
+ int clen = rlen;
+ //execute operations
+ ec.setMetaData(_output.getName(), rlen, clen);
+ LibMatrixCUDA.matmultTSMM(ec, ec.getGPUContext(0), getExtendedOpcode(), mat, _output.getName(), isLeftTransposed);
+ ec.releaseMatrixInputForGPUInstruction(_input.getName());
+ ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+ }
- public MMTSJType getMMTSJType() {
- return _type;
- }
+ public MMTSJType getMMTSJType() {
+ return _type;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java
index 8049e87..20f4333 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java
@@ -39,81 +39,81 @@ import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import java.util.ArrayList;
public class SpoofCUDAInstruction extends GPUInstruction implements LineageTraceable {
- private final SpoofCUDA _op;
- private final CPOperand[] _in;
-
- public final CPOperand _out;
-
- private SpoofCUDAInstruction(SpoofOperator op, CPOperand[] in, CPOperand out, String opcode, String istr) {
- super(null, opcode, istr);
-
- if(!(op instanceof SpoofCUDA))
- throw new RuntimeException("SpoofGPUInstruction needs an operator of type SpoofNativeCUDA!");
-
- _op = (SpoofCUDA) op;
- _in = in;
- _out = out;
- instString = istr;
- instOpcode = opcode;
- }
-
- public static SpoofCUDAInstruction parseInstruction(String str) {
- String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
-
- ArrayList<CPOperand> inlist = new ArrayList<>();
- SpoofCUDA op = CodegenUtils.getNativeOpData(parts[2]);
- String opcode = op.getSpoofType();
-
- for( int i=3; i<parts.length-2; i++ )
- inlist.add(new CPOperand(parts[i]));
- CPOperand out = new CPOperand(parts[parts.length-2]);
-
- return new SpoofCUDAInstruction(op, inlist.toArray(new CPOperand[0]), out, opcode, str);
- }
-
- @Override
- public void processInstruction(ExecutionContext ec) {
-
- //get input matrices and scalars, incl pinning of matrices
- ArrayList<MatrixObject> inputs = new ArrayList<>();
- ArrayList<ScalarObject> scalars = new ArrayList<>();
- for (CPOperand input : _in) {
- if(input.getDataType()== Types.DataType.MATRIX)
- inputs.add(ec.getMatrixInputForGPUInstruction(input.getName(), getExtendedOpcode()));
- else if(input.getDataType()== Types.DataType.SCALAR) {
- //note: even if literal, it might be compiled as scalar placeholder
- scalars.add(ec.getScalarInput(input));
- }
- }
-
- // set the output dimensions to the hop node matrix dimensions
- if( _out.getDataType() == Types.DataType.MATRIX) {
- long rows = inputs.get(0).getNumRows();
- long cols = inputs.get(0).getNumColumns();
- if(_op.getSpoofTemplateType().contains("CW"))
- if(((CNodeCell)_op.getCNodeTemplate()).getCellType() == SpoofCellwise.CellType.COL_AGG)
- rows = 1;
- else if(((CNodeCell)_op.getCNodeTemplate()).getCellType() == SpoofCellwise.CellType.ROW_AGG)
- cols = 1;
-
- MatrixObject out_obj = ec.getDenseMatrixOutputForGPUInstruction(_out.getName(), rows, cols).getKey();
- ec.setMetaData(_out.getName(), out_obj.getNumRows(), out_obj.getNumColumns());
- _op.execute(inputs, scalars, out_obj, ec);
- ec.releaseMatrixOutputForGPUInstruction(_out.getName());
- }
- else if (_out.getDataType() == Types.DataType.SCALAR) {
- ScalarObject out = new DoubleObject(_op.execute(inputs, scalars, null, ec));
- ec.setScalarOutput(_out.getName(), out);
- }
-
- for (CPOperand input : _in)
- if(input.getDataType()== Types.DataType.MATRIX)
- ec.releaseMatrixInputForGPUInstruction(input.getName());
- }
-
- @Override
- public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
- return Pair.of(_out.getName(),
- new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, _in)));
- }
+ private final SpoofCUDA _op;
+ private final CPOperand[] _in;
+
+ public final CPOperand _out;
+
+ private SpoofCUDAInstruction(SpoofOperator op, CPOperand[] in, CPOperand out, String opcode, String istr) {
+ super(null, opcode, istr);
+
+ if(!(op instanceof SpoofCUDA))
+ throw new RuntimeException("SpoofGPUInstruction needs an operator of type SpoofNativeCUDA!");
+
+ _op = (SpoofCUDA) op;
+ _in = in;
+ _out = out;
+ instString = istr;
+ instOpcode = opcode;
+ }
+
+ public static SpoofCUDAInstruction parseInstruction(String str) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+
+ ArrayList<CPOperand> inlist = new ArrayList<>();
+ SpoofCUDA op = CodegenUtils.getNativeOpData(parts[2]);
+ String opcode = op.getSpoofType();
+
+ for( int i=3; i<parts.length-2; i++ )
+ inlist.add(new CPOperand(parts[i]));
+ CPOperand out = new CPOperand(parts[parts.length-2]);
+
+ return new SpoofCUDAInstruction(op, inlist.toArray(new CPOperand[0]), out, opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+
+ //get input matrices and scalars, incl pinning of matrices
+ ArrayList<MatrixObject> inputs = new ArrayList<>();
+ ArrayList<ScalarObject> scalars = new ArrayList<>();
+ for (CPOperand input : _in) {
+ if(input.getDataType()== Types.DataType.MATRIX)
+ inputs.add(ec.getMatrixInputForGPUInstruction(input.getName(), getExtendedOpcode()));
+ else if(input.getDataType()== Types.DataType.SCALAR) {
+ //note: even if literal, it might be compiled as scalar placeholder
+ scalars.add(ec.getScalarInput(input));
+ }
+ }
+
+ // set the output dimensions to the hop node matrix dimensions
+ if( _out.getDataType() == Types.DataType.MATRIX) {
+ long rows = inputs.get(0).getNumRows();
+ long cols = inputs.get(0).getNumColumns();
+ if(_op.getSpoofTemplateType().contains("CW"))
+ if(((CNodeCell)_op.getCNodeTemplate()).getCellType() == SpoofCellwise.CellType.COL_AGG)
+ rows = 1;
+ else if(((CNodeCell)_op.getCNodeTemplate()).getCellType() == SpoofCellwise.CellType.ROW_AGG)
+ cols = 1;
+
+ MatrixObject out_obj = ec.getDenseMatrixOutputForGPUInstruction(_out.getName(), rows, cols).getKey();
+ ec.setMetaData(_out.getName(), out_obj.getNumRows(), out_obj.getNumColumns());
+ _op.execute(inputs, scalars, out_obj, ec);
+ ec.releaseMatrixOutputForGPUInstruction(_out.getName());
+ }
+ else if (_out.getDataType() == Types.DataType.SCALAR) {
+ ScalarObject out = new DoubleObject(_op.execute(inputs, scalars, null, ec));
+ ec.setScalarOutput(_out.getName(), out);
+ }
+
+ for (CPOperand input : _in)
+ if(input.getDataType()== Types.DataType.MATRIX)
+ ec.releaseMatrixInputForGPUInstruction(input.getName());
+ }
+
+ @Override
+ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+ return Pair.of(_out.getName(),
+ new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, _in)));
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUContextPool.java b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUContextPool.java
index 25f3059..049f492 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUContextPool.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUContextPool.java
@@ -21,7 +21,6 @@
import static jcuda.driver.JCudaDriver.cuDeviceGetCount;
import static jcuda.driver.JCudaDriver.cuInit;
-import static jcuda.runtime.JCuda.cudaGetDevice;
import static jcuda.runtime.JCuda.cudaGetDeviceProperties;
import java.util.ArrayList;
diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
index 96435b7..1361637 100644
--- a/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
+++ b/src/main/java/org/apache/sysds/runtime/io/ReaderWriterFederated.java
@@ -65,153 +65,153 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
*
*/
public class ReaderWriterFederated {
- private static final Log LOG = LogFactory.getLog(ReaderWriterFederated.class.getName());
-
- /**
- * Read a federated map from disk, It is not initialized before it is used in:
- *
- * org.apache.sysds.runtime.instructions.fed.InitFEDInstruction
- *
- * @param file The file to read (defaults to HDFS)
- * @param mc The data characteristics of the file, that can be read from the mtd file.
- * @return A List of federatedRanges and Federated Data
- */
- public static List<Pair<FederatedRange, FederatedData>> read(String file, DataCharacteristics mc) {
- LOG.debug("Reading federated map from " + file);
- try {
- JobConf job = new JobConf(ConfigurationManager.getCachedJobConf());
- Path path = new Path(file);
- FileSystem fs = IOUtilFunctions.getFileSystem(path, job);
- FSDataInputStream data = fs.open(path);
- ObjectMapper mapper = new ObjectMapper();
- List<FederatedDataAddress> obj = mapper.readValue(data, new TypeReference<List<FederatedDataAddress>>() {
- });
- return obj.stream().map(x -> x.convert()).collect(Collectors.toList());
- }
- catch(Exception e) {
- throw new DMLRuntimeException("Unable to read federated matrix (" + file + ")", e);
- }
- }
-
- /**
- * TODO add writing to each of the federated locations so that they also save their matrices.
- *
- * Currently this would write the federated matrix to disk only locally.
- *
- * @param file The file to save to, (defaults to HDFS paths)
- * @param fedMap The federated map to save.
- */
- public static void write(String file, FederationMap fedMap) {
- LOG.debug("Writing federated map to " + file);
- try {
- JobConf job = new JobConf(ConfigurationManager.getCachedJobConf());
- Path path = new Path(file);
- FileSystem fs = IOUtilFunctions.getFileSystem(path, job);
- DataOutputStream out = fs.create(path, true);
- ObjectMapper mapper = new ObjectMapper();
- FederatedDataAddress[] outObjects = parseMap(fedMap.getFedMapping());
- try(BufferedWriter pw = new BufferedWriter(new OutputStreamWriter(out))) {
- mapper.writeValue(pw, outObjects);
- }
-
- IOUtilFunctions.deleteCrcFilesFromLocalFileSystem(fs, path);
- }
- catch(IOException e) {
- fail("Unable to write test federated matrix to (" + file + "): " + e.getMessage());
- }
- }
-
- private static FederatedDataAddress[] parseMap(Map<FederatedRange, FederatedData> map) {
- FederatedDataAddress[] res = new FederatedDataAddress[map.size()];
- int i = 0;
- for(Entry<FederatedRange, FederatedData> ent : map.entrySet()) {
- res[i++] = new FederatedDataAddress(ent.getKey(), ent.getValue());
- }
- return res;
- }
-
- /**
- * This class is used for easy serialization from json using Jackson. The warnings are suppressed because the
- * setters and getters only is used inside Jackson.
- */
- @SuppressWarnings("unused")
- private static class FederatedDataAddress {
- private Types.DataType _dataType;
- private InetSocketAddress _address;
- private String _filepath;
- private long[] _begin;
- private long[] _end;
-
- public FederatedDataAddress() {
- }
-
- protected FederatedDataAddress(FederatedRange fr, FederatedData fd) {
- _dataType = fd.getDataType();
- _address = fd.getAddress();
- _filepath = fd.getFilepath();
- _begin = fr.getBeginDims();
- _end = fr.getEndDims();
- }
-
- protected Pair<FederatedRange, FederatedData> convert() {
- FederatedRange fr = new FederatedRange(_begin, _end);
- FederatedData fd = new FederatedData(_dataType, _address, _filepath);
- return new ImmutablePair<>(fr, fd);
- }
-
- public String getFilepath() {
- return _filepath;
- }
-
- public void setFilepath(String filePath) {
- _filepath = filePath;
- }
-
- public Types.DataType getDataType() {
- return _dataType;
- }
-
- public void setDataType(Types.DataType dataType) {
- _dataType = dataType;
- }
-
- public InetSocketAddress getAddress() {
- return _address;
- }
-
- public void setAddress(InetSocketAddress address) {
- _address = address;
- }
-
- public long[] getBegin() {
- return _begin;
- }
-
- public void setBegin(long[] begin) {
- _begin = begin;
- }
-
- public long[] getEnd() {
- return _end;
- }
-
- public void setEnd(long[] end) {
- _end = end;
- }
-
- @Override
- public String toString() {
- StringBuilder sb = new StringBuilder();
- sb.append(_dataType);
- sb.append(" ");
- sb.append(_address);
- sb.append(" ");
- sb.append(_filepath);
- sb.append(" ");
- sb.append(Arrays.toString(_begin));
- sb.append(" ");
- sb.append(Arrays.toString(_end));
- return sb.toString();
- }
- }
+ private static final Log LOG = LogFactory.getLog(ReaderWriterFederated.class.getName());
+
+ /**
+ * Read a federated map from disk, It is not initialized before it is used in:
+ *
+ * org.apache.sysds.runtime.instructions.fed.InitFEDInstruction
+ *
+ * @param file The file to read (defaults to HDFS)
+ * @param mc The data characteristics of the file, that can be read from the mtd file.
+ * @return A List of federatedRanges and Federated Data
+ */
+ public static List<Pair<FederatedRange, FederatedData>> read(String file, DataCharacteristics mc) {
+ LOG.debug("Reading federated map from " + file);
+ try {
+ JobConf job = new JobConf(ConfigurationManager.getCachedJobConf());
+ Path path = new Path(file);
+ FileSystem fs = IOUtilFunctions.getFileSystem(path, job);
+ FSDataInputStream data = fs.open(path);
+ ObjectMapper mapper = new ObjectMapper();
+ List<FederatedDataAddress> obj = mapper.readValue(data, new TypeReference<List<FederatedDataAddress>>() {
+ });
+ return obj.stream().map(x -> x.convert()).collect(Collectors.toList());
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("Unable to read federated matrix (" + file + ")", e);
+ }
+ }
+
+ /**
+ * TODO add writing to each of the federated locations so that they also save their matrices.
+ *
+ * Currently this would write the federated matrix to disk only locally.
+ *
+ * @param file The file to save to, (defaults to HDFS paths)
+ * @param fedMap The federated map to save.
+ */
+ public static void write(String file, FederationMap fedMap) {
+ LOG.debug("Writing federated map to " + file);
+ try {
+ JobConf job = new JobConf(ConfigurationManager.getCachedJobConf());
+ Path path = new Path(file);
+ FileSystem fs = IOUtilFunctions.getFileSystem(path, job);
+ DataOutputStream out = fs.create(path, true);
+ ObjectMapper mapper = new ObjectMapper();
+ FederatedDataAddress[] outObjects = parseMap(fedMap.getFedMapping());
+ try(BufferedWriter pw = new BufferedWriter(new OutputStreamWriter(out))) {
+ mapper.writeValue(pw, outObjects);
+ }
+
+ IOUtilFunctions.deleteCrcFilesFromLocalFileSystem(fs, path);
+ }
+ catch(IOException e) {
+ fail("Unable to write test federated matrix to (" + file + "): " + e.getMessage());
+ }
+ }
+
+ private static FederatedDataAddress[] parseMap(Map<FederatedRange, FederatedData> map) {
+ FederatedDataAddress[] res = new FederatedDataAddress[map.size()];
+ int i = 0;
+ for(Entry<FederatedRange, FederatedData> ent : map.entrySet()) {
+ res[i++] = new FederatedDataAddress(ent.getKey(), ent.getValue());
+ }
+ return res;
+ }
+
+ /**
+ * This class is used for easy serialization from json using Jackson. The warnings are suppressed because the
+ * setters and getters only is used inside Jackson.
+ */
+ @SuppressWarnings("unused")
+ private static class FederatedDataAddress {
+ private Types.DataType _dataType;
+ private InetSocketAddress _address;
+ private String _filepath;
+ private long[] _begin;
+ private long[] _end;
+
+ public FederatedDataAddress() {
+ }
+
+ protected FederatedDataAddress(FederatedRange fr, FederatedData fd) {
+ _dataType = fd.getDataType();
+ _address = fd.getAddress();
+ _filepath = fd.getFilepath();
+ _begin = fr.getBeginDims();
+ _end = fr.getEndDims();
+ }
+
+ protected Pair<FederatedRange, FederatedData> convert() {
+ FederatedRange fr = new FederatedRange(_begin, _end);
+ FederatedData fd = new FederatedData(_dataType, _address, _filepath);
+ return new ImmutablePair<>(fr, fd);
+ }
+
+ public String getFilepath() {
+ return _filepath;
+ }
+
+ public void setFilepath(String filePath) {
+ _filepath = filePath;
+ }
+
+ public Types.DataType getDataType() {
+ return _dataType;
+ }
+
+ public void setDataType(Types.DataType dataType) {
+ _dataType = dataType;
+ }
+
+ public InetSocketAddress getAddress() {
+ return _address;
+ }
+
+ public void setAddress(InetSocketAddress address) {
+ _address = address;
+ }
+
+ public long[] getBegin() {
+ return _begin;
+ }
+
+ public void setBegin(long[] begin) {
+ _begin = begin;
+ }
+
+ public long[] getEnd() {
+ return _end;
+ }
+
+ public void setEnd(long[] end) {
+ _end = end;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(_dataType);
+ sb.append(" ");
+ sb.append(_address);
+ sb.append(" ");
+ sb.append(_filepath);
+ sb.append(" ");
+ sb.append(Arrays.toString(_begin));
+ sb.append(" ");
+ sb.append(Arrays.toString(_end));
+ return sb.toString();
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java
index 50d12c3..a8fce54 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDatagen.java
@@ -135,7 +135,7 @@ public class LibMatrixDatagen
}
}
- public static RandomMatrixGenerator createRandomMatrixGenerator(String pdfStr, int r, int c, int blen, double sp, double min, double max, String distParams) {
+ public static RandomMatrixGenerator createRandomMatrixGenerator(String pdfStr, int r, int c, int blen, double sp, double min, double max, String distParams) {
RandomMatrixGenerator.PDF pdf = RandomMatrixGenerator.PDF.valueOf(pdfStr.toUpperCase());
RandomMatrixGenerator rgen = null;
switch (pdf) {