You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/08/03 17:06:41 UTC
systemml git commit: [SYSTEMML-1808] [SYSTEMML-1658] Visualize Hop
DAG for explaining the optimizer
Repository: systemml
Updated Branches:
refs/heads/master 2c9694dec -> ac1cf093a
[SYSTEMML-1808] [SYSTEMML-1658] Visualize Hop DAG for explaining the optimizer
- Also added an utility to print java output in notebook.
- Fixed a bug in dmlFromResource.
Closes #596.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ac1cf093
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ac1cf093
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ac1cf093
Branch: refs/heads/master
Commit: ac1cf093ad0b47cb6a0f0d48c4deb276b4ae1fa6
Parents: 2c9694d
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu Aug 3 09:06:24 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu Aug 3 10:06:24 2017 -0700
----------------------------------------------------------------------
.../apache/sysml/api/mlcontext/MLContext.java | 25 +-
.../sysml/api/mlcontext/MLContextUtil.java | 108 +++++++
.../sysml/api/mlcontext/ScriptExecutor.java | 46 ++-
.../context/SparkExecutionContext.java | 2 +-
.../sysml/runtime/instructions/Instruction.java | 20 ++
.../java/org/apache/sysml/utils/Explain.java | 301 +++++++++++++++++++
src/main/python/systemml/mlcontext.py | 100 +++++-
.../scala/org/apache/sysml/api/ml/Utils.scala | 61 ++++
8 files changed, 646 insertions(+), 17 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
index b35faa6..f74d593 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
@@ -46,7 +46,6 @@ import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.utils.Explain.ExplainType;
import org.apache.sysml.utils.MLContextProxy;
-
/**
* The MLContext API offers programmatic access to SystemML on Spark from
* languages such as Scala, Java, and Python.
@@ -287,6 +286,8 @@ public class MLContext {
public void resetConfig() {
MLContextUtil.setDefaultConfig();
}
+
+
/**
* Set configuration property, such as
@@ -305,7 +306,8 @@ public class MLContext {
throw new MLContextException(e);
}
}
-
+
+
/**
* Execute a DML or PYDML Script.
*
@@ -357,6 +359,16 @@ public class MLContext {
throw new MLContextException("Exception when executing script", e);
}
}
+
+ /**
+ * Sets the script that is being executed
+ *
+ * @param executionScript
+ * script that is being executed
+ */
+ public void setExecutionScript(Script executionScript) {
+ this.executionScript = executionScript;
+ }
/**
* Set SystemML configuration based on a configuration file.
@@ -489,6 +501,15 @@ public class MLContext {
}
/**
+ * Whether or not the "force" GPU mode is enabled.
+ *
+ * @return true if enabled, false otherwise
+ */
+ public boolean isForceGPU() {
+ return this.forceGPU;
+ }
+
+ /**
* Used internally by MLContextProxy.
*
*/
http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
index 2c9566c..51d38a5 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -23,6 +23,7 @@ import java.io.File;
import java.io.FileNotFoundException;
import java.net.URL;
import java.util.ArrayList;
+import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
@@ -35,6 +36,7 @@ import javax.xml.parsers.DocumentBuilderFactory;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.text.WordUtils;
+import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
@@ -52,8 +54,11 @@ import org.apache.sysml.conf.CompilerConfig;
import org.apache.sysml.conf.CompilerConfig.ConfigType;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.parser.LanguageException;
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
@@ -63,6 +68,8 @@ import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.BooleanObject;
import org.apache.sysml.runtime.instructions.cp.Data;
@@ -73,6 +80,7 @@ import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.utils.Explain;
import org.apache.sysml.utils.MLContextProxy;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
@@ -83,6 +91,106 @@ import org.w3c.dom.NodeList;
*
*/
public final class MLContextUtil {
+
+ /**
+ * Get HOP DAG in dot format for a DML or PYDML Script.
+ *
+ * @param mlCtx
+ * MLContext object.
+ * @param script
+ * The DML or PYDML Script object to execute.
+ * @param lines
+ * Only display the hops that have begin and end line number
+ * equals to the given integers.
+ * @param performHOPRewrites
+ * should perform static rewrites, perform
+ * intra-/inter-procedural analysis to propagate size information
+ * into functions and apply dynamic rewrites
+ * @param withSubgraph
+ * If false, the dot graph will be created without subgraphs for
+ * statement blocks.
+ * @return hop DAG in dot format
+ * @throws LanguageException
+ * if error occurs
+ * @throws DMLRuntimeException
+ * if error occurs
+ * @throws HopsException
+ * if error occurs
+ */
+ public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines,
+ boolean performHOPRewrites, boolean withSubgraph) throws HopsException, DMLRuntimeException,
+ LanguageException {
+ return getHopDAG(mlCtx, script, lines, null, performHOPRewrites, withSubgraph);
+ }
+
+ /**
+ * Get HOP DAG in dot format for a DML or PYDML Script.
+ *
+ * @param mlCtx
+ * MLContext object.
+ * @param script
+ * The DML or PYDML Script object to execute.
+ * @param lines
+ * Only display the hops that have begin and end line number
+ * equals to the given integers.
+ * @param newConf
+ * Spark Configuration.
+ * @param performHOPRewrites
+ * should perform static rewrites, perform
+ * intra-/inter-procedural analysis to propagate size information
+ * into functions and apply dynamic rewrites
+ * @param withSubgraph
+ * If false, the dot graph will be created without subgraphs for
+ * statement blocks.
+ * @return hop DAG in dot format
+ * @throws LanguageException
+ * if error occurs
+ * @throws DMLRuntimeException
+ * if error occurs
+ * @throws HopsException
+ * if error occurs
+ */
+ public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines, SparkConf newConf,
+ boolean performHOPRewrites, boolean withSubgraph) throws HopsException, DMLRuntimeException,
+ LanguageException {
+ SparkConf oldConf = mlCtx.getSparkSession().sparkContext().getConf();
+ SparkExecutionContext.SparkClusterConfig systemmlConf = SparkExecutionContext.getSparkClusterConfig();
+ long oldMaxMemory = InfrastructureAnalyzer.getLocalMaxMemory();
+ try {
+ if (newConf != null) {
+ systemmlConf.analyzeSparkConfiguation(newConf);
+ InfrastructureAnalyzer.setLocalMaxMemory(newConf.getSizeAsBytes("spark.driver.memory"));
+ }
+ ScriptExecutor scriptExecutor = new ScriptExecutor();
+ scriptExecutor.setExecutionType(mlCtx.getExecutionType());
+ scriptExecutor.setGPU(mlCtx.isGPU());
+ scriptExecutor.setForceGPU(mlCtx.isForceGPU());
+ scriptExecutor.setInit(mlCtx.isInitBeforeExecution());
+ if (mlCtx.isInitBeforeExecution()) {
+ mlCtx.setInitBeforeExecution(false);
+ }
+ scriptExecutor.setMaintainSymbolTable(mlCtx.isMaintainSymbolTable());
+
+ Long time = new Long((new Date()).getTime());
+ if ((script.getName() == null) || (script.getName().equals(""))) {
+ script.setName(time.toString());
+ }
+
+ mlCtx.setExecutionScript(script);
+ scriptExecutor.compile(script, performHOPRewrites);
+ Explain.reset();
+ // To deal with potential Py4J issues
+ lines = lines.size() == 1 && lines.get(0) == -1 ? new ArrayList<Integer>() : lines;
+ return Explain.getHopDAG(scriptExecutor.dmlProgram, lines, withSubgraph);
+ } catch (RuntimeException e) {
+ throw new MLContextException("Exception when compiling script", e);
+ } finally {
+ if (newConf != null) {
+ systemmlConf.analyzeSparkConfiguation(oldConf);
+ InfrastructureAnalyzer.setLocalMaxMemory(oldMaxMemory);
+ }
+ }
+ }
/**
* Basic data types supported by the MLContext API
http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index 467e94e..7e78891 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -263,9 +263,16 @@ public class ScriptExecutor {
DMLScript.USE_ACCELERATOR = oldGPU;
DMLScript.STATISTICS_COUNT = DMLOptions.defaultOptions.statsCount;
}
-
+
+ public void compile(Script script) {
+ compile(script, true);
+ }
+
/**
- * Execute a DML or PYDML script. This is broken down into the following
+ * Compile a DML or PYDML script. This will help analysis of DML programs
+ * that have dynamic recompilation flag set to false without actually executing it.
+ *
+ * This is broken down into the following
* primary methods:
*
* <ol>
@@ -283,16 +290,14 @@ public class ScriptExecutor {
* <li>{@link #countCompiledMRJobsAndSparkInstructions()}</li>
* <li>{@link #initializeCachingAndScratchSpace()}</li>
* <li>{@link #cleanupRuntimeProgram()}</li>
- * <li>{@link #createAndInitializeExecutionContext()}</li>
- * <li>{@link #executeRuntimeProgram()}</li>
- * <li>{@link #cleanupAfterExecution()}</li>
* </ol>
*
* @param script
- * the DML or PYDML script to execute
- * @return the results as a MLResults object
+ * the DML or PYDML script to compile
+ * @param performHOPRewrites
+ * should perform static rewrites, perform intra-/inter-procedural analysis to propagate size information into functions and apply dynamic rewrites
*/
- public MLResults execute(Script script) {
+ public void compile(Script script, boolean performHOPRewrites) {
// main steps in script execution
setup(script);
@@ -303,7 +308,8 @@ public class ScriptExecutor {
liveVariableAnalysis();
validateScript();
constructHops();
- rewriteHops();
+ if(performHOPRewrites)
+ rewriteHops();
rewritePersistentReadsAndWrites();
constructLops();
generateRuntimeProgram();
@@ -315,6 +321,28 @@ public class ScriptExecutor {
if (statistics) {
Statistics.stopCompileTimer();
}
+ }
+
+
+ /**
+ * Execute a DML or PYDML script. This is broken down into the following
+ * primary methods:
+ *
+ * <ol>
+ * <li>{@link #compile(Script)}</li>
+ * <li>{@link #createAndInitializeExecutionContext()}</li>
+ * <li>{@link #executeRuntimeProgram()}</li>
+ * <li>{@link #cleanupAfterExecution()}</li>
+ * </ol>
+ *
+ * @param script
+ * the DML or PYDML script to execute
+ * @return the results as a MLResults object
+ */
+ public MLResults execute(Script script) {
+
+ // main steps in script execution
+ compile(script);
try {
createAndInitializeExecutionContext();
http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 6f2f766..d1ff7d8 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1378,7 +1378,7 @@ public class SparkExecutionContext extends ExecutionContext
* degree of parallelism. This configuration abstracts legacy (< Spark 1.6) and current
* configurations and provides a unified view.
*/
- private static class SparkClusterConfig
+ public static class SparkClusterConfig
{
//broadcasts are stored in mem-and-disk in data space, this config
//defines the fraction of data space to be used as broadcast budget
http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java b/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java
index 6db8c7f..374f81c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/Instruction.java
@@ -63,6 +63,26 @@ public abstract class Instruction
protected int beginCol = -1;
protected int endCol = -1;
+ public String getFilename() {
+ return filename;
+ }
+
+ public int getBeginLine() {
+ return beginLine;
+ }
+
+ public int getEndLine() {
+ return endLine;
+ }
+
+ public int getBeginColumn() {
+ return beginCol;
+ }
+
+ public int getEndColumn() {
+ return endCol;
+ }
+
public void setType (INSTRUCTION_TYPE tp ) {
type = tp;
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/java/org/apache/sysml/utils/Explain.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java
index a2e843a..01b59f7 100644
--- a/src/main/java/org/apache/sysml/utils/Explain.java
+++ b/src/main/java/org/apache/sysml/utils/Explain.java
@@ -26,10 +26,16 @@ import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
+import org.apache.sysml.hops.AggBinaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
@@ -266,6 +272,50 @@ public class Explain
return sb.toString();
}
+
+ public static String getHopDAG(DMLProgram prog, ArrayList<Integer> lines, boolean withSubgraph)
+ throws HopsException, DMLRuntimeException, LanguageException {
+ StringBuilder sb = new StringBuilder();
+ StringBuilder nodes = new StringBuilder();
+
+ // create header
+ sb.append("digraph {");
+
+ // Explain functions (if exists)
+ if (prog.hasFunctionStatementBlocks()) {
+
+ // show function call graph
+ // FunctionCallGraph fgraph = new FunctionCallGraph(prog);
+ // sb.append(explainFunctionCallGraph(fgraph, new HashSet<String>(),
+ // null, 3));
+
+ // show individual functions
+ for (String namespace : prog.getNamespaces().keySet()) {
+ for (String fname : prog.getFunctionStatementBlocks(namespace).keySet()) {
+ FunctionStatementBlock fsb = prog.getFunctionStatementBlock(namespace, fname);
+ FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
+ String fkey = DMLProgram.constructFunctionKey(namespace, fname);
+
+ if (!(fstmt instanceof ExternalFunctionStatement)) {
+ addSubGraphHeader(sb, withSubgraph);
+ for (StatementBlock current : fstmt.getBody())
+ sb.append(getHopDAG(current, nodes, lines, withSubgraph));
+ String label = "FUNCTION " + fkey + " recompile=" + fsb.isRecompileOnce() + "\n";
+ addSubGraphFooter(sb, withSubgraph, label);
+ }
+ }
+ }
+ }
+
+ // Explain main program
+ for (StatementBlock sblk : prog.getStatementBlocks())
+ sb.append(getHopDAG(sblk, nodes, lines, withSubgraph));
+
+ sb.append(nodes);
+ sb.append("rankdir = \"BT\"\n");
+ sb.append("}\n");
+ return sb.toString();
+ }
public static String explain( Program rtprog ) throws HopsException {
return explain(rtprog, null);
@@ -466,6 +516,128 @@ public class Explain
//////////////
// internal explain HOPS
+ private static int clusterID = 0;
+
+ public static void reset() {
+ clusterID = 0;
+ }
+
+ private static void addSubGraphHeader(StringBuilder builder, boolean withSubgraph) {
+ if (withSubgraph) {
+ builder.append("subgraph cluster_" + (clusterID++) + " {\n");
+ }
+ }
+
+ private static void addSubGraphFooter(StringBuilder builder, boolean withSubgraph, String label) {
+ if (withSubgraph) {
+ builder.append("label = \"" + label + "\";\n");
+ builder.append("}\n");
+ }
+ }
+
+ private static StringBuilder getHopDAG(StatementBlock sb, StringBuilder nodes, ArrayList<Integer> lines,
+ boolean withSubgraph) throws HopsException, DMLRuntimeException {
+ StringBuilder builder = new StringBuilder();
+
+ if (sb instanceof WhileStatementBlock) {
+ addSubGraphHeader(builder, withSubgraph);
+
+ WhileStatementBlock wsb = (WhileStatementBlock) sb;
+ String label = null;
+ if (!wsb.getUpdateInPlaceVars().isEmpty())
+ label = "WHILE (lines " + wsb.getBeginLine() + "-" + wsb.getEndLine() + ") in-place="
+ + wsb.getUpdateInPlaceVars().toString() + "";
+ else
+ label = "WHILE (lines " + wsb.getBeginLine() + "-" + wsb.getEndLine() + ")";
+ // TODO: Don't show predicate hops for now
+ // builder.append(explainHop(wsb.getPredicateHops()));
+
+ WhileStatement ws = (WhileStatement) sb.getStatement(0);
+ for (StatementBlock current : ws.getBody())
+ builder.append(getHopDAG(current, nodes, lines, withSubgraph));
+
+ addSubGraphFooter(builder, withSubgraph, label);
+ } else if (sb instanceof IfStatementBlock) {
+ addSubGraphHeader(builder, withSubgraph);
+ IfStatementBlock ifsb = (IfStatementBlock) sb;
+ String label = "IF (lines " + ifsb.getBeginLine() + "-" + ifsb.getEndLine() + ")";
+ // TODO: Don't show predicate hops for now
+ // builder.append(explainHop(ifsb.getPredicateHops(), level+1));
+
+ IfStatement ifs = (IfStatement) sb.getStatement(0);
+ for (StatementBlock current : ifs.getIfBody()) {
+ builder.append(getHopDAG(current, nodes, lines, withSubgraph));
+ addSubGraphFooter(builder, withSubgraph, label);
+ }
+ if (!ifs.getElseBody().isEmpty()) {
+ addSubGraphHeader(builder, withSubgraph);
+ label = "ELSE (lines " + ifsb.getBeginLine() + "-" + ifsb.getEndLine() + ")";
+
+ for (StatementBlock current : ifs.getElseBody())
+ builder.append(getHopDAG(current, nodes, lines, withSubgraph));
+ addSubGraphFooter(builder, withSubgraph, label);
+ }
+ } else if (sb instanceof ForStatementBlock) {
+ ForStatementBlock fsb = (ForStatementBlock) sb;
+ addSubGraphHeader(builder, withSubgraph);
+ String label = "";
+ if (sb instanceof ParForStatementBlock) {
+ if (!fsb.getUpdateInPlaceVars().isEmpty())
+ label = "PARFOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ") in-place="
+ + fsb.getUpdateInPlaceVars().toString() + "";
+ else
+ label = "PARFOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ")";
+ } else {
+ if (!fsb.getUpdateInPlaceVars().isEmpty())
+ label = "FOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ") in-place="
+ + fsb.getUpdateInPlaceVars().toString() + "";
+ else
+ label = "FOR (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ")";
+ }
+ // TODO: Don't show predicate hops for now
+ // if (fsb.getFromHops() != null)
+ // builder.append(explainHop(fsb.getFromHops(), level+1));
+ // if (fsb.getToHops() != null)
+ // builder.append(explainHop(fsb.getToHops(), level+1));
+ // if (fsb.getIncrementHops() != null)
+ // builder.append(explainHop(fsb.getIncrementHops(), level+1));
+
+ ForStatement fs = (ForStatement) sb.getStatement(0);
+ for (StatementBlock current : fs.getBody())
+ builder.append(getHopDAG(current, nodes, lines, withSubgraph));
+ addSubGraphFooter(builder, withSubgraph, label);
+
+ } else if (sb instanceof FunctionStatementBlock) {
+ FunctionStatement fsb = (FunctionStatement) sb.getStatement(0);
+ addSubGraphHeader(builder, withSubgraph);
+ String label = "Function (lines " + fsb.getBeginLine() + "-" + fsb.getEndLine() + ")";
+ for (StatementBlock current : fsb.getBody())
+ builder.append(getHopDAG(current, nodes, lines, withSubgraph));
+ addSubGraphFooter(builder, withSubgraph, label);
+ } else {
+ // For generic StatementBlock
+ if (sb.requiresRecompilation()) {
+ addSubGraphHeader(builder, withSubgraph);
+ }
+ ArrayList<Hop> hopsDAG = sb.get_hops();
+ if (hopsDAG != null && !hopsDAG.isEmpty()) {
+ Hop.resetVisitStatus(hopsDAG);
+ for (Hop hop : hopsDAG)
+ builder.append(getHopDAG(hop, nodes, lines, withSubgraph));
+ Hop.resetVisitStatus(hopsDAG);
+ }
+
+ if (sb.requiresRecompilation()) {
+ builder.append("style=filled;\n");
+ builder.append("color=lightgrey;\n");
+ String label = "(lines " + sb.getBeginLine() + "-" + sb.getEndLine() + ") [recompile="
+ + sb.requiresRecompilation() + "]";
+ addSubGraphFooter(builder, withSubgraph, label);
+ }
+ }
+ return builder;
+ }
+
private static String explainStatementBlock(StatementBlock sb, int level)
throws HopsException, DMLRuntimeException
{
@@ -636,6 +808,134 @@ public class Explain
return sb.toString();
}
+
+ private static boolean isInRange(Hop hop, ArrayList<Integer> lines) {
+ boolean isInRange = lines.size() == 0 ? true : false;
+ for (int lineNum : lines) {
+ if (hop.getBeginLine() == lineNum && lineNum == hop.getEndLine()) {
+ return true;
+ }
+ }
+ return isInRange;
+ }
+
+ private static StringBuilder getHopDAG(Hop hop, StringBuilder nodes, ArrayList<Integer> lines, boolean withSubgraph)
+ throws DMLRuntimeException {
+ StringBuilder sb = new StringBuilder();
+ if (hop.isVisited() || (!SHOW_LITERAL_HOPS && hop instanceof LiteralOp))
+ return sb;
+
+ for (Hop input : hop.getInput()) {
+ if ((SHOW_LITERAL_HOPS || !(input instanceof LiteralOp)) && isInRange(hop, lines)) {
+ String edgeLabel = showMem(input.getOutputMemEstimate(), true);
+ sb.append("h" + input.getHopID() + " -> h" + hop.getHopID() + " [label=\"" + edgeLabel + "\"];\n");
+ }
+ }
+ for (Hop input : hop.getInput())
+ sb.append(getHopDAG(input, nodes, lines, withSubgraph));
+
+ if (isInRange(hop, lines)) {
+ nodes.append("h" + hop.getHopID() + "[label=\"" + getNodeLabel(hop) + "\", " + "shape=\""
+ + getNodeShape(hop) + "\", color=\"" + getNodeColor(hop) + "\", tooltip=\"" + getNodeToolTip(hop)
+ + "\"];\n");
+ }
+ hop.setVisited();
+
+ return sb;
+ }
+
+ private static String getNodeLabel(Hop hop) {
+ StringBuilder sb = new StringBuilder();
+ sb.append(hop.getOpString());
+ if (hop instanceof AggBinaryOp) {
+ AggBinaryOp aggBinOp = (AggBinaryOp) hop;
+ if (aggBinOp.getMMultMethod() != null)
+ sb.append(" " + aggBinOp.getMMultMethod().name() + " ");
+ }
+ // data flow properties
+ if (SHOW_DATA_FLOW_PROPERTIES) {
+ if (hop.requiresReblock() && hop.requiresCheckpoint())
+ sb.append(", rblk,chkpt");
+ else if (hop.requiresReblock())
+ sb.append(", rblk");
+ else if (hop.requiresCheckpoint())
+ sb.append(", chkpt");
+ }
+ if (hop.getFilename() == null) {
+ sb.append("[" + hop.getBeginLine() + ":" + hop.getBeginColumn() + "-" + hop.getEndLine() + ":"
+ + hop.getEndColumn() + "]");
+ } else {
+ sb.append("[" + hop.getFilename() + " " + hop.getBeginLine() + ":" + hop.getBeginColumn() + "-"
+ + hop.getEndLine() + ":" + hop.getEndColumn() + "]");
+ }
+
+ if (hop.getUpdateType().isInPlace())
+ sb.append("," + hop.getUpdateType().toString().toLowerCase());
+ return sb.toString();
+ }
+
+ private static String getNodeToolTip(Hop hop) {
+ StringBuilder sb = new StringBuilder();
+ if (hop.getExecType() != null) {
+ sb.append(hop.getExecType().name());
+ }
+ sb.append("[" + hop.getDim1() + " X " + hop.getDim2() + "], nnz=" + hop.getNnz());
+ sb.append(", mem= [in=");
+ sb.append(showMem(hop.getInputMemEstimate(), false));
+ sb.append(", inter=");
+ sb.append(showMem(hop.getIntermediateMemEstimate(), false));
+ sb.append(", out=");
+ sb.append(showMem(hop.getOutputMemEstimate(), false));
+ sb.append(" -> ");
+ sb.append(showMem(hop.getMemEstimate(), true));
+ sb.append("]");
+ return sb.toString();
+ }
+
+ private static String getNodeShape(Hop hop) {
+ String shape = "octagon";
+ if (hop.getExecType() != null) {
+ switch (hop.getExecType()) {
+ case CP:
+ shape = "ellipse";
+ break;
+ case SPARK:
+ shape = "box";
+ break;
+ case GPU:
+ shape = "trapezium";
+ break;
+ case MR:
+ shape = "parallelogram";
+ break;
+ default:
+ shape = "octagon";
+ break;
+ }
+ }
+ return shape;
+ }
+
+ private static String getNodeColor(Hop hop) {
+ if (hop instanceof DataOp) {
+ DataOp dOp = (DataOp) hop;
+ if (dOp.getDataOpType() == DataOpTypes.PERSISTENTREAD || dOp.getDataOpType() == DataOpTypes.TRANSIENTREAD) {
+ return "wheat2";
+ } else if (dOp.getDataOpType() == DataOpTypes.PERSISTENTWRITE
+ || dOp.getDataOpType() == DataOpTypes.TRANSIENTWRITE) {
+ return "wheat4";
+ }
+ } else if (hop instanceof AggBinaryOp) {
+ return "orangered2";
+ } else if (hop instanceof BinaryOp) {
+ return "royalblue2";
+ } else if (hop instanceof ReorgOp) {
+ return "green";
+ } else if (hop instanceof UnaryOp) {
+ return "yellow";
+ }
+ return "black";
+ }
//////////////
// internal explain CNODE
@@ -867,6 +1167,7 @@ public class Explain
sb.append( offsetInst );
sb.append( tmp );
+
sb.append( '\n' );
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/python/systemml/mlcontext.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mlcontext.py b/src/main/python/systemml/mlcontext.py
index 1e79648..b5a6bf9 100644
--- a/src/main/python/systemml/mlcontext.py
+++ b/src/main/python/systemml/mlcontext.py
@@ -19,7 +19,11 @@
#
#-------------------------------------------------------------
-__all__ = ['MLResults', 'MLContext', 'Script', 'dml', 'pydml', 'dmlFromResource', 'pydmlFromResource', 'dmlFromFile', 'pydmlFromFile', 'dmlFromUrl', 'pydmlFromUrl', '_java2py', 'Matrix']
+# Methods to create Script object
+script_factory_methods = [ 'dml', 'pydml', 'dmlFromResource', 'pydmlFromResource', 'dmlFromFile', 'pydmlFromFile', 'dmlFromUrl', 'pydmlFromUrl' ]
+# Utility methods
+util_methods = [ 'jvm_stdout', '_java2py', 'getHopDAG' ]
+__all__ = ['MLResults', 'MLContext', 'Script', 'Matrix' ] + script_factory_methods + util_methods
import os
@@ -27,13 +31,16 @@ try:
import py4j.java_gateway
from py4j.java_gateway import JavaObject
from pyspark import SparkContext
+ from pyspark.conf import SparkConf
import pyspark.mllib.common
except ImportError:
raise ImportError('Unable to import `pyspark`. Hint: Make sure you are running with PySpark.')
from .converters import *
from .classloader import *
+import threading, time
+_loadedSystemML = False
def _get_spark_context():
"""
Internal method to get already initialized SparkContext.
@@ -44,10 +51,93 @@ def _get_spark_context():
SparkContext
"""
if SparkContext._active_spark_context is not None:
- return SparkContext._active_spark_context
+ sc = SparkContext._active_spark_context
+ if not _loadedSystemML:
+ createJavaObject(sc, 'dummy')
+ _loadedSystemML = True
+ return sc
else:
raise Exception('Expected spark context to be created.')
+# This is useful utility class to get the output of the driver JVM from within a Jupyter notebook
+# Example usage:
+# with jvm_stdout():
+# ml.execute(script)
+class jvm_stdout(object):
+ """
+ This is useful utility class to get the output of the driver JVM from within a Jupyter notebook
+
+ Parameters
+ ----------
+ parallel_flush: boolean
+ Should flush the stdout in parallel
+ """
+ def __init__(self, parallel_flush=False):
+ self.util = SparkContext._active_spark_context._jvm.org.apache.sysml.api.ml.Utils()
+ self.parallel_flush = parallel_flush
+ self.t = threading.Thread(target=self.flush_stdout)
+ self.stop = False
+
+ def flush_stdout(self):
+ while not self.stop:
+ time.sleep(1) # flush stdout every 1 second
+ str = self.util.flushStdOut()
+ if str != '':
+ str = str[:-1] if str.endswith('\n') else str
+ print(str)
+
+ def __enter__(self):
+ self.util.startRedirectStdOut()
+ if self.parallel_flush:
+ self.t.start()
+
+ def __exit__(self, *args):
+ if self.parallel_flush:
+ self.stop = True
+ self.t.join()
+ print(self.util.stopRedirectStdOut())
+
+
+def getHopDAG(ml, script, lines=None, conf=None, apply_rewrites=True, with_subgraph=False):
+ """
+ Compile a DML / PyDML script.
+
+ Parameters
+ ----------
+ ml: MLContext instance
+ MLContext instance.
+
+ script: Script instance
+ Script instance defined with the appropriate input and output variables.
+
+ lines: list of integers
+ Optional: only display the hops that have begin and end line number equals to the given integers.
+
+ conf: SparkConf instance
+ Optional spark configuration
+
+ apply_rewrites: boolean
+ If True, perform static rewrites, perform intra-/inter-procedural analysis to propagate size information into functions and apply dynamic rewrites
+
+ with_subgraph: boolean
+ If False, the dot graph will be created without subgraphs for statement blocks.
+
+ Returns
+ -------
+ hopDAG: string
+ hop DAG in dot format
+ """
+ if not isinstance(script, Script):
+ raise ValueError("Expected script to be an instance of Script")
+ scriptString = script.scriptString
+ script_java = script.script_java
+ lines = [ int(x) for x in lines ] if lines is not None else [int(-1)]
+ sc = _get_spark_context()
+ if conf is not None:
+ hopDAG = sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(ml._ml, script_java, lines, conf._jconf, apply_rewrites, with_subgraph)
+ else:
+ hopDAG = sc._jvm.org.apache.sysml.api.mlcontext.MLContextUtil.getHopDAG(ml._ml, script_java, lines, apply_rewrites, with_subgraph)
+ return hopDAG
def dml(scriptString):
"""
@@ -330,9 +420,9 @@ class Script(object):
self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile(scriptString)
elif scriptFormat == "file" and self.scriptType == "pydml":
self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile(scriptString)
- elif scriptFormat == "file" and self.scriptType == "dml":
+ elif isResource and self.scriptType == "dml":
self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromResource(scriptString)
- elif scriptFormat == "file" and self.scriptType == "pydml":
+ elif isResource and self.scriptType == "pydml":
self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromResource(scriptString)
elif scriptFormat == "string" and self.scriptType == "dml":
self.script_java = self.sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml(scriptString)
@@ -605,7 +695,7 @@ class MLContext(object):
def __repr__(self):
return "MLContext"
-
+
def execute(self, script):
"""
Execute a DML / PyDML script.
http://git-wip-us.apache.org/repos/asf/systemml/blob/ac1cf093/src/main/scala/org/apache/sysml/api/ml/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/Utils.scala b/src/main/scala/org/apache/sysml/api/ml/Utils.scala
index da3edf5..a804f64 100644
--- a/src/main/scala/org/apache/sysml/api/ml/Utils.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/Utils.scala
@@ -18,8 +18,69 @@
*/
package org.apache.sysml.api.ml
+import org.apache.spark.api.java.JavaPairRDD
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+
+object Utils {
+ val originalOut = System.out
+ val originalErr = System.err
+}
class Utils {
def checkIfFileExists(filePath:String):Boolean = {
return org.apache.sysml.runtime.util.MapReduceTool.existsFileOnHDFS(filePath)
}
+
+ // --------------------------------------------------------------------------------
+ // Simple utility function to print the information about our binary blocked format
+ def getBinaryBlockInfo(binaryBlocks:JavaPairRDD[MatrixIndexes, MatrixBlock]):String = {
+ val sb = new StringBuilder
+ var partitionIndex = 0
+ for(str <- binaryBlocks.rdd.mapPartitions(binaryBlockIteratorToString(_), true).collect) {
+ sb.append("-------------------------------------\n")
+ sb.append("Partition " + partitionIndex + ":\n")
+ sb.append(str)
+ partitionIndex = partitionIndex + 1
+ }
+ sb.append("-------------------------------------\n")
+ return sb.toString()
+ }
+ def binaryBlockIteratorToString(it: Iterator[(MatrixIndexes, MatrixBlock)]): Iterator[String] = {
+ val sb = new StringBuilder
+ for(entry <- it) {
+ val mi = entry._1
+ val mb = entry._2
+ sb.append(mi.toString);
+ sb.append(" sparse? = ");
+ sb.append(mb.isInSparseFormat());
+ if(mb.isUltraSparse)
+ sb.append(" (ultra-sparse)")
+ sb.append(", nonzeros = ");
+ sb.append(mb.getNonZeros);
+ sb.append(", dimensions = ");
+ sb.append(mb.getNumRows);
+ sb.append(" X ");
+ sb.append(mb.getNumColumns);
+ sb.append("\n");
+ }
+ List[String](sb.toString).iterator
+ }
+ val baos = new java.io.ByteArrayOutputStream()
+ val baes = new java.io.ByteArrayOutputStream()
+ def startRedirectStdOut():Unit = {
+ System.setOut(new java.io.PrintStream(baos));
+ System.setErr(new java.io.PrintStream(baes));
+ }
+ def flushStdOut():String = {
+ val ret = baos.toString() + baes.toString()
+ baos.reset(); baes.reset()
+ return ret
+ }
+ def stopRedirectStdOut():String = {
+ val ret = baos.toString() + baes.toString()
+ System.setOut(Utils.originalOut)
+ System.setErr(Utils.originalErr)
+ return ret
+ }
+ // --------------------------------------------------------------------------------
}
\ No newline at end of file