You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/02/27 18:36:04 UTC
[2/9] incubator-systemml git commit: [SYSTEMML-1285] New basic code
generator for operator fusion
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/CplanRegister.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CplanRegister.java b/src/main/java/org/apache/sysml/hops/codegen/template/CplanRegister.java
new file mode 100644
index 0000000..a4bcffe
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CplanRegister.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.template;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.Map.Entry;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.cplan.CNode;
+import org.apache.sysml.hops.codegen.cplan.CNodeCell;
+import org.apache.sysml.hops.codegen.cplan.CNodeData;
+import org.apache.sysml.hops.codegen.cplan.CNodeRowAggVector;
+import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
+import org.apache.sysml.hops.codegen.template.BaseTpl.TemplateType;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.matrix.data.Pair;
+import org.apache.sysml.utils.Statistics;
+
+public class CplanRegister {
+
+ //HashMap: key: TemplateType - Value: List of all the patterns fused by that template
+ //LinkedHashMap: key: HopID of the original hop to be fused , Value: Input hops to the fused operation
+ //Note: LinkedHashMap holds intermediate cplans as well (e.g, log(exp(round(X))) ) We store in the LinkedHashMao three keys
+ //for the three hops (log, exp and round). The key that was inserted last is the key of the hop to be fused
+
+ private HashMap<TemplateType, ArrayList<LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>>> _cplans;
+
+ public CplanRegister() {
+ _cplans = new HashMap<TemplateType, ArrayList<LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>>>();
+ }
+
+ public void insertCpplans(TemplateType type, LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> cplans) {
+ if( !_cplans.containsKey(type) )
+ _cplans.put(type, new ArrayList<LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>>());
+
+ _cplans.get(type).add(cplans);
+
+ if( DMLScript.STATISTICS )
+ Statistics.incrementCodegenCPlanCompile(1);
+ //note: cplans.size() would also contain all subsets of cpplans
+ }
+
+ public boolean containsHop(TemplateType type, long hopID) {
+ if(!_cplans.containsKey(type))
+ return false;
+ for (LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> cpplans : _cplans.get(type) )
+ if(cpplans.containsKey(hopID))
+ return true;
+
+ return false;
+ }
+
+ public LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> getTopLevelCplans()
+ {
+ if( _cplans.isEmpty() )
+ return new LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>();
+
+ //resolve conflicts, i.e., overlap, between template types
+ resolvePlanConflicts();
+
+ //extract top level (subsuming) cplans per type and operator chain
+ LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> ret = new LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>();
+ for (TemplateType key : _cplans.keySet()) {
+ for (LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> intermediateCplans : _cplans.get(key)) {
+ Entry<Long, Pair<Hop[],CNodeTpl>> cplan = TemplateUtils.getTopLevelCpplan(intermediateCplans);
+ if(cplan !=null)
+ ret.put(cplan.getKey(), cplan.getValue());
+ }
+ }
+
+ //merge top level plans if possible //TODO move to rowagg template
+ ret = mergeRowAggregateCellwisePlans(ret);
+
+ return ret;
+ }
+
+ /**
+ * Resolves conflicts between overlapping cplans of different types.
+ *
+ */
+ private void resolvePlanConflicts()
+ {
+ //get different plan categories
+ ArrayList<LinkedHashMap<Long, Pair<Hop[], CNodeTpl>>> cellwisePlans = _cplans.get(TemplateType.CellTpl);
+ ArrayList<LinkedHashMap<Long, Pair<Hop[], CNodeTpl>>> outerprodPlans = _cplans.get(TemplateType.OuterProductTpl);
+ ArrayList<LinkedHashMap<Long, Pair<Hop[], CNodeTpl>>> rowaggPlans = _cplans.get(TemplateType.RowAggTpl);
+
+ //prefer outer product plans over cellwise plans -> remove overlap
+ if( cellwisePlans != null && outerprodPlans != null ) {
+ for( LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> outerprodCplan : outerprodPlans ) {
+ for( LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> map : cellwisePlans )
+ for( Long key : outerprodCplan.keySet() )
+ map.remove(key);
+ }
+ }
+
+ //prefer row aggregate plans over cellwise plans -> remove overlap
+ if( cellwisePlans != null && rowaggPlans != null ) {
+ for( LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> rowaggCplan : rowaggPlans ) {
+ for( LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> map : cellwisePlans )
+ for( Long key : rowaggCplan.keySet() )
+ map.remove(key);
+ }
+ }
+ }
+
+ private static LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> mergeRowAggregateCellwisePlans(LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> plans)
+ {
+ LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> ret = new LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>(plans);
+
+ //extract row aggregate templates
+ HashMap<Long, Pair<Hop[],CNodeTpl>> rowaggPlans = new HashMap<Long, Pair<Hop[],CNodeTpl>>();
+ for( Entry<Long, Pair<Hop[],CNodeTpl>> e : plans.entrySet() )
+ if( e.getValue().getValue() instanceof CNodeRowAggVector )
+ rowaggPlans.put(e.getKey(), e.getValue());
+
+ //probe and merge row aggregate secondary inputs (by definition vectors)
+ for( Entry<Long, Pair<Hop[],CNodeTpl>> e : rowaggPlans.entrySet() ) {
+ //check all inputs for existing cell plans
+ Hop[] inputs = e.getValue().getKey();
+ for( int i=1; i<inputs.length; i++ ) {
+ long inhopID = inputs[i].getHopID();
+ if( ret.containsKey(inhopID) && ret.get(inhopID).getValue() instanceof CNodeCell
+ && !((CNodeCell)ret.get(inhopID).getValue()).hasMultipleConsumers() )
+ {
+ //merge row agg template
+ CNodeRowAggVector rowaggtpl = (CNodeRowAggVector) e.getValue().getValue();
+ CNodeCell celltpl = (CNodeCell)ret.get(inhopID).getValue();
+ celltpl.getInput().get(0).setDataType(DataType.MATRIX);
+ rowaggtpl.rReplaceDataNode(rowaggtpl.getOutput(), inhopID, celltpl.getOutput());
+ rowaggtpl.rInsertLookupNode(rowaggtpl.getOutput(),
+ ((CNodeData)celltpl.getInput().get(0)).getHopID(), new HashMap<Long, CNode>());
+ for( CNode input : celltpl.getInput() )
+ rowaggtpl.addInput(input);
+ HashSet<Long> inputIDs = TemplateUtils.rGetInputHopIDs(rowaggtpl.getOutput(), new HashSet<Long>());
+ Hop[] hops = TemplateUtils.mergeDistinct(inputIDs, inputs, ret.get(inhopID).getKey());
+ e.getValue().setKey(hops);
+
+ //remove cell template
+ ret.remove(inhopID);
+ }
+ }
+ }
+
+ return ret;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java
new file mode 100644
index 0000000..c202d3c
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/OuterProductTpl.java
@@ -0,0 +1,489 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.template;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+
+import org.apache.sysml.api.DMLException;
+import org.apache.sysml.hops.AggBinaryOp;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.ReOrgOp;
+import org.apache.sysml.hops.codegen.cplan.CNode;
+import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
+import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
+import org.apache.sysml.hops.codegen.cplan.CNodeData;
+import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct;
+import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
+import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
+import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType;
+import org.apache.sysml.runtime.matrix.data.Pair;
+
+public class OuterProductTpl extends BaseTpl {
+
+ public OuterProductTpl() {
+ super(TemplateType.OuterProductTpl);
+ }
+
+ private List<OpOp2> sparseDrivers = new ArrayList<OpOp2>(Arrays.asList(OpOp2.MULT, OpOp2.DIV));
+ private OutProdType _outerProductType = null;
+ private boolean _transposeOutput = false;
+ private boolean _transposeInput = false;
+
+ @Override
+ public boolean openTpl(Hop hop) {
+ // outerproduct ( output dimensions is greater than the common dimension)
+ return ( hop instanceof AggBinaryOp && ((AggBinaryOp)hop).isMatrixMultiply() && hop.dimsKnown()
+ && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()
+ && (hop.getDim1() > hop.getInput().get(0).getDim2() && hop.getDim2() > hop.getInput().get(1).getDim1()) );
+ }
+
+ @Override
+ public boolean findTplBoundaries(Hop h, CplanRegister cplanRegister) {
+ _endHop = h;//outerProduct tpl starts with endHop
+ HashMap<String,Hop> uniqueMatrixInputs = new HashMap<String,Hop>();
+ uniqueMatrixInputs.put("U", h.getInput().get(0));
+ if( h.getInput().get(1) instanceof ReorgOp && ((ReorgOp)h.getInput().get(1)).getOp() == ReOrgOp.TRANSPOSE )
+ uniqueMatrixInputs.put("V", h.getInput().get(1).getInput().get(0));
+ else
+ {
+ _transposeInput = true; // we need to transpose V to be tall and skinny
+ uniqueMatrixInputs.put("V", h.getInput().get(1));
+ }
+ rfindOuterProduct(_endHop, _endHop, uniqueMatrixInputs, h.getDim1(), h.getDim2(), new HashSet<Long>());
+
+ if(uniqueMatrixInputs.size() == 3 && _initialHop != null && _initialHop != _endHop ) //sanity check
+ {
+ //check if added matrices can be inferred from input matrices for example (X!=0) or abs(X) are not different from X
+ Hop commonChild = null;
+ if(! _adddedMatrices.isEmpty() ) {
+ //if addedMatrices does not have a common child with input X then do not compile
+ commonChild = TemplateUtils.commonChild(_adddedMatrices,uniqueMatrixInputs.get("X"));
+ if(commonChild == null ) // there are multiple matrices involved other than X
+ return false;
+ }
+ if(commonChild != null) {
+ _matrixInputs.add(commonChild); //add common child as the major input matrix
+ _adddedMatrices.add(uniqueMatrixInputs.get("X")); // put unique matrix as one of the additional matrices that is a chain of cell wise operations for the input matrix
+ }
+ else {
+ _matrixInputs.add(uniqueMatrixInputs.get("X")); //major matrix is the sparse driver
+ }
+ _matrixInputs.add(uniqueMatrixInputs.get("U"));
+
+ if(_transposeInput) {
+ ReorgOp transposeV = HopRewriteUtils.createTranspose(uniqueMatrixInputs.get("V"));
+ //ReorgOp transposeV = new ReorgOp("", uniqueMatrixInputs.get("V").getDataType(), uniqueMatrixInputs.get("V").getValueType(), ReOrgOp.TRANSPOSE, uniqueMatrixInputs.get("V"));
+ _matrixInputs.add(transposeV);
+ }
+ else {
+ _matrixInputs.add(uniqueMatrixInputs.get("V"));
+ }
+
+
+ //add also added matrices so that they can be interpreted as inputs
+ for(Hop addedMatrix : _adddedMatrices)
+ if(!_matrixInputs.contains(addedMatrix))
+ _matrixInputs.add(addedMatrix);
+
+ //add the children of _endHop ( this will handle the case for wdivmm right when I add the both t(V) and V as inputs
+ for (Hop hop: _endHop.getInput())
+ _matrixInputs.add(hop);
+
+ return true;
+ }
+ else
+ return false;
+
+ }
+ private void rfindOuterProduct(Hop child, Hop h, HashMap<String,Hop> uniqueMatrixInputs, long outerProductDim1, long outerProductDim2, HashSet<Long> memo)
+ {
+ if(memo.contains(h.getHopID()))
+ return;
+
+ if( ( h instanceof UnaryOp || h instanceof BinaryOp ) //unary operation or binary operation
+ && h.getDataType() == DataType.MATRIX // Output is a matrix
+ && h.getDim1() == outerProductDim1 && h.getDim2() == outerProductDim2 // output is the same size as the matrix
+ && TemplateUtils.isOperationSupported(h)) // operation is supported in codegen
+ {
+ if(h instanceof BinaryOp)
+ {
+
+ // find the other child rather than the one that called the parent
+ Hop otherChild = h.getInput().get(0) != child ? h.getInput().get(0) : h.getInput().get(1);
+
+ //if scalar or vector then we fuse it similar to the way we fuse celltpl,
+ if(TemplateUtils.isVectorOrScalar(otherChild))
+ {
+ _initialHop = h;
+ _outerProductType = OutProdType.CELLWISE_OUTER_PRODUCT;
+
+ }
+ // other child is a matrix
+ else
+ {
+ //if the binary operation is sparse safe (mult, div)
+ if(sparseDrivers.contains(((BinaryOp)h).getOp()) )
+ {
+ if(!uniqueMatrixInputs.containsKey("X"))
+ {
+ //extra sanity check
+ if(otherChild.getDim1() == outerProductDim1 && otherChild.getDim2() == outerProductDim2) {
+ uniqueMatrixInputs.put("X", otherChild);
+ _initialHop = h;
+ }
+ else { //matrix size does not match what is expected for X
+ return;
+ }
+ }
+ }
+ else {
+ _adddedMatrices.add(otherChild);
+ }
+ }
+ }
+ }
+
+ if( h instanceof AggBinaryOp && ((AggBinaryOp) h).isMatrixMultiply() && h != child) //make sure that the AggBinaryOp is not the same as the outerproduct that triggered this method
+ {
+ if(memo.contains(h.getInput().get(0).getHopID())) { // if current node is the parent for the left child then it is right matrix multiply
+
+ if (h.getInput().get(1) == uniqueMatrixInputs.get("V") )//right operand is V
+ {
+ _initialHop = h;
+ _outerProductType = OutProdType.RIGHT_OUTER_PRODUCT;
+ return;
+ }
+ //right operand is t(V)
+ else if(h.getInput().get(1) instanceof ReorgOp && ((ReorgOp)h.getInput().get(1)).getOp() == ReOrgOp.TRANSPOSE && h.getInput().get(1).getInput().get(0) == uniqueMatrixInputs.get("V") )
+ {
+ //replace V with T(V)
+ uniqueMatrixInputs.put("V", h.getInput().get(1));
+ _transposeInput = false; //no need to transpose Input
+ _initialHop = h;
+ _outerProductType = OutProdType.RIGHT_OUTER_PRODUCT;
+ return;
+ }
+ else
+ {
+ _initialHop = h.getInput().get(0); // set the child that was processed
+ return;
+ }
+ }
+ else {//left matrix multiply
+
+ //left is T(U)
+ if (h.getInput().get(0) instanceof ReorgOp && ((ReorgOp)h.getInput().get(0)).getOp() == ReOrgOp.TRANSPOSE && h.getInput().get(0).getInput().get(0) == uniqueMatrixInputs.get("U") )
+ {
+ _initialHop = h;
+ _outerProductType = OutProdType.LEFT_OUTER_PRODUCT;
+ //T(T(U) %*% ..)
+ for(Hop hParent : h.getParent())
+ if(hParent instanceof ReorgOp && ((ReorgOp)hParent).getOp() == ReOrgOp.TRANSPOSE) {
+ _initialHop = hParent; // set the transpose hop
+ return;
+ }
+ _transposeOutput = true;
+ return;
+ }
+ else {
+ _initialHop = h.getInput().get(1); // set the child that was processed
+ return;
+ }
+ }
+ }
+
+ if( h instanceof AggUnaryOp && ((AggUnaryOp) h).getOp() == AggOp.SUM
+ && ((AggUnaryOp) h).getDirection() == Direction.RowCol)
+ {
+ _initialHop = h;
+ _outerProductType = OutProdType.AGG_OUTER_PRODUCT;
+ return;
+ }
+
+ memo.add(h.getHopID());
+ //process parents recursively
+ for( Hop parent : h.getParent())
+ rfindOuterProduct(h, parent,uniqueMatrixInputs, outerProductDim1,outerProductDim2, memo);
+ }
+
+ ////////////////Helper methods for finding boundaries
+ private OutProdType getOuterProductType(Hop X, Hop U, Hop V, Hop out)
+ {
+ if (_outerProductType != null)
+ return _outerProductType;
+
+
+ //extra checks to infer type
+ if (out.getDataType() == DataType.SCALAR) // sum
+ {
+ _outerProductType = OutProdType.AGG_OUTER_PRODUCT;
+ }
+ else if( isDimsEqual(out,V) && out instanceof ReorgOp) // the second condition is added because sometimes V and U might be same dimensions if the dims of X are equal
+ {
+ _outerProductType = OutProdType.LEFT_OUTER_PRODUCT;
+ }
+ else if( isDimsEqual(out,U))
+ {
+ _outerProductType = OutProdType.RIGHT_OUTER_PRODUCT;
+ }
+ else if ( isDimsEqual(out,X) )
+ {
+ _outerProductType = OutProdType.CELLWISE_OUTER_PRODUCT;
+ }
+
+ return _outerProductType;
+ }
+
+ private static boolean isDimsEqual(Hop hop1, Hop hop2)
+ {
+ if(hop1.getDim1() == hop2.getDim1() && hop1.getDim2() == hop2.getDim2())
+ return true;
+ return false;
+ }
+
+ @Override
+ public LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> constructTplCplan(boolean compileLiterals) throws DMLException {
+
+ //re-assign the dimensions of inputs to match the generated code dimensions
+
+ //matrix X is a scalar in the generated code
+ _initialCnodes.add(new CNodeData(_matrixInputs.get(0), 1,1,DataType.SCALAR));
+
+ //matrix V
+ _initialCnodes.add(new CNodeData(_matrixInputs.get(1), 1,(int)_matrixInputs.get(1).getDim2(), DataType.MATRIX));
+
+ //matrix V
+ _initialCnodes.add(new CNodeData(_matrixInputs.get(2), 1,(int)_matrixInputs.get(2).getDim2(),DataType.MATRIX));
+
+ rConstructOuterProdCplan(_initialHop, _initialHop, new HashSet<Long>(), compileLiterals);
+ return _cpplans;
+ }
+
+ private void rConstructOuterProdCplan(Hop root, Hop hop, HashSet<Long> memo, boolean compileLiterals) throws DMLException
+ {
+ if( memo.contains(hop.getHopID()) )
+ return;
+ //process childs recursively
+ for( Hop c : hop.getInput() )
+ rConstructOuterProdCplan(root, c, memo, compileLiterals);
+
+ //organize the main inputs
+ Hop X, U, V;
+ X = _matrixInputs.get(0);
+ U = _matrixInputs.get(1);
+ V = _matrixInputs.get(2);
+ if(hop==_endHop)
+ _endHopReached = true;
+
+ // first hop to enter here should be _endHop
+ if(TemplateUtils.inputsAreGenerated(hop,_matrixInputs,_cpplans) && _endHopReached) // if direct children are DataGenOps, literals, or already in the cpplans then we are ready to generate code
+ {
+ CNodeOuterProduct outerProdTmpl = null;
+
+ //Fetch operands
+ CNode out = null;
+ ArrayList<CNode> addedCNodes = new ArrayList<CNode>();
+ ArrayList<Hop> addedHops = new ArrayList<Hop>();
+ ArrayList<CNode> cnodeData = TemplateUtils.fetchOperands(hop, _cpplans, addedCNodes, addedHops, _initialCnodes, compileLiterals);
+
+ //if operands are scalar or independent from X
+ boolean independentOperands = hop != root && (hop.getDataType() == DataType.SCALAR || TemplateUtils.isOperandsIndependent(cnodeData, addedHops, new String[]{_matrixInputs.get(0).getName(),_matrixInputs.get(1).getName(),_matrixInputs.get(2).getName()}));
+ if(!independentOperands)
+ {
+ if(hop instanceof UnaryOp)
+ {
+ CNode cdata1 = cnodeData.get(0);
+
+ //Primitive Operation has the same name as Hop Type OpOp1
+ String primitiveOpName = ((UnaryOp)hop).getOp().toString();
+ out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
+ }
+ else if(hop instanceof BinaryOp)
+ {
+ CNode cdata1 = cnodeData.get(0);
+ CNode cdata2 = cnodeData.get(1);
+
+ //Primitive Operation has the same name as Hop Type OpOp2
+ String primitiveOpName = ((BinaryOp)hop).getOp().toString();
+
+ if( (cdata1.getNumRows() > 1 && cdata1.getNumCols() == 1) || (cdata1.getNumRows() == 1 && cdata1.getNumCols() > 1) )
+ {
+ //second argument is always the vector
+ cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP);
+ //out = new CNodeBinary(tmp, cdata2, BinType.valueOf(primitiveOpName));
+ }
+ //cdata1 is a matrix
+ else if ( (cdata1.getNumRows() > 1 && cdata1.getNumCols() > 1) )
+ {
+ CellTpl cellTpl = new CellTpl();
+ cdata1 = cellTpl.fuseCellWise(hop.getInput().get(0), _matrixInputs.get(0), compileLiterals); // second argument is always matrix X
+ if (cdata1 == null)
+ return;
+ }
+ //cdata2 is vector
+ //else if( cdata2 instanceof CNodeData && (((CNodeData)cdata2).getNumRows() > 1 && ((CNodeData)cdata2).getNumCols() == 1) || ( ((CNodeData)cdata2).getNumRows() == 1 && ((CNodeData)cdata2).getNumCols() > 1 ))
+ if( (cdata2.getNumRows() > 1 && cdata2.getNumCols() == 1) || (cdata2.getNumRows() == 1 && cdata2.getNumCols() > 1) )
+ {
+ cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP);
+ //out = new CNodeBinary(cdata1, tmp, BinType.valueOf(primitiveOpName));
+ }
+ //cdata2 is a matrix
+ else if ( (cdata2.getNumRows() > 1 && cdata2.getNumCols() > 1) )
+ {
+ CellTpl cellTpl = new CellTpl();
+ cdata2 = cellTpl.fuseCellWise(hop.getInput().get(1), _matrixInputs.get(0), compileLiterals); // second argument is always matrix X
+ if (cdata2 == null)
+ return;
+ }
+ out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
+ }
+ else if(hop instanceof AggBinaryOp)
+ {
+ CNode cdata1 = cnodeData.get(0);
+ CNode cdata2 = cnodeData.get(1); // remember that we already fetched what is under transpose
+
+ //outerproduct U%*%t(V) then we should have passsed in V as the input
+ if(hop.getInput().get(0) == U && hop.getInput().get(1) instanceof ReorgOp && hop.getInput().get(1).getInput().get(0) == V)
+ {
+ //re-assign cdata2 to read V instead of t(V)
+ cdata2 = _initialCnodes.get(2); // the initialCNodes holds V
+ out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
+ }
+
+ //outerproduct U%*%V then we should have passsed in trnasposeV as the input
+ else if(hop.getInput().get(0) == U && V instanceof ReorgOp && V.getInput().get(0)== hop.getInput().get(1))
+ {
+ //re-assign cdata2 to read t(V) instead of V
+ cdata2 = _initialCnodes.get(2); // the initialCNodes holds transpose of V
+ out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
+ }
+ //outerproduct U%*%V but not right wdivmm so we did not pass T(V)
+ else if(hop.getInput().get(0) == U && hop.getInput().get(1) == V )
+ {
+ //re-assign cdata2 to read t(V) instead of V
+ cdata2 = _initialCnodes.get(2); // the initialCNodes holds transpose of V
+ out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
+ }
+
+ //left outerproduct (i.e., left operand is T(U) )
+ else if(hop.getInput().get(0) instanceof ReorgOp && hop.getInput().get(0).getInput().get(0) == U)
+ {
+ //scalar is cdata2
+ out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD);
+ }
+
+ //right outerproduct (i.e., right operand is V )
+ else if(hop.getInput().get(1) != U && hop.getInput().get(1) == V)
+ {
+ cdata2 = _initialCnodes.get(2);
+ out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
+ }
+
+ //right outerproduct (i.e., right operand is t(V) )
+ else if(hop.getInput().get(1) instanceof ReorgOp && hop.getInput().get(1).getInput().get(0) == V)
+ {
+ cdata2 = _initialCnodes.get(2);
+ out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
+ }
+ }
+ else if ( hop instanceof ReorgOp && ((ReorgOp)hop).getOp() == ReOrgOp.TRANSPOSE && root == hop) // if transpose wire the oinput in T( T(U ...)
+ {
+ out = cnodeData.get(0);
+ }
+ else if (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM && root == hop
+ && ((AggUnaryOp)hop).getDirection() == Direction.RowCol )
+ {
+ out = cnodeData.get(0);
+ }
+ }
+ // wire output to the template
+ if(out != null || independentOperands)
+ {
+ if(_cpplans.isEmpty())
+ {
+ //first initialization has to have the first variable as input
+ ArrayList<CNode> initialInputs = new ArrayList<CNode>();
+
+ if(independentOperands) // pass the hop itself as an input instead of its children
+ {
+ CNode c = new CNodeData(hop);
+ initialInputs.addAll(_initialCnodes);
+ initialInputs.add(c);
+ outerProdTmpl = new CNodeOuterProduct(initialInputs, c);
+ outerProdTmpl.setOutProdType(getOuterProductType(X, U, V, root));
+ outerProdTmpl.setTransposeOutput(_transposeOutput);
+ _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(new Hop[] {X,U,V,hop} ,outerProdTmpl));
+ }
+ else
+ {
+ initialInputs.addAll(_initialCnodes);
+ initialInputs.addAll(cnodeData);
+ outerProdTmpl = new CNodeOuterProduct(initialInputs, out);
+ outerProdTmpl.setOutProdType(getOuterProductType(X, U, V, root));
+ outerProdTmpl.setTransposeOutput(_transposeOutput);
+
+ Hop[] hopArray = new Hop[addedHops.size()+3];
+ hopArray[0] = X;
+ hopArray[1] = U;
+ hopArray[2] = V;
+
+ System.arraycopy( addedHops.toArray(), 0, hopArray, 3, addedHops.size());
+
+ _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(hopArray,outerProdTmpl));
+ }
+ }
+ else
+ {
+ if(independentOperands)
+ {
+ CNode c = new CNodeData(hop);
+ //clear Operands
+ addedCNodes.clear();
+ addedHops.clear();
+
+ //added the current hop as the input
+ addedCNodes.add(c);
+ addedHops.add(hop);
+ out = c;
+ }
+
+ //wire the output to existing or new template
+ TemplateUtils.setOutputToExistingTemplate(hop, out, _cpplans, addedCNodes, addedHops);
+ }
+ }
+ memo.add(hop.getHopID());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java
new file mode 100644
index 0000000..0aff9ae
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/RowAggTpl.java
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.template;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+
+import org.apache.sysml.api.DMLException;
+import org.apache.sysml.hops.AggBinaryOp;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.DataGenOp;
+import org.apache.sysml.hops.DataOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.hops.codegen.cplan.CNode;
+import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
+import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
+import org.apache.sysml.hops.codegen.cplan.CNodeData;
+import org.apache.sysml.hops.codegen.cplan.CNodeRowAggVector;
+import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
+import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
+import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.ReOrgOp;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.matrix.data.Pair;
+
+public class RowAggTpl extends BaseTpl {
+
+ public RowAggTpl() {
+ super(TemplateType.RowAggTpl);
+ }
+
+ @Override
+ public boolean openTpl(Hop hop) {
+ if ( (hop instanceof AggBinaryOp || hop instanceof AggUnaryOp) // An aggregate operation
+ && ( (hop.getDim1()==1 && hop.getDim2()!=1) || (hop.getDim1()!=1 && hop.getDim2()==1) ) )// the output is a vector
+ return true;
+ return false;
+ }
+
+ @Override
+ public boolean findTplBoundaries(Hop initialHop, CplanRegister cplanRegister) {
+ _initialHop = initialHop;
+ if(initialHop instanceof AggBinaryOp) {
+ // for simplicity we assume that the first operand should be t(X) however, it could be later on W.T(X)
+ if(initialHop.getInput().get(0) instanceof ReorgOp && ((ReorgOp)initialHop.getInput().get(0)).getOp()== ReOrgOp.TRANSPOSE )
+ _matrixInputs.add(initialHop.getInput().get(0).getInput().get(0)); //add what is under the transpose
+ else
+ return false;
+ }
+ rFindRowAggPattern(initialHop, new HashSet<Long>());
+
+ if(cplanRegister.containsHop(TemplateType.RowAggTpl, initialHop.getHopID()))
+ return false;
+
+ return (_endHop != null);
+ }
+
+
+ private void rFindRowAggPattern(Hop h, HashSet<Long> memo)
+ {
+ if(memo.contains(h.getHopID()) || h.getDataType() == DataType.SCALAR
+ || h instanceof DataOp || h instanceof DataGenOp || h instanceof LiteralOp) {
+ return;
+ }
+
+ boolean continueTraversing = false;
+ if (h instanceof AggBinaryOp)
+ {
+ if(h != _initialHop) {
+ //T(X) % ..... X %*% v ,check that X is the same as what we saw previously under transpose
+ if( h.getInput().get(0).equals(_matrixInputs.get(0)) && TemplateUtils.isVector(h.getInput().get(1)) ) {
+ _endHop = h;
+ }
+ }
+ else {
+ continueTraversing = true;
+ }
+ }
+ // if initial hop is colSums continue
+ else if(h instanceof AggUnaryOp && (((AggUnaryOp)_initialHop).getDirection() == Direction.Col && ((AggUnaryOp)_initialHop).getOp() == AggOp.SUM ) && h == _initialHop)
+ {
+ continueTraversing=true;
+ }
+ //rowSums(X)
+ else if(h instanceof AggUnaryOp && ((AggUnaryOp)h).getDirection() == Direction.Row && ((AggUnaryOp)h).getOp() == AggOp.SUM )
+ {
+ // check if root pattern is colsums
+ if((((AggUnaryOp)_initialHop).getDirection() == Direction.Col && ((AggUnaryOp)_initialHop).getOp() == AggOp.SUM ))
+ {
+
+ //TODO Now the pattern is limited to finding rowSums
+ _matrixInputs.add(h.getInput().get(0));
+ _endHop = h;
+ }
+ }
+ // unary operation || binary operation with first input as a matrix || binary operation with second input as a matrix
+ else if( ( h instanceof UnaryOp || (h instanceof BinaryOp && h.getInput().get(0).getDataType() == DataType.MATRIX && TemplateUtils.isVectorOrScalar(h.getInput().get(1))) || (h instanceof BinaryOp && TemplateUtils.isVectorOrScalar(h.getInput().get(0)) && h.getInput().get(1).getDataType() == DataType.MATRIX) ) //unary operation or binary operaiton with one matrix and a scalar
+ && h.getDataType() == DataType.MATRIX // Output is a matrix
+ && TemplateUtils.isOperationSupported(h) ) //Operation is supported in codegen
+ {
+ continueTraversing = true;
+ }
+
+ //check if we should continue traversing
+ if(!continueTraversing)
+ {
+ return; // stop traversing if conditions does not apply
+ }
+ else
+ {
+ //process childs recursively
+ for( Hop in : h.getInput() )
+ rFindRowAggPattern(in,memo);
+ }
+ memo.add(h.getHopID());
+ }
+
+ @Override
+ public LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> constructTplCplan(boolean compileLiterals)
+ throws DMLException {
+
+ //re-assign the dimensions of inputs to match the generated code dimensions
+ _initialCnodes.add(new CNodeData(_matrixInputs.get(0)));
+
+ rConstructRowAggCplan(_initialHop,_initialHop,new HashSet<Long>(), compileLiterals);
+ return _cpplans;
+ }
+
+ private void rConstructRowAggCplan(Hop root, Hop hop, HashSet<Long> memo, boolean compileLiterals) throws DMLException
+ {
+ if( memo.contains(hop.getHopID()) )
+ return;
+ //process childs recursively
+ for( Hop c : hop.getInput() )
+ rConstructRowAggCplan(root, c, memo, compileLiterals);
+ if(hop == _endHop)
+ _endHopReached = true;
+
+ // first hop to enter here should be _endHop
+ if(TemplateUtils.inputsAreGenerated(hop,_matrixInputs,_cpplans) && _endHopReached) // if direct children are DataGenOps, literals, or already in the cpplans then we are ready to generate code
+ {
+ CNodeRowAggVector rowTmpl = null;
+
+ //Fetch operands
+ CNode out = null;
+ ArrayList<CNode> addedCNodes = new ArrayList<CNode>();
+ ArrayList<Hop> addedHops = new ArrayList<Hop>();
+ ArrayList<CNode> cnodeData = TemplateUtils.fetchOperands(hop, _cpplans, addedCNodes, addedHops, _initialCnodes, compileLiterals);
+
+ //if operands are scalar or independent from X
+ boolean independentOperands = hop.getDataType() == DataType.SCALAR
+ || TemplateUtils.isOperandsIndependent(cnodeData, addedHops, new String[] {_matrixInputs.get(0).getName()});
+
+ if(!independentOperands)
+ {
+
+ if(hop instanceof AggUnaryOp)
+ {
+ CNode cdata1 = cnodeData.get(0);
+ //set the out cnode based on the operation
+ if( ((AggUnaryOp)hop).getDirection() == Direction.Row && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) //RowSums
+ {
+ if(hop.getInput().get(0).getDim2()==1)
+ out = (cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1,UnaryType.LOOKUP);
+ else
+ out = new CNodeUnary(cdata1, UnaryType.ROW_SUMS);
+ }
+ // if colsums is the root hop, wire the input to the out because colsums it is done automatically by the template
+ else if (((AggUnaryOp)hop).getDirection() == Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM && root == hop)
+ {
+ //vector div add without temporary copy
+ if(cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType()==BinType.VECT_DIV_SCALAR)
+ out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), BinType.VECT_DIV_ADD);
+ else
+ out = cdata1;
+ }
+ }
+ else if(hop instanceof AggBinaryOp)
+ {
+ //Fetch operands specific to the operation
+ CNode cdata1 = cnodeData.get(0);
+ CNode cdata2 = cnodeData.get(1);
+
+ //choose the operation based on the transpose
+ if( hop.getInput().get(0) instanceof ReorgOp && ((ReorgOp)hop.getInput().get(0)).getOp()==ReOrgOp.TRANSPOSE )
+ {
+ //fetch the data inside the transpose
+ //cdata1 = new CNodeData(hop.getInput().get(0).getInput().get(0).getName(), (int)hop.getInput().get(0).getInput().get(0).getDim1(), (int)hop.getInput().get(0).getInput().get(0).getDim2());
+ out = new CNodeBinary(cdata2, cdata1, BinType.VECT_MULT_ADD);
+ }
+ else
+ {
+ if(hop.getInput().get(0).getDim2()==1 && hop.getInput().get(1).getDim2()==1)
+ out = new CNodeBinary((cdata1.getDataType()==DataType.SCALAR)? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0),
+ (cdata2.getDataType()==DataType.SCALAR)? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT);
+ else
+ out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
+ }
+ }
+ else if(hop instanceof BinaryOp)
+ {
+ CNode cdata1 = cnodeData.get(0);
+ CNode cdata2 = cnodeData.get(1);
+
+ // if one input is a matrix then we need to do vector by scalar operations
+ if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 )
+ {
+ if (((BinaryOp)hop).getOp()== OpOp2.DIV)
+ //CNode generatedScalar = new CNodeData("1", 0, 0); // generate literal in order to rewrite the div to x * 1/y
+ //CNode outScalar = new CNodeBinary(generatedScalar, cdata2, BinType.SCALAR_DIVIDE);
+ //out = new CNodeBinary(outScalar, cdata1, BinType.VECT_MULT_ADD);
+ out = new CNodeBinary(cdata1, cdata2, BinType.VECT_DIV_SCALAR);
+
+ }
+ else //one input is a vector/scalar other is a scalar
+ {
+ //Primitive Operation has the same name as Hop Type OpOp2
+ String primitiveOpName = ((BinaryOp)hop).getOp().toString();
+
+ if( (cdata1.getNumRows() > 1 && cdata1.getNumCols() == 1) || (cdata1.getNumRows() == 1 && cdata1.getNumCols() > 1) )
+ {
+ //second argument is always the vector
+ cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP);
+ //out = new CNodeBinary(tmp, cdata2, BinType.valueOf(primitiveOpName));
+ }
+ //cdata2 is vector
+ //else if( cdata2 instanceof CNodeData && (((CNodeData)cdata2).getNumRows() > 1 && ((CNodeData)cdata2).getNumCols() == 1) || ( ((CNodeData)cdata2).getNumRows() == 1 && ((CNodeData)cdata2).getNumCols() > 1 ))
+ if( (cdata2.getNumRows() > 1 && cdata2.getNumCols() == 1) || (cdata2.getNumRows() == 1 && cdata2.getNumCols() > 1) )
+ {
+ cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP);
+ //out = new CNodeBinary(cdata1, tmp, BinType.valueOf(primitiveOpName));
+ }
+ out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
+ }
+
+ }
+
+ if( out.getDataType().isMatrix() ) {
+ out.setNumRows(hop.getDim1());
+ out.setNumCols(hop.getDim2());
+ }
+ }
+ // wire output to the template
+ if(out != null || independentOperands)
+ {
+ if(_cpplans.isEmpty())
+ {
+ //first initialization has to have the first variable as input
+ ArrayList<CNode> initialInputs = new ArrayList<CNode>();
+
+ if(independentOperands) // pass the hop itself as an input instead of its children
+ {
+ CNode c = new CNodeData(hop);
+ initialInputs.addAll(_initialCnodes);
+ initialInputs.add(c);
+ rowTmpl = new CNodeRowAggVector(initialInputs, c);
+ _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(new Hop[] {_matrixInputs.get(0),hop} ,rowTmpl));
+ }
+ else
+ {
+ initialInputs.addAll(_initialCnodes);
+ initialInputs.addAll(cnodeData);
+ rowTmpl = new CNodeRowAggVector(initialInputs, out);
+
+ //Hop[] hopArray = new Hop[hop.getInput().size()+1];
+ Hop[] hopArray = new Hop[addedHops.size()+1];
+ hopArray[0] = _matrixInputs.get(0);
+
+ //System.arraycopy( hop.getInput().toArray(), 0, hopArray, 1, hop.getInput().size());
+ System.arraycopy( addedHops.toArray(), 0, hopArray, 1, addedHops.size());
+
+ _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(hopArray,rowTmpl));
+ }
+ }
+ else
+ {
+ if(independentOperands)
+ {
+ CNode c = new CNodeData(hop);
+ //clear Operands
+ addedCNodes.clear();
+ addedHops.clear();
+
+ //added the current hop as the input
+ addedCNodes.add(c);
+ addedHops.add(hop);
+ out = c;
+ }
+ //wire the output to existing or new template
+ TemplateUtils.setOutputToExistingTemplate(hop, out, _cpplans, addedCNodes, addedHops);
+ }
+ }
+ memo.add(hop.getHopID());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
new file mode 100644
index 0000000..fd8a960
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -0,0 +1,313 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.template;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.LinkedHashSet;
+import java.util.Map.Entry;
+
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.DataGenOp;
+import org.apache.sysml.hops.DataOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.ReOrgOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.hops.codegen.cplan.CNode;
+import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
+import org.apache.sysml.hops.codegen.cplan.CNodeCell;
+import org.apache.sysml.hops.codegen.cplan.CNodeData;
+import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct;
+import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
+import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
+import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;
+import org.apache.sysml.runtime.matrix.data.Pair;
+import org.apache.sysml.runtime.util.UtilFunctions;
+
+public class TemplateUtils
+{
+ public static boolean inputsAreGenerated(Hop parent, ArrayList<Hop> inputs, HashMap<Long, Pair<Hop[],CNodeTpl>> cpplans)
+ {
+ if( parent instanceof DataOp || parent instanceof DataGenOp || parent instanceof LiteralOp || inputs.contains(parent) )
+ return false;
+
+ for(Hop hop : parent.getInput() )
+ if(!inputs.contains(hop) && !(hop instanceof DataOp) && !(hop instanceof DataGenOp) && !(hop.getDataType()==DataType.SCALAR) && !isVector(hop) && !(cpplans.containsKey(hop.getHopID())) && !( hop instanceof ReorgOp && ((ReorgOp)hop).getOp() == ReOrgOp.TRANSPOSE && inputsAreGenerated(hop,inputs, cpplans) ))
+ return false;
+ return true;
+ }
+
+ public static ArrayList<CNode> fetchOperands(Hop hop, HashMap<Long, Pair<Hop[],CNodeTpl>> cpplans, ArrayList<CNode> addedCNodes, ArrayList<Hop> addedHops, ArrayList<CNodeData> initialCNodes, boolean compileLiterals)
+ {
+ ArrayList<CNode> cnodeData = new ArrayList<CNode>();
+ for (Hop h: hop.getInput())
+ {
+ CNode cdata = null;
+
+ //CNodeData already in template inputs
+ for(CNodeData c : initialCNodes) {
+ if( c.getHopID() == h.getHopID() ) {
+ cdata = c;
+ break;
+ }
+ }
+
+ if(cdata != null)
+ {
+ cnodeData.add(cdata);
+ continue;
+ }
+ //hop already in the cplan
+ else if(cpplans.containsKey(h.getHopID()))
+ {
+ cdata = cpplans.get(h.getHopID()).getValue().getOutput();
+ }
+ else if(h instanceof ReorgOp && ((ReorgOp)h).getOp()==ReOrgOp.TRANSPOSE )
+ {
+ //fetch what is under the transpose
+ Hop in = h.getInput().get(0);
+ cdata = new CNodeData(in);
+ if(in instanceof DataOp || in instanceof DataGenOp ) {
+ addedCNodes.add(cdata);
+ addedHops.add(in);
+ }
+ }
+ else
+ {
+ //note: only compile literals if forced or integer literals (likely constants)
+ //to increase reuse potential on literal replacement during recompilation
+ cdata = new CNodeData(h);
+ cdata.setLiteral(h instanceof LiteralOp && (compileLiterals
+ || UtilFunctions.isIntegerNumber(((LiteralOp)h).getStringValue())));
+ if( !cdata.isLiteral() ) {
+ addedCNodes.add(cdata);
+ addedHops.add(h);
+ }
+ }
+
+ cnodeData.add(cdata);
+ }
+ return cnodeData;
+ }
+
+ public static void setOutputToExistingTemplate(Hop hop, CNode out, HashMap<Long, Pair<Hop[],CNodeTpl>> cpplans, ArrayList<CNode> addedCNodes, ArrayList<Hop> addedHops)
+ {
+ //get the toplevel rowTemp
+ Entry<Long, Pair<Hop[],CNodeTpl>> cplan = null;
+ Iterator<Entry<Long, Pair<Hop[],CNodeTpl>>> iterator = cpplans.entrySet().iterator();
+ while (iterator.hasNext())
+ cplan = iterator.next();
+
+ CNodeTpl tmpl = cplan.getValue().getValue().clone();
+ tmpl.setDataType(hop.getDataType());
+
+ if(tmpl instanceof CNodeOuterProduct) {
+ ((CNodeOuterProduct) tmpl).setOutProdType( ((CNodeOuterProduct)cplan.getValue().getValue()).getOutProdType());
+ ((CNodeOuterProduct) tmpl).setTransposeOutput(((CNodeOuterProduct)cplan.getValue().getValue()).isTransposeOutput() );
+ }
+ else if( tmpl instanceof CNodeCell ) {
+ ((CNodeCell)tmpl).setCellType(getCellType(hop));
+ ((CNodeCell)tmpl).setMultipleConsumers(hop.getParent().size()>1);
+ }
+
+ //add extra inputs
+ for(CNode c : addedCNodes)
+ tmpl.addInput(c);
+
+ //modify addedHops if they exist
+
+ Hop[] currentInputHops = cplan.getValue().getKey();
+ for (Hop h : currentInputHops)
+ if (addedHops.contains(h))
+ addedHops.remove(h);
+
+ Hop[] extendedHopInputs = new Hop[cplan.getValue().getKey().length + addedHops.size()];
+ System.arraycopy(cplan.getValue().getKey(), 0, extendedHopInputs, 0, cplan.getValue().getKey().length);
+ for(int j=addedHops.size(); j > 0; j--)
+ extendedHopInputs[extendedHopInputs.length-j] = addedHops.get(addedHops.size() - j); //append the added hops to the end of the array
+
+ //set the template output and add it to the cpplans
+ Pair<Hop[],CNodeTpl> pair = new Pair<Hop[],CNodeTpl>(extendedHopInputs,tmpl);
+ pair.getValue().setOutput(out);
+ cpplans.put(hop.getHopID(), pair);
+
+ }
+
+ public static boolean isOperandsIndependent(ArrayList<CNode> cnodeData, ArrayList<Hop> addedHops, String[] varNames)
+ {
+ for(CNode c : cnodeData) {
+ // it is some variable inside the cplan // TODO needs to be modified because sometimes the varname is not null but the variable is in the cplan
+ if(c.getVarname() == null)
+ return false;
+ //if one of the operands is is any of the varnames // if one of the operands is T(X) this condition will apply as well because during fetch operands we fetch what is inside transpose
+ for(String varName : varNames)
+ if(c.getVarname().equals(varName))
+ return false;
+ }
+ return true;
+ }
+
+ public static Entry<Long, Pair<Hop[],CNodeTpl>> getTopLevelCpplan(HashMap<Long, Pair<Hop[],CNodeTpl>> cplans)
+ {
+ Entry<Long, Pair<Hop[],CNodeTpl>> ret = null;
+
+ //get last entry (most fused operators) or special handling
+ boolean hasExp = false;
+ for( Entry<Long, Pair<Hop[],CNodeTpl>> e : cplans.entrySet() )
+ {
+ ret = e; //keep last seen entry
+
+ //special handling overlapping fused operators with exp
+ hasExp |= (ret.getValue().getValue().getOutput() instanceof CNodeUnary
+ && ((CNodeUnary)ret.getValue().getValue().getOutput()).getType()==UnaryType.EXP);
+
+ if( hasExp && ret.getValue().getValue() instanceof CNodeCell
+ && ((CNodeCell)ret.getValue().getValue()).hasMultipleConsumers() )
+ break;
+ }
+
+ return ret;
+ }
+
+ public static boolean isVector(Hop hop) {
+ return (hop.getDataType() == DataType.MATRIX
+ && (hop.getDim1() != 1 && hop.getDim2() == 1
+ || hop.getDim1() == 1 && hop.getDim2() != 1 ) );
+ }
+
+ public static boolean isColVector(CNode hop) {
+ return (hop.getDataType() == DataType.MATRIX
+ && hop.getNumRows() != 1 && hop.getNumCols() == 1);
+ }
+
+ public static boolean isRowVector(CNode hop) {
+ return (hop.getDataType() == DataType.MATRIX
+ && hop.getNumRows() == 1 && hop.getNumCols() != 1);
+ }
+
+ public static boolean isMatrix(Hop hop) {
+ return (hop.getDataType() == DataType.MATRIX && hop.getDim1() != 1 && hop.getDim2()!=1);
+ }
+
+ public static boolean isVectorOrScalar(Hop hop) {
+ return hop.dimsKnown() && (hop.getDataType() == DataType.SCALAR || isVector(hop) );
+ }
+
+ public static boolean isBinaryMatrixRowVector(Hop hop) {
+ if( !(hop instanceof BinaryOp) )
+ return false;
+ Hop left = hop.getInput().get(0);
+ Hop right = hop.getInput().get(1);
+ return left.dimsKnown() && right.dimsKnown()
+ && left.getDataType().isMatrix() && right.getDataType().isMatrix()
+ && left.getDim1() > right.getDim1();
+ }
+
+ public static boolean isOperationSupported(Hop h) {
+ if(h instanceof UnaryOp)
+ return UnaryType.contains(((UnaryOp)h).getOp().toString());
+ else if(h instanceof BinaryOp)
+ return BinType.contains(((BinaryOp)h).getOp().toString());
+ else
+ return false;
+ }
+
+ private static void rfindChildren(Hop hop, HashSet<Hop> children ) {
+ if( hop instanceof UnaryOp || (hop instanceof BinaryOp && hop.getInput().get(0).getDataType() == DataType.MATRIX && TemplateUtils.isVectorOrScalar( hop.getInput().get(1))) || (hop instanceof BinaryOp && TemplateUtils.isVectorOrScalar( hop.getInput().get(0)) && hop.getInput().get(1).getDataType() == DataType.MATRIX) //unary operation or binary operaiton with one matrix and a scalar
+ && hop.getDataType() == DataType.MATRIX )
+ {
+ if(!children.contains(hop))
+ children.add(hop);
+ Hop matrix = TemplateUtils.isMatrix(hop.getInput().get(0)) ? hop.getInput().get(0) : hop.getInput().get(1);
+ rfindChildren(matrix,children);
+ }
+ else
+ children.add(hop);
+ }
+
+ private static Hop findCommonChild(Hop hop1, Hop hop2) {
+ //this method assumes that each two nodes have at most one common child
+ LinkedHashSet<Hop> children1 = new LinkedHashSet<Hop>();
+ LinkedHashSet<Hop> children2 = new LinkedHashSet<Hop>();
+
+ rfindChildren(hop1, children1 );
+ rfindChildren(hop2, children2 );
+
+ //iterate on one set and find the first common child in the other set
+ Iterator<Hop> iter = children1.iterator();
+ while (iter.hasNext()) {
+ Hop candidate = iter.next();
+ if(children2.contains(candidate))
+ return candidate;
+ }
+ return null;
+ }
+
+ public static Hop commonChild(ArrayList<Hop> _adddedMatrices, Hop input) {
+ Hop currentChild = null;
+ //loop on every added matrix and find its common child with the input, if all of them have the same common child then return it, otherwise null
+ for(Hop addedMatrix : _adddedMatrices)
+ {
+ Hop child = findCommonChild(addedMatrix,input);
+ if(child == null) // did not find a common child
+ return null;
+ if(currentChild == null) // first common child to be seen
+ currentChild = child;
+ else if(child.getHopID() != currentChild.getHopID())
+ return null;
+ }
+ return currentChild;
+ }
+
+ public static HashSet<Long> rGetInputHopIDs( CNode node, HashSet<Long> ids ) {
+ if( node instanceof CNodeData && !node.isLiteral() )
+ ids.add(((CNodeData)node).getHopID());
+
+ for( CNode c : node.getInput() )
+ rGetInputHopIDs(c, ids);
+
+ return ids;
+ }
+
+ public static Hop[] mergeDistinct(HashSet<Long> ids, Hop[] input1, Hop[] input2) {
+ Hop[] ret = new Hop[ids.size()];
+ int pos = 0;
+ for( Hop[] input : new Hop[][]{input1, input2} )
+ for( Hop c : input )
+ if( ids.contains(c.getHopID()) )
+ ret[pos++] = c;
+ return ret;
+ }
+
+ private static CellType getCellType(Hop hop) {
+ return (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM) ?
+ ((((AggUnaryOp) hop).getDirection() == Direction.RowCol) ?
+ CellType.FULL_AGG : CellType.ROW_AGG) : CellType.NO_AGG;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 7f65ddd..802a382 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -253,6 +253,12 @@ public class HopRewriteUtils
child.getParent().add( parent );
}
+ public static void rewireAllParentChildReferences( Hop hold, Hop hnew ) {
+ ArrayList<Hop> parents = new ArrayList<Hop>(hold.getParent());
+ for( Hop lparent : parents )
+ HopRewriteUtils.replaceChildReference(lparent, hold, hnew);
+ }
+
public static void replaceChildReference( Hop parent, Hop inOld, Hop inNew ) {
int pos = getChildReferencePos(parent, inOld);
removeChildReferenceByPos(parent, inOld, pos);
@@ -491,10 +497,12 @@ public class HopRewriteUtils
input2.getDataType().isMatrix() ? input2 : input1;
BinaryOp bop = new BinaryOp(mainInput.getName(), mainInput.getDataType(),
mainInput.getValueType(), op, input1, input2);
+ //cleanup value type for relational operations
+ if( bop.isPPredOperation() && bop.getDataType().isScalar() )
+ bop.setValueType(ValueType.BOOLEAN);
bop.setOutputBlocksizes(mainInput.getRowsInBlock(), mainInput.getColsInBlock());
copyLineNumbers(mainInput, bop);
bop.refreshSizeInformation();
-
return bop;
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
index 558deb3..cea2c93 100644
--- a/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
+++ b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java
@@ -166,48 +166,21 @@ public class ConvolutionTransform extends Lop
}
}
- // Used by maxpool
- public String getInstructions(String input, String stride1, String stride2, String padding1, String padding2,
- String input_shape1, String input_shape2, String input_shape3, String input_shape4,
- String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4,
- String output) throws LopsException {
- StringBuilder sb = new StringBuilder();
- appendOpcode(sb);
- sb.append( getInputs().get(0).prepInputOperand(input));
- appendOperands(1, 13, output, sb);
- return sb.toString();
- }
-
- // Used by conv2d*, maxpool_bwd
- public String getInstructions(String input, String dout, String stride1, String stride2, String padding1, String padding2,
- String input_shape1, String input_shape2, String input_shape3, String input_shape4,
- String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4,
- String output) throws LopsException {
- StringBuilder sb = new StringBuilder();
- appendOpcode(sb);
- sb.append( getInputs().get(0).prepInputOperand(input));
- sb.append( OPERAND_DELIMITOR );
- sb.append( getInputs().get(1).prepInputOperand(dout));
- appendOperands(2, 14, output, sb);
- return sb.toString();
- }
-
- // Used by fused conv2d+bias_add
- public String getInstructions(String input, String bias, String filter, String stride1, String stride2, String padding1, String padding2,
- String input_shape1, String input_shape2, String input_shape3, String input_shape4,
- String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4,
- String output) throws LopsException {
+ @Override
+ public String getInstructions(String[] inputs, String output) throws LopsException {
StringBuilder sb = new StringBuilder();
appendOpcode(sb);
- sb.append( getInputs().get(0).prepInputOperand(input));
- sb.append( OPERAND_DELIMITOR );
- sb.append( getInputs().get(1).prepInputOperand(bias));
- sb.append( OPERAND_DELIMITOR );
- sb.append( getInputs().get(2).prepInputOperand(filter));
- appendOperands(3, 15, output, sb);
+
+ for( int i=0; i<inputs.length-12; i++ ) {
+ if( i > 0 )
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getInputs().get(i).prepInputOperand(inputs[i]));
+ }
+ appendOperands(inputs.length-12, inputs.length, output, sb);
+
return sb.toString();
}
-
+
public void appendOpcode(StringBuilder sb) {
sb.append( getExecType() );
sb.append( OPERAND_DELIMITOR );
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/lops/Lop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java
index 567b0be..24f7ba3 100644
--- a/src/main/java/org/apache/sysml/lops/Lop.java
+++ b/src/main/java/org/apache/sysml/lops/Lop.java
@@ -59,6 +59,7 @@ public abstract class Lop
SortKeys, PickValues,
Checkpoint, //Spark persist into storage level
PlusMult, MinusMult, //CP
+ SpoofFused, //CP/SP generated fused operator
/** CP operation on a variable number of operands */
MULTIPLE_CP
};
@@ -418,6 +419,40 @@ public abstract class Lop
return outParams;
}
+
+ /** Method should be overridden if needed
+ *
+ * @param output output
+ * @return instructions as string
+ * @throws LopsException if LopsException occurs
+ */
+ public String getInstructions(String output) throws LopsException {
+ throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
+ }
+
+ /** Method should be overridden if needed
+ *
+ * @param input1 input 1
+ * @param output output
+ * @return instructions as string
+ * @throws LopsException if LopsException occurs
+ */
+ public String getInstructions(String input1, String output) throws LopsException {
+ throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
+ }
+
+ /** Method should be overridden if needed
+ *
+ * @param input1 input 1
+ * @param input2 input 2
+ * @param output output
+ * @return instructions as string
+ * @throws LopsException if LopsException occurs
+ */
+ public String getInstructions(String input1, String input2, String output) throws LopsException {
+ throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
+ }
+
/**
* Method should be overridden if needed
*
@@ -478,6 +513,15 @@ public abstract class Lop
public String getInstructions(String input1, String input2, String input3, String input4, String input5, String input6, String output) throws LopsException {
throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
}
+
+ public String getInstructions(String input1, String input2, String input3, String input4, String input5, String input6, String input7, String output) throws LopsException {
+ throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
+ }
+
+ public String getInstructions(String[] inputs, String outputs) throws LopsException {
+ throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
+ }
+
public String getInstructions(int output_index) throws LopsException {
throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass. Lop Type: " + this.getType());
@@ -541,38 +585,6 @@ public abstract class Lop
throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
}
- /** Method should be overridden if needed
- *
- * @param input1 input 1
- * @param input2 input 2
- * @param output output
- * @return instructions as string
- * @throws LopsException if LopsException occurs
- */
- public String getInstructions(String input1, String input2, String output) throws LopsException {
- throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
- }
-
- /** Method should be overridden if needed
- *
- * @param input1 input 1
- * @param output output
- * @return instructions as string
- * @throws LopsException if LopsException occurs
- */
- public String getInstructions(String input1, String output) throws LopsException {
- throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
- }
-
- /** Method should be overridden if needed
- *
- * @param output output
- * @return instructions as string
- * @throws LopsException if LopsException occurs
- */
- public String getInstructions(String output) throws LopsException {
- throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
- }
/** Method should be overridden if needed
*
@@ -630,37 +642,6 @@ public abstract class Lop
return "ERROR: line " + _beginLine + ", column " + _beginColumn + " -- ";
}
- //TODO: Leo This might get confused with Rand.getInstructions
- public String getInstructions(String input, String rowl, String rowu,
- String coll, String colu, String leftRowDim,
- String leftColDim, String output) throws LopsException {
- throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
- }
-
- // stride1, stride2, padding1, padding2
- // input_shape1, input_shape2, input_shape3, input_shape4,
- // filter_shape1, filter_shape2, filter_shape3, filter_shape4,
- public String getInstructions(String input, String stride1, String stride2, String padding1, String padding2,
- String input_shape1, String input_shape2, String input_shape3, String input_shape4,
- String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4,
- String output) throws LopsException {
- throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
- }
-
- public String getInstructions(String input, String dout, String stride1, String stride2, String padding1, String padding2,
- String input_shape1, String input_shape2, String input_shape3, String input_shape4,
- String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4,
- String output) throws LopsException {
- throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
- }
-
- public String getInstructions(String input, String bias, String dout, String stride1, String stride2, String padding1, String padding2,
- String input_shape1, String input_shape2, String input_shape3, String input_shape4,
- String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4,
- String output) throws LopsException {
- throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass");
- }
-
public String getInstructions(int input, int rowl, int rowu,
int coll, int colu, int leftRowDim,
int leftColDim, int output) throws LopsException {
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/lops/SpoofFused.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/SpoofFused.java b/src/main/java/org/apache/sysml/lops/SpoofFused.java
new file mode 100644
index 0000000..3f0ec59
--- /dev/null
+++ b/src/main/java/org/apache/sysml/lops/SpoofFused.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.lops;
+
+import java.util.ArrayList;
+
+import org.apache.sysml.lops.LopProperties.ExecLocation;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.lops.compile.JobType;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.parser.Expression.ValueType;
+
+public class SpoofFused extends Lop
+{
+ private final Class<?> _class;
+ private final int _numThreads;
+
+ public SpoofFused( ArrayList<Lop> inputs, DataType dt, ValueType vt, Class<?> cla, int k, ExecType etype) {
+ super(Type.SpoofFused, dt, vt);
+ _class = cla;
+ _numThreads = k;
+
+ for( Lop lop : inputs ) {
+ addInput(lop);
+ lop.addOutput(this);
+ }
+
+ lps.addCompatibility(JobType.INVALID);
+ lps.setProperties( inputs, etype, ExecLocation.ControlProgram, false, false, false );
+ }
+
+ @Override
+ public String toString() {
+ return "spoof("+_class.getSimpleName()+")";
+ }
+
+ @Override
+ public String getInstructions(String input1, String output) throws LopsException {
+ return getInstructions(new String[]{input1}, new String[]{output});
+ }
+
+ @Override
+ public String getInstructions(String input1, String input2, String output) throws LopsException {
+ return getInstructions(new String[]{input1, input2}, new String[]{output});
+ }
+
+ @Override
+ public String getInstructions(String input1, String input2, String input3, String output) throws LopsException {
+ return getInstructions(new String[]{input1, input2, input3}, new String[]{output});
+ }
+
+ @Override
+ public String getInstructions(String input1, String input2, String input3, String input4, String output) throws LopsException {
+ return getInstructions(new String[]{input1, input2, input3, input4}, new String[]{output});
+ }
+
+ @Override
+ public String getInstructions(String input1, String input2, String input3, String input4, String input5, String output) throws LopsException {
+ return getInstructions(new String[]{input1, input2, input3, input4, input5}, new String[]{output});
+ }
+
+ @Override
+ public String getInstructions(String input1, String input2, String input3, String input4, String input5, String input6, String output) throws LopsException {
+ return getInstructions(new String[]{input1, input2, input3, input4, input5, input6}, new String[]{output});
+ }
+
+ @Override
+ public String getInstructions(String input1, String input2, String input3, String input4, String input5, String input6, String input7, String output) throws LopsException {
+ return getInstructions(new String[]{input1, input2, input3, input4, input5, input6, input7}, new String[]{output});
+ }
+
+ @Override
+ public String getInstructions(String[] inputs, String output) throws LopsException {
+ return getInstructions(inputs, new String[]{output});
+ }
+
+ @Override
+ public String getInstructions(String[] inputs, String[] outputs)
+ throws LopsException
+ {
+ StringBuilder sb = new StringBuilder();
+ sb.append( getExecType() );
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( "spoof" );
+
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( _class.getName() );
+
+ for(int i=0; i < inputs.length; i++) {
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( getInputs().get(i).prepInputOperand(inputs[i]));
+ }
+
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( prepOutputOperand(outputs[0]) );
+
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( _numThreads );
+
+ return sb.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/lops/compile/Dag.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/compile/Dag.java b/src/main/java/org/apache/sysml/lops/compile/Dag.java
index 898f4ec..b513951 100644
--- a/src/main/java/org/apache/sysml/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysml/lops/compile/Dag.java
@@ -1491,65 +1491,12 @@ public class Dag<N extends Lop>
node.getInputs().get(6).getOutputParameters().getLabel(),
node.getOutputParameters().getLabel());
}
- else if (node.getInputs().size() == 13) {
- // Used for im2col and reshape_col
- inst_string = node.getInstructions(
- node.getInputs().get(0).getOutputParameters().getLabel(),
- node.getInputs().get(1).getOutputParameters().getLabel(),
- node.getInputs().get(2).getOutputParameters().getLabel(),
- node.getInputs().get(3).getOutputParameters().getLabel(),
- node.getInputs().get(4).getOutputParameters().getLabel(),
- node.getInputs().get(5).getOutputParameters().getLabel(),
- node.getInputs().get(6).getOutputParameters().getLabel(),
- node.getInputs().get(7).getOutputParameters().getLabel(),
- node.getInputs().get(8).getOutputParameters().getLabel(),
- node.getInputs().get(9).getOutputParameters().getLabel(),
- node.getInputs().get(10).getOutputParameters().getLabel(),
- node.getInputs().get(11).getOutputParameters().getLabel(),
- node.getInputs().get(12).getOutputParameters().getLabel(),
- node.getOutputParameters().getLabel());
- }
- else if (node.getInputs().size() == 14) {
- // Used for pooling_backward
- inst_string = node.getInstructions(
- node.getInputs().get(0).getOutputParameters().getLabel(),
- node.getInputs().get(1).getOutputParameters().getLabel(),
- node.getInputs().get(2).getOutputParameters().getLabel(),
- node.getInputs().get(3).getOutputParameters().getLabel(),
- node.getInputs().get(4).getOutputParameters().getLabel(),
- node.getInputs().get(5).getOutputParameters().getLabel(),
- node.getInputs().get(6).getOutputParameters().getLabel(),
- node.getInputs().get(7).getOutputParameters().getLabel(),
- node.getInputs().get(8).getOutputParameters().getLabel(),
- node.getInputs().get(9).getOutputParameters().getLabel(),
- node.getInputs().get(10).getOutputParameters().getLabel(),
- node.getInputs().get(11).getOutputParameters().getLabel(),
- node.getInputs().get(12).getOutputParameters().getLabel(),
- node.getInputs().get(13).getOutputParameters().getLabel(),
- node.getOutputParameters().getLabel());
- }
- else if (node.getInputs().size() == 15) {
- // Used for fused conv2d_bias_add
- inst_string = node.getInstructions(
- node.getInputs().get(0).getOutputParameters().getLabel(),
- node.getInputs().get(1).getOutputParameters().getLabel(),
- node.getInputs().get(2).getOutputParameters().getLabel(),
- node.getInputs().get(3).getOutputParameters().getLabel(),
- node.getInputs().get(4).getOutputParameters().getLabel(),
- node.getInputs().get(5).getOutputParameters().getLabel(),
- node.getInputs().get(6).getOutputParameters().getLabel(),
- node.getInputs().get(7).getOutputParameters().getLabel(),
- node.getInputs().get(8).getOutputParameters().getLabel(),
- node.getInputs().get(9).getOutputParameters().getLabel(),
- node.getInputs().get(10).getOutputParameters().getLabel(),
- node.getInputs().get(11).getOutputParameters().getLabel(),
- node.getInputs().get(12).getOutputParameters().getLabel(),
- node.getInputs().get(13).getOutputParameters().getLabel(),
- node.getInputs().get(14).getOutputParameters().getLabel(),
- node.getOutputParameters().getLabel());
- }
else {
- throw new LopsException(node.printErrorLocation() + "Node with " + node.getInputs().size() + " inputs is not supported in CP yet! \n");
+ String[] inputs = new String[node.getInputs().size()];
+ for( int j=0; j<node.getInputs().size(); j++ )
+ inputs[j] = node.getInputs().get(j).getOutputParameters().getLabel();
+ inst_string = node.getInstructions(inputs,
+ node.getOutputParameters().getLabel());
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/runtime/codegen/ByteClassLoader.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/ByteClassLoader.java b/src/main/java/org/apache/sysml/runtime/codegen/ByteClassLoader.java
new file mode 100644
index 0000000..27263d3
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/codegen/ByteClassLoader.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.codegen;
+
+import java.net.URL;
+import java.net.URLClassLoader;
+
+public class ByteClassLoader extends URLClassLoader
+{
+ private final byte[] _classBytes;
+
+ public ByteClassLoader(URL[] urls, ClassLoader parent, byte[] classBytes) {
+ super(urls, parent);
+ _classBytes = classBytes;
+ }
+
+ @Override
+ public Class<?> findClass(String className) throws ClassNotFoundException {
+ if (_classBytes != null)
+ return defineClass(className, _classBytes, 0, _classBytes.length);
+ return super.loadClass(className);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/runtime/codegen/CodegenUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/CodegenUtils.java b/src/main/java/org/apache/sysml/runtime/codegen/CodegenUtils.java
new file mode 100644
index 0000000..fdad9bd
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/codegen/CodegenUtils.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.codegen;
+
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.ObjectOutputStream;
+import java.net.URL;
+import java.net.URLClassLoader;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+
+import javax.tools.Diagnostic;
+import javax.tools.Diagnostic.Kind;
+import javax.tools.DiagnosticCollector;
+import javax.tools.JavaCompiler;
+import javax.tools.JavaCompiler.CompilationTask;
+import javax.tools.JavaFileObject;
+import javax.tools.StandardJavaFileManager;
+import javax.tools.ToolProvider;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.io.IOUtilFunctions;
+import org.apache.sysml.runtime.util.LocalFileUtils;
+import org.apache.sysml.utils.Statistics;
+
+public class CodegenUtils
+{
+ //cache to reuse compiled and loaded classes (this is also a workaround for classes,
+ //compiled during initial compilation and subsequently loaded as the working directory
+ //is cleaned up just before the actual execution
+ private static ConcurrentHashMap<String, Class<?>> _cache = new ConcurrentHashMap<String,Class<?>>();
+ private static String _workingDir = null;
+
+ public static Class<?> compileClass(String name, String src)
+ throws DMLRuntimeException
+ {
+ //reuse existing compiled class
+ Class<?> ret = _cache.get(name);
+ if( ret != null )
+ return ret;
+
+ long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
+
+ try
+ {
+ //create working dir on demand
+ if( _workingDir == null )
+ createWorkingDir();
+
+ //write input file (for debugging / classpath handling)
+ File ftmp = new File(_workingDir+"/codegen/"+name+".java");
+ if( !ftmp.getParentFile().exists() )
+ ftmp.getParentFile().mkdirs();
+ LocalFileUtils.writeTextFile(ftmp, src);
+
+ //get system java compiler
+ JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
+ if( compiler == null )
+ throw new RuntimeException("Unable to obtain system java compiler.");
+
+ //prepare file manager
+ DiagnosticCollector<JavaFileObject> diagnostics = new DiagnosticCollector<JavaFileObject>();
+ StandardJavaFileManager fileManager = compiler.getStandardFileManager(diagnostics, null, null);
+
+ //prepare input source code
+ Iterable<? extends JavaFileObject> sources = fileManager
+ .getJavaFileObjectsFromFiles(Arrays.asList(ftmp));
+
+ //prepare class path
+ URL runDir = CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
+ String classpath = System.getProperty("java.class.path") +
+ File.pathSeparator + runDir.getPath();
+ List<String> options = Arrays.asList("-classpath",classpath);
+
+ //compile source code
+ CompilationTask task = compiler.getTask(null, fileManager, diagnostics, options, null, sources);
+ Boolean success = task.call();
+
+ //output diagnostics and error handling
+ for(Diagnostic<? extends JavaFileObject> tmp : diagnostics.getDiagnostics())
+ if( tmp.getKind()==Kind.ERROR )
+ System.err.println("ERROR: "+tmp.toString());
+ if( success == null || !success )
+ throw new RuntimeException("Failed to compile class "+name);
+
+ //dynamically load compiled class
+ URLClassLoader classLoader = new URLClassLoader(
+ new URL[]{new File(_workingDir).toURI().toURL(), runDir},
+ CodegenUtils.class.getClassLoader());
+ ret = classLoader.loadClass("codegen."+name);
+ classLoader.close();
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+
+ //keep compiled class for reuse
+ _cache.put(name, ret);
+
+ if( DMLScript.STATISTICS ) {
+ Statistics.incrementCodegenClassCompile();
+ Statistics.incrementCodegenClassCompileTime(System.nanoTime()-t0);
+ }
+
+ return ret;
+ }
+
+ public static Class<?> loadClass(String name, byte[] classBytes) throws DMLRuntimeException {
+ //reuse existing compiled class
+ Class<?> ret = _cache.get(name);
+ if( ret != null )
+ return ret;
+
+ //define class using the bytes
+ if(classBytes != null)
+ {
+ //ByteClassLoader byteLoader = new ByteClassLoader(classLoader.getURLs() , classLoader.getParent(), classBytes);
+ try {
+ ByteClassLoader byteLoader = new ByteClassLoader(new URL[]{} ,CodegenUtils.class.getClassLoader(), classBytes);
+ ret = byteLoader.findClass(name);
+ byteLoader.close();
+ } catch (Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ }
+ else
+ {
+ //dynamically load compiled class
+ URL runDir = CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
+ URLClassLoader classLoader = null;
+ try {
+ classLoader = new URLClassLoader(
+ new URL[]{new File(_workingDir).toURI().toURL(), runDir},
+ CodegenUtils.class.getClassLoader());
+ ret = classLoader.loadClass(name);
+ }
+ catch (Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ finally {
+ IOUtilFunctions.closeSilently(classLoader);
+ }
+ }
+
+ //keep loaded class for reuse
+ _cache.put(name, ret);
+ return ret;
+ }
+
+ public static Object createInstance(Class<?> cla)
+ throws DMLRuntimeException
+ {
+ Object ret = null;
+
+ try {
+ ret = cla.newInstance();
+ }
+ catch( Exception ex ) {
+ throw new DMLRuntimeException(ex);
+ }
+
+ return ret;
+ }
+
+ public static byte[] getClassAsByteArray(String name)
+ throws DMLRuntimeException
+ {
+ //reuse existing compiled class
+ Class<?> cls = _cache.get(name);
+ if( cls != null )
+ return getClassAsByteArray(cls);
+
+
+ String classAsPath = name.replace('.', '/') + ".class";
+
+ URLClassLoader classLoader = null;
+ byte[] ret = null;
+
+ try {
+ //dynamically load compiled class
+ URL runDir = CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
+ classLoader = new URLClassLoader(
+ new URL[]{new File(_workingDir).toURI().toURL(), runDir},
+ CodegenUtils.class.getClassLoader());
+ InputStream stream = classLoader.getResourceAsStream(classAsPath);
+ ret = IOUtils.toByteArray(stream);
+ }
+ catch (IOException e) {
+ throw new DMLRuntimeException(e);
+ }
+ finally {
+ IOUtilFunctions.closeSilently(classLoader);
+ }
+
+ return ret;
+ }
+
+
+ public static byte[] getClassAsByteArray(Class<?> cls)
+ throws DMLRuntimeException
+ {
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ try {
+ ObjectOutputStream oos = new ObjectOutputStream(bos);
+ oos.writeObject(cls);
+ oos.flush();
+ return bos.toByteArray();
+ }
+ catch( IOException e ) {
+ throw new DMLRuntimeException(e);
+ }
+ finally {
+ IOUtilFunctions.closeSilently(bos);
+ }
+ }
+
+ private static void createWorkingDir() throws DMLRuntimeException {
+ if( _workingDir != null )
+ return;
+ String tmp = LocalFileUtils.getWorkingDir(LocalFileUtils.CATEGORY_CODEGEN);
+ LocalFileUtils.createLocalFileIfNotExist(tmp);
+ _workingDir = tmp;
+ }
+
+ public static URL[] getUrls() throws DMLRuntimeException {
+ try {
+ URL runDir = CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
+ return new URL[]{new File(_workingDir).toURI().toURL(), runDir};
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ }
+
+ public static String getSpoofType(Class<?> cls) {
+ if(cls.getSuperclass() == SpoofCellwise.class)
+ return "Cell" + cls.getName().split("\\.")[1];
+ else if(cls.getSuperclass() == SpoofOuterProduct.class)
+ return "OP" + cls.getName().split("\\.")[1];
+ else if(cls.getSuperclass() == SpoofRowAggregate.class)
+ return "RA" + cls.getName().split("\\.")[1];
+ else
+ return "UNKNOWN";
+ }
+}