You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2020/04/11 18:25:45 UTC
[systemml] branch master updated: [SYSTEMDS-12] Additional cleanup
unnecessary hop/lop indirections
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new 8adb8c0 [SYSTEMDS-12] Additional cleanup unnecessary hop/lop indirections
8adb8c0 is described below
commit 8adb8c0ec3ed9a9e867a5695a5390dc86a48413f
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Apr 11 20:25:20 2020 +0200
[SYSTEMDS-12] Additional cleanup unnecessary hop/lop indirections
This patch completes the removal of unnecessary hop/lop indirections,
operation types, and lookup tables for converting theses different
representations of the same operator semantics.
Furthermore, this also includes a minor performance improvement for
instructions string concatenation: We now reuse allocated string
builders (and their internal arrays) in a thread-local manner.
---
docs/Tasks.txt | 2 +-
src/main/java/org/apache/sysds/common/Types.java | 152 +++++++++
.../java/org/apache/sysds/hops/AggBinaryOp.java | 16 +-
.../java/org/apache/sysds/hops/AggUnaryOp.java | 16 +-
src/main/java/org/apache/sysds/hops/BinaryOp.java | 80 ++---
src/main/java/org/apache/sysds/hops/DnnOp.java | 1 +
src/main/java/org/apache/sysds/hops/Hop.java | 366 +--------------------
.../java/org/apache/sysds/hops/IndexingOp.java | 12 +-
.../java/org/apache/sysds/hops/LeftIndexingOp.java | 40 ++-
.../java/org/apache/sysds/hops/OptimizerUtils.java | 13 +-
.../apache/sysds/hops/ParameterizedBuiltinOp.java | 2 +
.../java/org/apache/sysds/hops/QuaternaryOp.java | 13 +-
src/main/java/org/apache/sysds/hops/TernaryOp.java | 1 +
src/main/java/org/apache/sysds/hops/UnaryOp.java | 16 +-
.../apache/sysds/hops/codegen/SpoofCompiler.java | 2 +-
.../codegen/opt/PlanSelectionFuseCostBasedV2.java | 2 +-
.../sysds/hops/codegen/template/TemplateCell.java | 2 +-
.../hops/codegen/template/TemplateMultiAgg.java | 2 +-
.../codegen/template/TemplateOuterProduct.java | 6 +-
.../sysds/hops/codegen/template/TemplateRow.java | 24 +-
.../sysds/hops/codegen/template/TemplateUtils.java | 16 +-
.../apache/sysds/hops/ipa/FunctionCallGraph.java | 2 +-
.../hops/ipa/IPAPassRemoveConstantBinaryOps.java | 2 +-
.../ipa/IPAPassRemoveUnnecessaryCheckpoints.java | 2 +-
.../sysds/hops/recompile/LiteralReplacement.java | 2 +-
.../apache/sysds/hops/recompile/Recompiler.java | 2 +-
.../apache/sysds/hops/rewrite/HopRewriteUtils.java | 24 +-
.../RewriteAlgebraicSimplificationDynamic.java | 8 +-
.../RewriteAlgebraicSimplificationStatic.java | 4 +-
.../hops/rewrite/RewriteCompressedReblock.java | 4 +-
.../sysds/hops/rewrite/RewriteConstantFolding.java | 4 +-
.../RewriteElementwiseMultChainOptimization.java | 19 +-
.../hops/rewrite/RewriteForLoopVectorization.java | 4 +-
.../sysds/hops/rewrite/RewriteGPUSpecificOps.java | 4 +-
.../hops/rewrite/RewriteIndexingVectorization.java | 2 +-
.../RewriteMarkLoopVariablesUpdateInPlace.java | 2 +-
.../RewriteRemoveDanglingParentReferences.java | 2 +-
.../rewrite/RewriteRemoveUnnecessaryCasts.java | 2 +-
.../RewriteSplitDagDataDependentOperators.java | 2 +-
src/main/java/org/apache/sysds/lops/Binary.java | 158 +--------
src/main/java/org/apache/sysds/lops/BinaryM.java | 79 +----
.../java/org/apache/sysds/lops/BinaryScalar.java | 100 +-----
.../org/apache/sysds/lops/BinaryUAggChain.java | 11 +-
.../java/org/apache/sysds/lops/CentralMoment.java | 46 +--
.../java/org/apache/sysds/lops/Checkpoint.java | 2 +-
.../java/org/apache/sysds/lops/Compression.java | 2 +-
src/main/java/org/apache/sysds/lops/Ctable.java | 2 +-
.../apache/sysds/lops/CumulativeOffsetBinary.java | 4 +-
.../sysds/lops/CumulativePartialAggregate.java | 2 +-
src/main/java/org/apache/sysds/lops/Data.java | 4 +-
src/main/java/org/apache/sysds/lops/DataGen.java | 101 ++----
.../java/org/apache/sysds/lops/DnnTransform.java | 8 +-
.../org/apache/sysds/lops/GroupedAggregate.java | 6 +-
src/main/java/org/apache/sysds/lops/LeftIndex.java | 12 +-
src/main/java/org/apache/sysds/lops/Lop.java | 1 +
src/main/java/org/apache/sysds/lops/MMCJ.java | 4 +-
src/main/java/org/apache/sysds/lops/MMRJ.java | 4 +-
src/main/java/org/apache/sysds/lops/MapMult.java | 4 +-
src/main/java/org/apache/sysds/lops/MatMultCP.java | 83 +++++
src/main/java/org/apache/sysds/lops/PMapMult.java | 4 +-
.../apache/sysds/lops/ParameterizedBuiltin.java | 2 +-
.../java/org/apache/sysds/lops/PickByCount.java | 4 +-
src/main/java/org/apache/sysds/lops/SortKeys.java | 4 +-
.../org/apache/sysds/lops/TernaryAggregate.java | 5 +-
src/main/java/org/apache/sysds/lops/Transform.java | 2 +-
.../java/org/apache/sysds/lops/UAggOuterChain.java | 16 +-
src/main/java/org/apache/sysds/lops/Unary.java | 226 ++-----------
src/main/java/org/apache/sysds/lops/UnaryCP.java | 124 +------
.../org/apache/sysds/lops/WeightedUnaryMM.java | 8 +-
.../org/apache/sysds/lops/WeightedUnaryMMR.java | 11 +-
.../org/apache/sysds/parser/DMLTranslator.java | 66 ++--
.../apache/sysds/parser/ParForStatementBlock.java | 4 +-
.../runtime/instructions/InstructionUtils.java | 12 +-
.../sysds/runtime/lineage/LineageItemUtils.java | 5 +-
.../sysds/runtime/lineage/LineageRewriteReuse.java | 2 +-
.../runtime/matrix/operators/BinaryOperator.java | 60 ++--
.../codegen/CPlanVectorPrimitivesTest.java | 5 +-
77 files changed, 669 insertions(+), 1399 deletions(-)
diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index 5fdd96d..7bb0c10 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -6,7 +6,7 @@ GENERAL NOTES:
SYSTEMDS-10 Compiler Rework / Misc
* 11 Support DML-bodied builtin functions OK
- * 12 Remove unnecessary HOP/LOP indirections
+ * 12 Remove unnecessary HOP/LOP indirections OK
* 13 Refactoring test cases into component/integration OK
* 14 Complete removal of external functions from all scripts
* 15 Travis integration w/ subset of tests OK
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index 52c8911..cd22f4f 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -186,6 +186,158 @@ public class Types
}
}
+ // Operations that require 1 operand
+ public enum OpOp1 {
+ ABS, ACOS, ASIN, ASSERT, ATAN, CAST_AS_SCALAR, CAST_AS_MATRIX,
+ CAST_AS_FRAME, CAST_AS_DOUBLE, CAST_AS_INT, CAST_AS_BOOLEAN,
+ CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
+ CUMSUMPROD, DETECTSCHEMA, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
+ IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
+ MEDIAN, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, SVD,
+ TAN, TANH, TYPEOF,
+ //fused ML-specific operators for performance
+ SPROP, //sample proportion: P * (1 - P)
+ SIGMOID, //sigmoid function: 1 / (1 + exp(-X))
+ LOG_NZ, //sparse-safe log; ppred(X,0,"!=")*log(X)
+
+ //low-level operators //TODO used?
+ MULT2, MINUS1_MULT, MINUS_RIGHT,
+ POW2, SUBTRACT_NZ;
+
+ @Override
+ public String toString() {
+ switch(this) {
+ case CAST_AS_SCALAR: return "castdts";
+ case CAST_AS_MATRIX: return "castdtm";
+ case CAST_AS_FRAME: return "castdtf";
+ case CAST_AS_DOUBLE: return "castvtd";
+ case CAST_AS_INT: return "castvti";
+ case CAST_AS_BOOLEAN: return "castvtb";
+ case CUMMAX: return "ucummax";
+ case CUMMIN: return "ucummin";
+ case CUMPROD: return "ucum*";
+ case CUMSUM: return "ucumk+";
+ case CUMSUMPROD: return "ucumk+*";
+ case DETECTSCHEMA: return "detectSchema";
+ case MULT2: return "*2";
+ case NOT: return "!";
+ case POW2: return "^2";
+ case TYPEOF: return "typeOf";
+ default: return name().toLowerCase();
+ }
+ }
+
+ //need to be kept consistent with toString
+ public static OpOp1 valueOfByOpcode(String opcode) {
+ switch(opcode) {
+ case "castdts": return CAST_AS_SCALAR;
+ case "castdtm": return CAST_AS_MATRIX;
+ case "castdtf": return CAST_AS_FRAME;
+ case "castvtd": return CAST_AS_DOUBLE;
+ case "castvti": return CAST_AS_INT;
+ case "castvtb": return CAST_AS_BOOLEAN;
+ case "ucummax": return CUMMAX;
+ case "ucummin": return CUMMIN;
+ case "ucum*": return CUMPROD;
+ case "ucumk+": return CUMSUM;
+ case "ucumk+*": return CUMSUMPROD;
+ case "*2": return MULT2;
+ case "!": return OpOp1.NOT;
+ case "^2": return POW2;
+ default: return valueOf(opcode.toUpperCase());
+ }
+ }
+ }
+
+ // Operations that require 2 operands
+ public enum OpOp2 {
+ AND(true), BITWAND(true), BITWOR(true), BITWSHIFTL(true), BITWSHIFTR(true),
+ BITWXOR(true), CBIND(false), CONCAT(false), COV(false), DIV(true),
+ DROP_INVALID(false), EQUAL(true), GREATER(true), GREATEREQUAL(true),
+ INTDIV(true), INTERQUANTILE(false), IQM(false), LESS(true), LESSEQUAL(true),
+ LOG(true), MAX(true), MEDIAN(false), MIN(true), MINUS(true), MODULUS(true),
+ MOMENT(false), MULT(true), NOTEQUAL(true), OR(true), PLUS(true), POW(true),
+ PRINT(false), QUANTILE(false), SOLVE(false), RBIND(false), XOR(true),
+ //fused ML-specific operators for performance
+ MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=))
+ LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
+ MINUS1_MULT(false); //1-X*Y
+
+ private final boolean _validOuter;
+
+ private OpOp2(boolean outer) {
+ _validOuter = outer;
+ }
+
+ public boolean isValidOuter() {
+ return _validOuter;
+ }
+
+ @Override
+ public String toString() {
+ switch(this) {
+ case PLUS: return "+";
+ case MINUS: return "-";
+ case MINUS_NZ: return "-nz";
+ case MINUS1_MULT: return "1-*";
+ case MULT: return "*";
+ case DIV: return "/";
+ case MODULUS: return "%%";
+ case INTDIV: return "%/%";
+ case LESSEQUAL: return "<=";
+ case LESS: return "<";
+ case GREATEREQUAL: return ">=";
+ case GREATER: return ">";
+ case EQUAL: return "==";
+ case NOTEQUAL: return "!=";
+ case OR: return "||";
+ case AND: return "&&";
+ case POW: return "^";
+ case IQM: return "IQM";
+ case MOMENT: return "cm";
+ case BITWAND: return "bitwAnd";
+ case BITWOR: return "bitwOr";
+ case BITWXOR: return "bitwXor";
+ case BITWSHIFTL: return "bitwShiftL";
+ case BITWSHIFTR: return "bitwShiftR";
+ case DROP_INVALID: return "dropInvalid";
+ default: return name().toLowerCase();
+ }
+ }
+
+ //need to be kept consistent with toString
+ public static OpOp2 valueOfByOpcode(String opcode) {
+ switch(opcode) {
+ case "+": return PLUS;
+ case "-": return MINUS;
+ case "-nz": return MINUS_NZ;
+ case "1-*": return MINUS1_MULT;
+ case "*": return MULT;
+ case "/": return DIV;
+ case "%%": return MODULUS;
+ case "%/%": return INTDIV;
+ case "<=": return LESSEQUAL;
+ case "<": return LESS;
+ case ">=": return GREATEREQUAL;
+ case ">": return GREATER;
+ case "==": return EQUAL;
+ case "!=": return NOTEQUAL;
+ case "||": return OR;
+ case "&&": return AND;
+ case "^": return POW;
+ case "IQM": return IQM;
+ case "cm": return MOMENT;
+ case "bitwAnd": return BITWAND;
+ case "bitwOr": return BITWOR;
+ case "bitwXor": return BITWXOR;
+ case "bitwShiftL": return BITWSHIFTL;
+ case "bitwShiftR": return BITWSHIFTR;
+ case "dropInvalid": return DROP_INVALID;
+ default: return valueOf(opcode.toUpperCase());
+ }
+ }
+ }
+
// Operations that require 3 operands
public enum OpOp3 {
QUANTILE, INTERQUANTILE, CTABLE, MOMENT, COV, PLUS_MULT, MINUS_MULT, IFELSE;
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index cee663f..b456cc8 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -24,10 +24,10 @@ import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
-import org.apache.sysds.lops.Binary;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.MMCJ;
@@ -38,6 +38,7 @@ import org.apache.sysds.lops.MMZip;
import org.apache.sysds.lops.MapMult;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.lops.MapMultChain.ChainType;
+import org.apache.sysds.lops.MatMultCP;
import org.apache.sysds.lops.PMMJ;
import org.apache.sysds.lops.PMapMult;
import org.apache.sysds.lops.Transform;
@@ -256,9 +257,7 @@ public class AggBinaryOp extends MultiThreadedHop
@Override
public String getOpString() {
//ba - binary aggregate, for consistency with runtime
- String s = "ba(" + outerOp.toString() +
- HopsOpOp2String.get(innerOp)+")";
- return s;
+ return "ba(" + outerOp.toString() + innerOp.toString()+")";
}
@Override
@@ -613,8 +612,7 @@ public class AggBinaryOp extends MultiThreadedHop
h1.getInput().get(0).constructLops();
Lop right = !rightTrans ? h2.constructLops() :
h2.getInput().get(0).constructLops();
- matmultCP = new Binary(left, right, Binary.OperationTypes.MATMULT,
- getDataType(), getValueType(), et, leftTrans, rightTrans);
+ matmultCP = new MatMultCP(left, right, getDataType(), getValueType(), et, leftTrans, rightTrans);
setOutputDimensions(matmultCP);
}
else {
@@ -623,8 +621,8 @@ public class AggBinaryOp extends MultiThreadedHop
}
else {
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
- matmultCP = new Binary(getInput().get(0).constructLops(),getInput().get(1).constructLops(),
- Binary.OperationTypes.MATMULT, getDataType(), getValueType(), et, k);
+ matmultCP = new MatMultCP(getInput().get(0).constructLops(),
+ getInput().get(1).constructLops(), getDataType(), getValueType(), et, k);
}
setOutputDimensions(matmultCP);
}
@@ -648,7 +646,7 @@ public class AggBinaryOp extends MultiThreadedHop
setLineNumbers(tY);
//matrix mult
- Lop mult = new Binary(tY, X.constructLops(), Binary.OperationTypes.MATMULT, getDataType(), getValueType(), ExecType.CP, k);
+ Lop mult = new MatMultCP(tY, X.constructLops(), getDataType(), getValueType(), ExecType.CP, k);
mult.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), getBlocksize(), getNnz());
setLineNumbers(mult);
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 9c3a954..f42434b 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -23,10 +23,11 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
-import org.apache.sysds.lops.Binary;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.PartialAggregate;
@@ -133,11 +134,11 @@ public class AggUnaryOp extends MultiThreadedHop
BinaryOp binput = (BinaryOp)getInput().get(0);
agg1 = new UAggOuterChain( binput.getInput().get(0).constructLops(),
binput.getInput().get(1).constructLops(), _op, _direction,
- HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.CP);
+ binput.getOp(), DataType.MATRIX, getValueType(), ExecType.CP);
PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), input.getBlocksize(), _direction);
if (getDataType() == DataType.SCALAR) {
- UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
+ UnaryCP unary1 = new UnaryCP(agg1, OpOp1.CAST_AS_SCALAR,
getDataType(), getValueType());
unary1.getOutputParameters().setDimensions(0, 0, 0, -1);
setLineNumbers(unary1);
@@ -174,15 +175,14 @@ public class AggUnaryOp extends MultiThreadedHop
BinaryOp binput = (BinaryOp)getInput().get(0);
Lop transform1 = new UAggOuterChain( binput.getInput().get(0).constructLops(),
binput.getInput().get(1).constructLops(), _op, _direction,
- HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.SPARK);
+ binput.getOp(), DataType.MATRIX, getValueType(), ExecType.SPARK);
PartialAggregate.setDimensionsBasedOnDirection(transform1, getDim1(), getDim2(), input.getBlocksize(), _direction);
setLineNumbers(transform1);
setLops(transform1);
if (getDataType() == DataType.SCALAR) {
UnaryCP unary1 = new UnaryCP(transform1,
- HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
- getDataType(), getValueType());
+ OpOp1.CAST_AS_SCALAR, getDataType(), getValueType());
unary1.getOutputParameters().setDimensions(0, 0, 0, -1);
setLineNumbers(unary1);
setLops(unary1);
@@ -202,7 +202,7 @@ public class AggUnaryOp extends MultiThreadedHop
if (getDataType() == DataType.SCALAR) {
UnaryCP unary1 = new UnaryCP(aggregate,
- HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType());
+ OpOp1.CAST_AS_SCALAR, getDataType(), getValueType());
unary1.getOutputParameters().setDimensions(0, 0, 0, -1);
setLineNumbers(unary1);
setLops(unary1);
@@ -605,7 +605,7 @@ public class AggUnaryOp extends MultiThreadedHop
et_input = et_input == ExecType.GPU ? ExecType.CP : et_input;
return new TernaryAggregate(in1, in2, in3, AggOp.SUM,
- Binary.OperationTypes.MULTIPLY, _direction, getDataType(), ValueType.FP64, et_input, k);
+ OpOp2.MULT, _direction, getDataType(), ValueType.FP64, et_input, k);
}
@Override
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 8212b53..586d675 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -23,6 +23,8 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpDnn;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
@@ -62,7 +64,7 @@ public class BinaryOp extends MultiThreadedHop
//we use the full remote memory budget (but reduced by sort buffer),
public static final double APPEND_MEM_MULTIPLIER = 1.0;
- private Hop.OpOp2 op;
+ private OpOp2 op;
private boolean outer = false;
public static AppendMethod FORCED_APPEND_METHOD = null;
@@ -89,7 +91,7 @@ public class BinaryOp extends MultiThreadedHop
//default constructor for clone
}
- public BinaryOp(String l, DataType dt, ValueType vt, Hop.OpOp2 o,
+ public BinaryOp(String l, DataType dt, ValueType vt, OpOp2 o,
Hop inp1, Hop inp2) {
super(l, dt, vt);
op = o;
@@ -307,22 +309,22 @@ public class BinaryOp extends MultiThreadedHop
// For INTERQUANTILE: 2nd argument is always a scalar
PickByCount.OperationTypes pick_op = null;
- if(op == Hop.OpOp2.QUANTILE)
+ if(op == OpOp2.QUANTILE)
pick_op = PickByCount.OperationTypes.VALUEPICK;
else
pick_op = PickByCount.OperationTypes.RANGEPICK;
SortKeys sort = SortKeys.constructSortByValueLop(
- getInput().get(0).constructLops(),
- SortKeys.OperationTypes.WithoutWeights,
- DataType.MATRIX, ValueType.FP64, et );
+ getInput().get(0).constructLops(),
+ SortKeys.OperationTypes.WithoutWeights,
+ DataType.MATRIX, ValueType.FP64, et );
sort.getOutputParameters().setDimensions(
- getInput().get(0).getDim1(),
- getInput().get(0).getDim2(),
- getInput().get(0).getBlocksize(),
- getInput().get(0).getNnz());
+ getInput().get(0).getDim1(),
+ getInput().get(0).getDim2(),
+ getInput().get(0).getBlocksize(),
+ getInput().get(0).getNnz());
PickByCount pick = new PickByCount( sort, getInput().get(1).constructLops(),
- getDataType(), getValueType(), pick_op, et, true);
+ getDataType(), getValueType(), pick_op, et, true);
setOutputDimensions(pick);
setLineNumbers(pick);
@@ -356,13 +358,11 @@ public class BinaryOp extends MultiThreadedHop
long clen = cbind ? ((getInput().get(0).dimsKnown() && getInput().get(1).dimsKnown()) ?
getInput().get(0).getDim2()+getInput().get(1).getDim2() : -1) : getInput().get(0).getDim2();
- if(et == ExecType.SPARK)
- {
+ if(et == ExecType.SPARK) {
append = constructSPAppendLop(getInput().get(0), getInput().get(1), getDataType(), getValueType(), cbind, this);
append.getOutputParameters().setDimensions(rlen, clen, getBlocksize(), getNnz());
}
- else //CP
- {
+ else { //CP
Lop offset = createOffsetLop( getInput().get(0), cbind ); //offset 1st input
append = new Append(getInput().get(0).constructLops(), getInput().get(1).constructLops(), offset, getDataType(), getValueType(), cbind, et);
append.getOutputParameters().setDimensions(rlen, clen, getBlocksize(), getNnz());
@@ -395,9 +395,8 @@ public class BinaryOp extends MultiThreadedHop
if (dt1 == dt2 && dt1 == DataType.SCALAR) {
// Both operands scalar
- BinaryScalar binScalar1 = new BinaryScalar(getInput().get(0)
- .constructLops(),getInput().get(1).constructLops(),
- HopsOpOp2LopsBS.get(op), getDataType(), getValueType());
+ BinaryScalar binScalar1 = new BinaryScalar(getInput().get(0).constructLops(),
+ getInput().get(1).constructLops(), op, getDataType(), getValueType());
binScalar1.getOutputParameters().setDimensions(0, 0, 0, -1);
setLineNumbers(binScalar1);
setLops(binScalar1);
@@ -410,34 +409,34 @@ public class BinaryOp extends MultiThreadedHop
ExecType et = optFindExecType();
//select specific operator implementations
- Unary.OperationTypes ot = null;
Hop right = getInput().get(1);
- if( op==OpOp2.POW && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==2.0 )
- ot = Unary.OperationTypes.POW2;
- else if( op==OpOp2.MULT && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==2.0 )
- ot = Unary.OperationTypes.MULTIPLY2;
- else //general case
- ot = HopsOpOp2LopsU.get(op);
-
- Unary unary1 = new Unary(getInput().get(0).constructLops(),
- getInput().get(1).constructLops(), ot, getDataType(), getValueType(), et);
-
- setOutputDimensions(unary1);
- setLineNumbers(unary1);
- setLops(unary1);
+ OpOp1 ot = (op==OpOp2.POW && HopRewriteUtils.isLiteralOfValue(right, 2d)) ? OpOp1.POW2 :
+ (op==OpOp2.MULT && HopRewriteUtils.isLiteralOfValue(right, 2d)) ? OpOp1.MULT2 : null;
+ Lop tmp = null;
+ if( ot != null ) {
+ tmp = new Unary(getInput().get(0).constructLops(),
+ getInput().get(1).constructLops(), ot, getDataType(), getValueType(), et);
+ }
+ else { //general case
+ tmp = new Binary(getInput().get(0).constructLops(),
+ getInput().get(1).constructLops(), op, getDataType(), getValueType(), et);
+ }
+ setOutputDimensions(tmp);
+ setLineNumbers(tmp);
+ setLops(tmp);
}
else
{
// Both operands are Matrixes or Tensors
ExecType et = optFindExecType();
- boolean isGPUSoftmax = et == ExecType.GPU && op == Hop.OpOp2.DIV &&
+ boolean isGPUSoftmax = et == ExecType.GPU && op == OpOp2.DIV &&
getInput().get(0) instanceof UnaryOp && getInput().get(1) instanceof AggUnaryOp &&
((UnaryOp)getInput().get(0)).getOp() == OpOp1.EXP && ((AggUnaryOp)getInput().get(1)).getOp() == AggOp.SUM &&
((AggUnaryOp)getInput().get(1)).getDirection() == Direction.Row &&
getInput().get(0) == getInput().get(1).getInput().get(0);
if(isGPUSoftmax) {
- UnaryCP softmax = new UnaryCP(getInput().get(0).getInput().get(0).constructLops(), UnaryCP.OperationTypes.SOFTMAX,
- getDataType(), getValueType(), et);
+ UnaryCP softmax = new UnaryCP(getInput().get(0).getInput().get(0).constructLops(),
+ OpOp1.SOFTMAX, getDataType(), getValueType(), et);
setOutputDimensions(softmax);
setLineNumbers(softmax);
setLops(softmax);
@@ -459,7 +458,8 @@ public class BinaryOp extends MultiThreadedHop
getDataType(), getValueType(), et, OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
}
else
- binary = new Binary(getInput().get(0).constructLops(), getInput().get(1).constructLops(), HopsOpOp2LopsB.get(op),
+ binary = new Binary(getInput().get(0).constructLops(),
+ getInput().get(1).constructLops(), op,
getDataType(), getValueType(), et);
setOutputDimensions(binary);
@@ -477,7 +477,7 @@ public class BinaryOp extends MultiThreadedHop
Lop binary = null;
if( mbin == MMBinaryMethod.MR_BINARY_UAGG_CHAIN ) {
AggUnaryOp uRight = (AggUnaryOp)right;
- binary = new BinaryUAggChain(left.constructLops(), HopsOpOp2LopsB.get(op),
+ binary = new BinaryUAggChain(left.constructLops(), op,
uRight.getOp(), uRight.getDirection(), getDataType(), getValueType(), et);
}
else if (mbin == MMBinaryMethod.MR_BINARY_M) {
@@ -485,11 +485,11 @@ public class BinaryOp extends MultiThreadedHop
(right.getDim2() == 1 && left.getDim1() == right.getDim1());
binary = new BinaryM(left.constructLops(), right.constructLops(),
- HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et, isColVector);
+ op, getDataType(), getValueType(), et, isColVector);
}
else {
binary = new Binary(left.constructLops(), right.constructLops(),
- HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et);
+ op, getDataType(), getValueType(), et);
}
setOutputDimensions(binary);
@@ -501,7 +501,7 @@ public class BinaryOp extends MultiThreadedHop
@Override
public String getOpString() {
- return "b(" + HopsOpOp2String.get(op) + ")";
+ return "b(" + op.toString() + ")";
}
@Override
diff --git a/src/main/java/org/apache/sysds/hops/DnnOp.java b/src/main/java/org/apache/sysds/hops/DnnOp.java
index 102134e..e4eed0e 100644
--- a/src/main/java/org/apache/sysds/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysds/hops/DnnOp.java
@@ -21,6 +21,7 @@ package org.apache.sysds.hops;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpDnn;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index f64b5f0..ba0dd03 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -25,13 +25,13 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.recompile.Recompiler.ResetType;
-import org.apache.sysds.lops.Binary;
-import org.apache.sysds.lops.BinaryScalar;
import org.apache.sysds.lops.CSVReBlock;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.Compression;
@@ -40,7 +40,6 @@ import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.LopsException;
import org.apache.sysds.lops.ReBlock;
-import org.apache.sysds.lops.Unary;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.parser.ParseInfo;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
@@ -57,8 +56,6 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
-import java.util.Map.Entry;
-
public abstract class Hop implements ParseInfo
{
@@ -408,27 +405,21 @@ public abstract class Hop implements ParseInfo
}
}
- public static Lop createOffsetLop( Hop hop, boolean repCols )
- {
+ public static Lop createOffsetLop( Hop hop, boolean repCols ) {
Lop offset = null;
-
- if( ConfigurationManager.isDynamicRecompilation() && hop.dimsKnown() )
- {
+ if( ConfigurationManager.isDynamicRecompilation() && hop.dimsKnown() ) {
// If dynamic recompilation is enabled and dims are known, we can replace the ncol with
// a literal in order to increase the piggybacking potential. This is safe because append
// is always marked for recompilation and hence, we have propagated the exact dimensions.
offset = Data.createLiteralLop(ValueType.INT64, String.valueOf(repCols ? hop.getDim2() : hop.getDim1()));
}
- else
- {
+ else {
offset = new UnaryCP(hop.constructLops(),
- repCols ? UnaryCP.OperationTypes.NCOL : UnaryCP.OperationTypes.NROW,
- DataType.SCALAR, ValueType.INT64);
+ repCols ? OpOp1.NCOL : OpOp1.NROW, DataType.SCALAR, ValueType.INT64);
}
offset.getOutputParameters().setDimensions(0, 0, 0, -1);
offset.setAllPositions(hop.getFilename(), hop.getBeginLine(), hop.getBeginColumn(), hop.getEndLine(), hop.getEndColumn());
-
return offset;
}
@@ -1030,343 +1021,6 @@ public abstract class Hop implements ParseInfo
_valueType = vt;
}
- @SuppressWarnings("hiding")
- public enum OpOp1 {
- NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SINH, COSH, TANH, SIGN, SQRT, LOG, EXP,
- CAST_AS_SCALAR, CAST_AS_MATRIX, CAST_AS_FRAME, CAST_AS_DOUBLE, CAST_AS_INT, CAST_AS_BOOLEAN,
- PRINT, ASSERT, EIGEN, NROW, NCOL, LENGTH, ROUND, IQM, STOP, CEIL, FLOOR, MEDIAN, INVERSE, CHOLESKY,
- SVD, EXISTS, LINEAGE, TYPEOF, DETECTSCHEMA,
- //cumulative sums, products, extreme values
- CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD,
- //checks for special values
- ISNA, ISNAN, ISINF,
- //fused ML-specific operators for performance
- SPROP, //sample proportion: P * (1 - P)
- SIGMOID, //sigmoid function: 1 / (1 + exp(-X))
- LOG_NZ, //sparse-safe log; ppred(X,0,"!=")*log(X)
- }
-
- // Operations that require two operands
- @SuppressWarnings("hiding")
- public enum OpOp2 {
- PLUS, MINUS, MULT, DIV, MODULUS, INTDIV, LESS, LESSEQUAL, GREATER, GREATEREQUAL, EQUAL, NOTEQUAL,
- MIN, MAX, AND, OR, XOR, LOG, POW, PRINT, CONCAT, QUANTILE, INTERQUANTILE, IQM,
- MOMENT, COV, CBIND, RBIND, SOLVE, MEDIAN, INVALID,
- //fused ML-specific operators for performance
- MINUS_NZ, //sparse-safe minus: X-(mean*ppred(X,0,!=))
- LOG_NZ, //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
- MINUS1_MULT, //1-X*Y
- BITWAND, BITWOR, BITWXOR, BITWSHIFTL, BITWSHIFTR, //bitwise operations
- DROP_INVALID, // frame operation for removing cells invalid wrt given data type
- }
-
- public static final HashMap<Hop.OpOp2, Binary.OperationTypes> HopsOpOp2LopsB;
- static {
- HopsOpOp2LopsB = new HashMap<>();
- HopsOpOp2LopsB.put(OpOp2.PLUS, Binary.OperationTypes.ADD);
- HopsOpOp2LopsB.put(OpOp2.MINUS, Binary.OperationTypes.SUBTRACT);
- HopsOpOp2LopsB.put(OpOp2.MULT, Binary.OperationTypes.MULTIPLY);
- HopsOpOp2LopsB.put(OpOp2.DIV, Binary.OperationTypes.DIVIDE);
- HopsOpOp2LopsB.put(OpOp2.MODULUS, Binary.OperationTypes.MODULUS);
- HopsOpOp2LopsB.put(OpOp2.INTDIV, Binary.OperationTypes.INTDIV);
- HopsOpOp2LopsB.put(OpOp2.MINUS1_MULT, Binary.OperationTypes.MINUS1_MULTIPLY);
- HopsOpOp2LopsB.put(OpOp2.LESS, Binary.OperationTypes.LESS_THAN);
- HopsOpOp2LopsB.put(OpOp2.LESSEQUAL, Binary.OperationTypes.LESS_THAN_OR_EQUALS);
- HopsOpOp2LopsB.put(OpOp2.GREATER, Binary.OperationTypes.GREATER_THAN);
- HopsOpOp2LopsB.put(OpOp2.GREATEREQUAL, Binary.OperationTypes.GREATER_THAN_OR_EQUALS);
- HopsOpOp2LopsB.put(OpOp2.EQUAL, Binary.OperationTypes.EQUALS);
- HopsOpOp2LopsB.put(OpOp2.NOTEQUAL, Binary.OperationTypes.NOT_EQUALS);
- HopsOpOp2LopsB.put(OpOp2.MIN, Binary.OperationTypes.MIN);
- HopsOpOp2LopsB.put(OpOp2.MAX, Binary.OperationTypes.MAX);
- HopsOpOp2LopsB.put(OpOp2.AND, Binary.OperationTypes.AND);
- HopsOpOp2LopsB.put(OpOp2.XOR, Binary.OperationTypes.XOR);
- HopsOpOp2LopsB.put(OpOp2.OR, Binary.OperationTypes.OR);
- HopsOpOp2LopsB.put(OpOp2.SOLVE, Binary.OperationTypes.SOLVE);
- HopsOpOp2LopsB.put(OpOp2.POW, Binary.OperationTypes.POW);
- HopsOpOp2LopsB.put(OpOp2.LOG, Binary.OperationTypes.NOTSUPPORTED);
- HopsOpOp2LopsB.put(OpOp2.BITWAND, Binary.OperationTypes.BW_AND);
- HopsOpOp2LopsB.put(OpOp2.BITWOR, Binary.OperationTypes.BW_OR);
- HopsOpOp2LopsB.put(OpOp2.BITWXOR, Binary.OperationTypes.BW_XOR);
- HopsOpOp2LopsB.put(OpOp2.BITWSHIFTL, Binary.OperationTypes.BW_SHIFTL);
- HopsOpOp2LopsB.put(OpOp2.BITWSHIFTR, Binary.OperationTypes.BW_SHIFTR);
- HopsOpOp2LopsB.put(OpOp2.DROP_INVALID, Binary.OperationTypes.DROP_INVALID);
- }
-
- protected static final HashMap<Hop.OpOp2, BinaryScalar.OperationTypes> HopsOpOp2LopsBS;
- static {
- HopsOpOp2LopsBS = new HashMap<>();
- HopsOpOp2LopsBS.put(OpOp2.PLUS, BinaryScalar.OperationTypes.ADD);
- HopsOpOp2LopsBS.put(OpOp2.MINUS, BinaryScalar.OperationTypes.SUBTRACT);
- HopsOpOp2LopsBS.put(OpOp2.MULT, BinaryScalar.OperationTypes.MULTIPLY);
- HopsOpOp2LopsBS.put(OpOp2.DIV, BinaryScalar.OperationTypes.DIVIDE);
- HopsOpOp2LopsBS.put(OpOp2.MODULUS, BinaryScalar.OperationTypes.MODULUS);
- HopsOpOp2LopsBS.put(OpOp2.INTDIV, BinaryScalar.OperationTypes.INTDIV);
- HopsOpOp2LopsBS.put(OpOp2.LESS, BinaryScalar.OperationTypes.LESS_THAN);
- HopsOpOp2LopsBS.put(OpOp2.LESSEQUAL, BinaryScalar.OperationTypes.LESS_THAN_OR_EQUALS);
- HopsOpOp2LopsBS.put(OpOp2.GREATER, BinaryScalar.OperationTypes.GREATER_THAN);
- HopsOpOp2LopsBS.put(OpOp2.GREATEREQUAL, BinaryScalar.OperationTypes.GREATER_THAN_OR_EQUALS);
- HopsOpOp2LopsBS.put(OpOp2.EQUAL, BinaryScalar.OperationTypes.EQUALS);
- HopsOpOp2LopsBS.put(OpOp2.NOTEQUAL, BinaryScalar.OperationTypes.NOT_EQUALS);
- HopsOpOp2LopsBS.put(OpOp2.MIN, BinaryScalar.OperationTypes.MIN);
- HopsOpOp2LopsBS.put(OpOp2.MAX, BinaryScalar.OperationTypes.MAX);
- HopsOpOp2LopsBS.put(OpOp2.AND, BinaryScalar.OperationTypes.AND);
- HopsOpOp2LopsBS.put(OpOp2.OR, BinaryScalar.OperationTypes.OR);
- HopsOpOp2LopsBS.put(OpOp2.XOR, BinaryScalar.OperationTypes.XOR);
- HopsOpOp2LopsBS.put(OpOp2.LOG, BinaryScalar.OperationTypes.LOG);
- HopsOpOp2LopsBS.put(OpOp2.POW, BinaryScalar.OperationTypes.POW);
- HopsOpOp2LopsBS.put(OpOp2.PRINT, BinaryScalar.OperationTypes.PRINT);
- HopsOpOp2LopsBS.put(OpOp2.BITWAND, BinaryScalar.OperationTypes.BW_AND);
- HopsOpOp2LopsBS.put(OpOp2.BITWOR, BinaryScalar.OperationTypes.BW_OR);
- HopsOpOp2LopsBS.put(OpOp2.BITWXOR, BinaryScalar.OperationTypes.BW_XOR);
- HopsOpOp2LopsBS.put(OpOp2.BITWSHIFTL, BinaryScalar.OperationTypes.BW_SHIFTL);
- HopsOpOp2LopsBS.put(OpOp2.BITWSHIFTR, BinaryScalar.OperationTypes.BW_SHIFTR);
- }
-
- protected static final HashMap<Hop.OpOp2, org.apache.sysds.lops.Unary.OperationTypes> HopsOpOp2LopsU;
- static {
- HopsOpOp2LopsU = new HashMap<>();
- HopsOpOp2LopsU.put(OpOp2.PLUS, org.apache.sysds.lops.Unary.OperationTypes.ADD);
- HopsOpOp2LopsU.put(OpOp2.MINUS, org.apache.sysds.lops.Unary.OperationTypes.SUBTRACT);
- HopsOpOp2LopsU.put(OpOp2.MULT, org.apache.sysds.lops.Unary.OperationTypes.MULTIPLY);
- HopsOpOp2LopsU.put(OpOp2.DIV, org.apache.sysds.lops.Unary.OperationTypes.DIVIDE);
- HopsOpOp2LopsU.put(OpOp2.MODULUS, org.apache.sysds.lops.Unary.OperationTypes.MODULUS);
- HopsOpOp2LopsU.put(OpOp2.INTDIV, org.apache.sysds.lops.Unary.OperationTypes.INTDIV);
- HopsOpOp2LopsU.put(OpOp2.MINUS1_MULT, org.apache.sysds.lops.Unary.OperationTypes.MINUS1_MULTIPLY);
- HopsOpOp2LopsU.put(OpOp2.LESSEQUAL, org.apache.sysds.lops.Unary.OperationTypes.LESS_THAN_OR_EQUALS);
- HopsOpOp2LopsU.put(OpOp2.LESS, org.apache.sysds.lops.Unary.OperationTypes.LESS_THAN);
- HopsOpOp2LopsU.put(OpOp2.GREATEREQUAL, org.apache.sysds.lops.Unary.OperationTypes.GREATER_THAN_OR_EQUALS);
- HopsOpOp2LopsU.put(OpOp2.GREATER, org.apache.sysds.lops.Unary.OperationTypes.GREATER_THAN);
- HopsOpOp2LopsU.put(OpOp2.EQUAL, org.apache.sysds.lops.Unary.OperationTypes.EQUALS);
- HopsOpOp2LopsU.put(OpOp2.NOTEQUAL, org.apache.sysds.lops.Unary.OperationTypes.NOT_EQUALS);
- HopsOpOp2LopsU.put(OpOp2.AND, org.apache.sysds.lops.Unary.OperationTypes.AND);
- HopsOpOp2LopsU.put(OpOp2.OR, org.apache.sysds.lops.Unary.OperationTypes.OR);
- HopsOpOp2LopsU.put(OpOp2.XOR, org.apache.sysds.lops.Unary.OperationTypes.XOR);
- HopsOpOp2LopsU.put(OpOp2.MAX, org.apache.sysds.lops.Unary.OperationTypes.MAX);
- HopsOpOp2LopsU.put(OpOp2.MIN, org.apache.sysds.lops.Unary.OperationTypes.MIN);
- HopsOpOp2LopsU.put(OpOp2.LOG, org.apache.sysds.lops.Unary.OperationTypes.LOG);
- HopsOpOp2LopsU.put(OpOp2.POW, org.apache.sysds.lops.Unary.OperationTypes.POW);
- HopsOpOp2LopsU.put(OpOp2.MINUS_NZ, org.apache.sysds.lops.Unary.OperationTypes.SUBTRACT_NZ);
- HopsOpOp2LopsU.put(OpOp2.LOG_NZ, org.apache.sysds.lops.Unary.OperationTypes.LOG_NZ);
- HopsOpOp2LopsU.put(OpOp2.BITWAND, Unary.OperationTypes.BW_AND);
- HopsOpOp2LopsU.put(OpOp2.BITWOR, Unary.OperationTypes.BW_OR);
- HopsOpOp2LopsU.put(OpOp2.BITWXOR, Unary.OperationTypes.BW_XOR);
- HopsOpOp2LopsU.put(OpOp2.BITWSHIFTL, Unary.OperationTypes.BW_SHIFTL);
- HopsOpOp2LopsU.put(OpOp2.BITWSHIFTR, Unary.OperationTypes.BW_SHIFTR);
- }
-
- protected static final HashMap<Hop.OpOp1, org.apache.sysds.lops.Unary.OperationTypes> HopsOpOp1LopsU;
- static {
- HopsOpOp1LopsU = new HashMap<>();
- HopsOpOp1LopsU.put(OpOp1.NOT, org.apache.sysds.lops.Unary.OperationTypes.NOT);
- HopsOpOp1LopsU.put(OpOp1.ABS, org.apache.sysds.lops.Unary.OperationTypes.ABS);
- HopsOpOp1LopsU.put(OpOp1.SIN, org.apache.sysds.lops.Unary.OperationTypes.SIN);
- HopsOpOp1LopsU.put(OpOp1.COS, org.apache.sysds.lops.Unary.OperationTypes.COS);
- HopsOpOp1LopsU.put(OpOp1.TAN, org.apache.sysds.lops.Unary.OperationTypes.TAN);
- HopsOpOp1LopsU.put(OpOp1.ASIN, org.apache.sysds.lops.Unary.OperationTypes.ASIN);
- HopsOpOp1LopsU.put(OpOp1.ACOS, org.apache.sysds.lops.Unary.OperationTypes.ACOS);
- HopsOpOp1LopsU.put(OpOp1.ATAN, org.apache.sysds.lops.Unary.OperationTypes.ATAN);
- HopsOpOp1LopsU.put(OpOp1.SINH, org.apache.sysds.lops.Unary.OperationTypes.SINH);
- HopsOpOp1LopsU.put(OpOp1.COSH, org.apache.sysds.lops.Unary.OperationTypes.COSH);
- HopsOpOp1LopsU.put(OpOp1.TANH, org.apache.sysds.lops.Unary.OperationTypes.TANH);
- HopsOpOp1LopsU.put(OpOp1.SIGN, org.apache.sysds.lops.Unary.OperationTypes.SIGN);
- HopsOpOp1LopsU.put(OpOp1.SQRT, org.apache.sysds.lops.Unary.OperationTypes.SQRT);
- HopsOpOp1LopsU.put(OpOp1.EXP, org.apache.sysds.lops.Unary.OperationTypes.EXP);
- HopsOpOp1LopsU.put(OpOp1.LOG, org.apache.sysds.lops.Unary.OperationTypes.LOG);
- HopsOpOp1LopsU.put(OpOp1.ROUND, org.apache.sysds.lops.Unary.OperationTypes.ROUND);
- HopsOpOp1LopsU.put(OpOp1.CEIL, org.apache.sysds.lops.Unary.OperationTypes.CEIL);
- HopsOpOp1LopsU.put(OpOp1.FLOOR, org.apache.sysds.lops.Unary.OperationTypes.FLOOR);
- HopsOpOp1LopsU.put(OpOp1.CUMSUM, org.apache.sysds.lops.Unary.OperationTypes.CUMSUM);
- HopsOpOp1LopsU.put(OpOp1.CUMPROD, org.apache.sysds.lops.Unary.OperationTypes.CUMPROD);
- HopsOpOp1LopsU.put(OpOp1.CUMMIN, org.apache.sysds.lops.Unary.OperationTypes.CUMMIN);
- HopsOpOp1LopsU.put(OpOp1.CUMMAX, org.apache.sysds.lops.Unary.OperationTypes.CUMMAX);
- HopsOpOp1LopsU.put(OpOp1.CUMSUMPROD, org.apache.sysds.lops.Unary.OperationTypes.CUMSUMPROD);
- HopsOpOp1LopsU.put(OpOp1.INVERSE, org.apache.sysds.lops.Unary.OperationTypes.INVERSE);
- HopsOpOp1LopsU.put(OpOp1.CHOLESKY, org.apache.sysds.lops.Unary.OperationTypes.CHOLESKY);
- HopsOpOp1LopsU.put(OpOp1.ISNA, org.apache.sysds.lops.Unary.OperationTypes.ISNA);
- HopsOpOp1LopsU.put(OpOp1.ISNAN, org.apache.sysds.lops.Unary.OperationTypes.ISNAN);
- HopsOpOp1LopsU.put(OpOp1.ISINF, org.apache.sysds.lops.Unary.OperationTypes.ISINF);
- HopsOpOp1LopsU.put(OpOp1.CAST_AS_SCALAR, org.apache.sysds.lops.Unary.OperationTypes.NOTSUPPORTED);
- HopsOpOp1LopsU.put(OpOp1.CAST_AS_MATRIX, org.apache.sysds.lops.Unary.OperationTypes.NOTSUPPORTED);
- HopsOpOp1LopsU.put(OpOp1.SPROP, org.apache.sysds.lops.Unary.OperationTypes.SPROP);
- HopsOpOp1LopsU.put(OpOp1.SIGMOID, Unary.OperationTypes.SIGMOID);
- HopsOpOp1LopsU.put(OpOp1.TYPEOF, Unary.OperationTypes.TYPEOF);
- HopsOpOp1LopsU.put(OpOp1.DETECTSCHEMA, Unary.OperationTypes.DETECTSCHEMA);
- HopsOpOp1LopsU.put(OpOp1.LOG_NZ, org.apache.sysds.lops.Unary.OperationTypes.LOG_NZ);
- HopsOpOp1LopsU.put(OpOp1.CAST_AS_MATRIX, org.apache.sysds.lops.Unary.OperationTypes.CAST_AS_MATRIX);
- HopsOpOp1LopsU.put(OpOp1.CAST_AS_FRAME, org.apache.sysds.lops.Unary.OperationTypes.CAST_AS_FRAME);
- }
-
- public static final HashMap<Hop.OpOp1, org.apache.sysds.lops.UnaryCP.OperationTypes> HopsOpOp1LopsUS;
- static {
- HopsOpOp1LopsUS = new HashMap<>();
- HopsOpOp1LopsUS.put(OpOp1.NOT, org.apache.sysds.lops.UnaryCP.OperationTypes.NOT);
- HopsOpOp1LopsUS.put(OpOp1.ABS, org.apache.sysds.lops.UnaryCP.OperationTypes.ABS);
- HopsOpOp1LopsUS.put(OpOp1.SIN, org.apache.sysds.lops.UnaryCP.OperationTypes.SIN);
- HopsOpOp1LopsUS.put(OpOp1.COS, org.apache.sysds.lops.UnaryCP.OperationTypes.COS);
- HopsOpOp1LopsUS.put(OpOp1.TAN, org.apache.sysds.lops.UnaryCP.OperationTypes.TAN);
- HopsOpOp1LopsUS.put(OpOp1.ASIN, org.apache.sysds.lops.UnaryCP.OperationTypes.ASIN);
- HopsOpOp1LopsUS.put(OpOp1.ACOS, org.apache.sysds.lops.UnaryCP.OperationTypes.ACOS);
- HopsOpOp1LopsUS.put(OpOp1.ATAN, org.apache.sysds.lops.UnaryCP.OperationTypes.ATAN);
- HopsOpOp1LopsUS.put(OpOp1.SINH, org.apache.sysds.lops.UnaryCP.OperationTypes.SINH);
- HopsOpOp1LopsUS.put(OpOp1.COSH, org.apache.sysds.lops.UnaryCP.OperationTypes.COSH);
- HopsOpOp1LopsUS.put(OpOp1.TANH, org.apache.sysds.lops.UnaryCP.OperationTypes.TANH);
- HopsOpOp1LopsUS.put(OpOp1.SQRT, org.apache.sysds.lops.UnaryCP.OperationTypes.SQRT);
- HopsOpOp1LopsUS.put(OpOp1.EXP, org.apache.sysds.lops.UnaryCP.OperationTypes.EXP);
- HopsOpOp1LopsUS.put(OpOp1.LOG, org.apache.sysds.lops.UnaryCP.OperationTypes.LOG);
- HopsOpOp1LopsUS.put(OpOp1.CAST_AS_SCALAR, org.apache.sysds.lops.UnaryCP.OperationTypes.CAST_AS_SCALAR);
- HopsOpOp1LopsUS.put(OpOp1.CAST_AS_MATRIX, org.apache.sysds.lops.UnaryCP.OperationTypes.CAST_AS_MATRIX);
- HopsOpOp1LopsUS.put(OpOp1.CAST_AS_FRAME, org.apache.sysds.lops.UnaryCP.OperationTypes.CAST_AS_FRAME);
- HopsOpOp1LopsUS.put(OpOp1.CAST_AS_DOUBLE, org.apache.sysds.lops.UnaryCP.OperationTypes.CAST_AS_DOUBLE);
- HopsOpOp1LopsUS.put(OpOp1.CAST_AS_INT, org.apache.sysds.lops.UnaryCP.OperationTypes.CAST_AS_INT);
- HopsOpOp1LopsUS.put(OpOp1.CAST_AS_BOOLEAN, org.apache.sysds.lops.UnaryCP.OperationTypes.CAST_AS_BOOLEAN);
- HopsOpOp1LopsUS.put(OpOp1.NROW, org.apache.sysds.lops.UnaryCP.OperationTypes.NROW);
- HopsOpOp1LopsUS.put(OpOp1.NCOL, org.apache.sysds.lops.UnaryCP.OperationTypes.NCOL);
- HopsOpOp1LopsUS.put(OpOp1.LENGTH, org.apache.sysds.lops.UnaryCP.OperationTypes.LENGTH);
- HopsOpOp1LopsUS.put(OpOp1.EXISTS, org.apache.sysds.lops.UnaryCP.OperationTypes.EXISTS);
- HopsOpOp1LopsUS.put(OpOp1.LINEAGE, org.apache.sysds.lops.UnaryCP.OperationTypes.LINEAGE);
- HopsOpOp1LopsUS.put(OpOp1.PRINT, org.apache.sysds.lops.UnaryCP.OperationTypes.PRINT);
- HopsOpOp1LopsUS.put(OpOp1.ASSERT, org.apache.sysds.lops.UnaryCP.OperationTypes.ASSERT);
- HopsOpOp1LopsUS.put(OpOp1.ROUND, org.apache.sysds.lops.UnaryCP.OperationTypes.ROUND);
- HopsOpOp1LopsUS.put(OpOp1.CEIL, org.apache.sysds.lops.UnaryCP.OperationTypes.CEIL);
- HopsOpOp1LopsUS.put(OpOp1.FLOOR, org.apache.sysds.lops.UnaryCP.OperationTypes.FLOOR);
- HopsOpOp1LopsUS.put(OpOp1.STOP, org.apache.sysds.lops.UnaryCP.OperationTypes.STOP);
- HopsOpOp1LopsUS.put(OpOp1.TYPEOF, UnaryCP.OperationTypes.TYPEOF);
- HopsOpOp1LopsUS.put(OpOp1.DETECTSCHEMA, UnaryCP.OperationTypes.DETECTSCHEMA);
- }
-
- protected static final HashMap<OpOp1, String> HopsOpOp12String;
- protected static final HashMap<String, OpOp1> HopsStringOpOp1;
-
- static {
- HopsOpOp12String = new HashMap<>();
- HopsOpOp12String.put(OpOp1.ABS, "abs");
- HopsOpOp12String.put(OpOp1.CAST_AS_SCALAR, "castAsScalar");
- HopsOpOp12String.put(OpOp1.COS, "cos");
- HopsOpOp12String.put(OpOp1.EIGEN, "eigen");
- HopsOpOp12String.put(OpOp1.SVD, "svd");
- HopsOpOp12String.put(OpOp1.EXP, "exp");
- HopsOpOp12String.put(OpOp1.IQM, "iqm");
- HopsOpOp12String.put(OpOp1.MEDIAN, "median");
- HopsOpOp12String.put(OpOp1.LENGTH, "length");
- HopsOpOp12String.put(OpOp1.LOG, "log");
- HopsOpOp12String.put(OpOp1.NCOL, "ncol");
- HopsOpOp12String.put(OpOp1.NOT, "!");
- HopsOpOp12String.put(OpOp1.NROW, "nrow");
- HopsOpOp12String.put(OpOp1.PRINT, "print");
- HopsOpOp12String.put(OpOp1.ASSERT, "assert");
- HopsOpOp12String.put(OpOp1.ROUND, "round");
- HopsOpOp12String.put(OpOp1.SIN, "sin");
- HopsOpOp12String.put(OpOp1.SQRT, "sqrt");
- HopsOpOp12String.put(OpOp1.TAN, "tan");
- HopsOpOp12String.put(OpOp1.ASIN, "asin");
- HopsOpOp12String.put(OpOp1.ACOS, "acos");
- HopsOpOp12String.put(OpOp1.ATAN, "atan");
- HopsOpOp12String.put(OpOp1.SINH, "sinh");
- HopsOpOp12String.put(OpOp1.COSH, "cosh");
- HopsOpOp12String.put(OpOp1.TANH, "tanh");
- HopsOpOp12String.put(OpOp1.STOP, "stop");
- HopsOpOp12String.put(OpOp1.INVERSE, "inv");
- HopsOpOp12String.put(OpOp1.SPROP, "sprop");
- HopsOpOp12String.put(OpOp1.SIGMOID, "sigmoid");
- HopsOpOp12String.put(OpOp1.TYPEOF, "typeOf");
- HopsOpOp12String.put(OpOp1.DETECTSCHEMA, "detectSchema");
-
- HopsStringOpOp1 = new HashMap<>();
- for( Entry<OpOp1,String> e : HopsOpOp12String.entrySet() )
- HopsStringOpOp1.put(e.getValue(), e.getKey());
- }
-
- public static OpOp1 getUnaryOpCode(String op) {
- return HopsStringOpOp1.get(op);
- }
-
- protected static final HashMap<OpOp2, String> HopsOpOp2String;
- protected static final HashMap<String,OpOp2> HopsStringOpOp2;
- static {
- HopsOpOp2String = new HashMap<>();
- HopsOpOp2String.put(OpOp2.PLUS, "+");
- HopsOpOp2String.put(OpOp2.MINUS, "-");
- HopsOpOp2String.put(OpOp2.MINUS_NZ, "-nz");
- HopsOpOp2String.put(OpOp2.MINUS1_MULT, "-1*");
- HopsOpOp2String.put(OpOp2.MULT, "*");
- HopsOpOp2String.put(OpOp2.DIV, "/");
- HopsOpOp2String.put(OpOp2.MODULUS, "%%");
- HopsOpOp2String.put(OpOp2.INTDIV, "%/%");
- HopsOpOp2String.put(OpOp2.MIN, "min");
- HopsOpOp2String.put(OpOp2.MAX, "max");
- HopsOpOp2String.put(OpOp2.LESSEQUAL, "<=");
- HopsOpOp2String.put(OpOp2.LESS, "<");
- HopsOpOp2String.put(OpOp2.GREATEREQUAL, ">=");
- HopsOpOp2String.put(OpOp2.GREATER, ">");
- HopsOpOp2String.put(OpOp2.EQUAL, "==");
- HopsOpOp2String.put(OpOp2.NOTEQUAL, "!=");
- HopsOpOp2String.put(OpOp2.OR, "|");
- HopsOpOp2String.put(OpOp2.AND, "&");
- HopsOpOp2String.put(OpOp2.LOG, "log");
- HopsOpOp2String.put(OpOp2.LOG_NZ, "log_nz");
- HopsOpOp2String.put(OpOp2.POW, "^");
- HopsOpOp2String.put(OpOp2.CONCAT, "concat");
- HopsOpOp2String.put(OpOp2.INVALID, "?");
- HopsOpOp2String.put(OpOp2.QUANTILE, "quantile");
- HopsOpOp2String.put(OpOp2.INTERQUANTILE, "interquantile");
- HopsOpOp2String.put(OpOp2.IQM, "IQM");
- HopsOpOp2String.put(OpOp2.MEDIAN, "median");
- HopsOpOp2String.put(OpOp2.MOMENT, "cm");
- HopsOpOp2String.put(OpOp2.COV, "cov");
- HopsOpOp2String.put(OpOp2.CBIND, "cbind");
- HopsOpOp2String.put(OpOp2.RBIND, "rbind");
- HopsOpOp2String.put(OpOp2.SOLVE, "solve");
- HopsOpOp2String.put(OpOp2.XOR, "xor");
- HopsOpOp2String.put(OpOp2.BITWAND, "bitwAnd");
- HopsOpOp2String.put(OpOp2.BITWOR, "bitwOr");
- HopsOpOp2String.put(OpOp2.BITWXOR, "bitwXor");
- HopsOpOp2String.put(OpOp2.BITWSHIFTL, "bitwShiftL");
- HopsOpOp2String.put(OpOp2.BITWSHIFTR, "bitwShiftR");
- HopsOpOp2String.put(OpOp2.DROP_INVALID, "dropInvalid");
-
- HopsStringOpOp2 = new HashMap<>();
- for( Entry<OpOp2,String> e : HopsOpOp2String.entrySet() )
- HopsStringOpOp2.put(e.getValue(), e.getKey());
- }
-
- public static String getBinaryOpCode(OpOp2 op) {
- return HopsOpOp2String.get(op);
- }
-
- public static OpOp2 getBinaryOpCode(String op) {
- return HopsStringOpOp2.get(op);
- }
-
- public static OpOp2 getOpOp2ForOuterVectorOperation(String op)
- {
- if( "+".equals(op) ) return OpOp2.PLUS;
- else if( "-".equals(op) ) return OpOp2.MINUS;
- else if( "*".equals(op) ) return OpOp2.MULT;
- else if( "/".equals(op) ) return OpOp2.DIV;
- else if( "%%".equals(op) ) return OpOp2.MODULUS;
- else if( "%/%".equals(op) ) return OpOp2.INTDIV;
- else if( "min".equals(op) ) return OpOp2.MIN;
- else if( "max".equals(op) ) return OpOp2.MAX;
- else if( "<=".equals(op) ) return OpOp2.LESSEQUAL;
- else if( "<".equals(op) ) return OpOp2.LESS;
- else if( ">=".equals(op) ) return OpOp2.GREATEREQUAL;
- else if( ">".equals(op) ) return OpOp2.GREATER;
- else if( "==".equals(op) ) return OpOp2.EQUAL;
- else if( "!=".equals(op) ) return OpOp2.NOTEQUAL;
- else if( "|".equals(op) ) return OpOp2.OR;
- else if( "xor".equals(op) ) return OpOp2.XOR;
- else if( "&".equals(op) ) return OpOp2.AND;
- else if( "log".equals(op) ) return OpOp2.LOG;
- else if( "^".equals(op) ) return OpOp2.POW;
- else if("bitwAnd".equals(op) ) return OpOp2.BITWAND;
- else if("bitwOr".equals(op) ) return OpOp2.BITWOR;
- else if("bitwXor".equals(op) ) return OpOp2.BITWXOR;
- else if("bitwShiftL".equals(op) ) return OpOp2.BITWSHIFTL;
- else if("bitwShiftR".equals(op) ) return OpOp2.BITWSHIFTR;
-
- return null;
- }
-
/////////////////////////////////////
// methods for dynamic re-compilation
/////////////////////////////////////
@@ -1545,12 +1199,12 @@ public abstract class Hop implements ParseInfo
if( input instanceof UnaryOp )
{
- if( ((UnaryOp)input).getOp() == Hop.OpOp1.NROW ) {
+ if( ((UnaryOp)input).getOp() == OpOp1.NROW ) {
DataCharacteristics mc = memo.getAllInputStats(input.getInput().get(0));
if( mc.rowsKnown() )
ret = mc.getRows();
}
- else if ( ((UnaryOp)input).getOp() == Hop.OpOp1.NCOL ) {
+ else if ( ((UnaryOp)input).getOp() == OpOp1.NCOL ) {
DataCharacteristics mc = memo.getAllInputStats(input.getInput().get(0));
if( mc.colsKnown() )
ret = mc.getCols();
@@ -1588,12 +1242,12 @@ public abstract class Hop implements ParseInfo
{
UnaryOp uroot = (UnaryOp) root;
long dim = -1;
- if(uroot.getOp() == Hop.OpOp1.NROW)
+ if(uroot.getOp() == OpOp1.NROW)
{
DataCharacteristics mc = memo.getAllInputStats(uroot.getInput().get(0));
dim = mc.getRows();
}
- else if( uroot.getOp() == Hop.OpOp1.NCOL )
+ else if( uroot.getOp() == OpOp1.NCOL )
{
DataCharacteristics mc = memo.getAllInputStats(uroot.getInput().get(0));
dim = mc.getCols();
diff --git a/src/main/java/org/apache/sysds/hops/IndexingOp.java b/src/main/java/org/apache/sysds/hops/IndexingOp.java
index 17cbd29..55870fd 100644
--- a/src/main/java/org/apache/sysds/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/IndexingOp.java
@@ -21,6 +21,8 @@ package org.apache.sysds.hops;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
@@ -463,10 +465,10 @@ public class IndexingOp extends Hop
return false;
}
- return ( getInput().get(0) == that.getInput().get(0)
- && getInput().get(1) == that.getInput().get(1)
- && getInput().get(2) == that.getInput().get(2)
- && getInput().get(3) == that.getInput().get(3)
- && getInput().get(4) == that.getInput().get(4));
+ return getInput().get(0) == that.getInput().get(0)
+ && getInput().get(1) == that.getInput().get(1)
+ && getInput().get(2) == that.getInput().get(2)
+ && getInput().get(3) == that.getInput().get(3)
+ && getInput().get(4) == that.getInput().get(4);
}
}
diff --git a/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java b/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java
index 149fc68..7aa26bd 100644
--- a/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java
@@ -20,6 +20,7 @@
package org.apache.sysds.hops;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.LeftIndex;
@@ -27,7 +28,6 @@ import org.apache.sysds.lops.LeftIndex.LixCacheType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.UnaryCP;
-import org.apache.sysds.lops.UnaryCP.OperationTypes;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -116,24 +116,24 @@ public class LeftIndexingOp extends Hop
Hop right = getInput().get(1);
LeftIndexingMethod method = getOptMethodLeftIndexingMethod(
- left.getDim1(), left.getDim2(), left.getBlocksize(), left.getNnz(),
- right.getDim1(), right.getDim2(), right.getNnz(), right.getDataType() );
+ left.getDim1(), left.getDim2(), left.getBlocksize(), left.getNnz(),
+ right.getDim1(), right.getDim2(), right.getNnz(), right.getDataType() );
//insert cast to matrix if necessary (for reuse broadcast runtime)
Lop rightInput = right.constructLops();
if (isRightHandSideScalar()) {
rightInput = new UnaryCP(rightInput,
- (left.getDataType()==DataType.MATRIX?OperationTypes.CAST_AS_MATRIX:OperationTypes.CAST_AS_FRAME),
+ (left.getDataType()==DataType.MATRIX?OpOp1.CAST_AS_MATRIX:OpOp1.CAST_AS_FRAME),
left.getDataType(), right.getValueType());
long bsize = ConfigurationManager.getBlocksize();
rightInput.getOutputParameters().setDimensions( 1, 1, bsize, -1);
}
LeftIndex leftIndexLop = new LeftIndex(
- left.constructLops(), rightInput,
- getInput().get(2).constructLops(), getInput().get(3).constructLops(),
- getInput().get(4).constructLops(), getInput().get(5).constructLops(),
- getDataType(), getValueType(), et, getSpLixCacheType(method));
+ left.constructLops(), rightInput,
+ getInput().get(2).constructLops(), getInput().get(3).constructLops(),
+ getInput().get(4).constructLops(), getInput().get(5).constructLops(),
+ getDataType(), getValueType(), et, getSpLixCacheType(method));
setOutputDimensions(leftIndexLop);
setLineNumbers(leftIndexLop);
@@ -142,9 +142,9 @@ public class LeftIndexingOp extends Hop
else
{
LeftIndex left = new LeftIndex(
- getInput().get(0).constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(),
- getInput().get(3).constructLops(), getInput().get(4).constructLops(), getInput().get(5).constructLops(),
- getDataType(), getValueType(), et);
+ getInput().get(0).constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(),
+ getInput().get(3).constructLops(), getInput().get(4).constructLops(), getInput().get(5).constructLops(),
+ getDataType(), getValueType(), et);
setOutputDimensions(left);
setLineNumbers(left);
@@ -186,8 +186,7 @@ public class LeftIndexingOp extends Hop
}
@Override
- public boolean allowsAllExecTypes()
- {
+ public boolean allowsAllExecTypes() {
return false;
}
@@ -426,20 +425,19 @@ public class LeftIndexingOp extends Hop
}
@Override
- public boolean compare( Hop that )
- {
+ public boolean compare( Hop that ) {
if( !(that instanceof LeftIndexingOp)
|| getInput().size() != that.getInput().size() )
{
return false;
}
- return ( getInput().get(0) == that.getInput().get(0)
- && getInput().get(1) == that.getInput().get(1)
- && getInput().get(2) == that.getInput().get(2)
- && getInput().get(3) == that.getInput().get(3)
- && getInput().get(4) == that.getInput().get(4)
- && getInput().get(5) == that.getInput().get(5));
+ return getInput().get(0) == that.getInput().get(0)
+ && getInput().get(1) == that.getInput().get(1)
+ && getInput().get(2) == that.getInput().get(2)
+ && getInput().get(3) == that.getInput().get(3)
+ && getInput().get(4) == that.getInput().get(4)
+ && getInput().get(5) == that.getInput().get(5);
}
}
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 2ffacef..3b0f694 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -24,6 +24,8 @@ import org.apache.log4j.Logger;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.common.Types.ValueType;
@@ -31,7 +33,6 @@ import org.apache.sysds.conf.CompilerConfig;
import org.apache.sysds.conf.CompilerConfig.ConfigType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.Lop;
@@ -1114,7 +1115,7 @@ public class OptimizerUtils
}
public static long getOuterNonZeros(long n1, long n2, long nnz1, long nnz2, OpOp2 op) {
- if( nnz1 < 0 || nnz2 < 0 )
+ if( nnz1 < 0 || nnz2 < 0 || op == null )
return n1 * n2;
switch(op) {
case PLUS:
@@ -1299,9 +1300,9 @@ public class OptimizerUtils
UnaryOp uroot = (UnaryOp) root;
Hop input = uroot.getInput().get(0);
- if(uroot.getOp() == Hop.OpOp1.NROW)
+ if(uroot.getOp() == OpOp1.NROW)
ret = input.rowsKnown() ? input.getDim1() : Double.MAX_VALUE;
- else if( uroot.getOp() == Hop.OpOp1.NCOL )
+ else if( uroot.getOp() == OpOp1.NCOL )
ret = input.colsKnown() ? input.getDim2() : Double.MAX_VALUE;
else
{
@@ -1337,9 +1338,9 @@ public class OptimizerUtils
UnaryOp uroot = (UnaryOp) root;
Hop input = uroot.getInput().get(0);
- if(uroot.getOp() == Hop.OpOp1.NROW)
+ if(uroot.getOp() == OpOp1.NROW)
ret = input.rowsKnown() ? input.getDim1() : Double.MAX_VALUE;
- else if( uroot.getOp() == Hop.OpOp1.NCOL )
+ else if( uroot.getOp() == OpOp1.NCOL )
ret = input.colsKnown() ? input.getDim2() : Double.MAX_VALUE;
else
{
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 07624c8..2f2352d 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -22,6 +22,8 @@ package org.apache.sysds.hops;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.common.Types.ValueType;
diff --git a/src/main/java/org/apache/sysds/hops/QuaternaryOp.java b/src/main/java/org/apache/sysds/hops/QuaternaryOp.java
index 911dc93..64df97e 100644
--- a/src/main/java/org/apache/sysds/hops/QuaternaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/QuaternaryOp.java
@@ -20,12 +20,13 @@
package org.apache.sysds.hops;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp4;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.LopsException;
-import org.apache.sysds.lops.Unary;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType;
import org.apache.sysds.lops.WeightedCrossEntropyR;
@@ -535,9 +536,8 @@ public class QuaternaryOp extends MultiThreadedHop
}
private void constructCPLopsWeightedUMM(WUMMType wtype) {
- Unary.OperationTypes uop = _uop!=null ?
- HopsOpOp1LopsU.get(_uop) : _sop==OpOp2.POW ?
- Unary.OperationTypes.POW2 : Unary.OperationTypes.MULTIPLY2;
+ OpOp1 uop = _uop!=null ? _uop : _sop==OpOp2.POW ?
+ OpOp1.POW2 : OpOp1.MULT2;
WeightedUnaryMM wumm = new WeightedUnaryMM(
getInput().get(0).constructLops(),
@@ -560,9 +560,8 @@ public class QuaternaryOp extends MultiThreadedHop
//supports single block outer products (U/V rank <= blocksize, i.e., 1000 by default); we enforce this
//by applying the hop rewrite for Weighted UnaryMM only if this constraint holds.
- Unary.OperationTypes uop = _uop!=null ?
- HopsOpOp1LopsU.get(_uop) : _sop==OpOp2.POW ?
- Unary.OperationTypes.POW2 : Unary.OperationTypes.MULTIPLY2;
+ OpOp1 uop = _uop!=null ? _uop : _sop==OpOp2.POW ?
+ OpOp1.POW2 : OpOp1.MULT2;
//Notes: Any broadcast needs to fit twice in local memory because we partition the input in cp,
//and needs to fit once in executor broadcast memory. The 2GB broadcast constraint is no longer
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index b218960..6700fe6 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -21,6 +21,7 @@ package org.apache.sysds.hops;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.ParamBuiltinOp;
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index b0e474a..f9b46a8 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -22,6 +22,7 @@ package org.apache.sysds.hops;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Checkpoint;
@@ -134,25 +135,20 @@ public class UnaryOp extends MultiThreadedHop
|| (_op == OpOp1.CAST_AS_MATRIX && getInput().get(0).getDataType()==DataType.SCALAR)
|| (_op == OpOp1.CAST_AS_FRAME && getInput().get(0).getDataType()==DataType.SCALAR))
{
- if (_op == Hop.OpOp1.IQM) //special handling IQM
+ if (_op == OpOp1.IQM) //special handling IQM
{
Lop iqmLop = constructLopsIQM();
setLops(iqmLop);
}
- else if(_op == Hop.OpOp1.MEDIAN) {
+ else if(_op == OpOp1.MEDIAN) {
Lop medianLop = constructLopsMedian();
setLops(medianLop);
}
else //general case SCALAR/CAST (always in CP)
{
- UnaryCP.OperationTypes optype = HopsOpOp1LopsUS.get(_op);
- if( optype == null )
- throw new HopsException("Unknown UnaryCP lop type for UnaryOp operation type '"+_op+"'");
-
- UnaryCP unary1 = new UnaryCP(input.constructLops(), optype, getDataType(), getValueType());
+ UnaryCP unary1 = new UnaryCP(input.constructLops(), _op, getDataType(), getValueType());
setOutputDimensions(unary1);
setLineNumbers(unary1);
-
setLops(unary1);
}
}
@@ -172,7 +168,7 @@ public class UnaryOp extends MultiThreadedHop
int k = isCumulativeUnaryOperation() || isExpensiveUnaryOperation() ?
OptimizerUtils.getConstrainedNumThreads( _maxNumThreads ) : 1;
Unary unary1 = new Unary(input.constructLops(),
- HopsOpOp1LopsU.get(_op), getDataType(), getValueType(), et, k, false);
+ _op, getDataType(), getValueType(), et, k, false);
setOutputDimensions(unary1);
setLineNumbers(unary1);
setLops(unary1);
@@ -296,7 +292,7 @@ public class UnaryOp extends MultiThreadedHop
//in-memory cum sum (of partial aggregates)
if( TEMP.getOutputParameters().getNumRows()!=1 ){
int k = OptimizerUtils.getConstrainedNumThreads( _maxNumThreads );
- Unary unary1 = new Unary( TEMP, HopsOpOp1LopsU.get(_op), DataType.MATRIX, ValueType.FP64, ExecType.CP, k, true);
+ Unary unary1 = new Unary( TEMP, _op, DataType.MATRIX, ValueType.FP64, ExecType.CP, k, true);
unary1.getOutputParameters().setDimensions(TEMP.getOutputParameters().getNumRows(), clen, blen, -1);
setLineNumbers(unary1);
TEMP = unary1;
diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
index adbb127..b3b0e32 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
@@ -34,12 +34,12 @@ import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.Hop.OpOp1;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeCell;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
diff --git a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 05c8d0e..2f8437e 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -41,6 +41,7 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.Direction;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpN;
@@ -57,7 +58,6 @@ import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.codegen.opt.ReachabilityGraph.SubProblem;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateOuterProduct;
diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateCell.java
index 4230a6b..a7f92c8 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateCell.java
@@ -41,7 +41,6 @@ import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.common.Types.AggOp;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeCell;
@@ -56,6 +55,7 @@ import org.apache.sysds.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpDnn;
diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateMultiAgg.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateMultiAgg.java
index b1718dd..31c52ce 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateMultiAgg.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateMultiAgg.java
@@ -26,8 +26,8 @@ import java.util.Set;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types.AggOp;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateOuterProduct.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateOuterProduct.java
index f796215..0b948e7 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateOuterProduct.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateOuterProduct.java
@@ -26,12 +26,12 @@ import java.util.LinkedList;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
@@ -189,14 +189,14 @@ public class TemplateOuterProduct extends TemplateBase {
if(hop instanceof UnaryOp)
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
- String primitiveOpName = ((UnaryOp)hop).getOp().toString();
+ String primitiveOpName = ((UnaryOp)hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
}
else if(hop instanceof BinaryOp)
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
- String primitiveOpName = ((BinaryOp)hop).getOp().toString();
+ String primitiveOpName = ((BinaryOp)hop).getOp().name();
if( HopRewriteUtils.isBinarySparseSafe(hop) ) {
if( TemplateUtils.isMatrix(hop.getInput().get(0)) && cdata1 instanceof CNodeData )
diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
index 5264f0e..866af1a 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
@@ -37,8 +37,6 @@ import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.hops.Hop.OpOp1;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
@@ -57,6 +55,8 @@ import org.apache.sysds.parser.Statement;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpDnn;
@@ -68,14 +68,14 @@ import org.apache.sysds.runtime.matrix.data.Pair;
public class TemplateRow extends TemplateBase
{
private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.MEAN};
- private static final Hop.OpOp1[] SUPPORTED_VECT_UNARY = new OpOp1[]{
- OpOp1.EXP, OpOp1.SQRT, OpOp1.LOG, OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.SIGN,
- OpOp1.SIN, OpOp1.COS, OpOp1.TAN, OpOp1.ASIN, OpOp1.ACOS, OpOp1.ATAN, OpOp1.SINH, OpOp1.COSH, OpOp1.TANH,
- OpOp1.CUMSUM, OpOp1.CUMMIN, OpOp1.CUMMAX, OpOp1.SPROP, OpOp1.SIGMOID};
- private static final Hop.OpOp2[] SUPPORTED_VECT_BINARY = new OpOp2[]{
- OpOp2.MULT, OpOp2.DIV, OpOp2.MINUS, OpOp2.PLUS, OpOp2.POW, OpOp2.MIN, OpOp2.MAX, OpOp2.XOR,
- OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS, OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL,
- OpOp2.BITWAND,
+ private static final OpOp1[] SUPPORTED_VECT_UNARY = new OpOp1[]{
+ OpOp1.EXP, OpOp1.SQRT, OpOp1.LOG, OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.SIGN,
+ OpOp1.SIN, OpOp1.COS, OpOp1.TAN, OpOp1.ASIN, OpOp1.ACOS, OpOp1.ATAN, OpOp1.SINH, OpOp1.COSH, OpOp1.TANH,
+ OpOp1.CUMSUM, OpOp1.CUMMIN, OpOp1.CUMMAX, OpOp1.SPROP, OpOp1.SIGMOID};
+ private static final OpOp2[] SUPPORTED_VECT_BINARY = new OpOp2[]{
+ OpOp2.MULT, OpOp2.DIV, OpOp2.MINUS, OpOp2.PLUS, OpOp2.POW, OpOp2.MIN, OpOp2.MAX, OpOp2.XOR,
+ OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS, OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL,
+ OpOp2.BITWAND,
};
public TemplateRow() {
@@ -400,7 +400,7 @@ public class TemplateRow extends TemplateBase
else //general scalar case
{
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
- String primitiveOpName = ((UnaryOp)hop).getOp().toString();
+ String primitiveOpName = ((UnaryOp)hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
}
}
@@ -450,7 +450,7 @@ public class TemplateRow extends TemplateBase
}
else //one input is a vector/scalar other is a scalar
{
- String primitiveOpName = ((BinaryOp)hop).getOp().toString();
+ String primitiveOpName = ((BinaryOp)hop).getOp().name();
if( TemplateUtils.isColVector(cdata1) )
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
if( TemplateUtils.isColVector(cdata2) //vector or vector can be inferred from lhs
diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
index b92920b..e3920d8 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
@@ -19,7 +19,12 @@
package org.apache.sysds.hops.codegen.template;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.Map;
+import java.util.Set;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.mutable.MutableInt;
@@ -35,9 +40,10 @@ import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpDnn;
import org.apache.sysds.common.Types.OpOpN;
-import org.apache.sysds.hops.Hop.OpOp1;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
@@ -326,16 +332,16 @@ public class TemplateUtils
public static LinkedList<Long> findRemovableConditionalPatternInOuterProduct(Hop hop) {
LinkedList<Long> removableHopIDs = new LinkedList<>();
- if(((BinaryOp) hop).getOp() == Hop.OpOp2.MULT) {
+ if(((BinaryOp) hop).getOp() == OpOp2.MULT) {
if (hop.getInput().get(0) instanceof BinaryOp &&
- ((BinaryOp) hop.getInput().get(0)).getOp() == Hop.OpOp2.NOTEQUAL) {
+ ((BinaryOp) hop.getInput().get(0)).getOp() == OpOp2.NOTEQUAL) {
removableHopIDs.add(hop.getHopID());
removableHopIDs.add(hop.getInput().get(0).getHopID());
removableHopIDs.add(hop.getInput().get(0).getInput().get(0).getHopID());
removableHopIDs.add(hop.getInput().get(0).getInput().get(1).getHopID());
}
else if (hop.getInput().get(1) instanceof BinaryOp &&
- ((BinaryOp) hop.getInput().get(1)).getOp() == Hop.OpOp2.NOTEQUAL) {
+ ((BinaryOp) hop.getInput().get(1)).getOp() == OpOp2.NOTEQUAL) {
removableHopIDs.add(hop.getHopID());
removableHopIDs.add(hop.getInput().get(1).getHopID());
removableHopIDs.add(hop.getInput().get(1).getInput().get(0).getHopID());
diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
index 66576ab..838890a 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
@@ -29,12 +29,12 @@ import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
-import org.apache.sysds.hops.Hop.OpOp1;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
index fecbdee..606c677 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
@@ -27,7 +27,6 @@ import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
@@ -38,6 +37,7 @@ import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpData;
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
index 7d51744..351d099 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
@@ -25,13 +25,13 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.hops.Hop.OpOp1;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatementBlock;
diff --git a/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java b/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
index f5bda16..8b3798d 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
@@ -30,12 +30,12 @@ import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.hops.Hop.OpOp1;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.runtime.DMLRuntimeException;
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index d8d72a9..2b11c73 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -26,6 +26,7 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.api.jmlc.JMLCUtils;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.ReOrgOp;
@@ -37,7 +38,6 @@ import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.FunctionOp.FunctionType;
import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.Hop.OpOp1;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index d852e44..91118bf 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -24,6 +24,8 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpData;
@@ -42,8 +44,6 @@ import org.apache.sysds.hops.DnnOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.Direction;
-import org.apache.sysds.hops.Hop.OpOp1;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LeftIndexingOp;
@@ -84,11 +84,10 @@ import java.util.List;
public class HopRewriteUtils
{
- public static boolean isValueTypeCast( OpOp1 op )
- {
- return ( op == OpOp1.CAST_AS_BOOLEAN
- || op == OpOp1.CAST_AS_INT
- || op == OpOp1.CAST_AS_DOUBLE );
+ public static boolean isValueTypeCast( OpOp1 op ) {
+ return op == OpOp1.CAST_AS_BOOLEAN
+ || op == OpOp1.CAST_AS_INT
+ || op == OpOp1.CAST_AS_DOUBLE;
}
//////////////////////////////////
@@ -595,7 +594,7 @@ public class HopRewriteUtils
}
public static UnaryOp createUnary(Hop input, String type) {
- return createUnary(input, Hop.getUnaryOpCode(type));
+ return createUnary(input, OpOp1.valueOfByOpcode(type));
}
public static UnaryOp createUnary(Hop input, OpOp1 type)
@@ -622,7 +621,7 @@ public class HopRewriteUtils
}
public static BinaryOp createBinary(Hop input1, Hop input2, String op) {
- return createBinary(input1, input2, Hop.getBinaryOpCode(op), false);
+ return createBinary(input1, input2, OpOp2.valueOfByOpcode(op), false);
}
public static BinaryOp createBinary(Hop input1, Hop input2, OpOp2 op) {
@@ -630,9 +629,9 @@ public class HopRewriteUtils
}
public static BinaryOp createBinary(Hop input1, Hop input2, OpOp2 op, boolean outer) {
- Hop mainInput = input1.getDataType().isMatrix() ? input1 :
+ Hop mainInput = input1.getDataType().isMatrix() ? input1 :
input2.getDataType().isMatrix() ? input2 : input1;
- BinaryOp bop = new BinaryOp(mainInput.getName(), mainInput.getDataType(),
+ BinaryOp bop = new BinaryOp(mainInput.getName(), mainInput.getDataType(),
mainInput.getValueType(), op, input1, input2);
//cleanup value type for relational operations
if( bop.isPPredOperation() && bop.getDataType().isScalar() )
@@ -911,8 +910,7 @@ public class HopRewriteUtils
}
public static boolean isValidOuterBinaryOp( OpOp2 op ) {
- String opcode = Hop.getBinaryOpCode(op);
- return (Hop.getOpOp2ForOuterVectorOperation(opcode) == op);
+ return op.isValidOuter();
}
public static boolean isSparse(Hop hop) {
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 46089ac..e929e0a 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -43,14 +43,14 @@ import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOp4;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ReOrgOp;
-import org.apache.sysds.hops.Hop.OpOp1;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.lops.MapMultChain.ChainType;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.common.Types.DataType;
@@ -1679,7 +1679,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( !appliedPattern
&& HopRewriteUtils.isBinary(right, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV
&& HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv
- && HopRewriteUtils.isBinary(right.getInput().get(1), Hop.OpOp2.PLUS)
+ && HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.PLUS)
&& right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
&& HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
@@ -1743,7 +1743,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
if( !appliedPattern
&& HopRewriteUtils.isBinary(left, LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV
&& HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv
- && HopRewriteUtils.isBinary(left.getInput().get(1), Hop.OpOp2.PLUS)
+ && HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.PLUS)
&& left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
&& HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 4faaaac..f51a8b3 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -41,13 +41,13 @@ import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ReOrgOp;
-import org.apache.sysds.hops.Hop.OpOp1;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.common.Types.DataType;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java
index 94990b0..b9bd5e4 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java
@@ -31,9 +31,9 @@ import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.common.Types.AggOp;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpData;
-import org.apache.sysds.hops.Hop.OpOp1;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.lops.Compression.CompressConfig;
import org.apache.sysds.lops.MMTSJ.MMTSJType;
import org.apache.sysds.parser.DMLProgram;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
index cc3e0e7..ec098e6 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
@@ -28,12 +28,12 @@ import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.hops.Hop.OpOp1;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 57419df..bbb8f0a 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -30,6 +30,7 @@ import java.util.SortedSet;
import java.util.TreeSet;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
@@ -73,7 +74,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
}
private static boolean isBinaryMult(final Hop hop) {
- return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT;
+ return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == OpOp2.MULT;
}
private static Hop rule_RewriteEMult(final Hop root) {
@@ -156,7 +157,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
if( colVectorsScalars == null )
colVectorsScalars = next;
else {
- colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
+ colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, OpOp2.MULT);
colVectorsScalars.setVisited();
}
next = iterator.hasNext() ? iterator.next() : null;
@@ -170,7 +171,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
if( rowVectors == null )
rowVectors = next;
else {
- rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
+ rowVectors = HopRewriteUtils.createBinary(rowVectors, next, OpOp2.MULT);
rowVectors.setVisited();
}
next = iterator.hasNext() ? iterator.next() : null;
@@ -184,7 +185,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
if( matrices == null )
matrices = next;
else {
- matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
+ matrices = HopRewriteUtils.createBinary(matrices, next, OpOp2.MULT);
matrices.setVisited();
}
next = iterator.hasNext() ? iterator.next() : null;
@@ -197,7 +198,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
if( other == null )
other = next;
else {
- other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT);
+ other = HopRewriteUtils.createBinary(other, next, OpOp2.MULT);
other.setVisited();
}
next = iterator.hasNext() ? iterator.next() : null;
@@ -211,21 +212,21 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
else if( other != null && matrices == null )
top = other;
else if( other != null ) { //matrices != null
- top = HopRewriteUtils.createBinary(other, matrices, Hop.OpOp2.MULT);
+ top = HopRewriteUtils.createBinary(other, matrices, OpOp2.MULT);
top.setVisited();
}
if( top == null && rowVectors != null )
top = rowVectors;
else if( rowVectors != null ) { //top != null
- top = HopRewriteUtils.createBinary(top, rowVectors, Hop.OpOp2.MULT);
+ top = HopRewriteUtils.createBinary(top, rowVectors, OpOp2.MULT);
top.setVisited();
}
if( top == null && colVectorsScalars != null )
top = colVectorsScalars;
else if( colVectorsScalars != null ) { //top != null
- top = HopRewriteUtils.createBinary(top, colVectorsScalars, Hop.OpOp2.MULT);
+ top = HopRewriteUtils.createBinary(top, colVectorsScalars, OpOp2.MULT);
top.setVisited();
}
@@ -237,7 +238,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
hop.setVisited(); // we will visit the leaves' children next
if (cnt == 1)
return hop;
- final Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+ final Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), OpOp2.POW);
pow.setVisited();
return pow;
}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
index 4ab31bd..92898bb 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
@@ -30,8 +30,6 @@ import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LeftIndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.hops.Hop.OpOp1;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatementBlock;
@@ -40,6 +38,8 @@ import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
/**
* Rule: Simplify program structure by pulling if or else statement body out
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteGPUSpecificOps.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteGPUSpecificOps.java
index d8664a9..9417ea1 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteGPUSpecificOps.java
@@ -37,11 +37,11 @@ import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.FunctionOp.FunctionType;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpDnn;
import org.apache.sysds.common.Types.ReOrgOp;
-import org.apache.sysds.hops.Hop.OpOp1;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
index b72688b..3c10f02 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
@@ -22,11 +22,11 @@ package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.List;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LeftIndexingOp;
import org.apache.sysds.hops.LiteralOp;
-import org.apache.sysds.hops.Hop.OpOp2;
/**
* Rule: Indexing vectorization. This rewrite rule set simplifies
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
index 7e30921..18bde46 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
@@ -25,11 +25,11 @@ import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.Hop.OpOp1;
import org.apache.sysds.hops.LeftIndexingOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.parser.ForStatement;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveDanglingParentReferences.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveDanglingParentReferences.java
index 6d5fc2e..e786e51 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveDanglingParentReferences.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveDanglingParentReferences.java
@@ -21,6 +21,7 @@ package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.hops.DataOp;
@@ -28,7 +29,6 @@ import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.hops.Hop.OpOp1;
/**
* This rewrite is a general-purpose cleanup pass that removes any
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveUnnecessaryCasts.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
index 74a8c67..341323e 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
@@ -23,7 +23,7 @@ import java.util.ArrayList;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.UnaryOp;
-import org.apache.sysds.hops.Hop.OpOp1;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.ValueType;
/**
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index a6951d7..fe00ae0 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -27,6 +27,7 @@ import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpN;
@@ -40,7 +41,6 @@ import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.TernaryOp;
-import org.apache.sysds.hops.Hop.OpOp1;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.StatementBlock;
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java b/src/main/java/org/apache/sysds/lops/Binary.java
index 9bda9fa..9ebe551 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -20,7 +20,9 @@
package org.apache.sysds.lops;
import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
@@ -31,18 +33,7 @@ import org.apache.sysds.common.Types.ValueType;
public class Binary extends Lop
{
- public enum OperationTypes {
- ADD, SUBTRACT, MULTIPLY, DIVIDE, MINUS1_MULTIPLY, MODULUS, INTDIV, MATMULT,
- LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS,
- AND, OR, XOR,
- MAX, MIN, POW, SOLVE, NOTSUPPORTED,
- BW_AND, BW_OR, BW_XOR, BW_SHIFTL, BW_SHIFTR, //Bitwise operations
- DROP_INVALID,
- }
-
- private OperationTypes operation;
- private int numThreads = -1;
- boolean isLeftTransposed; boolean isRightTransposed; // Used for GPU matmult operation
+ private OpOp2 operation;
/**
* Constructor to perform a binary operation.
@@ -54,29 +45,15 @@ public class Binary extends Lop
* @param vt value type
* @param et exec type
*/
- public Binary(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
- this(input1, input2, op, dt, vt, et, 1);
- }
-
- public Binary(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et, int k) {
- super(Lop.Type.Binary, dt, vt);
- init(input1, input2, op, dt, vt, et);
- numThreads = k;
- }
-
- public Binary(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et,
- boolean isLeftTransposed, boolean isRightTransposed) {
+ public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType vt, ExecType et) {
super(Lop.Type.Binary, dt, vt);
init(input1, input2, op, dt, vt, et);
- this.isLeftTransposed = isLeftTransposed;
- this.isRightTransposed = isRightTransposed;
}
- private void init(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et)
- {
+ private void init(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType vt, ExecType et) {
operation = op;
- this.addInput(input1);
- this.addInput(input2);
+ addInput(input1);
+ addInput(input2);
input1.addOutput(this);
input2.addOutput(this);
lps.setProperties( inputs, et);
@@ -84,126 +61,23 @@ public class Binary extends Lop
@Override
public String toString() {
-
return " Operation: " + operation;
-
}
-
- /**
- * method to get operation type
- * @return operation type
- */
-
- public OperationTypes getOperationType()
- {
+
+ public OpOp2 getOperationType() {
return operation;
}
- private String getOpcode()
- {
- return getOpcode( operation );
+ private String getOpcode() {
+ return operation.toString();
}
-
- public static String getOpcode( OperationTypes op ) {
- switch(op) {
- /* Arithmetic */
- case ADD:
- return "+";
- case SUBTRACT:
- return "-";
- case MULTIPLY:
- return "*";
- case DIVIDE:
- return "/";
- case MODULUS:
- return "%%";
- case INTDIV:
- return "%/%";
- case MATMULT:
- return "ba+*";
- case MINUS1_MULTIPLY:
- return "1-*";
-
- /* Relational */
- case LESS_THAN:
- return "<";
- case LESS_THAN_OR_EQUALS:
- return "<=";
- case GREATER_THAN:
- return ">";
- case GREATER_THAN_OR_EQUALS:
- return ">=";
- case EQUALS:
- return "==";
- case NOT_EQUALS:
- return "!=";
-
- /* Boolean */
- case AND:
- return "&&";
- case OR:
- return "||";
- /* Binary Builtin Function */
- case XOR:
- return "xor";
- case BW_AND:
- return "bitwAnd";
- case BW_OR:
- return "bitwOr";
- case BW_XOR:
- return "bitwXor";
- case BW_SHIFTL:
- return "bitwShiftL";
- case BW_SHIFTR:
- return "bitwShiftR";
-
- /* Builtin Functions */
- case MIN:
- return "min";
- case MAX:
- return "max";
- case POW:
- return "^";
-
- case SOLVE:
- return "solve";
-
- case DROP_INVALID:
- return "dropInvalid";
-
- default:
- throw new UnsupportedOperationException("Instruction is not defined for Binary operation: " + op);
- }
- }
-
@Override
public String getInstructions(String input1, String input2, String output) {
- StringBuilder sb = new StringBuilder();
- sb.append( getExecType() );
- sb.append( OPERAND_DELIMITOR );
- sb.append( getOpcode() );
- sb.append( OPERAND_DELIMITOR );
-
- sb.append ( getInputs().get(0).prepInputOperand(input1));
- sb.append( OPERAND_DELIMITOR );
-
- sb.append ( getInputs().get(1).prepInputOperand(input2));
- sb.append( OPERAND_DELIMITOR );
-
- sb.append( this.prepOutputOperand(output));
-
- //append degree of parallelism for matrix multiplications
- if( operation == OperationTypes.MATMULT && getExecType()==ExecType.CP ) {
- sb.append( OPERAND_DELIMITOR );
- sb.append( numThreads );
- }
- else if( operation == OperationTypes.MATMULT && getExecType()==ExecType.GPU ) {
- sb.append( OPERAND_DELIMITOR );
- sb.append( isLeftTransposed );
- sb.append( OPERAND_DELIMITOR );
- sb.append( isRightTransposed );
- }
- return sb.toString();
+ return InstructionUtils.concatOperands(
+ getExecType().toString(), getOpcode(),
+ getInputs().get(0).prepInputOperand(input1),
+ getInputs().get(1).prepInputOperand(input2),
+ prepOutputOperand(output));
}
}
diff --git a/src/main/java/org/apache/sysds/lops/BinaryM.java b/src/main/java/org/apache/sysds/lops/BinaryM.java
index 50207c2..6f35d25 100644
--- a/src/main/java/org/apache/sysds/lops/BinaryM.java
+++ b/src/main/java/org/apache/sysds/lops/BinaryM.java
@@ -19,10 +19,10 @@
package org.apache.sysds.lops;
-import org.apache.sysds.lops.Binary.OperationTypes;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
@@ -38,7 +38,7 @@ public class BinaryM extends Lop
ROW_VECTOR,
}
- private OperationTypes _operation;
+ private OpOp2 _operation;
private VectorType _vectorType = null;
/**
@@ -52,14 +52,14 @@ public class BinaryM extends Lop
* @param et exec type
* @param colVector true if colVector
*/
- public BinaryM(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et, boolean colVector ) {
+ public BinaryM(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType vt, ExecType et, boolean colVector ) {
super(Lop.Type.Binary, dt, vt);
_operation = op;
_vectorType = colVector ? VectorType.COL_VECTOR : VectorType.ROW_VECTOR;
- this.addInput(input1);
- this.addInput(input2);
+ addInput(input1);
+ addInput(input2);
input1.addOutput(this);
input2.addOutput(this);
@@ -71,78 +71,21 @@ public class BinaryM extends Lop
}
}
-
@Override
- public String toString()
- {
+ public String toString() {
return " Operation: " + _operation;
}
- /**
- * method to get operation type
- * @return operation type
- */
-
- public OperationTypes getOperationType()
- {
+ public OpOp2 getOperationType() {
return _operation;
}
- private String getOpcode()
- {
- return getOpcode( _operation );
+ private String getOpcode() {
+ return getOpcode(_operation);
}
- public static String getOpcode( OperationTypes op ) {
- switch(op) {
- /* Arithmetic */
- case ADD:
- return "map+";
- case SUBTRACT:
- return "map-";
- case MULTIPLY:
- return "map*";
- case DIVIDE:
- return "map/";
- case MODULUS:
- return "map%%";
- case INTDIV:
- return "map%/%";
- case MINUS1_MULTIPLY:
- return "map1-*";
-
- /* Relational */
- case LESS_THAN:
- return "map<";
- case LESS_THAN_OR_EQUALS:
- return "map<=";
- case GREATER_THAN:
- return "map>";
- case GREATER_THAN_OR_EQUALS:
- return "map>=";
- case EQUALS:
- return "map==";
- case NOT_EQUALS:
- return "map!=";
-
- /* Boolean */
- case AND:
- return "map&&";
- case OR:
- return "map||";
-
-
- /* Builtin Functions */
- case MIN:
- return "mapmin";
- case MAX:
- return "mapmax";
- case POW:
- return "map^";
-
- default:
- throw new UnsupportedOperationException("Instruction is not defined for Binary operation: " + op);
- }
+ public static String getOpcode(OpOp2 op) {
+ return "map"+op.toString();
}
public static boolean isOpcode(String opcode) {
diff --git a/src/main/java/org/apache/sysds/lops/BinaryScalar.java b/src/main/java/org/apache/sysds/lops/BinaryScalar.java
index c827acf..78af824 100644
--- a/src/main/java/org/apache/sysds/lops/BinaryScalar.java
+++ b/src/main/java/org/apache/sysds/lops/BinaryScalar.java
@@ -23,6 +23,7 @@ package org.apache.sysds.lops;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
/**
@@ -31,16 +32,7 @@ import org.apache.sysds.common.Types.ValueType;
*/
public class BinaryScalar extends Lop
{
- @SuppressWarnings("hiding")
- public enum OperationTypes {
- ADD, SUBTRACT, MULTIPLY, DIVIDE, MODULUS, INTDIV,
- LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS,
- AND, OR, XOR,
- LOG,POW,MAX,MIN,PRINT,IQSIZE,
- BW_AND, BW_OR, BW_XOR, BW_SHIFTL, BW_SHIFTR, //Bitwise operations
- }
-
- private final OperationTypes operation;
+ private final OpOp2 operation;
/**
* Constructor to perform a scalar operation
@@ -51,11 +43,11 @@ public class BinaryScalar extends Lop
* @param dt data type
* @param vt value type
*/
- public BinaryScalar(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt) {
+ public BinaryScalar(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType vt) {
super(Lop.Type.BinaryCP, dt, vt);
operation = op;
- this.addInput(input1);
- this.addInput(input2);
+ addInput(input1);
+ addInput(input2);
input1.addOutput(this);
input2.addOutput(this);
lps.setProperties(inputs, ExecType.CP);
@@ -66,7 +58,7 @@ public class BinaryScalar extends Lop
return "Operation: " + operation;
}
- public OperationTypes getOperationType() {
+ public OpOp2 getOperationType() {
return operation;
}
@@ -74,89 +66,11 @@ public class BinaryScalar extends Lop
public Lop.SimpleInstType getSimpleInstructionType() {
return SimpleInstType.Scalar;
}
-
- public static String getOpcode( OperationTypes op )
- {
- if( op == null )
- throw new UnsupportedOperationException("Unable to get opcode for 'null'.");
-
- switch ( op )
- {
- /* Arithmetic */
- case ADD:
- return "+";
- case SUBTRACT:
- return "-";
- case MULTIPLY:
- return "*";
- case DIVIDE:
- return "/";
- case MODULUS:
- return "%%";
- case INTDIV:
- return "%/%";
- case POW:
- return "^";
-
- /* Relational */
- case LESS_THAN:
- return "<";
- case LESS_THAN_OR_EQUALS:
- return "<=";
- case GREATER_THAN:
- return ">";
- case GREATER_THAN_OR_EQUALS:
- return ">=";
- case EQUALS:
- return "==";
- case NOT_EQUALS:
- return "!=";
-
- /* Boolean */
- case AND:
- return "&&";
- case OR:
- return "||";
- /* Boolean built in binary function */
- case XOR:
- return "xor";
- case BW_AND:
- return "bitwAnd";
- case BW_OR:
- return "bitwOr";
- case BW_XOR:
- return "bitwXor";
- case BW_SHIFTL:
- return "bitwShiftL";
- case BW_SHIFTR:
- return "bitwShiftR";
-
- /* Builtin Functions */
- case LOG:
- return "log";
- case MIN:
- return "min";
- case MAX:
- return "max";
-
- case PRINT:
- return "print";
-
- case IQSIZE:
- return "iqsize";
-
- default:
- throw new UnsupportedOperationException("Instruction "
- + "is not defined for BinaryScalar operator: " + op);
- }
- }
-
@Override
public String getInstructions(String input1, String input2, String output) {
return InstructionUtils.concatOperands(
- getExecType().name(),
- getOpcode(operation),
+ getExecType().name(), operation.toString(),
getInputs().get(0).prepScalarInputOperand(getExecType()),
getInputs().get(1).prepScalarInputOperand(getExecType()),
prepOutputOperand(output));
diff --git a/src/main/java/org/apache/sysds/lops/BinaryUAggChain.java b/src/main/java/org/apache/sysds/lops/BinaryUAggChain.java
index cf7f33e..55170a5 100644
--- a/src/main/java/org/apache/sysds/lops/BinaryUAggChain.java
+++ b/src/main/java/org/apache/sysds/lops/BinaryUAggChain.java
@@ -24,16 +24,16 @@ import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
public class BinaryUAggChain extends Lop
{
-
public static final String OPCODE = "binuaggchain";
//outer operation
- private Binary.OperationTypes _binOp = null;
+ private OpOp2 _binOp = null;
//inner operation
private AggOp _uaggOp = null;
private Direction _uaggDir = null;
@@ -51,7 +51,7 @@ public class BinaryUAggChain extends Lop
* @param vt value type
* @param et execution type
*/
- public BinaryUAggChain(Lop input1, Binary.OperationTypes bop, AggOp uaop, Direction uadir, DataType dt, ValueType vt, ExecType et) {
+ public BinaryUAggChain(Lop input1, OpOp2 bop, AggOp uaop, Direction uadir, DataType dt, ValueType vt, ExecType et) {
super(Lop.Type.BinUaggChain, dt, vt);
addInput(input1); //X
input1.addOutput(this);
@@ -71,9 +71,8 @@ public class BinaryUAggChain extends Lop
@Override
public String getInstructions(String input1, String output) {
return InstructionUtils.concatOperands(
- getExecType().name(),
- OPCODE,
- Binary.getOpcode(_binOp), //outer opcode
+ getExecType().name(), OPCODE,
+ _binOp.toString(), //outer opcode
PartialAggregate.getOpcode(_uaggOp, _uaggDir), //inner opcode
getInputs().get(0).prepInputOperand(input1),
prepOutputOperand(output));
diff --git a/src/main/java/org/apache/sysds/lops/CentralMoment.java b/src/main/java/org/apache/sysds/lops/CentralMoment.java
index d1ea217..3f0214f 100644
--- a/src/main/java/org/apache/sysds/lops/CentralMoment.java
+++ b/src/main/java/org/apache/sysds/lops/CentralMoment.java
@@ -20,6 +20,7 @@
package org.apache.sysds.lops;
import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
@@ -29,7 +30,6 @@ import org.apache.sysds.common.Types.ValueType;
*/
public class CentralMoment extends Lop
{
-
/**
* Constructor to perform central moment.
* input1 <- data (weighted or unweighted)
@@ -41,14 +41,14 @@ public class CentralMoment extends Lop
* @param et execution type
*/
private void init(Lop input1, Lop input2, Lop input3, ExecType et) {
- this.addInput(input1);
- this.addInput(input2);
+ addInput(input1);
+ addInput(input2);
input1.addOutput(this);
input2.addOutput(this);
// when executing in CP, this lop takes an optional 3rd input (Weights)
if ( input3 != null ) {
- this.addInput(input3);
+ addInput(input3);
input3.addOutput(this);
}
lps.setProperties(inputs, et);
@@ -77,31 +77,21 @@ public class CentralMoment extends Lop
*/
@Override
public String getInstructions(String input1, String input2, String input3, String output) {
- StringBuilder sb = new StringBuilder();
- sb.append( getExecType() );
- sb.append( Lop.OPERAND_DELIMITOR );
-
- sb.append( "cm" );
- sb.append( OPERAND_DELIMITOR );
-
- // Input data
- sb.append( getInputs().get(0).prepInputOperand(input1) );
- sb.append( OPERAND_DELIMITOR );
-
- // Weights
- if( input3 != null ) {
- sb.append( getInputs().get(1).prepInputOperand(input2) );
- sb.append( OPERAND_DELIMITOR );
+ if( input3 == null ) {
+ return InstructionUtils.concatOperands(
+ getExecType().toString(), "cm",
+ getInputs().get(0).prepInputOperand(input1),
+ getInputs().get((input3!=null)?2:1).prepScalarInputOperand(getExecType()),
+ prepOutputOperand(output));
+ }
+ else {
+ return InstructionUtils.concatOperands(
+ getExecType().toString(), "cm",
+ getInputs().get(0).prepInputOperand(input1),
+ getInputs().get(1).prepInputOperand(input2),
+ getInputs().get((input3!=null)?2:1).prepScalarInputOperand(getExecType()),
+ prepOutputOperand(output));
}
-
- // Order
- sb.append( getInputs().get((input3!=null)?2:1)
- .prepScalarInputOperand(getExecType()) );
- sb.append( OPERAND_DELIMITOR );
-
- sb.append( prepOutputOperand(output));
-
- return sb.toString();
}
/**
diff --git a/src/main/java/org/apache/sysds/lops/Checkpoint.java b/src/main/java/org/apache/sysds/lops/Checkpoint.java
index 97a0b9f..958e71a 100644
--- a/src/main/java/org/apache/sysds/lops/Checkpoint.java
+++ b/src/main/java/org/apache/sysds/lops/Checkpoint.java
@@ -58,7 +58,7 @@ public class Checkpoint extends Lop
*/
public Checkpoint(Lop input, DataType dt, ValueType vt, String level) {
super(Lop.Type.Checkpoint, dt, vt);
- this.addInput(input);
+ addInput(input);
input.addOutput(this);
_storageLevel = StorageLevel.fromString(level);
diff --git a/src/main/java/org/apache/sysds/lops/Compression.java b/src/main/java/org/apache/sysds/lops/Compression.java
index d6e191b..4270e34 100644
--- a/src/main/java/org/apache/sysds/lops/Compression.java
+++ b/src/main/java/org/apache/sysds/lops/Compression.java
@@ -38,7 +38,7 @@ public class Compression extends Lop
public Compression(Lop input, DataType dt, ValueType vt, ExecType et) {
super(Lop.Type.Checkpoint, dt, vt);
- this.addInput(input);
+ addInput(input);
input.addOutput(this);
lps.setProperties(inputs, et);
}
diff --git a/src/main/java/org/apache/sysds/lops/Ctable.java b/src/main/java/org/apache/sysds/lops/Ctable.java
index cb55ca7..93032c9 100644
--- a/src/main/java/org/apache/sysds/lops/Ctable.java
+++ b/src/main/java/org/apache/sysds/lops/Ctable.java
@@ -70,7 +70,7 @@ public class Ctable extends Lop
operation = op;
for(int i=0; i < inputLops.length; i++) {
- this.addInput(inputLops[i]);
+ addInput(inputLops[i]);
inputLops[i].addOutput(this);
}
diff --git a/src/main/java/org/apache/sysds/lops/CumulativeOffsetBinary.java b/src/main/java/org/apache/sysds/lops/CumulativeOffsetBinary.java
index 4cbb126..314b5cf 100644
--- a/src/main/java/org/apache/sysds/lops/CumulativeOffsetBinary.java
+++ b/src/main/java/org/apache/sysds/lops/CumulativeOffsetBinary.java
@@ -53,8 +53,8 @@ public class CumulativeOffsetBinary extends Lop
}
private void init(Lop input1, Lop input2, DataType dt, ValueType vt, ExecType et) {
- this.addInput(input1);
- this.addInput(input2);
+ addInput(input1);
+ addInput(input2);
input1.addOutput(this);
input2.addOutput(this);
lps.setProperties( inputs, et);
diff --git a/src/main/java/org/apache/sysds/lops/CumulativePartialAggregate.java b/src/main/java/org/apache/sysds/lops/CumulativePartialAggregate.java
index 9259745..6e8685f 100644
--- a/src/main/java/org/apache/sysds/lops/CumulativePartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/CumulativePartialAggregate.java
@@ -44,7 +44,7 @@ public class CumulativePartialAggregate extends Lop
}
private void init(Lop input, DataType dt, ValueType vt, ExecType et) {
- this.addInput(input);
+ addInput(input);
input.addOutput(this);
lps.setProperties( inputs, et);
}
diff --git a/src/main/java/org/apache/sysds/lops/Data.java b/src/main/java/org/apache/sysds/lops/Data.java
index 43618e8..51bc611 100644
--- a/src/main/java/org/apache/sysds/lops/Data.java
+++ b/src/main/java/org/apache/sysds/lops/Data.java
@@ -99,7 +99,7 @@ public class Data extends Lop
// input Lops as the first element of WRITE input. The parameters of
// WRITE operation are then put as the following input elements.
if(input != null && op.isWrite()) {
- this.addInput(input);
+ addInput(input);
input.addOutput(this);
}
@@ -107,7 +107,7 @@ public class Data extends Lop
if ( _inputParams != null ) {
for (Lop lop : inputParametersLops.values()) {
- this.addInput(lop);
+ addInput(lop);
lop.addOutput(this);
}
if ( inputParametersLops.get(DataExpression.IO_FILENAME)!= null
diff --git a/src/main/java/org/apache/sysds/lops/DataGen.java b/src/main/java/org/apache/sysds/lops/DataGen.java
index 5b90bf7..27a634c 100644
--- a/src/main/java/org/apache/sysds/lops/DataGen.java
+++ b/src/main/java/org/apache/sysds/lops/DataGen.java
@@ -69,7 +69,7 @@ public class DataGen extends Lop
_op = op;
for (Lop lop : inputParametersLops.values()) {
- this.addInput(lop);
+ addInput(lop);
lop.addOutput(this);
}
@@ -227,60 +227,34 @@ public class DataGen extends Lop
+ "Parameter " + DataExpression.RAND_MIN
+ " must be a literal for a Rand operation.");
- //generate instruction
- StringBuilder sb = new StringBuilder( );
- ExecType et = getExecType();
-
- sb.append( et );
- sb.append( Lop.OPERAND_DELIMITOR );
- sb.append(SINIT_OPCODE);
- sb.append(OPERAND_DELIMITOR);
- sb.append(rowsString);
- sb.append(OPERAND_DELIMITOR);
- sb.append(colsString);
- sb.append(OPERAND_DELIMITOR);
- sb.append(blen);
- sb.append(OPERAND_DELIMITOR);
- sb.append(minString);
- sb.append(OPERAND_DELIMITOR);
- sb.append(prepOutputOperand(output));
-
- return sb.toString();
+ return InstructionUtils.concatOperands(
+ getExecType().toString(), SINIT_OPCODE,
+ rowsString, colsString, blen, minString,
+ prepOutputOperand(output));
}
private String getSampleInstructionCPSpark(String output) {
if ( _op != OpOpDG.SAMPLE )
throw new LopsException("Invalid instruction generation for data generation method " + _op);
- //prepare instruction parameters
- Lop lsize = _inputParams.get(DataExpression.RAND_ROWS.toString());
- Lop lrange = _inputParams.get(DataExpression.RAND_MAX.toString());
- Lop lreplace = _inputParams.get(DataExpression.RAND_PDF.toString());
- Lop lseed = _inputParams.get(DataExpression.RAND_SEED.toString());
-
+ ExecType et = getExecType();
return InstructionUtils.concatOperands(
- getExecType().name(),
- "sample",
- lrange.prepScalarLabel(),
- lsize.prepScalarInputOperand(getExecType()),
- lreplace.prepScalarLabel(),
- lseed.prepScalarLabel(),
+ getExecType().name(), "sample",
+ _inputParams.get(DataExpression.RAND_MAX.toString()).prepScalarLabel(),
+ _inputParams.get(DataExpression.RAND_ROWS.toString()).prepScalarInputOperand(et),
+ _inputParams.get(DataExpression.RAND_PDF.toString()).prepScalarLabel(),
+ _inputParams.get(DataExpression.RAND_SEED.toString()).prepScalarLabel(),
String.valueOf(getOutputParameters().getBlocksize()),
prepOutputOperand(output));
}
- private String getTimeInstructionCP(String output)
- {
+ private String getTimeInstructionCP(String output) {
if (_op != OpOpDG.TIME )
throw new LopsException("Invalid instruction generation for data generation method " + _op);
- StringBuilder sb = new StringBuilder();
- sb.append( getExecType() );
- sb.append( Lop.OPERAND_DELIMITOR );
- sb.append( "time" );
- sb.append( Lop.OPERAND_DELIMITOR );
- sb.append( prepOutputOperand(output) );
- return sb.toString();
+ return InstructionUtils.concatOperands(
+ getExecType().toString(), "time",
+ prepOutputOperand(output));
}
/**
@@ -293,43 +267,16 @@ public class DataGen extends Lop
if ( _op != OpOpDG.SEQ )
throw new LopsException("Invalid instruction generation for data generation method " + _op);
- StringBuilder sb = new StringBuilder( );
ExecType et = getExecType();
- sb.append( et );
- sb.append( Lop.OPERAND_DELIMITOR );
-
- Lop iLop = null;
-
- iLop = _inputParams.get(Statement.SEQ_FROM.toString());
- String fromString = iLop.prepScalarInputOperand(et);
-
- iLop = _inputParams.get(Statement.SEQ_TO.toString());
- String toString = iLop.prepScalarInputOperand(et);
-
- iLop = _inputParams.get(Statement.SEQ_INCR.toString());
- String incrString = iLop.prepScalarInputOperand(et);
-
- String rowsString = String.valueOf(this.getOutputParameters().getNumRows());
- String colsString = String.valueOf(this.getOutputParameters().getNumCols());
- String blen = String.valueOf(this.getOutputParameters().getBlocksize());
-
- sb.append( DataGen.SEQ_OPCODE );
- sb.append( OPERAND_DELIMITOR );
- sb.append( rowsString );
- sb.append( OPERAND_DELIMITOR );
- sb.append( colsString );
- sb.append( OPERAND_DELIMITOR );
- sb.append( blen );
- sb.append( OPERAND_DELIMITOR );
- sb.append( fromString );
- sb.append( OPERAND_DELIMITOR );
- sb.append( toString );
- sb.append( OPERAND_DELIMITOR );
- sb.append( incrString );
- sb.append( OPERAND_DELIMITOR );
- sb.append(prepOutputOperand(output));
-
- return sb.toString();
+ return InstructionUtils.concatOperands(
+ et.toString(), DataGen.SEQ_OPCODE,
+ String.valueOf(getOutputParameters().getNumRows()),
+ String.valueOf(getOutputParameters().getNumCols()),
+ String.valueOf(getOutputParameters().getBlocksize()),
+ _inputParams.get(Statement.SEQ_FROM.toString()).prepScalarInputOperand(et),
+ _inputParams.get(Statement.SEQ_TO.toString()).prepScalarInputOperand(et),
+ _inputParams.get(Statement.SEQ_INCR.toString()).prepScalarInputOperand(et),
+ prepOutputOperand(output));
}
@Override
diff --git a/src/main/java/org/apache/sysds/lops/DnnTransform.java b/src/main/java/org/apache/sysds/lops/DnnTransform.java
index 6c6b7c8..3a75a66 100644
--- a/src/main/java/org/apache/sysds/lops/DnnTransform.java
+++ b/src/main/java/org/apache/sysds/lops/DnnTransform.java
@@ -53,7 +53,7 @@ public class DnnTransform extends Lop
super(Lop.Type.Transform, dt, vt);
init(input1, op, dt, vt, et);
numThreads = k;
- this.addInput(input2);
+ addInput(input2);
input2.addOutput(this);
setLevel();
}
@@ -62,16 +62,16 @@ public class DnnTransform extends Lop
super(Lop.Type.Transform, dt, vt);
init(input1, op, dt, vt, et);
numThreads = k;
- this.addInput(input2);
+ addInput(input2);
input2.addOutput(this);
- this.addInput(input3);
+ addInput(input3);
input3.addOutput(this);
setLevel();
}
private void init (Lop input, OpOpDnn op, DataType dt, ValueType vt, ExecType et) {
operation = op;
- this.addInput(input);
+ addInput(input);
input.addOutput(this);
lps.setProperties( inputs, et);
}
diff --git a/src/main/java/org/apache/sysds/lops/GroupedAggregate.java b/src/main/java/org/apache/sysds/lops/GroupedAggregate.java
index 002c874..d46d1ba 100644
--- a/src/main/java/org/apache/sysds/lops/GroupedAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/GroupedAggregate.java
@@ -72,9 +72,9 @@ public class GroupedAggregate extends Lop
private void init(HashMap<String, Lop> inputParameterLops,
DataType dt, ValueType vt, ExecType et) {
// First, add inputs corresponding to "target" and "groups"
- this.addInput(inputParameterLops.get(Statement.GAGG_TARGET));
+ addInput(inputParameterLops.get(Statement.GAGG_TARGET));
inputParameterLops.get(Statement.GAGG_TARGET).addOutput(this);
- this.addInput(inputParameterLops.get(Statement.GAGG_GROUPS));
+ addInput(inputParameterLops.get(Statement.GAGG_GROUPS));
inputParameterLops.get(Statement.GAGG_GROUPS).addOutput(this);
// process remaining parameters
@@ -82,7 +82,7 @@ public class GroupedAggregate extends Lop
String k = e.getKey();
Lop lop = e.getValue();
if ( !k.equalsIgnoreCase(Statement.GAGG_TARGET) && !k.equalsIgnoreCase(Statement.GAGG_GROUPS) ) {
- this.addInput(lop);
+ addInput(lop);
lop.addOutput(this);
}
}
diff --git a/src/main/java/org/apache/sysds/lops/LeftIndex.java b/src/main/java/org/apache/sysds/lops/LeftIndex.java
index 4a535a3..20d23c1 100644
--- a/src/main/java/org/apache/sysds/lops/LeftIndex.java
+++ b/src/main/java/org/apache/sysds/lops/LeftIndex.java
@@ -72,12 +72,12 @@ public class LeftIndex extends Lop
* i,j -> rowL, rowU
* k,l -> colL, colU
*/
- this.addInput(lhsMatrix);
- this.addInput(rhsMatrix);
- this.addInput(rowL);
- this.addInput(rowU);
- this.addInput(colL);
- this.addInput(colU);
+ addInput(lhsMatrix);
+ addInput(rhsMatrix);
+ addInput(rowL);
+ addInput(rowU);
+ addInput(colL);
+ addInput(colU);
lhsMatrix.addOutput(this);
rhsMatrix.addOutput(this);
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java
index 19e9cc3..fa25000 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -39,6 +39,7 @@ public abstract class Lop
public enum Type {
Data, DataGen, //CP/MR read/write/datagen
ReBlock, CSVReBlock, //MR reblock operations
+ MatMultCP,
MMCJ, MMRJ, MMTSJ, PMMJ, MapMult, MapMultChain, //MR matrix multiplications
UnaryCP, UNARY, BinaryCP, Binary, Ternary, Nary, //CP/MR unary/binary/ternary
RightIndex, LeftIndex, ZeroOut, //CP/MR indexing
diff --git a/src/main/java/org/apache/sysds/lops/MMCJ.java b/src/main/java/org/apache/sysds/lops/MMCJ.java
index 96bf1d0..d4f16de 100644
--- a/src/main/java/org/apache/sysds/lops/MMCJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMCJ.java
@@ -57,8 +57,8 @@ public class MMCJ extends Lop
public MMCJ(Lop input1, Lop input2, DataType dt, ValueType vt, MMCJType type, ExecType et)
{
super(Lop.Type.MMCJ, dt, vt);
- this.addInput(input1);
- this.addInput(input2);
+ addInput(input1);
+ addInput(input2);
input1.addOutput(this);
input2.addOutput(this);
diff --git a/src/main/java/org/apache/sysds/lops/MMRJ.java b/src/main/java/org/apache/sysds/lops/MMRJ.java
index 6c8b280..a09d37d 100644
--- a/src/main/java/org/apache/sysds/lops/MMRJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMRJ.java
@@ -44,8 +44,8 @@ public class MMRJ extends Lop
{
//handle inputs and outputs
super(Lop.Type.MMRJ, dt, vt);
- this.addInput(input1);
- this.addInput(input2);
+ addInput(input1);
+ addInput(input2);
input1.addOutput(this);
input2.addOutput(this);
diff --git a/src/main/java/org/apache/sysds/lops/MapMult.java b/src/main/java/org/apache/sysds/lops/MapMult.java
index c564d16..3b30129 100644
--- a/src/main/java/org/apache/sysds/lops/MapMult.java
+++ b/src/main/java/org/apache/sysds/lops/MapMult.java
@@ -72,8 +72,8 @@ public class MapMult extends Lop
*/
public MapMult(Lop input1, Lop input2, DataType dt, ValueType vt, boolean rightCache, boolean partitioned, boolean emptyBlocks, SparkAggType aggtype) {
super(Lop.Type.MapMult, dt, vt);
- this.addInput(input1);
- this.addInput(input2);
+ addInput(input1);
+ addInput(input2);
input1.addOutput(this);
input2.addOutput(this);
diff --git a/src/main/java/org/apache/sysds/lops/MatMultCP.java b/src/main/java/org/apache/sysds/lops/MatMultCP.java
new file mode 100644
index 0000000..4f2c9cd
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/MatMultCP.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops;
+
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
+
+public class MatMultCP extends Lop
+{
+ private int numThreads = -1;
+ private boolean isLeftTransposed; // Used for GPU matmult operation
+ private boolean isRightTransposed;
+
+ public MatMultCP(Lop input1, Lop input2, DataType dt, ValueType vt, ExecType et) {
+ this(input1, input2, dt, vt, et, 1);
+ }
+
+ public MatMultCP(Lop input1, Lop input2, DataType dt, ValueType vt, ExecType et, int k) {
+ super(Lop.Type.MatMultCP, dt, vt);
+ init(input1, input2, dt, vt, et);
+ numThreads = k;
+ }
+
+ public MatMultCP(Lop input1, Lop input2, DataType dt, ValueType vt, ExecType et,
+ boolean isLeftTransposed, boolean isRightTransposed) {
+ super(Lop.Type.Binary, dt, vt);
+ init(input1, input2, dt, vt, et);
+ this.isLeftTransposed = isLeftTransposed;
+ this.isRightTransposed = isRightTransposed;
+ }
+
+ private void init(Lop input1, Lop input2, DataType dt, ValueType vt, ExecType et) {
+ addInput(input1);
+ addInput(input2);
+ input1.addOutput(this);
+ input2.addOutput(this);
+ lps.setProperties( inputs, et);
+ }
+
+ @Override
+ public String toString() {
+ return " Operation: ba+*";
+ }
+
+ @Override
+ public String getInstructions(String input1, String input2, String output) {
+ if( getExecType() == ExecType.CP ) {
+ return InstructionUtils.concatOperands(
+ getExecType().name(), "ba+*",
+ getInputs().get(0).prepInputOperand(input1),
+ getInputs().get(1).prepInputOperand(input2),
+ prepOutputOperand(output), String.valueOf(numThreads));
+ }
+ else { //GPU
+ return InstructionUtils.concatOperands(
+ getExecType().name(), "ba+*",
+ getInputs().get(0).prepInputOperand(input1),
+ getInputs().get(1).prepInputOperand(input2),
+ prepOutputOperand(output), String.valueOf(numThreads),
+ String.valueOf(isLeftTransposed),
+ String.valueOf(isRightTransposed));
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/lops/PMapMult.java b/src/main/java/org/apache/sysds/lops/PMapMult.java
index 4fc8021..c311ff4 100644
--- a/src/main/java/org/apache/sysds/lops/PMapMult.java
+++ b/src/main/java/org/apache/sysds/lops/PMapMult.java
@@ -32,8 +32,8 @@ public class PMapMult extends Lop
public PMapMult(Lop input1, Lop input2, DataType dt, ValueType vt) {
super(Lop.Type.MapMult, dt, vt);
- this.addInput(input1);
- this.addInput(input2);
+ addInput(input1);
+ addInput(input2);
input1.addOutput(this);
input2.addOutput(this);
lps.setProperties( inputs, ExecType.SPARK);
diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
index fd39f55..698e739 100644
--- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
+++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
@@ -52,7 +52,7 @@ public class ParameterizedBuiltin extends Lop
_operation = op;
for (Lop lop : paramLops.values()) {
- this.addInput(lop);
+ addInput(lop);
lop.addOutput(this);
}
diff --git a/src/main/java/org/apache/sysds/lops/PickByCount.java b/src/main/java/org/apache/sysds/lops/PickByCount.java
index 698ad27..aa7ce1d 100644
--- a/src/main/java/org/apache/sysds/lops/PickByCount.java
+++ b/src/main/java/org/apache/sysds/lops/PickByCount.java
@@ -53,11 +53,11 @@ public class PickByCount extends Lop
private void init(Lop input1, Lop input2, OperationTypes op, ExecType et) {
- this.addInput(input1);
+ addInput(input1);
input1.addOutput(this);
if ( input2 != null ) {
- this.addInput(input2);
+ addInput(input2);
input2.addOutput(this);
}
diff --git a/src/main/java/org/apache/sysds/lops/SortKeys.java b/src/main/java/org/apache/sysds/lops/SortKeys.java
index 7d34d6f..1ed0e21 100644
--- a/src/main/java/org/apache/sysds/lops/SortKeys.java
+++ b/src/main/java/org/apache/sysds/lops/SortKeys.java
@@ -60,7 +60,7 @@ public class SortKeys extends Lop
}
private void init(Lop input1, Lop input2, OperationTypes op, ExecType et) {
- this.addInput(input1);
+ addInput(input1);
input1.addOutput(this);
operation = op;
@@ -68,7 +68,7 @@ public class SortKeys extends Lop
// SortKeys can accept a optional second input only when executing in CP
// Example: sorting with weights inside CP
if ( input2 != null ) {
- this.addInput(input2);
+ addInput(input2);
input2.addOutput(this);
}
lps.setProperties( inputs, et);
diff --git a/src/main/java/org/apache/sysds/lops/TernaryAggregate.java b/src/main/java/org/apache/sysds/lops/TernaryAggregate.java
index 146e664..2d2ca64 100644
--- a/src/main/java/org/apache/sysds/lops/TernaryAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/TernaryAggregate.java
@@ -24,6 +24,7 @@ import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
public class TernaryAggregate extends Lop
@@ -39,11 +40,11 @@ public class TernaryAggregate extends Lop
//optional attribute for cp
private int _numThreads = -1;
- public TernaryAggregate(Lop input1, Lop input2, Lop input3, AggOp aggOp, Binary.OperationTypes binOp, Direction direction, DataType dt, ValueType vt, ExecType et, int k )
+ public TernaryAggregate(Lop input1, Lop input2, Lop input3, AggOp aggOp, OpOp2 binOp, Direction direction, DataType dt, ValueType vt, ExecType et, int k )
{
super(Lop.Type.TernaryAggregate, dt, vt);
- //_aggOp = aggOp;
+ //_aggOp = aggOp;
//_binOp = binOp;
addInput(input1);
diff --git a/src/main/java/org/apache/sysds/lops/Transform.java b/src/main/java/org/apache/sysds/lops/Transform.java
index b1bb0fe..544ea50 100644
--- a/src/main/java/org/apache/sysds/lops/Transform.java
+++ b/src/main/java/org/apache/sysds/lops/Transform.java
@@ -81,7 +81,7 @@ public class Transform extends Lop
private void init (Lop[] input, ReOrgOp op, DataType dt, ValueType vt, ExecType et) {
_operation = op;
for(Lop in : input) {
- this.addInput(in);
+ addInput(in);
in.addOutput(this);
}
lps.setProperties(inputs, et);
diff --git a/src/main/java/org/apache/sysds/lops/UAggOuterChain.java b/src/main/java/org/apache/sysds/lops/UAggOuterChain.java
index 9419bc8..43900f1 100644
--- a/src/main/java/org/apache/sysds/lops/UAggOuterChain.java
+++ b/src/main/java/org/apache/sysds/lops/UAggOuterChain.java
@@ -25,6 +25,7 @@ import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
@@ -39,10 +40,10 @@ public class UAggOuterChain extends Lop
public static final String OPCODE = "uaggouterchain";
//outer operation
- private AggOp _uaggOp = null;
+ private AggOp _uaggOp = null;
private Direction _uaggDir = null;
//inner operation
- private Binary.OperationTypes _binOp = null;
+ private OpOp2 _binOp = null;
/**
* Constructor to setup a unaryagg outer chain
@@ -56,12 +57,12 @@ public class UAggOuterChain extends Lop
* @param vt value type
* @param et execution type
*/
- public UAggOuterChain(Lop input1, Lop input2, AggOp uaop, Direction uadir, Binary.OperationTypes bop, DataType dt, ValueType vt, ExecType et) {
+ public UAggOuterChain(Lop input1, Lop input2, AggOp uaop, Direction uadir, OpOp2 bop, DataType dt, ValueType vt, ExecType et) {
super(Lop.Type.UaggOuterChain, dt, vt);
addInput(input1);
addInput(input2);
- input1.addOutput(this);
- input2.addOutput(this);
+ input1.addOutput(this);
+ input2.addOutput(this);
//setup operator types
_uaggOp = uaop;
@@ -78,10 +79,9 @@ public class UAggOuterChain extends Lop
@Override
public String getInstructions(String input1, String input2, String output) {
return InstructionUtils.concatOperands(
- getExecType().name(),
- OPCODE,
+ getExecType().name(), OPCODE,
PartialAggregate.getOpcode(_uaggOp, _uaggDir), //outer
- Binary.getOpcode(_binOp), //inner
+ _binOp.toString(), //inner
getInputs().get(0).prepInputOperand(input1),
getInputs().get(0).prepInputOperand(input2),
prepOutputOperand(output));
diff --git a/src/main/java/org/apache/sysds/lops/Unary.java b/src/main/java/org/apache/sysds/lops/Unary.java
index f4e6d44..aa51477 100644
--- a/src/main/java/org/apache/sysds/lops/Unary.java
+++ b/src/main/java/org/apache/sysds/lops/Unary.java
@@ -23,6 +23,7 @@ package org.apache.sysds.lops;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.ValueType;
@@ -35,22 +36,7 @@ import org.apache.sysds.common.Types.ValueType;
public class Unary extends Lop
{
- @SuppressWarnings("hiding")
- public enum OperationTypes {
- ADD, SUBTRACT, SUBTRACTRIGHT, MULTIPLY, MULTIPLY2, DIVIDE, MODULUS, INTDIV, MINUS1_MULTIPLY,
- POW, POW2, LOG, MAX, MIN, NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SINH, COSH, TANH, SIGN, SQRT, EXP, Over,
- LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS,
- AND, OR, XOR, BW_AND, BW_OR, BW_XOR, BW_SHIFTL, BW_SHIFTR,
- ROUND, CEIL, FLOOR, MR_IQM, INVERSE, CHOLESKY,
- CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD,
- ISNA, ISNAN, ISINF,
- SPROP, SIGMOID, SUBTRACT_NZ, LOG_NZ,
- CAST_AS_MATRIX, CAST_AS_FRAME,
- TYPEOF, DETECTSCHEMA,
- NOTSUPPORTED
- }
-
- private OperationTypes operation;
+ private OpOp1 operation;
private Lop valInput;
//cp-specific parameters
@@ -68,12 +54,12 @@ public class Unary extends Lop
* @param vt value type
* @param et execution type
*/
- public Unary(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
+ public Unary(Lop input1, Lop input2, OpOp1 op, DataType dt, ValueType vt, ExecType et) {
super(Lop.Type.UNARY, dt, vt);
init(input1, input2, op, dt, vt, et);
}
- private void init(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
+ private void init(Lop input1, Lop input2, OpOp1 op, DataType dt, ValueType vt, ExecType et) {
operation = op;
if (input1.getDataType() == DataType.MATRIX)
@@ -81,9 +67,9 @@ public class Unary extends Lop
else
valInput = input1;
- this.addInput(input1);
+ addInput(input1);
input1.addOutput(this);
- this.addInput(input2);
+ addInput(input2);
input2.addOutput(this);
lps.setProperties(inputs, et);
}
@@ -99,24 +85,21 @@ public class Unary extends Lop
* @param numThreads number of threads
* @param inplace inplace behavior
*/
- public Unary(Lop input1, OperationTypes op, DataType dt, ValueType vt, ExecType et, int numThreads, boolean inplace) {
+ public Unary(Lop input1, OpOp1 op, DataType dt, ValueType vt, ExecType et, int numThreads, boolean inplace) {
super(Lop.Type.UNARY, dt, vt);
init(input1, op, dt, vt, et);
_numThreads = numThreads;
_inplace = inplace;
}
- private void init(Lop input1, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
+ private void init(Lop input1, OpOp1 op, DataType dt, ValueType vt, ExecType et) {
//sanity check
- if ( (op == OperationTypes.INVERSE || op == OperationTypes.CHOLESKY)
- && et == ExecType.SPARK ) {
+ if ( (op == OpOp1.INVERSE || op == OpOp1.CHOLESKY) && et == ExecType.SPARK )
throw new LopsException("Invalid exection type "+et.toString()+" for operation "+op.toString());
- }
operation = op;
valInput = null;
-
- this.addInput(input1);
+ addInput(input1);
input1.addOutput(this);
lps.setProperties(inputs, et);
}
@@ -133,187 +116,18 @@ public class Unary extends Lop
}
private String getOpcode() {
- return getOpcode(operation);
- }
-
- public static String getOpcode(OperationTypes op) {
- switch (op) {
- case NOT:
- return "!";
- case ABS:
- return "abs";
- case SIN:
- return "sin";
- case COS:
- return "cos";
- case TAN:
- return "tan";
- case ASIN:
- return "asin";
- case ACOS:
- return "acos";
- case ATAN:
- return "atan";
- case SINH:
- return "sinh";
- case COSH:
- return "cosh";
- case TANH:
- return "tanh";
- case SIGN:
- return "sign";
- case SQRT:
- return "sqrt";
- case EXP:
- return "exp";
-
- case LOG:
- return "log";
-
- case LOG_NZ:
- return "log_nz";
-
- case ROUND:
- return "round";
-
- case ADD:
- return "+";
-
- case SUBTRACT:
- return "-";
-
- case SUBTRACT_NZ:
- return "-nz";
-
- case SUBTRACTRIGHT:
- return "s-r";
-
- case MULTIPLY:
- return "*";
-
- case MULTIPLY2:
- return "*2";
-
- case MINUS1_MULTIPLY:
- return "1-*";
-
- case DIVIDE:
- return "/";
-
- case MODULUS:
- return "%%";
-
- case INTDIV:
- return "%/%";
-
- case Over:
- return "so";
-
- case POW:
- return "^";
-
- case POW2:
- return "^2";
-
- case GREATER_THAN:
- return ">";
-
- case GREATER_THAN_OR_EQUALS:
- return ">=";
-
- case LESS_THAN:
- return "<";
-
- case LESS_THAN_OR_EQUALS:
- return "<=";
-
- case EQUALS:
- return "==";
-
- case NOT_EQUALS:
- return "!=";
-
- case MAX:
- return "max";
-
- case MIN:
- return "min";
-
- case CEIL:
- return "ceil";
-
- case FLOOR:
- return "floor";
-
- case CUMSUM:
- return "ucumk+";
-
- case CUMPROD:
- return "ucum*";
-
- case CUMMIN:
- return "ucummin";
-
- case CUMMAX:
- return "ucummax";
-
- case CUMSUMPROD:
- return "ucumk+*";
-
- case INVERSE:
- return "inverse";
-
- case CHOLESKY:
- return "cholesky";
-
- case MR_IQM:
- return "qpick";
-
- case SPROP:
- return "sprop";
-
- case SIGMOID:
- return "sigmoid";
-
- case TYPEOF:
- return "typeOf";
-
- case DETECTSCHEMA:
- return "detectSchema";
-
- case CAST_AS_MATRIX:
- return UnaryCP.CAST_AS_MATRIX_OPCODE;
-
- case CAST_AS_FRAME:
- return UnaryCP.CAST_AS_FRAME_OPCODE;
-
- case AND: return "&&";
- case OR: return "||";
- case XOR: return "xor";
- case BW_AND: return "bitwAnd";
- case BW_OR: return "bitwOr";
- case BW_XOR: return "bitwXor";
- case BW_SHIFTL: return "bitwShiftL";
- case BW_SHIFTR: return "bitwShiftR";
- case ISNA: return "isna";
- case ISNAN: return "isnan";
- case ISINF: return "isinf";
-
- default:
- throw new LopsException(
- "Instruction not defined for Unary operation: " + op);
- }
+ return operation.toString();
}
- public static boolean isMultiThreadedOp(OperationTypes op) {
- return op==OperationTypes.CUMSUM
- || op==OperationTypes.CUMPROD
- || op==OperationTypes.CUMMIN
- || op==OperationTypes.CUMMAX
- || op==OperationTypes.CUMSUMPROD
- || op==OperationTypes.EXP
- || op==OperationTypes.LOG
- || op==OperationTypes.SIGMOID;
+ public static boolean isMultiThreadedOp(OpOp1 op) {
+ return op==OpOp1.CUMSUM
+ || op==OpOp1.CUMPROD
+ || op==OpOp1.CUMMIN
+ || op==OpOp1.CUMMAX
+ || op==OpOp1.CUMSUMPROD
+ || op==OpOp1.EXP
+ || op==OpOp1.LOG
+ || op==OpOp1.SIGMOID;
}
@Override
diff --git a/src/main/java/org/apache/sysds/lops/UnaryCP.java b/src/main/java/org/apache/sysds/lops/UnaryCP.java
index c3cdde9..f395a09 100644
--- a/src/main/java/org/apache/sysds/lops/UnaryCP.java
+++ b/src/main/java/org/apache/sysds/lops/UnaryCP.java
@@ -23,17 +23,11 @@ package org.apache.sysds.lops;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.ValueType;
public class UnaryCP extends Lop
{
- @SuppressWarnings("hiding")
- public enum OperationTypes {
- NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SQRT, LOG, EXP, SINH, COSH, TANH,
- CAST_AS_SCALAR, CAST_AS_MATRIX, CAST_AS_FRAME, CAST_AS_DOUBLE, CAST_AS_INT, CAST_AS_BOOLEAN,
- PRINT, ASSERT, NROW, NCOL, LENGTH, EXISTS, LINEAGE, ROUND, STOP, CEIL, FLOOR, CUMSUM, SOFTMAX, TYPEOF, DETECTSCHEMA
- }
-
public static final String CAST_AS_SCALAR_OPCODE = "castdts";
public static final String CAST_AS_MATRIX_OPCODE = "castdtm";
public static final String CAST_AS_FRAME_OPCODE = "castdtf";
@@ -41,9 +35,7 @@ public class UnaryCP extends Lop
public static final String CAST_AS_INT_OPCODE = "castvti";
public static final String CAST_AS_BOOLEAN_OPCODE = "castvtb";
-
-
- OperationTypes operation;
+ private OpOp1 operation;
/**
* Constructor to perform a scalar operation
@@ -54,133 +46,31 @@ public class UnaryCP extends Lop
* @param vt value type
* @param et exec type
*/
- public UnaryCP(Lop input, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
+ public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt, ExecType et) {
super(Lop.Type.UnaryCP, dt, vt);
operation = op;
- this.addInput(input);
+ addInput(input);
input.addOutput(this);
lps.setProperties(inputs, et);
}
- public UnaryCP(Lop input, OperationTypes op, DataType dt, ValueType vt) {
+ public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt) {
this(input, op, dt, vt, ExecType.CP);
}
@Override
public String toString() {
-
return "Operation: " + operation;
-
}
private String getOpCode() {
- return getOpCode(operation);
+ return operation.toString();
}
- public static String getOpCode(OperationTypes op) {
- switch (op) {
- case NOT:
- return "!";
-
- case ABS:
- return "abs";
-
- case SIN:
- return "sin";
-
- case COS:
- return "cos";
-
- case TAN:
- return "tan";
-
- case ASIN:
- return "asin";
-
- case ACOS:
- return "acos";
-
- case ATAN:
- return "atan";
-
- case SINH:
- return "sinh";
-
- case COSH:
- return "cosh";
-
- case TANH:
- return "tanh";
-
- case SQRT:
- return "sqrt";
-
- case LOG:
- return "log";
-
- case ROUND:
- return "round";
-
- case EXP:
- return "exp";
-
- case PRINT:
- return "print";
-
- case ASSERT:
- return "assert";
-
- case CAST_AS_MATRIX:
- return CAST_AS_MATRIX_OPCODE;
-
- case CAST_AS_FRAME:
- return CAST_AS_FRAME_OPCODE;
-
- case STOP:
- return "stop";
-
- case CEIL:
- return "ceil";
-
- case FLOOR:
- return "floor";
-
- case CUMSUM:
- return "ucumk+";
-
- // CAST_AS_SCALAR, NROW, NCOL, LENGTH builtins take matrix as the input
- // and produces a scalar
- case CAST_AS_SCALAR:
- return CAST_AS_SCALAR_OPCODE;
-
- case CAST_AS_DOUBLE:
- return CAST_AS_DOUBLE_OPCODE;
-
- case CAST_AS_INT:
- return CAST_AS_INT_OPCODE;
-
- case CAST_AS_BOOLEAN:
- return CAST_AS_BOOLEAN_OPCODE;
-
- case NROW: return "nrow";
- case NCOL: return "ncol";
- case LENGTH: return "length";
- case EXISTS: return "exists";
- case LINEAGE: return "lineage";
-
- case SOFTMAX:
- return "softmax";
-
- default:
- throw new LopsException("Unknown operation: " + op);
- }
- }
-
@Override
public String getInstructions(String input, String output) {
return InstructionUtils.concatOperands(
- getExecType().name(),
- getOpCode(),
+ getExecType().name(), getOpCode(),
getInputs().get(0).prepScalarInputOperand(getExecType()),
prepOutputOperand(output));
}
diff --git a/src/main/java/org/apache/sysds/lops/WeightedUnaryMM.java b/src/main/java/org/apache/sysds/lops/WeightedUnaryMM.java
index f81ee72..ec934c6 100644
--- a/src/main/java/org/apache/sysds/lops/WeightedUnaryMM.java
+++ b/src/main/java/org/apache/sysds/lops/WeightedUnaryMM.java
@@ -21,9 +21,9 @@ package org.apache.sysds.lops;
import org.apache.sysds.lops.LopProperties.ExecType;
-import org.apache.sysds.lops.Unary.OperationTypes;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.ValueType;
public class WeightedUnaryMM extends Lop
@@ -37,10 +37,10 @@ public class WeightedUnaryMM extends Lop
}
private WUMMType _wummType = null;
- private OperationTypes _uop = null;
+ private OpOp1 _uop = null;
private int _numThreads = 1;
- public WeightedUnaryMM(Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, WUMMType wt, OperationTypes op, ExecType et) {
+ public WeightedUnaryMM(Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, WUMMType wt, OpOp1 op, ExecType et) {
super(Lop.Type.WeightedUMM, dt, vt);
addInput(input1); //X
addInput(input2); //U
@@ -73,7 +73,7 @@ public class WeightedUnaryMM extends Lop
sb.append(OPCODE);
sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(Unary.getOpcode(_uop));
+ sb.append(_uop.toString());
sb.append(Lop.OPERAND_DELIMITOR);
sb.append( getInputs().get(0).prepInputOperand(input1));
diff --git a/src/main/java/org/apache/sysds/lops/WeightedUnaryMMR.java b/src/main/java/org/apache/sysds/lops/WeightedUnaryMMR.java
index f16a16f..4288bcd 100644
--- a/src/main/java/org/apache/sysds/lops/WeightedUnaryMMR.java
+++ b/src/main/java/org/apache/sysds/lops/WeightedUnaryMMR.java
@@ -21,22 +21,22 @@ package org.apache.sysds.lops;
import org.apache.sysds.lops.LopProperties.ExecType;
-import org.apache.sysds.lops.Unary.OperationTypes;
import org.apache.sysds.lops.WeightedUnaryMM.WUMMType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.ValueType;
-public class WeightedUnaryMMR extends Lop
+public class WeightedUnaryMMR extends Lop
{
public static final String OPCODE = "redwumm";
private WUMMType _wummType = null;
- private OperationTypes _uop = null;
+ private OpOp1 _uop = null;
private boolean _cacheU = false;
private boolean _cacheV = false;
- public WeightedUnaryMMR(Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, WUMMType wt, OperationTypes op, boolean cacheU, boolean cacheV, ExecType et) {
+ public WeightedUnaryMMR(Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, WUMMType wt, OpOp1 op, boolean cacheU, boolean cacheV, ExecType et) {
super(Lop.Type.WeightedUMM, dt, vt);
addInput(input1); //X
addInput(input2); //U
@@ -62,8 +62,7 @@ public class WeightedUnaryMMR extends Lop
public String getInstructions(String input1, String input2, String input3, String output) {
return InstructionUtils.concatOperands(
getExecType().name(),
- OPCODE,
- Unary.getOpcode(_uop),
+ OPCODE, _uop.toString(),
getInputs().get(0).prepInputOperand(input1),
getInputs().get(1).prepInputOperand(input2),
getInputs().get(2).prepInputOperand(input3),
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 0b92e18..9e41f9b 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -40,8 +40,6 @@ import org.apache.sysds.hops.DnnOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.FunctionOp.FunctionType;
import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.Hop.OpOp1;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LeftIndexingOp;
@@ -68,6 +66,8 @@ import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpData;
@@ -1027,21 +1027,23 @@ public class DMLTranslator
try {
if (ptype == PRINTTYPE.PRINT) {
- Hop.OpOp1 op = Hop.OpOp1.PRINT;
+ OpOp1 op = OpOp1.PRINT;
Expression source = ps.getExpressions().get(0);
Hop ae = processExpression(source, target, ids);
Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
printHop.setParseInfo(current);
output.add(printHop);
- } else if (ptype == PRINTTYPE.ASSERT) {
- Hop.OpOp1 op = Hop.OpOp1.ASSERT;
+ }
+ else if (ptype == PRINTTYPE.ASSERT) {
+ OpOp1 op = OpOp1.ASSERT;
Expression source = ps.getExpressions().get(0);
Hop ae = processExpression(source, target, ids);
Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
printHop.setParseInfo(current);
output.add(printHop);
- } else if (ptype == PRINTTYPE.STOP) {
- Hop.OpOp1 op = Hop.OpOp1.STOP;
+ }
+ else if (ptype == PRINTTYPE.STOP) {
+ OpOp1 op = OpOp1.STOP;
Expression source = ps.getExpressions().get(0);
Hop ae = processExpression(source, target, ids);
Hop stopHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
@@ -1583,7 +1585,7 @@ public class DMLTranslator
if ( target.getDim1() != -1 )
rowUpperHops = new LiteralOp(target.getOrigDim1());
else {
- rowUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT64, Hop.OpOp1.NROW, hops.get(target.getName()));
+ rowUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NROW, hops.get(target.getName()));
rowUpperHops.setParseInfo(target);
}
}
@@ -1599,7 +1601,7 @@ public class DMLTranslator
if ( target.getDim2() != -1 )
colUpperHops = new LiteralOp(target.getOrigDim2());
else
- colUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT64, Hop.OpOp1.NCOL, hops.get(target.getName()));
+ colUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NCOL, hops.get(target.getName()));
}
// process the source expression to get source Hops
@@ -1644,7 +1646,7 @@ public class DMLTranslator
if ( source.getOrigDim1() != -1 )
rowUpperHops = new LiteralOp(source.getOrigDim1());
else {
- rowUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT64, Hop.OpOp1.NROW, hops.get(source.getName()));
+ rowUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NROW, hops.get(source.getName()));
rowUpperHops.setParseInfo(source);
}
}
@@ -1660,7 +1662,7 @@ public class DMLTranslator
if ( source.getOrigDim2() != -1 )
colUpperHops = new LiteralOp(source.getOrigDim2());
else
- colUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT64, Hop.OpOp1.NCOL, hops.get(source.getName()));
+ colUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NCOL, hops.get(source.getName()));
}
if (target == null) {
@@ -1804,7 +1806,7 @@ public class DMLTranslator
target.setValueType(ValueType.BOOLEAN);
if (source.getRight() == null) {
- Hop currUop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NOT, left);
+ Hop currUop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp1.NOT, left);
currUop.setParseInfo(source);
return currUop;
}
@@ -2203,7 +2205,7 @@ public class DMLTranslator
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX,
target.getValueType(), AggOp.VAR, Direction.Col, expr);
currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX,
- target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
+ target.getValueType(), OpOp1.SQRT, currBuiltinOp);
break;
case ROWSUM:
@@ -2231,32 +2233,32 @@ public class DMLTranslator
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX,
target.getValueType(), AggOp.VAR, Direction.Row, expr);
currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX,
- target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
+ target.getValueType(), OpOp1.SQRT, currBuiltinOp);
break;
case NROW:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
currBuiltinOp = (expr.getDim1()==-1) ? new UnaryOp(target.getName(), target.getDataType(),
- target.getValueType(), Hop.OpOp1.NROW, expr) : new LiteralOp(expr.getDim1());
+ target.getValueType(), OpOp1.NROW, expr) : new LiteralOp(expr.getDim1());
break;
case NCOL:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
currBuiltinOp = (expr.getDim2()==-1) ? new UnaryOp(target.getName(), target.getDataType(),
- target.getValueType(), Hop.OpOp1.NCOL, expr) : new LiteralOp(expr.getDim2());
+ target.getValueType(), OpOp1.NCOL, expr) : new LiteralOp(expr.getDim2());
break;
case LENGTH:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
currBuiltinOp = (expr.getDim1()==-1 || expr.getDim2()==-1) ? new UnaryOp(target.getName(), target.getDataType(),
- target.getValueType(), Hop.OpOp1.LENGTH, expr) : new LiteralOp(expr.getDim1()*expr.getDim2());
+ target.getValueType(), OpOp1.LENGTH, expr) : new LiteralOp(expr.getDim1()*expr.getDim2());
break;
case LINEAGE:
//construct hop and enable lineage tracing if necessary
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
- target.getValueType(), Hop.OpOp1.LINEAGE, expr);
+ target.getValueType(), OpOp1.LINEAGE, expr);
DMLScript.LINEAGE = true;
break;
@@ -2267,7 +2269,7 @@ public class DMLTranslator
case EXISTS:
currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR,
- target.getValueType(), Hop.OpOp1.EXISTS, expr);
+ target.getValueType(), OpOp1.EXISTS, expr);
break;
case SUM:
@@ -2298,7 +2300,7 @@ public class DMLTranslator
target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
HopRewriteUtils.setOutputParametersForScalar(currBuiltinOp);
currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR,
- target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
+ target.getValueType(), OpOp1.SQRT, currBuiltinOp);
break;
case MIN:
@@ -2404,24 +2406,24 @@ public class DMLTranslator
//data type casts
case CAST_AS_SCALAR:
- currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), Hop.OpOp1.CAST_AS_SCALAR, expr);
+ currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), OpOp1.CAST_AS_SCALAR, expr);
break;
case CAST_AS_MATRIX:
- currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), Hop.OpOp1.CAST_AS_MATRIX, expr);
+ currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), OpOp1.CAST_AS_MATRIX, expr);
break;
case CAST_AS_FRAME:
- currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), Hop.OpOp1.CAST_AS_FRAME, expr);
+ currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), OpOp1.CAST_AS_FRAME, expr);
break;
//value type casts
case CAST_AS_DOUBLE:
- currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, Hop.OpOp1.CAST_AS_DOUBLE, expr);
+ currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.CAST_AS_DOUBLE, expr);
break;
case CAST_AS_INT:
- currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.INT64, Hop.OpOp1.CAST_AS_INT, expr);
+ currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.INT64, OpOp1.CAST_AS_INT, expr);
break;
case CAST_AS_BOOLEAN:
- currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.BOOLEAN, Hop.OpOp1.CAST_AS_BOOLEAN, expr);
+ currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.BOOLEAN, OpOp1.CAST_AS_BOOLEAN, expr);
break;
// Boolean binary
@@ -2468,10 +2470,10 @@ public class DMLTranslator
case LOG:
if (expr2 == null) {
- Hop.OpOp1 mathOp2;
+ OpOp1 mathOp2;
switch (source.getOpCode()) {
case LOG:
- mathOp2 = Hop.OpOp1.LOG;
+ mathOp2 = OpOp1.LOG;
break;
default:
throw new ParseException(source.printErrorLocation() +
@@ -2481,10 +2483,10 @@ public class DMLTranslator
currBuiltinOp = new UnaryOp(target.getName(),
target.getDataType(), target.getValueType(), mathOp2, expr);
} else {
- Hop.OpOp2 mathOp3;
+ OpOp2 mathOp3;
switch (source.getOpCode()) {
case LOG:
- mathOp3 = Hop.OpOp2.LOG;
+ mathOp3 = OpOp2.LOG;
break;
default:
throw new ParseException(source.printErrorLocation() +
@@ -2575,7 +2577,7 @@ public class DMLTranslator
}
case SOLVE:
- currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp2.SOLVE, expr, expr2);
+ currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.SOLVE, expr, expr2);
break;
case INVERSE:
@@ -2589,7 +2591,7 @@ public class DMLTranslator
case OUTER:
if( !(expr3 instanceof LiteralOp) )
throw new HopsException("Operator for outer builtin function must be a constant: "+expr3);
- OpOp2 op = Hop.getOpOp2ForOuterVectorOperation(((LiteralOp)expr3).getStringValue());
+ OpOp2 op = OpOp2.valueOfByOpcode(((LiteralOp)expr3).getStringValue());
if( op == null )
throw new HopsException("Unsupported outer vector binary operation: "+((LiteralOp)expr3).getStringValue());
diff --git a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
index 8392d61..7bd3793 100644
--- a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
@@ -34,11 +34,11 @@ import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.Hop.OpOp1;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.parser.Expression.BinaryOp;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index c59c350..a47d6de 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -107,7 +107,14 @@ import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTy
public class InstructionUtils
{
-
+ //thread-local string builders for instruction concatenation (avoid allocation)
+ private static ThreadLocal<StringBuilder> _strBuilders = new ThreadLocal<StringBuilder>() {
+ @Override
+ protected StringBuilder initialValue() {
+ return new StringBuilder(64);
+ }
+ };
+
public static int checkNumFields( String str, int expected ) {
//note: split required for empty tokens
int numParts = str.split(Instruction.OPERAND_DELIM).length;
@@ -992,7 +999,8 @@ public class InstructionUtils
}
public static String concatOperands(String... inputs) {
- StringBuilder sb = new StringBuilder(64);
+ StringBuilder sb = _strBuilders.get();
+ sb.setLength(0); //reuse allocated space
for( int i=0; i<inputs.length-1; i++ ) {
sb.append(inputs[i]);
sb.append(Lop.OPERAND_DELIMITOR);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index 6a57f38..eade225 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -50,7 +50,6 @@ import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.SpoofFusedOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
-import org.apache.sysds.lops.Binary;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.PartialAggregate;
import org.apache.sysds.lops.UnaryCP;
@@ -424,13 +423,13 @@ public class LineageItemUtils {
if (root instanceof ReorgOp)
li = new LineageItem(name, "r'", LIinputs);
else if (root instanceof UnaryOp) {
- String opcode = UnaryCP.getOpCode(Hop.HopsOpOp1LopsUS.get(((UnaryOp) root).getOp()));
+ String opcode = ((UnaryOp) root).getOp().toString();
li = new LineageItem(name, opcode, LIinputs);
}
else if (root instanceof AggBinaryOp)
li = new LineageItem(name, "ba+*", LIinputs);
else if (root instanceof BinaryOp)
- li = new LineageItem(name, Binary.getOpcode(Hop.HopsOpOp2LopsB.get(((BinaryOp)root).getOp())), LIinputs);
+ li = new LineageItem(name, ((BinaryOp)root).getOp().toString(), LIinputs);
else if (root instanceof TernaryOp) {
String opcode = ((TernaryOp) root).getOp().toString();
li = new LineageItem(name, opcode, LIinputs);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
index f1b8c58..d400623 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
@@ -27,6 +27,7 @@ import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ValueType;
@@ -34,7 +35,6 @@ import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.Hop.OpOp2;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.NaryOp;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index 8530adf..beca629 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -22,7 +22,7 @@ package org.apache.sysds.runtime.matrix.operators;
import java.io.Serializable;
-import org.apache.sysds.hops.Hop.OpOp2;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.runtime.functionobjects.And;
import org.apache.sysds.runtime.functionobjects.BitwAnd;
import org.apache.sysds.runtime.functionobjects.BitwOr;
@@ -74,43 +74,41 @@ public class BinaryOperator extends Operator implements Serializable
*
* @return binary operator type for a function object
*/
- public OpOp2 getBinaryOperatorOpOp2()
- {
- if( fn instanceof Plus ) return OpOp2.PLUS;
- else if( fn instanceof Minus ) return OpOp2.MINUS;
- else if( fn instanceof Multiply ) return OpOp2.MULT;
- else if( fn instanceof Divide ) return OpOp2.DIV;
- else if( fn instanceof Modulus ) return OpOp2.MODULUS;
- else if( fn instanceof IntegerDivide ) return OpOp2.INTDIV;
- else if( fn instanceof LessThan ) return OpOp2.LESS;
- else if( fn instanceof LessThanEquals ) return OpOp2.LESSEQUAL;
- else if( fn instanceof GreaterThan ) return OpOp2.GREATER;
- else if( fn instanceof GreaterThanEquals ) return OpOp2.GREATEREQUAL;
- else if( fn instanceof Equals ) return OpOp2.EQUAL;
- else if( fn instanceof NotEquals ) return OpOp2.NOTEQUAL;
- else if( fn instanceof And ) return OpOp2.AND;
- else if( fn instanceof Or ) return OpOp2.OR;
- else if( fn instanceof Xor ) return OpOp2.XOR;
- else if( fn instanceof BitwAnd ) return OpOp2.BITWAND;
- else if( fn instanceof BitwOr ) return OpOp2.BITWOR;
- else if( fn instanceof BitwXor ) return OpOp2.BITWXOR;
- else if( fn instanceof BitwShiftL ) return OpOp2.BITWSHIFTL;
- else if( fn instanceof BitwShiftR ) return OpOp2.BITWSHIFTR;
- else if( fn instanceof Power ) return OpOp2.POW;
- else if( fn instanceof MinusNz ) return OpOp2.MINUS_NZ;
+ public OpOp2 getBinaryOperatorOpOp2() {
+ if( fn instanceof Plus ) return OpOp2.PLUS;
+ else if( fn instanceof Minus ) return OpOp2.MINUS;
+ else if( fn instanceof Multiply ) return OpOp2.MULT;
+ else if( fn instanceof Divide ) return OpOp2.DIV;
+ else if( fn instanceof Modulus ) return OpOp2.MODULUS;
+ else if( fn instanceof IntegerDivide ) return OpOp2.INTDIV;
+ else if( fn instanceof LessThan ) return OpOp2.LESS;
+ else if( fn instanceof LessThanEquals ) return OpOp2.LESSEQUAL;
+ else if( fn instanceof GreaterThan ) return OpOp2.GREATER;
+ else if( fn instanceof GreaterThanEquals ) return OpOp2.GREATEREQUAL;
+ else if( fn instanceof Equals ) return OpOp2.EQUAL;
+ else if( fn instanceof NotEquals ) return OpOp2.NOTEQUAL;
+ else if( fn instanceof And ) return OpOp2.AND;
+ else if( fn instanceof Or ) return OpOp2.OR;
+ else if( fn instanceof Xor ) return OpOp2.XOR;
+ else if( fn instanceof BitwAnd ) return OpOp2.BITWAND;
+ else if( fn instanceof BitwOr ) return OpOp2.BITWOR;
+ else if( fn instanceof BitwXor ) return OpOp2.BITWXOR;
+ else if( fn instanceof BitwShiftL ) return OpOp2.BITWSHIFTL;
+ else if( fn instanceof BitwShiftR ) return OpOp2.BITWSHIFTR;
+ else if( fn instanceof Power ) return OpOp2.POW;
+ else if( fn instanceof MinusNz ) return OpOp2.MINUS_NZ;
else if( fn instanceof Builtin ) {
BuiltinCode bfc = ((Builtin) fn).getBuiltinCode();
- if( bfc == BuiltinCode.MIN ) return OpOp2.MIN;
- else if( bfc == BuiltinCode.MAX ) return OpOp2.MAX;
- else if( bfc == BuiltinCode.LOG ) return OpOp2.LOG;
- else if( bfc == BuiltinCode.LOG_NZ ) return OpOp2.LOG_NZ;
+ if( bfc == BuiltinCode.MIN ) return OpOp2.MIN;
+ else if( bfc == BuiltinCode.MAX ) return OpOp2.MAX;
+ else if( bfc == BuiltinCode.LOG ) return OpOp2.LOG;
+ else if( bfc == BuiltinCode.LOG_NZ ) return OpOp2.LOG_NZ;
}
//non-supported ops (not required for sparsity estimates):
//PRINT, CONCAT, QUANTILE, INTERQUANTILE, IQM,
//CENTRALMOMENT, COVARIANCE, APPEND, SOLVE, MEDIAN,
-
- return OpOp2.INVALID;
+ return null;
}
@Override
diff --git a/src/test/java/org/apache/sysds/test/component/codegen/CPlanVectorPrimitivesTest.java b/src/test/java/org/apache/sysds/test/component/codegen/CPlanVectorPrimitivesTest.java
index 2592aa5..4e153f8 100644
--- a/src/test/java/org/apache/sysds/test/component/codegen/CPlanVectorPrimitivesTest.java
+++ b/src/test/java/org/apache/sysds/test/component/codegen/CPlanVectorPrimitivesTest.java
@@ -22,8 +22,7 @@ package org.apache.sysds.test.component.codegen;
import java.lang.reflect.Method;
import org.junit.Test;
-import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.Hop.OpOp2;
+import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary.BinType;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary.UnaryType;
import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
@@ -839,7 +838,7 @@ public class CPlanVectorPrimitivesTest extends AutomatedTestBase
inA.getSparseBlock().indexes(i), inA.getSparseBlock().pos(i), i*n, inA.getSparseBlock().size(i), n);
//execute comparison operation
- String opcode = Hop.getBinaryOpCode(OpOp2.valueOf(bintype.name().split("_")[1]));
+ String opcode = OpOp2.valueOf(bintype.name().split("_")[1]).toString();
MatrixBlock in1 = inA.slice(i, i, 0, n-1, new MatrixBlock());
MatrixBlock in2 = inB.slice(i, i, 0, n-1, new MatrixBlock());
double[] ret2 = null;