You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/02/27 18:36:05 UTC
[3/9] incubator-systemml git commit: [SYSTEMML-1285] New basic code
generator for operator fusion
[SYSTEMML-1285] New basic code generator for operator fusion
This patch introduces a cleaned-up version of SPOOF's basic code
generator, covering its core compiler and runtime operators as well as
its basic integration into the stats and explain tools (SYSTEMML-1296
and SYSTEMML-1297).
Furthermore, this also includes the following minor fixes and
improvements of existing components:
* Fix of rewrite utils for creating binary scalar operations with
boolean outputs
* Cleanup instruction generation convolution lop
* Fix lop dag compilation (removed constraint of max 7 input lops)
* Improved value type handling of scalar comparison instructions
* Fix various gpu-related src and javadoc warnings
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/d7fd5879
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/d7fd5879
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/d7fd5879
Branch: refs/heads/master
Commit: d7fd58795c06dea8db6fb55a045a8b312547f398
Parents: b78c125
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Feb 26 18:53:46 2017 -0800
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Feb 26 18:53:46 2017 -0800
----------------------------------------------------------------------
src/main/java/org/apache/sysml/hops/Hop.java | 5 +
.../sysml/hops/codegen/SpoofCompiler.java | 407 ++++++++++++++
.../apache/sysml/hops/codegen/SpoofFusedOp.java | 212 ++++++++
.../apache/sysml/hops/codegen/cplan/CNode.java | 167 ++++++
.../sysml/hops/codegen/cplan/CNodeBinary.java | 260 +++++++++
.../sysml/hops/codegen/cplan/CNodeCell.java | 144 +++++
.../sysml/hops/codegen/cplan/CNodeData.java | 94 ++++
.../hops/codegen/cplan/CNodeOuterProduct.java | 165 ++++++
.../hops/codegen/cplan/CNodeRowAggVector.java | 111 ++++
.../sysml/hops/codegen/cplan/CNodeTpl.java | 201 +++++++
.../sysml/hops/codegen/cplan/CNodeUnary.java | 206 +++++++
.../sysml/hops/codegen/template/BaseTpl.java | 63 +++
.../sysml/hops/codegen/template/CellTpl.java | 289 ++++++++++
.../hops/codegen/template/CplanRegister.java | 168 ++++++
.../hops/codegen/template/OuterProductTpl.java | 489 +++++++++++++++++
.../sysml/hops/codegen/template/RowAggTpl.java | 321 +++++++++++
.../hops/codegen/template/TemplateUtils.java | 313 +++++++++++
.../sysml/hops/rewrite/HopRewriteUtils.java | 10 +-
.../apache/sysml/lops/ConvolutionTransform.java | 49 +-
src/main/java/org/apache/sysml/lops/Lop.java | 107 ++--
.../java/org/apache/sysml/lops/SpoofFused.java | 119 ++++
.../java/org/apache/sysml/lops/compile/Dag.java | 63 +--
.../sysml/runtime/codegen/ByteClassLoader.java | 40 ++
.../sysml/runtime/codegen/CodegenUtils.java | 268 +++++++++
.../runtime/codegen/LibSpoofPrimitives.java | 257 +++++++++
.../sysml/runtime/codegen/SpoofCellwise.java | 430 +++++++++++++++
.../sysml/runtime/codegen/SpoofOperator.java | 74 +++
.../runtime/codegen/SpoofOuterProduct.java | 541 +++++++++++++++++++
.../runtime/codegen/SpoofRowAggregate.java | 188 +++++++
.../controlprogram/parfor/util/IDSequence.java | 21 +-
.../cp/RelationalBinaryCPInstruction.java | 52 +-
.../cp/ScalarScalarRelationalCPInstruction.java | 22 +-
.../instructions/gpu/context/GPUContext.java | 2 +
.../instructions/gpu/context/GPUObject.java | 1 +
.../instructions/gpu/context/JCudaObject.java | 2 +
.../runtime/matrix/data/LibMatrixMult.java | 18 +-
.../sysml/runtime/util/LocalFileUtils.java | 24 +
.../java/org/apache/sysml/utils/Statistics.java | 73 +++
38 files changed, 5742 insertions(+), 234 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 3aa3dab..4021a1a 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -789,6 +789,11 @@ public abstract class Hop
public ArrayList<Hop> getInput() {
return _input;
}
+
+ public void addInput( Hop h ) {
+ _input.add(h);
+ h._parent.add(this);
+ }
public long getRowsInBlock() {
return _rows_in_block;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
new file mode 100644
index 0000000..dd24703
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.Map.Entry;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.api.DMLException;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.hops.codegen.cplan.CNode;
+import org.apache.sysml.hops.codegen.cplan.CNodeCell;
+import org.apache.sysml.hops.codegen.cplan.CNodeData;
+import org.apache.sysml.hops.codegen.cplan.CNodeOuterProduct;
+import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
+import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
+import org.apache.sysml.hops.codegen.template.BaseTpl;
+import org.apache.sysml.hops.codegen.template.CellTpl;
+import org.apache.sysml.hops.codegen.template.CplanRegister;
+import org.apache.sysml.hops.codegen.template.OuterProductTpl;
+import org.apache.sysml.hops.codegen.template.RowAggTpl;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.ForStatement;
+import org.apache.sysml.parser.ForStatementBlock;
+import org.apache.sysml.parser.FunctionStatement;
+import org.apache.sysml.parser.FunctionStatementBlock;
+import org.apache.sysml.parser.IfStatement;
+import org.apache.sysml.parser.IfStatementBlock;
+import org.apache.sysml.parser.LanguageException;
+import org.apache.sysml.parser.StatementBlock;
+import org.apache.sysml.parser.WhileStatement;
+import org.apache.sysml.parser.WhileStatementBlock;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.codegen.CodegenUtils;
+import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;
+import org.apache.sysml.runtime.matrix.data.Pair;
+import org.apache.sysml.utils.Explain;
+import org.apache.sysml.utils.Explain.ExplainType;
+import org.apache.sysml.utils.Statistics;
+
+public class SpoofCompiler
+{
+ private static final Log LOG = LogFactory.getLog(SpoofCompiler.class.getName());
+
+ public static boolean OPTIMIZE = true;
+
+ //internal configuration flags
+ public static final boolean LDEBUG = false;
+ public static final boolean SUM_PRODUCT = false;
+ public static final boolean RECOMPILE = true;
+ public static boolean USE_PLAN_CACHE = true;
+ public static boolean ALWAYS_COMPILE_LITERALS = false;
+ public static final boolean ALLOW_SPARK_OPS = false;
+
+ //plan cache for cplan->compiled source to avoid unnecessary codegen/source code compile
+ //for equal operators from (1) different hop dags and (2) repeated recompilation
+ private static ConcurrentHashMap<CNode, Class<?>> planCache = new ConcurrentHashMap<CNode, Class<?>>();
+
+ public static void generateCode(DMLProgram dmlp)
+ throws LanguageException, HopsException, DMLRuntimeException
+ {
+ // cleanup static plan cache
+ planCache.clear();
+
+ // for each namespace, handle function statement blocks
+ for (String namespaceKey : dmlp.getNamespaces().keySet()) {
+ for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
+ FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname);
+ generateCodeFromStatementBlock(fsblock);
+ }
+ }
+
+ // handle regular statement blocks in "main" method
+ for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
+ StatementBlock current = dmlp.getStatementBlock(i);
+ generateCodeFromStatementBlock(current);
+ }
+ }
+
+ public static void generateCodeFromStatementBlock(StatementBlock current)
+ throws HopsException, DMLRuntimeException
+ {
+ if (current instanceof FunctionStatementBlock)
+ {
+ FunctionStatementBlock fsb = (FunctionStatementBlock)current;
+ FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
+ for (StatementBlock sb : fstmt.getBody())
+ generateCodeFromStatementBlock(sb);
+ }
+ else if (current instanceof WhileStatementBlock)
+ {
+ WhileStatementBlock wsb = (WhileStatementBlock) current;
+ WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
+ wsb.setPredicateHops(optimize(wsb.getPredicateHops(), true));
+ for (StatementBlock sb : wstmt.getBody())
+ generateCodeFromStatementBlock(sb);
+ }
+ else if (current instanceof IfStatementBlock)
+ {
+ IfStatementBlock isb = (IfStatementBlock) current;
+ IfStatement istmt = (IfStatement)isb.getStatement(0);
+ isb.setPredicateHops(optimize(isb.getPredicateHops(), true));
+ for (StatementBlock sb : istmt.getIfBody())
+ generateCodeFromStatementBlock(sb);
+ for (StatementBlock sb : istmt.getElseBody())
+ generateCodeFromStatementBlock(sb);
+ }
+ else if (current instanceof ForStatementBlock) //incl parfor
+ {
+ ForStatementBlock fsb = (ForStatementBlock) current;
+ ForStatement fstmt = (ForStatement)fsb.getStatement(0);
+ fsb.setFromHops(optimize(fsb.getFromHops(), true));
+ fsb.setToHops(optimize(fsb.getToHops(), true));
+ fsb.setIncrementHops(optimize(fsb.getIncrementHops(), true));
+ for (StatementBlock sb : fstmt.getBody())
+ generateCodeFromStatementBlock(sb);
+ }
+ else //generic (last-level)
+ {
+ current.set_hops( generateCodeFromHopDAGs(current.get_hops()) );
+ current.updateRecompilationFlag();
+ }
+ }
+
+ public static ArrayList<Hop> generateCodeFromHopDAGs(ArrayList<Hop> roots)
+ throws HopsException, DMLRuntimeException
+ {
+ if( roots == null )
+ return roots;
+
+ ArrayList<Hop> optimized = SpoofCompiler.optimize(roots, true);
+ Hop.resetVisitStatus(roots);
+ Hop.resetVisitStatus(optimized);
+
+ return optimized;
+ }
+
+
+ /**
+ * Main interface of sum-product optimizer, predicate dag.
+ *
+ * @param root dag root node
+ * @param compileLiterals if true literals compiled as constants, otherwise as scalar variables
+ * @return dag root node of modified dag
+ * @throws DMLRuntimeException if optimization failed
+ */
+ public static Hop optimize( Hop root, boolean compileLiterals ) throws DMLRuntimeException {
+ if( root == null )
+ return root;
+
+ return optimize(new ArrayList<Hop>(Arrays.asList(root)), compileLiterals).get(0);
+ }
+
+ /**
+ * Main interface of sum-product optimizer, statement block dag.
+ *
+ * @param roots dag root nodes
+ * @param compileLiterals if true literals compiled as constants, otherwise as scalar variables
+ * @return dag root nodes of modified dag
+ * @throws DMLRuntimeException if optimization failed
+ */
+ @SuppressWarnings("unused")
+ public static ArrayList<Hop> optimize(ArrayList<Hop> roots, boolean compileLiterals)
+ throws DMLRuntimeException
+ {
+ if( roots == null || roots.isEmpty() || !OPTIMIZE )
+ return roots;
+
+ long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
+ ArrayList<Hop> ret = roots;
+
+ try
+ {
+ //construct codegen plans
+ HashMap<Long, Pair<Hop[],CNodeTpl>> cplans = constructCPlans(roots, compileLiterals);
+
+ //cleanup codegen plans (remove unnecessary inputs, fix hop-cnodedata mapping,
+ //remove empty templates with single cnodedata input)
+ cplans = cleanupCPlans(cplans);
+
+ //explain before modification
+ if( LDEBUG && cplans.size() > 0 ) { //existing cplans
+ LOG.info("Codegen EXPLAIN (before optimize): \n"+Explain.explainHops(roots));
+ }
+
+ //source code generation for all cplans
+ HashMap<Long, Pair<Hop[],Class<?>>> clas = new HashMap<Long, Pair<Hop[],Class<?>>>();
+ for( Entry<Long, Pair<Hop[],CNodeTpl>> cplan : cplans.entrySet() ) {
+ Pair<Hop[],CNodeTpl> tmp = cplan.getValue();
+
+ if( !USE_PLAN_CACHE || !planCache.containsKey(tmp.getValue()) ) {
+ //generate java source code
+ String src = tmp.getValue().codegen(false);
+
+ //explain debug output generated source code
+ if( LDEBUG || DMLScript.EXPLAIN != ExplainType.NONE ) {
+ LOG.info("Codegen EXPLAIN (generated code for HopID: " + cplan.getKey() +"):");
+ LOG.info(src);
+ }
+
+ //compile generated java source code
+ Class<?> cla = CodegenUtils.compileClass(tmp.getValue().getClassname(), src);
+ planCache.put(tmp.getValue(), cla);
+ }
+ else if( LDEBUG || DMLScript.STATISTICS ) {
+ Statistics.incrementCodegenPlanCacheHits();
+ }
+
+ Class<?> cla = planCache.get(tmp.getValue());
+ if(cla != null)
+ clas.put(cplan.getKey(), new Pair<Hop[],Class<?>>(tmp.getKey(),cla));
+
+ if( LDEBUG || DMLScript.STATISTICS )
+ Statistics.incrementCodegenPlanCacheTotal();
+ }
+
+ //generate final hop dag
+ ret = constructModifiedHopDag(roots, cplans, clas);
+
+ //explain after modification
+ if( LDEBUG && cplans.size() > 0 ) { //existing cplans
+ LOG.info("Codegen EXPLAIN (after optimize): \n"+Explain.explainHops(roots));
+ }
+ }
+ catch( Exception ex ) {
+ throw new DMLRuntimeException(ex);
+ }
+
+ if( DMLScript.STATISTICS ) {
+ Statistics.incrementCodegenDAGCompile();
+ Statistics.incrementCodegenCompileTime(System.nanoTime()-t0);
+ }
+
+ return ret;
+ }
+
+
+ ////////////////////
+ // Codegen plan construction
+
+ private static HashMap<Long, Pair<Hop[],CNodeTpl>> constructCPlans(ArrayList<Hop> roots, boolean compileLiterals) throws DMLException
+ {
+ LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> ret = new LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>();
+ for( Hop hop : roots ) {
+ CplanRegister perRootCplans = new CplanRegister();
+ HashSet<Long> memo = new HashSet<Long>();
+ rConstructCPlans(hop, perRootCplans, memo, compileLiterals);
+
+ for (Entry<Long, Pair<Hop[],CNodeTpl>> entry : perRootCplans.getTopLevelCplans().entrySet())
+ if(!ret.containsKey(entry.getKey()))
+ ret.put(entry.getKey(), entry.getValue());
+ }
+ return ret;
+ }
+
+ private static void rConstructCPlans(Hop hop, CplanRegister cplanReg, HashSet<Long> memo, boolean compileLiterals) throws DMLException
+ {
+ if( memo.contains(hop.getHopID()) )
+ return;
+
+ //construct template instances
+ BaseTpl[] templates = new BaseTpl[]{
+ new RowAggTpl(), new CellTpl(), new OuterProductTpl()};
+
+ //process hop with all templates
+ for( BaseTpl tpl : templates ) {
+ if( tpl.openTpl(hop) && tpl.findTplBoundaries(hop,cplanReg) ) {
+ cplanReg.insertCpplans(tpl.getType(),
+ tpl.constructTplCplan(compileLiterals));
+ }
+ }
+
+ //process childs recursively
+ memo.add(hop.getHopID());
+ for( Hop c : hop.getInput() )
+ rConstructCPlans(c, cplanReg, memo, compileLiterals);
+ }
+
+ ////////////////////
+ // Codegen hop dag construction
+
+ private static ArrayList<Hop> constructModifiedHopDag(ArrayList<Hop> orig,
+ HashMap<Long, Pair<Hop[],CNodeTpl>> cplans, HashMap<Long, Pair<Hop[],Class<?>>> cla)
+ {
+ HashSet<Long> memo = new HashSet<Long>();
+ for( int i=0; i<orig.size(); i++ ) {
+ Hop hop = orig.get(i); //w/o iterator because modified
+ rConstructModifiedHopDag(hop, cplans, cla, memo);
+ }
+ return orig;
+ }
+
+ private static void rConstructModifiedHopDag(Hop hop, HashMap<Long, Pair<Hop[],CNodeTpl>> cplans,
+ HashMap<Long, Pair<Hop[],Class<?>>> clas, HashSet<Long> memo)
+ {
+ if( memo.contains(hop.getHopID()) )
+ return; //already processed
+
+ Hop hnew = hop;
+ if( clas.containsKey(hop.getHopID()) )
+ {
+ //replace sub-dag with generated operator
+ Pair<Hop[], Class<?>> tmpCla = clas.get(hop.getHopID());
+ CNodeTpl tmpCNode = cplans.get(hop.getHopID()).getValue();
+ hnew = new SpoofFusedOp(hop.getName(), hop.getDataType(), hop.getValueType(),
+ tmpCla.getValue(), false, tmpCNode.getOutputDimType());
+ for( Hop in : tmpCla.getKey() ) {
+ hnew.addInput(in); //add inputs
+ }
+ hnew.setOutputBlocksizes(hop.getRowsInBlock() , hop.getColsInBlock());
+ hnew.setDim1(hop.getDim1());
+ hnew.setDim2(hop.getDim2());
+ if(tmpCNode instanceof CNodeOuterProduct && ((CNodeOuterProduct)tmpCNode).isTransposeOutput() ) {
+ hnew = HopRewriteUtils.createTranspose(hnew);
+ }
+
+ HopRewriteUtils.rewireAllParentChildReferences(hop, hnew);
+ memo.add(hnew.getHopID());
+ }
+
+ //process hops recursively (parent-child links modified)
+ for( int i=0; i<hnew.getInput().size(); i++ ) {
+ Hop c = hnew.getInput().get(i);
+ rConstructModifiedHopDag(c, cplans, clas, memo);
+ }
+ memo.add(hnew.getHopID());
+ }
+
+ /**
+ * Cleanup generated cplans in order to remove unnecessary inputs created
+ * during incremental construction. This is important as it avoids unnecessary
+ * redundant computation.
+ *
+ * @param cplans set of cplans
+ */
+ private static HashMap<Long, Pair<Hop[],CNodeTpl>> cleanupCPlans(HashMap<Long, Pair<Hop[],CNodeTpl>> cplans) {
+ HashMap<Long, Pair<Hop[],CNodeTpl>> cplans2 = new HashMap<Long, Pair<Hop[],CNodeTpl>>();
+ for( Entry<Long, Pair<Hop[],CNodeTpl>> e : cplans.entrySet() ) {
+ CNodeTpl tpl = e.getValue().getValue();
+ Hop[] inHops = e.getValue().getKey();
+
+ //collect cplan leaf node names
+ HashSet<Long> leafs = new HashSet<Long>();
+ rCollectLeafIDs(tpl.getOutput(), leafs);
+
+ //create clean cplan w/ minimal inputs
+ if( inHops.length == leafs.size() )
+ cplans2.put(e.getKey(), e.getValue());
+ else {
+ tpl.cleanupInputs(leafs);
+ ArrayList<Hop> tmp = new ArrayList<Hop>();
+ for( Hop hop : inHops )
+ if( leafs.contains(hop.getHopID()) )
+ tmp.add(hop);
+ cplans2.put(e.getKey(), new Pair<Hop[],CNodeTpl>(
+ tmp.toArray(new Hop[0]),tpl));
+ }
+
+ //remove cplan w/ single op and w/o agg
+ if( tpl instanceof CNodeCell && ((CNodeCell)tpl).getCellType()==CellType.NO_AGG
+ && tpl.getOutput() instanceof CNodeUnary && tpl.getOutput().getInput().get(0) instanceof CNodeData)
+ cplans2.remove(e.getKey());
+
+ //remove cplan if empty
+ if( tpl.getOutput() instanceof CNodeData )
+ cplans2.remove(e.getKey());
+ }
+
+ return cplans2;
+ }
+
+ private static void rCollectLeafIDs(CNode node, HashSet<Long> leafs) {
+ //collect leaf variable names
+ if( node instanceof CNodeData )
+ leafs.add(((CNodeData) node).getHopID());
+
+ //recursively process cplan
+ for( CNode c : node.getInput() )
+ rCollectLeafIDs(c, leafs);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
new file mode 100644
index 0000000..357d41c
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen;
+
+import java.util.ArrayList;
+
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.MultiThreadedHop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.MemoTable;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.Lop;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.lops.LopsException;
+import org.apache.sysml.lops.SpoofFused;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.parser.Expression.ValueType;
+
+public class SpoofFusedOp extends Hop implements MultiThreadedHop
+{
+ public enum SpoofOutputDimsType {
+ INPUT_DIMS,
+ ROW_DIMS,
+ COLUMN_DIMS_ROWS,
+ COLUMN_DIMS_COLS,
+ SCALAR,
+ ROW_RANK_DIMS, // right wdivmm
+ COLUMN_RANK_DIMS // left wdivmm
+ }
+
+ private Class<?> _class = null;
+ private boolean _distSupported = false;
+ private int _numThreads = -1;
+ private SpoofOutputDimsType _dimsType;
+
+ public SpoofFusedOp ( ) {
+
+ }
+
+ public SpoofFusedOp( String name, DataType dt, ValueType vt, Class<?> cla, boolean dist, SpoofOutputDimsType type ) {
+ super(name, dt, vt);
+ _class = cla;
+ _distSupported = dist;
+ _dimsType = type;
+ }
+
+ @Override
+ public void setMaxNumThreads(int k) {
+ _numThreads = k;
+ }
+
+ @Override
+ public int getMaxNumThreads() {
+ return _numThreads;
+ }
+
+ @Override
+ public boolean allowsAllExecTypes() {
+ return _distSupported;
+ }
+
+ @Override
+ protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
+ return OptimizerUtils.estimateSize(dim1, dim2);
+ }
+
+ @Override
+ protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
+ return 0;
+ }
+
+ @Override
+ protected long[] inferOutputCharacteristics(MemoTable memo) {
+ return null;
+ }
+
+ @Override
+ public Lop constructLops() throws HopsException, LopsException {
+ if( getLops() != null )
+ return getLops();
+
+ ExecType et = optFindExecType();
+
+ ArrayList<Lop> inputs = new ArrayList<Lop>();
+ for( Hop c : getInput() )
+ inputs.add(c.constructLops());
+
+ int k = OptimizerUtils.getConstrainedNumThreads(_numThreads);
+ SpoofFused lop = new SpoofFused(inputs, getDataType(), getValueType(), _class, k, et);
+ setOutputDimensions(lop);
+ setLineNumbers(lop);
+ setLops(lop);
+
+ return lop;
+ }
+
+ @Override
+ protected ExecType optFindExecType() throws HopsException {
+
+ checkAndSetForcedPlatform();
+
+ if( _etypeForced != null ) {
+ _etype = _etypeForced;
+ }
+ else {
+ _etype = findExecTypeByMemEstimate();
+ checkAndSetInvalidCPDimsAndSize();
+ }
+
+ //ensure valid execution plans
+ if( _etype == ExecType.MR )
+ _etype = ExecType.CP;
+
+ return _etype;
+ }
+
+ @Override
+ public String getOpString() {
+ return "spoof("+_class.getSimpleName()+")";
+ }
+
+ @Override
+ public void refreshSizeInformation() {
+ switch(_dimsType)
+ {
+ case ROW_DIMS:
+ setDim1(getInput().get(0).getDim1());
+ setDim2(1);
+ break;
+ case COLUMN_DIMS_ROWS:
+ setDim1(getInput().get(0).getDim2());
+ setDim2(1);
+ break;
+ case COLUMN_DIMS_COLS:
+ setDim1(1);
+ setDim2(getInput().get(0).getDim2());
+ break;
+ case INPUT_DIMS:
+ setDim1(getInput().get(0).getDim1());
+ setDim2(getInput().get(0).getDim2());
+ break;
+ case SCALAR:
+ setDim1(0);
+ setDim2(0);
+ break;
+ case ROW_RANK_DIMS:
+ setDim1(getInput().get(0).getDim1());
+ setDim2(getInput().get(1).getDim2());
+ break;
+ case COLUMN_RANK_DIMS:
+ setDim1(getInput().get(0).getDim2());
+ setDim2(getInput().get(1).getDim2());
+ break;
+ default:
+ throw new RuntimeException("Failed to refresh size information "
+ + "for type: "+_dimsType.toString());
+ }
+ }
+
+ @Override
+ public Object clone() throws CloneNotSupportedException
+ {
+ SpoofFusedOp ret = new SpoofFusedOp();
+
+ //copy generic attributes
+ ret.clone(this, false);
+
+ //copy specific attributes
+ ret._class = _class;
+ ret._distSupported = _distSupported;
+ ret._numThreads = _numThreads;
+ ret._dimsType = _dimsType;
+ return ret;
+ }
+
+ @Override
+ public boolean compare( Hop that )
+ {
+ if( !(that instanceof SpoofFusedOp) )
+ return false;
+
+ SpoofFusedOp that2 = (SpoofFusedOp)that;
+ boolean ret = ( _class.equals(that2._class)
+ && _distSupported == that2._distSupported
+ && _numThreads == that2._numThreads
+ && getInput().size() == that2.getInput().size());
+
+ if( ret ) {
+ for( int i=0; i<getInput().size(); i++ )
+ ret &= (getInput().get(i) == that2.getInput().get(i));
+ }
+
+ return ret;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
new file mode 100644
index 0000000..46637cc
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.cplan;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
+
+public abstract class CNode
+{
+ private static final IDSequence _seq = new IDSequence();
+
+ protected ArrayList<CNode> _inputs = null;
+ protected CNode _output = null;
+ protected boolean _generated = false;
+ protected String _genVar = null;
+ protected long _rows = -1;
+ protected long _cols = -1;
+ protected DataType _dataType;
+ protected boolean _literal = false;
+
+ //cached hash to allow memoization in DAG structures and repeated
+ //recursive hash computation over all inputs (w/ reset on updates)
+ protected int _hash = 0;
+
+ public CNode() {
+ _inputs = new ArrayList<CNode>();
+ _generated = false;
+ }
+
+ public ArrayList<CNode> getInput() {
+ return _inputs;
+ }
+
+ public String createVarname() {
+ _genVar = "TMP"+_seq.getNextID();
+ return _genVar;
+ }
+
+ protected String getCurrentVarName() {
+ return "TMP"+(_seq.getCurrentID()-1);
+ }
+
+ public String getVarname() {
+ return _genVar;
+ }
+
+ public String getClassname() {
+ return getVarname();
+ }
+
+ public void resetGenerated() {
+ if( _generated )
+ for( CNode cn : _inputs )
+ cn.resetGenerated();
+ _generated = false;
+ }
+
+ public void setNumRows(long rows) {
+ _rows = rows;
+ }
+
+ public long getNumRows() {
+ return _rows;
+ }
+
+ public void setNumCols(long cols) {
+ _cols = cols;
+ }
+
+ public long getNumCols() {
+ return _cols;
+ }
+
+ public DataType getDataType() {
+ return _dataType;
+ }
+
+ public void setDataType(DataType dt) {
+ _dataType = dt;
+ _hash = 0;
+ }
+
+ public boolean isLiteral() {
+ return _literal;
+ }
+
+ public void setLiteral(boolean literal) {
+ _literal = literal;
+ _hash = 0;
+ }
+
+ public CNode getOutput() {
+ return _output;
+ }
+
+ public void setOutput(CNode output) {
+ _output = output;
+ _hash = 0;
+ }
+
+ public abstract String codegen(boolean sparse) ;
+
+ public abstract void setOutputDims();
+
+ ///////////////////////////////////////
+ // Functionality for plan cache
+
+ //note: genvar/generated changed on codegen and not considered,
+ //rows and cols also not include to increase reuse potential
+
+ @Override
+ public int hashCode() {
+ if( _hash == 0 ) {
+ int numIn = _inputs.size();
+ int[] tmp = new int[numIn + 3];
+ //include inputs, partitioned by matrices and scalars to increase
+ //reuse in case of interleaved inputs (see CNodeTpl.renameInputs)
+ int pos = 0;
+ for( CNode c : _inputs )
+ if( c.getDataType()==DataType.MATRIX )
+ tmp[pos++] = c.hashCode();
+ for( CNode c : _inputs )
+ if( c.getDataType()!=DataType.MATRIX )
+ tmp[pos++] = c.hashCode();
+ tmp[numIn+0] = (_output!=null)?_output.hashCode():0;
+ tmp[numIn+1] = (_dataType!=null)?_dataType.hashCode():0;
+ tmp[numIn+2] = Boolean.hashCode(_literal);
+ _hash = Arrays.hashCode(tmp);
+ }
+ return _hash;
+ }
+
+ @Override
+ public boolean equals(Object that) {
+ if( !(that instanceof CNode) )
+ return false;
+
+ CNode cthat = (CNode) that;
+ boolean ret = _inputs.size() == cthat._inputs.size();
+ for( int i=0; i<_inputs.size() && ret; i++ )
+ ret &= _inputs.get(i).equals(_inputs.get(i));
+ return ret
+ && (_output == cthat._output || _output.equals(cthat._output))
+ && _dataType == cthat._dataType
+ && _literal == cthat._literal;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
new file mode 100644
index 0000000..1bfaab4
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.cplan;
+
+import java.util.Arrays;
+
+import org.apache.sysml.parser.Expression.DataType;
+
+
+public class CNodeBinary extends CNode
+{
+ public enum BinType {
+ DOT_PRODUCT,
+ VECT_MULT_ADD, VECT_DIV_ADD,
+ VECT_MULT_SCALAR, VECT_DIV_SCALAR,
+ MULT, DIV, PLUS, MINUS, MODULUS, INTDIV,
+ LESS, LESSEQUAL, GREATER, GREATEREQUAL, EQUAL,NOTEQUAL,
+ MIN, MAX, AND, OR, LOG, POW,
+ MINUS1_MULT;
+
+ public static boolean contains(String value) {
+ for( BinType bt : values() )
+ if( bt.toString().equals(value) )
+ return true;
+ return false;
+ }
+
+ public boolean isCommutative() {
+ return ( this == EQUAL || this == NOTEQUAL
+ || this == PLUS || this == MULT
+ || this == MIN || this == MAX );
+ }
+
+ public String getTemplate(boolean sparse) {
+ switch (this) {
+ case DOT_PRODUCT:
+ return sparse ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, %LEN%);\n" :
+ " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
+
+ case VECT_MULT_ADD:
+ return sparse ? " LibSpoofPrimitives.vectMultiplyAdd(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, %LEN%);\n" :
+ " LibSpoofPrimitives.vectMultiplyAdd(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
+
+ case VECT_DIV_ADD:
+ return sparse ? " LibSpoofPrimitives.vectDivAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, %LEN%);\n" :
+ " LibSpoofPrimitives.vectDivAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
+
+ case VECT_DIV_SCALAR:
+ return sparse ? " LibSpoofPrimitives.vectDivWrite(%IN1v%, %IN1i%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n" :
+ " LibSpoofPrimitives.vectDivWrite(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, %LEN%);\n";
+
+ case VECT_MULT_SCALAR:
+ return " LibSpoofPrimitives.vectMultiplyWrite(%IN2%, %IN1%, %POS1%, %OUT%, 0, %LEN%);\n";
+
+
+ /*Can be replaced by function objects*/
+ case MULT:
+ return " double %TMP% = %IN1% * %IN2%;\n" ;
+
+ case DIV:
+ return " double %TMP% = %IN1% / %IN2%;\n" ;
+ case PLUS:
+ return " double %TMP% = %IN1% + %IN2%;\n" ;
+ case MINUS:
+ return " double %TMP% = %IN1% - %IN2%;\n" ;
+ case MODULUS:
+ return " double %TMP% = %IN1% % %IN2%;\n" ;
+ case INTDIV:
+ return " double %TMP% = (int) %IN1% / %IN2%;\n" ;
+ case LESS:
+ return " double %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n" ;
+ case LESSEQUAL:
+ return " double %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n" ;
+ case GREATER:
+ return " double %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n" ;
+ case GREATEREQUAL:
+ return " double %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n" ;
+ case EQUAL:
+ return " double %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n" ;
+ case NOTEQUAL:
+ return " double %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n" ;
+
+ case MIN:
+ return " double %TMP% = Math.min(%IN1%, %IN2%);\n" ;
+ case MAX:
+ return " double %TMP% = Math.max(%IN1%, %IN2%);\n" ;
+ case LOG:
+ return " double %TMP% = Math.log(%IN1%)/Math.log(%IN2%);\n" ;
+ case POW:
+ return " double %TMP% = Math.pow(%IN1%, %IN2%);\n" ;
+ case MINUS1_MULT:
+ return " double %TMP% = 1 - %IN1% * %IN2%;\n" ;
+
+ default:
+ throw new RuntimeException("Invalid binary type: "+this.toString());
+ }
+ }
+ }
+
+ private final BinType _type;
+
+ public CNodeBinary( CNode in1, CNode in2, BinType type ) {
+ //canonicalize commutative matrix-scalar operations
+ //to increase reuse potential
+ if( type.isCommutative() && in1 instanceof CNodeData
+ && in1.getDataType()==DataType.SCALAR ) {
+ CNode tmp = in1;
+ in1 = in2;
+ in2 = tmp;
+ }
+
+ _inputs.add(in1);
+ _inputs.add(in2);
+ _type = type;
+ setOutputDims();
+ }
+
+ public BinType getType() {
+ return _type;
+ }
+
+ @Override
+ public String codegen(boolean sparse) {
+ if( _generated )
+ return "";
+
+ StringBuilder sb = new StringBuilder();
+
+ //generate children
+ sb.append(_inputs.get(0).codegen(sparse));
+ sb.append(_inputs.get(1).codegen(sparse));
+
+ //generate binary operation
+ String var = createVarname();
+ String tmp = _type.getTemplate(sparse);
+ tmp = tmp.replaceAll("%TMP%", var);
+ for( int j=1; j<=2; j++ ) {
+ String varj = _inputs.get(j-1).getVarname();
+ if( sparse && !tmp.contains("%IN"+j+"%") ) {
+ tmp = tmp.replaceAll("%IN"+j+"v%", varj+"vals");
+ tmp = tmp.replaceAll("%IN"+j+"i%", varj+"ix");
+ }
+ else
+ tmp = tmp.replaceAll("%IN"+j+"%", varj );
+
+ if(varj.startsWith("_b") ) //i.e. b.get(index)
+ tmp = tmp.replaceAll("%POS"+j+"%", "_bi");
+ else
+ tmp = tmp.replaceAll("%POS"+j+"%", varj+"i");
+ }
+ sb.append(tmp);
+
+ //mark as generated
+ _generated = true;
+
+ return sb.toString();
+ }
+
+ @Override
+ public String toString() {
+ switch(_type) {
+ case DOT_PRODUCT: return "b(dot)";
+ case VECT_MULT_ADD: return "b(vma)";
+ case VECT_DIV_ADD: return "b(vda)";
+ case MULT: return "b(*)";
+ case DIV: return "b(/)";
+ case VECT_DIV_SCALAR: return "b(vector/)";
+ case VECT_MULT_SCALAR: return "b(vector*)";
+ default:
+ return super.toString();
+ }
+ }
+
+ public void setOutputDims()
+ {
+ switch(_type) {
+ //VECT
+ case VECT_MULT_ADD:
+ case VECT_DIV_ADD:
+ _rows = _inputs.get(1)._rows;
+ _cols = _inputs.get(1)._cols;
+ _dataType= DataType.MATRIX;
+ break;
+
+ case VECT_DIV_SCALAR:
+ case VECT_MULT_SCALAR:
+ _rows = _inputs.get(0)._rows;
+ _cols = _inputs.get(0)._cols;
+ _dataType= DataType.MATRIX;
+ break;
+
+
+ case DOT_PRODUCT:
+
+ //SCALAR Arithmetic
+ case MULT:
+ case DIV:
+ case PLUS:
+ case MINUS:
+ case MINUS1_MULT:
+ case MODULUS:
+ case INTDIV:
+ //SCALAR Comparison
+ case LESS:
+ case LESSEQUAL:
+ case GREATER:
+ case GREATEREQUAL:
+ case EQUAL:
+ case NOTEQUAL:
+ //SCALAR LOGIC
+ case MIN:
+ case MAX:
+ case AND:
+ case OR:
+ case LOG:
+ case POW:
+ _rows = 0;
+ _cols = 0;
+ _dataType= DataType.SCALAR;
+ break;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ if( _hash == 0 ) {
+ int h1 = super.hashCode();
+ int h2 = _type.hashCode();
+ _hash = Arrays.hashCode(new int[]{h1,h2});
+ }
+ return _hash;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if( !(o instanceof CNodeBinary) )
+ return false;
+
+ CNodeBinary that = (CNodeBinary) o;
+ return super.equals(that)
+ && _type == that._type;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java
new file mode 100644
index 0000000..a9408ca
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.cplan;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
+import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;
+
+public class CNodeCell extends CNodeTpl
+{
+ private static final String TEMPLATE =
+ "package codegen;\n"
+ + "import java.util.Arrays;\n"
+ + "import java.io.Serializable;\n"
+ + "import java.util.ArrayList;\n"
+ + "import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;\n"
+ + "import org.apache.sysml.runtime.codegen.SpoofCellwise;\n"
+ + "import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;\n"
+ + "import org.apache.commons.math3.util.FastMath;\n"
+ + "\n"
+ + "public final class %TMP% extends SpoofCellwise {\n"
+ + " public %TMP%() {\n"
+ + " _type = CellType.%TYPE%;\n"
+ + " }\n"
+ + " protected double genexecDense( double _a, double[][] _b, double[] _scalars, int _n, int _m, int _rowIndex, int _colIndex) { \n"
+ + "%BODY_dense%"
+ + " return %OUT%;\n"
+ + " } \n"
+ + "}";
+
+ private CellType _type = null;
+ private boolean _multipleConsumers = false;
+
+ public CNodeCell(ArrayList<CNode> inputs, CNode output ) {
+ super(inputs,output);
+ }
+
+ public void setMultipleConsumers(boolean flag) {
+ _multipleConsumers = flag;
+ }
+
+ public boolean hasMultipleConsumers() {
+ return _multipleConsumers;
+ }
+
+ public void setCellType(CellType type) {
+ _type = type;
+ _hash = 0;
+ }
+
+ public CellType getCellType() {
+ return _type;
+ }
+
+ @Override
+ public String codegen(boolean sparse) {
+ String tmp = TEMPLATE;
+
+ //rename inputs
+ rReplaceDataNode(_output, _inputs.get(0), "_a");
+ renameInputs(_inputs, 1);
+
+ //generate dense/sparse bodies
+ String tmpDense = _output.codegen(false);
+ _output.resetGenerated();
+
+ tmp = tmp.replaceAll("%TMP%", createVarname());
+ tmp = tmp.replaceAll("%BODY_dense%", tmpDense);
+
+ //return last TMP
+ tmp = tmp.replaceAll("%OUT%", getCurrentVarName());
+
+ //replace aggregate information
+ tmp = tmp.replaceAll("%TYPE%", getCellType().toString());
+
+ return tmp;
+ }
+
+ @Override
+ public void setOutputDims() {
+
+
+ }
+
+ @Override
+ public CNodeTpl clone() {
+ CNodeCell tmp = new CNodeCell(_inputs, _output);
+ tmp.setDataType(getDataType());
+ tmp.setCellType(getCellType());
+ tmp.setMultipleConsumers(hasMultipleConsumers());
+ return tmp;
+ }
+
+ @Override
+ public SpoofOutputDimsType getOutputDimType() {
+ switch( _type ) {
+ case NO_AGG: return SpoofOutputDimsType.INPUT_DIMS;
+ case ROW_AGG: return SpoofOutputDimsType.ROW_DIMS;
+ case FULL_AGG: return SpoofOutputDimsType.SCALAR;
+ default:
+ throw new RuntimeException("Unsupported cell type: "+_type.toString());
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ if( _hash == 0 ) {
+ int h1 = super.hashCode();
+ int h2 = _type.hashCode();
+ //note: _multipleConsumers irrelevant for plan comparison
+ _hash = Arrays.hashCode(new int[]{h1,h2});
+ }
+ return _hash;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if(!(o instanceof CNodeCell))
+ return false;
+
+ CNodeCell that = (CNodeCell)o;
+ return super.equals(that)
+ && _type == that._type;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java
new file mode 100644
index 0000000..d5457e8
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.cplan;
+
+import java.util.Arrays;
+
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.parser.Expression.DataType;
+
+public class CNodeData extends CNode
+{
+ protected final String _name;
+ protected final long _hopID;
+
+ public CNodeData(Hop hop) {
+ this(hop, hop.getDim1(), hop.getDim2(), hop.getDataType());
+ }
+
+ public CNodeData(Hop hop, long rows, long cols, DataType dt) {
+ //note: previous rewrites might have created hops with equal name
+ //hence, we also keep the hopID to uniquely identify inputs
+ _name = hop.getName();
+ _hopID = hop.getHopID();
+ _rows = rows;
+ _cols = cols;
+ _dataType = dt;
+ }
+
+ public CNodeData(CNodeData node, String newName) {
+ _name = newName;
+ _hopID = node.getHopID();
+ _rows = node.getNumRows();
+ _cols = node.getNumCols();
+ _dataType = node.getDataType();
+ }
+
+ @Override
+ public String getVarname() {
+ return _name;
+ }
+
+ public long getHopID() {
+ return _hopID;
+ }
+
+ @Override
+ public String codegen(boolean sparse) {
+ return "";
+ }
+
+ @Override
+ public void setOutputDims() {
+
+ }
+
+ @Override
+ public String toString() {
+ return "CdataNode[name="+_name+", id="+_hopID+"]";
+ }
+
+ @Override
+ public int hashCode() {
+ if( _hash == 0 ) {
+ int h1 = super.hashCode();
+ int h2 = isLiteral() ? _name.hashCode() : 0;
+ _hash = Arrays.hashCode(new int[]{h1,h2});
+ }
+ return _hash;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof CNodeData
+ && super.equals(o)
+ && (!isLiteral() || _name.equals(((CNodeData)o)._name)));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java
new file mode 100644
index 0000000..8c2e38c
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeOuterProduct.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.cplan;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
+import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType;
+
+
+public class CNodeOuterProduct extends CNodeTpl
+{
+ private static final String TEMPLATE =
+ "package codegen;\n"
+ + "import java.util.Arrays;\n"
+ + "import java.util.ArrayList;\n"
+ + "import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;\n"
+ + "import org.apache.sysml.runtime.codegen.SpoofOuterProduct;\n"
+ + "import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType;\n"
+ + "import org.apache.commons.math3.util.FastMath;\n"
+ + "\n"
+ + "public final class %TMP% extends SpoofOuterProduct { \n"
+ + " public %TMP%() {\n"
+ + " _outerProductType = OutProdType.%TYPE%;\n"
+ + " }\n"
+ + " protected void genexecDense( double _a, double[] _a1, int _a1i, double[] _a2, int _a2i, double[][] _b, double[] _scalars, double[] _c, int _ci, int _n, int _m, int _k, int _rowIndex, int _colIndex) { \n"
+ + "%BODY_dense%"
+ + " } \n"
+ + " protected double genexecCellwise( double _a, double[] _a1, int _a1i, double[] _a2, int _a2i, double[][] _b, double[] _scalars, int _n, int _m, int _k, int _rowIndex, int _colIndex) { \n"
+ + "%BODY_cellwise%"
+ + " return %OUT_cellwise%;\n"
+ + " } \n"
+
+ + "}";
+
+ private OutProdType _type = null;
+ private boolean _transposeOutput = false;
+
+ public CNodeOuterProduct(ArrayList<CNode> inputs, CNode output ) {
+ super(inputs,output);
+ }
+
+ @Override
+ public String codegen(boolean sparse) {
+ // note: ignore sparse flag, generate both
+ String tmp = TEMPLATE;
+
+ //rename inputs
+ rReplaceDataNode(_output, _inputs.get(0), "_a");
+ rReplaceDataNode(_output, _inputs.get(1), "_a1"); // u
+ rReplaceDataNode(_output, _inputs.get(2), "_a2"); // v
+ renameInputs(_inputs, 3);
+
+ //generate dense/sparse bodies
+ String tmpDense = _output.codegen(false);
+ _output.resetGenerated();
+
+ tmp = tmp.replaceAll("%TMP%", createVarname());
+
+ if(_type == OutProdType.LEFT_OUTER_PRODUCT || _type == OutProdType.RIGHT_OUTER_PRODUCT) {
+ tmp = tmp.replaceAll("%BODY_dense%", tmpDense);
+ tmp = tmp.replaceAll("%OUT%", "_c");
+ tmp = tmp.replaceAll("%BODY_cellwise%", "");
+ tmp = tmp.replaceAll("%OUT_cellwise%", "0");
+ }
+ else {
+ tmp = tmp.replaceAll("%BODY_dense%", "");
+ tmp = tmp.replaceAll("%BODY_cellwise%", tmpDense);
+ tmp = tmp.replaceAll("%OUT_cellwise%", getCurrentVarName());
+ }
+ //replace size information
+ tmp = tmp.replaceAll("%LEN%", "_k");
+
+ tmp = tmp.replaceAll("%POSOUT%", "_ci");
+
+ tmp = tmp.replaceAll("%TYPE%", _type.toString());
+
+ return tmp;
+ }
+
+ public void setOutProdType(OutProdType type) {
+ _type = type;
+ _hash = 0;
+ }
+
+ public OutProdType getOutProdType() {
+ return _type;
+ }
+
+ @Override
+ public void setOutputDims() {
+
+ }
+
+ public void setTransposeOutput(boolean transposeOutput) {
+ _transposeOutput = transposeOutput;
+ _hash = 0;
+ }
+
+
+ public boolean isTransposeOutput() {
+ return _transposeOutput;
+ }
+
+ @Override
+ public SpoofOutputDimsType getOutputDimType() {
+ switch( _type ) {
+ case LEFT_OUTER_PRODUCT:
+ return SpoofOutputDimsType.COLUMN_RANK_DIMS;
+ case RIGHT_OUTER_PRODUCT:
+ return SpoofOutputDimsType.ROW_RANK_DIMS;
+ case CELLWISE_OUTER_PRODUCT:
+ return SpoofOutputDimsType.INPUT_DIMS;
+ case AGG_OUTER_PRODUCT:
+ return SpoofOutputDimsType.SCALAR;
+ default:
+ throw new RuntimeException("Unsupported outer product type: "+_type.toString());
+ }
+ }
+
+ @Override
+ public CNodeTpl clone() {
+ return new CNodeOuterProduct(_inputs, _output);
+ }
+
+ @Override
+ public int hashCode() {
+ if( _hash == 0 ) {
+ int h1 = super.hashCode();
+ int h2 = _type.hashCode();
+ int h3 = Boolean.hashCode(_transposeOutput);
+ _hash = Arrays.hashCode(new int[]{h1,h2,h3});
+ }
+ return _hash;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if(!(o instanceof CNodeOuterProduct))
+ return false;
+
+ CNodeOuterProduct that = (CNodeOuterProduct)o;
+ return super.equals(that)
+ && _type == that._type
+ && _transposeOutput == that._transposeOutput;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRowAggVector.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRowAggVector.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRowAggVector.java
new file mode 100644
index 0000000..147615f
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRowAggVector.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.cplan;
+
+import java.util.ArrayList;
+
+import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
+
+public class CNodeRowAggVector extends CNodeTpl
+{
+ private static final String TEMPLATE =
+ "package codegen;\n"
+ + "import java.util.Arrays;\n"
+ + "import java.util.ArrayList;\n"
+ + "import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;\n"
+ + "import org.apache.sysml.runtime.codegen.SpoofRowAggregate;\n"
+ + "\n"
+ + "public final class %TMP% extends SpoofRowAggregate { \n"
+ + " public %TMP%() {\n"
+ + " _colVector = %FLAG%;\n"
+ + " }\n"
+ + " protected void genexecRowDense( double[] _a, int _ai, double[][] _b, double[] _scalars, double[] _c, int _len, int _rowIndex ) { \n"
+ + "%BODY_dense%"
+ + " } \n"
+ + " protected void genexecRowSparse( double[] _avals, int[] _aix, int _ai, double[][] _b, double[] _scalars, double[] _c, int _len, int _rowIndex ) { \n"
+ + "%BODY_sparse%"
+ + " } \n"
+ + "}\n";
+
+ public CNodeRowAggVector(ArrayList<CNode> inputs, CNode output ) {
+ super(inputs, output);
+ }
+
+
+ @Override
+ public String codegen(boolean sparse) {
+ // note: ignore sparse flag, generate both
+ String tmp = TEMPLATE;
+
+ //rename inputs
+ rReplaceDataNode(_output, _inputs.get(0), "_a"); // input matrix
+ renameInputs(_inputs, 1);
+
+ //generate dense/sparse bodies
+ String tmpDense = _output.codegen(false);
+ _output.resetGenerated();
+ String tmpSparse = _output.codegen(true);
+ tmp = tmp.replaceAll("%TMP%", createVarname());
+ tmp = tmp.replaceAll("%BODY_dense%", tmpDense);
+ tmp = tmp.replaceAll("%BODY_sparse%", tmpSparse);
+
+ //replace outputs
+ tmp = tmp.replaceAll("%OUT%", "_c");
+ tmp = tmp.replaceAll("%POSOUT%", "0");
+
+ //replace size information
+ tmp = tmp.replaceAll("%LEN%", "_len");
+
+ //replace colvector information and start position
+ tmp = tmp.replaceAll("%FLAG%", String.valueOf(_output._cols==1));
+ tmp = tmp.replaceAll("_bi", "0");
+
+ return tmp;
+ }
+
+ @Override
+ public void setOutputDims() {
+ // TODO Auto-generated method stub
+
+ }
+
+ @Override
+ public SpoofOutputDimsType getOutputDimType() {
+ return (_output._cols==1) ?
+ SpoofOutputDimsType.COLUMN_DIMS_ROWS : //column vector
+ SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector
+ }
+
+ @Override
+ public CNodeTpl clone() {
+ return new CNodeRowAggVector(_inputs, _output);
+ }
+
+ @Override
+ public int hashCode() {
+ return super.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof CNodeRowAggVector
+ && super.equals(o));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
new file mode 100644
index 0000000..719770b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.cplan;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+
+import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
+import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
+import org.apache.sysml.parser.Expression.DataType;
+
+public abstract class CNodeTpl extends CNode implements Cloneable
+{
+ public CNodeTpl(ArrayList<CNode> inputs, CNode output ) {
+ if(inputs.size() < 1)
+ throw new RuntimeException("Cannot pass empty inputs to the CNodeTpl");
+
+ for(CNode input : inputs)
+ addInput(input);
+ _output = output;
+ }
+
+ public void addInput(CNode in) {
+ //check for duplicate entries or literals
+ if( containsInput(in) || in.isLiteral() )
+ return;
+
+ _inputs.add(in);
+ }
+
+ public void cleanupInputs(HashSet<Long> filter) {
+ ArrayList<CNode> tmp = new ArrayList<CNode>();
+ for( CNode in : _inputs )
+ if( in instanceof CNodeData && filter.contains(((CNodeData) in).getHopID()) )
+ tmp.add(in);
+ _inputs = tmp;
+ }
+
+ public String codegen() {
+ return codegen(false);
+ }
+
+ public abstract CNodeTpl clone();
+
+ public abstract SpoofOutputDimsType getOutputDimType();
+
+ protected void renameInputs(ArrayList<CNode> inputs, int startIndex) {
+ //create map of hopID to data nodes with new names, used for CSE
+ HashMap<Long, CNode> nodes = new HashMap<Long, CNode>();
+ for(int i=startIndex, sPos=0, mPos=0; i < inputs.size(); i++) {
+ CNode cnode = inputs.get(i);
+ if( !(cnode instanceof CNodeData) || ((CNodeData)cnode).isLiteral())
+ continue;
+ CNodeData cdata = (CNodeData)cnode;
+ if( cdata.getDataType() == DataType.SCALAR || ( cdata.getNumCols() == 0 && cdata.getNumRows() == 0) )
+ nodes.put(cdata.getHopID(), new CNodeData(cdata, "_scalars["+ mPos++ +"]"));
+ else
+ nodes.put(cdata.getHopID(), new CNodeData(cdata, "_b["+ sPos++ +"]"));
+ }
+
+ //single pass to replace all names
+ rReplaceDataNode(_output, nodes, new HashMap<Long, CNode>());
+ }
+
+ protected void rReplaceDataNode( CNode root, CNode input, String newName ) {
+ if( !(input instanceof CNodeData) )
+ return;
+
+ //create temporary name mapping
+ HashMap<Long, CNode> names = new HashMap<Long, CNode>();
+ CNodeData tmp = (CNodeData)input;
+ names.put(tmp.getHopID(), new CNodeData(tmp, newName));
+
+ rReplaceDataNode(root, names, new HashMap<Long,CNode>());
+ }
+
+ /**
+ * Recursively searches for data nodes and replaces them if found.
+ *
+ * @param node current node in recursive descend
+ * @param dnodes prepared data nodes, identified by own hop id
+ * @param lnodes memoized lookup nodes, identified by data node hop id
+ */
+ protected void rReplaceDataNode( CNode node, HashMap<Long, CNode> dnodes, HashMap<Long, CNode> lnodes )
+ {
+ for( int i=0; i<node._inputs.size(); i++ ) {
+ //recursively process children
+ rReplaceDataNode(node._inputs.get(i), dnodes, lnodes);
+
+ //replace leaf data node
+ if( node._inputs.get(i) instanceof CNodeData ) {
+ CNodeData tmp = (CNodeData)node._inputs.get(i);
+ if( dnodes.containsKey(tmp.getHopID()) )
+ node._inputs.set(i, dnodes.get(tmp.getHopID()));
+ }
+
+ //replace lookup on top of leaf data node
+ if( node._inputs.get(i) instanceof CNodeUnary
+ && ((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP) {
+ CNodeData tmp = (CNodeData)node._inputs.get(i)._inputs.get(0);
+ if( !lnodes.containsKey(tmp.getHopID()) )
+ lnodes.put(tmp.getHopID(), node._inputs.get(i));
+ else
+ node._inputs.set(i, lnodes.get(tmp.getHopID()));
+ }
+ }
+ }
+
+ public void rReplaceDataNode( CNode node, long hopID, CNode newNode )
+ {
+ for( int i=0; i<node._inputs.size(); i++ ) {
+ //replace leaf node
+ if( node._inputs.get(i) instanceof CNodeData ) {
+ CNodeData tmp = (CNodeData)node._inputs.get(i);
+ if( tmp.getHopID() == hopID )
+ node._inputs.set(i, newNode);
+ }
+ //recursively process children
+ rReplaceDataNode(node._inputs.get(i), hopID, newNode);
+
+ //remove unnecessary lookups
+ if( node._inputs.get(i) instanceof CNodeUnary
+ && ((CNodeUnary)node._inputs.get(i)).getType()==UnaryType.LOOKUP
+ && node._inputs.get(i)._inputs.get(0).getDataType()==DataType.SCALAR)
+ node._inputs.set(i, node._inputs.get(i)._inputs.get(0));
+ }
+ }
+
+ public void rInsertLookupNode( CNode node, long hopID, HashMap<Long, CNode> memo )
+ {
+ for( int i=0; i<node._inputs.size(); i++ ) {
+ //recursively process children
+ rInsertLookupNode(node._inputs.get(i), hopID, memo);
+
+ //replace leaf node
+ if( node._inputs.get(i) instanceof CNodeData ) {
+ CNodeData tmp = (CNodeData)node._inputs.get(i);
+ if( tmp.getHopID() == hopID ) {
+ //use memo structure to retain DAG structure
+ CNode lookup = memo.get(hopID);
+ if( lookup == null ) {
+ lookup = new CNodeUnary(tmp, UnaryType.LOOKUP);
+ memo.put(hopID, lookup);
+ }
+ node._inputs.set(i, lookup);
+ }
+ }
+ }
+ }
+
+ /**
+ * Checks for duplicates (object ref or varname).
+ *
+ * @param input new input node
+ * @return true if duplicate, false otherwise
+ */
+ private boolean containsInput(CNode input) {
+ if( !(input instanceof CNodeData) )
+ return false;
+
+ CNodeData input2 = (CNodeData)input;
+ for( CNode cnode : _inputs ) {
+ if( !(cnode instanceof CNodeData) )
+ continue;
+ CNodeData cnode2 = (CNodeData)cnode;
+ if( cnode2._name.equals(input2._name) && cnode2._hopID==input2._hopID )
+ return true;
+ }
+
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return super.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof CNodeTpl
+ && super.equals(o));
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
new file mode 100644
index 0000000..f08769e
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.cplan;
+
+import java.util.Arrays;
+
+import org.apache.sysml.parser.Expression.DataType;
+
+
+public class CNodeUnary extends CNode
+{
+ public enum UnaryType {
+ ROW_SUMS, LOOKUP, LOOKUP0,
+ EXP, POW2, MULT2, SQRT, LOG,
+ ABS, ROUND, CEIL,FLOOR, SIGN,
+ SIN, COS, TAN, ASIN, ACOS, ATAN,
+ IQM, STOP,
+ DOTPRODUCT_ROW_SUMS; //row sums via dot product for debugging purposes
+
+ public static boolean contains(String value) {
+ for( UnaryType ut : values() )
+ if( ut.toString().equals(value) )
+ return true;
+ return false;
+ }
+
+ public String getTemplate(boolean sparse) {
+ switch (this) {
+ case ROW_SUMS:
+ return sparse ? " double %TMP% = LibSpoofPrimitives.vectSum( %IN1v%, %IN1i%, %POS1%, %LEN%);\n":
+ " double %TMP% = LibSpoofPrimitives.vectSum( %IN1%, %POS1%, %LEN%);\n";
+ case EXP:
+ return " double %TMP% = FastMath.exp(%IN1%);\n";
+ case LOOKUP:
+ return " double %TMP% = %IN1%[_rowIndex];\n" ;
+ case LOOKUP0:
+ return " double %TMP% = %IN1%[0];\n" ;
+ case POW2:
+ return " double %TMP% = %IN1% * %IN1%;\n" ;
+ case MULT2:
+ return " double %TMP% = %IN1% + %IN1%;\n" ;
+ case ABS:
+ return " double %TMP% = Math.abs(%IN1%);\n";
+ case SIN:
+ return " double %TMP% = Math.sin(%IN1%);\n";
+ case COS:
+ return " double %TMP% = Math.cos(%IN1%);\n";
+ case TAN:
+ return " double %TMP% = Math.tan(%IN1%);\n";
+ case ASIN:
+ return " double %TMP% = Math.asin(%IN1%);\n";
+ case ACOS:
+ return " double %TMP% = Math.acos(%IN1%);\n";
+ case ATAN:
+ return " double %TMP% = Math.atan(%IN1%);\n";
+ case SIGN:
+ return " double %TMP% = Math.signum(%IN1%);\n";
+ case SQRT:
+ return " double %TMP% = Math.sqrt(%IN1%);\n";
+ case LOG:
+ return " double %TMP% = FastMath.log(%IN1%);\n";
+ case ROUND:
+ return " double %TMP% = Math.round(%IN1%);\n";
+ case CEIL:
+ return " double %TMP% = Math.ceil(%IN1%);\n";
+ case FLOOR:
+ return " double %TMP% = Math.floor(%IN1%);\n";
+ default:
+ throw new RuntimeException("Invalid binary type: "+this.toString());
+ }
+ }
+ }
+
+ private final UnaryType _type;
+
+ public CNodeUnary( CNode in1, UnaryType type ) {
+ _inputs.add(in1);
+ _type = type;
+ setOutputDims();
+ }
+
+ public UnaryType getType() {
+ return _type;
+ }
+
+ @Override
+ public String codegen(boolean sparse) {
+ if( _generated )
+ return "";
+
+ StringBuilder sb = new StringBuilder();
+
+ //generate children
+ sb.append(_inputs.get(0).codegen(sparse));
+
+ //generate binary operation
+ String var = createVarname();
+ String tmp = _type.getTemplate(sparse);
+ tmp = tmp.replaceAll("%TMP%", var);
+
+ String varj = _inputs.get(0).getVarname();
+ if( sparse && !tmp.contains("%IN1%") ) {
+ tmp = tmp.replaceAll("%IN1v%", varj+"vals");
+ tmp = tmp.replaceAll("%IN1i%", varj+"ix");
+ }
+ else
+ tmp = tmp.replaceAll("%IN1%", varj );
+
+ if(varj.startsWith("_b") ) //i.e. b.get(index)
+ {
+ tmp = tmp.replaceAll("%POS1%", "_bi");
+ tmp = tmp.replaceAll("%POS2%", "_bi");
+ }
+ tmp = tmp.replaceAll("%POS1%", varj+"i");
+ tmp = tmp.replaceAll("%POS2%", varj+"i");
+
+ sb.append(tmp);
+
+ //mark as generated
+ _generated = true;
+
+ return sb.toString();
+ }
+
+ @Override
+ public String toString() {
+ switch(_type) {
+ case ROW_SUMS: return "u(R+)";
+ default:
+ return super.toString();
+ }
+ }
+
+ @Override
+ public void setOutputDims() {
+ switch(_type)
+ {
+ case ROW_SUMS:
+ case EXP:
+ case LOOKUP:
+ case LOOKUP0:
+ case POW2:
+ case MULT2:
+ case ABS:
+ case SIN:
+ case COS:
+ case TAN:
+ case ASIN:
+ case ACOS:
+ case ATAN:
+ case SIGN:
+ case SQRT:
+ case LOG:
+ case ROUND:
+ case IQM:
+ case STOP:
+ case CEIL:
+ case FLOOR:
+ _rows = 0;
+ _cols = 0;
+ _dataType= DataType.SCALAR;
+ break;
+ default:
+ throw new RuntimeException("Operation " + _type.toString() + " has no "
+ + "output dimensions, dimensions needs to be specified for the CNode " );
+ }
+
+ }
+
+ @Override
+ public int hashCode() {
+ if( _hash == 0 ) {
+ int h1 = super.hashCode();
+ int h2 = _type.hashCode();
+ _hash = Arrays.hashCode(new int[]{h1,h2});
+ }
+ return _hash;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if( !(o instanceof CNodeUnary) )
+ return false;
+
+ CNodeUnary that = (CNodeUnary) o;
+ return super.equals(that)
+ && _type == that._type;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/BaseTpl.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/BaseTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/BaseTpl.java
new file mode 100644
index 0000000..4b7ecbf
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/BaseTpl.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.template;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+
+import org.apache.sysml.api.DMLException;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.cplan.CNodeData;
+import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
+import org.apache.sysml.runtime.matrix.data.Pair;
+
+public abstract class BaseTpl
+{
+ public enum TemplateType {
+ CellTpl,
+ OuterProductTpl,
+ RowAggTpl
+ }
+
+ private TemplateType _type = null;
+
+ protected ArrayList<Hop> _matrixInputs = new ArrayList<Hop>();
+ protected Hop _initialHop;
+ protected Hop _endHop;
+ protected ArrayList<CNodeData> _initialCnodes = new ArrayList<CNodeData>();
+ protected ArrayList<Hop> _adddedMatrices = new ArrayList<Hop>();
+ protected boolean _endHopReached = false;
+
+ protected LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> _cpplans = new LinkedHashMap<Long, Pair<Hop[],CNodeTpl>>();
+
+ protected BaseTpl(TemplateType type) {
+ _type = type;
+ }
+
+ public TemplateType getType() {
+ return _type;
+ }
+
+ public abstract boolean openTpl(Hop hop);
+
+ public abstract boolean findTplBoundaries(Hop initialHop, CplanRegister cplanRegister);
+
+ public abstract LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> constructTplCplan(boolean compileLiterals) throws DMLException;
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7fd5879/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java
new file mode 100644
index 0000000..0c841e8
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.template;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.Map.Entry;
+
+import org.apache.sysml.api.DMLException;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.codegen.cplan.CNode;
+import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
+import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
+import org.apache.sysml.hops.codegen.cplan.CNodeCell;
+import org.apache.sysml.hops.codegen.cplan.CNodeData;
+import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
+import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
+import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;
+import org.apache.sysml.runtime.matrix.data.Pair;
+
+public class CellTpl extends BaseTpl
+{
+
+ public CellTpl() {
+ super(TemplateType.CellTpl);
+ }
+
+ @Override
+ public boolean openTpl(Hop hop) {
+ return isValidOperation(hop);
+ }
+
+ @Override
+ public boolean findTplBoundaries(Hop initialHop, CplanRegister cplanRegister) {
+ _initialHop = initialHop;
+ rFindCellwisePattern(initialHop, new HashMap<Long, Hop>());
+
+ //if cplanRegister has the initial hop then no need to reconstruct
+ if(cplanRegister.containsHop(TemplateType.CellTpl, _initialHop.getHopID()))
+ return false;
+
+ //re-assign initialHop to fuse the sum/rowsums (before checking for chains)
+ for (Hop h : _initialHop.getParent())
+ if( h instanceof AggUnaryOp && ((AggUnaryOp) h).getOp() == AggOp.SUM
+ && ((AggUnaryOp) h).getDirection()!= Direction.Col ) {
+ _initialHop = h;
+ }
+
+ //unary matrix && endHop found && endHop is not direct child of the initialHop (i.e., chain of operators)
+ if(_endHop != null && _endHop != _initialHop)
+ {
+
+ // if final hop is unary add its child to the input
+ if(_endHop instanceof UnaryOp)
+ _matrixInputs.add(_endHop.getInput().get(0));
+ //if one input is scalar then add the other as major input
+ else if(_endHop.getInput().get(0).getDataType() == DataType.SCALAR)
+ _matrixInputs.add(_endHop.getInput().get(1));
+ else if(_endHop.getInput().get(1).getDataType() == DataType.SCALAR)
+ _matrixInputs.add(_endHop.getInput().get(0));
+ //if one is matrix and the other is vector add the matrix
+ else if(TemplateUtils.isMatrix(_endHop.getInput().get(0)) && TemplateUtils.isVector(_endHop.getInput().get(1)) )
+ _matrixInputs.add(_endHop.getInput().get(0));
+ else if(TemplateUtils.isMatrix(_endHop.getInput().get(1)) && TemplateUtils.isVector(_endHop.getInput().get(0)) )
+ _matrixInputs.add(_endHop.getInput().get(1));
+ //both are vectors (add any of them)
+ else
+ _matrixInputs.add(_endHop.getInput().get(0));
+
+ return true;
+ }
+
+ return false;
+ }
+
+ private void rFindCellwisePattern(Hop h, HashMap<Long,Hop> memo)
+ {
+ if(memo.containsKey(h.getHopID()))
+ return;
+
+ //stop recursion if stopping operator
+ if(h.getDataType() == DataType.SCALAR || !isValidOperation(h))
+ return;
+
+ //process childs recursively
+ _endHop = h;
+ for( Hop in : h.getInput() )
+ {
+ //propagate the _endHop from bottom to top
+ if(memo.containsKey(in.getHopID()))
+ _endHop=memo.get(in.getHopID());
+ else
+ rFindCellwisePattern(in,memo);
+ }
+
+ memo.put(h.getHopID(), _endHop);
+ }
+
+ @Override
+ public LinkedHashMap<Long, Pair<Hop[],CNodeTpl>> constructTplCplan(boolean compileLiterals)
+ throws DMLException {
+ //re-assign the dimensions of inputs to match the generated code dimensions
+ _initialCnodes.add(new CNodeData(_matrixInputs.get(0), 1, 1, DataType.SCALAR));
+
+ rConstructCellCplan(_initialHop,_initialHop, new HashSet<Long>(), compileLiterals);
+ return _cpplans;
+ }
+
+ public CNode fuseCellWise(Hop initialHop,Hop matrixInput, boolean compileLiterals)
+ throws DMLException {
+ //re-assign the dimensions of inputs to match the generated code dimensions
+ _initialHop = initialHop;
+ _matrixInputs.add(matrixInput);
+
+ constructTplCplan(compileLiterals);
+ Entry<Long, Pair<Hop[],CNodeTpl>> toplevel = TemplateUtils.getTopLevelCpplan(_cpplans);
+ if(toplevel != null)
+ return toplevel.getValue().getValue().getOutput();
+ else
+ return null;
+ }
+
+ private void rConstructCellCplan(Hop root, Hop hop, HashSet<Long> memo, boolean compileLiterals)
+ throws DMLException
+ {
+ if( memo.contains(hop.getHopID()) )
+ return;
+
+
+ //process childs recursively
+ for( Hop c : hop.getInput() )
+ rConstructCellCplan(root, c, memo, compileLiterals);
+
+ // first hop to enter here should be _endHop
+ if(TemplateUtils.inputsAreGenerated(hop,_matrixInputs,_cpplans))
+ // if direct children are DataGenOps, literals, or already in the cpplans then we are ready to generate code
+ {
+ CNodeCell cellTmpl = null;
+
+ //Fetch operands
+ CNode out = null;
+ ArrayList<CNode> addedCNodes = new ArrayList<CNode>();
+ ArrayList<Hop> addedHops = new ArrayList<Hop>();
+ ArrayList<CNode> cnodeData = TemplateUtils.fetchOperands(hop, _cpplans, addedCNodes, addedHops, _initialCnodes, compileLiterals);
+
+ //if operands are scalar or independent from X
+ boolean independentOperands = hop != root && (hop.getDataType() == DataType.SCALAR || TemplateUtils.isOperandsIndependent(cnodeData, addedHops, new String[] {_matrixInputs.get(0).getName()}));
+ if(!independentOperands)
+ {
+ if(hop instanceof UnaryOp)
+ {
+ CNode cdata1 = cnodeData.get(0);
+
+ //Primitive Operation haas the same name as Hop Type OpOp1
+ String primitiveOpName = ((UnaryOp)hop).getOp().toString();
+ out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
+ }
+ else if(hop instanceof BinaryOp)
+ {
+ BinaryOp bop = (BinaryOp) hop;
+ CNode cdata1 = cnodeData.get(0);
+ CNode cdata2 = cnodeData.get(1);
+
+ //Primitive Operation has the same name as Hop Type OpOp2
+ String primitiveOpName = bop.getOp().toString();
+
+ //cdata1 is vector
+ if( TemplateUtils.isColVector(cdata1) )
+ cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP);
+
+ //cdata2 is vector
+ if( TemplateUtils.isColVector(cdata2) )
+ cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP);
+
+
+ if( bop.getOp()==OpOp2.POW && cdata2.isLiteral() && cdata2.getVarname().equals("2") )
+ out = new CNodeUnary(cdata1, UnaryType.POW2);
+ else if( bop.getOp()==OpOp2.MULT && cdata2.isLiteral() && cdata2.getVarname().equals("2") )
+ out = new CNodeUnary(cdata1, UnaryType.MULT2);
+ else //default binary
+ out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
+ }
+ else if (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM
+ && (((AggUnaryOp) hop).getDirection() == Direction.RowCol
+ || ((AggUnaryOp) hop).getDirection() == Direction.Row) && root == hop)
+ {
+ out = cnodeData.get(0);
+ }
+ }
+ // wire output to the template
+ if(out != null || independentOperands)
+ {
+ if(_cpplans.isEmpty())
+ {
+ //first initialization has to have the first variable as input
+ ArrayList<CNode> initialInputs = new ArrayList<CNode>();
+
+ if(independentOperands) // pass the hop itself as an input instead of its children
+ {
+ CNode c = new CNodeData(hop);
+ initialInputs.addAll(_initialCnodes);
+ initialInputs.add(c);
+ cellTmpl = new CNodeCell(initialInputs, c);
+ cellTmpl.setDataType(hop.getDataType());
+ cellTmpl.setCellType(CellType.NO_AGG);
+ cellTmpl.setMultipleConsumers(hop.getParent().size()>1);
+
+ _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(new Hop[] {_matrixInputs.get(0),hop} ,cellTmpl));
+ }
+ else
+ {
+ initialInputs.addAll(_initialCnodes);
+ initialInputs.addAll(cnodeData);
+ cellTmpl = new CNodeCell(initialInputs, out);
+ cellTmpl.setDataType(hop.getDataType());
+ cellTmpl.setCellType(CellType.NO_AGG);
+ cellTmpl.setMultipleConsumers(hop.getParent().size()>1);
+
+ //Hop[] hopArray = new Hop[hop.getInput().size()+1];
+ Hop[] hopArray = new Hop[addedHops.size()+1];
+ hopArray[0] = _matrixInputs.get(0);
+
+ //System.arraycopy( hop.getInput().toArray(), 0, hopArray, 1, hop.getInput().size());
+ System.arraycopy( addedHops.toArray(), 0, hopArray, 1, addedHops.size());
+
+ _cpplans.put(hop.getHopID(), new Pair<Hop[],CNodeTpl>(hopArray,cellTmpl));
+ }
+ }
+ else
+ {
+ if(independentOperands)
+ {
+ CNode c = new CNodeData(hop);
+ //clear Operands
+ addedCNodes.clear();
+ addedHops.clear();
+
+ //added the current hop as the input
+ addedCNodes.add(c);
+ addedHops.add(hop);
+ out = c;
+ }
+ //wire the output to existing or new template
+ TemplateUtils.setOutputToExistingTemplate(hop, out, _cpplans, addedCNodes, addedHops);
+ }
+ }
+ memo.add(hop.getHopID());
+ }
+ }
+
+ private boolean isValidOperation(Hop hop) {
+ boolean isBinaryMatrixScalar = hop instanceof BinaryOp && hop.getDataType()==DataType.MATRIX &&
+ (hop.getInput().get(0).getDataType()==DataType.SCALAR || hop.getInput().get(1).getDataType()==DataType.SCALAR);
+ boolean isBinaryMatrixVector = hop instanceof BinaryOp && hop.dimsKnown() &&
+ ((hop.getInput().get(0).getDataType() == DataType.MATRIX
+ && TemplateUtils.isVectorOrScalar(hop.getInput().get(1)) && !TemplateUtils.isBinaryMatrixRowVector(hop))
+ ||(TemplateUtils.isVectorOrScalar( hop.getInput().get(0))
+ && hop.getInput().get(1).getDataType() == DataType.MATRIX && !TemplateUtils.isBinaryMatrixRowVector(hop)) );
+ return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop)
+ && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector);
+ }
+}