You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by du...@apache.org on 2016/08/19 22:47:34 UTC
incubator-systemml git commit: [SYSTEMML-853] Python API for new
MLContext
Repository: incubator-systemml
Updated Branches:
refs/heads/master 9ac1d4f86 -> 4fff6f769
[SYSTEMML-853] Python API for new MLContext
This adds a new Python API that targets the new MLContext API on the Java/Scala side.
Closes #211.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/4fff6f76
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/4fff6f76
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/4fff6f76
Branch: refs/heads/master
Commit: 4fff6f76951c42dcf902d584a26089eda27c43a8
Parents: 9ac1d4f
Author: MechCoder <mk...@nyu.edu>
Authored: Fri Aug 19 15:46:07 2016 -0700
Committer: Mike Dusenberry <mw...@us.ibm.com>
Committed: Fri Aug 19 15:46:07 2016 -0700
----------------------------------------------------------------------
.../apache/sysml/api/mlcontext/MLResults.java | 108 ++++++----
.../org/apache/sysml/api/mlcontext/Script.java | 98 +++++----
src/main/python/SystemML.py | 203 +++++++++++++++++++
src/main/python/SystemMLtests.py | 87 ++++++++
4 files changed, 412 insertions(+), 84 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4fff6f76/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
index 582a73e..289f490 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -92,7 +92,7 @@ public class MLResults {
/**
* Obtain an output as a {@code Data} object.
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code Data} object
@@ -108,7 +108,7 @@ public class MLResults {
/**
* Obtain an output as a {@code MatrixObject}
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code MatrixObject}
@@ -124,7 +124,7 @@ public class MLResults {
/**
* Obtain an output as a two-dimensional {@code double} array.
- *
+ *
* @param outputName
* the name of the output
* @return the output as a two-dimensional {@code double} array
@@ -150,7 +150,7 @@ public class MLResults {
* <br>2 1 3.0
* <br>2 2 4.0
* </code>
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code JavaRDD<String>} in IJV format
@@ -174,7 +174,7 @@ public class MLResults {
* <code>1.0,2.0
* <br>3.0,4.0
* </code>
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code JavaRDD<String>} in CSV format
@@ -198,7 +198,7 @@ public class MLResults {
* <code>1.0,2.0
* <br>3.0,4.0
* </code>
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code RDD<String>} in CSV format
@@ -224,7 +224,7 @@ public class MLResults {
* <br>2 1 3.0
* <br>2 2 4.0
* </code>
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code RDD<String>} in IJV format
@@ -248,7 +248,7 @@ public class MLResults {
* <code>[0.0,1.0,2.0]
* <br>[1.0,3.0,4.0]
* </code>
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code DataFrame} of doubles
@@ -258,7 +258,7 @@ public class MLResults {
DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false);
return df;
}
-
+
public DataFrame getDataFrame(String outputName, boolean isVectorDF) {
MatrixObject mo = getMatrixObject(outputName);
DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, isVectorDF);
@@ -267,7 +267,7 @@ public class MLResults {
/**
* Obtain an output as a {@code Matrix}.
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code Matrix}
@@ -277,11 +277,11 @@ public class MLResults {
Matrix matrix = new Matrix(mo, sparkExecutionContext);
return matrix;
}
-
+
/**
* Obtain an output as a {@code BinaryBlockMatrix}.
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code BinaryBlockMatrix}
@@ -295,7 +295,7 @@ public class MLResults {
/**
* Obtain an output as a two-dimensional {@code String} array.
- *
+ *
* @param outputName
* the name of the output
* @return the output as a two-dimensional {@code String} array
@@ -320,7 +320,7 @@ public class MLResults {
/**
* Obtain a {@code double} output
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code double}
@@ -331,8 +331,26 @@ public class MLResults {
}
/**
+ * Obtain a serializable object as output
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a serializable object.
+ */
+
+ public Object get(String outputName) {
+ Data data = getData(outputName);
+ if (data instanceof ScalarObject) {
+ ScalarObject so = (ScalarObject) data;
+ return so.getValue();
+ } else {
+ return data;
+ }
+ }
+
+ /**
* Obtain an output as a {@code Scalar} object.
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code Scalar} object
@@ -348,7 +366,7 @@ public class MLResults {
/**
* Obtain a {@code boolean} output
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code boolean}
@@ -360,7 +378,7 @@ public class MLResults {
/**
* Obtain a {@code long} output
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code long}
@@ -372,7 +390,7 @@ public class MLResults {
/**
* Obtain a {@code String} output
- *
+ *
* @param outputName
* the name of the output
* @return the output as a {@code String}
@@ -384,7 +402,7 @@ public class MLResults {
/**
* Obtain the Script object associated with these results.
- *
+ *
* @return the DML or PYDML Script object
*/
public Script getScript() {
@@ -393,7 +411,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @return a Scala tuple
@@ -405,7 +423,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -419,7 +437,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -436,7 +454,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -456,7 +474,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -478,7 +496,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -503,7 +521,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -531,7 +549,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -561,7 +579,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -594,7 +612,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -629,7 +647,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -668,7 +686,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -709,7 +727,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -752,7 +770,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -798,7 +816,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -846,7 +864,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -898,7 +916,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -952,7 +970,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -1008,7 +1026,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -1067,7 +1085,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -1128,7 +1146,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -1192,7 +1210,7 @@ public class MLResults {
/**
* Obtain a Scala tuple.
- *
+ *
* @param outputName1
* the name of the first output
* @param outputName2
@@ -1262,7 +1280,7 @@ public class MLResults {
* specific output type. MLResults tuple support requires specifying the
* object types at runtime to avoid the items in the tuple being returned as
* Anys.
- *
+ *
* @param outputName
* the name of the output
* @return the output value cast to a specific output type
@@ -1289,7 +1307,7 @@ public class MLResults {
/**
* Obtain the symbol table, which is essentially a {@code Map<String, Data>}
* representing variables and their values as SystemML representations.
- *
+ *
* @return the symbol table
*/
public LocalVariableMap getSymbolTable() {
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4fff6f76/src/main/java/org/apache/sysml/api/mlcontext/Script.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Script.java b/src/main/java/org/apache/sysml/api/mlcontext/Script.java
index 782f2c6..28667cf 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/Script.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/Script.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -97,7 +97,7 @@ public class Script {
/**
* Script constructor, specifying the type of script ({@code ScriptType.DML}
* or {@code ScriptType.PYDML}).
- *
+ *
* @param scriptType
* {@code ScriptType.DML} or {@code ScriptType.PYDML}
*/
@@ -108,7 +108,7 @@ public class Script {
/**
* Script constructor, specifying the script content. By default, the script
* type is DML.
- *
+ *
* @param scriptString
* the script content as a string
*/
@@ -120,7 +120,7 @@ public class Script {
/**
* Script constructor, specifying the script content and the type of script
* (DML or PYDML).
- *
+ *
* @param scriptString
* the script content as a string
* @param scriptType
@@ -133,7 +133,7 @@ public class Script {
/**
* Obtain the script type.
- *
+ *
* @return {@code ScriptType.DML} or {@code ScriptType.PYDML}
*/
public ScriptType getScriptType() {
@@ -142,7 +142,7 @@ public class Script {
/**
* Set the type of script (DML or PYDML).
- *
+ *
* @param scriptType
* {@code ScriptType.DML} or {@code ScriptType.PYDML}
*/
@@ -152,7 +152,7 @@ public class Script {
/**
* Obtain the script string.
- *
+ *
* @return the script string
*/
public String getScriptString() {
@@ -161,7 +161,7 @@ public class Script {
/**
* Set the script string.
- *
+ *
* @param scriptString
* the script string
* @return {@code this} Script object to allow chaining of methods
@@ -173,7 +173,7 @@ public class Script {
/**
* Obtain the input variable names as an unmodifiable set of strings.
- *
+ *
* @return the input variable names
*/
public Set<String> getInputVariables() {
@@ -182,7 +182,7 @@ public class Script {
/**
* Obtain the output variable names as an unmodifiable set of strings.
- *
+ *
* @return the output variable names
*/
public Set<String> getOutputVariables() {
@@ -192,7 +192,7 @@ public class Script {
/**
* Obtain the symbol table, which is essentially a
* {@code HashMap<String, Data>} representing variables and their values.
- *
+ *
* @return the symbol table
*/
public LocalVariableMap getSymbolTable() {
@@ -201,7 +201,7 @@ public class Script {
/**
* Obtain an unmodifiable map of all inputs (parameters ($) and variables).
- *
+ *
* @return all inputs to the script
*/
public Map<String, Object> getInputs() {
@@ -210,7 +210,7 @@ public class Script {
/**
* Obtain an unmodifiable map of input matrix metadata.
- *
+ *
* @return input matrix metadata
*/
public Map<String, MatrixMetadata> getInputMatrixMetadata() {
@@ -219,7 +219,7 @@ public class Script {
/**
* Pass a map of inputs to the script.
- *
+ *
* @param inputs
* map of inputs (parameters ($) and variables).
* @return {@code this} Script object to allow chaining of methods
@@ -232,9 +232,13 @@ public class Script {
return this;
}
+ public Script input(Map<String, Object> inputs) {
+ return in(inputs);
+ }
+
/**
* Pass a Scala Map of inputs to the script.
- *
+ *
* @param inputs
* Scala Map of inputs (parameters ($) and variables).
* @return {@code this} Script object to allow chaining of methods
@@ -246,12 +250,16 @@ public class Script {
return this;
}
+ public Script input(scala.collection.Map<String, Object> inputs) {
+ return in(inputs);
+ }
+
/**
* Pass a Scala Seq of inputs to the script. The inputs are either two-value
* or three-value tuples, where the first value is the variable name, the
* second value is the variable value, and the third optional value is the
* metadata.
- *
+ *
* @param inputs
* Scala Seq of inputs (parameters ($) and variables).
* @return {@code this} Script object to allow chaining of methods
@@ -274,9 +282,13 @@ public class Script {
return this;
}
+ public Script input(scala.collection.Seq<Object> inputs) {
+ return in(inputs);
+ }
+
/**
* Obtain an unmodifiable map of all input parameters ($).
- *
+ *
* @return input parameters ($)
*/
public Map<String, Object> getInputParameters() {
@@ -285,7 +297,7 @@ public class Script {
/**
* Register an input (parameter ($) or variable).
- *
+ *
* @param name
* name of the input
* @param value
@@ -296,10 +308,14 @@ public class Script {
return in(name, value, null);
}
+ public Script input(String name, Object value) {
+ return in(name, value);
+ }
+
/**
* Register an input (parameter ($) or variable) with optional matrix
* metadata.
- *
+ *
* @param name
* name of the input
* @param value
@@ -336,9 +352,13 @@ public class Script {
return this;
}
+ public Script input(String name, Object value, MatrixMetadata matrixMetadata) {
+ return in(name, value, matrixMetadata);
+ }
+
/**
* Register an output variable.
- *
+ *
* @param outputName
* name of the output variable
* @return {@code this} Script object to allow chaining of methods
@@ -350,7 +370,7 @@ public class Script {
/**
* Register output variables.
- *
+ *
* @param outputNames
* names of the output variables
* @return {@code this} Script object to allow chaining of methods
@@ -411,7 +431,7 @@ public class Script {
/**
* Obtain the results of the script execution.
- *
+ *
* @return the results of the script execution.
*/
public MLResults results() {
@@ -420,7 +440,7 @@ public class Script {
/**
* Obtain the results of the script execution.
- *
+ *
* @return the results of the script execution.
*/
public MLResults getResults() {
@@ -429,7 +449,7 @@ public class Script {
/**
* Set the results of the script execution.
- *
+ *
* @param results
* the results of the script execution.
*/
@@ -439,7 +459,7 @@ public class Script {
/**
* Obtain the script executor used by this Script.
- *
+ *
* @return the ScriptExecutor used by this Script.
*/
public ScriptExecutor getScriptExecutor() {
@@ -448,7 +468,7 @@ public class Script {
/**
* Set the ScriptExecutor used by this Script.
- *
+ *
* @param scriptExecutor
* the script executor
*/
@@ -458,7 +478,7 @@ public class Script {
/**
* Is the script type DML?
- *
+ *
* @return {@code true} if the script type is DML, {@code false} otherwise
*/
public boolean isDML() {
@@ -467,7 +487,7 @@ public class Script {
/**
* Is the script type PYDML?
- *
+ *
* @return {@code true} if the script type is PYDML, {@code false} otherwise
*/
public boolean isPYDML() {
@@ -477,7 +497,7 @@ public class Script {
/**
* Generate the script execution string, which adds read/load/write/save
* statements to the beginning and end of the script to execute.
- *
+ *
* @return the script execution string
*/
public String getScriptExecutionString() {
@@ -545,7 +565,7 @@ public class Script {
* script type, inputs, outputs, input parameters, input variables, output
* variables, the symbol table, the script string, and the script execution
* string.
- *
+ *
* @return information about this script as a String
*/
public String info() {
@@ -576,7 +596,7 @@ public class Script {
/**
* Display the script inputs.
- *
+ *
* @return the script inputs
*/
public String displayInputs() {
@@ -585,7 +605,7 @@ public class Script {
/**
* Display the script outputs.
- *
+ *
* @return the script outputs as a String
*/
public String displayOutputs() {
@@ -594,7 +614,7 @@ public class Script {
/**
* Display the script input parameters.
- *
+ *
* @return the script input parameters as a String
*/
public String displayInputParameters() {
@@ -603,7 +623,7 @@ public class Script {
/**
* Display the script input variables.
- *
+ *
* @return the script input variables as a String
*/
public String displayInputVariables() {
@@ -612,7 +632,7 @@ public class Script {
/**
* Display the script output variables.
- *
+ *
* @return the script output variables as a String
*/
public String displayOutputVariables() {
@@ -621,7 +641,7 @@ public class Script {
/**
* Display the script symbol table.
- *
+ *
* @return the script symbol table as a String
*/
public String displaySymbolTable() {
@@ -630,7 +650,7 @@ public class Script {
/**
* Obtain the script name.
- *
+ *
* @return the script name
*/
public String getName() {
@@ -639,7 +659,7 @@ public class Script {
/**
* Set the script name.
- *
+ *
* @param name
* the script name
* @return {@code this} Script object to allow chaining of methods
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4fff6f76/src/main/python/SystemML.py
----------------------------------------------------------------------
diff --git a/src/main/python/SystemML.py b/src/main/python/SystemML.py
new file mode 100644
index 0000000..85731ed
--- /dev/null
+++ b/src/main/python/SystemML.py
@@ -0,0 +1,203 @@
+#!/usr/bin/python
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+import os
+
+from py4j.java_gateway import JavaObject
+from py4j.java_collections import ListConverter, JavaArray, JavaList
+from pyspark import SparkContext, RDD
+from pyspark.mllib.common import _java2py, _py2java
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+from pyspark.sql import DataFrame
+
+
+class MLResults(object):
+ """
+ Wrapper around the Java ML Results object.
+
+ Parameters
+ ----------
+ results: JavaObject
+ A Java MLResults object as returned by calling ml.execute()
+
+ sc: SparkContext
+ SparkContext
+ """
+ def __init__(self, results, sc):
+ self._java_results = results
+ self.sc = sc
+
+ def __repr__(self):
+ return "MLResults"
+
+ def get(self, *outputs):
+ """
+ Parameters
+ ----------
+ outputs: string, list of strings
+ Output variables as defined inside the DML script.
+ """
+ outs = [_java2py(self.sc, self._java_results.get(out)) for out in outputs]
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+class Script(object):
+ """
+ Instance of a DML/PyDML Script.
+
+ Parameters
+ ----------
+ path: string
+ Can be either a file path to a DML script or a DML script itself.
+ """
+ def __init__(self, scriptString, scriptType="dml"):
+ self.scriptString = scriptString
+ self.scriptType = scriptType
+ self._input = {}
+ self._output = []
+
+ def input(self, *args, **kwargs):
+ """
+ Parameters
+ ----------
+ args: name, value tuple
+ where name is a string and currently supported value formats
+ are double, string, rdds and list of such object.
+
+ kwargs: dict of name, value pairs
+ To know what formats are supported for name and value, look above.
+ """
+ if args and len(args) != 2:
+ raise ValueError("Expected name, value pair.")
+ elif args:
+ self._input[args[0]] = args[1]
+ for name, value in kwargs.items():
+ self._input[name] = value
+ return self
+
+ def out(self, *names):
+ """
+ Parameters
+ ----------
+ outputs: string, list of strings
+ Output variables as defined inside the DML script.
+ """
+ self._output.extend(names)
+ return self
+
+
+def pydml(scriptString):
+ """
+ Create a pydml script object based on a string.
+
+ Parameters
+ ----------
+ scriptString: string
+ Can be a path to a pydml script or a pydml script itself.
+
+ Returns
+ -------
+ script: Script instance
+ Instance of a script object.
+ """
+ if not isinstance(scriptString, str):
+ raise ValueError("scriptString should be a string, got %s" % type(scriptString))
+ return Script(scriptString, scriptType="pydml")
+
+
+def dml(scriptString):
+ """
+ Create a dml script object based on a string.
+
+ Parameters
+ ----------
+ scriptString: string
+ Can be a path to a dml script or a dml script itself.
+
+ Returns
+ -------
+ script: Script instance
+ Instance of a script object.
+ """
+ if not isinstance(scriptString, str):
+ raise ValueError("scriptString should be a string, got %s" % type(scriptString))
+ return Script(scriptString, scriptType="dml")
+
+
+class MLContext(object):
+ """
+ Wrapper around the new SystemML MLContext.
+
+ Parameters
+ ----------
+ sc: SparkContext
+ SparkContext
+ """
+ def __init__(self, sc):
+ if not isinstance(sc, SparkContext):
+ raise ValueError("Expected sc to be a SparkContext, got " % sc)
+ self._sc = sc
+ self._ml = sc._jvm.org.apache.sysml.api.mlcontext.MLContext(sc._jsc)
+
+ def __repr__(self):
+ return "MLContext"
+
+ def execute(self, script):
+ """
+ Execute a DML / PyDML script.
+
+ Parameters
+ ----------
+ script: Script instance
+ Script instance defined with the appropriate input and output variables.
+
+ Returns
+ -------
+ ml_results: MLResults
+ MLResults instance.
+ """
+ if not isinstance(script, Script):
+ raise ValueError("Expected script to be an instance of Script")
+ scriptString = script.scriptString
+ if script.scriptType == "dml":
+ if scriptString.endswith(".dml"):
+ if os.path.exists(scriptString):
+ script_java = self._sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile(scriptString)
+ else:
+ raise ValueError("path: %s does not exist" % scriptString)
+ else:
+ script_java = self._sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.dml(scriptString)
+ elif script.scriptType == "pydml":
+ if scriptString.endswith(".pydml"):
+ if os.path.exists(scriptString):
+ script_java = self._sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile(scriptString)
+ else:
+ raise ValueError("path: %s does not exist" % scriptString)
+ else:
+ script_java = self._sc._jvm.org.apache.sysml.api.mlcontext.ScriptFactory.pydml(scriptString)
+
+ for key, val in script._input.items():
+ script_java.input(key, _py2java(self._sc, val))
+ for val in script._output:
+ script_java.out(val)
+ return MLResults(self._ml.execute(script_java), self._sc)
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/4fff6f76/src/main/python/SystemMLtests.py
----------------------------------------------------------------------
diff --git a/src/main/python/SystemMLtests.py b/src/main/python/SystemMLtests.py
new file mode 100644
index 0000000..5dcae4a
--- /dev/null
+++ b/src/main/python/SystemMLtests.py
@@ -0,0 +1,87 @@
+#!/usr/bin/python
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+import unittest
+
+from pyspark.sql import SQLContext
+from pyspark.context import SparkContext
+
+from SystemML import dml
+from SystemML import pydml
+from SystemML import MLContext
+
+sc = SparkContext()
+ml = MLContext(sc)
+
+class TestAPI(unittest.TestCase):
+
+ def test_output_string(self):
+ script = dml("x1 = 'Hello World'").out("x1")
+ self.assertEqual(ml.execute(script).get("x1"), "Hello World")
+
+ def test_output_list(self):
+ script = """
+ x1 = 0.2
+ x2 = x1 + 1
+ x3 = x1 + 2
+ """
+ script = dml(script).out("x1", "x2", "x3")
+ self.assertEqual(ml.execute(script).get("x1", "x2"), [0.2, 1.2])
+ self.assertEqual(ml.execute(script).get("x1", "x3"), [0.2, 2.2])
+
+ def test_input_single(self):
+ script = """
+ x2 = x1 + 1
+ x3 = x1 + 2
+ """
+ script = dml(script).input("x1", 5).out("x2", "x3")
+ self.assertEqual(ml.execute(script).get("x2", "x3"), [6, 7])
+
+ def test_input(self):
+ script = """
+ x3 = x1 + x2
+ """
+ script = dml(script).input(x1=5, x2=3).out("x3")
+ self.assertEqual(ml.execute(script).get("x3"), 8)
+
+ def test_rdd(self):
+ sums = """
+ s1 = sum(m1)
+ s2 = sum(m2)
+ s3 = 'whatever'
+ """
+ rdd1 = sc.parallelize(["1.0,2.0", "3.0,4.0"])
+ rdd2 = sc.parallelize(["5.0,6.0", "7.0,8.0"])
+ script = dml(sums).input(m1=rdd1).input(m2=rdd2).out("s1", "s2", "s3")
+ self.assertEqual(
+ ml.execute(script).get("s1", "s2", "s3"), [10.0, 26.0, "whatever"])
+
+ def test_pydml(self):
+ script = "A = full('1 2 3 4 5 6 7 8 9', rows=3, cols=3)\nx = toString(A)"
+ script = pydml(script).out("x")
+ self.assertEqual(
+ ml.execute(script).get("x"),
+ '1.000 2.000 3.000\n4.000 5.000 6.000\n7.000 8.000 9.000\n'
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()