You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2016/07/29 00:14:39 UTC
[2/4] incubator-systemml git commit: [SYSTEMML-593] MLContext redesign
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/Script.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Script.java b/src/main/java/org/apache/sysml/api/mlcontext/Script.java
new file mode 100644
index 0000000..65d3338
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/Script.java
@@ -0,0 +1,652 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.instructions.cp.Data;
+
+import scala.Tuple2;
+import scala.Tuple3;
+import scala.collection.JavaConversions;
+
+/**
+ * A Script object encapsulates a DML or PYDML script.
+ *
+ */
+public class Script {
+
+ /**
+ * The type of script ({@code ScriptType.DML} or {@code ScriptType.PYDML}).
+ */
+ private ScriptType scriptType;
+ /**
+ * The script content.
+ */
+ private String scriptString;
+ /**
+ * The optional name of the script.
+ */
+ private String name;
+ /**
+ * All inputs (input parameters ($) and input variables).
+ */
+ private Map<String, Object> inputs = new LinkedHashMap<String, Object>();
+ /**
+ * The input parameters ($).
+ */
+ private Map<String, Object> inputParameters = new LinkedHashMap<String, Object>();
+ /**
+ * The input variables.
+ */
+ private Set<String> inputVariables = new LinkedHashSet<String>();
+ /**
+ * The input matrix metadata if present.
+ */
+ private Map<String, MatrixMetadata> inputMatrixMetadata = new LinkedHashMap<String, MatrixMetadata>();
+ /**
+ * The output variables.
+ */
+ private Set<String> outputVariables = new LinkedHashSet<String>();
+ /**
+ * The symbol table containing the data associated with variables.
+ */
+ private LocalVariableMap symbolTable = new LocalVariableMap();
+ /**
+ * The ScriptExecutor which is used to define the execution of the script.
+ */
+ private ScriptExecutor scriptExecutor;
+ /**
+ * The results of the execution of the script.
+ */
+ private MLResults results;
+
+ /**
+ * Script constructor, which by default creates a DML script.
+ */
+ public Script() {
+ scriptType = ScriptType.DML;
+ }
+
+ /**
+ * Script constructor, specifying the type of script ({@code ScriptType.DML}
+ * or {@code ScriptType.PYDML}).
+ *
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ */
+ public Script(ScriptType scriptType) {
+ this.scriptType = scriptType;
+ }
+
+ /**
+ * Script constructor, specifying the script content. By default, the script
+ * type is DML.
+ *
+ * @param scriptString
+ * the script content as a string
+ */
+ public Script(String scriptString) {
+ this.scriptString = scriptString;
+ this.scriptType = ScriptType.DML;
+ }
+
+ /**
+ * Script constructor, specifying the script content and the type of script
+ * (DML or PYDML).
+ *
+ * @param scriptString
+ * the script content as a string
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ */
+ public Script(String scriptString, ScriptType scriptType) {
+ this.scriptString = scriptString;
+ this.scriptType = scriptType;
+ }
+
+ /**
+ * Obtain the script type.
+ *
+ * @return {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ */
+ public ScriptType getScriptType() {
+ return scriptType;
+ }
+
+ /**
+ * Set the type of script (DML or PYDML).
+ *
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ */
+ public void setScriptType(ScriptType scriptType) {
+ this.scriptType = scriptType;
+ }
+
+ /**
+ * Obtain the script string.
+ *
+ * @return the script string
+ */
+ public String getScriptString() {
+ return scriptString;
+ }
+
+ /**
+ * Set the script string.
+ *
+ * @param scriptString
+ * the script string
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script setScriptString(String scriptString) {
+ this.scriptString = scriptString;
+ return this;
+ }
+
+ /**
+ * Obtain the input variable names as an unmodifiable set of strings.
+ *
+ * @return the input variable names
+ */
+ public Set<String> getInputVariables() {
+ return Collections.unmodifiableSet(inputVariables);
+ }
+
+ /**
+ * Obtain the output variable names as an unmodifiable set of strings.
+ *
+ * @return the output variable names
+ */
+ public Set<String> getOutputVariables() {
+ return Collections.unmodifiableSet(outputVariables);
+ }
+
+ /**
+ * Obtain the symbol table, which is essentially a
+ * {@code HashMap<String, Data>} representing variables and their values.
+ *
+ * @return the symbol table
+ */
+ public LocalVariableMap getSymbolTable() {
+ return symbolTable;
+ }
+
+ /**
+ * Obtain an unmodifiable map of all inputs (parameters ($) and variables).
+ *
+ * @return all inputs to the script
+ */
+ public Map<String, Object> getInputs() {
+ return Collections.unmodifiableMap(inputs);
+ }
+
+ /**
+ * Obtain an unmodifiable map of input matrix metadata.
+ *
+ * @return input matrix metadata
+ */
+ public Map<String, MatrixMetadata> getInputMatrixMetadata() {
+ return Collections.unmodifiableMap(inputMatrixMetadata);
+ }
+
+ /**
+ * Pass a map of inputs to the script.
+ *
+ * @param inputs
+ * map of inputs (parameters ($) and variables).
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script in(Map<String, Object> inputs) {
+ for (Entry<String, Object> input : inputs.entrySet()) {
+ in(input.getKey(), input.getValue());
+ }
+
+ return this;
+ }
+
+ /**
+ * Pass a Scala Map of inputs to the script.
+ *
+ * @param inputs
+ * Scala Map of inputs (parameters ($) and variables).
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script in(scala.collection.Map<String, Object> inputs) {
+ Map<String, Object> javaMap = JavaConversions.mapAsJavaMap(inputs);
+ in(javaMap);
+
+ return this;
+ }
+
+ /**
+ * Pass a Scala Seq of inputs to the script. The inputs are either two-value
+ * or three-value tuples, where the first value is the variable name, the
+ * second value is the variable value, and the third optional value is the
+ * metadata.
+ *
+ * @param inputs
+ * Scala Seq of inputs (parameters ($) and variables).
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script in(scala.collection.Seq<Object> inputs) {
+ List<Object> list = JavaConversions.asJavaList(inputs);
+ for (Object obj : list) {
+ if (obj instanceof Tuple3) {
+ @SuppressWarnings("unchecked")
+ Tuple3<String, Object, MatrixMetadata> t3 = (Tuple3<String, Object, MatrixMetadata>) obj;
+ in(t3._1(), t3._2(), t3._3());
+ } else if (obj instanceof Tuple2) {
+ @SuppressWarnings("unchecked")
+ Tuple2<String, Object> t2 = (Tuple2<String, Object>) obj;
+ in(t2._1(), t2._2());
+ } else {
+ throw new MLContextException("Only Tuples of 2 or 3 values are permitted");
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Obtain an unmodifiable map of all input parameters ($).
+ *
+ * @return input parameters ($)
+ */
+ public Map<String, Object> getInputParameters() {
+ return inputParameters;
+ }
+
+ /**
+ * Register an input (parameter ($) or variable).
+ *
+ * @param name
+ * name of the input
+ * @param value
+ * value of the input
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script in(String name, Object value) {
+ return in(name, value, null);
+ }
+
+ /**
+ * Register an input (parameter ($) or variable) with optional matrix
+ * metadata.
+ *
+ * @param name
+ * name of the input
+ * @param value
+ * value of the input
+ * @param matrixMetadata
+ * optional matrix metadata
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script in(String name, Object value, MatrixMetadata matrixMetadata) {
+ MLContextUtil.checkInputValueType(name, value);
+ if (inputs == null) {
+ inputs = new LinkedHashMap<String, Object>();
+ }
+ inputs.put(name, value);
+
+ if (name.startsWith("$")) {
+ MLContextUtil.checkInputParameterType(name, value);
+ if (inputParameters == null) {
+ inputParameters = new LinkedHashMap<String, Object>();
+ }
+ inputParameters.put(name, value);
+ } else {
+ Data data = MLContextUtil.convertInputType(name, value, matrixMetadata);
+ if (data != null) {
+ symbolTable.put(name, data);
+ inputVariables.add(name);
+ if (data instanceof MatrixObject) {
+ if (matrixMetadata != null) {
+ inputMatrixMetadata.put(name, matrixMetadata);
+ }
+ }
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Register an output variable.
+ *
+ * @param outputName
+ * name of the output variable
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script out(String outputName) {
+ outputVariables.add(outputName);
+ return this;
+ }
+
+ /**
+ * Register output variables.
+ *
+ * @param outputNames
+ * names of the output variables
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script out(String... outputNames) {
+ outputVariables.addAll(Arrays.asList(outputNames));
+ return this;
+ }
+
+ /**
+ * Clear the inputs, outputs, and symbol table.
+ */
+ public void clearIOS() {
+ clearInputs();
+ clearOutputs();
+ clearSymbolTable();
+ }
+
+ /**
+ * Clear the inputs and outputs, but not the symbol table.
+ */
+ public void clearIO() {
+ clearInputs();
+ clearOutputs();
+ }
+
+ /**
+ * Clear the script string, inputs, outputs, and symbol table.
+ */
+ public void clearAll() {
+ scriptString = null;
+ clearIOS();
+ }
+
+ /**
+ * Clear the inputs.
+ */
+ public void clearInputs() {
+ inputs.clear();
+ inputParameters.clear();
+ inputVariables.clear();
+ inputMatrixMetadata.clear();
+ }
+
+ /**
+ * Clear the outputs.
+ */
+ public void clearOutputs() {
+ outputVariables.clear();
+ }
+
+ /**
+ * Clear the symbol table.
+ */
+ public void clearSymbolTable() {
+ symbolTable.removeAll();
+ }
+
+ /**
+ * Obtain the results of the script execution.
+ *
+ * @return the results of the script execution.
+ */
+ public MLResults results() {
+ return results;
+ }
+
+ /**
+ * Obtain the results of the script execution.
+ *
+ * @return the results of the script execution.
+ */
+ public MLResults getResults() {
+ return results;
+ }
+
+ /**
+ * Set the results of the script execution.
+ *
+ * @param results
+ * the results of the script execution.
+ */
+ public void setResults(MLResults results) {
+ this.results = results;
+ }
+
+ /**
+ * Obtain the script executor used by this Script.
+ *
+ * @return the ScriptExecutor used by this Script.
+ */
+ public ScriptExecutor getScriptExecutor() {
+ return scriptExecutor;
+ }
+
+ /**
+ * Set the ScriptExecutor used by this Script.
+ *
+ * @param scriptExecutor
+ * the script executor
+ */
+ public void setScriptExecutor(ScriptExecutor scriptExecutor) {
+ this.scriptExecutor = scriptExecutor;
+ }
+
+ /**
+ * Is the script type DML?
+ *
+ * @return {@code true} if the script type is DML, {@code false} otherwise
+ */
+ public boolean isDML() {
+ return scriptType.isDML();
+ }
+
+ /**
+ * Is the script type PYDML?
+ *
+ * @return {@code true} if the script type is PYDML, {@code false} otherwise
+ */
+ public boolean isPYDML() {
+ return scriptType.isPYDML();
+ }
+
+ /**
+ * Generate the script execution string, which adds read/load/write/save
+ * statements to the beginning and end of the script to execute.
+ *
+ * @return the script execution string
+ */
+ public String getScriptExecutionString() {
+ StringBuilder sb = new StringBuilder();
+
+ Set<String> ins = getInputVariables();
+ for (String in : ins) {
+ Object inValue = getInputs().get(in);
+ sb.append(in);
+ if (isDML()) {
+ if (inValue instanceof String) {
+ String quotedString = MLContextUtil.quotedString((String) inValue);
+ sb.append(" = " + quotedString + ";\n");
+ } else if (MLContextUtil.isBasicType(inValue)) {
+ sb.append(" = read('', data_type='scalar');\n");
+ } else {
+ sb.append(" = read('');\n");
+ }
+ } else if (isPYDML()) {
+ if (inValue instanceof String) {
+ String quotedString = MLContextUtil.quotedString((String) inValue);
+ sb.append(" = " + quotedString + "\n");
+ } else if (MLContextUtil.isBasicType(inValue)) {
+ sb.append(" = load('', data_type='scalar')\n");
+ } else {
+ sb.append(" = load('')\n");
+ }
+ }
+
+ }
+
+ sb.append(getScriptString());
+ if (!getScriptString().endsWith("\n")) {
+ sb.append("\n");
+ }
+
+ Set<String> outs = getOutputVariables();
+ for (String out : outs) {
+ if (isDML()) {
+ sb.append("write(");
+ sb.append(out);
+ sb.append(", '');\n");
+ } else if (isPYDML()) {
+ sb.append("save(");
+ sb.append(out);
+ sb.append(", '')\n");
+ }
+ }
+
+ return sb.toString();
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+
+ sb.append(MLContextUtil.displayInputs("Inputs", inputs));
+ sb.append("\n");
+ sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable));
+ return sb.toString();
+ }
+
+ /**
+ * Display information about the script as a String. This consists of the
+ * script type, inputs, outputs, input parameters, input variables, output
+ * variables, the symbol table, the script string, and the script execution
+ * string.
+ *
+ * @return information about this script as a String
+ */
+ public String info() {
+ StringBuilder sb = new StringBuilder();
+
+ sb.append("Script Type: ");
+ sb.append(scriptType);
+ sb.append("\n\n");
+ sb.append(MLContextUtil.displayInputs("Inputs", inputs));
+ sb.append("\n");
+ sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable));
+ sb.append("\n");
+ sb.append(MLContextUtil.displayMap("Input Parameters", inputParameters));
+ sb.append("\n");
+ sb.append(MLContextUtil.displaySet("Input Variables", inputVariables));
+ sb.append("\n");
+ sb.append(MLContextUtil.displaySet("Output Variables", outputVariables));
+ sb.append("\n");
+ sb.append(MLContextUtil.displaySymbolTable("Symbol Table", symbolTable));
+ sb.append("\nScript String:\n");
+ sb.append(scriptString);
+ sb.append("\nScript Execution String:\n");
+ sb.append(getScriptExecutionString());
+ sb.append("\n");
+
+ return sb.toString();
+ }
+
+ /**
+ * Display the script inputs.
+ *
+ * @return the script inputs
+ */
+ public String displayInputs() {
+ return MLContextUtil.displayInputs("Inputs", inputs);
+ }
+
+ /**
+ * Display the script outputs.
+ *
+ * @return the script outputs as a String
+ */
+ public String displayOutputs() {
+ return MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable);
+ }
+
+ /**
+ * Display the script input parameters.
+ *
+ * @return the script input parameters as a String
+ */
+ public String displayInputParameters() {
+ return MLContextUtil.displayMap("Input Parameters", inputParameters);
+ }
+
+ /**
+ * Display the script input variables.
+ *
+ * @return the script input variables as a String
+ */
+ public String displayInputVariables() {
+ return MLContextUtil.displaySet("Input Variables", inputVariables);
+ }
+
+ /**
+ * Display the script output variables.
+ *
+ * @return the script output variables as a String
+ */
+ public String displayOutputVariables() {
+ return MLContextUtil.displaySet("Output Variables", outputVariables);
+ }
+
+ /**
+ * Display the script symbol table.
+ *
+ * @return the script symbol table as a String
+ */
+ public String displaySymbolTable() {
+ return MLContextUtil.displaySymbolTable("Symbol Table", symbolTable);
+ }
+
+ /**
+ * Obtain the script name.
+ *
+ * @return the script name
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Set the script name.
+ *
+ * @param name
+ * the script name
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script setName(String name) {
+ this.name = name;
+ return this;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
new file mode 100644
index 0000000..4702af2
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -0,0 +1,624 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.jmlc.JMLCUtils;
+import org.apache.sysml.api.monitoring.SparkMonitoringUtil;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.conf.DMLConfig;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.OptimizerUtils.OptimizationLevel;
+import org.apache.sysml.hops.globalopt.GlobalOptimizerWrapper;
+import org.apache.sysml.hops.rewrite.ProgramRewriter;
+import org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite;
+import org.apache.sysml.lops.LopsException;
+import org.apache.sysml.parser.AParserWrapper;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.DMLTranslator;
+import org.apache.sysml.parser.LanguageException;
+import org.apache.sysml.parser.ParseException;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.Program;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysml.utils.Explain;
+import org.apache.sysml.utils.Explain.ExplainCounts;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ * ScriptExecutor executes a DML or PYDML Script object using SystemML. This is
+ * accomplished by calling the {@link #execute} method.
+ * <p>
+ * Script execution via the MLContext API typically consists of the following
+ * steps:
+ * </p>
+ * <ol>
+ * <li>Language Steps
+ * <ol>
+ * <li>Parse script into program</li>
+ * <li>Live variable analysis</li>
+ * <li>Validate program</li>
+ * </ol>
+ * </li>
+ * <li>HOP (High-Level Operator) Steps
+ * <ol>
+ * <li>Construct HOP DAGs</li>
+ * <li>Static rewrites</li>
+ * <li>Intra-/Inter-procedural analysis</li>
+ * <li>Dynamic rewrites</li>
+ * <li>Compute memory estimates</li>
+ * <li>Rewrite persistent reads and writes (MLContext-specific)</li>
+ * </ol>
+ * </li>
+ * <li>LOP (Low-Level Operator) Steps
+ * <ol>
+ * <li>Contruct LOP DAGs</li>
+ * <li>Generate runtime program</li>
+ * <li>Execute runtime program</li>
+ * <li>Dynamic recompilation</li>
+ * </ol>
+ * </li>
+ * </ol>
+ * <p>
+ * Modifications to these steps can be accomplished by subclassing
+ * ScriptExecutor. For example, the following code will turn off the global data
+ * flow optimization check by subclassing ScriptExecutor and overriding the
+ * globalDataFlowOptimization method.
+ * </p>
+ *
+ * <code>ScriptExecutor scriptExecutor = new ScriptExecutor() {
+ * <br> // turn off global data flow optimization check
+ * <br> @Override
+ * <br> protected void globalDataFlowOptimization() {
+ * <br> return;
+ * <br> }
+ * <br>};
+ * <br>ml.execute(script, scriptExecutor);</code>
+ * <p>
+ *
+ * For more information, please see the {@link #execute} method.
+ */
+public class ScriptExecutor {
+
+ protected DMLConfig config;
+ protected SparkMonitoringUtil sparkMonitoringUtil;
+ protected DMLProgram dmlProgram;
+ protected DMLTranslator dmlTranslator;
+ protected Program runtimeProgram;
+ protected ExecutionContext executionContext;
+ protected Script script;
+ protected boolean explain = false;
+ protected boolean statistics = false;
+
+ /**
+ * ScriptExecutor constructor.
+ */
+ public ScriptExecutor() {
+ config = ConfigurationManager.getDMLConfig();
+ }
+
+ /**
+ * ScriptExecutor constructor, where the configuration properties are passed
+ * in.
+ *
+ * @param config
+ * the configuration properties to use by the ScriptExecutor
+ */
+ public ScriptExecutor(DMLConfig config) {
+ this.config = config;
+ ConfigurationManager.setGlobalConfig(config);
+ }
+
+ /**
+ * ScriptExecutor constructor, where a SparkMonitoringUtil object is passed
+ * in.
+ *
+ * @param sparkMonitoringUtil
+ * SparkMonitoringUtil object to monitor Spark
+ */
+ public ScriptExecutor(SparkMonitoringUtil sparkMonitoringUtil) {
+ this();
+ this.sparkMonitoringUtil = sparkMonitoringUtil;
+ }
+
+ /**
+ * ScriptExecutor constructor, where the configuration properties and a
+ * SparkMonitoringUtil object are passed in.
+ *
+ * @param config
+ * the configuration properties to use by the ScriptExecutor
+ * @param sparkMonitoringUtil
+ * SparkMonitoringUtil object to monitor Spark
+ */
+ public ScriptExecutor(DMLConfig config, SparkMonitoringUtil sparkMonitoringUtil) {
+ this.config = config;
+ this.sparkMonitoringUtil = sparkMonitoringUtil;
+ }
+
+ /**
+ * Construct DAGs of high-level operators (HOPs) for each block of
+ * statements.
+ */
+ protected void constructHops() {
+ try {
+ dmlTranslator.constructHops(dmlProgram);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while constructing HOPS (high-level operators)", e);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred while constructing HOPS (high-level operators)", e);
+ }
+ }
+
+ /**
+ * Apply static rewrites, perform intra-/inter-procedural analysis to
+ * propagate size information into functions, apply dynamic rewrites, and
+ * compute memory estimates for all HOPs.
+ */
+ protected void rewriteHops() {
+ try {
+ dmlTranslator.rewriteHopsDAG(dmlProgram);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e);
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e);
+ }
+ }
+
+ /**
+ * Output a description of the program to standard output.
+ */
+ protected void showExplanation() {
+ if (explain) {
+ try {
+ System.out.println(Explain.explain(dmlProgram));
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred while explaining dml program", e);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred while explaining dml program", e);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while explaining dml program", e);
+ }
+ }
+ }
+
+ /**
+ * Construct DAGs of low-level operators (LOPs) based on the DAGs of
+ * high-level operators (HOPs).
+ */
+ protected void constructLops() {
+ try {
+ dmlTranslator.constructLops(dmlProgram);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e);
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e);
+ } catch (LopsException e) {
+ throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e);
+ }
+ }
+
+ /**
+ * Create runtime program. For each namespace, translate function statement
+ * blocks into function program blocks and add these to the runtime program.
+ * For each top-level block, add the program block to the runtime program.
+ */
+ protected void generateRuntimeProgram() {
+ try {
+ runtimeProgram = dmlProgram.getRuntimeProgram(config);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while generating runtime program", e);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred while generating runtime program", e);
+ } catch (LopsException e) {
+ throw new MLContextException("Exception occurred while generating runtime program", e);
+ } catch (IOException e) {
+ throw new MLContextException("Exception occurred while generating runtime program", e);
+ }
+ }
+
+ /**
+ * Count the number of compiled MR Jobs/Spark Instructions in the runtime
+ * program and set this value in the statistics.
+ */
+ protected void countCompiledMRJobsAndSparkInstructions() {
+ ExplainCounts counts = Explain.countDistributedOperations(runtimeProgram);
+ Statistics.resetNoOfCompiledJobs(counts.numJobs);
+ }
+
+ /**
+ * Create an execution context and set its variables to be the symbol table
+ * of the script.
+ */
+ protected void createAndInitializeExecutionContext() {
+ executionContext = ExecutionContextFactory.createContext(runtimeProgram);
+ LocalVariableMap symbolTable = script.getSymbolTable();
+ if (symbolTable != null) {
+ executionContext.setVariables(symbolTable);
+ }
+ }
+
+ /**
+ * Execute a DML or PYDML script. This is broken down into the following
+ * primary methods:
+ *
+ * <ol>
+ * <li>{@link #parseScript()}</li>
+ * <li>{@link #liveVariableAnalysis()}</li>
+ * <li>{@link #validateScript()}</li>
+ * <li>{@link #constructHops()}</li>
+ * <li>{@link #rewriteHops()}</li>
+ * <li>{@link #showExplanation()}</li>
+ * <li>{@link #rewritePersistentReadsAndWrites()}</li>
+ * <li>{@link #constructLops()}</li>
+ * <li>{@link #generateRuntimeProgram()}</li>
+ * <li>{@link #globalDataFlowOptimization()}</li>
+ * <li>{@link #countCompiledMRJobsAndSparkInstructions()}</li>
+ * <li>{@link #initializeCachingAndScratchSpace()}</li>
+ * <li>{@link #cleanupRuntimeProgram()}</li>
+ * <li>{@link #createAndInitializeExecutionContext()}</li>
+ * <li>{@link #executeRuntimeProgram()}</li>
+ * <li>{@link #cleanupAfterExecution()}</li>
+ * </ol>
+ *
+ * @param script
+ * the DML or PYDML script to execute
+ */
+ public MLResults execute(Script script) {
+ this.script = script;
+ checkScriptHasTypeAndString();
+ script.setScriptExecutor(this);
+ setScriptStringInSparkMonitor();
+
+ // main steps in script execution
+ parseScript();
+ liveVariableAnalysis();
+ validateScript();
+ constructHops();
+ rewriteHops();
+ showExplanation();
+ rewritePersistentReadsAndWrites();
+ constructLops();
+ generateRuntimeProgram();
+ globalDataFlowOptimization();
+ countCompiledMRJobsAndSparkInstructions();
+ initializeCachingAndScratchSpace();
+ cleanupRuntimeProgram();
+ createAndInitializeExecutionContext();
+ executeRuntimeProgram();
+ setExplainRuntimeProgramInSparkMonitor();
+ cleanupAfterExecution();
+
+ // add symbol table to MLResults
+ MLResults mlResults = new MLResults(script);
+ script.setResults(mlResults);
+
+ if (statistics) {
+ System.out.println(Statistics.display());
+ }
+
+ return mlResults;
+ }
+
+ /**
+ * Perform any necessary cleanup operations after program execution.
+ */
+ protected void cleanupAfterExecution() {
+ restoreInputsInSymbolTable();
+ }
+
+ /**
+ * Restore the input variables in the symbol table after script execution.
+ */
+ protected void restoreInputsInSymbolTable() {
+ Map<String, Object> inputs = script.getInputs();
+ Map<String, MatrixMetadata> inputMatrixMetadata = script.getInputMatrixMetadata();
+ LocalVariableMap symbolTable = script.getSymbolTable();
+ Set<String> inputVariables = script.getInputVariables();
+ for (String inputVariable : inputVariables) {
+ if (symbolTable.get(inputVariable) == null) {
+ // retrieve optional metadata if it exists
+ MatrixMetadata mm = inputMatrixMetadata.get(inputVariable);
+ script.in(inputVariable, inputs.get(inputVariable), mm);
+ }
+ }
+ }
+
+ /**
+ * Remove rmvar instructions so as to maintain registered outputs after the
+ * program terminates.
+ */
+ protected void cleanupRuntimeProgram() {
+ JMLCUtils.cleanupRuntimeProgram(runtimeProgram, (script.getOutputVariables() == null) ? new String[0] : script
+ .getOutputVariables().toArray(new String[0]));
+ }
+
+ /**
+ * Execute the runtime program. This involves execution of the program
+ * blocks that make up the runtime program and may involve dynamic
+ * recompilation.
+ */
+ protected void executeRuntimeProgram() {
+ try {
+ runtimeProgram.execute(executionContext);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred while executing runtime program", e);
+ }
+ }
+
+ /**
+ * Obtain the SparkMonitoringUtil object.
+ *
+ * @return the SparkMonitoringUtil object, if available
+ */
+ public SparkMonitoringUtil getSparkMonitoringUtil() {
+ return sparkMonitoringUtil;
+ }
+
+ /**
+ * Check security, create scratch space, cleanup working directories,
+ * initialize caching, and reset statistics.
+ */
+ protected void initializeCachingAndScratchSpace() {
+ try {
+ DMLScript.initHadoopExecution(config);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred initializing caching and scratch space", e);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred initializing caching and scratch space", e);
+ } catch (IOException e) {
+ throw new MLContextException("Exception occurred initializing caching and scratch space", e);
+ }
+ }
+
+ /**
+ * Optimize the program.
+ */
+ protected void globalDataFlowOptimization() {
+ if (OptimizerUtils.isOptLevel(OptimizationLevel.O4_GLOBAL_TIME_MEMORY)) {
+ try {
+ runtimeProgram = GlobalOptimizerWrapper.optimizeProgram(dmlProgram, runtimeProgram);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred during global data flow optimization", e);
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred during global data flow optimization", e);
+ } catch (LopsException e) {
+ throw new MLContextException("Exception occurred during global data flow optimization", e);
+ }
+ }
+ }
+
+ /**
+ * Parse the script into an ANTLR parse tree, and convert this parse tree
+ * into a SystemML program. Parsing includes lexical/syntactic analysis.
+ */
+ protected void parseScript() {
+ try {
+ AParserWrapper parser = AParserWrapper.createParser(script.getScriptType().isPYDML());
+ Map<String, Object> inputParameters = script.getInputParameters();
+ Map<String, String> inputParametersStringMaps = MLContextUtil.convertInputParametersForParser(
+ inputParameters, script.getScriptType());
+
+ String scriptExecutionString = script.getScriptExecutionString();
+ dmlProgram = parser.parse(null, scriptExecutionString, inputParametersStringMaps);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred while parsing script", e);
+ }
+ }
+
+ /**
+ * Replace persistent reads and writes with transient reads and writes in
+ * the symbol table.
+ */
+ protected void rewritePersistentReadsAndWrites() {
+ LocalVariableMap symbolTable = script.getSymbolTable();
+ if (symbolTable != null) {
+ String[] inputs = (script.getInputVariables() == null) ? new String[0] : script.getInputVariables()
+ .toArray(new String[0]);
+ String[] outputs = (script.getOutputVariables() == null) ? new String[0] : script.getOutputVariables()
+ .toArray(new String[0]);
+ RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs);
+ ProgramRewriter programRewriter = new ProgramRewriter(rewrite);
+ try {
+ programRewriter.rewriteProgramHopDAGs(dmlProgram);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e);
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e);
+ }
+ }
+
+ }
+
+ /**
+ * Set the SystemML configuration properties.
+ *
+ * @param config
+ * The configuration properties
+ */
+ public void setConfig(DMLConfig config) {
+ this.config = config;
+ ConfigurationManager.setGlobalConfig(config);
+ }
+
+ /**
+ * Set the explanation of the runtime program in the SparkMonitoringUtil if
+ * it exists.
+ */
+ protected void setExplainRuntimeProgramInSparkMonitor() {
+ if (sparkMonitoringUtil != null) {
+ try {
+ String explainOutput = Explain.explain(runtimeProgram);
+ sparkMonitoringUtil.setExplainOutput(explainOutput);
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred while explaining runtime program", e);
+ }
+ }
+
+ }
+
+ /**
+ * Set the script string in the SparkMonitoringUtil if it exists.
+ */
+ protected void setScriptStringInSparkMonitor() {
+ if (sparkMonitoringUtil != null) {
+ sparkMonitoringUtil.setDMLString(script.getScriptString());
+ }
+ }
+
+ /**
+ * Set the SparkMonitoringUtil object.
+ *
+ * @param sparkMonitoringUtil
+ * The SparkMonitoringUtil object
+ */
+ public void setSparkMonitoringUtil(SparkMonitoringUtil sparkMonitoringUtil) {
+ this.sparkMonitoringUtil = sparkMonitoringUtil;
+ }
+
+ /**
+ * Liveness analysis is performed on the program, obtaining sets of live-in
+ * and live-out variables by forward and backward passes over the program.
+ */
+ protected void liveVariableAnalysis() {
+ try {
+ dmlTranslator = new DMLTranslator(dmlProgram);
+ dmlTranslator.liveVariableAnalysis(dmlProgram);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred during live variable analysis", e);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred during live variable analysis", e);
+ }
+ }
+
+ /**
+ * Semantically validate the program's expressions, statements, and
+ * statement blocks in a single recursive pass over the program. Constant
+ * and size propagation occurs during this step.
+ */
+ protected void validateScript() {
+ try {
+ dmlTranslator.validateParseTree(dmlProgram);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while validating script", e);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred while validating script", e);
+ } catch (IOException e) {
+ throw new MLContextException("Exception occurred while validating script", e);
+ }
+ }
+
+ /**
+ * Check that the Script object has a type (DML or PYDML) and a string
+ * representing the content of the Script.
+ */
+ protected void checkScriptHasTypeAndString() {
+ if (script == null) {
+ throw new MLContextException("Script is null");
+ } else if (script.getScriptType() == null) {
+ throw new MLContextException("ScriptType (DML or PYDML) needs to be specified");
+ } else if (script.getScriptString() == null) {
+ throw new MLContextException("Script string is null");
+ } else if (StringUtils.isBlank(script.getScriptString())) {
+ throw new MLContextException("Script string is blank");
+ }
+ }
+
+ /**
+ * Obtain the program
+ *
+ * @return the program
+ */
+ public DMLProgram getDmlProgram() {
+ return dmlProgram;
+ }
+
+ /**
+ * Obtain the translator
+ *
+ * @return the translator
+ */
+ public DMLTranslator getDmlTranslator() {
+ return dmlTranslator;
+ }
+
+ /**
+ * Obtain the runtime program
+ *
+ * @return the runtime program
+ */
+ public Program getRuntimeProgram() {
+ return runtimeProgram;
+ }
+
+ /**
+ * Obtain the execution context
+ *
+ * @return the execution context
+ */
+ public ExecutionContext getExecutionContext() {
+ return executionContext;
+ }
+
+ /**
+ * Obtain the Script object associated with this ScriptExecutor
+ *
+ * @return the Script object associated with this ScriptExecutor
+ */
+ public Script getScript() {
+ return script;
+ }
+
+ /**
+ * Whether or not an explanation of the DML/PYDML program should be output
+ * to standard output.
+ *
+ * @param explain
+ * {@code true} if explanation should be output, {@code false}
+ * otherwise
+ */
+ public void setExplain(boolean explain) {
+ this.explain = explain;
+ }
+
+ /**
+ * Whether or not statistics about the DML/PYDML program should be output to
+ * standard output.
+ *
+ * @param statistics
+ * {@code true} if statistics should be output, {@code false}
+ * otherwise
+ */
+ public void setStatistics(boolean statistics) {
+ this.statistics = statistics;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java
new file mode 100644
index 0000000..5f0e56b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.MalformedURLException;
+import java.net.URL;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.io.IOUtils;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.runtime.util.LocalFileUtils;
+
+/**
+ * Factory for creating DML and PYDML Script objects from strings, files, URLs,
+ * and input streams.
+ *
+ */
+public class ScriptFactory {
+
+ /**
+ * Create a DML Script object based on a string path to a file.
+ *
+ * @param scriptFilePath
+ * path to DML script file (local or HDFS)
+ * @return DML Script object
+ */
+ public static Script dmlFromFile(String scriptFilePath) {
+ return scriptFromFile(scriptFilePath, ScriptType.DML);
+ }
+
+ /**
+ * Create a DML Script object based on an input stream.
+ *
+ * @param inputStream
+ * input stream to DML
+ * @return DML Script object
+ */
+ public static Script dmlFromInputStream(InputStream inputStream) {
+ return scriptFromInputStream(inputStream, ScriptType.DML);
+ }
+
+ /**
+ * Creates a DML Script object based on a file in the local file system. To
+ * create a DML Script object from a local file or HDFS, please use
+ * {@link #dmlFromFile(String)}.
+ *
+ * @param localScriptFile
+ * the local DML file
+ * @return DML Script object
+ */
+ public static Script dmlFromLocalFile(File localScriptFile) {
+ return scriptFromLocalFile(localScriptFile, ScriptType.DML);
+ }
+
+ /**
+ * Create a DML Script object based on a string.
+ *
+ * @param scriptString
+ * string of DML
+ * @return DML Script object
+ */
+ public static Script dmlFromString(String scriptString) {
+ return scriptFromString(scriptString, ScriptType.DML);
+ }
+
+ /**
+ * Create a DML Script object based on a URL path.
+ *
+ * @param scriptUrlPath
+ * URL path to DML script
+ * @return DML Script object
+ */
+ public static Script dmlFromUrl(String scriptUrlPath) {
+ return scriptFromUrl(scriptUrlPath, ScriptType.DML);
+ }
+
+ /**
+ * Create a DML Script object based on a URL.
+ *
+ * @param scriptUrl
+ * URL to DML script
+ * @return DML Script object
+ */
+ public static Script dmlFromUrl(URL scriptUrl) {
+ return scriptFromUrl(scriptUrl, ScriptType.DML);
+ }
+
+ /**
+ * Create a PYDML Script object based on a string path to a file.
+ *
+ * @param scriptFilePath
+ * path to PYDML script file (local or HDFS)
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromFile(String scriptFilePath) {
+ return scriptFromFile(scriptFilePath, ScriptType.PYDML);
+ }
+
+ /**
+ * Create a PYDML Script object based on an input stream.
+ *
+ * @param inputStream
+ * input stream to PYDML
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromInputStream(InputStream inputStream) {
+ return scriptFromInputStream(inputStream, ScriptType.PYDML);
+ }
+
+ /**
+ * Creates a PYDML Script object based on a file in the local file system.
+ * To create a PYDML Script object from a local file or HDFS, please use
+ * {@link #pydmlFromFile(String)}.
+ *
+ * @param localScriptFile
+ * the local PYDML file
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromLocalFile(File localScriptFile) {
+ return scriptFromLocalFile(localScriptFile, ScriptType.PYDML);
+ }
+
+ /**
+ * Create a PYDML Script object based on a string.
+ *
+ * @param scriptString
+ * string of PYDML
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromString(String scriptString) {
+ return scriptFromString(scriptString, ScriptType.PYDML);
+ }
+
+ /**
+ * Creat a PYDML Script object based on a URL path.
+ *
+ * @param scriptUrlPath
+ * URL path to PYDML script
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromUrl(String scriptUrlPath) {
+ return scriptFromUrl(scriptUrlPath, ScriptType.PYDML);
+ }
+
+ /**
+ * Create a PYDML Script object based on a URL.
+ *
+ * @param scriptUrl
+ * URL to PYDML script
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromUrl(URL scriptUrl) {
+ return scriptFromUrl(scriptUrl, ScriptType.PYDML);
+ }
+
+ /**
+ * Create a DML or PYDML Script object based on a string path to a file.
+ *
+ * @param scriptFilePath
+ * path to DML or PYDML script file (local or HDFS)
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromFile(String scriptFilePath, ScriptType scriptType) {
+ String scriptString = getScriptStringFromFile(scriptFilePath);
+ return scriptFromString(scriptString, scriptType).setName(scriptFilePath);
+ }
+
+ /**
+ * Create a DML or PYDML Script object based on an input stream.
+ *
+ * @param inputStream
+ * input stream to DML or PYDML
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromInputStream(InputStream inputStream, ScriptType scriptType) {
+ String scriptString = getScriptStringFromInputStream(inputStream);
+ return scriptFromString(scriptString, scriptType);
+ }
+
+ /**
+ * Creates a DML or PYDML Script object based on a file in the local file
+ * system. To create a Script object from a local file or HDFS, please use
+ * {@link scriptFromFile(String, ScriptType)}.
+ *
+ * @param localScriptFile
+ * The local DML or PYDML file
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromLocalFile(File localScriptFile, ScriptType scriptType) {
+ String scriptString = getScriptStringFromFile(localScriptFile);
+ return scriptFromString(scriptString, scriptType).setName(localScriptFile.getName());
+ }
+
+ /**
+ * Create a DML or PYDML Script object based on a string.
+ *
+ * @param scriptString
+ * string of DML or PYDML
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromString(String scriptString, ScriptType scriptType) {
+ Script script = new Script(scriptString, scriptType);
+ return script;
+ }
+
+ /**
+ * Creat a DML or PYDML Script object based on a URL path.
+ *
+ * @param scriptUrlPath
+ * URL path to DML or PYDML script
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromUrl(String scriptUrlPath, ScriptType scriptType) {
+ String scriptString = getScriptStringFromUrl(scriptUrlPath);
+ return scriptFromString(scriptString, scriptType).setName(scriptUrlPath);
+ }
+
+ /**
+ * Create a DML or PYDML Script object based on a URL.
+ *
+ * @param scriptUrl
+ * URL to DML or PYDML script
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromUrl(URL scriptUrl, ScriptType scriptType) {
+ String scriptString = getScriptStringFromUrl(scriptUrl);
+ return scriptFromString(scriptString, scriptType).setName(scriptUrl.toString());
+ }
+
+ /**
+ * Create a DML Script object based on a string.
+ *
+ * @param scriptString
+ * string of DML
+ * @return DML Script object
+ */
+ public static Script dml(String scriptString) {
+ return dmlFromString(scriptString);
+ }
+
+ /**
+ * Obtain a script string from a file in the local file system. To obtain a
+ * script string from a file in HDFS, please use
+ * getScriptStringFromFile(String scriptFilePath).
+ *
+ * @param file
+ * The script file.
+ * @return The script string.
+ * @throws MLContextException
+ * If a problem occurs reading the script string from the file.
+ */
+ private static String getScriptStringFromFile(File file) {
+ if (file == null) {
+ throw new MLContextException("Script file is null");
+ }
+ String filePath = file.getPath();
+ try {
+ if (!LocalFileUtils.validateExternalFilename(filePath, false)) {
+ throw new MLContextException("Invalid (non-trustworthy) local filename: " + filePath);
+ }
+ String scriptString = FileUtils.readFileToString(file);
+ return scriptString;
+ } catch (IllegalArgumentException e) {
+ throw new MLContextException("Error trying to read script string from file: " + filePath, e);
+ } catch (IOException e) {
+ throw new MLContextException("Error trying to read script string from file: " + filePath, e);
+ }
+ }
+
+ /**
+ * Obtain a script string from a file.
+ *
+ * @param scriptFilePath
+ * The file path to the script file (either local file system or
+ * HDFS)
+ * @return The script string
+ * @throws MLContextException
+ * If a problem occurs reading the script string from the file
+ */
+ private static String getScriptStringFromFile(String scriptFilePath) {
+ if (scriptFilePath == null) {
+ throw new MLContextException("Script file path is null");
+ }
+ try {
+ if (scriptFilePath.startsWith("hdfs:") || scriptFilePath.startsWith("gpfs:")) {
+ if (!LocalFileUtils.validateExternalFilename(scriptFilePath, true)) {
+ throw new MLContextException("Invalid (non-trustworthy) hdfs/gpfs filename: " + scriptFilePath);
+ }
+ FileSystem fs = FileSystem.get(ConfigurationManager.getCachedJobConf());
+ Path path = new Path(scriptFilePath);
+ FSDataInputStream fsdis = fs.open(path);
+ String scriptString = IOUtils.toString(fsdis);
+ return scriptString;
+ } else {// from local file system
+ if (!LocalFileUtils.validateExternalFilename(scriptFilePath, false)) {
+ throw new MLContextException("Invalid (non-trustworthy) local filename: " + scriptFilePath);
+ }
+ File scriptFile = new File(scriptFilePath);
+ String scriptString = FileUtils.readFileToString(scriptFile);
+ return scriptString;
+ }
+ } catch (IllegalArgumentException e) {
+ throw new MLContextException("Error trying to read script string from file: " + scriptFilePath, e);
+ } catch (IOException e) {
+ throw new MLContextException("Error trying to read script string from file: " + scriptFilePath, e);
+ }
+ }
+
+ /**
+ * Obtain a script string from an InputStream.
+ *
+ * @param inputStream
+ * The InputStream from which to read the script string
+ * @return The script string
+ * @throws MLContextException
+ * If a problem occurs reading the script string from the URL
+ */
+ private static String getScriptStringFromInputStream(InputStream inputStream) {
+ if (inputStream == null) {
+ throw new MLContextException("InputStream is null");
+ }
+ try {
+ String scriptString = IOUtils.toString(inputStream);
+ return scriptString;
+ } catch (IOException e) {
+ throw new MLContextException("Error trying to read script string from InputStream", e);
+ }
+ }
+
+ /**
+ * Obtain a script string from a URL.
+ *
+ * @param scriptUrlPath
+ * The URL path to the script file
+ * @return The script string
+ * @throws MLContextException
+ * If a problem occurs reading the script string from the URL
+ */
+ private static String getScriptStringFromUrl(String scriptUrlPath) {
+ if (scriptUrlPath == null) {
+ throw new MLContextException("Script URL path is null");
+ }
+ try {
+ URL url = new URL(scriptUrlPath);
+ return getScriptStringFromUrl(url);
+ } catch (MalformedURLException e) {
+ throw new MLContextException("Error trying to read script string from URL path: " + scriptUrlPath, e);
+ }
+ }
+
+ /**
+ * Obtain a script string from a URL.
+ *
+ * @param url
+ * The script URL
+ * @return The script string
+ * @throws MLContextException
+ * If a problem occurs reading the script string from the URL
+ */
+ private static String getScriptStringFromUrl(URL url) {
+ if (url == null) {
+ throw new MLContextException("URL is null");
+ }
+ String urlString = url.toString();
+ if ((!urlString.toLowerCase().startsWith("http:")) && (!urlString.toLowerCase().startsWith("https:"))) {
+ throw new MLContextException("Currently only reading from http and https URLs is supported");
+ }
+ try {
+ InputStream is = url.openStream();
+ String scriptString = IOUtils.toString(is);
+ return scriptString;
+ } catch (IOException e) {
+ throw new MLContextException("Error trying to read script string from URL: " + url, e);
+ }
+ }
+
+ /**
+ * Create a PYDML script object based on a string.
+ *
+ * @param scriptString
+ * string of PYDML
+ * @return PYDML Script object
+ */
+ public static Script pydml(String scriptString) {
+ return pydmlFromString(scriptString);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java
new file mode 100644
index 0000000..94c9057
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+/**
+ * ScriptType represents the type of script, DML (R-like syntax) or PYDML
+ * (Python-like syntax).
+ *
+ */
+public enum ScriptType {
+ /**
+ * R-like syntax.
+ */
+ DML,
+
+ /**
+ * Python-like syntax.
+ */
+ PYDML;
+
+ /**
+ * Obtain script type as a lowercase string ("dml" or "pydml").
+ *
+ * @return lowercase string representing the script type
+ */
+ public String lowerCase() {
+ return super.toString().toLowerCase();
+ }
+
+ /**
+ * Is the script type DML?
+ *
+ * @return {@code true} if the script type is DML, {@code false} otherwise
+ */
+ public boolean isDML() {
+ return (this == ScriptType.DML);
+ }
+
+ /**
+ * Is the script type PYDML?
+ *
+ * @return {@code true} if the script type is PYDML, {@code false} otherwise
+ */
+ public boolean isPYDML() {
+ return (this == ScriptType.PYDML);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 0eea221..c715331 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -36,11 +36,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.RDDInfo;
import org.apache.spark.storage.StorageLevel;
-
-import scala.Tuple2;
-
import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.MLContext;
import org.apache.sysml.api.MLContextProxy;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.OptimizerUtils;
@@ -82,6 +78,8 @@ import org.apache.sysml.runtime.util.MapReduceTool;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.utils.Statistics;
+import scala.Tuple2;
+
public class SparkExecutionContext extends ExecutionContext
{
@@ -178,22 +176,28 @@ public class SparkExecutionContext extends ExecutionContext
*
*/
private synchronized static void initSparkContext()
- {
+ {
//check for redundant spark context init
if( _spctx != null )
return;
-
+
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
//create a default spark context (master, appname, etc refer to system properties
//as given in the spark configuration or during spark-submit)
- MLContext mlCtx = MLContextProxy.getActiveMLContext();
- if(mlCtx != null)
+ Object mlCtxObj = MLContextProxy.getActiveMLContext();
+ if(mlCtxObj != null)
{
// This is when DML is called through spark shell
// Will clean the passing of static variables later as this involves minimal change to DMLScript
- _spctx = new JavaSparkContext(mlCtx.getSparkContext());
+ if (mlCtxObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlCtxObj;
+ _spctx = new JavaSparkContext(mlCtx.getSparkContext());
+ } else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj;
+ _spctx = new JavaSparkContext(mlCtx.getSparkContext());
+ }
}
else
{
@@ -1424,11 +1428,26 @@ public class SparkExecutionContext extends ExecutionContext
}
}
- MLContext mlContext = MLContextProxy.getActiveMLContext();
- if(mlContext != null && mlContext.getMonitoringUtil() != null) {
- mlContext.getMonitoringUtil().setLineageInfo(inst, outDebugString);
- }
- else {
+
+ Object mlContextObj = MLContextProxy.getActiveMLContext();
+ if (mlContextObj != null) {
+ if (mlContextObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlContextObj;
+ if (mlCtx.getMonitoringUtil() != null) {
+ mlCtx.getMonitoringUtil().setLineageInfo(inst, outDebugString);
+ } else {
+ throw new DMLRuntimeException("The method setLineageInfoForExplain should be called only through MLContext");
+ }
+ } else if (mlContextObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlContextObj;
+ if (mlCtx.getSparkMonitoringUtil() != null) {
+ mlCtx.getSparkMonitoringUtil().setLineageInfo(inst, outDebugString);
+ } else {
+ throw new DMLRuntimeException("The method setLineageInfoForExplain should be called only through MLContext");
+ }
+ }
+
+ } else {
throw new DMLRuntimeException("The method setLineageInfoForExplain should be called only through MLContext");
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
index 0c0d3f0..d5301e7 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
@@ -19,7 +19,6 @@
package org.apache.sysml.runtime.instructions.spark;
-import org.apache.sysml.api.MLContext;
import org.apache.sysml.api.MLContextProxy;
import org.apache.sysml.lops.runtime.RunMRJobs;
import org.apache.sysml.runtime.DMLRuntimeException;
@@ -99,13 +98,23 @@ public abstract class SPInstruction extends Instruction
//spark-explain-specific handling of current instructions
//This only relevant for ComputationSPInstruction as in postprocess we call setDebugString which is valid only for ComputationSPInstruction
- MLContext mlCtx = MLContextProxy.getActiveMLContext();
- if( tmp instanceof ComputationSPInstruction
- && mlCtx != null && mlCtx.getMonitoringUtil() != null
- && ec instanceof SparkExecutionContext )
- {
- mlCtx.getMonitoringUtil().addCurrentInstruction((SPInstruction)tmp);
- MLContextProxy.setInstructionForMonitoring(tmp);
+ Object mlCtxObj = MLContextProxy.getActiveMLContext();
+ if (mlCtxObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlCtxObj;
+ if (tmp instanceof ComputationSPInstruction
+ && mlCtx != null && mlCtx.getMonitoringUtil() != null
+ && ec instanceof SparkExecutionContext ) {
+ mlCtx.getMonitoringUtil().addCurrentInstruction((SPInstruction)tmp);
+ MLContextProxy.setInstructionForMonitoring(tmp);
+ }
+ } else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj;
+ if (tmp instanceof ComputationSPInstruction
+ && mlCtx != null && mlCtx.getSparkMonitoringUtil() != null
+ && ec instanceof SparkExecutionContext ) {
+ mlCtx.getSparkMonitoringUtil().addCurrentInstruction((SPInstruction)tmp);
+ MLContextProxy.setInstructionForMonitoring(tmp);
+ }
}
return tmp;
@@ -120,14 +129,25 @@ public abstract class SPInstruction extends Instruction
throws DMLRuntimeException
{
//spark-explain-specific handling of current instructions
- MLContext mlCtx = MLContextProxy.getActiveMLContext();
- if( this instanceof ComputationSPInstruction
- && mlCtx != null && mlCtx.getMonitoringUtil() != null
- && ec instanceof SparkExecutionContext )
- {
- SparkExecutionContext sec = (SparkExecutionContext) ec;
- sec.setDebugString(this, ((ComputationSPInstruction) this).getOutputVariableName());
- mlCtx.getMonitoringUtil().removeCurrentInstruction(this);
+ Object mlCtxObj = MLContextProxy.getActiveMLContext();
+ if (mlCtxObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlCtxObj;
+ if (this instanceof ComputationSPInstruction
+ && mlCtx != null && mlCtx.getMonitoringUtil() != null
+ && ec instanceof SparkExecutionContext ) {
+ SparkExecutionContext sec = (SparkExecutionContext) ec;
+ sec.setDebugString(this, ((ComputationSPInstruction) this).getOutputVariableName());
+ mlCtx.getMonitoringUtil().removeCurrentInstruction(this);
+ }
+ } else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj;
+ if (this instanceof ComputationSPInstruction
+ && mlCtx != null && mlCtx.getSparkMonitoringUtil() != null
+ && ec instanceof SparkExecutionContext ) {
+ SparkExecutionContext sec = (SparkExecutionContext) ec;
+ sec.setDebugString(this, ((ComputationSPInstruction) this).getOutputVariableName());
+ mlCtx.getSparkMonitoringUtil().removeCurrentInstruction(this);
+ }
}
//maintain statistics
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java
index 956b841..3bf2f67 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java
@@ -33,16 +33,14 @@ import org.apache.spark.storage.RDDInfo;
import org.apache.spark.ui.jobs.StagesTab;
import org.apache.spark.ui.jobs.UIData.TaskUIData;
import org.apache.spark.ui.scope.RDDOperationGraphListener;
+import org.apache.sysml.api.MLContextProxy;
+import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import scala.Option;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.xml.Node;
-import org.apache.sysml.api.MLContext;
-import org.apache.sysml.api.MLContextProxy;
-import org.apache.sysml.runtime.instructions.spark.SPInstruction;
-
// Instead of extending org.apache.spark.JavaSparkListener
/**
* This class is only used by MLContext for now. It is used to provide UI data for Python notebook.
@@ -94,9 +92,19 @@ public class SparkListener extends RDDOperationGraphListener {
jobDAGs.put(jobID, jobNodes);
synchronized(currentInstructions) {
for(SPInstruction inst : currentInstructions) {
- MLContext mlContext = MLContextProxy.getActiveMLContext();
- if(mlContext != null && mlContext.getMonitoringUtil() != null) {
- mlContext.getMonitoringUtil().setJobId(inst, jobID);
+ Object mlContextObj = MLContextProxy.getActiveMLContext();
+ if (mlContextObj != null) {
+ if (mlContextObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlContext = (org.apache.sysml.api.MLContext) mlContextObj;
+ if (mlContext.getMonitoringUtil() != null) {
+ mlContext.getMonitoringUtil().setJobId(inst, jobID);
+ }
+ } else if (mlContextObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlContext = (org.apache.sysml.api.mlcontext.MLContext) mlContextObj;
+ if (mlContext.getSparkMonitoringUtil() != null) {
+ mlContext.getSparkMonitoringUtil().setJobId(inst, jobID);
+ }
+ }
}
}
}
@@ -140,9 +148,19 @@ public class SparkListener extends RDDOperationGraphListener {
synchronized(currentInstructions) {
for(SPInstruction inst : currentInstructions) {
- MLContext mlContext = MLContextProxy.getActiveMLContext();
- if(mlContext != null && mlContext.getMonitoringUtil() != null) {
- mlContext.getMonitoringUtil().setStageId(inst, stageSubmitted.stageInfo().stageId());
+ Object mlContextObj = MLContextProxy.getActiveMLContext();
+ if (mlContextObj != null) {
+ if (mlContextObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlContext = (org.apache.sysml.api.MLContext) mlContextObj;
+ if (mlContext.getMonitoringUtil() != null) {
+ mlContext.getMonitoringUtil().setStageId(inst, stageSubmitted.stageInfo().stageId());
+ }
+ } else if (mlContextObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlContext = (org.apache.sysml.api.mlcontext.MLContext) mlContextObj;
+ if (mlContext.getSparkMonitoringUtil() != null) {
+ mlContext.getSparkMonitoringUtil().setStageId(inst, stageSubmitted.stageInfo().stageId());
+ }
+ }
}
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
index ccdc927..f022e40 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
@@ -410,7 +410,7 @@ public class RDDConverterUtilsExt
}
- private static class DataFrameAnalysisFunction implements Function<Row,Row> {
+ public static class DataFrameAnalysisFunction implements Function<Row,Row> {
private static final long serialVersionUID = 5705371332119770215L;
private RowAnalysisFunctionHelper helper = null;
boolean isVectorBasedRDD;
@@ -445,7 +445,7 @@ public class RDDConverterUtilsExt
}
- private static class DataFrameToBinaryBlockFunction implements PairFlatMapFunction<Iterator<Tuple2<Row,Long>>,MatrixIndexes,MatrixBlock> {
+ public static class DataFrameToBinaryBlockFunction implements PairFlatMapFunction<Iterator<Tuple2<Row,Long>>,MatrixIndexes,MatrixBlock> {
private static final long serialVersionUID = 653447740362447236L;
private RowToBinaryBlockFunctionHelper helper = null;
boolean isVectorBasedDF;