You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2015/11/23 04:53:02 UTC
[7/8] incubator-systemml git commit: New wumm quaternary op (rewrite,
cp/mr/sp compiler/runtime, tests, docs)
New wumm quaternary op (rewrite, cp/mr/sp compiler/runtime, tests, docs)
This change adds a new quaternary operation 'wumm' for the pattern
X*uop(L%*%t(R)), where uop is an arbitrary unary operator (with few
exceptions) or matrix-scalar/scalar-matrix operation that is internally
mapped to an unary operator (e.g., X^2, 2*X).
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/d70c4524
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/d70c4524
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/d70c4524
Branch: refs/heads/master
Commit: d70c4524726386ab6dec80b21914e60f80e52af1
Parents: e52e0c0
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Nov 21 22:00:15 2015 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sun Nov 22 19:38:28 2015 -0800
----------------------------------------------------------------------
docs/devdocs/MatrixMultiplicationOperators.txt | 18 +-
src/main/java/com/ibm/bi/dml/hops/Hop.java | 1 +
.../java/com/ibm/bi/dml/hops/QuaternaryOp.java | 265 +++++++++++++-
.../RewriteAlgebraicSimplificationDynamic.java | 112 ++++++
src/main/java/com/ibm/bi/dml/lops/Lop.java | 2 +-
src/main/java/com/ibm/bi/dml/lops/Unary.java | 27 +-
.../ibm/bi/dml/lops/WeightedCrossEntropy.java | 30 +-
.../ibm/bi/dml/lops/WeightedCrossEntropyR.java | 36 +-
.../com/ibm/bi/dml/lops/WeightedSigmoid.java | 30 +-
.../com/ibm/bi/dml/lops/WeightedSigmoidR.java | 36 +-
.../com/ibm/bi/dml/lops/WeightedUnaryMM.java | 165 +++++++++
.../com/ibm/bi/dml/lops/WeightedUnaryMMR.java | 162 +++++++++
.../dml/runtime/functionobjects/Multiply2.java | 6 +-
.../bi/dml/runtime/functionobjects/Power2.java | 5 +
.../runtime/functionobjects/ValueFunction.java | 1 -
.../instructions/CPInstructionParser.java | 1 +
.../runtime/instructions/InstructionUtils.java | 8 +-
.../instructions/MRInstructionParser.java | 4 +
.../instructions/SPInstructionParser.java | 4 +
.../cp/QuaternaryCPInstruction.java | 16 +-
.../instructions/mr/QuaternaryInstruction.java | 34 +-
.../spark/QuaternarySPInstruction.java | 34 +-
.../dml/runtime/matrix/data/LibMatrixMult.java | 341 +++++++++++++++++++
.../bi/dml/runtime/matrix/data/MatrixBlock.java | 8 +-
.../matrix/operators/QuaternaryOperator.java | 25 +-
.../quaternary/WeightedUnaryMatrixMultTest.java | 284 +++++++++++++++
.../quaternary/WeightedUnaryMMExpDiv.R | 33 ++
.../quaternary/WeightedUnaryMMExpDiv.dml | 27 ++
.../quaternary/WeightedUnaryMMExpMult.R | 33 ++
.../quaternary/WeightedUnaryMMExpMult.dml | 27 ++
.../functions/quaternary/WeightedUnaryMMMult2.R | 33 ++
.../quaternary/WeightedUnaryMMMult2.dml | 27 ++
.../functions/quaternary/WeightedUnaryMMPow2.R | 33 ++
.../quaternary/WeightedUnaryMMPow2.dml | 27 ++
.../functions/quaternary/ZPackageSuite.java | 3 +-
35 files changed, 1764 insertions(+), 134 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/docs/devdocs/MatrixMultiplicationOperators.txt
----------------------------------------------------------------------
diff --git a/docs/devdocs/MatrixMultiplicationOperators.txt b/docs/devdocs/MatrixMultiplicationOperators.txt
index 54cb5ec..7bc8a9c 100644
--- a/docs/devdocs/MatrixMultiplicationOperators.txt
+++ b/docs/devdocs/MatrixMultiplicationOperators.txt
@@ -1,6 +1,6 @@
#####################################################################
# TITLE: An Overview of Matrix Multiplication Operators in SystemML #
-# DATE MODIFIED: 09/26/2015 #
+# DATE MODIFIED: 11/21/2015 #
#####################################################################
In the following, we give an overview of backend-specific physical matrix multiplication operators in SystemML as well as their internally used matrix multiplication block operations.
@@ -48,6 +48,7 @@ A QuaternaryOp hop can be compiled into the following physical operators. Note t
- WSigmoid (weighted sigmoid) --> wsigmoid
- WDivMM (weighted divide matrix multiplication) --> wdivmm
- WCeMM (weighted cross entropy matrix multiplication) --> wcemm
+ - WuMM (weighted unary op matrix multiplication) --> wumm
* 2) Physical Operator in MR (distributed, mapreduce)
- MapWSLoss (map-side weighted squared loss) --> wsloss
@@ -58,6 +59,8 @@ A QuaternaryOp hop can be compiled into the following physical operators. Note t
- RedWDivMM (reduce-side weighted divide matrix mult) --> wdivmm
- MapWCeMM (map-side weighted cross entr. matrix mult) --> wcemm
- RedWCeMM (reduce-side w. cross entr. matrix mult) --> wcemm
+ - MapWuMM (map-side weighted unary op matrix mult) --> wumm
+ - RedWuMM (reduce-side weighted unary op matrix mult) --> wumm
* 3) Physical Operators in SPARK (distributed, spark)
- MapWSLoss (see MR, mappartitions + reduce) --> wsloss
@@ -70,8 +73,11 @@ A QuaternaryOp hop can be compiled into the following physical operators. Note t
- RedWDivMM (see MR, 1/2x flatmaptopair + 1/2x join + --> wdivmm
maptopair + reducebykey)
- MapWCeMM (see MR, mappartitions + reduce) --> wcemm
- - RedWDivMM (see MR, 1/2x flatmaptopair + 1/2x join + --> wcemm
+ - RedWCeMM (see MR, 1/2x flatmaptopair + 1/2x join + --> wcemm
maptopair + reduce)
+ - MapWuMM (see MR, mappartitions) --> wumm
+ - RedWuMM (see MR, 1/2x flatmaptopair + --> wumm
+ 1/2x join + maptopair)
C) CORE MATRIX MULT PRIMITIVES LibMatrixMult (incl related script patterns)
@@ -112,9 +118,11 @@ C) CORE MATRIX MULT PRIMITIVES LibMatrixMult (incl related script patterns)
- sequential / multi-threaded (same block ops, par over rows in X)
- all dense, sparse-dense factors, sparse/dense-* x 7 patterns
-* 8) wcemm (sum(X*log(U%*%t(V))))
+* 8) wcemm (sum(X*log(U%*%t(V))))
- sequential / multi-threaded (same block ops, par over rows in X)
- all dense, sparse-dense factors, sparse/dense-*, 1 pattern
-
-
\ No newline at end of file
+* 9) wumm ((a) X*uop(U%*%t(V)), (b) X/uop(U%*%t(V)))
+ - any unary operator, e.g., X*exp(U%*%t(V)) or X*(U%*%t(V))^2
+ - sequential / multi-threaded (same block ops, par over rows in X)
+ - all dense, sparse-dense factors, sparse/dense-*, 2 pattern
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/hops/Hop.java b/src/main/java/com/ibm/bi/dml/hops/Hop.java
index ab328a8..b968952 100644
--- a/src/main/java/com/ibm/bi/dml/hops/Hop.java
+++ b/src/main/java/com/ibm/bi/dml/hops/Hop.java
@@ -1054,6 +1054,7 @@ public abstract class Hop
WSIGMOID, //weighted sigmoid mm
WDIVMM, //weighted divide mm
WCEMM, //weighted cross entropy mm
+ WUMM, //weighted unary mm
INVALID
};
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/hops/QuaternaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/hops/QuaternaryOp.java b/src/main/java/com/ibm/bi/dml/hops/QuaternaryOp.java
index db0df09..fd329f0 100644
--- a/src/main/java/com/ibm/bi/dml/hops/QuaternaryOp.java
+++ b/src/main/java/com/ibm/bi/dml/hops/QuaternaryOp.java
@@ -25,6 +25,7 @@ import com.ibm.bi.dml.lops.Lop;
import com.ibm.bi.dml.lops.LopsException;
import com.ibm.bi.dml.lops.RepMat;
import com.ibm.bi.dml.lops.Transform;
+import com.ibm.bi.dml.lops.Unary;
import com.ibm.bi.dml.lops.UnaryCP;
import com.ibm.bi.dml.lops.LopProperties.ExecType;
import com.ibm.bi.dml.lops.PartialAggregate.CorrectionLocationType;
@@ -40,6 +41,9 @@ import com.ibm.bi.dml.lops.WeightedSigmoidR;
import com.ibm.bi.dml.lops.WeightedSquaredLoss;
import com.ibm.bi.dml.lops.WeightedSquaredLoss.WeightsType;
import com.ibm.bi.dml.lops.WeightedSquaredLossR;
+import com.ibm.bi.dml.lops.WeightedUnaryMM;
+import com.ibm.bi.dml.lops.WeightedUnaryMM.WUMMType;
+import com.ibm.bi.dml.lops.WeightedUnaryMMR;
import com.ibm.bi.dml.parser.Expression.DataType;
import com.ibm.bi.dml.parser.Expression.ValueType;
import com.ibm.bi.dml.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat;
@@ -72,6 +76,11 @@ public class QuaternaryOp extends Hop implements MultiThreadedHop
private boolean _mult = false;
private boolean _minus = false;
+ //wumm-specific attributes
+ private boolean _umult = false;
+ private OpOp1 _uop = null;
+ private OpOp2 _sop = null;
+
private QuaternaryOp() {
//default constructor for clone
}
@@ -131,6 +140,16 @@ public class QuaternaryOp extends Hop implements MultiThreadedHop
_minus = flag2;
}
+ public QuaternaryOp(String l, DataType dt, ValueType vt, Hop.OpOp4 o,
+ Hop inW, Hop inU, Hop inV, boolean umult, OpOp1 uop, OpOp2 sop)
+ {
+ this(l, dt, vt, o, inW, inU, inV);
+
+ _umult = umult;
+ _uop = uop;
+ _sop = sop;
+ }
+
/**
*
* @param l
@@ -235,6 +254,20 @@ public class QuaternaryOp extends Hop implements MultiThreadedHop
break;
}
+ case WUMM:{
+ WUMMType wtype = _umult ? WUMMType.MULT : WUMMType.DIV;
+
+ if( et == ExecType.CP )
+ constructCPLopsWeightedUMM(wtype);
+ else if( et == ExecType.MR )
+ constructMRLopsWeightedUMM(wtype);
+ else if( et == ExecType.SPARK )
+ constructSparkLopsWeightedUMM(wtype);
+ else
+ throw new HopsException("Unsupported quaternaryop-wumm exec type: "+et);
+ break;
+ }
+
default:
throw new HopsException(this.printErrorLocation() + "Unknown QuaternaryOp (" + _op + ") while constructing Lops");
}
@@ -1165,6 +1198,225 @@ public class QuaternaryOp extends Hop implements MultiThreadedHop
setLops(wcemm);
}
}
+
+ /**
+ *
+ * @param wtype
+ * @throws HopsException
+ * @throws LopsException
+ */
+ private void constructCPLopsWeightedUMM(WUMMType wtype)
+ throws HopsException, LopsException
+ {
+ Unary.OperationTypes uop = _uop!=null ?
+ HopsOpOp1LopsU.get(_uop) : _sop==OpOp2.POW ?
+ Unary.OperationTypes.POW2 : Unary.OperationTypes.MULTIPLY2;
+
+ WeightedUnaryMM wsig = new WeightedUnaryMM(
+ getInput().get(0).constructLops(),
+ getInput().get(1).constructLops(),
+ getInput().get(2).constructLops(),
+ getDataType(), getValueType(), wtype, uop, ExecType.CP);
+
+ //set degree of parallelism
+ int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
+ wsig.setNumThreads(k);
+
+ setOutputDimensions( wsig );
+ setLineNumbers( wsig );
+ setLops( wsig );
+ }
+
+ /**
+ *
+ * @param wtype
+ * @throws HopsException
+ * @throws LopsException
+ */
+ private void constructMRLopsWeightedUMM( WUMMType wtype )
+ throws HopsException, LopsException
+ {
+ //NOTE: the common case for wsigmoid are factors U/V with a rank of 10s to 100s; the current runtime only
+ //supports single block outer products (U/V rank <= blocksize, i.e., 1000 by default); we enforce this
+ //by applying the hop rewrite for Weighted Squared Loss only if this constraint holds.
+
+ Unary.OperationTypes uop = _uop!=null ?
+ HopsOpOp1LopsU.get(_uop) : _sop==OpOp2.POW ?
+ Unary.OperationTypes.POW2 : Unary.OperationTypes.MULTIPLY2;
+
+ Hop X = getInput().get(0);
+ Hop U = getInput().get(1);
+ Hop V = getInput().get(2);
+
+ //MR operator selection, part1
+ double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2()); //size U
+ double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2()); //size V
+ boolean isMapWsloss = (m1Size+m2Size < OptimizerUtils.getRemoteMemBudgetMap(true));
+
+ if( !FORCE_REPLICATION && isMapWsloss ) //broadcast
+ {
+ //partitioning of U
+ boolean needPartU = !U.dimsKnown() || U.getDim1() * U.getDim2() > DistributedCacheInput.PARTITION_SIZE;
+ Lop lU = U.constructLops();
+ if( needPartU ){ //requires partitioning
+ lU = new DataPartition(lU, DataType.MATRIX, ValueType.DOUBLE, (m1Size>OptimizerUtils.getLocalMemBudget())?ExecType.MR:ExecType.CP, PDataPartitionFormat.ROW_BLOCK_WISE_N);
+ lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), getRowsInBlock(), getColsInBlock(), U.getNnz());
+ setLineNumbers(lU);
+ }
+
+ //partitioning of V
+ boolean needPartV = !V.dimsKnown() || V.getDim1() * V.getDim2() > DistributedCacheInput.PARTITION_SIZE;
+ Lop lV = V.constructLops();
+ if( needPartV ){ //requires partitioning
+ lV = new DataPartition(lV, DataType.MATRIX, ValueType.DOUBLE, (m2Size>OptimizerUtils.getLocalMemBudget())?ExecType.MR:ExecType.CP, PDataPartitionFormat.ROW_BLOCK_WISE_N);
+ lV.getOutputParameters().setDimensions(V.getDim1(), V.getDim2(), getRowsInBlock(), getColsInBlock(), V.getNnz());
+ setLineNumbers(lV);
+ }
+
+ //map-side wsloss always with broadcast
+ Lop wumm = new WeightedUnaryMM( X.constructLops(), lU, lV,
+ DataType.MATRIX, ValueType.DOUBLE, wtype, uop, ExecType.MR);
+ setOutputDimensions(wumm);
+ setLineNumbers(wumm);
+ setLops( wumm );
+
+ //in contrast to wsloss no aggregation required
+ }
+ else //general case
+ {
+ //MR operator selection part 2
+ boolean cacheU = !FORCE_REPLICATION && (m1Size < OptimizerUtils.getRemoteMemBudgetReduce());
+ boolean cacheV = !FORCE_REPLICATION && ((!cacheU && m2Size < OptimizerUtils.getRemoteMemBudgetReduce())
+ || (cacheU && m1Size+m2Size < OptimizerUtils.getRemoteMemBudgetReduce()));
+
+ Group grpX = new Group(X.constructLops(), Group.OperationTypes.Sort, DataType.MATRIX, ValueType.DOUBLE);
+ grpX.getOutputParameters().setDimensions(X.getDim1(), X.getDim2(), X.getRowsInBlock(), X.getColsInBlock(), X.getNnz());
+ setLineNumbers(grpX);
+
+ Lop lU = null;
+ if( cacheU ) {
+ //partitioning of U for read through distributed cache
+ boolean needPartU = !U.dimsKnown() || U.getDim1() * U.getDim2() > DistributedCacheInput.PARTITION_SIZE;
+ lU = U.constructLops();
+ if( needPartU ){ //requires partitioning
+ lU = new DataPartition(lU, DataType.MATRIX, ValueType.DOUBLE, (m1Size>OptimizerUtils.getLocalMemBudget())?ExecType.MR:ExecType.CP, PDataPartitionFormat.ROW_BLOCK_WISE_N);
+ lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), getRowsInBlock(), getColsInBlock(), U.getNnz());
+ setLineNumbers(lU);
+ }
+ }
+ else {
+ //replication of U for shuffle to target block
+ Lop offset = createOffsetLop(V, false); //ncol of t(V) -> nrow of V determines num replicates
+ lU = new RepMat(U.constructLops(), offset, true, V.getDataType(), V.getValueType());
+ lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(),
+ U.getRowsInBlock(), U.getColsInBlock(), U.getNnz());
+ setLineNumbers(lU);
+
+ Group grpU = new Group(lU, Group.OperationTypes.Sort, DataType.MATRIX, ValueType.DOUBLE);
+ grpU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), U.getRowsInBlock(), U.getColsInBlock(), -1);
+ setLineNumbers(grpU);
+ lU = grpU;
+ }
+
+ Lop lV = null;
+ if( cacheV ) {
+ //partitioning of V for read through distributed cache
+ boolean needPartV = !V.dimsKnown() || V.getDim1() * V.getDim2() > DistributedCacheInput.PARTITION_SIZE;
+ lV = V.constructLops();
+ if( needPartV ){ //requires partitioning
+ lV = new DataPartition(lV, DataType.MATRIX, ValueType.DOUBLE, (m2Size>OptimizerUtils.getLocalMemBudget())?ExecType.MR:ExecType.CP, PDataPartitionFormat.ROW_BLOCK_WISE_N);
+ lV.getOutputParameters().setDimensions(V.getDim1(), V.getDim2(), getRowsInBlock(), getColsInBlock(), V.getNnz());
+ setLineNumbers(lV);
+ }
+ }
+ else {
+ //replication of t(V) for shuffle to target block
+ Transform ltV = new Transform( V.constructLops(), HopsTransf2Lops.get(ReOrgOp.TRANSPOSE), getDataType(), getValueType(), ExecType.MR);
+ ltV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(),
+ V.getColsInBlock(), V.getRowsInBlock(), V.getNnz());
+ setLineNumbers(ltV);
+
+ Lop offset = createOffsetLop(U, false); //nrow of U determines num replicates
+ lV = new RepMat(ltV, offset, false, V.getDataType(), V.getValueType());
+ lV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(),
+ V.getColsInBlock(), V.getRowsInBlock(), V.getNnz());
+ setLineNumbers(lV);
+
+ Group grpV = new Group(lV, Group.OperationTypes.Sort, DataType.MATRIX, ValueType.DOUBLE);
+ grpV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(), -1);
+ setLineNumbers(grpV);
+ lV = grpV;
+ }
+
+ //reduce-side wsloss w/ or without broadcast
+ Lop wumm = new WeightedUnaryMMR(
+ grpX, lU, lV, DataType.MATRIX, ValueType.DOUBLE, wtype, uop, cacheU, cacheV, ExecType.MR);
+ setOutputDimensions(wumm);
+ setLineNumbers(wumm);
+ setLops(wumm);
+
+ //in contrast to wsloss no aggregation required
+ }
+ }
+
+ /**
+ *
+ * @param wtype
+ * @throws HopsException
+ * @throws LopsException
+ */
+ private void constructSparkLopsWeightedUMM( WUMMType wtype )
+ throws HopsException, LopsException
+ {
+ //NOTE: the common case for wsigmoid are factors U/V with a rank of 10s to 100s; the current runtime only
+ //supports single block outer products (U/V rank <= blocksize, i.e., 1000 by default); we enforce this
+ //by applying the hop rewrite for Weighted Squared Loss only if this constraint holds.
+
+ Unary.OperationTypes uop = _uop!=null ?
+ HopsOpOp1LopsU.get(_uop) : _sop==OpOp2.POW ?
+ Unary.OperationTypes.POW2 : Unary.OperationTypes.MULTIPLY2;
+
+ //Notes: Any broadcast needs to fit twice in local memory because we partition the input in cp,
+ //and needs to fit once in executor broadcast memory. The 2GB broadcast constraint is no longer
+ //required because the max_int byte buffer constraint has been fixed in Spark 1.4
+ double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
+ double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
+
+ Hop X = getInput().get(0);
+ Hop U = getInput().get(1);
+ Hop V = getInput().get(2);
+
+ //MR operator selection, part1
+ double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2()); //size U
+ double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2()); //size V
+ boolean isMapWsloss = (m1Size+m2Size < memBudgetExec
+ && 2*m1Size<memBudgetLocal && 2*m2Size<memBudgetLocal);
+
+ if( !FORCE_REPLICATION && isMapWsloss ) //broadcast
+ {
+ //map-side wsloss always with broadcast
+ Lop wsigmoid = new WeightedUnaryMM( X.constructLops(), U.constructLops(), V.constructLops(),
+ DataType.MATRIX, ValueType.DOUBLE, wtype, uop, ExecType.SPARK);
+ setOutputDimensions(wsigmoid);
+ setLineNumbers(wsigmoid);
+ setLops( wsigmoid );
+ }
+ else //general case
+ {
+ //MR operator selection part 2
+ boolean cacheU = !FORCE_REPLICATION && (m1Size < memBudgetExec && 2*m1Size < memBudgetLocal);
+ boolean cacheV = !FORCE_REPLICATION && ((!cacheU && m2Size < memBudgetExec )
+ || (cacheU && m1Size+m2Size < memBudgetExec)) && 2*m2Size < memBudgetLocal;
+
+ //reduce-side wsloss w/ or without broadcast
+ Lop wsigmoid = new WeightedUnaryMMR(
+ X.constructLops(), U.constructLops(), V.constructLops(),
+ DataType.MATRIX, ValueType.DOUBLE, wtype, uop, cacheU, cacheV, ExecType.SPARK);
+ setOutputDimensions(wsigmoid);
+ setLineNumbers(wsigmoid);
+ setLops(wsigmoid);
+ }
+ }
/**
*
@@ -1238,6 +1490,7 @@ public class QuaternaryOp extends Hop implements MultiThreadedHop
case WSIGMOID:
case WDIVMM:
+ case WUMM:
double sp = OptimizerUtils.getSparsity(dim1, dim2, nnz);
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sp);
@@ -1263,7 +1516,8 @@ public class QuaternaryOp extends Hop implements MultiThreadedHop
ret = null;
break;
- case WSIGMOID: {
+ case WSIGMOID:
+ case WUMM: {
MatrixCharacteristics mcW = memo.getAllInputStats(getInput().get(0));
ret = new long[]{mcW.getRows(), mcW.getCols(), mcW.getNonZeros()};
break;
@@ -1336,7 +1590,8 @@ public class QuaternaryOp extends Hop implements MultiThreadedHop
//do nothing: always scalar
break;
- case WSIGMOID: {
+ case WSIGMOID:
+ case WUMM: {
Hop inW = getInput().get(0);
setDim1( inW.getDim1() );
setDim2( inW.getDim2() );
@@ -1385,6 +1640,9 @@ public class QuaternaryOp extends Hop implements MultiThreadedHop
ret._baseType = _baseType;
ret._mult = _mult;
ret._minus = _minus;
+ ret._umult = _umult;
+ ret._uop = _uop;
+ ret._sop = _sop;
ret._maxNumThreads = _maxNumThreads;
return ret;
@@ -1416,6 +1674,9 @@ public class QuaternaryOp extends Hop implements MultiThreadedHop
ret &= _baseType == that2._baseType;
ret &= _mult == that2._mult;
ret &= _minus == that2._minus;
+ ret &= _umult == that2._umult;
+ ret &= _uop == that2._uop;
+ ret &= _sop == that2._sop;
ret &= _maxNumThreads == that2._maxNumThreads;
return ret;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 10dbedf..1a0710f 100644
--- a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -74,6 +74,9 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//valid pseudo-sparse-safe binary operators for wdivmm
private static OpOp2[] LOOKUP_VALID_WDIVMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.DIV};
+ //valid unary and binary operators for wumm
+ private static OpOp1[] LOOKUP_VALID_WUMM_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.EXP, OpOp1.LOG, OpOp1.SQRT, OpOp1.SIGMOID, OpOp1.SPROP};
+ private static OpOp2[] LOOKUP_VALID_WUMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.POW};
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state)
@@ -166,6 +169,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hi = simplifyWeightedSigmoidMMChains(hop, hi, i); //e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type)
hi = simplifyWeightedDivMM(hop, hi, i); //e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left)
hi = simplifyWeightedCrossEntropy(hop, hi, i); //e.g., sum(X*log(U%*%t(V))) -> wcemm(X, U, t(V))
+ hi = simplifyWeightedUnaryMM(hop, hi, i); //e.g., X*exp(U%*%t(V)) -> wumm(X, U, t(V), exp)
hi = simplifyDotProductSum(hop, hi, i); //e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1
hi = fuseSumSquared(hop, hi, i); //e.g., sum(X^2) -> sumSq(X), if ncol(X)>1
hi = reorderMinusMatrixMult(hop, hi, i); //e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size
@@ -1946,6 +1950,114 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
}
/**
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
+ * @throws HopsException
+ */
+ private Hop simplifyWeightedUnaryMM(Hop parent, Hop hi, int pos)
+ throws HopsException
+ {
+ Hop hnew = null;
+ boolean appliedPattern = false;
+
+ //Pattern 1) (W*uop(U%*%t(V)))
+ if( hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
+ && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv
+ && hi.getDim2() > 1 //not applied for vector-vector mult
+ && hi.getInput().get(0).getDataType() == DataType.MATRIX
+ && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock()
+ && hi.getInput().get(1) instanceof UnaryOp
+ && HopRewriteUtils.isValidOp(((UnaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_UNARY)
+ && hi.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
+ && HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+ {
+ Hop W = hi.getInput().get(0);
+ Hop U = hi.getInput().get(1).getInput().get(0).getInput().get(0);
+ Hop V = hi.getInput().get(1).getInput().get(0).getInput().get(1);
+ boolean mult = ((BinaryOp)hi).getOp()==OpOp2.MULT;
+ OpOp1 op = ((UnaryOp)hi.getInput().get(1)).getOp();
+
+ if( !HopRewriteUtils.isTransposeOperation(V) )
+ V = HopRewriteUtils.createTranspose(V);
+ else
+ V = V.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WUMM, W, U, V, mult, op, null);
+ HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")");
+ }
+
+ //Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
+ if( !appliedPattern
+ && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
+ && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv
+ && hi.getDim2() > 1 //not applied for vector-vector mult
+ && hi.getInput().get(0).getDataType() == DataType.MATRIX
+ && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock()
+ && hi.getInput().get(1) instanceof BinaryOp
+ && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WUMM_BINARY) )
+ {
+ Hop left = hi.getInput().get(1).getInput().get(0);
+ Hop right = hi.getInput().get(1).getInput().get(1);
+ Hop abop = null;
+
+ //pattern 2a) matrix-scalar operations
+ if( right.getDataType()==DataType.SCALAR && right instanceof LiteralOp
+ && HopRewriteUtils.getDoubleValue((LiteralOp)right)==2 //pow2, mult2
+ && left instanceof AggBinaryOp
+ && HopRewriteUtils.isSingleBlock(left.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+ {
+ abop = left;
+ }
+ //pattern 2b) scalar-matrix operations
+ else if( left.getDataType()==DataType.SCALAR && left instanceof LiteralOp
+ && HopRewriteUtils.getDoubleValue((LiteralOp)left)==2 //mult2
+ && ((BinaryOp)hi.getInput().get(1)).getOp() == OpOp2.MULT
+ && right instanceof AggBinaryOp
+ && HopRewriteUtils.isSingleBlock(right.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+ {
+ abop = right;
+ }
+
+ if( abop != null ) {
+ Hop W = hi.getInput().get(0);
+ Hop U = abop.getInput().get(0);
+ Hop V = abop.getInput().get(1);
+ boolean mult = ((BinaryOp)hi).getOp()==OpOp2.MULT;
+ OpOp2 op = ((BinaryOp)hi.getInput().get(1)).getOp();
+
+ if( !HopRewriteUtils.isTransposeOperation(V) )
+ V = HopRewriteUtils.createTranspose(V);
+ else
+ V = V.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WUMM, W, U, V, mult, null, op);
+ HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedUnaryMM2 (line "+hi.getBeginLine()+")");
+ }
+ }
+
+
+ //relink new hop into original position
+ if( hnew != null ) {
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, hnew, pos);
+ hi = hnew;
+ }
+
+ return hi;
+ }
+
+ /**
* NOTE: dot-product-sum could be also applied to sum(a*b). However, we
* restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm
* a%*%b on MR can be also counter-productive (e.g., MMCJ) while tsmm is always
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/lops/Lop.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/lops/Lop.java b/src/main/java/com/ibm/bi/dml/lops/Lop.java
index 1de8f60..751ca1f 100644
--- a/src/main/java/com/ibm/bi/dml/lops/Lop.java
+++ b/src/main/java/com/ibm/bi/dml/lops/Lop.java
@@ -53,7 +53,7 @@ public abstract class Lop
ParameterizedBuiltin, //CP/MR parameterized ops (name/value)
FunctionCallCP, //CP function calls
CumulativePartialAggregate, CumulativeSplitAggregate, CumulativeOffsetBinary, //MR cumsum/cumprod/cummin/cummax
- WeightedSquaredLoss, WeightedSigmoid, WeightedDivMM, WeightedCeMM,
+ WeightedSquaredLoss, WeightedSigmoid, WeightedDivMM, WeightedCeMM, WeightedUMM,
SortKeys, PickValues,
Checkpoint, //Spark persist into storage level
};
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/lops/Unary.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/lops/Unary.java b/src/main/java/com/ibm/bi/dml/lops/Unary.java
index 4f5cd0e..b29e822 100644
--- a/src/main/java/com/ibm/bi/dml/lops/Unary.java
+++ b/src/main/java/com/ibm/bi/dml/lops/Unary.java
@@ -160,8 +160,27 @@ public class Unary extends Lop
return "Operation: " + operation + " " + "Label: N/A";
}
- private String getOpcode() throws LopsException {
- switch (operation) {
+ /**
+ *
+ * @return
+ * @throws LopsException
+ */
+ private String getOpcode()
+ throws LopsException
+ {
+ return getOpcode(operation);
+ }
+
+ /**
+ *
+ * @param op
+ * @return
+ * @throws LopsException
+ */
+ public static String getOpcode(OperationTypes op)
+ throws LopsException
+ {
+ switch (op) {
case NOT:
return "!";
case ABS:
@@ -289,8 +308,8 @@ public class Unary extends Lop
return "sel+";
default:
- throw new LopsException(this.printErrorLocation() +
- "Instruction not defined for Unary operation: " + operation);
+ throw new LopsException(
+ "Instruction not defined for Unary operation: " + op);
}
}
public String getInstructions(String input1, String output)
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/lops/WeightedCrossEntropy.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/lops/WeightedCrossEntropy.java b/src/main/java/com/ibm/bi/dml/lops/WeightedCrossEntropy.java
index ea8a160..df131c9 100644
--- a/src/main/java/com/ibm/bi/dml/lops/WeightedCrossEntropy.java
+++ b/src/main/java/com/ibm/bi/dml/lops/WeightedCrossEntropy.java
@@ -86,31 +86,13 @@ public class WeightedCrossEntropy extends Lop
}
@Override
- public String getInstructions(int input_index1, int input_index2, int input_index3, int output_index)
+ public String getInstructions(int input1, int input2, int input3, int output)
{
- StringBuilder sb = new StringBuilder();
-
- sb.append(getExecType());
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(OPCODE);
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(0).prepInputOperand(input_index1));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(1).prepInputOperand(input_index2));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(2).prepInputOperand(input_index3));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( prepOutputOperand(output_index));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(_wcemmType);
-
- return sb.toString();
+ return getInstructions(
+ String.valueOf(input1),
+ String.valueOf(input2),
+ String.valueOf(input3),
+ String.valueOf(output));
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/lops/WeightedCrossEntropyR.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/lops/WeightedCrossEntropyR.java b/src/main/java/com/ibm/bi/dml/lops/WeightedCrossEntropyR.java
index 5b40628..58f3cc7 100644
--- a/src/main/java/com/ibm/bi/dml/lops/WeightedCrossEntropyR.java
+++ b/src/main/java/com/ibm/bi/dml/lops/WeightedCrossEntropyR.java
@@ -89,37 +89,13 @@ public class WeightedCrossEntropyR extends Lop
}
@Override
- public String getInstructions(int input_index1, int input_index2, int input_index3, int output_index)
+ public String getInstructions(int input1, int input2, int input3, int output)
{
- StringBuilder sb = new StringBuilder();
-
- sb.append(getExecType());
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(OPCODE);
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(0).prepInputOperand(input_index1));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(1).prepInputOperand(input_index2));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(2).prepInputOperand(input_index3));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( prepOutputOperand(output_index));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(_wcemmType);
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(_cacheU);
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(_cacheV);
-
- return sb.toString();
+ return getInstructions(
+ String.valueOf(input1),
+ String.valueOf(input2),
+ String.valueOf(input3),
+ String.valueOf(output));
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/lops/WeightedSigmoid.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/lops/WeightedSigmoid.java b/src/main/java/com/ibm/bi/dml/lops/WeightedSigmoid.java
index 2f187b9..310ae37 100644
--- a/src/main/java/com/ibm/bi/dml/lops/WeightedSigmoid.java
+++ b/src/main/java/com/ibm/bi/dml/lops/WeightedSigmoid.java
@@ -90,31 +90,13 @@ public class WeightedSigmoid extends Lop
}
@Override
- public String getInstructions(int input_index1, int input_index2, int input_index3, int output_index)
+ public String getInstructions(int input1, int input2, int input3, int output)
{
- StringBuilder sb = new StringBuilder();
-
- sb.append(getExecType());
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(OPCODE);
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(0).prepInputOperand(input_index1));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(1).prepInputOperand(input_index2));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(2).prepInputOperand(input_index3));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( prepOutputOperand(output_index));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(_wsigmoidType);
-
- return sb.toString();
+ return getInstructions(
+ String.valueOf(input1),
+ String.valueOf(input2),
+ String.valueOf(input3),
+ String.valueOf(output));
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/lops/WeightedSigmoidR.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/lops/WeightedSigmoidR.java b/src/main/java/com/ibm/bi/dml/lops/WeightedSigmoidR.java
index 6aff6d2..de99bfd 100644
--- a/src/main/java/com/ibm/bi/dml/lops/WeightedSigmoidR.java
+++ b/src/main/java/com/ibm/bi/dml/lops/WeightedSigmoidR.java
@@ -89,37 +89,13 @@ public class WeightedSigmoidR extends Lop
}
@Override
- public String getInstructions(int input_index1, int input_index2, int input_index3, int output_index)
+ public String getInstructions(int input1, int input2, int input3, int output)
{
- StringBuilder sb = new StringBuilder();
-
- sb.append(getExecType());
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(OPCODE);
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(0).prepInputOperand(input_index1));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(1).prepInputOperand(input_index2));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( getInputs().get(2).prepInputOperand(input_index3));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append( prepOutputOperand(output_index));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(_wsType);
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(_cacheU);
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(_cacheV);
-
- return sb.toString();
+ return getInstructions(
+ String.valueOf(input1),
+ String.valueOf(input2),
+ String.valueOf(input3),
+ String.valueOf(output));
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/lops/WeightedUnaryMM.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/lops/WeightedUnaryMM.java b/src/main/java/com/ibm/bi/dml/lops/WeightedUnaryMM.java
new file mode 100644
index 0000000..48269df
--- /dev/null
+++ b/src/main/java/com/ibm/bi/dml/lops/WeightedUnaryMM.java
@@ -0,0 +1,165 @@
+/**
+ * (C) Copyright IBM Corp. 2010, 2015
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package com.ibm.bi.dml.lops;
+
+import com.ibm.bi.dml.lops.LopProperties.ExecLocation;
+import com.ibm.bi.dml.lops.LopProperties.ExecType;
+import com.ibm.bi.dml.lops.Unary.OperationTypes;
+import com.ibm.bi.dml.lops.compile.JobType;
+import com.ibm.bi.dml.parser.Expression.DataType;
+import com.ibm.bi.dml.parser.Expression.ValueType;
+
+/**
+ *
+ */
+public class WeightedUnaryMM extends Lop
+{
+ public static final String OPCODE = "mapwumm";
+ public static final String OPCODE_CP = "wumm";
+
+ public enum WUMMType {
+ MULT,
+ DIV,
+ }
+
+ private WUMMType _wummType = null;
+ private OperationTypes _uop = null;
+ private int _numThreads = 1;
+
+ public WeightedUnaryMM(Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, WUMMType wt, OperationTypes op, ExecType et)
+ throws LopsException
+ {
+ super(Lop.Type.WeightedUMM, dt, vt);
+ addInput(input1); //X
+ addInput(input2); //U
+ addInput(input3); //V
+ input1.addOutput(this);
+ input2.addOutput(this);
+ input3.addOutput(this);
+
+ //setup mapmult parameters
+ _wummType = wt;
+ _uop = op;
+ setupLopProperties(et);
+ }
+
+ /**
+ *
+ * @param et
+ */
+ private void setupLopProperties( ExecType et )
+ {
+ if( et == ExecType.MR )
+ {
+ //setup MR parameters
+ boolean breaksAlignment = true;
+ boolean aligner = false;
+ boolean definesMRJob = false;
+ lps.addCompatibility(JobType.GMR);
+ lps.addCompatibility(JobType.DATAGEN);
+ lps.setProperties( inputs, ExecType.MR, ExecLocation.Map, breaksAlignment, aligner, definesMRJob );
+ }
+ else //Spark/CP
+ {
+ //setup Spark parameters
+ boolean breaksAlignment = false;
+ boolean aligner = false;
+ boolean definesMRJob = false;
+ lps.addCompatibility(JobType.INVALID);
+ lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+ }
+ }
+
+ public String toString() {
+ return "Operation = WeightedUMM";
+ }
+
+ @Override
+ public String getInstructions(int input1, int input2, int input3, int output)
+ throws LopsException
+ {
+ return getInstructions(
+ String.valueOf(input1),
+ String.valueOf(input2),
+ String.valueOf(input3),
+ String.valueOf(output));
+ }
+
+ @Override
+ public String getInstructions(String input1, String input2, String input3, String output)
+ throws LopsException
+ {
+ StringBuilder sb = new StringBuilder();
+
+ sb.append(getExecType());
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ if( getExecType() == ExecType.CP )
+ sb.append(OPCODE_CP);
+ else
+ sb.append(OPCODE);
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append(Unary.getOpcode(_uop));
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append( getInputs().get(0).prepInputOperand(input1));
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append( getInputs().get(1).prepInputOperand(input2));
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append( getInputs().get(2).prepInputOperand(input3));
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append( prepOutputOperand(output));
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append(_wummType);
+
+ //append degree of parallelism
+ if( getExecType()==ExecType.CP ) {
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( _numThreads );
+ }
+
+ return sb.toString();
+ }
+
+ @Override
+ public boolean usesDistributedCache()
+ {
+ if( getExecType()==ExecType.MR )
+ return true;
+ else
+ return false;
+ }
+
+ @Override
+ public int[] distributedCacheInputIndex()
+ {
+ if( getExecType()==ExecType.MR )
+ return new int[]{2,3};
+ else
+ return new int[]{-1};
+ }
+
+ public void setNumThreads(int k) {
+ _numThreads = k;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/lops/WeightedUnaryMMR.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/lops/WeightedUnaryMMR.java b/src/main/java/com/ibm/bi/dml/lops/WeightedUnaryMMR.java
new file mode 100644
index 0000000..362cb4d
--- /dev/null
+++ b/src/main/java/com/ibm/bi/dml/lops/WeightedUnaryMMR.java
@@ -0,0 +1,162 @@
+/**
+ * (C) Copyright IBM Corp. 2010, 2015
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package com.ibm.bi.dml.lops;
+
+import com.ibm.bi.dml.lops.LopProperties.ExecLocation;
+import com.ibm.bi.dml.lops.LopProperties.ExecType;
+import com.ibm.bi.dml.lops.Unary.OperationTypes;
+import com.ibm.bi.dml.lops.WeightedUnaryMM.WUMMType;
+import com.ibm.bi.dml.lops.compile.JobType;
+import com.ibm.bi.dml.parser.Expression.DataType;
+import com.ibm.bi.dml.parser.Expression.ValueType;
+
+/**
+ *
+ */
+public class WeightedUnaryMMR extends Lop
+{
+ public static final String OPCODE = "redwumm";
+
+ private WUMMType _wummType = null;
+ private OperationTypes _uop = null;
+ private boolean _cacheU = false;
+ private boolean _cacheV = false;
+
+ public WeightedUnaryMMR(Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, WUMMType wt, OperationTypes op, boolean cacheU, boolean cacheV, ExecType et)
+ throws LopsException
+ {
+ super(Lop.Type.WeightedUMM, dt, vt);
+ addInput(input1); //X
+ addInput(input2); //U
+ addInput(input3); //V
+ input1.addOutput(this);
+ input2.addOutput(this);
+ input3.addOutput(this);
+
+ //setup mapmult parameters
+ _wummType = wt;
+ _uop = op;
+ _cacheU = cacheU;
+ _cacheV = cacheV;
+ setupLopProperties(et);
+ }
+
+ /**
+ *
+ * @param et
+ * @throws LopsException
+ */
+ private void setupLopProperties( ExecType et )
+ throws LopsException
+ {
+ if( et == ExecType.MR )
+ {
+ //setup MR parameters
+ boolean breaksAlignment = true;
+ boolean aligner = false;
+ boolean definesMRJob = false;
+ lps.addCompatibility(JobType.GMR);
+ lps.addCompatibility(JobType.DATAGEN);
+ lps.setProperties( inputs, ExecType.MR, ExecLocation.Reduce, breaksAlignment, aligner, definesMRJob );
+ }
+ else //Spark/CP
+ {
+ //setup Spark parameters
+ boolean breaksAlignment = false;
+ boolean aligner = false;
+ boolean definesMRJob = false;
+ lps.addCompatibility(JobType.INVALID);
+ lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+ }
+ }
+
+ public String toString() {
+ return "Operation = WeightedUMMR";
+ }
+
+ @Override
+ public String getInstructions(int input1, int input2, int input3, int output)
+ throws LopsException
+ {
+ return getInstructions(
+ String.valueOf(input1),
+ String.valueOf(input2),
+ String.valueOf(input3),
+ String.valueOf(output));
+ }
+
+ @Override
+ public String getInstructions(String input1, String input2, String input3, String output)
+ throws LopsException
+ {
+ StringBuilder sb = new StringBuilder();
+
+ sb.append(getExecType());
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append(OPCODE);
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append(Unary.getOpcode(_uop));
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append( getInputs().get(0).prepInputOperand(input1));
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append( getInputs().get(1).prepInputOperand(input2));
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append( getInputs().get(2).prepInputOperand(input3));
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append( prepOutputOperand(output));
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append(_wummType);
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append(_cacheU);
+
+ sb.append(Lop.OPERAND_DELIMITOR);
+ sb.append(_cacheV);
+
+ return sb.toString();
+ }
+
+ @Override
+ public boolean usesDistributedCache()
+ {
+ if( _cacheU || _cacheV )
+ return true;
+ else
+ return false;
+ }
+
+ @Override
+ public int[] distributedCacheInputIndex()
+ {
+ if( !_cacheU && !_cacheV )
+ return new int[]{-1};
+ else if( _cacheU && !_cacheV )
+ return new int[]{2};
+ else if( !_cacheU && _cacheV )
+ return new int[]{3};
+ else
+ return new int[]{2,3};
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/functionobjects/Multiply2.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/functionobjects/Multiply2.java b/src/main/java/com/ibm/bi/dml/runtime/functionobjects/Multiply2.java
index af07c0b..01af490 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/functionobjects/Multiply2.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/functionobjects/Multiply2.java
@@ -40,6 +40,11 @@ public class Multiply2 extends ValueFunction
}
@Override
+ public double execute(double in1) {
+ return in1 + in1; //ignore in2 because always 2;
+ }
+
+ @Override
public double execute(double in1, double in2) {
return in1 + in1; //ignore in2 because always 2;
}
@@ -63,5 +68,4 @@ public class Multiply2 extends ValueFunction
return in1 + in1; //ignore in2 because always 2;
}
-
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/functionobjects/Power2.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/functionobjects/Power2.java b/src/main/java/com/ibm/bi/dml/runtime/functionobjects/Power2.java
index 1b13923..87d91b8 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/functionobjects/Power2.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/functionobjects/Power2.java
@@ -42,6 +42,11 @@ public class Power2 extends ValueFunction
}
@Override
+ public double execute(double in1) {
+ return in1*in1; //ignore in2 because always 2;
+ }
+
+ @Override
public double execute(double in1, double in2) {
return in1*in1; //ignore in2 because always 2;
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/functionobjects/ValueFunction.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/functionobjects/ValueFunction.java b/src/main/java/com/ibm/bi/dml/runtime/functionobjects/ValueFunction.java
index fd6455c..af42955 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/functionobjects/ValueFunction.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/functionobjects/ValueFunction.java
@@ -21,6 +21,5 @@ import java.io.Serializable;
public class ValueFunction extends FunctionObject implements Serializable
{
-
private static final long serialVersionUID = -4985988545393861058L;
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/instructions/CPInstructionParser.java b/src/main/java/com/ibm/bi/dml/runtime/instructions/CPInstructionParser.java
index cf3dcaf..c9744a9 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/instructions/CPInstructionParser.java
@@ -203,6 +203,7 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "wsigmoid", CPINSTRUCTION_TYPE.Quaternary);
String2CPInstructionType.put( "wdivmm" , CPINSTRUCTION_TYPE.Quaternary);
String2CPInstructionType.put( "wcemm" , CPINSTRUCTION_TYPE.Quaternary);
+ String2CPInstructionType.put( "wumm" , CPINSTRUCTION_TYPE.Quaternary);
// User-defined function Opcodes
String2CPInstructionType.put( "extfunct" , CPINSTRUCTION_TYPE.External);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/instructions/InstructionUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/instructions/InstructionUtils.java b/src/main/java/com/ibm/bi/dml/runtime/instructions/InstructionUtils.java
index 4f0fc1c..d32ca61 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/instructions/InstructionUtils.java
@@ -34,6 +34,8 @@ import com.ibm.bi.dml.lops.WeightedSigmoid;
import com.ibm.bi.dml.lops.WeightedSigmoidR;
import com.ibm.bi.dml.lops.WeightedSquaredLoss;
import com.ibm.bi.dml.lops.WeightedSquaredLossR;
+import com.ibm.bi.dml.lops.WeightedUnaryMM;
+import com.ibm.bi.dml.lops.WeightedUnaryMMR;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.functionobjects.And;
@@ -799,7 +801,9 @@ public class InstructionUtils
|| WeightedSigmoidR.OPCODE.equalsIgnoreCase(opcode) //redwsigmoid
|| WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) //mapwdivmm
|| WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode) //redwdivmm
- || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(opcode) //mapwdcemm
- || WeightedCrossEntropyR.OPCODE.equalsIgnoreCase(opcode); //redwdcemm
+ || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(opcode) //mapwcemm
+ || WeightedCrossEntropyR.OPCODE.equalsIgnoreCase(opcode) //redwcemm
+ || WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) //mapwumm
+ || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode); //redwumm
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/instructions/MRInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/instructions/MRInstructionParser.java b/src/main/java/com/ibm/bi/dml/runtime/instructions/MRInstructionParser.java
index 51e5c84..932cb08 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/instructions/MRInstructionParser.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/instructions/MRInstructionParser.java
@@ -31,6 +31,8 @@ import com.ibm.bi.dml.lops.WeightedSigmoid;
import com.ibm.bi.dml.lops.WeightedSigmoidR;
import com.ibm.bi.dml.lops.WeightedSquaredLoss;
import com.ibm.bi.dml.lops.WeightedSquaredLossR;
+import com.ibm.bi.dml.lops.WeightedUnaryMM;
+import com.ibm.bi.dml.lops.WeightedUnaryMMR;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.instructions.mr.AggregateBinaryInstruction;
@@ -224,6 +226,8 @@ public class MRInstructionParser extends InstructionParser
String2MRInstructionType.put( WeightedDivMMR.OPCODE, MRINSTRUCTION_TYPE.Quaternary);
String2MRInstructionType.put( WeightedCrossEntropy.OPCODE, MRINSTRUCTION_TYPE.Quaternary);
String2MRInstructionType.put( WeightedCrossEntropyR.OPCODE,MRINSTRUCTION_TYPE.Quaternary);
+ String2MRInstructionType.put( WeightedUnaryMM.OPCODE, MRINSTRUCTION_TYPE.Quaternary);
+ String2MRInstructionType.put( WeightedUnaryMMR.OPCODE, MRINSTRUCTION_TYPE.Quaternary);
// Combine Instruction Opcodes
String2MRInstructionType.put( "combinebinary" , MRINSTRUCTION_TYPE.CombineBinary);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/instructions/SPInstructionParser.java b/src/main/java/com/ibm/bi/dml/runtime/instructions/SPInstructionParser.java
index d6ba5ec..95102be 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/instructions/SPInstructionParser.java
@@ -29,6 +29,8 @@ import com.ibm.bi.dml.lops.WeightedSigmoid;
import com.ibm.bi.dml.lops.WeightedSigmoidR;
import com.ibm.bi.dml.lops.WeightedSquaredLoss;
import com.ibm.bi.dml.lops.WeightedSquaredLossR;
+import com.ibm.bi.dml.lops.WeightedUnaryMM;
+import com.ibm.bi.dml.lops.WeightedUnaryMMR;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.instructions.spark.AggregateTernarySPInstruction;
@@ -224,6 +226,8 @@ public class SPInstructionParser extends InstructionParser {
String2SPInstructionType.put( WeightedDivMMR.OPCODE, SPINSTRUCTION_TYPE.Quaternary);
String2SPInstructionType.put( WeightedCrossEntropy.OPCODE, SPINSTRUCTION_TYPE.Quaternary);
String2SPInstructionType.put( WeightedCrossEntropyR.OPCODE,SPINSTRUCTION_TYPE.Quaternary);
+ String2SPInstructionType.put( WeightedUnaryMM.OPCODE, SPINSTRUCTION_TYPE.Quaternary);
+ String2SPInstructionType.put( WeightedUnaryMMR.OPCODE, SPINSTRUCTION_TYPE.Quaternary);
//cumsum/cumprod/cummin/cummax
String2SPInstructionType.put( "ucumack+" , SPINSTRUCTION_TYPE.CumsumAggregate);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/instructions/cp/QuaternaryCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/instructions/cp/QuaternaryCPInstruction.java b/src/main/java/com/ibm/bi/dml/runtime/instructions/cp/QuaternaryCPInstruction.java
index 4188235..42889fc 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/instructions/cp/QuaternaryCPInstruction.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/instructions/cp/QuaternaryCPInstruction.java
@@ -21,6 +21,7 @@ import com.ibm.bi.dml.lops.WeightedDivMM.WDivMMType;
import com.ibm.bi.dml.lops.WeightedSigmoid.WSigmoidType;
import com.ibm.bi.dml.lops.WeightedSquaredLoss.WeightsType;
import com.ibm.bi.dml.lops.WeightedCrossEntropy.WCeMMType;
+import com.ibm.bi.dml.lops.WeightedUnaryMM.WUMMType;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.controlprogram.context.ExecutionContext;
@@ -93,6 +94,19 @@ public class QuaternaryCPInstruction extends ComputationCPInstruction
else if( opcode.equalsIgnoreCase("wcemm") )
return new QuaternaryCPInstruction(new QuaternaryOperator(WCeMMType.valueOf(parts[5])), in1, in2, in3, null, out, k, opcode, inst);
}
+ else if( opcode.equalsIgnoreCase("wumm") )
+ {
+ InstructionUtils.checkNumFields ( parts, 7 );
+
+ String uopcode = parts[1];
+ CPOperand in1 = new CPOperand(parts[2]);
+ CPOperand in2 = new CPOperand(parts[3]);
+ CPOperand in3 = new CPOperand(parts[4]);
+ CPOperand out = new CPOperand(parts[5]);
+ int k = Integer.parseInt(parts[7]);
+
+ return new QuaternaryCPInstruction(new QuaternaryOperator(WUMMType.valueOf(parts[6]),uopcode), in1, in2, in3, null, out, k, opcode, inst);
+ }
throw new DMLRuntimeException("Unexpected opcode in QuaternaryCPInstruction: " + inst);
}
@@ -124,7 +138,7 @@ public class QuaternaryCPInstruction extends ComputationCPInstruction
ec.releaseMatrixInput(input4.getName());
ec.setVariable(output.getName(), new DoubleObject(out.getValue(0, 0)));
}
- else { //wsigmoid / wdivmm
+ else { //wsigmoid / wdivmm / wumm
ec.setMatrixOutput(output.getName(), (MatrixBlock)out);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/instructions/mr/QuaternaryInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/instructions/mr/QuaternaryInstruction.java b/src/main/java/com/ibm/bi/dml/runtime/instructions/mr/QuaternaryInstruction.java
index 814150f..087f554 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/instructions/mr/QuaternaryInstruction.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/instructions/mr/QuaternaryInstruction.java
@@ -26,6 +26,9 @@ import com.ibm.bi.dml.lops.WeightedSigmoid.WSigmoidType;
import com.ibm.bi.dml.lops.WeightedSquaredLoss;
import com.ibm.bi.dml.lops.WeightedSquaredLoss.WeightsType;
import com.ibm.bi.dml.lops.WeightedSquaredLossR;
+import com.ibm.bi.dml.lops.WeightedUnaryMM;
+import com.ibm.bi.dml.lops.WeightedUnaryMM.WUMMType;
+import com.ibm.bi.dml.lops.WeightedUnaryMMR;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.functionobjects.SwapIndex;
@@ -111,7 +114,7 @@ public class QuaternaryInstruction extends MRInstruction implements IDistributed
//output size independent of chain type (scalar)
dimOut.set(1, 1, mc1.getRowsPerBlock(), mc1.getColsPerBlock());
}
- else if( qop.wtype2 != null ) { //wsigmoid
+ else if( qop.wtype2 != null || qop.wtype5 != null ) { //wsigmoid/wumm
//output size determined by main input
dimOut.set(mc1.getRows(), mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
}
@@ -169,6 +172,33 @@ public class QuaternaryInstruction extends MRInstruction implements IDistributed
return new QuaternaryInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, str);
}
+ else if( WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) //wumm
+ || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode) )
+ {
+ boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode);
+
+ //check number of fields (4 inputs, output, type)
+ if( isRed )
+ InstructionUtils.checkNumFields ( str, 8 );
+ else
+ InstructionUtils.checkNumFields ( str, 6 );
+
+ //parse instruction parts (without exec type)
+ String[] parts = InstructionUtils.getInstructionParts(str);
+
+ String uopcode = parts[1];
+ byte in1 = Byte.parseByte(parts[2]);
+ byte in2 = Byte.parseByte(parts[3]);
+ byte in3 = Byte.parseByte(parts[4]);
+ byte out = Byte.parseByte(parts[5]);
+ WUMMType wtype = WUMMType.valueOf(parts[6]);
+
+ //in mappers always through distcache, in reducers through distcache/shuffle
+ boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
+ boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
+
+ return new QuaternaryInstruction(new QuaternaryOperator(wtype,uopcode), in1, in2, in3, (byte)-1, out, cacheU, cacheV, str);
+ }
else //wsigmoid / wdivmm / wcemm
{
boolean isRed = opcode.startsWith("red");
@@ -303,7 +333,7 @@ public class QuaternaryInstruction extends MRInstruction implements IDistributed
if( qop.wtype1 != null || qop.wtype4 != null)
outIx.setIndexes(1, 1); //wsloss
- else if ( qop.wtype2 != null || qop.wtype3!=null && qop.wtype3.isBasic() )
+ else if ( qop.wtype2 != null || qop.wtype5 != null || qop.wtype3!=null && qop.wtype3.isBasic() )
outIx.setIndexes(inIx); //wsigmoid/wdivmm-basic
else { //wdivmm
boolean left = qop.wtype3.isLeft();
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/instructions/spark/QuaternarySPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/instructions/spark/QuaternarySPInstruction.java b/src/main/java/com/ibm/bi/dml/runtime/instructions/spark/QuaternarySPInstruction.java
index e7effbb..ac8dce8 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/instructions/spark/QuaternarySPInstruction.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/instructions/spark/QuaternarySPInstruction.java
@@ -39,6 +39,9 @@ import com.ibm.bi.dml.lops.WeightedSquaredLossR;
import com.ibm.bi.dml.lops.WeightedSigmoid.WSigmoidType;
import com.ibm.bi.dml.lops.WeightedSquaredLoss.WeightsType;
import com.ibm.bi.dml.lops.WeightedCrossEntropy.WCeMMType;
+import com.ibm.bi.dml.lops.WeightedUnaryMM;
+import com.ibm.bi.dml.lops.WeightedUnaryMM.WUMMType;
+import com.ibm.bi.dml.lops.WeightedUnaryMMR;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.controlprogram.context.ExecutionContext;
@@ -117,6 +120,30 @@ public class QuaternarySPInstruction extends ComputationSPInstruction
return new QuaternarySPInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
}
+ else if( WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) //wumm
+ || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode) )
+ {
+ boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode);
+
+ //check number of fields (4 inputs, output, type)
+ if( isRed )
+ InstructionUtils.checkNumFields ( parts, 8 );
+ else
+ InstructionUtils.checkNumFields ( parts, 6 );
+
+ String uopcode = parts[1];
+ CPOperand in1 = new CPOperand(parts[2]);
+ CPOperand in2 = new CPOperand(parts[3]);
+ CPOperand in3 = new CPOperand(parts[4]);
+ CPOperand out = new CPOperand(parts[5]);
+ WUMMType wtype = WUMMType.valueOf(parts[6]);
+
+ //in mappers always through distcache, in reducers through distcache/shuffle
+ boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
+ boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
+
+ return new QuaternarySPInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
+ }
else //map/redwsigmoid, map/redwdivmm, map/redwcemm
{
boolean isRed = opcode.startsWith("red");
@@ -171,7 +198,8 @@ public class QuaternarySPInstruction extends ComputationSPInstruction
if( WeightedSquaredLoss.OPCODE.equalsIgnoreCase(getOpcode())
|| WeightedSigmoid.OPCODE.equalsIgnoreCase(getOpcode())
|| WeightedDivMM.OPCODE.equalsIgnoreCase(getOpcode())
- || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(getOpcode()) )
+ || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(getOpcode())
+ || WeightedUnaryMM.OPCODE.equalsIgnoreCase(getOpcode()))
{
PartitionedBroadcastMatrix bc1 = sec.getBroadcastForVariable( input2.getName() );
PartitionedBroadcastMatrix bc2 = sec.getBroadcastForVariable( input3.getName() );
@@ -242,7 +270,7 @@ public class QuaternarySPInstruction extends ComputationSPInstruction
DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
sec.setVariable(output.getName(), ret);
}
- else //map/redwsigmoid, map/redwdivmm
+ else //map/redwsigmoid, map/redwdivmm, map/redwumm
{
//aggregation if required (map/redwdivmm)
if( qop.wtype3 != null && !qop.wtype3.isBasic() )
@@ -275,7 +303,7 @@ public class QuaternarySPInstruction extends ComputationSPInstruction
MatrixCharacteristics mcIn3 = sec.getMatrixCharacteristics(input3.getName());
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
- if( qop.wtype2 != null ) {
+ if( qop.wtype2 != null || qop.wtype5 != null ) {
//output size determined by main input
mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getRowsPerBlock(), mcIn1.getColsPerBlock());
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/com/ibm/bi/dml/runtime/matrix/data/LibMatrixMult.java
index ebab2c6..6bae2f8 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/matrix/data/LibMatrixMult.java
@@ -31,9 +31,11 @@ import com.ibm.bi.dml.lops.WeightedCrossEntropy.WCeMMType;
import com.ibm.bi.dml.lops.WeightedDivMM.WDivMMType;
import com.ibm.bi.dml.lops.WeightedSigmoid.WSigmoidType;
import com.ibm.bi.dml.lops.WeightedSquaredLoss.WeightsType;
+import com.ibm.bi.dml.lops.WeightedUnaryMM.WUMMType;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.DMLUnsupportedOperationException;
import com.ibm.bi.dml.runtime.functionobjects.SwapIndex;
+import com.ibm.bi.dml.runtime.functionobjects.ValueFunction;
import com.ibm.bi.dml.runtime.matrix.operators.ReorgOperator;
import com.ibm.bi.dml.runtime.util.UtilFunctions;
@@ -887,6 +889,101 @@ public class LibMatrixMult
//System.out.println("MMWCe "+wt.toString()+" k="+k+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
// "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
}
+
+ /**
+ *
+ * @param mW
+ * @param mU
+ * @param mV
+ * @param ret
+ * @param wt
+ * @throws DMLRuntimeException
+ */
+ public static void matrixMultWuMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn)
+ throws DMLRuntimeException
+ {
+ //check for empty result
+ if( mW.isEmptyBlock(false) ) {
+ ret.examSparsity(); //turn empty dense into sparse
+ return;
+ }
+
+ //Timing time = new Timing(true);
+
+ //pre-processing
+ ret.sparse = mW.sparse;
+ ret.allocateDenseOrSparseBlock();
+
+ //core weighted square sum mm computation
+ if( !mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock() )
+ matrixMultWuMMDense(mW, mU, mV, ret, wt, fn, 0, mW.rlen);
+ else if( mW.sparse && !mU.sparse && !mV.sparse && !mU.isEmptyBlock() && !mV.isEmptyBlock())
+ matrixMultWuMMSparseDense(mW, mU, mV, ret, wt, fn, 0, mW.rlen);
+ else
+ matrixMultWuMMGeneric(mW, mU, mV, ret, wt, fn, 0, mW.rlen);
+
+ //post-processing
+ ret.recomputeNonZeros();
+ ret.examSparsity();
+
+ //System.out.println("MMWu "+wt.toString()+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
+ // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop());
+ }
+
+ /**
+ *
+ * @param mW
+ * @param mU
+ * @param mV
+ * @param ret
+ * @param wt
+ * @param k
+ * @throws DMLRuntimeException
+ */
+ public static void matrixMultWuMM(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn, int k)
+ throws DMLRuntimeException
+ {
+ //check for empty result
+ if( mW.isEmptyBlock(false) ) {
+ ret.examSparsity(); //turn empty dense into sparse
+ return;
+ }
+
+ //check no parallelization benefit (fallback to sequential)
+ if (mW.rlen == 1) {
+ matrixMultWuMM(mW, mU, mV, ret, wt, fn);
+ return;
+ }
+
+ //Timing time = new Timing(true);
+
+ //pre-processing
+ ret.sparse = mW.sparse;
+ ret.allocateDenseOrSparseBlock();
+
+ try
+ {
+ ExecutorService pool = Executors.newFixedThreadPool(k);
+ ArrayList<MatrixMultWuTask> tasks = new ArrayList<MatrixMultWuTask>();
+ int blklen = (int)(Math.ceil((double)mW.rlen/k));
+ for( int i=0; i<k & i*blklen<mW.rlen; i++ )
+ tasks.add(new MatrixMultWuTask(mW, mU, mV, ret, wt, fn, i*blklen, Math.min((i+1)*blklen, mW.rlen)));
+ pool.invokeAll(tasks);
+ pool.shutdown();
+ ret.nonZeros = 0; //reset after execute
+ for( MatrixMultWuTask task : tasks )
+ ret.nonZeros += task.getPartialNnz();
+ }
+ catch (InterruptedException e) {
+ throw new DMLRuntimeException(e);
+ }
+
+ //post-processing (nnz maintained in parallel)
+ ret.examSparsity();
+
+ //System.out.println("MMWu "+wt.toString()+" k="+k+" ("+mW.isInSparseFormat()+","+mW.getNumRows()+","+mW.getNumColumns()+","+mW.getNonZeros()+")x" +
+ // "("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+") in "+time.stop() + ".");
+ }
//////////////////////////////////////////
// optimized matrix mult implementation //
@@ -2704,7 +2801,149 @@ public class LibMatrixMult
ret.quickSetValue(0, 0, wceval);
}
+
+ /**
+ *
+ * @param mW
+ * @param mU
+ * @param mV
+ * @param ret
+ * @param wt
+ * @param rl
+ * @param ru
+ * @throws DMLRuntimeException
+ */
+ private static void matrixMultWuMMDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn, int rl, int ru)
+ throws DMLRuntimeException
+ {
+ double[] w = mW.denseBlock;
+ double[] c = ret.denseBlock;
+ double[] u = mU.denseBlock;
+ double[] v = mV.denseBlock;
+ final int n = mW.clen;
+ final int cd = mU.clen;
+
+ //note: cannot compute U %*% t(V) in-place of result w/ regular mm because
+ //t(V) comes in transformed form and hence would require additional memory
+ boolean flagmult = (wt==WUMMType.MULT);
+
+ //approach: iterate over non-zeros of w, selective mm computation
+ //cache-conscious blocking: due to blocksize constraint (default 1000),
+ //a blocksize of 16 allows to fit blocks of UV into L2 cache (256KB)
+
+ final int blocksizeIJ = 16; //u/v block (max at typical L2 size)
+
+ //blocked execution
+ for( int bi = rl; bi < ru; bi+=blocksizeIJ )
+ for( int bj = 0, bimin = Math.min(ru, bi+blocksizeIJ); bj < n; bj+=blocksizeIJ )
+ {
+ int bjmin = Math.min(n, bj+blocksizeIJ);
+
+ //core wsigmoid computation
+ for( int i=bi, ix=bi*n, uix=bi*cd; i<bimin; i++, ix+=n, uix+=cd )
+ for( int j=bj, vix=bj*cd; j<bjmin; j++, vix+=cd) {
+ double wij = w[ix+j];
+ if( wij != 0 )
+ c[ix+j] = wumm(wij, u, v, uix, vix, flagmult, fn, cd);
+ }
+ }
+ }
+
+ /**
+ *
+ * @param mX
+ * @param mU
+ * @param mV
+ * @param mW
+ * @param ret
+ * @param wt
+ * @param rl
+ * @param ru
+ * @throws DMLRuntimeException
+ */
+ private static void matrixMultWuMMSparseDense(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn, int rl, int ru)
+ throws DMLRuntimeException
+ {
+ SparseRow[] w = mW.sparseRows;
+ SparseRow[] c = ret.sparseRows;
+ double[] u = mU.denseBlock;
+ double[] v = mV.denseBlock;
+ final int n = mW.clen;
+ final int cd = mU.clen;
+
+ boolean flagmult = (wt==WUMMType.MULT);
+
+ //approach: iterate over non-zeros of w, selective mm computation
+ for( int i=rl, uix=rl*cd; i<ru; i++, uix+=cd )
+ if( w[i] != null && !w[i].isEmpty() ) {
+ int wlen = w[i].size();
+ int[] wix = w[i].getIndexContainer();
+ double[] wval = w[i].getValueContainer();
+ c[i] = new SparseRow(wlen, n);
+
+ for( int k=0; k<wlen; k++ ) {
+ double cval = wumm(wval[k], u, v, uix, wix[k]*cd, flagmult, fn, cd);
+ c[i].append(wix[k], cval);
+ }
+ }
+ }
+
+ /**
+ *
+ * @param mX
+ * @param mU
+ * @param mV
+ * @param mW
+ * @param ret
+ * @param wt
+ * @param rl
+ * @param ru
+ * @throws DMLRuntimeException
+ */
+ private static void matrixMultWuMMGeneric (MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn, int rl, int ru)
+ throws DMLRuntimeException
+ {
+ final int n = mW.clen;
+ final int cd = mU.clen;
+
+ boolean flagmult = (wt==WUMMType.MULT);
+
+ //approach: iterate over non-zeros of w, selective mm computation
+ if( mW.sparse ) //SPARSE
+ {
+ //w and c always in same representation
+ SparseRow[] w = mW.sparseRows;
+ SparseRow[] c = ret.sparseRows;
+
+ for( int i=rl; i<ru; i++ )
+ if( w[i] != null && !w[i].isEmpty() ) {
+ int wlen = w[i].size();
+ int[] wix = w[i].getIndexContainer();
+ double[] wval = w[i].getValueContainer();
+ c[i] = new SparseRow(wlen, n);
+
+ for( int k=0; k<wlen; k++ ) {
+ double cval = wumm(wval[k], mU, mV, i, wix[k], flagmult, fn, cd);
+ c[i].append(wix[k], cval);
+ }
+ }
+ }
+ else //DENSE
+ {
+ //w and c always in same representation
+ double[] w = mW.denseBlock;
+ double[] c = ret.denseBlock;
+
+ for( int i=rl, ix=rl*n; i<ru; i++ )
+ for( int j=0; j<n; j++, ix++) {
+ double wij = w[ix];
+ if( wij != 0 ) {
+ c[ix] = wumm(wij, mU, mV, i, j, flagmult, fn, cd);
+ }
+ }
+ }
+ }
////////////////////////////////////////////
// performance-relevant utility functions //
@@ -3312,6 +3551,58 @@ public class LibMatrixMult
/**
*
+ * @param wij
+ * @param u
+ * @param v
+ * @param uix
+ * @param vix
+ * @param flagmult
+ * @param fn
+ * @param len
+ * @return
+ * @throws DMLRuntimeException
+ */
+ private static double wumm( final double wij, double[] u, double[] v, final int uix, final int vix, final boolean flagmult, ValueFunction fn, final int len )
+ throws DMLRuntimeException
+ {
+ //compute dot product over ui vj
+ double uvij = dotProduct(u, v, uix, vix, len);
+
+ //compute unary operations
+ double cval = fn.execute(uvij);
+
+ //compute weighted output
+ return flagmult ? wij * cval : wij / cval;
+ }
+
+ /**
+ *
+ * @param wij
+ * @param u
+ * @param v
+ * @param uix
+ * @param vix
+ * @param flagminus
+ * @param flaglog
+ * @param len
+ * @return
+ * @throws DMLRuntimeException
+ */
+ private static double wumm( final double wij, MatrixBlock u, MatrixBlock v, final int uix, final int vix, final boolean flagmult, ValueFunction fn, final int len )
+ throws DMLRuntimeException
+ {
+ //compute dot product over ui vj
+ double uvij = dotProductGeneric(u, v, uix, vix, len);
+
+ //compute unary operations
+ double cval = fn.execute(uvij);
+
+ //compute weighted output
+ return flagmult ? wij * cval : wij / cval;
+ }
+
+ /**
+ *
* @param a
* @param b
* @param ai
@@ -3910,4 +4201,54 @@ public class LibMatrixMult
return _ret.quickGetValue(0, 0);
}
}
+
+ /**
+ *
+ */
+ private static class MatrixMultWuTask implements Callable<Object>
+ {
+ private MatrixBlock _mW = null;
+ private MatrixBlock _mU = null;
+ private MatrixBlock _mV = null;
+ private MatrixBlock _ret = null;
+ private WUMMType _wt = null;
+ private ValueFunction _fn = null;
+ private int _rl = -1;
+ private int _ru = -1;
+ private long _nnz = -1;
+
+ protected MatrixMultWuTask(MatrixBlock mW, MatrixBlock mU, MatrixBlock mV, MatrixBlock ret, WUMMType wt, ValueFunction fn, int rl, int ru)
+ throws DMLRuntimeException
+ {
+ _mW = mW;
+ _mU = mU;
+ _mV = mV;
+ _ret = ret;
+ _wt = wt;
+ _fn = fn;
+ _rl = rl;
+ _ru = ru;
+ }
+
+ @Override
+ public Object call() throws DMLRuntimeException
+ {
+ //core weighted square sum mm computation
+ if( !_mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock() )
+ matrixMultWuMMDense(_mW, _mU, _mV, _ret, _wt, _fn, _rl, _ru);
+ else if( _mW.sparse && !_mU.sparse && !_mV.sparse && !_mU.isEmptyBlock() && !_mV.isEmptyBlock())
+ matrixMultWuMMSparseDense(_mW, _mU, _mV, _ret, _wt, _fn, _rl, _ru);
+ else
+ matrixMultWuMMGeneric(_mW, _mU, _mV, _ret, _wt, _fn, _rl, _ru);
+
+ //maintain block nnz (upper bounds inclusive)
+ _nnz = _ret.recomputeNonZeros(_rl, _ru-1, 0, _ret.getNumColumns()-1);
+
+ return null;
+ }
+
+ public long getPartialNnz() {
+ return _nnz;
+ }
+ }
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/matrix/data/MatrixBlock.java b/src/main/java/com/ibm/bi/dml/runtime/matrix/data/MatrixBlock.java
index 38ab456..18fce81 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/matrix/data/MatrixBlock.java
@@ -5914,7 +5914,7 @@ public class MatrixBlock extends MatrixValue implements Externalizable
//prepare intermediates and output
if( qop.wtype1 != null || qop.wtype4 != null )
R.reset(1, 1, false);
- else if( qop.wtype2 != null )
+ else if( qop.wtype2 != null || qop.wtype5 != null )
R.reset(rlen, clen, sparse);
else if( qop.wtype3 != null ) {
MatrixCharacteristics mc = qop.wtype3.computeOutputCharacteristics(X.rlen, X.clen, U.clen);
@@ -5948,6 +5948,12 @@ public class MatrixBlock extends MatrixValue implements Externalizable
else
LibMatrixMult.matrixMultWCeMM(X, U, V, R, qop.wtype4);
}
+ else if( qop.wtype5 != null ){ //wumm
+ if( k > 1 )
+ LibMatrixMult.matrixMultWuMM(X, U, V, R, qop.wtype5, qop.fn, k);
+ else
+ LibMatrixMult.matrixMultWuMM(X, U, V, R, qop.wtype5, qop.fn);
+ }
return R;
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d70c4524/src/main/java/com/ibm/bi/dml/runtime/matrix/operators/QuaternaryOperator.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/runtime/matrix/operators/QuaternaryOperator.java b/src/main/java/com/ibm/bi/dml/runtime/matrix/operators/QuaternaryOperator.java
index ada6f7d..c0b1412 100644
--- a/src/main/java/com/ibm/bi/dml/runtime/matrix/operators/QuaternaryOperator.java
+++ b/src/main/java/com/ibm/bi/dml/runtime/matrix/operators/QuaternaryOperator.java
@@ -22,8 +22,11 @@ import com.ibm.bi.dml.lops.WeightedDivMM.WDivMMType;
import com.ibm.bi.dml.lops.WeightedSigmoid.WSigmoidType;
import com.ibm.bi.dml.lops.WeightedSquaredLoss.WeightsType;
import com.ibm.bi.dml.lops.WeightedCrossEntropy.WCeMMType;
+import com.ibm.bi.dml.lops.WeightedUnaryMM.WUMMType;
import com.ibm.bi.dml.runtime.functionobjects.Builtin;
-import com.ibm.bi.dml.runtime.functionobjects.FunctionObject;
+import com.ibm.bi.dml.runtime.functionobjects.Multiply2;
+import com.ibm.bi.dml.runtime.functionobjects.Power2;
+import com.ibm.bi.dml.runtime.functionobjects.ValueFunction;
public class QuaternaryOperator extends Operator
{
@@ -34,8 +37,9 @@ public class QuaternaryOperator extends Operator
public WSigmoidType wtype2 = null;
public WDivMMType wtype3 = null;
public WCeMMType wtype4 = null;
+ public WUMMType wtype5 = null;
- public FunctionObject fn;
+ public ValueFunction fn;
/**
* wsloss
@@ -73,4 +77,21 @@ public class QuaternaryOperator extends Operator
public QuaternaryOperator( WCeMMType wt ) {
wtype4 = wt;
}
+
+ /**
+ * wumm
+ *
+ * @param wt
+ * @param op
+ */
+ public QuaternaryOperator( WUMMType wt, String op ) {
+ wtype5 = wt;
+
+ if( op.equals("^2") )
+ fn = Power2.getPower2FnObject();
+ else if( op.equals("*2") )
+ fn = Multiply2.getMultiply2FnObject();
+ else
+ fn = Builtin.getBuiltinFnObject(op);
+ }
}