You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2016/07/29 00:14:38 UTC
[1/4] incubator-systemml git commit: [SYSTEMML-593] MLContext redesign
Repository: incubator-systemml
Updated Branches:
refs/heads/master 873bae76b -> 457bbd3a4
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
new file mode 100644
index 0000000..4441105
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -0,0 +1,1713 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.mlcontext;
+
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dml;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromInputStream;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromLocalFile;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromUrl;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.pydml;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromInputStream;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromLocalFile;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromUrl;
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.InputStream;
+import java.net.MalformedURLException;
+import java.net.URL;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.sysml.api.mlcontext.BinaryBlockMatrix;
+import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLContextConversionUtil;
+import org.apache.sysml.api.mlcontext.MLContextException;
+import org.apache.sysml.api.mlcontext.MLResults;
+import org.apache.sysml.api.mlcontext.MatrixFormat;
+import org.apache.sysml.api.mlcontext.MatrixMetadata;
+import org.apache.sysml.api.mlcontext.Script;
+import org.apache.sysml.api.mlcontext.ScriptExecutor;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Test;
+
+import scala.Tuple2;
+import scala.Tuple3;
+import scala.collection.Iterator;
+import scala.collection.JavaConversions;
+import scala.collection.Seq;
+
+public class MLContextTest extends AutomatedTestBase {
+ protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext";
+ protected final static String TEST_NAME = "MLContext";
+
+ static SparkConf conf;
+ static JavaSparkContext sc;
+ MLContext ml;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_DIR, TEST_NAME);
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ if (conf == null) {
+ conf = new SparkConf().setAppName("MLContextTest").setMaster("local");
+ }
+ if (sc == null) {
+ sc = new JavaSparkContext(conf);
+ }
+ ml = new MLContext(sc);
+ // ml.setExplain(true);
+ }
+
+ @Test
+ public void testCreateDMLScriptBasedOnStringAndExecute() {
+ System.out.println("MLContextTest - create DML script based on string and execute");
+ String testString = "Create DML script based on string and execute";
+ setExpectedStdOut(testString);
+ Script script = dml("print('" + testString + "');");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testCreatePYDMLScriptBasedOnStringAndExecute() {
+ System.out.println("MLContextTest - create PYDML script based on string and execute");
+ String testString = "Create PYDML script based on string and execute";
+ setExpectedStdOut(testString);
+ Script script = pydml("print('" + testString + "')");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testCreateDMLScriptBasedOnFileAndExecute() {
+ System.out.println("MLContextTest - create DML script based on file and execute");
+ setExpectedStdOut("hello world");
+ Script script = dmlFromFile(baseDirectory + File.separator + "hello-world.dml");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testCreatePYDMLScriptBasedOnFileAndExecute() {
+ System.out.println("MLContextTest - create PYDML script based on file and execute");
+ setExpectedStdOut("hello world");
+ Script script = pydmlFromFile(baseDirectory + File.separator + "hello-world.pydml");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testCreateDMLScriptBasedOnInputStreamAndExecute() throws FileNotFoundException {
+ System.out.println("MLContextTest - create DML script based on InputStream and execute");
+ setExpectedStdOut("hello world");
+ File file = new File(baseDirectory + File.separator + "hello-world.dml");
+ InputStream is = new FileInputStream(file);
+ Script script = dmlFromInputStream(is);
+ ml.execute(script);
+ }
+
+ @Test
+ public void testCreatePYDMLScriptBasedOnInputStreamAndExecute() throws FileNotFoundException {
+ System.out.println("MLContextTest - create PYDML script based on InputStream and execute");
+ setExpectedStdOut("hello world");
+ File file = new File(baseDirectory + File.separator + "hello-world.pydml");
+ InputStream is = new FileInputStream(file);
+ Script script = pydmlFromInputStream(is);
+ ml.execute(script);
+ }
+
+ @Test
+ public void testCreateDMLScriptBasedOnLocalFileAndExecute() {
+ System.out.println("MLContextTest - create DML script based on local file and execute");
+ setExpectedStdOut("hello world");
+ File file = new File(baseDirectory + File.separator + "hello-world.dml");
+ Script script = dmlFromLocalFile(file);
+ ml.execute(script);
+ }
+
+ @Test
+ public void testCreatePYDMLScriptBasedOnLocalFileAndExecute() {
+ System.out.println("MLContextTest - create PYDML script based on local file and execute");
+ setExpectedStdOut("hello world");
+ File file = new File(baseDirectory + File.separator + "hello-world.pydml");
+ Script script = pydmlFromLocalFile(file);
+ ml.execute(script);
+ }
+
+ @Test
+ public void testCreateDMLScriptBasedOnURL() throws MalformedURLException {
+ System.out.println("MLContextTest - create DML script based on URL");
+ String urlString = "https://raw.githubusercontent.com/apache/incubator-systemml/master/src/test/scripts/applications/hits/HITS.dml";
+ URL url = new URL(urlString);
+ Script script = dmlFromUrl(url);
+ String expectedContent = "Licensed to the Apache Software Foundation";
+ String s = script.getScriptString();
+ assertTrue("Script string doesn't contain expected content: " + expectedContent, s.contains(expectedContent));
+ }
+
+ @Test
+ public void testCreatePYDMLScriptBasedOnURL() throws MalformedURLException {
+ System.out.println("MLContextTest - create PYDML script based on URL");
+ String urlString = "https://raw.githubusercontent.com/apache/incubator-systemml/master/src/test/scripts/applications/hits/HITS.pydml";
+ URL url = new URL(urlString);
+ Script script = pydmlFromUrl(url);
+ String expectedContent = "Licensed to the Apache Software Foundation";
+ String s = script.getScriptString();
+ assertTrue("Script string doesn't contain expected content: " + expectedContent, s.contains(expectedContent));
+ }
+
+ @Test
+ public void testCreateDMLScriptBasedOnURLString() throws MalformedURLException {
+ System.out.println("MLContextTest - create DML script based on URL string");
+ String urlString = "https://raw.githubusercontent.com/apache/incubator-systemml/master/src/test/scripts/applications/hits/HITS.dml";
+ Script script = dmlFromUrl(urlString);
+ String expectedContent = "Licensed to the Apache Software Foundation";
+ String s = script.getScriptString();
+ assertTrue("Script string doesn't contain expected content: " + expectedContent, s.contains(expectedContent));
+ }
+
+ @Test
+ public void testCreatePYDMLScriptBasedOnURLString() throws MalformedURLException {
+ System.out.println("MLContextTest - create PYDML script based on URL string");
+ String urlString = "https://raw.githubusercontent.com/apache/incubator-systemml/master/src/test/scripts/applications/hits/HITS.pydml";
+ Script script = pydmlFromUrl(urlString);
+ String expectedContent = "Licensed to the Apache Software Foundation";
+ String s = script.getScriptString();
+ assertTrue("Script string doesn't contain expected content: " + expectedContent, s.contains(expectedContent));
+ }
+
+ @Test
+ public void testExecuteDMLScript() {
+ System.out.println("MLContextTest - execute DML script");
+ String testString = "hello dml world!";
+ setExpectedStdOut(testString);
+ Script script = new Script("print('" + testString + "');", org.apache.sysml.api.mlcontext.ScriptType.DML);
+ ml.execute(script);
+ }
+
+ @Test
+ public void testExecutePYDMLScript() {
+ System.out.println("MLContextTest - execute PYDML script");
+ String testString = "hello pydml world!";
+ setExpectedStdOut(testString);
+ Script script = new Script("print('" + testString + "')", org.apache.sysml.api.mlcontext.ScriptType.PYDML);
+ ml.execute(script);
+ }
+
+ @Test
+ public void testInputParametersAddDML() {
+ System.out.println("MLContextTest - input parameters add DML");
+
+ String s = "x = $X; y = $Y; print('x + y = ' + (x + y));";
+ Script script = dml(s).in("$X", 3).in("$Y", 4);
+ setExpectedStdOut("x + y = 7");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testInputParametersAddPYDML() {
+ System.out.println("MLContextTest - input parameters add PYDML");
+
+ String s = "x = $X\ny = $Y\nprint('x + y = ' + (x + y))";
+ Script script = pydml(s).in("$X", 3).in("$Y", 4);
+ setExpectedStdOut("x + y = 7");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testJavaRDDCSVSumDML() {
+ System.out.println("MLContextTest - JavaRDD<String> CSV sum DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,2,3");
+ list.add("4,5,6");
+ list.add("7,8,9");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", javaRDD);
+ setExpectedStdOut("sum: 45.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testJavaRDDCSVSumPYDML() {
+ System.out.println("MLContextTest - JavaRDD<String> CSV sum PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,2,3");
+ list.add("4,5,6");
+ list.add("7,8,9");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ Script script = pydml("print('sum: ' + sum(M))").in("M", javaRDD);
+ setExpectedStdOut("sum: 45.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testJavaRDDIJVSumDML() {
+ System.out.println("MLContextTest - JavaRDD<String> IJV sum DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1 1 5");
+ list.add("2 2 5");
+ list.add("3 3 5");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, 3, 3);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", javaRDD, mm);
+ setExpectedStdOut("sum: 15.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testJavaRDDIJVSumPYDML() {
+ System.out.println("MLContextTest - JavaRDD<String> IJV sum PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1 1 5");
+ list.add("2 2 5");
+ list.add("3 3 5");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, 3, 3);
+
+ Script script = pydml("print('sum: ' + sum(M))").in("M", javaRDD, mm);
+ setExpectedStdOut("sum: 15.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testJavaRDDAndInputParameterDML() {
+ System.out.println("MLContextTest - JavaRDD<String> and input parameter DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,2");
+ list.add("3,4");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ String s = "M = M + $X; print('sum: ' + sum(M));";
+ Script script = dml(s).in("M", javaRDD).in("$X", 1);
+ setExpectedStdOut("sum: 14.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testJavaRDDAndInputParameterPYDML() {
+ System.out.println("MLContextTest - JavaRDD<String> and input parameter PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,2");
+ list.add("3,4");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ String s = "M = M + $X\nprint('sum: ' + sum(M))";
+ Script script = pydml(s).in("M", javaRDD).in("$X", 1);
+ setExpectedStdOut("sum: 14.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testInputMapDML() {
+ System.out.println("MLContextTest - input map DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("10,20");
+ list.add("30,40");
+ final JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ Map<String, Object> inputs = new HashMap<String, Object>() {
+ private static final long serialVersionUID = 1L;
+ {
+ put("$X", 2);
+ put("M", javaRDD);
+ }
+ };
+
+ String s = "M = M + $X; print('sum: ' + sum(M));";
+ Script script = dml(s).in(inputs);
+ setExpectedStdOut("sum: 108.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testInputMapPYDML() {
+ System.out.println("MLContextTest - input map PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("10,20");
+ list.add("30,40");
+ final JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ Map<String, Object> inputs = new HashMap<String, Object>() {
+ private static final long serialVersionUID = 1L;
+ {
+ put("$X", 2);
+ put("M", javaRDD);
+ }
+ };
+
+ String s = "M = M + $X\nprint('sum: ' + sum(M))";
+ Script script = pydml(s).in(inputs);
+ setExpectedStdOut("sum: 108.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testCustomExecutionStepDML() {
+ System.out.println("MLContextTest - custom execution step DML");
+ String testString = "custom execution step";
+ setExpectedStdOut(testString);
+ Script script = new Script("print('" + testString + "');", org.apache.sysml.api.mlcontext.ScriptType.DML);
+
+ ScriptExecutor scriptExecutor = new ScriptExecutor() {
+ // turn off global data flow optimization check
+ @Override
+ protected void globalDataFlowOptimization() {
+ return;
+ }
+ };
+ ml.execute(script, scriptExecutor);
+ }
+
+ @Test
+ public void testCustomExecutionStepPYDML() {
+ System.out.println("MLContextTest - custom execution step PYDML");
+ String testString = "custom execution step";
+ setExpectedStdOut(testString);
+ Script script = new Script("print('" + testString + "')", org.apache.sysml.api.mlcontext.ScriptType.PYDML);
+
+ ScriptExecutor scriptExecutor = new ScriptExecutor() {
+ // turn off global data flow optimization check
+ @Override
+ protected void globalDataFlowOptimization() {
+ return;
+ }
+ };
+ ml.execute(script, scriptExecutor);
+ }
+
+ @Test
+ public void testRDDSumCSVDML() {
+ System.out.println("MLContextTest - RDD<String> CSV sum DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,1,1");
+ list.add("2,2,2");
+ list.add("3,3,3");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+ RDD<String> rdd = JavaRDD.toRDD(javaRDD);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", rdd);
+ setExpectedStdOut("sum: 18.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testRDDSumCSVPYDML() {
+ System.out.println("MLContextTest - RDD<String> CSV sum PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,1,1");
+ list.add("2,2,2");
+ list.add("3,3,3");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+ RDD<String> rdd = JavaRDD.toRDD(javaRDD);
+
+ Script script = pydml("print('sum: ' + sum(M))").in("M", rdd);
+ setExpectedStdOut("sum: 18.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testRDDSumIJVDML() {
+ System.out.println("MLContextTest - RDD<String> IJV sum DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1 1 1");
+ list.add("2 1 2");
+ list.add("1 2 3");
+ list.add("3 3 4");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+ RDD<String> rdd = JavaRDD.toRDD(javaRDD);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, 3, 3);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", rdd, mm);
+ setExpectedStdOut("sum: 10.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testRDDSumIJVPYDML() {
+ System.out.println("MLContextTest - RDD<String> IJV sum PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1 1 1");
+ list.add("2 1 2");
+ list.add("1 2 3");
+ list.add("3 3 4");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+ RDD<String> rdd = JavaRDD.toRDD(javaRDD);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.IJV, 3, 3);
+
+ Script script = pydml("print('sum: ' + sum(M))").in("M", rdd, mm);
+ setExpectedStdOut("sum: 10.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testDataFrameSumDML() {
+ System.out.println("MLContextTest - DataFrame sum DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("10,20,30");
+ list.add("40,50,60");
+ list.add("70,80,90");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame);
+ setExpectedStdOut("sum: 450.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testDataFrameSumPYDML() {
+ System.out.println("MLContextTest - DataFrame sum PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("10,20,30");
+ list.add("40,50,60");
+ list.add("70,80,90");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame);
+ setExpectedStdOut("sum: 450.0");
+ ml.execute(script);
+ }
+
+ static class CommaSeparatedValueStringToRow implements Function<String, Row> {
+ private static final long serialVersionUID = -7871020122671747808L;
+
+ public Row call(String str) throws Exception {
+ String[] fields = str.split(",");
+ return RowFactory.create((Object[]) fields);
+ }
+ }
+
+ @Test
+ public void testCSVMatrixFileInputParameterSumDML() {
+ System.out.println("MLContextTest - CSV matrix file input parameter sum DML");
+
+ String s = "M = read($Min); print('sum: ' + sum(M));";
+ String csvFile = baseDirectory + File.separator + "1234.csv";
+ setExpectedStdOut("sum: 10.0");
+ ml.execute(dml(s).in("$Min", csvFile));
+ }
+
+ @Test
+ public void testCSVMatrixFileInputParameterSumPYDML() {
+ System.out.println("MLContextTest - CSV matrix file input parameter sum PYDML");
+
+ String s = "M = load($Min)\nprint('sum: ' + sum(M))";
+ String csvFile = baseDirectory + File.separator + "1234.csv";
+ setExpectedStdOut("sum: 10.0");
+ ml.execute(pydml(s).in("$Min", csvFile));
+ }
+
+ @Test
+ public void testCSVMatrixFileInputVariableSumDML() {
+ System.out.println("MLContextTest - CSV matrix file input variable sum DML");
+
+ String s = "M = read(Min); print('sum: ' + sum(M));";
+ String csvFile = baseDirectory + File.separator + "1234.csv";
+ setExpectedStdOut("sum: 10.0");
+ ml.execute(dml(s).in("Min", csvFile));
+ }
+
+ @Test
+ public void testCSVMatrixFileInputVariableSumPYDML() {
+ System.out.println("MLContextTest - CSV matrix file input variable sum PYDML");
+
+ String s = "M = load(Min)\nprint('sum: ' + sum(M))";
+ String csvFile = baseDirectory + File.separator + "1234.csv";
+ setExpectedStdOut("sum: 10.0");
+ ml.execute(pydml(s).in("Min", csvFile));
+ }
+
+ @Test
+ public void test2DDoubleSumDML() {
+ System.out.println("MLContextTest - two-dimensional double array sum DML");
+
+ double[][] matrix = new double[][] { { 10.0, 20.0 }, { 30.0, 40.0 } };
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", matrix);
+ setExpectedStdOut("sum: 100.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void test2DDoubleSumPYDML() {
+ System.out.println("MLContextTest - two-dimensional double array sum PYDML");
+
+ double[][] matrix = new double[][] { { 10.0, 20.0 }, { 30.0, 40.0 } };
+
+ Script script = pydml("print('sum: ' + sum(M))").in("M", matrix);
+ setExpectedStdOut("sum: 100.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testAddScalarIntegerInputsDML() {
+ System.out.println("MLContextTest - add scalar integer inputs DML");
+ String s = "total = in1 + in2; print('total: ' + total);";
+ Script script = dml(s).in("in1", 1).in("in2", 2);
+ setExpectedStdOut("total: 3");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testAddScalarIntegerInputsPYDML() {
+ System.out.println("MLContextTest - add scalar integer inputs PYDML");
+ String s = "total = in1 + in2\nprint('total: ' + total)";
+ Script script = pydml(s).in("in1", 1).in("in2", 2);
+ setExpectedStdOut("total: 3");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testInputScalaMapDML() {
+ System.out.println("MLContextTest - input Scala map DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("10,20");
+ list.add("30,40");
+ final JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ Map<String, Object> inputs = new HashMap<String, Object>() {
+ private static final long serialVersionUID = 1L;
+ {
+ put("$X", 2);
+ put("M", javaRDD);
+ }
+ };
+
+ scala.collection.mutable.Map<String, Object> scalaMap = JavaConversions.asScalaMap(inputs);
+
+ String s = "M = M + $X; print('sum: ' + sum(M));";
+ Script script = dml(s).in(scalaMap);
+ setExpectedStdOut("sum: 108.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testInputScalaMapPYDML() {
+ System.out.println("MLContextTest - input Scala map PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("10,20");
+ list.add("30,40");
+ final JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ Map<String, Object> inputs = new HashMap<String, Object>() {
+ private static final long serialVersionUID = 1L;
+ {
+ put("$X", 2);
+ put("M", javaRDD);
+ }
+ };
+
+ scala.collection.mutable.Map<String, Object> scalaMap = JavaConversions.asScalaMap(inputs);
+
+ String s = "M = M + $X\nprint('sum: ' + sum(M))";
+ Script script = pydml(s).in(scalaMap);
+ setExpectedStdOut("sum: 108.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testOutputDoubleArrayMatrixDML() {
+ System.out.println("MLContextTest - output double array matrix DML");
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ double[][] matrix = ml.execute(dml(s).out("M")).getDoubleMatrix("M");
+ Assert.assertEquals(1.0, matrix[0][0], 0);
+ Assert.assertEquals(2.0, matrix[0][1], 0);
+ Assert.assertEquals(3.0, matrix[1][0], 0);
+ Assert.assertEquals(4.0, matrix[1][1], 0);
+ }
+
+ @Test
+ public void testOutputDoubleArrayMatrixPYDML() {
+ System.out.println("MLContextTest - output double array matrix PYDML");
+ String s = "M = full('1 2 3 4', rows=2, cols=2)";
+ double[][] matrix = ml.execute(pydml(s).out("M")).getDoubleMatrix("M");
+ Assert.assertEquals(1.0, matrix[0][0], 0);
+ Assert.assertEquals(2.0, matrix[0][1], 0);
+ Assert.assertEquals(3.0, matrix[1][0], 0);
+ Assert.assertEquals(4.0, matrix[1][1], 0);
+ }
+
+ @Test
+ public void testOutputScalarLongDML() {
+ System.out.println("MLContextTest - output scalar long DML");
+ String s = "m = 5;";
+ long result = ml.execute(dml(s).out("m")).getLong("m");
+ Assert.assertEquals(5, result);
+ }
+
+ @Test
+ public void testOutputScalarLongPYDML() {
+ System.out.println("MLContextTest - output scalar long PYDML");
+ String s = "m = 5";
+ long result = ml.execute(pydml(s).out("m")).getLong("m");
+ Assert.assertEquals(5, result);
+ }
+
+ @Test
+ public void testOutputScalarDoubleDML() {
+ System.out.println("MLContextTest - output scalar double DML");
+ String s = "m = 1.23";
+ double result = ml.execute(dml(s).out("m")).getDouble("m");
+ Assert.assertEquals(1.23, result, 0);
+ }
+
+ @Test
+ public void testOutputScalarDoublePYDML() {
+ System.out.println("MLContextTest - output scalar double PYDML");
+ String s = "m = 1.23";
+ double result = ml.execute(pydml(s).out("m")).getDouble("m");
+ Assert.assertEquals(1.23, result, 0);
+ }
+
+ @Test
+ public void testOutputScalarBooleanDML() {
+ System.out.println("MLContextTest - output scalar boolean DML");
+ String s = "m = FALSE;";
+ boolean result = ml.execute(dml(s).out("m")).getBoolean("m");
+ Assert.assertEquals(false, result);
+ }
+
+ @Test
+ public void testOutputScalarBooleanPYDML() {
+ System.out.println("MLContextTest - output scalar boolean PYDML");
+ String s = "m = False";
+ boolean result = ml.execute(pydml(s).out("m")).getBoolean("m");
+ Assert.assertEquals(false, result);
+ }
+
+ @Test
+ public void testOutputScalarStringDML() {
+ System.out.println("MLContextTest - output scalar string DML");
+ String s = "m = 'hello';";
+ String result = ml.execute(dml(s).out("m")).getString("m");
+ Assert.assertEquals("hello", result);
+ }
+
+ @Test
+ public void testOutputScalarStringPYDML() {
+ System.out.println("MLContextTest - output scalar string PYDML");
+ String s = "m = 'hello'";
+ String result = ml.execute(pydml(s).out("m")).getString("m");
+ Assert.assertEquals("hello", result);
+ }
+
+ @Test
+ public void testInputFrameDML() {
+ System.out.println("MLContextTest - input frame DML");
+
+ String s = "M = read(Min, data_type='frame', format='csv'); print(toString(M));";
+ String csvFile = baseDirectory + File.separator + "one-two-three-four.csv";
+ Script script = dml(s).in("Min", csvFile);
+ setExpectedStdOut("one");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testInputFramePYDML() {
+ System.out.println("MLContextTest - input frame PYDML");
+
+ String s = "M = load(Min, data_type='frame', format='csv')\nprint(toString(M))";
+ String csvFile = baseDirectory + File.separator + "one-two-three-four.csv";
+ Script script = pydml(s).in("Min", csvFile);
+ setExpectedStdOut("one");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testOutputFrameDML() {
+ System.out.println("MLContextTest - output frame DML");
+
+ String s = "M = read(Min, data_type='frame', format='csv');";
+ String csvFile = baseDirectory + File.separator + "one-two-three-four.csv";
+ Script script = dml(s).in("Min", csvFile).out("M");
+ String[][] frame = ml.execute(script).getFrame("M");
+ Assert.assertEquals("one", frame[0][0]);
+ Assert.assertEquals("two", frame[0][1]);
+ Assert.assertEquals("three", frame[1][0]);
+ Assert.assertEquals("four", frame[1][1]);
+ }
+
+ @Test
+ public void testOutputFramePYDML() {
+ System.out.println("MLContextTest - output frame PYDML");
+
+ String s = "M = load(Min, data_type='frame', format='csv')";
+ String csvFile = baseDirectory + File.separator + "one-two-three-four.csv";
+ Script script = pydml(s).in("Min", csvFile).out("M");
+ String[][] frame = ml.execute(script).getFrame("M");
+ Assert.assertEquals("one", frame[0][0]);
+ Assert.assertEquals("two", frame[0][1]);
+ Assert.assertEquals("three", frame[1][0]);
+ Assert.assertEquals("four", frame[1][1]);
+ }
+
+ @Test
+ public void testOutputJavaRDDStringIJVDML() {
+ System.out.println("MLContextTest - output Java RDD String IJV DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("M");
+ List<String> lines = javaRDDStringIJV.collect();
+ Assert.assertEquals("1 1 1.0", lines.get(0));
+ Assert.assertEquals("1 2 2.0", lines.get(1));
+ Assert.assertEquals("2 1 3.0", lines.get(2));
+ Assert.assertEquals("2 2 4.0", lines.get(3));
+ }
+
+ @Test
+ public void testOutputJavaRDDStringIJVPYDML() {
+ System.out.println("MLContextTest - output Java RDD String IJV PYDML");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("M");
+ List<String> lines = javaRDDStringIJV.collect();
+ Assert.assertEquals("1 1 1.0", lines.get(0));
+ Assert.assertEquals("1 2 2.0", lines.get(1));
+ Assert.assertEquals("2 1 3.0", lines.get(2));
+ Assert.assertEquals("2 2 4.0", lines.get(3));
+ }
+
+ @Test
+ public void testOutputJavaRDDStringCSVDenseDML() {
+ System.out.println("MLContextTest - output Java RDD String CSV Dense DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2); print(toString(M));";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("M");
+ List<String> lines = javaRDDStringCSV.collect();
+ Assert.assertEquals("1.0,2.0", lines.get(0));
+ Assert.assertEquals("3.0,4.0", lines.get(1));
+ }
+
+ @Test
+ public void testOutputJavaRDDStringCSVDensePYDML() {
+ System.out.println("MLContextTest - output Java RDD String CSV Dense PYDML");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)\nprint(toString(M))";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("M");
+ List<String> lines = javaRDDStringCSV.collect();
+ Assert.assertEquals("1.0,2.0", lines.get(0));
+ Assert.assertEquals("3.0,4.0", lines.get(1));
+ }
+
+ /**
+ * Reading from dense and sparse matrices is handled differently, so we have
+ * tests for both dense and sparse matrices.
+ */
+ @Test
+ public void testOutputJavaRDDStringCSVSparseDML() {
+ System.out.println("MLContextTest - output Java RDD String CSV Sparse DML");
+
+ String s = "M = matrix(0, rows=10, cols=10); M[1,1]=1; M[1,2]=2; M[2,1]=3; M[2,2]=4; print(toString(M));";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("M");
+ List<String> lines = javaRDDStringCSV.collect();
+ Assert.assertEquals("1.0,2.0", lines.get(0));
+ Assert.assertEquals("3.0,4.0", lines.get(1));
+ }
+
+ /**
+ * Reading from dense and sparse matrices is handled differently, so we have
+ * tests for both dense and sparse matrices.
+ */
+ @Test
+ public void testOutputJavaRDDStringCSVSparsePYDML() {
+ System.out.println("MLContextTest - output Java RDD String CSV Sparse PYDML");
+
+ String s = "M = full(0, rows=10, cols=10)\nM[0,0]=1\nM[0,1]=2\nM[1,0]=3\nM[1,1]=4\nprint(toString(M))";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ JavaRDD<String> javaRDDStringCSV = results.getJavaRDDStringCSV("M");
+ List<String> lines = javaRDDStringCSV.collect();
+ Assert.assertEquals("1.0,2.0", lines.get(0));
+ Assert.assertEquals("3.0,4.0", lines.get(1));
+ }
+
+ @Test
+ public void testOutputRDDStringIJVDML() {
+ System.out.println("MLContextTest - output RDD String IJV DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ RDD<String> rddStringIJV = results.getRDDStringIJV("M");
+ Iterator<String> iterator = rddStringIJV.toLocalIterator();
+ Assert.assertEquals("1 1 1.0", iterator.next());
+ Assert.assertEquals("1 2 2.0", iterator.next());
+ Assert.assertEquals("2 1 3.0", iterator.next());
+ Assert.assertEquals("2 2 4.0", iterator.next());
+ }
+
+ @Test
+ public void testOutputRDDStringIJVPYDML() {
+ System.out.println("MLContextTest - output RDD String IJV PYDML");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ RDD<String> rddStringIJV = results.getRDDStringIJV("M");
+ Iterator<String> iterator = rddStringIJV.toLocalIterator();
+ Assert.assertEquals("1 1 1.0", iterator.next());
+ Assert.assertEquals("1 2 2.0", iterator.next());
+ Assert.assertEquals("2 1 3.0", iterator.next());
+ Assert.assertEquals("2 2 4.0", iterator.next());
+ }
+
+ @Test
+ public void testOutputRDDStringCSVDenseDML() {
+ System.out.println("MLContextTest - output RDD String CSV Dense DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2); print(toString(M));";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ RDD<String> rddStringCSV = results.getRDDStringCSV("M");
+ Iterator<String> iterator = rddStringCSV.toLocalIterator();
+ Assert.assertEquals("1.0,2.0", iterator.next());
+ Assert.assertEquals("3.0,4.0", iterator.next());
+ }
+
+ @Test
+ public void testOutputRDDStringCSVDensePYDML() {
+ System.out.println("MLContextTest - output RDD String CSV Dense PYDML");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)\nprint(toString(M))";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ RDD<String> rddStringCSV = results.getRDDStringCSV("M");
+ Iterator<String> iterator = rddStringCSV.toLocalIterator();
+ Assert.assertEquals("1.0,2.0", iterator.next());
+ Assert.assertEquals("3.0,4.0", iterator.next());
+ }
+
+ @Test
+ public void testOutputRDDStringCSVSparseDML() {
+ System.out.println("MLContextTest - output RDD String CSV Sparse DML");
+
+ String s = "M = matrix(0, rows=10, cols=10); M[1,1]=1; M[1,2]=2; M[2,1]=3; M[2,2]=4; print(toString(M));";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ RDD<String> rddStringCSV = results.getRDDStringCSV("M");
+ Iterator<String> iterator = rddStringCSV.toLocalIterator();
+ Assert.assertEquals("1.0,2.0", iterator.next());
+ Assert.assertEquals("3.0,4.0", iterator.next());
+ }
+
+ @Test
+ public void testOutputRDDStringCSVSparsePYDML() {
+ System.out.println("MLContextTest - output RDD String CSV Sparse PYDML");
+
+ String s = "M = full(0, rows=10, cols=10)\nM[0,0]=1\nM[0,1]=2\nM[1,0]=3\nM[1,1]=4\nprint(toString(M))";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ RDD<String> rddStringCSV = results.getRDDStringCSV("M");
+ Iterator<String> iterator = rddStringCSV.toLocalIterator();
+ Assert.assertEquals("1.0,2.0", iterator.next());
+ Assert.assertEquals("3.0,4.0", iterator.next());
+ }
+
+ @Test
+ public void testOutputDataFrameDML() {
+ System.out.println("MLContextTest - output DataFrame DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ DataFrame dataFrame = results.getDataFrame("M");
+ List<Row> list = dataFrame.collectAsList();
+ Row row1 = list.get(0);
+ Assert.assertEquals(0.0, row1.getDouble(0), 0.0);
+ Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
+ Assert.assertEquals(2.0, row1.getDouble(2), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertEquals(1.0, row2.getDouble(0), 0.0);
+ Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
+ Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFramePYDML() {
+ System.out.println("MLContextTest - output DataFrame PYDML");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ DataFrame dataFrame = results.getDataFrame("M");
+ List<Row> list = dataFrame.collectAsList();
+ Row row1 = list.get(0);
+ Assert.assertEquals(0.0, row1.getDouble(0), 0.0);
+ Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
+ Assert.assertEquals(2.0, row1.getDouble(2), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertEquals(1.0, row2.getDouble(0), 0.0);
+ Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
+ Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
+ }
+
+ @Test
+ public void testTwoScriptsDML() {
+ System.out.println("MLContextTest - two scripts with inputs and outputs DML");
+
+ double[][] m1 = new double[][] { { 1.0, 2.0 }, { 3.0, 4.0 } };
+ String s1 = "sum1 = sum(m1);";
+ double sum1 = ml.execute(dml(s1).in("m1", m1).out("sum1")).getDouble("sum1");
+ Assert.assertEquals(10.0, sum1, 0.0);
+
+ double[][] m2 = new double[][] { { 5.0, 6.0 }, { 7.0, 8.0 } };
+ String s2 = "sum2 = sum(m2);";
+ double sum2 = ml.execute(dml(s2).in("m2", m2).out("sum2")).getDouble("sum2");
+ Assert.assertEquals(26.0, sum2, 0.0);
+ }
+
+ @Test
+ public void testTwoScriptsPYDML() {
+ System.out.println("MLContextTest - two scripts with inputs and outputs PYDML");
+
+ double[][] m1 = new double[][] { { 1.0, 2.0 }, { 3.0, 4.0 } };
+ String s1 = "sum1 = sum(m1)";
+ double sum1 = ml.execute(pydml(s1).in("m1", m1).out("sum1")).getDouble("sum1");
+ Assert.assertEquals(10.0, sum1, 0.0);
+
+ double[][] m2 = new double[][] { { 5.0, 6.0 }, { 7.0, 8.0 } };
+ String s2 = "sum2 = sum(m2)";
+ double sum2 = ml.execute(pydml(s2).in("m2", m2).out("sum2")).getDouble("sum2");
+ Assert.assertEquals(26.0, sum2, 0.0);
+ }
+
+ @Test
+ public void testOneScriptTwoExecutionsDML() {
+ System.out.println("MLContextTest - one script with two executions DML");
+
+ Script script = new Script(org.apache.sysml.api.mlcontext.ScriptType.DML);
+
+ double[][] m1 = new double[][] { { 1.0, 2.0 }, { 3.0, 4.0 } };
+ script.setScriptString("sum1 = sum(m1);").in("m1", m1).out("sum1");
+ ml.execute(script);
+ Assert.assertEquals(10.0, script.results().getDouble("sum1"), 0.0);
+
+ script.clearAll();
+
+ double[][] m2 = new double[][] { { 5.0, 6.0 }, { 7.0, 8.0 } };
+ script.setScriptString("sum2 = sum(m2);").in("m2", m2).out("sum2");
+ ml.execute(script);
+ Assert.assertEquals(26.0, script.results().getDouble("sum2"), 0.0);
+ }
+
+ @Test
+ public void testOneScriptTwoExecutionsPYDML() {
+ System.out.println("MLContextTest - one script with two executions PYDML");
+
+ Script script = new Script(org.apache.sysml.api.mlcontext.ScriptType.PYDML);
+
+ double[][] m1 = new double[][] { { 1.0, 2.0 }, { 3.0, 4.0 } };
+ script.setScriptString("sum1 = sum(m1)").in("m1", m1).out("sum1");
+ ml.execute(script);
+ Assert.assertEquals(10.0, script.results().getDouble("sum1"), 0.0);
+
+ script.clearAll();
+
+ double[][] m2 = new double[][] { { 5.0, 6.0 }, { 7.0, 8.0 } };
+ script.setScriptString("sum2 = sum(m2)").in("m2", m2).out("sum2");
+ ml.execute(script);
+ Assert.assertEquals(26.0, script.results().getDouble("sum2"), 0.0);
+ }
+
+ @Test
+ public void testInputParameterBooleanDML() {
+ System.out.println("MLContextTest - input parameter boolean DML");
+
+ String s = "x = $X; if (x == TRUE) { print('yes'); }";
+ Script script = dml(s).in("$X", true);
+ setExpectedStdOut("yes");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testInputParameterBooleanPYDML() {
+ System.out.println("MLContextTest - input parameter boolean PYDML");
+
+ String s = "x = $X\nif (x == True):\n print('yes')";
+ Script script = pydml(s).in("$X", true);
+ setExpectedStdOut("yes");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testMultipleOutDML() {
+ System.out.println("MLContextTest - multiple out DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2); N = sum(M)";
+ // alternative to .out("M").out("N")
+ MLResults results = ml.execute(dml(s).out("M", "N"));
+ double[][] matrix = results.getDoubleMatrix("M");
+ double sum = results.getDouble("N");
+ Assert.assertEquals(1.0, matrix[0][0], 0);
+ Assert.assertEquals(2.0, matrix[0][1], 0);
+ Assert.assertEquals(3.0, matrix[1][0], 0);
+ Assert.assertEquals(4.0, matrix[1][1], 0);
+ Assert.assertEquals(10.0, sum, 0);
+ }
+
+ @Test
+ public void testMultipleOutPYDML() {
+ System.out.println("MLContextTest - multiple out PYDML");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)\nN = sum(M)";
+ // alternative to .out("M").out("N")
+ MLResults results = ml.execute(pydml(s).out("M", "N"));
+ double[][] matrix = results.getDoubleMatrix("M");
+ double sum = results.getDouble("N");
+ Assert.assertEquals(1.0, matrix[0][0], 0);
+ Assert.assertEquals(2.0, matrix[0][1], 0);
+ Assert.assertEquals(3.0, matrix[1][0], 0);
+ Assert.assertEquals(4.0, matrix[1][1], 0);
+ Assert.assertEquals(10.0, sum, 0);
+ }
+
+ @Test
+ public void testOutputMatrixObjectDML() {
+ System.out.println("MLContextTest - output matrix object DML");
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ MatrixObject mo = ml.execute(dml(s).out("M")).getMatrixObject("M");
+ RDD<String> rddStringCSV = MLContextConversionUtil.matrixObjectToRDDStringCSV(mo);
+ Iterator<String> iterator = rddStringCSV.toLocalIterator();
+ Assert.assertEquals("1.0,2.0", iterator.next());
+ Assert.assertEquals("3.0,4.0", iterator.next());
+ }
+
+ @Test
+ public void testOutputMatrixObjectPYDML() {
+ System.out.println("MLContextTest - output matrix object PYDML");
+ String s = "M = full('1 2 3 4', rows=2, cols=2);";
+ MatrixObject mo = ml.execute(pydml(s).out("M")).getMatrixObject("M");
+ RDD<String> rddStringCSV = MLContextConversionUtil.matrixObjectToRDDStringCSV(mo);
+ Iterator<String> iterator = rddStringCSV.toLocalIterator();
+ Assert.assertEquals("1.0,2.0", iterator.next());
+ Assert.assertEquals("3.0,4.0", iterator.next());
+ }
+
+ @Test
+ public void testInputBinaryBlockMatrixDML() {
+ System.out.println("MLContextTest - input BinaryBlockMatrix DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("10,20,30");
+ list.add("40,50,60");
+ list.add("70,80,90");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ BinaryBlockMatrix binaryBlockMatrix = new BinaryBlockMatrix(dataFrame);
+ Script script = dml("avg = avg(M);").in("M", binaryBlockMatrix).out("avg");
+ double avg = ml.execute(script).getDouble("avg");
+ Assert.assertEquals(50.0, avg, 0.0);
+ }
+
+ @Test
+ public void testInputBinaryBlockMatrixPYDML() {
+ System.out.println("MLContextTest - input BinaryBlockMatrix PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("10,20,30");
+ list.add("40,50,60");
+ list.add("70,80,90");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ BinaryBlockMatrix binaryBlockMatrix = new BinaryBlockMatrix(dataFrame);
+ Script script = pydml("avg = avg(M)").in("M", binaryBlockMatrix).out("avg");
+ double avg = ml.execute(script).getDouble("avg");
+ Assert.assertEquals(50.0, avg, 0.0);
+ }
+
+ @Test
+ public void testOutputBinaryBlockMatrixDML() {
+ System.out.println("MLContextTest - output BinaryBlockMatrix DML");
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ BinaryBlockMatrix binaryBlockMatrix = ml.execute(dml(s).out("M")).getBinaryBlockMatrix("M");
+
+ JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil
+ .binaryBlockMatrixToJavaRDDStringIJV(binaryBlockMatrix);
+ List<String> lines = javaRDDStringIJV.collect();
+ Assert.assertEquals("1 1 1.0", lines.get(0));
+ Assert.assertEquals("1 2 2.0", lines.get(1));
+ Assert.assertEquals("2 1 3.0", lines.get(2));
+ Assert.assertEquals("2 2 4.0", lines.get(3));
+ }
+
+ @Test
+ public void testOutputBinaryBlockMatrixPYDML() {
+ System.out.println("MLContextTest - output BinaryBlockMatrix PYDML");
+ String s = "M = full('1 2 3 4', rows=2, cols=2);";
+ BinaryBlockMatrix binaryBlockMatrix = ml.execute(pydml(s).out("M")).getBinaryBlockMatrix("M");
+
+ JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil
+ .binaryBlockMatrixToJavaRDDStringIJV(binaryBlockMatrix);
+ List<String> lines = javaRDDStringIJV.collect();
+ Assert.assertEquals("1 1 1.0", lines.get(0));
+ Assert.assertEquals("1 2 2.0", lines.get(1));
+ Assert.assertEquals("2 1 3.0", lines.get(2));
+ Assert.assertEquals("2 2 4.0", lines.get(3));
+ }
+
+ @Test
+ public void testOutputListStringCSVDenseDML() {
+ System.out.println("MLContextTest - output List String CSV Dense DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2); print(toString(M));";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ MatrixObject mo = results.getMatrixObject("M");
+ List<String> lines = MLContextConversionUtil.matrixObjectToListStringCSV(mo);
+ Assert.assertEquals("1.0,2.0", lines.get(0));
+ Assert.assertEquals("3.0,4.0", lines.get(1));
+ }
+
+ @Test
+ public void testOutputListStringCSVDensePYDML() {
+ System.out.println("MLContextTest - output List String CSV Dense PYDML");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)\nprint(toString(M))";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ MatrixObject mo = results.getMatrixObject("M");
+ List<String> lines = MLContextConversionUtil.matrixObjectToListStringCSV(mo);
+ Assert.assertEquals("1.0,2.0", lines.get(0));
+ Assert.assertEquals("3.0,4.0", lines.get(1));
+ }
+
+ @Test
+ public void testOutputListStringCSVSparseDML() {
+ System.out.println("MLContextTest - output List String CSV Sparse DML");
+
+ String s = "M = matrix(0, rows=10, cols=10); M[1,1]=1; M[1,2]=2; M[2,1]=3; M[2,2]=4; print(toString(M));";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ MatrixObject mo = results.getMatrixObject("M");
+ List<String> lines = MLContextConversionUtil.matrixObjectToListStringCSV(mo);
+ Assert.assertEquals("1.0,2.0", lines.get(0));
+ Assert.assertEquals("3.0,4.0", lines.get(1));
+ }
+
+ @Test
+ public void testOutputListStringCSVSparsePYDML() {
+ System.out.println("MLContextTest - output List String CSV Sparse PYDML");
+
+ String s = "M = full(0, rows=10, cols=10)\nM[0,0]=1\nM[0,1]=2\nM[1,0]=3\nM[1,1]=4\nprint(toString(M))";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ MatrixObject mo = results.getMatrixObject("M");
+ List<String> lines = MLContextConversionUtil.matrixObjectToListStringCSV(mo);
+ Assert.assertEquals("1.0,2.0", lines.get(0));
+ Assert.assertEquals("3.0,4.0", lines.get(1));
+ }
+
+ @Test
+ public void testOutputListStringIJVDenseDML() {
+ System.out.println("MLContextTest - output List String IJV Dense DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2); print(toString(M));";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ MatrixObject mo = results.getMatrixObject("M");
+ List<String> lines = MLContextConversionUtil.matrixObjectToListStringIJV(mo);
+ Assert.assertEquals("1 1 1.0", lines.get(0));
+ Assert.assertEquals("1 2 2.0", lines.get(1));
+ Assert.assertEquals("2 1 3.0", lines.get(2));
+ Assert.assertEquals("2 2 4.0", lines.get(3));
+ }
+
+ @Test
+ public void testOutputListStringIJVDensePYDML() {
+ System.out.println("MLContextTest - output List String IJV Dense PYDML");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)\nprint(toString(M))";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ MatrixObject mo = results.getMatrixObject("M");
+ List<String> lines = MLContextConversionUtil.matrixObjectToListStringIJV(mo);
+ Assert.assertEquals("1 1 1.0", lines.get(0));
+ Assert.assertEquals("1 2 2.0", lines.get(1));
+ Assert.assertEquals("2 1 3.0", lines.get(2));
+ Assert.assertEquals("2 2 4.0", lines.get(3));
+ }
+
+ @Test
+ public void testOutputListStringIJVSparseDML() {
+ System.out.println("MLContextTest - output List String IJV Sparse DML");
+
+ String s = "M = matrix(0, rows=10, cols=10); M[1,1]=1; M[1,2]=2; M[2,1]=3; M[2,2]=4; print(toString(M));";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ MatrixObject mo = results.getMatrixObject("M");
+ List<String> lines = MLContextConversionUtil.matrixObjectToListStringIJV(mo);
+ Assert.assertEquals("1 1 1.0", lines.get(0));
+ Assert.assertEquals("1 2 2.0", lines.get(1));
+ Assert.assertEquals("2 1 3.0", lines.get(2));
+ Assert.assertEquals("2 2 4.0", lines.get(3));
+ }
+
+ @Test
+ public void testOutputListStringIJVSparsePYDML() {
+ System.out.println("MLContextTest - output List String IJV Sparse PYDML");
+
+ String s = "M = full(0, rows=10, cols=10)\nM[0,0]=1\nM[0,1]=2\nM[1,0]=3\nM[1,1]=4\nprint(toString(M))";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ MatrixObject mo = results.getMatrixObject("M");
+ List<String> lines = MLContextConversionUtil.matrixObjectToListStringIJV(mo);
+ Assert.assertEquals("1 1 1.0", lines.get(0));
+ Assert.assertEquals("1 2 2.0", lines.get(1));
+ Assert.assertEquals("2 1 3.0", lines.get(2));
+ Assert.assertEquals("2 2 4.0", lines.get(3));
+ }
+
+ @Test
+ public void testJavaRDDGoodMetadataDML() {
+ System.out.println("MLContextTest - JavaRDD<String> good metadata DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,2,3");
+ list.add("4,5,6");
+ list.add("7,8,9");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ MatrixMetadata mm = new MatrixMetadata(3, 3, 9);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", javaRDD, mm);
+ setExpectedStdOut("sum: 45.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testJavaRDDGoodMetadataPYDML() {
+ System.out.println("MLContextTest - JavaRDD<String> good metadata PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,2,3");
+ list.add("4,5,6");
+ list.add("7,8,9");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ MatrixMetadata mm = new MatrixMetadata(3, 3, 9);
+
+ Script script = pydml("print('sum: ' + sum(M))").in("M", javaRDD, mm);
+ setExpectedStdOut("sum: 45.0");
+ ml.execute(script);
+ }
+
+ @Test(expected = MLContextException.class)
+ public void testJavaRDDBadMetadataDML() {
+ System.out.println("MLContextTest - JavaRDD<String> bad metadata DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,2,3");
+ list.add("4,5,6");
+ list.add("7,8,9");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ MatrixMetadata mm = new MatrixMetadata(1, 1, 9);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", javaRDD, mm);
+ ml.execute(script);
+ }
+
+ @Test(expected = MLContextException.class)
+ public void testJavaRDDBadMetadataPYDML() {
+ System.out.println("MLContextTest - JavaRDD<String> bad metadata PYML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,2,3");
+ list.add("4,5,6");
+ list.add("7,8,9");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+
+ MatrixMetadata mm = new MatrixMetadata(1, 1, 9);
+
+ Script script = dml("print('sum: ' + sum(M))").in("M", javaRDD, mm);
+ ml.execute(script);
+ }
+
+ @Test
+ public void testRDDGoodMetadataDML() {
+ System.out.println("MLContextTest - RDD<String> good metadata DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,1,1");
+ list.add("2,2,2");
+ list.add("3,3,3");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+ RDD<String> rdd = JavaRDD.toRDD(javaRDD);
+
+ MatrixMetadata mm = new MatrixMetadata(3, 3, 9);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", rdd, mm);
+ setExpectedStdOut("sum: 18.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testRDDGoodMetadataPYDML() {
+ System.out.println("MLContextTest - RDD<String> good metadata PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,1,1");
+ list.add("2,2,2");
+ list.add("3,3,3");
+ JavaRDD<String> javaRDD = sc.parallelize(list);
+ RDD<String> rdd = JavaRDD.toRDD(javaRDD);
+
+ MatrixMetadata mm = new MatrixMetadata(3, 3, 9);
+
+ Script script = pydml("print('sum: ' + sum(M))").in("M", rdd, mm);
+ setExpectedStdOut("sum: 18.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testDataFrameGoodMetadataDML() {
+ System.out.println("MLContextTest - DataFrame good metadata DML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("10,20,30");
+ list.add("40,50,60");
+ list.add("70,80,90");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ MatrixMetadata mm = new MatrixMetadata(3, 3, 9);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
+ setExpectedStdOut("sum: 450.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testDataFrameGoodMetadataPYDML() {
+ System.out.println("MLContextTest - DataFrame good metadata PYDML");
+
+ List<String> list = new ArrayList<String>();
+ list.add("10,20,30");
+ list.add("40,50,60");
+ list.add("70,80,90");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ MatrixMetadata mm = new MatrixMetadata(3, 3, 9);
+
+ Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm);
+ setExpectedStdOut("sum: 450.0");
+ ml.execute(script);
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ @Test
+ public void testInputTupleSeqNoMetadataDML() {
+ System.out.println("MLContextTest - Tuple sequence no metadata DML");
+
+ List<String> list1 = new ArrayList<String>();
+ list1.add("1,2");
+ list1.add("3,4");
+ JavaRDD<String> javaRDD1 = sc.parallelize(list1);
+ RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
+
+ List<String> list2 = new ArrayList<String>();
+ list2.add("5,6");
+ list2.add("7,8");
+ JavaRDD<String> javaRDD2 = sc.parallelize(list2);
+ RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
+
+ Tuple2 tuple1 = new Tuple2("m1", rdd1);
+ Tuple2 tuple2 = new Tuple2("m2", rdd2);
+ List tupleList = new ArrayList();
+ tupleList.add(tuple1);
+ tupleList.add(tuple2);
+ Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
+
+ Script script = dml("print('sums: ' + sum(m1) + ' ' + sum(m2));").in(seq);
+ setExpectedStdOut("sums: 10.0 26.0");
+ ml.execute(script);
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ @Test
+ public void testInputTupleSeqNoMetadataPYDML() {
+ System.out.println("MLContextTest - Tuple sequence no metadata PYDML");
+
+ List<String> list1 = new ArrayList<String>();
+ list1.add("1,2");
+ list1.add("3,4");
+ JavaRDD<String> javaRDD1 = sc.parallelize(list1);
+ RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
+
+ List<String> list2 = new ArrayList<String>();
+ list2.add("5,6");
+ list2.add("7,8");
+ JavaRDD<String> javaRDD2 = sc.parallelize(list2);
+ RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
+
+ Tuple2 tuple1 = new Tuple2("m1", rdd1);
+ Tuple2 tuple2 = new Tuple2("m2", rdd2);
+ List tupleList = new ArrayList();
+ tupleList.add(tuple1);
+ tupleList.add(tuple2);
+ Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
+
+ Script script = pydml("print('sums: ' + sum(m1) + ' ' + sum(m2))").in(seq);
+ setExpectedStdOut("sums: 10.0 26.0");
+ ml.execute(script);
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ @Test
+ public void testInputTupleSeqWithMetadataDML() {
+ System.out.println("MLContextTest - Tuple sequence with metadata DML");
+
+ List<String> list1 = new ArrayList<String>();
+ list1.add("1,2");
+ list1.add("3,4");
+ JavaRDD<String> javaRDD1 = sc.parallelize(list1);
+ RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
+
+ List<String> list2 = new ArrayList<String>();
+ list2.add("5,6");
+ list2.add("7,8");
+ JavaRDD<String> javaRDD2 = sc.parallelize(list2);
+ RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
+
+ MatrixMetadata mm1 = new MatrixMetadata(2, 2);
+ MatrixMetadata mm2 = new MatrixMetadata(2, 2);
+
+ Tuple3 tuple1 = new Tuple3("m1", rdd1, mm1);
+ Tuple3 tuple2 = new Tuple3("m2", rdd2, mm2);
+ List tupleList = new ArrayList();
+ tupleList.add(tuple1);
+ tupleList.add(tuple2);
+ Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
+
+ Script script = dml("print('sums: ' + sum(m1) + ' ' + sum(m2));").in(seq);
+ setExpectedStdOut("sums: 10.0 26.0");
+ ml.execute(script);
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ @Test
+ public void testInputTupleSeqWithMetadataPYDML() {
+ System.out.println("MLContextTest - Tuple sequence with metadata PYDML");
+
+ List<String> list1 = new ArrayList<String>();
+ list1.add("1,2");
+ list1.add("3,4");
+ JavaRDD<String> javaRDD1 = sc.parallelize(list1);
+ RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
+
+ List<String> list2 = new ArrayList<String>();
+ list2.add("5,6");
+ list2.add("7,8");
+ JavaRDD<String> javaRDD2 = sc.parallelize(list2);
+ RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
+
+ MatrixMetadata mm1 = new MatrixMetadata(2, 2);
+ MatrixMetadata mm2 = new MatrixMetadata(2, 2);
+
+ Tuple3 tuple1 = new Tuple3("m1", rdd1, mm1);
+ Tuple3 tuple2 = new Tuple3("m2", rdd2, mm2);
+ List tupleList = new ArrayList();
+ tupleList.add(tuple1);
+ tupleList.add(tuple2);
+ Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
+
+ Script script = pydml("print('sums: ' + sum(m1) + ' ' + sum(m2))").in(seq);
+ setExpectedStdOut("sums: 10.0 26.0");
+ ml.execute(script);
+ }
+
+ // NOTE: Uncomment these tests once they work
+
+ // @SuppressWarnings({ "rawtypes", "unchecked" })
+ // @Test
+ // public void testInputTupleSeqWithAndWithoutMetadataDML() {
+ // System.out.println("MLContextTest - Tuple sequence with and without metadata DML");
+ //
+ // List<String> list1 = new ArrayList<String>();
+ // list1.add("1,2");
+ // list1.add("3,4");
+ // JavaRDD<String> javaRDD1 = sc.parallelize(list1);
+ // RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
+ //
+ // List<String> list2 = new ArrayList<String>();
+ // list2.add("5,6");
+ // list2.add("7,8");
+ // JavaRDD<String> javaRDD2 = sc.parallelize(list2);
+ // RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
+ //
+ // MatrixMetadata mm1 = new MatrixMetadata(2, 2);
+ //
+ // Tuple3 tuple1 = new Tuple3("m1", rdd1, mm1);
+ // Tuple2 tuple2 = new Tuple2("m2", rdd2);
+ // List tupleList = new ArrayList();
+ // tupleList.add(tuple1);
+ // tupleList.add(tuple2);
+ // Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
+ //
+ // Script script =
+ // dml("print('sums: ' + sum(m1) + ' ' + sum(m2));").in(seq);
+ // setExpectedStdOut("sums: 10.0 26.0");
+ // ml.execute(script);
+ // }
+ //
+ // @SuppressWarnings({ "rawtypes", "unchecked" })
+ // @Test
+ // public void testInputTupleSeqWithAndWithoutMetadataPYDML() {
+ // System.out.println("MLContextTest - Tuple sequence with and without metadata PYDML");
+ //
+ // List<String> list1 = new ArrayList<String>();
+ // list1.add("1,2");
+ // list1.add("3,4");
+ // JavaRDD<String> javaRDD1 = sc.parallelize(list1);
+ // RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
+ //
+ // List<String> list2 = new ArrayList<String>();
+ // list2.add("5,6");
+ // list2.add("7,8");
+ // JavaRDD<String> javaRDD2 = sc.parallelize(list2);
+ // RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
+ //
+ // MatrixMetadata mm1 = new MatrixMetadata(2, 2);
+ //
+ // Tuple3 tuple1 = new Tuple3("m1", rdd1, mm1);
+ // Tuple2 tuple2 = new Tuple2("m2", rdd2);
+ // List tupleList = new ArrayList();
+ // tupleList.add(tuple1);
+ // tupleList.add(tuple2);
+ // Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
+ //
+ // Script script =
+ // pydml("print('sums: ' + sum(m1) + ' ' + sum(m2))").in(seq);
+ // setExpectedStdOut("sums: 10.0 26.0");
+ // ml.execute(script);
+ // }
+
+ @After
+ public void tearDown() {
+ super.tearDown();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/test/scripts/org/apache/sysml/api/mlcontext/1234.csv
----------------------------------------------------------------------
diff --git a/src/test/scripts/org/apache/sysml/api/mlcontext/1234.csv b/src/test/scripts/org/apache/sysml/api/mlcontext/1234.csv
new file mode 100644
index 0000000..e055049
--- /dev/null
+++ b/src/test/scripts/org/apache/sysml/api/mlcontext/1234.csv
@@ -0,0 +1,2 @@
+1.0,2.0
+3.0,4.0
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/test/scripts/org/apache/sysml/api/mlcontext/1234.csv.mtd
----------------------------------------------------------------------
diff --git a/src/test/scripts/org/apache/sysml/api/mlcontext/1234.csv.mtd b/src/test/scripts/org/apache/sysml/api/mlcontext/1234.csv.mtd
new file mode 100644
index 0000000..d57e93d
--- /dev/null
+++ b/src/test/scripts/org/apache/sysml/api/mlcontext/1234.csv.mtd
@@ -0,0 +1,13 @@
+{
+ "data_type": "matrix",
+ "value_type": "double",
+ "rows": 2,
+ "cols": 2,
+ "nnz": 4,
+ "format": "csv",
+ "header": false,
+ "sep": ",",
+ "description": {
+ "author": "SystemML"
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/test/scripts/org/apache/sysml/api/mlcontext/hello-world.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/org/apache/sysml/api/mlcontext/hello-world.dml b/src/test/scripts/org/apache/sysml/api/mlcontext/hello-world.dml
new file mode 100644
index 0000000..32e8eb5
--- /dev/null
+++ b/src/test/scripts/org/apache/sysml/api/mlcontext/hello-world.dml
@@ -0,0 +1,22 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+print('hello world');
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/test/scripts/org/apache/sysml/api/mlcontext/hello-world.pydml
----------------------------------------------------------------------
diff --git a/src/test/scripts/org/apache/sysml/api/mlcontext/hello-world.pydml b/src/test/scripts/org/apache/sysml/api/mlcontext/hello-world.pydml
new file mode 100644
index 0000000..01c348b
--- /dev/null
+++ b/src/test/scripts/org/apache/sysml/api/mlcontext/hello-world.pydml
@@ -0,0 +1,22 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+print('hello world')
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/test/scripts/org/apache/sysml/api/mlcontext/one-two-three-four.csv
----------------------------------------------------------------------
diff --git a/src/test/scripts/org/apache/sysml/api/mlcontext/one-two-three-four.csv b/src/test/scripts/org/apache/sysml/api/mlcontext/one-two-three-four.csv
new file mode 100644
index 0000000..eedf6ea
--- /dev/null
+++ b/src/test/scripts/org/apache/sysml/api/mlcontext/one-two-three-four.csv
@@ -0,0 +1,2 @@
+one,two
+three,four
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/test/scripts/org/apache/sysml/api/mlcontext/one-two-three-four.csv.mtd
----------------------------------------------------------------------
diff --git a/src/test/scripts/org/apache/sysml/api/mlcontext/one-two-three-four.csv.mtd b/src/test/scripts/org/apache/sysml/api/mlcontext/one-two-three-four.csv.mtd
new file mode 100644
index 0000000..8c6dcd1
--- /dev/null
+++ b/src/test/scripts/org/apache/sysml/api/mlcontext/one-two-three-four.csv.mtd
@@ -0,0 +1,5 @@
+{
+ "data_type": "frame",
+ "format": "csv",
+ "header": false
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java
new file mode 100644
index 0000000..5687a55
--- /dev/null
+++ b/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.mlcontext;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/** Group together the tests in this package/related subpackages into a single suite so that the Maven build
+ * won't run two of them at once. Since the DML and PyDML equivalent tests currently share the same directories,
+ * they should not be run in parallel. */
+@RunWith(Suite.class)
+@Suite.SuiteClasses({
+ org.apache.sysml.test.integration.mlcontext.MLContextTest.class
+})
+
+
+/** This class is just a holder for the above JUnit annotations. */
+public class ZPackageSuite {
+
+}
[4/4] incubator-systemml git commit: [SYSTEMML-593] MLContext redesign
Posted by de...@apache.org.
[SYSTEMML-593] MLContext redesign
Closes #199.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/457bbd3a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/457bbd3a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/457bbd3a
Branch: refs/heads/master
Commit: 457bbd3a4aca2c75163f4cbaed3faa2a9cb14d72
Parents: 873bae7
Author: Deron Eriksson <de...@us.ibm.com>
Authored: Thu Jul 28 17:11:40 2016 -0700
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Thu Jul 28 17:11:40 2016 -0700
----------------------------------------------------------------------
pom.xml | 1 +
.../java/org/apache/sysml/api/DMLScript.java | 2 +-
.../java/org/apache/sysml/api/MLContext.java | 4 +
.../org/apache/sysml/api/MLContextProxy.java | 55 +-
.../sysml/api/mlcontext/BinaryBlockMatrix.java | 148 ++
.../apache/sysml/api/mlcontext/MLContext.java | 505 ++++++
.../api/mlcontext/MLContextConversionUtil.java | 720 ++++++++
.../sysml/api/mlcontext/MLContextException.java | 47 +
.../sysml/api/mlcontext/MLContextUtil.java | 844 +++++++++
.../apache/sysml/api/mlcontext/MLResults.java | 1299 +++++++++++++
.../org/apache/sysml/api/mlcontext/Matrix.java | 141 ++
.../sysml/api/mlcontext/MatrixFormat.java | 39 +
.../sysml/api/mlcontext/MatrixMetadata.java | 522 ++++++
.../org/apache/sysml/api/mlcontext/Script.java | 652 +++++++
.../sysml/api/mlcontext/ScriptExecutor.java | 624 +++++++
.../sysml/api/mlcontext/ScriptFactory.java | 422 +++++
.../apache/sysml/api/mlcontext/ScriptType.java | 65 +
.../context/SparkExecutionContext.java | 47 +-
.../instructions/spark/SPInstruction.java | 52 +-
.../spark/functions/SparkListener.java | 38 +-
.../spark/utils/RDDConverterUtilsExt.java | 4 +-
.../integration/mlcontext/MLContextTest.java | 1713 ++++++++++++++++++
.../org/apache/sysml/api/mlcontext/1234.csv | 2 +
.../org/apache/sysml/api/mlcontext/1234.csv.mtd | 13 +
.../apache/sysml/api/mlcontext/hello-world.dml | 22 +
.../sysml/api/mlcontext/hello-world.pydml | 22 +
.../sysml/api/mlcontext/one-two-three-four.csv | 2 +
.../api/mlcontext/one-two-three-four.csv.mtd | 5 +
.../integration/mlcontext/ZPackageSuite.java | 37 +
29 files changed, 7991 insertions(+), 56 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 9574679..f7ec5b7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -394,6 +394,7 @@
<include>**/integration/functions/gdfo/*Suite.java</include>
<include>**/integration/functions/sparse/*Suite.java</include>
<include>**/integration/functions/**/*Test*.java</include>
+ <include>**/integration/mlcontext/*Suite.java</include>
<include>**/integration/scalability/**/*Test.java</include>
</includes>
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/DMLScript.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java b/src/main/java/org/apache/sysml/api/DMLScript.java
index 814bcb8..3d76273 100644
--- a/src/main/java/org/apache/sysml/api/DMLScript.java
+++ b/src/main/java/org/apache/sysml/api/DMLScript.java
@@ -777,7 +777,7 @@ public class DMLScript
* @throws DMLRuntimeException
*
*/
- static void initHadoopExecution( DMLConfig config )
+ public static void initHadoopExecution( DMLConfig config )
throws IOException, ParseException, DMLRuntimeException
{
//check security aspects
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLContext.java b/src/main/java/org/apache/sysml/api/MLContext.java
index 2600c35..a03c8b7 100644
--- a/src/main/java/org/apache/sysml/api/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/MLContext.java
@@ -87,6 +87,10 @@ import org.apache.sysml.utils.Explain.ExplainCounts;
import org.apache.sysml.utils.Statistics;
/**
+ * The MLContext API has been redesigned and this API will be deprecated.
+ * Please migrate to {@link org.apache.sysml.api.mlcontext.MLContext}.
+ * <p>
+ *
* MLContext is useful for passing RDDs as input/output to SystemML. This API avoids the need to read/write
* from HDFS (which is another way to pass inputs to SystemML).
* <p>
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/MLContextProxy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLContextProxy.java b/src/main/java/org/apache/sysml/api/MLContextProxy.java
index ee16690..f8f31d6 100644
--- a/src/main/java/org/apache/sysml/api/MLContextProxy.java
+++ b/src/main/java/org/apache/sysml/api/MLContextProxy.java
@@ -61,8 +61,10 @@ public class MLContextProxy
*/
public static ArrayList<Instruction> performCleanupAfterRecompilation(ArrayList<Instruction> tmp)
{
- if(MLContext.getActiveMLContext() != null) {
- return MLContext.getActiveMLContext().performCleanupAfterRecompilation(tmp);
+ if(org.apache.sysml.api.MLContext.getActiveMLContext() != null) {
+ return org.apache.sysml.api.MLContext.getActiveMLContext().performCleanupAfterRecompilation(tmp);
+ } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) {
+ return org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext().getInternalProxy().performCleanupAfterRecompilation(tmp);
}
return tmp;
}
@@ -76,28 +78,55 @@ public class MLContextProxy
public static void setAppropriateVarsForRead(Expression source, String targetname)
throws LanguageException
{
- MLContext mlContext = MLContext.getActiveMLContext();
- if(mlContext != null) {
- mlContext.setAppropriateVarsForRead(source, targetname);
+ if(org.apache.sysml.api.MLContext.getActiveMLContext() != null) {
+ org.apache.sysml.api.MLContext.getActiveMLContext().setAppropriateVarsForRead(source, targetname);
+ } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) {
+ org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext().getInternalProxy().setAppropriateVarsForRead(source, targetname);
}
}
- public static MLContext getActiveMLContext() {
- return MLContext.getActiveMLContext();
+ public static Object getActiveMLContext() {
+ if (org.apache.sysml.api.MLContext.getActiveMLContext() != null) {
+ return org.apache.sysml.api.MLContext.getActiveMLContext();
+ } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) {
+ return org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext();
+ } else {
+ return null;
+ }
+
}
public static void setInstructionForMonitoring(Instruction inst) {
Location loc = inst.getLocation();
- MLContext mlContext = MLContext.getActiveMLContext();
- if(loc != null && mlContext != null && mlContext.getMonitoringUtil() != null) {
- mlContext.getMonitoringUtil().setInstructionLocation(loc, inst);
+ if (loc == null) {
+ return;
+ }
+
+ if (org.apache.sysml.api.MLContext.getActiveMLContext() != null) {
+ org.apache.sysml.api.MLContext mlContext = org.apache.sysml.api.MLContext.getActiveMLContext();
+ if(mlContext.getMonitoringUtil() != null) {
+ mlContext.getMonitoringUtil().setInstructionLocation(loc, inst);
+ }
+ } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) {
+ org.apache.sysml.api.mlcontext.MLContext mlContext = org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext();
+ if(mlContext.getSparkMonitoringUtil() != null) {
+ mlContext.getSparkMonitoringUtil().setInstructionLocation(loc, inst);
+ }
}
}
public static void addRDDForInstructionForMonitoring(SPInstruction inst, Integer rddID) {
- MLContext mlContext = MLContext.getActiveMLContext();
- if(mlContext != null && mlContext.getMonitoringUtil() != null) {
- mlContext.getMonitoringUtil().addRDDForInstruction(inst, rddID);
+
+ if (org.apache.sysml.api.MLContext.getActiveMLContext() != null) {
+ org.apache.sysml.api.MLContext mlContext = org.apache.sysml.api.MLContext.getActiveMLContext();
+ if(mlContext.getMonitoringUtil() != null) {
+ mlContext.getMonitoringUtil().addRDDForInstruction(inst, rddID);
+ }
+ } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) {
+ org.apache.sysml.api.mlcontext.MLContext mlContext = org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext();
+ if(mlContext.getSparkMonitoringUtil() != null) {
+ mlContext.getSparkMonitoringUtil().addRDDForInstruction(inst, rddID);
+ }
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
new file mode 100644
index 0000000..8c9f923
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.sql.DataFrame;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+
+/**
+ * BinaryBlockMatrix stores data as a SystemML binary-block representation.
+ *
+ */
+public class BinaryBlockMatrix {
+
+ JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks;
+ MatrixMetadata matrixMetadata;
+
+ /**
+ * Convert a Spark DataFrame to a SystemML binary-block representation.
+ *
+ * @param dataFrame
+ * the Spark DataFrame
+ * @param matrixMetadata
+ * matrix metadata, such as number of rows and columns
+ */
+ public BinaryBlockMatrix(DataFrame dataFrame, MatrixMetadata matrixMetadata) {
+ this.matrixMetadata = matrixMetadata;
+ binaryBlocks = MLContextConversionUtil.dataFrameToBinaryBlocks(dataFrame, matrixMetadata);
+ }
+
+ /**
+ * Convert a Spark DataFrame to a SystemML binary-block representation,
+ * specifying the number of rows and columns.
+ *
+ * @param dataFrame
+ * the Spark DataFrame
+ * @param numRows
+ * the number of rows
+ * @param numCols
+ * the number of columns
+ */
+ public BinaryBlockMatrix(DataFrame dataFrame, long numRows, long numCols) {
+ this(dataFrame, new MatrixMetadata(numRows, numCols, MLContextUtil.defaultBlockSize(),
+ MLContextUtil.defaultBlockSize()));
+ }
+
+ /**
+ * Convert a Spark DataFrame to a SystemML binary-block representation.
+ *
+ * @param dataFrame
+ * the Spark DataFrame
+ */
+ public BinaryBlockMatrix(DataFrame dataFrame) {
+ this(dataFrame, new MatrixMetadata());
+ }
+
+ /**
+ * Create a BinaryBlockMatrix, specifying the SystemML binary-block matrix
+ * and its metadata.
+ *
+ * @param binaryBlocks
+ * the {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} matrix
+ * @param matrixCharacteristics
+ * the matrix metadata as {@code MatrixCharacteristics}
+ */
+ public BinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks,
+ MatrixCharacteristics matrixCharacteristics) {
+ this.binaryBlocks = binaryBlocks;
+ this.matrixMetadata = new MatrixMetadata(matrixCharacteristics);
+ }
+
+ /**
+ * Obtain a SystemML binary-block matrix as a
+ * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>}
+ *
+ * @return the SystemML binary-block matrix
+ */
+ public JavaPairRDD<MatrixIndexes, MatrixBlock> getBinaryBlocks() {
+ return binaryBlocks;
+ }
+
+ /**
+ * Obtain the SystemML binary-block matrix characteristics
+ *
+ * @return the matrix metadata as {@code MatrixCharacteristics}
+ */
+ public MatrixCharacteristics getMatrixCharacteristics() {
+ return matrixMetadata.asMatrixCharacteristics();
+ }
+
+ /**
+ * Obtain the SystemML binary-block matrix metadata
+ *
+ * @return the matrix metadata as {@code MatrixMetadata}
+ */
+ public MatrixMetadata getMatrixMetadata() {
+ return matrixMetadata;
+ }
+
+ /**
+ * Set the SystemML binary-block matrix metadata
+ *
+ * @param matrixMetadata
+ * the matrix metadata
+ */
+ public void setMatrixMetadata(MatrixMetadata matrixMetadata) {
+ this.matrixMetadata = matrixMetadata;
+ }
+
+ /**
+ * Set the SystemML binary-block matrix as a
+ * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>}
+ *
+ * @param binaryBlocks
+ * the SystemML binary-block matrix
+ */
+ public void setBinaryBlocks(JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks) {
+ this.binaryBlocks = binaryBlocks;
+ }
+
+ @Override
+ public String toString() {
+ if (matrixMetadata != null) {
+ return matrixMetadata.toString();
+ } else {
+ return super.toString();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
new file mode 100644
index 0000000..05deec2
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
@@ -0,0 +1,505 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.api.MLContextProxy;
+import org.apache.sysml.api.monitoring.SparkMonitoringUtil;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.conf.DMLConfig;
+import org.apache.sysml.parser.DataExpression;
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.parser.IntIdentifier;
+import org.apache.sysml.parser.StringIdentifier;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.instructions.Instruction;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.ScalarObject;
+import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
+import org.apache.sysml.runtime.instructions.spark.functions.SparkListener;
+import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+
+/**
+ * The MLContext API offers programmatic access to SystemML on Spark from
+ * languages such as Scala, Java, and Python.
+ *
+ */
+public class MLContext {
+
+ /**
+ * Minimum Spark version supported by SystemML.
+ */
+ public static final String SYSTEMML_MINIMUM_SPARK_VERSION = "1.4.0";
+
+ /**
+ * SparkContext object.
+ */
+ private SparkContext sc = null;
+
+ /**
+ * SparkMonitoringUtil monitors SystemML performance on Spark.
+ */
+ private SparkMonitoringUtil sparkMonitoringUtil = null;
+
+ /**
+ * Reference to the currently executing script.
+ */
+ private Script executingScript = null;
+
+ /**
+ * The currently active MLContext.
+ */
+ private static MLContext activeMLContext = null;
+
+ /**
+ * Contains cleanup methods used by MLContextProxy.
+ */
+ private InternalProxy internalProxy = new InternalProxy();
+
+ /**
+ * Whether or not an explanation of the DML/PYDML program should be output
+ * to standard output.
+ */
+ private boolean explain = false;
+
+ /**
+ * Whether or not statistics of the DML/PYDML program execution should be
+ * output to standard output.
+ */
+ private boolean statistics = false;
+
+ private List<String> scriptHistoryStrings = new ArrayList<String>();
+ private Map<String, Script> scripts = new LinkedHashMap<String, Script>();
+
+ /**
+ * Retrieve the currently active MLContext. This is used internally by
+ * SystemML via MLContextProxy.
+ *
+ * @return the active MLContext
+ */
+ public static MLContext getActiveMLContext() {
+ return activeMLContext;
+ }
+
+ /**
+ * Create an MLContext based on a SparkContext for interaction with SystemML
+ * on Spark.
+ *
+ * @param sparkContext
+ * SparkContext
+ */
+ public MLContext(SparkContext sparkContext) {
+ this(sparkContext, false);
+ }
+
+ /**
+ * Create an MLContext based on a JavaSparkContext for interaction with
+ * SystemML on Spark.
+ *
+ * @param javaSparkContext
+ * JavaSparkContext
+ */
+ public MLContext(JavaSparkContext javaSparkContext) {
+ this(javaSparkContext.sc(), false);
+ }
+
+ /**
+ * Create an MLContext based on a SparkContext for interaction with SystemML
+ * on Spark, optionally monitor performance.
+ *
+ * @param sc
+ * SparkContext object.
+ * @param monitorPerformance
+ * {@code true} if performance should be monitored, {@code false}
+ * otherwise
+ */
+ public MLContext(SparkContext sc, boolean monitorPerformance) {
+ initMLContext(sc, monitorPerformance);
+ }
+
+ /**
+ * Initialize MLContext. Verify Spark version supported, set default
+ * execution mode, set MLContextProxy, set default config, set compiler
+ * config, and configure monitoring if needed.
+ *
+ * @param sc
+ * SparkContext object.
+ * @param monitorPerformance
+ * {@code true} if performance should be monitored, {@code false}
+ * otherwise
+ */
+ private void initMLContext(SparkContext sc, boolean monitorPerformance) {
+
+ if (activeMLContext == null) {
+ System.out.println(MLContextUtil.welcomeMessage());
+ }
+
+ this.sc = sc;
+ MLContextUtil.verifySparkVersionSupported(sc);
+ // by default, run in hybrid Spark mode for optimal performance
+ DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
+
+ activeMLContext = this;
+ MLContextProxy.setActive(true);
+
+ MLContextUtil.setDefaultConfig();
+ MLContextUtil.setCompilerConfig();
+
+ if (monitorPerformance) {
+ SparkListener sparkListener = new SparkListener(sc);
+ sparkMonitoringUtil = new SparkMonitoringUtil(sparkListener);
+ sc.addSparkListener(sparkListener);
+ }
+ }
+
+ /**
+ * Clean up the variables from the buffer pool, including evicted files,
+ * because the buffer pool holds references.
+ */
+ public void clearCache() {
+ CacheableData.cleanupCacheDir();
+ }
+
+ /**
+ * Reset configuration settings to default settings.
+ */
+ public void resetConfig() {
+ MLContextUtil.setDefaultConfig();
+ }
+
+ /**
+ * Set configuration property, such as
+ * {@code setConfigProperty("localtmpdir", "/tmp/systemml")}.
+ *
+ * @param propertyName
+ * property name
+ * @param propertyValue
+ * property value
+ */
+ public void setConfigProperty(String propertyName, String propertyValue) {
+ DMLConfig config = ConfigurationManager.getDMLConfig();
+ try {
+ config.setTextValue(propertyName, propertyValue);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException(e);
+ }
+ }
+
+ /**
+ * Execute a DML or PYDML Script.
+ *
+ * @param script
+ * The DML or PYDML Script object to execute.
+ */
+ public MLResults execute(Script script) {
+ ScriptExecutor scriptExecutor = new ScriptExecutor(sparkMonitoringUtil);
+ scriptExecutor.setExplain(explain);
+ scriptExecutor.setStatistics(statistics);
+ return execute(script, scriptExecutor);
+ }
+
+ /**
+ * Execute a DML or PYDML Script object using a ScriptExecutor. The
+ * ScriptExecutor class can be extended to allow the modification of the
+ * default execution pathway.
+ *
+ * @param script
+ * the DML or PYDML Script object
+ * @param scriptExecutor
+ * the ScriptExecutor that defines the script execution pathway
+ */
+ public MLResults execute(Script script, ScriptExecutor scriptExecutor) {
+ try {
+ executingScript = script;
+
+ Long time = new Long((new Date()).getTime());
+ if ((script.getName() == null) || (script.getName().equals(""))) {
+ script.setName(time.toString());
+ }
+
+ MLResults results = scriptExecutor.execute(script);
+
+ String history = MLContextUtil.createHistoryForScript(script, time);
+ scriptHistoryStrings.add(history);
+ scripts.put(script.getName(), script);
+
+ return results;
+ } catch (RuntimeException e) {
+ throw new MLContextException("Exception when executing script", e);
+ }
+ }
+
+ /**
+ * Set SystemML configuration based on a configuration file.
+ *
+ * @param configFilePath
+ * path to the configuration file
+ */
+ public void setConfig(String configFilePath) {
+ MLContextUtil.setConfig(configFilePath);
+ }
+
+ /**
+ * Obtain the SparkMonitoringUtil if it is available.
+ *
+ * @return the SparkMonitoringUtil if it is available.
+ */
+ public SparkMonitoringUtil getSparkMonitoringUtil() {
+ return sparkMonitoringUtil;
+ }
+
+ /**
+ * Obtain the SparkContext associated with this MLContext.
+ *
+ * @return the SparkContext associated with this MLContext.
+ */
+ public SparkContext getSparkContext() {
+ return sc;
+ }
+
+ /**
+ * Whether or not an explanation of the DML/PYDML program should be output
+ * to standard output.
+ *
+ * @return {@code true} if explanation should be output, {@code false}
+ * otherwise
+ */
+ public boolean isExplain() {
+ return explain;
+ }
+
+ /**
+ * Whether or not an explanation of the DML/PYDML program should be output
+ * to standard output.
+ *
+ * @param explain
+ * {@code true} if explanation should be output, {@code false}
+ * otherwise
+ */
+ public void setExplain(boolean explain) {
+ this.explain = explain;
+ }
+
+ /**
+ * Used internally by MLContextProxy.
+ *
+ */
+ public class InternalProxy {
+
+ public void setAppropriateVarsForRead(Expression source, String target) {
+ boolean isTargetRegistered = isRegisteredAsInput(target);
+ boolean isReadExpression = (source instanceof DataExpression && ((DataExpression) source).isRead());
+ if (isTargetRegistered && isReadExpression) {
+ DataExpression exp = (DataExpression) source;
+ // Do not check metadata file for registered reads
+ exp.setCheckMetadata(false);
+
+ MatrixObject mo = getMatrixObject(target);
+ if (mo != null) {
+ int blp = source.getBeginLine();
+ int bcp = source.getBeginColumn();
+ int elp = source.getEndLine();
+ int ecp = source.getEndColumn();
+ exp.addVarParam(DataExpression.READROWPARAM,
+ new IntIdentifier(mo.getNumRows(), source.getFilename(), blp, bcp, elp, ecp));
+ exp.addVarParam(DataExpression.READCOLPARAM,
+ new IntIdentifier(mo.getNumColumns(), source.getFilename(), blp, bcp, elp, ecp));
+ exp.addVarParam(DataExpression.READNUMNONZEROPARAM,
+ new IntIdentifier(mo.getNnz(), source.getFilename(), blp, bcp, elp, ecp));
+ exp.addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier("matrix", source.getFilename(),
+ blp, bcp, elp, ecp));
+ exp.addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier("double", source.getFilename(),
+ blp, bcp, elp, ecp));
+
+ if (mo.getMetaData() instanceof MatrixFormatMetaData) {
+ MatrixFormatMetaData metaData = (MatrixFormatMetaData) mo.getMetaData();
+ if (metaData.getOutputInfo() == OutputInfo.CSVOutputInfo) {
+ exp.addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(
+ DataExpression.FORMAT_TYPE_VALUE_CSV, source.getFilename(), blp, bcp, elp, ecp));
+ } else if (metaData.getOutputInfo() == OutputInfo.TextCellOutputInfo) {
+ exp.addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(
+ DataExpression.FORMAT_TYPE_VALUE_TEXT, source.getFilename(), blp, bcp, elp, ecp));
+ } else if (metaData.getOutputInfo() == OutputInfo.BinaryBlockOutputInfo) {
+ exp.addVarParam(
+ DataExpression.ROWBLOCKCOUNTPARAM,
+ new IntIdentifier(mo.getNumRowsPerBlock(), source.getFilename(), blp, bcp, elp, ecp));
+ exp.addVarParam(DataExpression.COLUMNBLOCKCOUNTPARAM,
+ new IntIdentifier(mo.getNumColumnsPerBlock(), source.getFilename(), blp, bcp, elp,
+ ecp));
+ exp.addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(
+ DataExpression.FORMAT_TYPE_VALUE_BINARY, source.getFilename(), blp, bcp, elp, ecp));
+ } else {
+ throw new MLContextException("Unsupported format through MLContext");
+ }
+ }
+ }
+
+ }
+ }
+
+ private boolean isRegisteredAsInput(String parameterName) {
+ if (executingScript != null) {
+ Set<String> inputVariableNames = executingScript.getInputVariables();
+ if (inputVariableNames != null) {
+ return inputVariableNames.contains(parameterName);
+ }
+ }
+ return false;
+ }
+
+ private MatrixObject getMatrixObject(String parameterName) {
+ if (executingScript != null) {
+ LocalVariableMap symbolTable = executingScript.getSymbolTable();
+ if (symbolTable != null) {
+ Data data = symbolTable.get(parameterName);
+ if (data instanceof MatrixObject) {
+ return (MatrixObject) data;
+ } else {
+ if (data instanceof ScalarObject) {
+ return null;
+ }
+ }
+ }
+ }
+ throw new MLContextException("getMatrixObject not set for parameter: " + parameterName);
+ }
+
+ public ArrayList<Instruction> performCleanupAfterRecompilation(ArrayList<Instruction> instructions) {
+ if (executingScript == null) {
+ return instructions;
+ }
+ Set<String> outputVariableNames = executingScript.getOutputVariables();
+ if (outputVariableNames == null) {
+ return instructions;
+ }
+
+ for (int i = 0; i < instructions.size(); i++) {
+ Instruction inst = instructions.get(i);
+ if (inst instanceof VariableCPInstruction && ((VariableCPInstruction) inst).isRemoveVariable()) {
+ VariableCPInstruction varInst = (VariableCPInstruction) inst;
+ for (String outputVariableName : outputVariableNames)
+ if (varInst.isRemoveVariable(outputVariableName)) {
+ instructions.remove(i);
+ i--;
+ break;
+ }
+ }
+ }
+ return instructions;
+ }
+ }
+
+ /**
+ * Used internally by MLContextProxy.
+ *
+ */
+ public InternalProxy getInternalProxy() {
+ return internalProxy;
+ }
+
+ /**
+ * Whether or not statistics of the DML/PYDML program execution should be
+ * output to standard output.
+ *
+ * @return {@code true} if statistics should be output, {@code false}
+ * otherwise
+ */
+ public boolean isStatistics() {
+ return statistics;
+ }
+
+ /**
+ * Whether or not statistics of the DML/PYDML program execution should be
+ * output to standard output.
+ *
+ * @param statistics
+ * {@code true} if statistics should be output, {@code false}
+ * otherwise
+ */
+ public void setStatistics(boolean statistics) {
+ DMLScript.STATISTICS = statistics;
+ this.statistics = statistics;
+ }
+
+ /**
+ * Obtain a map of the scripts that have executed.
+ *
+ * @return a map of the scripts that have executed
+ */
+ public Map<String, Script> getScripts() {
+ return scripts;
+ }
+
+ /**
+ * Obtain a script that has executed by name.
+ *
+ * @param name
+ * the name of the script
+ * @return the script corresponding to the name
+ */
+ public Script getScriptByName(String name) {
+ Script script = scripts.get(name);
+ if (script == null) {
+ throw new MLContextException("Script with name '" + name + "' not found.");
+ }
+ return script;
+ }
+
+ /**
+ * Display the history of scripts that have executed.
+ *
+ * @return the history of scripts that have executed
+ */
+ public String history() {
+ return MLContextUtil.displayScriptHistory(scriptHistoryStrings);
+ }
+
+ /**
+ * Clear all the scripts, removing them from the history, and clear the
+ * cache.
+ */
+ public void clear() {
+ Set<String> scriptNames = scripts.keySet();
+ for (String scriptName : scriptNames) {
+ Script script = scripts.get(scriptName);
+ script.clearAll();
+ }
+
+ scripts.clear();
+ scriptHistoryStrings.clear();
+
+ clearCache();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
new file mode 100644
index 0000000..33226d2
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -0,0 +1,720 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.spark.Accumulator;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+import org.apache.sysml.api.MLContextProxy;
+import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
+import org.apache.sysml.runtime.instructions.spark.functions.ConvertStringToLongTextPair;
+import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockPairFunction;
+import org.apache.sysml.runtime.instructions.spark.functions.CopyTextInputFunction;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.DataFrameAnalysisFunction;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.DataFrameToBinaryBlockFunction;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
+import org.apache.sysml.runtime.matrix.data.IJV;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.runtime.util.DataConverter;
+import org.apache.sysml.runtime.util.UtilFunctions;
+
+import scala.collection.JavaConversions;
+import scala.reflect.ClassTag;
+
+/**
+ * Utility class containing methods to perform data conversions.
+ *
+ */
+public class MLContextConversionUtil {
+
+ /**
+ * Convert a two-dimensional double array to a {@code MatrixObject}.
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param doubleMatrix
+ * matrix of double values
+ * @return the two-dimensional double matrix converted to a
+ * {@code MatrixObject}
+ */
+ public static MatrixObject doubleMatrixToMatrixObject(String variableName, double[][] doubleMatrix) {
+ return doubleMatrixToMatrixObject(variableName, doubleMatrix, null);
+ }
+
+ /**
+ * Convert a two-dimensional double array to a {@code MatrixObject}.
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param doubleMatrix
+ * matrix of double values
+ * @param matrixMetadata
+ * the matrix metadata
+ * @return the two-dimensional double matrix converted to a
+ * {@code MatrixObject}
+ */
+ public static MatrixObject doubleMatrixToMatrixObject(String variableName, double[][] doubleMatrix,
+ MatrixMetadata matrixMetadata) {
+ try {
+ MatrixBlock matrixBlock = DataConverter.convertToMatrixBlock(doubleMatrix);
+ MatrixCharacteristics matrixCharacteristics;
+ if (matrixMetadata != null) {
+ matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
+ } else {
+ matrixCharacteristics = new MatrixCharacteristics(matrixBlock.getNumRows(),
+ matrixBlock.getNumColumns(), MLContextUtil.defaultBlockSize(), MLContextUtil.defaultBlockSize());
+ }
+
+ MatrixFormatMetaData meta = new MatrixFormatMetaData(matrixCharacteristics,
+ OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
+ MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, MLContextUtil.scratchSpace() + "/"
+ + variableName, meta);
+ matrixObject.acquireModify(matrixBlock);
+ matrixObject.release();
+ return matrixObject;
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception converting double[][] array to MatrixObject", e);
+ }
+ }
+
+ /**
+ * Convert a {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} to a
+ * {@code MatrixObject}.
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param binaryBlocks
+ * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} representation
+ * of a binary-block matrix
+ * @return the {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} matrix
+ * converted to a {@code MatrixObject}
+ */
+ public static MatrixObject binaryBlocksToMatrixObject(String variableName,
+ JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks) {
+ return binaryBlocksToMatrixObject(variableName, binaryBlocks, null);
+ }
+
+ /**
+ * Convert a {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} to a
+ * {@code MatrixObject}.
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param binaryBlocks
+ * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} representation
+ * of a binary-block matrix
+ * @param matrixMetadata
+ * the matrix metadata
+ * @return the {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} matrix
+ * converted to a {@code MatrixObject}
+ */
+ public static MatrixObject binaryBlocksToMatrixObject(String variableName,
+ JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks, MatrixMetadata matrixMetadata) {
+
+ MatrixCharacteristics matrixCharacteristics;
+ if (matrixMetadata != null) {
+ matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
+ } else {
+ matrixCharacteristics = new MatrixCharacteristics();
+ }
+
+ JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRdd = binaryBlocks.mapToPair(new CopyBlockPairFunction());
+
+ MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, MLContextUtil.scratchSpace() + "/" + "temp_"
+ + System.nanoTime(), new MatrixFormatMetaData(matrixCharacteristics, OutputInfo.BinaryBlockOutputInfo,
+ InputInfo.BinaryBlockInputInfo));
+ matrixObject.setRDDHandle(new RDDObject(javaPairRdd, variableName));
+ return matrixObject;
+ }
+
+ /**
+ * Convert a {@code DataFrame} to a {@code MatrixObject}.
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param dataFrame
+ * the Spark {@code DataFrame}
+ * @return the {@code DataFrame} matrix converted to a converted to a
+ * {@code MatrixObject}
+ */
+ public static MatrixObject dataFrameToMatrixObject(String variableName, DataFrame dataFrame) {
+ return dataFrameToMatrixObject(variableName, dataFrame, null);
+ }
+
+ /**
+ * Convert a {@code DataFrame} to a {@code MatrixObject}.
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param dataFrame
+ * the Spark {@code DataFrame}
+ * @param matrixMetadata
+ * the matrix metadata
+ * @return the {@code DataFrame} matrix converted to a converted to a
+ * {@code MatrixObject}
+ */
+ public static MatrixObject dataFrameToMatrixObject(String variableName, DataFrame dataFrame,
+ MatrixMetadata matrixMetadata) {
+ if (matrixMetadata == null) {
+ matrixMetadata = new MatrixMetadata();
+ }
+ JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = MLContextConversionUtil.dataFrameToBinaryBlocks(
+ dataFrame, matrixMetadata);
+ MatrixObject matrixObject = MLContextConversionUtil.binaryBlocksToMatrixObject(variableName, binaryBlock,
+ matrixMetadata);
+ return matrixObject;
+ }
+
+ /**
+ * Convert a {@code DataFrame} to a
+ * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} binary-block matrix.
+ *
+ * @param dataFrame
+ * the Spark {@code DataFrame}
+ * @return the {@code DataFrame} matrix converted to a
+ * {@code JavaPairRDD<MatrixIndexes,
+ * MatrixBlock>} binary-block matrix
+ */
+ public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlocks(DataFrame dataFrame) {
+ return dataFrameToBinaryBlocks(dataFrame, null);
+ }
+
+ /**
+ * Convert a {@code DataFrame} to a
+ * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} binary-block matrix.
+ *
+ * @param dataFrame
+ * the Spark {@code DataFrame}
+ * @param matrixMetadata
+ * the matrix metadata
+ * @return the {@code DataFrame} matrix converted to a
+ * {@code JavaPairRDD<MatrixIndexes,
+ * MatrixBlock>} binary-block matrix
+ */
+ public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlocks(DataFrame dataFrame,
+ MatrixMetadata matrixMetadata) {
+
+ MatrixCharacteristics matrixCharacteristics;
+ if (matrixMetadata != null) {
+ matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
+ if (matrixCharacteristics == null) {
+ matrixCharacteristics = new MatrixCharacteristics();
+ }
+ } else {
+ matrixCharacteristics = new MatrixCharacteristics();
+ }
+ determineDataFrameDimensionsIfNeeded(dataFrame, matrixCharacteristics);
+ if (matrixMetadata != null) {
+ // so external reference can be updated with the metadata
+ matrixMetadata.setMatrixCharacteristics(matrixCharacteristics);
+ }
+
+ JavaRDD<Row> javaRDD = dataFrame.javaRDD();
+ JavaPairRDD<Row, Long> prepinput = javaRDD.zipWithIndex();
+ JavaPairRDD<MatrixIndexes, MatrixBlock> out = prepinput.mapPartitionsToPair(new DataFrameToBinaryBlockFunction(
+ matrixCharacteristics, false));
+ out = RDDAggregateUtils.mergeByKey(out);
+ return out;
+ }
+
+ /**
+ * If the {@code DataFrame} dimensions aren't present in the
+ * {@code MatrixCharacteristics} metadata, determine the dimensions and
+ * place them in the {@code MatrixCharacteristics} metadata.
+ *
+ * @param dataFrame
+ * the Spark {@code DataFrame}
+ * @param matrixCharacteristics
+ * the matrix metadata
+ */
+ public static void determineDataFrameDimensionsIfNeeded(DataFrame dataFrame,
+ MatrixCharacteristics matrixCharacteristics) {
+ if (!matrixCharacteristics.dimsKnown(true)) {
+ // only available to the new MLContext API, not the old API
+ MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext();
+ SparkContext sparkContext = activeMLContext.getSparkContext();
+ @SuppressWarnings("resource")
+ JavaSparkContext javaSparkContext = new JavaSparkContext(sparkContext);
+
+ Accumulator<Double> aNnz = javaSparkContext.accumulator(0L);
+ JavaRDD<Row> javaRDD = dataFrame.javaRDD().map(new DataFrameAnalysisFunction(aNnz, false));
+ long numRows = javaRDD.count();
+ long numColumns = dataFrame.columns().length;
+ long numNonZeros = UtilFunctions.toLong(aNnz.value());
+ matrixCharacteristics.set(numRows, numColumns, matrixCharacteristics.getRowsPerBlock(),
+ matrixCharacteristics.getColsPerBlock(), numNonZeros);
+ }
+ }
+
+ /**
+ * Convert a {@code JavaRDD<String>} in CSV format to a {@code MatrixObject}
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param javaRDD
+ * the Java RDD of strings
+ * @return the {@code JavaRDD<String>} converted to a {@code MatrixObject}
+ */
+ public static MatrixObject javaRDDStringCSVToMatrixObject(String variableName, JavaRDD<String> javaRDD) {
+ return javaRDDStringCSVToMatrixObject(variableName, javaRDD, null);
+ }
+
+ /**
+ * Convert a {@code JavaRDD<String>} in CSV format to a {@code MatrixObject}
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param javaRDD
+ * the Java RDD of strings
+ * @param matrixMetadata
+ * matrix metadata
+ * @return the {@code JavaRDD<String>} converted to a {@code MatrixObject}
+ */
+ public static MatrixObject javaRDDStringCSVToMatrixObject(String variableName, JavaRDD<String> javaRDD,
+ MatrixMetadata matrixMetadata) {
+ JavaPairRDD<LongWritable, Text> javaPairRDD = javaRDD.mapToPair(new ConvertStringToLongTextPair());
+ MatrixCharacteristics matrixCharacteristics;
+ if (matrixMetadata != null) {
+ matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
+ } else {
+ matrixCharacteristics = new MatrixCharacteristics();
+ }
+ MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, null, new MatrixFormatMetaData(
+ matrixCharacteristics, OutputInfo.CSVOutputInfo, InputInfo.CSVInputInfo));
+ JavaPairRDD<LongWritable, Text> javaPairRDD2 = javaPairRDD.mapToPair(new CopyTextInputFunction());
+ matrixObject.setRDDHandle(new RDDObject(javaPairRDD2, variableName));
+ return matrixObject;
+ }
+
+ /**
+ * Convert a {@code JavaRDD<String>} in IJV format to a {@code MatrixObject}
+ * . Note that metadata is required for IJV format.
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param javaRDD
+ * the Java RDD of strings
+ * @param matrixMetadata
+ * matrix metadata
+ * @return the {@code JavaRDD<String>} converted to a {@code MatrixObject}
+ */
+ public static MatrixObject javaRDDStringIJVToMatrixObject(String variableName, JavaRDD<String> javaRDD,
+ MatrixMetadata matrixMetadata) {
+ JavaPairRDD<LongWritable, Text> javaPairRDD = javaRDD.mapToPair(new ConvertStringToLongTextPair());
+ MatrixCharacteristics matrixCharacteristics;
+ if (matrixMetadata != null) {
+ matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
+ } else {
+ matrixCharacteristics = new MatrixCharacteristics();
+ }
+ MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, null, new MatrixFormatMetaData(
+ matrixCharacteristics, OutputInfo.TextCellOutputInfo, InputInfo.TextCellInputInfo));
+ JavaPairRDD<LongWritable, Text> javaPairRDD2 = javaPairRDD.mapToPair(new CopyTextInputFunction());
+ matrixObject.setRDDHandle(new RDDObject(javaPairRDD2, variableName));
+ return matrixObject;
+ }
+
+ /**
+ * Convert a {@code RDD<String>} in CSV format to a {@code MatrixObject}
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param rdd
+ * the RDD of strings
+ * @return the {@code RDD<String>} converted to a {@code MatrixObject}
+ */
+ public static MatrixObject rddStringCSVToMatrixObject(String variableName, RDD<String> rdd) {
+ return rddStringCSVToMatrixObject(variableName, rdd, null);
+ }
+
+ /**
+ * Convert a {@code RDD<String>} in CSV format to a {@code MatrixObject}
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param rdd
+ * the RDD of strings
+ * @param matrixMetadata
+ * matrix metadata
+ * @return the {@code RDD<String>} converted to a {@code MatrixObject}
+ */
+ public static MatrixObject rddStringCSVToMatrixObject(String variableName, RDD<String> rdd,
+ MatrixMetadata matrixMetadata) {
+ ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class);
+ JavaRDD<String> javaRDD = JavaRDD.fromRDD(rdd, tag);
+ return javaRDDStringCSVToMatrixObject(variableName, javaRDD, matrixMetadata);
+ }
+
+ /**
+ * Convert a {@code RDD<String>} in IJV format to a {@code MatrixObject}.
+ * Note that metadata is required for IJV format.
+ *
+ * @param variableName
+ * name of the variable associated with the matrix
+ * @param rdd
+ * the RDD of strings
+ * @param matrixMetadata
+ * matrix metadata
+ * @return the {@code RDD<String>} converted to a {@code MatrixObject}
+ */
+ public static MatrixObject rddStringIJVToMatrixObject(String variableName, RDD<String> rdd,
+ MatrixMetadata matrixMetadata) {
+ ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class);
+ JavaRDD<String> javaRDD = JavaRDD.fromRDD(rdd, tag);
+ return javaRDDStringIJVToMatrixObject(variableName, javaRDD, matrixMetadata);
+ }
+
+ /**
+ * Convert an {@code BinaryBlockMatrix} to a {@code JavaRDD<String>} in IVJ
+ * format.
+ *
+ * @param binaryBlockMatrix
+ * the {@code BinaryBlockMatrix}
+ * @return the {@code BinaryBlockMatrix} converted to a
+ * {@code JavaRDD<String>}
+ */
+ public static JavaRDD<String> binaryBlockMatrixToJavaRDDStringIJV(BinaryBlockMatrix binaryBlockMatrix) {
+ JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = binaryBlockMatrix.getBinaryBlocks();
+ MatrixCharacteristics matrixCharacteristics = binaryBlockMatrix.getMatrixCharacteristics();
+ try {
+ JavaRDD<String> javaRDDString = RDDConverterUtilsExt.binaryBlockToStringRDD(binaryBlock,
+ matrixCharacteristics, "text");
+ return javaRDDString;
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception converting BinaryBlockMatrix to JavaRDD<String> (ijv)", e);
+ }
+ }
+
+ /**
+ * Convert an {@code BinaryBlockMatrix} to a {@code RDD<String>} in IVJ
+ * format.
+ *
+ * @param binaryBlockMatrix
+ * the {@code BinaryBlockMatrix}
+ * @return the {@code BinaryBlockMatrix} converted to a {@code RDD<String>}
+ */
+ public static RDD<String> binaryBlockMatrixToRDDStringIJV(BinaryBlockMatrix binaryBlockMatrix) {
+ JavaRDD<String> javaRDD = binaryBlockMatrixToJavaRDDStringIJV(binaryBlockMatrix);
+ RDD<String> rdd = JavaRDD.toRDD(javaRDD);
+ return rdd;
+ }
+
+ /**
+ * Convert a {@code MatrixObject} to a {@code JavaRDD<String>} in CSV
+ * format.
+ *
+ * @param matrixObject
+ * the {@code MatrixObject}
+ * @return the {@code MatrixObject} converted to a {@code JavaRDD<String>}
+ */
+ public static JavaRDD<String> matrixObjectToJavaRDDStringCSV(MatrixObject matrixObject) {
+ List<String> list = matrixObjectToListStringCSV(matrixObject);
+
+ MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext();
+ SparkContext sc = activeMLContext.getSparkContext();
+ @SuppressWarnings("resource")
+ JavaSparkContext jsc = new JavaSparkContext(sc);
+ JavaRDD<String> javaRDDStringCSV = jsc.parallelize(list);
+ return javaRDDStringCSV;
+ }
+
+ /**
+ * Convert a {@code MatrixObject} to a {@code JavaRDD<String>} in IJV
+ * format.
+ *
+ * @param matrixObject
+ * the {@code MatrixObject}
+ * @return the {@code MatrixObject} converted to a {@code JavaRDD<String>}
+ */
+ public static JavaRDD<String> matrixObjectToJavaRDDStringIJV(MatrixObject matrixObject) {
+ List<String> list = matrixObjectToListStringIJV(matrixObject);
+
+ MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext();
+ SparkContext sc = activeMLContext.getSparkContext();
+ @SuppressWarnings("resource")
+ JavaSparkContext jsc = new JavaSparkContext(sc);
+ JavaRDD<String> javaRDDStringCSV = jsc.parallelize(list);
+ return javaRDDStringCSV;
+ }
+
+ /**
+ * Convert a {@code MatrixObject} to a {@code RDD<String>} in IJV format.
+ *
+ * @param matrixObject
+ * the {@code MatrixObject}
+ * @return the {@code MatrixObject} converted to a {@code RDD<String>}
+ */
+ public static RDD<String> matrixObjectToRDDStringIJV(MatrixObject matrixObject) {
+
+ // NOTE: The following works when called from Java but does not
+ // currently work when called from Spark Shell (when you call
+ // collect() on the RDD<String>).
+ //
+ // JavaRDD<String> javaRDD = jsc.parallelize(list);
+ // RDD<String> rdd = JavaRDD.toRDD(javaRDD);
+ //
+ // Therefore, we call parallelize() on the SparkContext rather than
+ // the JavaSparkContext to produce the RDD<String> for Scala.
+
+ List<String> list = matrixObjectToListStringIJV(matrixObject);
+
+ MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext();
+ SparkContext sc = activeMLContext.getSparkContext();
+ ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class);
+ RDD<String> rddString = sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag);
+ return rddString;
+ }
+
+ /**
+ * Convert a {@code MatrixObject} to a {@code RDD<String>} in CSV format.
+ *
+ * @param matrixObject
+ * the {@code MatrixObject}
+ * @return the {@code MatrixObject} converted to a {@code RDD<String>}
+ */
+ public static RDD<String> matrixObjectToRDDStringCSV(MatrixObject matrixObject) {
+
+ // NOTE: The following works when called from Java but does not
+ // currently work when called from Spark Shell (when you call
+ // collect() on the RDD<String>).
+ //
+ // JavaRDD<String> javaRDD = jsc.parallelize(list);
+ // RDD<String> rdd = JavaRDD.toRDD(javaRDD);
+ //
+ // Therefore, we call parallelize() on the SparkContext rather than
+ // the JavaSparkContext to produce the RDD<String> for Scala.
+
+ List<String> list = matrixObjectToListStringCSV(matrixObject);
+
+ MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext();
+ SparkContext sc = activeMLContext.getSparkContext();
+ ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class);
+ RDD<String> rddString = sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag);
+ return rddString;
+ }
+
+ /**
+ * Convert a {@code MatrixObject} to a {@code List<String>} in CSV format.
+ *
+ * @param matrixObject
+ * the {@code MatrixObject}
+ * @return the {@code MatrixObject} converted to a {@code List<String>}
+ */
+ public static List<String> matrixObjectToListStringCSV(MatrixObject matrixObject) {
+ try {
+ MatrixBlock mb = matrixObject.acquireRead();
+
+ int rows = mb.getNumRows();
+ int cols = mb.getNumColumns();
+ List<String> list = new ArrayList<String>();
+
+ if (mb.getNonZeros() > 0) {
+ if (mb.isInSparseFormat()) {
+ Iterator<IJV> iter = mb.getSparseBlockIterator();
+ int prevCellRow = -1;
+ StringBuilder sb = null;
+ while (iter.hasNext()) {
+ IJV cell = iter.next();
+ int i = cell.getI();
+ double v = cell.getV();
+ if (i > prevCellRow) {
+ if (sb == null) {
+ sb = new StringBuilder();
+ } else {
+ list.add(sb.toString());
+ sb = new StringBuilder();
+ }
+ sb.append(v);
+ prevCellRow = i;
+ } else if (i == prevCellRow) {
+ sb.append(",");
+ sb.append(v);
+ }
+ }
+ if (sb != null) {
+ list.add(sb.toString());
+ }
+ } else {
+ for (int i = 0; i < rows; i++) {
+ StringBuilder sb = new StringBuilder();
+ for (int j = 0; j < cols; j++) {
+ if (j > 0) {
+ sb.append(",");
+ }
+ sb.append(mb.getValueDenseUnsafe(i, j));
+ }
+ list.add(sb.toString());
+ }
+ }
+ }
+
+ matrixObject.release();
+ return list;
+ } catch (CacheException e) {
+ throw new MLContextException("Cache exception while converting matrix object to List<String> CSV format", e);
+ }
+ }
+
+ /**
+ * Convert a {@code MatrixObject} to a {@code List<String>} in IJV format.
+ *
+ * @param matrixObject
+ * the {@code MatrixObject}
+ * @return the {@code MatrixObject} converted to a {@code List<String>}
+ */
+ public static List<String> matrixObjectToListStringIJV(MatrixObject matrixObject) {
+ try {
+ MatrixBlock mb = matrixObject.acquireRead();
+
+ int rows = mb.getNumRows();
+ int cols = mb.getNumColumns();
+ List<String> list = new ArrayList<String>();
+
+ if (mb.getNonZeros() > 0) {
+ if (mb.isInSparseFormat()) {
+ Iterator<IJV> iter = mb.getSparseBlockIterator();
+ StringBuilder sb = null;
+ while (iter.hasNext()) {
+ IJV cell = iter.next();
+ sb = new StringBuilder();
+ sb.append(cell.getI() + 1);
+ sb.append(" ");
+ sb.append(cell.getJ() + 1);
+ sb.append(" ");
+ sb.append(cell.getV());
+ list.add(sb.toString());
+ }
+ } else {
+ StringBuilder sb = null;
+ for (int i = 0; i < rows; i++) {
+ sb = new StringBuilder();
+ for (int j = 0; j < cols; j++) {
+ sb = new StringBuilder();
+ sb.append(i + 1);
+ sb.append(" ");
+ sb.append(j + 1);
+ sb.append(" ");
+ sb.append(mb.getValueDenseUnsafe(i, j));
+ list.add(sb.toString());
+ }
+ }
+ }
+ }
+
+ matrixObject.release();
+ return list;
+ } catch (CacheException e) {
+ throw new MLContextException("Cache exception while converting matrix object to List<String> IJV format", e);
+ }
+ }
+
+ /**
+ * Convert a {@code MatrixObject} to a two-dimensional double array.
+ *
+ * @param matrixObject
+ * the {@code MatrixObject}
+ * @return the {@code MatrixObject} converted to a {@code double[][]}
+ */
+ public static double[][] matrixObjectToDoubleMatrix(MatrixObject matrixObject) {
+ try {
+ MatrixBlock mb = matrixObject.acquireRead();
+ double[][] matrix = DataConverter.convertToDoubleMatrix(mb);
+ matrixObject.release();
+ return matrix;
+ } catch (CacheException e) {
+ throw new MLContextException("Cache exception while converting matrix object to double matrix", e);
+ }
+ }
+
+ /**
+ * Convert a {@code MatrixObject} to a {@code DataFrame}.
+ *
+ * @param matrixObject
+ * the {@code MatrixObject}
+ * @param sparkExecutionContext
+ * the Spark execution context
+ * @return the {@code MatrixObject} converted to a {@code DataFrame}
+ */
+ public static DataFrame matrixObjectToDataFrame(MatrixObject matrixObject,
+ SparkExecutionContext sparkExecutionContext) {
+ try {
+ @SuppressWarnings("unchecked")
+ JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockMatrix = (JavaPairRDD<MatrixIndexes, MatrixBlock>) sparkExecutionContext
+ .getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo);
+ MatrixCharacteristics matrixCharacteristics = matrixObject.getMatrixCharacteristics();
+
+ MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext();
+ SparkContext sc = activeMLContext.getSparkContext();
+ SQLContext sqlContext = new SQLContext(sc);
+ DataFrame df = RDDConverterUtilsExt.binaryBlockToDataFrame(binaryBlockMatrix, matrixCharacteristics,
+ sqlContext);
+ return df;
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("DMLRuntimeException while converting matrix object to DataFrame", e);
+ }
+ }
+
+ /**
+ * Convert a {@code MatrixObject} to a {@code BinaryBlockMatrix}.
+ *
+ * @param matrixObject
+ * the {@code MatrixObject}
+ * @param sparkExecutionContext
+ * the Spark execution context
+ * @return the {@code MatrixObject} converted to a {@code BinaryBlockMatrix}
+ */
+ public static BinaryBlockMatrix matrixObjectToBinaryBlockMatrix(MatrixObject matrixObject,
+ SparkExecutionContext sparkExecutionContext) {
+ try {
+ @SuppressWarnings("unchecked")
+ JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = (JavaPairRDD<MatrixIndexes, MatrixBlock>) sparkExecutionContext
+ .getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo);
+ MatrixCharacteristics matrixCharacteristics = matrixObject.getMatrixCharacteristics();
+ BinaryBlockMatrix binaryBlockMatrix = new BinaryBlockMatrix(binaryBlock, matrixCharacteristics);
+ return binaryBlockMatrix;
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("DMLRuntimeException while converting matrix object to BinaryBlockMatrix", e);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/MLContextException.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextException.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextException.java
new file mode 100644
index 0000000..63e6b64
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextException.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+/**
+ * Uncaught exception representing SystemML exceptions that occur through the
+ * MLContext API
+ *
+ */
+public class MLContextException extends RuntimeException {
+
+ private static final long serialVersionUID = 1L;
+
+ public MLContextException() {
+ super();
+ }
+
+ public MLContextException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ public MLContextException(String message) {
+ super(message);
+ }
+
+ public MLContextException(Throwable cause) {
+ super(cause);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
new file mode 100644
index 0000000..feb616e
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -0,0 +1,844 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import java.io.FileNotFoundException;
+import java.text.DateFormat;
+import java.text.SimpleDateFormat;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Scanner;
+import java.util.Set;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.lang3.text.WordUtils;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+import org.apache.sysml.conf.CompilerConfig;
+import org.apache.sysml.conf.CompilerConfig.ConfigType;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.conf.DMLConfig;
+import org.apache.sysml.parser.ParseException;
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.instructions.cp.BooleanObject;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.DoubleObject;
+import org.apache.sysml.runtime.instructions.cp.IntObject;
+import org.apache.sysml.runtime.instructions.cp.StringObject;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+
+/**
+ * Utility class containing methods for working with the MLContext API.
+ *
+ */
+public final class MLContextUtil {
+
+ /**
+ * Basic data types supported by the MLContext API
+ */
+ @SuppressWarnings("rawtypes")
+ public static final Class[] BASIC_DATA_TYPES = { Integer.class, Boolean.class, Double.class, String.class };
+
+ /**
+ * Complex data types supported by the MLContext API
+ */
+ @SuppressWarnings("rawtypes")
+ public static final Class[] COMPLEX_DATA_TYPES = { JavaRDD.class, RDD.class, DataFrame.class,
+ BinaryBlockMatrix.class, Matrix.class, (new double[][] {}).getClass() };
+
+ /**
+ * All data types supported by the MLContext API
+ */
+ @SuppressWarnings("rawtypes")
+ public static final Class[] ALL_SUPPORTED_DATA_TYPES = (Class[]) ArrayUtils.addAll(BASIC_DATA_TYPES,
+ COMPLEX_DATA_TYPES);
+
+ /**
+ * Compare two version strings (ie, "1.4.0" and "1.4.1").
+ *
+ * @param versionStr1
+ * First version string.
+ * @param versionStr2
+ * Second version string.
+ * @return If versionStr1 is less than versionStr2, return {@code -1}. If
+ * versionStr1 equals versionStr2, return {@code 0}. If versionStr1
+ * is greater than versionStr2, return {@code 1}.
+ * @throws MLContextException
+ * if versionStr1 or versionStr2 is {@code null}
+ */
+ private static int compareVersion(String versionStr1, String versionStr2) {
+ if (versionStr1 == null) {
+ throw new MLContextException("First version argument to compareVersion() is null");
+ }
+ if (versionStr2 == null) {
+ throw new MLContextException("Second version argument to compareVersion() is null");
+ }
+
+ Scanner scanner1 = null;
+ Scanner scanner2 = null;
+ try {
+ scanner1 = new Scanner(versionStr1);
+ scanner2 = new Scanner(versionStr2);
+ scanner1.useDelimiter("\\.");
+ scanner2.useDelimiter("\\.");
+
+ while (scanner1.hasNextInt() && scanner2.hasNextInt()) {
+ int version1 = scanner1.nextInt();
+ int version2 = scanner2.nextInt();
+ if (version1 < version2) {
+ return -1;
+ } else if (version1 > version2) {
+ return 1;
+ }
+ }
+
+ return scanner1.hasNextInt() ? 1 : 0;
+ } finally {
+ scanner1.close();
+ scanner2.close();
+ }
+ }
+
+ /**
+ * Determine whether the Spark version is supported.
+ *
+ * @param sparkVersion
+ * Spark version string (ie, "1.5.0").
+ * @return {@code true} if Spark version supported; otherwise {@code false}.
+ */
+ public static boolean isSparkVersionSupported(String sparkVersion) {
+ if (compareVersion(sparkVersion, MLContext.SYSTEMML_MINIMUM_SPARK_VERSION) < 0) {
+ return false;
+ } else {
+ return true;
+ }
+ }
+
+ /**
+ * Check that the Spark version is supported. If it isn't supported, throw
+ * an MLContextException.
+ *
+ * @param sc
+ * SparkContext
+ * @throws MLContextException
+ * thrown if Spark version isn't supported
+ */
+ public static void verifySparkVersionSupported(SparkContext sc) {
+ if (!MLContextUtil.isSparkVersionSupported(sc.version())) {
+ throw new MLContextException("SystemML requires Spark " + MLContext.SYSTEMML_MINIMUM_SPARK_VERSION
+ + " or greater");
+ }
+ }
+
+ /**
+ * Set default SystemML configuration properties.
+ */
+ public static void setDefaultConfig() {
+ ConfigurationManager.setGlobalConfig(new DMLConfig());
+ }
+
+ /**
+ * Set SystemML configuration properties based on a configuration file.
+ *
+ * @param configFilePath
+ * Path to configuration file.
+ * @throws MLContextException
+ * if configuration file was not found or a parse exception
+ * occurred
+ */
+ public static void setConfig(String configFilePath) {
+ try {
+ DMLConfig config = new DMLConfig(configFilePath);
+ ConfigurationManager.setGlobalConfig(config);
+ } catch (ParseException e) {
+ throw new MLContextException("Parse Exception when setting config", e);
+ } catch (FileNotFoundException e) {
+ throw new MLContextException("File not found (" + configFilePath + ") when setting config", e);
+ }
+ }
+
+ /**
+ * Set SystemML compiler configuration properties for MLContext
+ */
+ public static void setCompilerConfig() {
+ CompilerConfig compilerConfig = new CompilerConfig();
+ compilerConfig.set(ConfigType.IGNORE_UNSPECIFIED_ARGS, true);
+ compilerConfig.set(ConfigType.REJECT_READ_WRITE_UNKNOWNS, false);
+ compilerConfig.set(ConfigType.ALLOW_CSE_PERSISTENT_READS, false);
+ ConfigurationManager.setGlobalConfig(compilerConfig);
+ }
+
+ /**
+ * Convenience method to generate a {@code Map<String, Object>} of key/value
+ * pairs.
+ * <p>
+ * Example:<br>
+ * {@code Map<String, Object> inputMap = MLContextUtil.generateInputMap("A", 1, "B", "two", "C", 3);}
+ * <br>
+ * <br>
+ * This is equivalent to:<br>
+ * <code>Map<String, Object> inputMap = new LinkedHashMap<String, Object>(){{
+ * <br>put("A", 1);
+ * <br>put("B", "two");
+ * <br>put("C", 3);
+ * <br>}};</code>
+ *
+ * @param objs
+ * List of String/Object pairs
+ * @return Map of String/Object pairs
+ * @throws MLContextException
+ * if the number of arguments is not an even number
+ */
+ public static Map<String, Object> generateInputMap(Object... objs) {
+ int len = objs.length;
+ if ((len & 1) == 1) {
+ throw new MLContextException("The number of arguments needs to be an even number");
+ }
+ Map<String, Object> map = new LinkedHashMap<String, Object>();
+ int i = 0;
+ while (i < len) {
+ map.put((String) objs[i++], objs[i++]);
+ }
+ return map;
+ }
+
+ /**
+ * Verify that the types of input values are supported.
+ *
+ * @param inputs
+ * Map of String/Object pairs
+ * @throws MLContextException
+ * if an input value type is not supported
+ */
+ public static void checkInputValueTypes(Map<String, Object> inputs) {
+ for (Entry<String, Object> entry : inputs.entrySet()) {
+ checkInputValueType(entry.getKey(), entry.getValue());
+ }
+ }
+
+ /**
+ * Verify that the type of input value is supported.
+ *
+ * @param name
+ * The name of the input
+ * @param value
+ * The value of the input
+ * @throws MLContextException
+ * if the input value type is not supported
+ */
+ public static void checkInputValueType(String name, Object value) {
+
+ if (name == null) {
+ throw new MLContextException("No input name supplied");
+ } else if (value == null) {
+ throw new MLContextException("No input value supplied");
+ }
+
+ Object o = value;
+ boolean supported = false;
+ for (Class<?> clazz : ALL_SUPPORTED_DATA_TYPES) {
+ if (o.getClass().equals(clazz)) {
+ supported = true;
+ break;
+ } else if (clazz.isAssignableFrom(o.getClass())) {
+ supported = true;
+ break;
+ }
+ }
+ if (!supported) {
+ throw new MLContextException("Input name (\"" + value + "\") value type not supported: " + o.getClass());
+ }
+ }
+
+ /**
+ * Verify that the type of input parameter value is supported.
+ *
+ * @param parameterName
+ * The name of the input parameter
+ * @param parameterValue
+ * The value of the input parameter
+ * @throws MLContextException
+ * if the input parameter value type is not supported
+ */
+ public static void checkInputParameterType(String parameterName, Object parameterValue) {
+
+ if (parameterName == null) {
+ throw new MLContextException("No parameter name supplied");
+ } else if (parameterValue == null) {
+ throw new MLContextException("No parameter value supplied");
+ } else if (!parameterName.startsWith("$")) {
+ throw new MLContextException("Input parameter name must start with a $");
+ }
+
+ Object o = parameterValue;
+ boolean supported = false;
+ for (Class<?> clazz : BASIC_DATA_TYPES) {
+ if (o.getClass().equals(clazz)) {
+ supported = true;
+ break;
+ } else if (clazz.isAssignableFrom(o.getClass())) {
+ supported = true;
+ break;
+ }
+ }
+ if (!supported) {
+ throw new MLContextException("Input parameter (\"" + parameterName + "\") value type not supported: "
+ + o.getClass());
+ }
+ }
+
+ /**
+ * Is the object one of the supported basic data types? (Integer, Boolean,
+ * Double, String)
+ *
+ * @param object
+ * the object type to be examined
+ * @return {@code true} if type is a basic data type; otherwise
+ * {@code false}.
+ */
+ public static boolean isBasicType(Object object) {
+ for (Class<?> clazz : BASIC_DATA_TYPES) {
+ if (object.getClass().equals(clazz)) {
+ return true;
+ } else if (clazz.isAssignableFrom(object.getClass())) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Is the object one of the supported complex data types? (JavaRDD, RDD,
+ * DataFrame, BinaryBlockMatrix, Matrix, double[][])
+ *
+ * @param object
+ * the object type to be examined
+ * @return {@code true} if type is a complexe data type; otherwise
+ * {@code false}.
+ */
+ public static boolean isComplexType(Object object) {
+ for (Class<?> clazz : COMPLEX_DATA_TYPES) {
+ if (object.getClass().equals(clazz)) {
+ return true;
+ } else if (clazz.isAssignableFrom(object.getClass())) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Converts non-string basic input parameter values to strings to pass to
+ * the parser.
+ *
+ * @param basicInputParameterMap
+ * map of input parameters
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return map of String/String name/value pairs
+ */
+ public static Map<String, String> convertInputParametersForParser(Map<String, Object> basicInputParameterMap,
+ ScriptType scriptType) {
+ if (basicInputParameterMap == null) {
+ return null;
+ }
+ if (scriptType == null) {
+ throw new MLContextException("ScriptType needs to be specified");
+ }
+ Map<String, String> convertedMap = new HashMap<String, String>();
+ for (Entry<String, Object> entry : basicInputParameterMap.entrySet()) {
+ String key = entry.getKey();
+ Object value = entry.getValue();
+ if (value == null) {
+ throw new MLContextException("Input parameter value is null for: " + entry.getKey());
+ } else if (value instanceof Integer) {
+ convertedMap.put(key, Integer.toString((Integer) value));
+ } else if (value instanceof Boolean) {
+ if (scriptType == ScriptType.DML) {
+ convertedMap.put(key, String.valueOf((Boolean) value).toUpperCase());
+ } else {
+ convertedMap.put(key, WordUtils.capitalize(String.valueOf((Boolean) value)));
+ }
+ } else if (value instanceof Double) {
+ convertedMap.put(key, Double.toString((Double) value));
+ } else if (value instanceof String) {
+ convertedMap.put(key, (String) value);
+ }
+ }
+ return convertedMap;
+ }
+
+ /**
+ * Convert input types to internal SystemML representations
+ *
+ * @param parameterName
+ * The name of the input parameter
+ * @param parameterValue
+ * The value of the input parameter
+ * @return input in SystemML data representation
+ */
+ public static Data convertInputType(String parameterName, Object parameterValue) {
+ return convertInputType(parameterName, parameterValue, null);
+ }
+
+ /**
+ * Convert input types to internal SystemML representations
+ *
+ * @param parameterName
+ * The name of the input parameter
+ * @param parameterValue
+ * The value of the input parameter
+ * @param matrixMetadata
+ * matrix metadata
+ * @return input in SystemML data representation
+ */
+ public static Data convertInputType(String parameterName, Object parameterValue, MatrixMetadata matrixMetadata) {
+ String name = parameterName;
+ Object value = parameterValue;
+ if (name == null) {
+ throw new MLContextException("Input parameter name is null");
+ } else if (value == null) {
+ throw new MLContextException("Input parameter value is null for: " + parameterName);
+ } else if (value instanceof JavaRDD<?>) {
+ @SuppressWarnings("unchecked")
+ JavaRDD<String> javaRDD = (JavaRDD<String>) value;
+ MatrixObject matrixObject;
+ if ((matrixMetadata != null) && (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV)) {
+ matrixObject = MLContextConversionUtil.javaRDDStringIJVToMatrixObject(name, javaRDD, matrixMetadata);
+ } else {
+ matrixObject = MLContextConversionUtil.javaRDDStringCSVToMatrixObject(name, javaRDD, matrixMetadata);
+ }
+ return matrixObject;
+ } else if (value instanceof RDD<?>) {
+ @SuppressWarnings("unchecked")
+ RDD<String> rdd = (RDD<String>) value;
+ MatrixObject matrixObject;
+ if ((matrixMetadata != null) && (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV)) {
+ matrixObject = MLContextConversionUtil.rddStringIJVToMatrixObject(name, rdd, matrixMetadata);
+ } else {
+ matrixObject = MLContextConversionUtil.rddStringCSVToMatrixObject(name, rdd, matrixMetadata);
+ }
+
+ return matrixObject;
+ } else if (value instanceof DataFrame) {
+ DataFrame dataFrame = (DataFrame) value;
+ MatrixObject matrixObject = MLContextConversionUtil
+ .dataFrameToMatrixObject(name, dataFrame, matrixMetadata);
+ return matrixObject;
+ } else if (value instanceof BinaryBlockMatrix) {
+ BinaryBlockMatrix binaryBlockMatrix = (BinaryBlockMatrix) value;
+ if (matrixMetadata == null) {
+ matrixMetadata = binaryBlockMatrix.getMatrixMetadata();
+ }
+ JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks = binaryBlockMatrix.getBinaryBlocks();
+ MatrixObject matrixObject = MLContextConversionUtil.binaryBlocksToMatrixObject(name, binaryBlocks,
+ matrixMetadata);
+ return matrixObject;
+ } else if (value instanceof Matrix) {
+ Matrix matrix = (Matrix) value;
+ MatrixObject matrixObject = matrix.asMatrixObject();
+ return matrixObject;
+ } else if (value instanceof double[][]) {
+ double[][] doubleMatrix = (double[][]) value;
+ MatrixObject matrixObject = MLContextConversionUtil.doubleMatrixToMatrixObject(name, doubleMatrix,
+ matrixMetadata);
+ return matrixObject;
+ } else if (value instanceof Integer) {
+ Integer i = (Integer) value;
+ IntObject iObj = new IntObject(i);
+ return iObj;
+ } else if (value instanceof Double) {
+ Double d = (Double) value;
+ DoubleObject dObj = new DoubleObject(d);
+ return dObj;
+ } else if (value instanceof String) {
+ String s = (String) value;
+ StringObject sObj = new StringObject(s);
+ return sObj;
+ } else if (value instanceof Boolean) {
+ Boolean b = (Boolean) value;
+ BooleanObject bObj = new BooleanObject(b);
+ return bObj;
+ }
+ return null;
+ }
+
+ /**
+ * Return the default matrix block size.
+ *
+ * @return the default matrix block size
+ */
+ public static int defaultBlockSize() {
+ DMLConfig conf = ConfigurationManager.getDMLConfig();
+ int blockSize = conf.getIntValue(DMLConfig.DEFAULT_BLOCK_SIZE);
+ return blockSize;
+ }
+
+ /**
+ * Return the location of the scratch space directory.
+ *
+ * @return the lcoation of the scratch space directory
+ */
+ public static String scratchSpace() {
+ DMLConfig conf = ConfigurationManager.getDMLConfig();
+ String scratchSpace = conf.getTextValue(DMLConfig.SCRATCH_SPACE);
+ return scratchSpace;
+ }
+
+ /**
+ * Return a double-quoted string with inner single and double quotes
+ * escaped.
+ *
+ * @param str
+ * the original string
+ * @return double-quoted string with inner single and double quotes escaped
+ */
+ public static String quotedString(String str) {
+ if (str == null) {
+ return null;
+ }
+
+ StringBuilder sb = new StringBuilder();
+ sb.append("\"");
+ for (int i = 0; i < str.length(); i++) {
+ char ch = str.charAt(i);
+ if ((ch == '\'') || (ch == '"')) {
+ if ((i > 0) && (str.charAt(i - 1) != '\\')) {
+ sb.append('\\');
+ } else if (i == 0) {
+ sb.append('\\');
+ }
+ }
+ sb.append(ch);
+ }
+ sb.append("\"");
+
+ return sb.toString();
+ }
+
+ /**
+ * Display the keys and values in a Map
+ *
+ * @param mapName
+ * the name of the map
+ * @param map
+ * Map of String keys and Object values
+ * @return the keys and values in the Map as a String
+ */
+ public static String displayMap(String mapName, Map<String, Object> map) {
+ StringBuilder sb = new StringBuilder();
+ sb.append(mapName);
+ sb.append(":\n");
+ Set<String> keys = map.keySet();
+ if (keys.isEmpty()) {
+ sb.append("None\n");
+ } else {
+ int count = 0;
+ for (String key : keys) {
+ sb.append(" [");
+ sb.append(++count);
+ sb.append("] ");
+ sb.append(key);
+ sb.append(": ");
+ sb.append(map.get(key));
+ sb.append("\n");
+ }
+ }
+ return sb.toString();
+ }
+
+ /**
+ * Display the values in a Set
+ *
+ * @param setName
+ * the name of the Set
+ * @param set
+ * Set of String values
+ * @return the values in the Set as a String
+ */
+ public static String displaySet(String setName, Set<String> set) {
+ StringBuilder sb = new StringBuilder();
+ sb.append(setName);
+ sb.append(":\n");
+ if (set.isEmpty()) {
+ sb.append("None\n");
+ } else {
+ int count = 0;
+ for (String value : set) {
+ sb.append(" [");
+ sb.append(++count);
+ sb.append("] ");
+ sb.append(value);
+ sb.append("\n");
+ }
+ }
+ return sb.toString();
+ }
+
+ /**
+ * Display the keys and values in the symbol table
+ *
+ * @param name
+ * the name of the symbol table
+ * @param symbolTable
+ * the LocalVariableMap
+ * @return the keys and values in the symbol table as a String
+ */
+ public static String displaySymbolTable(String name, LocalVariableMap symbolTable) {
+ StringBuilder sb = new StringBuilder();
+ sb.append(name);
+ sb.append(":\n");
+ sb.append(displaySymbolTable(symbolTable));
+ return sb.toString();
+ }
+
+ /**
+ * Display the keys and values in the symbol table
+ *
+ * @param symbolTable
+ * the LocalVariableMap
+ * @return the keys and values in the symbol table as a String
+ */
+ public static String displaySymbolTable(LocalVariableMap symbolTable) {
+ StringBuilder sb = new StringBuilder();
+ Set<String> keys = symbolTable.keySet();
+ if (keys.isEmpty()) {
+ sb.append("None\n");
+ } else {
+ int count = 0;
+ for (String key : keys) {
+ sb.append(" [");
+ sb.append(++count);
+ sb.append("]");
+
+ sb.append(" (");
+ sb.append(determineOutputTypeAsString(symbolTable, key));
+ sb.append(") ");
+
+ sb.append(key);
+
+ sb.append(": ");
+ sb.append(symbolTable.get(key));
+ sb.append("\n");
+ }
+ }
+ return sb.toString();
+ }
+
+ /**
+ * Obtain a symbol table output type as a String
+ *
+ * @param symbolTable
+ * the symbol table
+ * @param outputName
+ * the name of the output variable
+ * @return the symbol table output type for a variable as a String
+ */
+ public static String determineOutputTypeAsString(LocalVariableMap symbolTable, String outputName) {
+ Data data = symbolTable.get(outputName);
+ if (data instanceof BooleanObject) {
+ return "Boolean";
+ } else if (data instanceof DoubleObject) {
+ return "Double";
+ } else if (data instanceof IntObject) {
+ return "Long";
+ } else if (data instanceof StringObject) {
+ return "String";
+ } else if (data instanceof MatrixObject) {
+ return "Matrix";
+ } else if (data instanceof FrameObject) {
+ return "Frame";
+ }
+ return "Unknown";
+ }
+
+ /**
+ * Obtain a display of script inputs.
+ *
+ * @param name
+ * the title to display for the inputs
+ * @param map
+ * the map of inputs
+ * @return the script inputs represented as a String
+ */
+ public static String displayInputs(String name, Map<String, Object> map) {
+ StringBuilder sb = new StringBuilder();
+ sb.append(name);
+ sb.append(":\n");
+ Set<String> keys = map.keySet();
+ if (keys.isEmpty()) {
+ sb.append("None\n");
+ } else {
+ int count = 0;
+ for (String key : keys) {
+ Object object = map.get(key);
+ @SuppressWarnings("rawtypes")
+ Class clazz = object.getClass();
+ String type = clazz.getSimpleName();
+ if (object instanceof JavaRDD<?>) {
+ type = "JavaRDD";
+ } else if (object instanceof RDD<?>) {
+ type = "RDD";
+ }
+
+ sb.append(" [");
+ sb.append(++count);
+ sb.append("]");
+
+ sb.append(" (");
+ sb.append(type);
+ sb.append(") ");
+
+ sb.append(key);
+ sb.append(": ");
+ String str = object.toString();
+ str = StringUtils.abbreviate(str, 100);
+ sb.append(str);
+ sb.append("\n");
+ }
+ }
+ return sb.toString();
+ }
+
+ /**
+ * Obtain a display of the script outputs.
+ *
+ * @param name
+ * the title to display for the outputs
+ * @param outputNames
+ * the names of the output variables
+ * @param symbolTable
+ * the symbol table
+ * @return the script outputs represented as a String
+ *
+ */
+ public static String displayOutputs(String name, Set<String> outputNames, LocalVariableMap symbolTable) {
+ StringBuilder sb = new StringBuilder();
+ sb.append(name);
+ sb.append(":\n");
+ sb.append(displayOutputs(outputNames, symbolTable));
+ return sb.toString();
+ }
+
+ /**
+ * Obtain a display of the script outputs.
+ *
+ * @param outputNames
+ * the names of the output variables
+ * @param symbolTable
+ * the symbol table
+ * @return the script outputs represented as a String
+ *
+ */
+ public static String displayOutputs(Set<String> outputNames, LocalVariableMap symbolTable) {
+ StringBuilder sb = new StringBuilder();
+ if (outputNames.isEmpty()) {
+ sb.append("None\n");
+ } else {
+ int count = 0;
+ for (String outputName : outputNames) {
+ sb.append(" [");
+ sb.append(++count);
+ sb.append("] ");
+
+ if (symbolTable.get(outputName) != null) {
+ sb.append("(");
+ sb.append(determineOutputTypeAsString(symbolTable, outputName));
+ sb.append(") ");
+ }
+
+ sb.append(outputName);
+
+ if (symbolTable.get(outputName) != null) {
+ sb.append(": ");
+ sb.append(symbolTable.get(outputName));
+ }
+
+ sb.append("\n");
+ }
+ }
+ return sb.toString();
+ }
+
+ /**
+ * The SystemML welcome message
+ *
+ * @return the SystemML welcome message
+ */
+ public static String welcomeMessage() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("\nWelcome to Apache SystemML!\n");
+ return sb.toString();
+ }
+
+ /**
+ * Generate a String history entry for a script.
+ *
+ * @param script
+ * the script
+ * @param when
+ * when the script was executed
+ * @return a script history entry as a String
+ */
+ public static String createHistoryForScript(Script script, long when) {
+ DateFormat dateFormat = new SimpleDateFormat("MM/dd/yyyy HH:mm:ss.SSS");
+ StringBuilder sb = new StringBuilder();
+ sb.append("Script Name: " + script.getName() + "\n");
+ sb.append("When: " + dateFormat.format(new Date(when)) + "\n");
+ sb.append(script.displayInputs());
+ sb.append(script.displayOutputs());
+ sb.append(script.displaySymbolTable());
+ return sb.toString();
+ }
+
+ /**
+ * Generate a String listing of the script execution history.
+ *
+ * @param scriptHistory
+ * the list of script history entries
+ * @return the listing of the script execution history as a String
+ */
+ public static String displayScriptHistory(List<String> scriptHistory) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("MLContext Script History:\n");
+ if (scriptHistory.isEmpty()) {
+ sb.append("None");
+ }
+ int i = 1;
+ for (String history : scriptHistory) {
+ sb.append("--------------------------------------------\n");
+ sb.append("#" + (i++) + ":\n");
+ sb.append(history);
+ }
+ return sb.toString();
+ }
+
+}
[3/4] incubator-systemml git commit: [SYSTEMML-593] MLContext redesign
Posted by de...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
new file mode 100644
index 0000000..bd1b6bc
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
@@ -0,0 +1,1299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import java.util.Set;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.caching.CacheException;
+import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.BooleanObject;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.DoubleObject;
+import org.apache.sysml.runtime.instructions.cp.IntObject;
+import org.apache.sysml.runtime.instructions.cp.ScalarObject;
+import org.apache.sysml.runtime.instructions.cp.StringObject;
+import org.apache.sysml.runtime.matrix.data.FrameBlock;
+import org.apache.sysml.runtime.util.DataConverter;
+
+import scala.Tuple1;
+import scala.Tuple10;
+import scala.Tuple11;
+import scala.Tuple12;
+import scala.Tuple13;
+import scala.Tuple14;
+import scala.Tuple15;
+import scala.Tuple16;
+import scala.Tuple17;
+import scala.Tuple18;
+import scala.Tuple19;
+import scala.Tuple2;
+import scala.Tuple20;
+import scala.Tuple21;
+import scala.Tuple22;
+import scala.Tuple3;
+import scala.Tuple4;
+import scala.Tuple5;
+import scala.Tuple6;
+import scala.Tuple7;
+import scala.Tuple8;
+import scala.Tuple9;
+
+/**
+ * MLResults handles the results returned from executing a Script using the
+ * MLContext API.
+ *
+ */
+public class MLResults {
+ protected LocalVariableMap symbolTable = null;
+ protected Script script = null;
+ protected SparkExecutionContext sparkExecutionContext = null;
+
+ public MLResults() {
+ }
+
+ public MLResults(LocalVariableMap symbolTable) {
+ this.symbolTable = symbolTable;
+ }
+
+ public MLResults(Script script) {
+ this.script = script;
+ this.symbolTable = script.getSymbolTable();
+ ScriptExecutor scriptExecutor = script.getScriptExecutor();
+ ExecutionContext executionContext = scriptExecutor.getExecutionContext();
+ if (executionContext instanceof SparkExecutionContext) {
+ sparkExecutionContext = (SparkExecutionContext) executionContext;
+ }
+ }
+
+ /**
+ * Obtain an output as a {@code Data} object.
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code Data} object
+ */
+ public Data getData(String outputName) {
+ Set<String> keys = symbolTable.keySet();
+ if (!keys.contains(outputName)) {
+ throw new MLContextException("Variable '" + outputName + "' not found");
+ }
+ Data data = symbolTable.get(outputName);
+ return data;
+ }
+
+ /**
+ * Obtain an output as a {@code MatrixObject}
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code MatrixObject}
+ */
+ public MatrixObject getMatrixObject(String outputName) {
+ Data data = getData(outputName);
+ if (!(data instanceof MatrixObject)) {
+ throw new MLContextException("Variable '" + outputName + "' not a matrix");
+ }
+ MatrixObject mo = (MatrixObject) data;
+ return mo;
+ }
+
+ /**
+ * Obtain an output as a two-dimensional {@code double} array.
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a two-dimensional {@code double} array
+ */
+ public double[][] getDoubleMatrix(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ double[][] doubleMatrix = MLContextConversionUtil.matrixObjectToDoubleMatrix(mo);
+ return doubleMatrix;
+ }
+
+ /**
+ * Obtain an output as a {@code JavaRDD<String>} in IJV format.
+ * <p>
+ * The following matrix in DML:
+ * </p>
+ * <code>M = full('1 2 3 4', rows=2, cols=2);
+ * </code>
+ * <p>
+ * is equivalent to the following {@code JavaRDD<String>} in IJV format:
+ * </p>
+ * <code>1 1 1.0
+ * <br>1 2 2.0
+ * <br>2 1 3.0
+ * <br>2 2 4.0
+ * </code>
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code JavaRDD<String>} in IJV format
+ */
+ public JavaRDD<String> getJavaRDDStringIJV(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil.matrixObjectToJavaRDDStringIJV(mo);
+ return javaRDDStringIJV;
+ }
+
+ /**
+ * Obtain an output as a {@code JavaRDD<String>} in CSV format.
+ * <p>
+ * The following matrix in DML:
+ * </p>
+ * <code>M = full('1 2 3 4', rows=2, cols=2);
+ * </code>
+ * <p>
+ * is equivalent to the following {@code JavaRDD<String>} in CSV format:
+ * </p>
+ * <code>1.0,2.0
+ * <br>3.0,4.0
+ * </code>
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code JavaRDD<String>} in CSV format
+ */
+ public JavaRDD<String> getJavaRDDStringCSV(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ JavaRDD<String> javaRDDStringCSV = MLContextConversionUtil.matrixObjectToJavaRDDStringCSV(mo);
+ return javaRDDStringCSV;
+ }
+
+ /**
+ * Obtain an output as a {@code RDD<String>} in CSV format.
+ * <p>
+ * The following matrix in DML:
+ * </p>
+ * <code>M = full('1 2 3 4', rows=2, cols=2);
+ * </code>
+ * <p>
+ * is equivalent to the following {@code RDD<String>} in CSV format:
+ * </p>
+ * <code>1.0,2.0
+ * <br>3.0,4.0
+ * </code>
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code RDD<String>} in CSV format
+ */
+ public RDD<String> getRDDStringCSV(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ RDD<String> rddStringCSV = MLContextConversionUtil.matrixObjectToRDDStringCSV(mo);
+ return rddStringCSV;
+ }
+
+ /**
+ * Obtain an output as a {@code RDD<String>} in IJV format.
+ * <p>
+ * The following matrix in DML:
+ * </p>
+ * <code>M = full('1 2 3 4', rows=2, cols=2);
+ * </code>
+ * <p>
+ * is equivalent to the following {@code RDD<String>} in IJV format:
+ * </p>
+ * <code>1 1 1.0
+ * <br>1 2 2.0
+ * <br>2 1 3.0
+ * <br>2 2 4.0
+ * </code>
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code RDD<String>} in IJV format
+ */
+ public RDD<String> getRDDStringIJV(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ RDD<String> rddStringIJV = MLContextConversionUtil.matrixObjectToRDDStringIJV(mo);
+ return rddStringIJV;
+ }
+
+ /**
+ * Obtain an output as a {@code DataFrame} of doubles.
+ * <p>
+ * The following matrix in DML:
+ * </p>
+ * <code>M = full('1 2 3 4', rows=2, cols=2);
+ * </code>
+ * <p>
+ * is equivalent to the following {@code DataFrame} of doubles:
+ * </p>
+ * <code>[0.0,1.0,2.0]
+ * <br>[1.0,3.0,4.0]
+ * </code>
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code DataFrame} of doubles
+ */
+ public DataFrame getDataFrame(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext);
+ return df;
+ }
+
+ /**
+ * Obtain an output as a {@code Matrix}.
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code Matrix}
+ */
+ public Matrix getMatrix(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ Matrix matrix = new Matrix(mo, sparkExecutionContext);
+ return matrix;
+ }
+
+ /**
+ * Obtain an output as a {@code BinaryBlockMatrix}.
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code BinaryBlockMatrix}
+ */
+ public BinaryBlockMatrix getBinaryBlockMatrix(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ BinaryBlockMatrix binaryBlockMatrix = MLContextConversionUtil.matrixObjectToBinaryBlockMatrix(mo,
+ sparkExecutionContext);
+ return binaryBlockMatrix;
+ }
+
+ /**
+ * Obtain an output as a two-dimensional {@code String} array.
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a two-dimensional {@code String} array
+ */
+ public String[][] getFrame(String outputName) {
+ try {
+ Data data = getData(outputName);
+ if (!(data instanceof FrameObject)) {
+ throw new MLContextException("Variable '" + outputName + "' not a frame");
+ }
+ FrameObject fo = (FrameObject) data;
+ FrameBlock fb = fo.acquireRead();
+ String[][] frame = DataConverter.convertToStringFrame(fb);
+ fo.release();
+ return frame;
+ } catch (CacheException e) {
+ throw new MLContextException("Cache exception when reading frame", e);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("DML runtime exception when reading frame", e);
+ }
+ }
+
+ /**
+ * Obtain a {@code double} output
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code double}
+ */
+ public double getDouble(String outputName) {
+ ScalarObject so = getScalarObject(outputName);
+ return so.getDoubleValue();
+ }
+
+ /**
+ * Obtain an output as a {@code Scalar} object.
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code Scalar} object
+ */
+ public ScalarObject getScalarObject(String outputName) {
+ Data data = getData(outputName);
+ if (!(data instanceof ScalarObject)) {
+ throw new MLContextException("Variable '" + outputName + "' not a scalar");
+ }
+ ScalarObject so = (ScalarObject) data;
+ return so;
+ }
+
+ /**
+ * Obtain a {@code boolean} output
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code boolean}
+ */
+ public boolean getBoolean(String outputName) {
+ ScalarObject so = getScalarObject(outputName);
+ return so.getBooleanValue();
+ }
+
+ /**
+ * Obtain a {@code long} output
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code long}
+ */
+ public long getLong(String outputName) {
+ ScalarObject so = getScalarObject(outputName);
+ return so.getLongValue();
+ }
+
+ /**
+ * Obtain a {@code String} output
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code String}
+ */
+ public String getString(String outputName) {
+ ScalarObject so = getScalarObject(outputName);
+ return so.getStringValue();
+ }
+
+ /**
+ * Obtain the Script object associated with these results.
+ *
+ * @return the DML or PYDML Script object
+ */
+ public Script getScript() {
+ return script;
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T> Tuple1<T> getTuple(String outputName1) {
+ return new Tuple1<T>((T) outputValue(outputName1));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2> Tuple2<T1, T2> getTuple(String outputName1, String outputName2) {
+ return new Tuple2<T1, T2>((T1) outputValue(outputName1), (T2) outputValue(outputName2));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3> Tuple3<T1, T2, T3> getTuple(String outputName1, String outputName2, String outputName3) {
+ return new Tuple3<T1, T2, T3>((T1) outputValue(outputName1), (T2) outputValue(outputName2),
+ (T3) outputValue(outputName3));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4> Tuple4<T1, T2, T3, T4> getTuple(String outputName1, String outputName2, String outputName3,
+ String outputName4) {
+ return new Tuple4<T1, T2, T3, T4>((T1) outputValue(outputName1), (T2) outputValue(outputName2),
+ (T3) outputValue(outputName3), (T4) outputValue(outputName4));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5> Tuple5<T1, T2, T3, T4, T5> getTuple(String outputName1, String outputName2,
+ String outputName3, String outputName4, String outputName5) {
+ return new Tuple5<T1, T2, T3, T4, T5>((T1) outputValue(outputName1), (T2) outputValue(outputName2),
+ (T3) outputValue(outputName3), (T4) outputValue(outputName4), (T5) outputValue(outputName5));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6> Tuple6<T1, T2, T3, T4, T5, T6> getTuple(String outputName1, String outputName2,
+ String outputName3, String outputName4, String outputName5, String outputName6) {
+ return new Tuple6<T1, T2, T3, T4, T5, T6>((T1) outputValue(outputName1), (T2) outputValue(outputName2),
+ (T3) outputValue(outputName3), (T4) outputValue(outputName4), (T5) outputValue(outputName5),
+ (T6) outputValue(outputName6));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7> Tuple7<T1, T2, T3, T4, T5, T6, T7> getTuple(String outputName1,
+ String outputName2, String outputName3, String outputName4, String outputName5, String outputName6,
+ String outputName7) {
+ return new Tuple7<T1, T2, T3, T4, T5, T6, T7>((T1) outputValue(outputName1), (T2) outputValue(outputName2),
+ (T3) outputValue(outputName3), (T4) outputValue(outputName4), (T5) outputValue(outputName5),
+ (T6) outputValue(outputName6), (T7) outputValue(outputName7));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8> Tuple8<T1, T2, T3, T4, T5, T6, T7, T8> getTuple(String outputName1,
+ String outputName2, String outputName3, String outputName4, String outputName5, String outputName6,
+ String outputName7, String outputName8) {
+ return new Tuple8<T1, T2, T3, T4, T5, T6, T7, T8>((T1) outputValue(outputName1), (T2) outputValue(outputName2),
+ (T3) outputValue(outputName3), (T4) outputValue(outputName4), (T5) outputValue(outputName5),
+ (T6) outputValue(outputName6), (T7) outputValue(outputName7), (T8) outputValue(outputName8));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9> Tuple9<T1, T2, T3, T4, T5, T6, T7, T8, T9> getTuple(String outputName1,
+ String outputName2, String outputName3, String outputName4, String outputName5, String outputName6,
+ String outputName7, String outputName8, String outputName9) {
+ return new Tuple9<T1, T2, T3, T4, T5, T6, T7, T8, T9>((T1) outputValue(outputName1),
+ (T2) outputValue(outputName2), (T3) outputValue(outputName3), (T4) outputValue(outputName4),
+ (T5) outputValue(outputName5), (T6) outputValue(outputName6), (T7) outputValue(outputName7),
+ (T8) outputValue(outputName8), (T9) outputValue(outputName9));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10> Tuple10<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10) {
+ return new Tuple10<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>((T1) outputValue(outputName1),
+ (T2) outputValue(outputName2), (T3) outputValue(outputName3), (T4) outputValue(outputName4),
+ (T5) outputValue(outputName5), (T6) outputValue(outputName6), (T7) outputValue(outputName7),
+ (T8) outputValue(outputName8), (T9) outputValue(outputName9), (T10) outputValue(outputName10));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11> Tuple11<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11) {
+ return new Tuple11<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11>((T1) outputValue(outputName1),
+ (T2) outputValue(outputName2), (T3) outputValue(outputName3), (T4) outputValue(outputName4),
+ (T5) outputValue(outputName5), (T6) outputValue(outputName6), (T7) outputValue(outputName7),
+ (T8) outputValue(outputName8), (T9) outputValue(outputName9), (T10) outputValue(outputName10),
+ (T11) outputValue(outputName11));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @param outputName12
+ * the name of the twelfth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12> Tuple12<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11, String outputName12) {
+ return new Tuple12<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12>((T1) outputValue(outputName1),
+ (T2) outputValue(outputName2), (T3) outputValue(outputName3), (T4) outputValue(outputName4),
+ (T5) outputValue(outputName5), (T6) outputValue(outputName6), (T7) outputValue(outputName7),
+ (T8) outputValue(outputName8), (T9) outputValue(outputName9), (T10) outputValue(outputName10),
+ (T11) outputValue(outputName11), (T12) outputValue(outputName12));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @param outputName12
+ * the name of the twelfth output
+ * @param outputName13
+ * the name of the thirteenth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13> Tuple13<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11, String outputName12, String outputName13) {
+ return new Tuple13<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13>((T1) outputValue(outputName1),
+ (T2) outputValue(outputName2), (T3) outputValue(outputName3), (T4) outputValue(outputName4),
+ (T5) outputValue(outputName5), (T6) outputValue(outputName6), (T7) outputValue(outputName7),
+ (T8) outputValue(outputName8), (T9) outputValue(outputName9), (T10) outputValue(outputName10),
+ (T11) outputValue(outputName11), (T12) outputValue(outputName12), (T13) outputValue(outputName13));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @param outputName12
+ * the name of the twelfth output
+ * @param outputName13
+ * the name of the thirteenth output
+ * @param outputName14
+ * the name of the fourteenth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14> Tuple14<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11, String outputName12, String outputName13, String outputName14) {
+ return new Tuple14<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14>((T1) outputValue(outputName1),
+ (T2) outputValue(outputName2), (T3) outputValue(outputName3), (T4) outputValue(outputName4),
+ (T5) outputValue(outputName5), (T6) outputValue(outputName6), (T7) outputValue(outputName7),
+ (T8) outputValue(outputName8), (T9) outputValue(outputName9), (T10) outputValue(outputName10),
+ (T11) outputValue(outputName11), (T12) outputValue(outputName12), (T13) outputValue(outputName13),
+ (T14) outputValue(outputName14));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @param outputName12
+ * the name of the twelfth output
+ * @param outputName13
+ * the name of the thirteenth output
+ * @param outputName14
+ * the name of the fourteenth output
+ * @param outputName15
+ * the name of the fifteenth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15> Tuple15<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11, String outputName12, String outputName13, String outputName14, String outputName15) {
+ return new Tuple15<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15>(
+ (T1) outputValue(outputName1), (T2) outputValue(outputName2), (T3) outputValue(outputName3),
+ (T4) outputValue(outputName4), (T5) outputValue(outputName5), (T6) outputValue(outputName6),
+ (T7) outputValue(outputName7), (T8) outputValue(outputName8), (T9) outputValue(outputName9),
+ (T10) outputValue(outputName10), (T11) outputValue(outputName11), (T12) outputValue(outputName12),
+ (T13) outputValue(outputName13), (T14) outputValue(outputName14), (T15) outputValue(outputName15));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @param outputName12
+ * the name of the twelfth output
+ * @param outputName13
+ * the name of the thirteenth output
+ * @param outputName14
+ * the name of the fourteenth output
+ * @param outputName15
+ * the name of the fifteenth output
+ * @param outputName16
+ * the name of the sixteenth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16> Tuple16<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11, String outputName12, String outputName13, String outputName14, String outputName15,
+ String outputName16) {
+ return new Tuple16<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16>(
+ (T1) outputValue(outputName1), (T2) outputValue(outputName2), (T3) outputValue(outputName3),
+ (T4) outputValue(outputName4), (T5) outputValue(outputName5), (T6) outputValue(outputName6),
+ (T7) outputValue(outputName7), (T8) outputValue(outputName8), (T9) outputValue(outputName9),
+ (T10) outputValue(outputName10), (T11) outputValue(outputName11), (T12) outputValue(outputName12),
+ (T13) outputValue(outputName13), (T14) outputValue(outputName14), (T15) outputValue(outputName15),
+ (T16) outputValue(outputName16));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @param outputName12
+ * the name of the twelfth output
+ * @param outputName13
+ * the name of the thirteenth output
+ * @param outputName14
+ * the name of the fourteenth output
+ * @param outputName15
+ * the name of the fifteenth output
+ * @param outputName16
+ * the name of the sixteenth output
+ * @param outputName17
+ * the name of the seventeenth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17> Tuple17<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11, String outputName12, String outputName13, String outputName14, String outputName15,
+ String outputName16, String outputName17) {
+ return new Tuple17<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17>(
+ (T1) outputValue(outputName1), (T2) outputValue(outputName2), (T3) outputValue(outputName3),
+ (T4) outputValue(outputName4), (T5) outputValue(outputName5), (T6) outputValue(outputName6),
+ (T7) outputValue(outputName7), (T8) outputValue(outputName8), (T9) outputValue(outputName9),
+ (T10) outputValue(outputName10), (T11) outputValue(outputName11), (T12) outputValue(outputName12),
+ (T13) outputValue(outputName13), (T14) outputValue(outputName14), (T15) outputValue(outputName15),
+ (T16) outputValue(outputName16), (T17) outputValue(outputName17));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @param outputName12
+ * the name of the twelfth output
+ * @param outputName13
+ * the name of the thirteenth output
+ * @param outputName14
+ * the name of the fourteenth output
+ * @param outputName15
+ * the name of the fifteenth output
+ * @param outputName16
+ * the name of the sixteenth output
+ * @param outputName17
+ * the name of the seventeenth output
+ * @param outputName18
+ * the name of the eighteenth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18> Tuple18<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11, String outputName12, String outputName13, String outputName14, String outputName15,
+ String outputName16, String outputName17, String outputName18) {
+ return new Tuple18<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18>(
+ (T1) outputValue(outputName1), (T2) outputValue(outputName2), (T3) outputValue(outputName3),
+ (T4) outputValue(outputName4), (T5) outputValue(outputName5), (T6) outputValue(outputName6),
+ (T7) outputValue(outputName7), (T8) outputValue(outputName8), (T9) outputValue(outputName9),
+ (T10) outputValue(outputName10), (T11) outputValue(outputName11), (T12) outputValue(outputName12),
+ (T13) outputValue(outputName13), (T14) outputValue(outputName14), (T15) outputValue(outputName15),
+ (T16) outputValue(outputName16), (T17) outputValue(outputName17), (T18) outputValue(outputName18));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @param outputName12
+ * the name of the twelfth output
+ * @param outputName13
+ * the name of the thirteenth output
+ * @param outputName14
+ * the name of the fourteenth output
+ * @param outputName15
+ * the name of the fifteenth output
+ * @param outputName16
+ * the name of the sixteenth output
+ * @param outputName17
+ * the name of the seventeenth output
+ * @param outputName18
+ * the name of the eighteenth output
+ * @param outputName19
+ * the name of the nineteenth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19> Tuple19<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11, String outputName12, String outputName13, String outputName14, String outputName15,
+ String outputName16, String outputName17, String outputName18, String outputName19) {
+ return new Tuple19<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19>(
+ (T1) outputValue(outputName1), (T2) outputValue(outputName2), (T3) outputValue(outputName3),
+ (T4) outputValue(outputName4), (T5) outputValue(outputName5), (T6) outputValue(outputName6),
+ (T7) outputValue(outputName7), (T8) outputValue(outputName8), (T9) outputValue(outputName9),
+ (T10) outputValue(outputName10), (T11) outputValue(outputName11), (T12) outputValue(outputName12),
+ (T13) outputValue(outputName13), (T14) outputValue(outputName14), (T15) outputValue(outputName15),
+ (T16) outputValue(outputName16), (T17) outputValue(outputName17), (T18) outputValue(outputName18),
+ (T19) outputValue(outputName19));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @param outputName12
+ * the name of the twelfth output
+ * @param outputName13
+ * the name of the thirteenth output
+ * @param outputName14
+ * the name of the fourteenth output
+ * @param outputName15
+ * the name of the fifteenth output
+ * @param outputName16
+ * the name of the sixteenth output
+ * @param outputName17
+ * the name of the seventeenth output
+ * @param outputName18
+ * the name of the eighteenth output
+ * @param outputName19
+ * the name of the nineteenth output
+ * @param outputName20
+ * the name of the twentieth output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20> Tuple20<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11, String outputName12, String outputName13, String outputName14, String outputName15,
+ String outputName16, String outputName17, String outputName18, String outputName19, String outputName20) {
+ return new Tuple20<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20>(
+ (T1) outputValue(outputName1), (T2) outputValue(outputName2), (T3) outputValue(outputName3),
+ (T4) outputValue(outputName4), (T5) outputValue(outputName5), (T6) outputValue(outputName6),
+ (T7) outputValue(outputName7), (T8) outputValue(outputName8), (T9) outputValue(outputName9),
+ (T10) outputValue(outputName10), (T11) outputValue(outputName11), (T12) outputValue(outputName12),
+ (T13) outputValue(outputName13), (T14) outputValue(outputName14), (T15) outputValue(outputName15),
+ (T16) outputValue(outputName16), (T17) outputValue(outputName17), (T18) outputValue(outputName18),
+ (T19) outputValue(outputName19), (T20) outputValue(outputName20));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @param outputName12
+ * the name of the twelfth output
+ * @param outputName13
+ * the name of the thirteenth output
+ * @param outputName14
+ * the name of the fourteenth output
+ * @param outputName15
+ * the name of the fifteenth output
+ * @param outputName16
+ * the name of the sixteenth output
+ * @param outputName17
+ * the name of the seventeenth output
+ * @param outputName18
+ * the name of the eighteenth output
+ * @param outputName19
+ * the name of the nineteenth output
+ * @param outputName20
+ * the name of the twentieth output
+ * @param outputName21
+ * the name of the twenty-first output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21> Tuple21<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11, String outputName12, String outputName13, String outputName14, String outputName15,
+ String outputName16, String outputName17, String outputName18, String outputName19, String outputName20,
+ String outputName21) {
+ return new Tuple21<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21>(
+ (T1) outputValue(outputName1), (T2) outputValue(outputName2), (T3) outputValue(outputName3),
+ (T4) outputValue(outputName4), (T5) outputValue(outputName5), (T6) outputValue(outputName6),
+ (T7) outputValue(outputName7), (T8) outputValue(outputName8), (T9) outputValue(outputName9),
+ (T10) outputValue(outputName10), (T11) outputValue(outputName11), (T12) outputValue(outputName12),
+ (T13) outputValue(outputName13), (T14) outputValue(outputName14), (T15) outputValue(outputName15),
+ (T16) outputValue(outputName16), (T17) outputValue(outputName17), (T18) outputValue(outputName18),
+ (T19) outputValue(outputName19), (T20) outputValue(outputName20), (T21) outputValue(outputName21));
+ }
+
+ /**
+ * Obtain a Scala tuple.
+ *
+ * @param outputName1
+ * the name of the first output
+ * @param outputName2
+ * the name of the second output
+ * @param outputName3
+ * the name of the third output
+ * @param outputName4
+ * the name of the fourth output
+ * @param outputName5
+ * the name of the fifth output
+ * @param outputName6
+ * the name of the sixth output
+ * @param outputName7
+ * the name of the seventh output
+ * @param outputName8
+ * the name of the eighth output
+ * @param outputName9
+ * the name of the ninth output
+ * @param outputName10
+ * the name of the tenth output
+ * @param outputName11
+ * the name of the eleventh output
+ * @param outputName12
+ * the name of the twelfth output
+ * @param outputName13
+ * the name of the thirteenth output
+ * @param outputName14
+ * the name of the fourteenth output
+ * @param outputName15
+ * the name of the fifteenth output
+ * @param outputName16
+ * the name of the sixteenth output
+ * @param outputName17
+ * the name of the seventeenth output
+ * @param outputName18
+ * the name of the eighteenth output
+ * @param outputName19
+ * the name of the nineteenth output
+ * @param outputName20
+ * the name of the twentieth output
+ * @param outputName21
+ * the name of the twenty-first output
+ * @param outputName22
+ * the name of the twenty-second output
+ * @return a Scala tuple
+ */
+ @SuppressWarnings("unchecked")
+ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21, T22> Tuple22<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21, T22> getTuple(
+ String outputName1, String outputName2, String outputName3, String outputName4, String outputName5,
+ String outputName6, String outputName7, String outputName8, String outputName9, String outputName10,
+ String outputName11, String outputName12, String outputName13, String outputName14, String outputName15,
+ String outputName16, String outputName17, String outputName18, String outputName19, String outputName20,
+ String outputName21, String outputName22) {
+ return new Tuple22<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21, T22>(
+ (T1) outputValue(outputName1), (T2) outputValue(outputName2), (T3) outputValue(outputName3),
+ (T4) outputValue(outputName4), (T5) outputValue(outputName5), (T6) outputValue(outputName6),
+ (T7) outputValue(outputName7), (T8) outputValue(outputName8), (T9) outputValue(outputName9),
+ (T10) outputValue(outputName10), (T11) outputValue(outputName11), (T12) outputValue(outputName12),
+ (T13) outputValue(outputName13), (T14) outputValue(outputName14), (T15) outputValue(outputName15),
+ (T16) outputValue(outputName16), (T17) outputValue(outputName17), (T18) outputValue(outputName18),
+ (T19) outputValue(outputName19), (T20) outputValue(outputName20), (T21) outputValue(outputName21),
+ (T22) outputValue(outputName22));
+ }
+
+ /**
+ * Provide support for Scala tuples by returning an output value cast to a
+ * specific output type. MLResults tuple support requires specifying the
+ * object types at runtime to avoid the items in the tuple being returned as
+ * Anys.
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output value cast to a specific output type
+ */
+ @SuppressWarnings("unchecked")
+ private <T> T outputValue(String outputName) {
+ Data data = getData(outputName);
+ if (data instanceof BooleanObject) {
+ return (T) new Boolean(((BooleanObject) data).getBooleanValue());
+ } else if (data instanceof DoubleObject) {
+ return (T) new Double(((DoubleObject) data).getDoubleValue());
+ } else if (data instanceof IntObject) {
+ return (T) new Long(((IntObject) data).getLongValue());
+ } else if (data instanceof StringObject) {
+ return (T) ((StringObject) data).getStringValue();
+ } else if (data instanceof MatrixObject) {
+ return (T) getMatrix(outputName);
+ } else if (data instanceof FrameObject) {
+ return (T) getFrame(outputName);
+ }
+ return (T) data;
+ }
+
+ /**
+ * Obtain the symbol table, which is essentially a {@code Map<String, Data>}
+ * representing variables and their values as SystemML representations.
+ *
+ * @return the symbol table
+ */
+ public LocalVariableMap getSymbolTable() {
+ return symbolTable;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(MLContextUtil.displayOutputs(script.getOutputVariables(), symbolTable));
+ return sb.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java b/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
new file mode 100644
index 0000000..178a6e5
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+
+/**
+ * Matrix encapsulates a SystemML matrix. It allows for easy conversion to
+ * various other formats, such as RDDs, JavaRDDs, DataFrames,
+ * BinaryBlockMatrices, and double[][]s. After script execution, it offers a
+ * convenient format for obtaining SystemML matrix data in Scala tuples.
+ *
+ */
+public class Matrix {
+
+ private MatrixObject matrixObject;
+ private SparkExecutionContext sparkExecutionContext;
+
+ public Matrix(MatrixObject matrixObject, SparkExecutionContext sparkExecutionContext) {
+ this.matrixObject = matrixObject;
+ this.sparkExecutionContext = sparkExecutionContext;
+ }
+
+ /**
+ * Obtain the matrix as a SystemML MatrixObject.
+ *
+ * @return the matrix as a SystemML MatrixObject
+ */
+ public MatrixObject asMatrixObject() {
+ return matrixObject;
+ }
+
+ /**
+ * Obtain the matrix as a two-dimensional double array
+ *
+ * @return the matrix as a two-dimensional double array
+ */
+ public double[][] asDoubleMatrix() {
+ double[][] doubleMatrix = MLContextConversionUtil.matrixObjectToDoubleMatrix(matrixObject);
+ return doubleMatrix;
+ }
+
+ /**
+ * Obtain the matrix as a {@code JavaRDD<String>} in IJV format
+ *
+ * @return the matrix as a {@code JavaRDD<String>} in IJV format
+ */
+ public JavaRDD<String> asJavaRDDStringIJV() {
+ JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil.matrixObjectToJavaRDDStringIJV(matrixObject);
+ return javaRDDStringIJV;
+ }
+
+ /**
+ * Obtain the matrix as a {@code JavaRDD<String>} in CSV format
+ *
+ * @return the matrix as a {@code JavaRDD<String>} in CSV format
+ */
+ public JavaRDD<String> asJavaRDDStringCSV() {
+ JavaRDD<String> javaRDDStringCSV = MLContextConversionUtil.matrixObjectToJavaRDDStringCSV(matrixObject);
+ return javaRDDStringCSV;
+ }
+
+ /**
+ * Obtain the matrix as a {@code RDD<String>} in CSV format
+ *
+ * @return the matrix as a {@code RDD<String>} in CSV format
+ */
+ public RDD<String> asRDDStringCSV() {
+ RDD<String> rddStringCSV = MLContextConversionUtil.matrixObjectToRDDStringCSV(matrixObject);
+ return rddStringCSV;
+ }
+
+ /**
+ * Obtain the matrix as a {@code RDD<String>} in IJV format
+ *
+ * @return the matrix as a {@code RDD<String>} in IJV format
+ */
+ public RDD<String> asRDDStringIJV() {
+ RDD<String> rddStringIJV = MLContextConversionUtil.matrixObjectToRDDStringIJV(matrixObject);
+ return rddStringIJV;
+ }
+
+ /**
+ * Obtain the matrix as a {@code DataFrame}
+ *
+ * @return the matrix as a {@code DataFrame}
+ */
+ public DataFrame asDataFrame() {
+ DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext);
+ return df;
+ }
+
+ /**
+ * Obtain the matrix as a {@code BinaryBlockMatrix}
+ *
+ * @return the matrix as a {@code BinaryBlockMatrix}
+ */
+ public BinaryBlockMatrix asBinaryBlockMatrix() {
+ BinaryBlockMatrix binaryBlockMatrix = MLContextConversionUtil.matrixObjectToBinaryBlockMatrix(matrixObject,
+ sparkExecutionContext);
+ return binaryBlockMatrix;
+ }
+
+ /**
+ * Obtain the matrix metadata
+ *
+ * @return the matrix metadata
+ */
+ public MatrixMetadata getMatrixMetadata() {
+ MatrixCharacteristics matrixCharacteristics = matrixObject.getMatrixCharacteristics();
+ MatrixMetadata matrixMetadata = new MatrixMetadata(matrixCharacteristics);
+ return matrixMetadata;
+ }
+
+ @Override
+ public String toString() {
+ return matrixObject.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java b/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java
new file mode 100644
index 0000000..50ed634
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+/**
+ * MatrixFormat represents the different matrix formats supported by the
+ * MLContext API.
+ *
+ */
+public enum MatrixFormat {
+ /**
+ * Comma-separated value format (dense).
+ */
+ CSV,
+
+ /**
+ * (I J V) format (sparse). I and J represent matrix coordinates and V
+ * represents the value. The I J and V values are space-separated.
+ */
+ IJV;
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java b/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java
new file mode 100644
index 0000000..1ea3a10
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java
@@ -0,0 +1,522 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+
+/**
+ * Matrix metadata, such as the number of rows, the number of columns, the
+ * number of non-zero values, the number of rows per block, and the number of
+ * columns per block in the matrix.
+ *
+ */
+public class MatrixMetadata {
+
+ private Long numRows = null;
+ private Long numColumns = null;
+ private Long numNonZeros = null;
+ private Integer numRowsPerBlock = null;
+ private Integer numColumnsPerBlock = null;
+ private MatrixFormat matrixFormat;
+
+ public MatrixMetadata() {
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on matrix format.
+ *
+ * @param matrixFormat
+ * The matrix format.
+ */
+ public MatrixMetadata(MatrixFormat matrixFormat) {
+ this.matrixFormat = matrixFormat;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on matrix format, the
+ * number of rows, and the number of columns in a matrix.
+ *
+ * @param matrixFormat
+ * The matrix format.
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ */
+ public MatrixMetadata(MatrixFormat matrixFormat, Long numRows, Long numColumns) {
+ this.matrixFormat = matrixFormat;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on matrix format, the
+ * number of rows, and the number of columns in a matrix.
+ *
+ * @param matrixFormat
+ * The matrix format.
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ */
+ public MatrixMetadata(MatrixFormat matrixFormat, int numRows, int numColumns) {
+ this.matrixFormat = matrixFormat;
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on matrix format, the
+ * number of rows, the number of columns, and the number of non-zero values
+ * in a matrix.
+ *
+ * @param matrixFormat
+ * The matrix format.
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ * @param numNonZeros
+ * The number of non-zero values in the matrix.
+ */
+ public MatrixMetadata(MatrixFormat matrixFormat, Long numRows, Long numColumns, Long numNonZeros) {
+ this.matrixFormat = matrixFormat;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.numNonZeros = numNonZeros;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on matrix format, the
+ * number of rows, the number of columns, and the number of non-zero values
+ * in a matrix.
+ *
+ * @param matrixFormat
+ * The matrix format.
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ * @param numNonZeros
+ * The number of non-zero values in the matrix.
+ */
+ public MatrixMetadata(MatrixFormat matrixFormat, int numRows, int numColumns, int numNonZeros) {
+ this.matrixFormat = matrixFormat;
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ this.numNonZeros = (long) numNonZeros;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on matrix format, the
+ * number of rows, the number of columns, the number of non-zero values, the
+ * number of rows per block, and the number of columns per block in a
+ * matrix.
+ *
+ * @param matrixFormat
+ * The matrix format.
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ * @param numNonZeros
+ * The number of non-zero values in the matrix.
+ * @param numRowsPerBlock
+ * The number of rows per block in the matrix.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the matrix.
+ */
+ public MatrixMetadata(MatrixFormat matrixFormat, Long numRows, Long numColumns, Long numNonZeros,
+ Integer numRowsPerBlock, Integer numColumnsPerBlock) {
+ this.matrixFormat = matrixFormat;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.numNonZeros = numNonZeros;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on matrix format, the
+ * number of rows, the number of columns, the number of non-zero values, the
+ * number of rows per block, and the number of columns per block in a
+ * matrix.
+ *
+ * @param matrixFormat
+ * The matrix format.
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ * @param numNonZeros
+ * The number of non-zero values in the matrix.
+ * @param numRowsPerBlock
+ * The number of rows per block in the matrix.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the matrix.
+ */
+ public MatrixMetadata(MatrixFormat matrixFormat, int numRows, int numColumns, int numNonZeros, int numRowsPerBlock,
+ int numColumnsPerBlock) {
+ this.matrixFormat = matrixFormat;
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ this.numNonZeros = (long) numNonZeros;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on the number of rows
+ * and the number of columns in a matrix.
+ *
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ */
+ public MatrixMetadata(Long numRows, Long numColumns) {
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on the number of rows
+ * and the number of columns in a matrix.
+ *
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ */
+ public MatrixMetadata(int numRows, int numColumns) {
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on the number of
+ * rows, the number of columns, and the number of non-zero values in a
+ * matrix.
+ *
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ * @param numNonZeros
+ * The number of non-zero values in the matrix.
+ */
+ public MatrixMetadata(Long numRows, Long numColumns, Long numNonZeros) {
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.numNonZeros = numNonZeros;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on the number of
+ * rows, the number of columns, and the number of non-zero values in a
+ * matrix.
+ *
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ * @param numNonZeros
+ * The number of non-zero values in the matrix.
+ */
+ public MatrixMetadata(int numRows, int numColumns, int numNonZeros) {
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ this.numNonZeros = (long) numNonZeros;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on the number of
+ * rows, the number of columns, the number of rows per block, and the number
+ * of columns per block in a matrix.
+ *
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ * @param numRowsPerBlock
+ * The number of rows per block in the matrix.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the matrix.
+ */
+ public MatrixMetadata(Long numRows, Long numColumns, Integer numRowsPerBlock, Integer numColumnsPerBlock) {
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on the number of
+ * rows, the number of columns, the number of rows per block, and the number
+ * of columns per block in a matrix.
+ *
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ * @param numRowsPerBlock
+ * The number of rows per block in the matrix.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the matrix.
+ */
+ public MatrixMetadata(int numRows, int numColumns, int numRowsPerBlock, int numColumnsPerBlock) {
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on the number of
+ * rows, the number of columns, the number of non-zero values, the number of
+ * rows per block, and the number of columns per block in a matrix.
+ *
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ * @param numNonZeros
+ * The number of non-zero values in the matrix.
+ * @param numRowsPerBlock
+ * The number of rows per block in the matrix.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the matrix.
+ */
+ public MatrixMetadata(Long numRows, Long numColumns, Long numNonZeros, Integer numRowsPerBlock,
+ Integer numColumnsPerBlock) {
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.numNonZeros = numNonZeros;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on the number of
+ * rows, the number of columns, the number of non-zero values, the number of
+ * rows per block, and the number of columns per block in a matrix.
+ *
+ * @param numRows
+ * The number of rows in the matrix.
+ * @param numColumns
+ * The number of columns in the matrix.
+ * @param numNonZeros
+ * The number of non-zero values in the matrix.
+ * @param numRowsPerBlock
+ * The number of rows per block in the matrix.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the matrix.
+ */
+ public MatrixMetadata(int numRows, int numColumns, int numNonZeros, int numRowsPerBlock, int numColumnsPerBlock) {
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ this.numNonZeros = (long) numNonZeros;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a MatrixMetadata object based on a
+ * MatrixCharacteristics object.
+ *
+ * @param matrixCharacteristics
+ * the matrix metadata as a MatrixCharacteristics object
+ */
+ public MatrixMetadata(MatrixCharacteristics matrixCharacteristics) {
+ this.numRows = matrixCharacteristics.getRows();
+ this.numColumns = matrixCharacteristics.getCols();
+ this.numNonZeros = matrixCharacteristics.getNonZeros();
+ this.numRowsPerBlock = matrixCharacteristics.getRowsPerBlock();
+ this.numColumnsPerBlock = matrixCharacteristics.getColsPerBlock();
+ }
+
+ /**
+ * Set the MatrixMetadata fields based on a MatrixCharacteristics object.
+ *
+ * @param matrixCharacteristics
+ * the matrix metadata as a MatrixCharacteristics object
+ */
+ public void setMatrixCharacteristics(MatrixCharacteristics matrixCharacteristics) {
+ this.numRows = matrixCharacteristics.getRows();
+ this.numColumns = matrixCharacteristics.getCols();
+ this.numNonZeros = matrixCharacteristics.getNonZeros();
+ this.numRowsPerBlock = matrixCharacteristics.getRowsPerBlock();
+ this.numColumnsPerBlock = matrixCharacteristics.getColsPerBlock();
+ }
+
+ /**
+ * Obtain the number of rows
+ *
+ * @return the number of rows
+ */
+ public Long getNumRows() {
+ return numRows;
+ }
+
+ /**
+ * Set the number of rows
+ *
+ * @param numRows
+ * the number of rows
+ */
+ public void setNumRows(Long numRows) {
+ this.numRows = numRows;
+ }
+
+ /**
+ * Obtain the number of columns
+ *
+ * @return the number of columns
+ */
+ public Long getNumColumns() {
+ return numColumns;
+ }
+
+ /**
+ * Set the number of columns
+ *
+ * @param numColumns
+ * the number of columns
+ */
+ public void setNumColumns(Long numColumns) {
+ this.numColumns = numColumns;
+ }
+
+ /**
+ * Obtain the number of non-zero values
+ *
+ * @return the number of non-zero values
+ */
+ public Long getNumNonZeros() {
+ return numNonZeros;
+ }
+
+ /**
+ * Set the number of non-zero values
+ *
+ * @param numNonZeros
+ * the number of non-zero values
+ */
+ public void setNumNonZeros(Long numNonZeros) {
+ this.numNonZeros = numNonZeros;
+ }
+
+ /**
+ * Obtain the number of rows per block
+ *
+ * @return the number of rows per block
+ */
+ public Integer getNumRowsPerBlock() {
+ return numRowsPerBlock;
+ }
+
+ /**
+ * Set the number of rows per block
+ *
+ * @param numRowsPerBlock
+ * the number of rows per block
+ */
+ public void setNumRowsPerBlock(Integer numRowsPerBlock) {
+ this.numRowsPerBlock = numRowsPerBlock;
+ }
+
+ /**
+ * Obtain the number of columns per block
+ *
+ * @return the number of columns per block
+ */
+ public Integer getNumColumnsPerBlock() {
+ return numColumnsPerBlock;
+ }
+
+ /**
+ * Set the number of columns per block
+ *
+ * @param numColumnsPerBlock
+ * the number of columns per block
+ */
+ public void setNumColumnsPerBlock(Integer numColumnsPerBlock) {
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Convert the matrix metadata to a MatrixCharacteristics object. If all
+ * field values are {@code null}, {@code null} is returned.
+ *
+ * @return the matrix metadata as a MatrixCharacteristics object, or
+ * {@code null} if all field values are null
+ */
+ public MatrixCharacteristics asMatrixCharacteristics() {
+
+ if ((numRows == null) && (numColumns == null) && (numRowsPerBlock == null) && (numColumnsPerBlock == null)
+ && (numNonZeros == null)) {
+ return null;
+ }
+
+ long nr = (numRows == null) ? -1 : numRows;
+ long nc = (numColumns == null) ? -1 : numColumns;
+ int nrpb = (numRowsPerBlock == null) ? MLContextUtil.defaultBlockSize() : numRowsPerBlock;
+ int ncpb = (numColumnsPerBlock == null) ? MLContextUtil.defaultBlockSize() : numColumnsPerBlock;
+ long nnz = (numNonZeros == null) ? -1 : numNonZeros;
+ MatrixCharacteristics mc = new MatrixCharacteristics(nr, nc, nrpb, ncpb, nnz);
+ return mc;
+ }
+
+ @Override
+ public String toString() {
+ return "rows: " + fieldDisplay(numRows) + ", columns: " + fieldDisplay(numColumns) + ", non-zeros: "
+ + fieldDisplay(numNonZeros) + ", rows per block: " + fieldDisplay(numRowsPerBlock)
+ + ", columns per block: " + fieldDisplay(numColumnsPerBlock);
+ }
+
+ private String fieldDisplay(Object field) {
+ if (field == null) {
+ return "None";
+ } else {
+ return field.toString();
+ }
+ }
+
+ /**
+ * Obtain the matrix format
+ *
+ * @return the matrix format
+ */
+ public MatrixFormat getMatrixFormat() {
+ return matrixFormat;
+ }
+
+ /**
+ * Set the matrix format
+ *
+ * @param matrixFormat
+ * the matrix format
+ */
+ public void setMatrixFormat(MatrixFormat matrixFormat) {
+ this.matrixFormat = matrixFormat;
+ }
+
+}
[2/4] incubator-systemml git commit: [SYSTEMML-593] MLContext redesign
Posted by de...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/Script.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Script.java b/src/main/java/org/apache/sysml/api/mlcontext/Script.java
new file mode 100644
index 0000000..65d3338
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/Script.java
@@ -0,0 +1,652 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.instructions.cp.Data;
+
+import scala.Tuple2;
+import scala.Tuple3;
+import scala.collection.JavaConversions;
+
+/**
+ * A Script object encapsulates a DML or PYDML script.
+ *
+ */
+public class Script {
+
+ /**
+ * The type of script ({@code ScriptType.DML} or {@code ScriptType.PYDML}).
+ */
+ private ScriptType scriptType;
+ /**
+ * The script content.
+ */
+ private String scriptString;
+ /**
+ * The optional name of the script.
+ */
+ private String name;
+ /**
+ * All inputs (input parameters ($) and input variables).
+ */
+ private Map<String, Object> inputs = new LinkedHashMap<String, Object>();
+ /**
+ * The input parameters ($).
+ */
+ private Map<String, Object> inputParameters = new LinkedHashMap<String, Object>();
+ /**
+ * The input variables.
+ */
+ private Set<String> inputVariables = new LinkedHashSet<String>();
+ /**
+ * The input matrix metadata if present.
+ */
+ private Map<String, MatrixMetadata> inputMatrixMetadata = new LinkedHashMap<String, MatrixMetadata>();
+ /**
+ * The output variables.
+ */
+ private Set<String> outputVariables = new LinkedHashSet<String>();
+ /**
+ * The symbol table containing the data associated with variables.
+ */
+ private LocalVariableMap symbolTable = new LocalVariableMap();
+ /**
+ * The ScriptExecutor which is used to define the execution of the script.
+ */
+ private ScriptExecutor scriptExecutor;
+ /**
+ * The results of the execution of the script.
+ */
+ private MLResults results;
+
+ /**
+ * Script constructor, which by default creates a DML script.
+ */
+ public Script() {
+ scriptType = ScriptType.DML;
+ }
+
+ /**
+ * Script constructor, specifying the type of script ({@code ScriptType.DML}
+ * or {@code ScriptType.PYDML}).
+ *
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ */
+ public Script(ScriptType scriptType) {
+ this.scriptType = scriptType;
+ }
+
+ /**
+ * Script constructor, specifying the script content. By default, the script
+ * type is DML.
+ *
+ * @param scriptString
+ * the script content as a string
+ */
+ public Script(String scriptString) {
+ this.scriptString = scriptString;
+ this.scriptType = ScriptType.DML;
+ }
+
+ /**
+ * Script constructor, specifying the script content and the type of script
+ * (DML or PYDML).
+ *
+ * @param scriptString
+ * the script content as a string
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ */
+ public Script(String scriptString, ScriptType scriptType) {
+ this.scriptString = scriptString;
+ this.scriptType = scriptType;
+ }
+
+ /**
+ * Obtain the script type.
+ *
+ * @return {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ */
+ public ScriptType getScriptType() {
+ return scriptType;
+ }
+
+ /**
+ * Set the type of script (DML or PYDML).
+ *
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ */
+ public void setScriptType(ScriptType scriptType) {
+ this.scriptType = scriptType;
+ }
+
+ /**
+ * Obtain the script string.
+ *
+ * @return the script string
+ */
+ public String getScriptString() {
+ return scriptString;
+ }
+
+ /**
+ * Set the script string.
+ *
+ * @param scriptString
+ * the script string
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script setScriptString(String scriptString) {
+ this.scriptString = scriptString;
+ return this;
+ }
+
+ /**
+ * Obtain the input variable names as an unmodifiable set of strings.
+ *
+ * @return the input variable names
+ */
+ public Set<String> getInputVariables() {
+ return Collections.unmodifiableSet(inputVariables);
+ }
+
+ /**
+ * Obtain the output variable names as an unmodifiable set of strings.
+ *
+ * @return the output variable names
+ */
+ public Set<String> getOutputVariables() {
+ return Collections.unmodifiableSet(outputVariables);
+ }
+
+ /**
+ * Obtain the symbol table, which is essentially a
+ * {@code HashMap<String, Data>} representing variables and their values.
+ *
+ * @return the symbol table
+ */
+ public LocalVariableMap getSymbolTable() {
+ return symbolTable;
+ }
+
+ /**
+ * Obtain an unmodifiable map of all inputs (parameters ($) and variables).
+ *
+ * @return all inputs to the script
+ */
+ public Map<String, Object> getInputs() {
+ return Collections.unmodifiableMap(inputs);
+ }
+
+ /**
+ * Obtain an unmodifiable map of input matrix metadata.
+ *
+ * @return input matrix metadata
+ */
+ public Map<String, MatrixMetadata> getInputMatrixMetadata() {
+ return Collections.unmodifiableMap(inputMatrixMetadata);
+ }
+
+ /**
+ * Pass a map of inputs to the script.
+ *
+ * @param inputs
+ * map of inputs (parameters ($) and variables).
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script in(Map<String, Object> inputs) {
+ for (Entry<String, Object> input : inputs.entrySet()) {
+ in(input.getKey(), input.getValue());
+ }
+
+ return this;
+ }
+
+ /**
+ * Pass a Scala Map of inputs to the script.
+ *
+ * @param inputs
+ * Scala Map of inputs (parameters ($) and variables).
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script in(scala.collection.Map<String, Object> inputs) {
+ Map<String, Object> javaMap = JavaConversions.mapAsJavaMap(inputs);
+ in(javaMap);
+
+ return this;
+ }
+
+ /**
+ * Pass a Scala Seq of inputs to the script. The inputs are either two-value
+ * or three-value tuples, where the first value is the variable name, the
+ * second value is the variable value, and the third optional value is the
+ * metadata.
+ *
+ * @param inputs
+ * Scala Seq of inputs (parameters ($) and variables).
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script in(scala.collection.Seq<Object> inputs) {
+ List<Object> list = JavaConversions.asJavaList(inputs);
+ for (Object obj : list) {
+ if (obj instanceof Tuple3) {
+ @SuppressWarnings("unchecked")
+ Tuple3<String, Object, MatrixMetadata> t3 = (Tuple3<String, Object, MatrixMetadata>) obj;
+ in(t3._1(), t3._2(), t3._3());
+ } else if (obj instanceof Tuple2) {
+ @SuppressWarnings("unchecked")
+ Tuple2<String, Object> t2 = (Tuple2<String, Object>) obj;
+ in(t2._1(), t2._2());
+ } else {
+ throw new MLContextException("Only Tuples of 2 or 3 values are permitted");
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Obtain an unmodifiable map of all input parameters ($).
+ *
+ * @return input parameters ($)
+ */
+ public Map<String, Object> getInputParameters() {
+ return inputParameters;
+ }
+
+ /**
+ * Register an input (parameter ($) or variable).
+ *
+ * @param name
+ * name of the input
+ * @param value
+ * value of the input
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script in(String name, Object value) {
+ return in(name, value, null);
+ }
+
+ /**
+ * Register an input (parameter ($) or variable) with optional matrix
+ * metadata.
+ *
+ * @param name
+ * name of the input
+ * @param value
+ * value of the input
+ * @param matrixMetadata
+ * optional matrix metadata
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script in(String name, Object value, MatrixMetadata matrixMetadata) {
+ MLContextUtil.checkInputValueType(name, value);
+ if (inputs == null) {
+ inputs = new LinkedHashMap<String, Object>();
+ }
+ inputs.put(name, value);
+
+ if (name.startsWith("$")) {
+ MLContextUtil.checkInputParameterType(name, value);
+ if (inputParameters == null) {
+ inputParameters = new LinkedHashMap<String, Object>();
+ }
+ inputParameters.put(name, value);
+ } else {
+ Data data = MLContextUtil.convertInputType(name, value, matrixMetadata);
+ if (data != null) {
+ symbolTable.put(name, data);
+ inputVariables.add(name);
+ if (data instanceof MatrixObject) {
+ if (matrixMetadata != null) {
+ inputMatrixMetadata.put(name, matrixMetadata);
+ }
+ }
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Register an output variable.
+ *
+ * @param outputName
+ * name of the output variable
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script out(String outputName) {
+ outputVariables.add(outputName);
+ return this;
+ }
+
+ /**
+ * Register output variables.
+ *
+ * @param outputNames
+ * names of the output variables
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script out(String... outputNames) {
+ outputVariables.addAll(Arrays.asList(outputNames));
+ return this;
+ }
+
+ /**
+ * Clear the inputs, outputs, and symbol table.
+ */
+ public void clearIOS() {
+ clearInputs();
+ clearOutputs();
+ clearSymbolTable();
+ }
+
+ /**
+ * Clear the inputs and outputs, but not the symbol table.
+ */
+ public void clearIO() {
+ clearInputs();
+ clearOutputs();
+ }
+
+ /**
+ * Clear the script string, inputs, outputs, and symbol table.
+ */
+ public void clearAll() {
+ scriptString = null;
+ clearIOS();
+ }
+
+ /**
+ * Clear the inputs.
+ */
+ public void clearInputs() {
+ inputs.clear();
+ inputParameters.clear();
+ inputVariables.clear();
+ inputMatrixMetadata.clear();
+ }
+
+ /**
+ * Clear the outputs.
+ */
+ public void clearOutputs() {
+ outputVariables.clear();
+ }
+
+ /**
+ * Clear the symbol table.
+ */
+ public void clearSymbolTable() {
+ symbolTable.removeAll();
+ }
+
+ /**
+ * Obtain the results of the script execution.
+ *
+ * @return the results of the script execution.
+ */
+ public MLResults results() {
+ return results;
+ }
+
+ /**
+ * Obtain the results of the script execution.
+ *
+ * @return the results of the script execution.
+ */
+ public MLResults getResults() {
+ return results;
+ }
+
+ /**
+ * Set the results of the script execution.
+ *
+ * @param results
+ * the results of the script execution.
+ */
+ public void setResults(MLResults results) {
+ this.results = results;
+ }
+
+ /**
+ * Obtain the script executor used by this Script.
+ *
+ * @return the ScriptExecutor used by this Script.
+ */
+ public ScriptExecutor getScriptExecutor() {
+ return scriptExecutor;
+ }
+
+ /**
+ * Set the ScriptExecutor used by this Script.
+ *
+ * @param scriptExecutor
+ * the script executor
+ */
+ public void setScriptExecutor(ScriptExecutor scriptExecutor) {
+ this.scriptExecutor = scriptExecutor;
+ }
+
+ /**
+ * Is the script type DML?
+ *
+ * @return {@code true} if the script type is DML, {@code false} otherwise
+ */
+ public boolean isDML() {
+ return scriptType.isDML();
+ }
+
+ /**
+ * Is the script type PYDML?
+ *
+ * @return {@code true} if the script type is PYDML, {@code false} otherwise
+ */
+ public boolean isPYDML() {
+ return scriptType.isPYDML();
+ }
+
+ /**
+ * Generate the script execution string, which adds read/load/write/save
+ * statements to the beginning and end of the script to execute.
+ *
+ * @return the script execution string
+ */
+ public String getScriptExecutionString() {
+ StringBuilder sb = new StringBuilder();
+
+ Set<String> ins = getInputVariables();
+ for (String in : ins) {
+ Object inValue = getInputs().get(in);
+ sb.append(in);
+ if (isDML()) {
+ if (inValue instanceof String) {
+ String quotedString = MLContextUtil.quotedString((String) inValue);
+ sb.append(" = " + quotedString + ";\n");
+ } else if (MLContextUtil.isBasicType(inValue)) {
+ sb.append(" = read('', data_type='scalar');\n");
+ } else {
+ sb.append(" = read('');\n");
+ }
+ } else if (isPYDML()) {
+ if (inValue instanceof String) {
+ String quotedString = MLContextUtil.quotedString((String) inValue);
+ sb.append(" = " + quotedString + "\n");
+ } else if (MLContextUtil.isBasicType(inValue)) {
+ sb.append(" = load('', data_type='scalar')\n");
+ } else {
+ sb.append(" = load('')\n");
+ }
+ }
+
+ }
+
+ sb.append(getScriptString());
+ if (!getScriptString().endsWith("\n")) {
+ sb.append("\n");
+ }
+
+ Set<String> outs = getOutputVariables();
+ for (String out : outs) {
+ if (isDML()) {
+ sb.append("write(");
+ sb.append(out);
+ sb.append(", '');\n");
+ } else if (isPYDML()) {
+ sb.append("save(");
+ sb.append(out);
+ sb.append(", '')\n");
+ }
+ }
+
+ return sb.toString();
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+
+ sb.append(MLContextUtil.displayInputs("Inputs", inputs));
+ sb.append("\n");
+ sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable));
+ return sb.toString();
+ }
+
+ /**
+ * Display information about the script as a String. This consists of the
+ * script type, inputs, outputs, input parameters, input variables, output
+ * variables, the symbol table, the script string, and the script execution
+ * string.
+ *
+ * @return information about this script as a String
+ */
+ public String info() {
+ StringBuilder sb = new StringBuilder();
+
+ sb.append("Script Type: ");
+ sb.append(scriptType);
+ sb.append("\n\n");
+ sb.append(MLContextUtil.displayInputs("Inputs", inputs));
+ sb.append("\n");
+ sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable));
+ sb.append("\n");
+ sb.append(MLContextUtil.displayMap("Input Parameters", inputParameters));
+ sb.append("\n");
+ sb.append(MLContextUtil.displaySet("Input Variables", inputVariables));
+ sb.append("\n");
+ sb.append(MLContextUtil.displaySet("Output Variables", outputVariables));
+ sb.append("\n");
+ sb.append(MLContextUtil.displaySymbolTable("Symbol Table", symbolTable));
+ sb.append("\nScript String:\n");
+ sb.append(scriptString);
+ sb.append("\nScript Execution String:\n");
+ sb.append(getScriptExecutionString());
+ sb.append("\n");
+
+ return sb.toString();
+ }
+
+ /**
+ * Display the script inputs.
+ *
+ * @return the script inputs
+ */
+ public String displayInputs() {
+ return MLContextUtil.displayInputs("Inputs", inputs);
+ }
+
+ /**
+ * Display the script outputs.
+ *
+ * @return the script outputs as a String
+ */
+ public String displayOutputs() {
+ return MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable);
+ }
+
+ /**
+ * Display the script input parameters.
+ *
+ * @return the script input parameters as a String
+ */
+ public String displayInputParameters() {
+ return MLContextUtil.displayMap("Input Parameters", inputParameters);
+ }
+
+ /**
+ * Display the script input variables.
+ *
+ * @return the script input variables as a String
+ */
+ public String displayInputVariables() {
+ return MLContextUtil.displaySet("Input Variables", inputVariables);
+ }
+
+ /**
+ * Display the script output variables.
+ *
+ * @return the script output variables as a String
+ */
+ public String displayOutputVariables() {
+ return MLContextUtil.displaySet("Output Variables", outputVariables);
+ }
+
+ /**
+ * Display the script symbol table.
+ *
+ * @return the script symbol table as a String
+ */
+ public String displaySymbolTable() {
+ return MLContextUtil.displaySymbolTable("Symbol Table", symbolTable);
+ }
+
+ /**
+ * Obtain the script name.
+ *
+ * @return the script name
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Set the script name.
+ *
+ * @param name
+ * the script name
+ * @return {@code this} Script object to allow chaining of methods
+ */
+ public Script setName(String name) {
+ this.name = name;
+ return this;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
new file mode 100644
index 0000000..4702af2
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -0,0 +1,624 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.jmlc.JMLCUtils;
+import org.apache.sysml.api.monitoring.SparkMonitoringUtil;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.conf.DMLConfig;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.OptimizerUtils.OptimizationLevel;
+import org.apache.sysml.hops.globalopt.GlobalOptimizerWrapper;
+import org.apache.sysml.hops.rewrite.ProgramRewriter;
+import org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite;
+import org.apache.sysml.lops.LopsException;
+import org.apache.sysml.parser.AParserWrapper;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.DMLTranslator;
+import org.apache.sysml.parser.LanguageException;
+import org.apache.sysml.parser.ParseException;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.Program;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysml.utils.Explain;
+import org.apache.sysml.utils.Explain.ExplainCounts;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ * ScriptExecutor executes a DML or PYDML Script object using SystemML. This is
+ * accomplished by calling the {@link #execute} method.
+ * <p>
+ * Script execution via the MLContext API typically consists of the following
+ * steps:
+ * </p>
+ * <ol>
+ * <li>Language Steps
+ * <ol>
+ * <li>Parse script into program</li>
+ * <li>Live variable analysis</li>
+ * <li>Validate program</li>
+ * </ol>
+ * </li>
+ * <li>HOP (High-Level Operator) Steps
+ * <ol>
+ * <li>Construct HOP DAGs</li>
+ * <li>Static rewrites</li>
+ * <li>Intra-/Inter-procedural analysis</li>
+ * <li>Dynamic rewrites</li>
+ * <li>Compute memory estimates</li>
+ * <li>Rewrite persistent reads and writes (MLContext-specific)</li>
+ * </ol>
+ * </li>
+ * <li>LOP (Low-Level Operator) Steps
+ * <ol>
+ * <li>Contruct LOP DAGs</li>
+ * <li>Generate runtime program</li>
+ * <li>Execute runtime program</li>
+ * <li>Dynamic recompilation</li>
+ * </ol>
+ * </li>
+ * </ol>
+ * <p>
+ * Modifications to these steps can be accomplished by subclassing
+ * ScriptExecutor. For example, the following code will turn off the global data
+ * flow optimization check by subclassing ScriptExecutor and overriding the
+ * globalDataFlowOptimization method.
+ * </p>
+ *
+ * <code>ScriptExecutor scriptExecutor = new ScriptExecutor() {
+ * <br> // turn off global data flow optimization check
+ * <br> @Override
+ * <br> protected void globalDataFlowOptimization() {
+ * <br> return;
+ * <br> }
+ * <br>};
+ * <br>ml.execute(script, scriptExecutor);</code>
+ * <p>
+ *
+ * For more information, please see the {@link #execute} method.
+ */
+public class ScriptExecutor {
+
+ protected DMLConfig config;
+ protected SparkMonitoringUtil sparkMonitoringUtil;
+ protected DMLProgram dmlProgram;
+ protected DMLTranslator dmlTranslator;
+ protected Program runtimeProgram;
+ protected ExecutionContext executionContext;
+ protected Script script;
+ protected boolean explain = false;
+ protected boolean statistics = false;
+
+ /**
+ * ScriptExecutor constructor.
+ */
+ public ScriptExecutor() {
+ config = ConfigurationManager.getDMLConfig();
+ }
+
+ /**
+ * ScriptExecutor constructor, where the configuration properties are passed
+ * in.
+ *
+ * @param config
+ * the configuration properties to use by the ScriptExecutor
+ */
+ public ScriptExecutor(DMLConfig config) {
+ this.config = config;
+ ConfigurationManager.setGlobalConfig(config);
+ }
+
+ /**
+ * ScriptExecutor constructor, where a SparkMonitoringUtil object is passed
+ * in.
+ *
+ * @param sparkMonitoringUtil
+ * SparkMonitoringUtil object to monitor Spark
+ */
+ public ScriptExecutor(SparkMonitoringUtil sparkMonitoringUtil) {
+ this();
+ this.sparkMonitoringUtil = sparkMonitoringUtil;
+ }
+
+ /**
+ * ScriptExecutor constructor, where the configuration properties and a
+ * SparkMonitoringUtil object are passed in.
+ *
+ * @param config
+ * the configuration properties to use by the ScriptExecutor
+ * @param sparkMonitoringUtil
+ * SparkMonitoringUtil object to monitor Spark
+ */
+ public ScriptExecutor(DMLConfig config, SparkMonitoringUtil sparkMonitoringUtil) {
+ this.config = config;
+ this.sparkMonitoringUtil = sparkMonitoringUtil;
+ }
+
+ /**
+ * Construct DAGs of high-level operators (HOPs) for each block of
+ * statements.
+ */
+ protected void constructHops() {
+ try {
+ dmlTranslator.constructHops(dmlProgram);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while constructing HOPS (high-level operators)", e);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred while constructing HOPS (high-level operators)", e);
+ }
+ }
+
+ /**
+ * Apply static rewrites, perform intra-/inter-procedural analysis to
+ * propagate size information into functions, apply dynamic rewrites, and
+ * compute memory estimates for all HOPs.
+ */
+ protected void rewriteHops() {
+ try {
+ dmlTranslator.rewriteHopsDAG(dmlProgram);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e);
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e);
+ }
+ }
+
+ /**
+ * Output a description of the program to standard output.
+ */
+ protected void showExplanation() {
+ if (explain) {
+ try {
+ System.out.println(Explain.explain(dmlProgram));
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred while explaining dml program", e);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred while explaining dml program", e);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while explaining dml program", e);
+ }
+ }
+ }
+
+ /**
+ * Construct DAGs of low-level operators (LOPs) based on the DAGs of
+ * high-level operators (HOPs).
+ */
+ protected void constructLops() {
+ try {
+ dmlTranslator.constructLops(dmlProgram);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e);
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e);
+ } catch (LopsException e) {
+ throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e);
+ }
+ }
+
+ /**
+ * Create runtime program. For each namespace, translate function statement
+ * blocks into function program blocks and add these to the runtime program.
+ * For each top-level block, add the program block to the runtime program.
+ */
+ protected void generateRuntimeProgram() {
+ try {
+ runtimeProgram = dmlProgram.getRuntimeProgram(config);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while generating runtime program", e);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred while generating runtime program", e);
+ } catch (LopsException e) {
+ throw new MLContextException("Exception occurred while generating runtime program", e);
+ } catch (IOException e) {
+ throw new MLContextException("Exception occurred while generating runtime program", e);
+ }
+ }
+
+ /**
+ * Count the number of compiled MR Jobs/Spark Instructions in the runtime
+ * program and set this value in the statistics.
+ */
+ protected void countCompiledMRJobsAndSparkInstructions() {
+ ExplainCounts counts = Explain.countDistributedOperations(runtimeProgram);
+ Statistics.resetNoOfCompiledJobs(counts.numJobs);
+ }
+
+ /**
+ * Create an execution context and set its variables to be the symbol table
+ * of the script.
+ */
+ protected void createAndInitializeExecutionContext() {
+ executionContext = ExecutionContextFactory.createContext(runtimeProgram);
+ LocalVariableMap symbolTable = script.getSymbolTable();
+ if (symbolTable != null) {
+ executionContext.setVariables(symbolTable);
+ }
+ }
+
+ /**
+ * Execute a DML or PYDML script. This is broken down into the following
+ * primary methods:
+ *
+ * <ol>
+ * <li>{@link #parseScript()}</li>
+ * <li>{@link #liveVariableAnalysis()}</li>
+ * <li>{@link #validateScript()}</li>
+ * <li>{@link #constructHops()}</li>
+ * <li>{@link #rewriteHops()}</li>
+ * <li>{@link #showExplanation()}</li>
+ * <li>{@link #rewritePersistentReadsAndWrites()}</li>
+ * <li>{@link #constructLops()}</li>
+ * <li>{@link #generateRuntimeProgram()}</li>
+ * <li>{@link #globalDataFlowOptimization()}</li>
+ * <li>{@link #countCompiledMRJobsAndSparkInstructions()}</li>
+ * <li>{@link #initializeCachingAndScratchSpace()}</li>
+ * <li>{@link #cleanupRuntimeProgram()}</li>
+ * <li>{@link #createAndInitializeExecutionContext()}</li>
+ * <li>{@link #executeRuntimeProgram()}</li>
+ * <li>{@link #cleanupAfterExecution()}</li>
+ * </ol>
+ *
+ * @param script
+ * the DML or PYDML script to execute
+ */
+ public MLResults execute(Script script) {
+ this.script = script;
+ checkScriptHasTypeAndString();
+ script.setScriptExecutor(this);
+ setScriptStringInSparkMonitor();
+
+ // main steps in script execution
+ parseScript();
+ liveVariableAnalysis();
+ validateScript();
+ constructHops();
+ rewriteHops();
+ showExplanation();
+ rewritePersistentReadsAndWrites();
+ constructLops();
+ generateRuntimeProgram();
+ globalDataFlowOptimization();
+ countCompiledMRJobsAndSparkInstructions();
+ initializeCachingAndScratchSpace();
+ cleanupRuntimeProgram();
+ createAndInitializeExecutionContext();
+ executeRuntimeProgram();
+ setExplainRuntimeProgramInSparkMonitor();
+ cleanupAfterExecution();
+
+ // add symbol table to MLResults
+ MLResults mlResults = new MLResults(script);
+ script.setResults(mlResults);
+
+ if (statistics) {
+ System.out.println(Statistics.display());
+ }
+
+ return mlResults;
+ }
+
+ /**
+ * Perform any necessary cleanup operations after program execution.
+ */
+ protected void cleanupAfterExecution() {
+ restoreInputsInSymbolTable();
+ }
+
+ /**
+ * Restore the input variables in the symbol table after script execution.
+ */
+ protected void restoreInputsInSymbolTable() {
+ Map<String, Object> inputs = script.getInputs();
+ Map<String, MatrixMetadata> inputMatrixMetadata = script.getInputMatrixMetadata();
+ LocalVariableMap symbolTable = script.getSymbolTable();
+ Set<String> inputVariables = script.getInputVariables();
+ for (String inputVariable : inputVariables) {
+ if (symbolTable.get(inputVariable) == null) {
+ // retrieve optional metadata if it exists
+ MatrixMetadata mm = inputMatrixMetadata.get(inputVariable);
+ script.in(inputVariable, inputs.get(inputVariable), mm);
+ }
+ }
+ }
+
+ /**
+ * Remove rmvar instructions so as to maintain registered outputs after the
+ * program terminates.
+ */
+ protected void cleanupRuntimeProgram() {
+ JMLCUtils.cleanupRuntimeProgram(runtimeProgram, (script.getOutputVariables() == null) ? new String[0] : script
+ .getOutputVariables().toArray(new String[0]));
+ }
+
+ /**
+ * Execute the runtime program. This involves execution of the program
+ * blocks that make up the runtime program and may involve dynamic
+ * recompilation.
+ */
+ protected void executeRuntimeProgram() {
+ try {
+ runtimeProgram.execute(executionContext);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred while executing runtime program", e);
+ }
+ }
+
+ /**
+ * Obtain the SparkMonitoringUtil object.
+ *
+ * @return the SparkMonitoringUtil object, if available
+ */
+ public SparkMonitoringUtil getSparkMonitoringUtil() {
+ return sparkMonitoringUtil;
+ }
+
+ /**
+ * Check security, create scratch space, cleanup working directories,
+ * initialize caching, and reset statistics.
+ */
+ protected void initializeCachingAndScratchSpace() {
+ try {
+ DMLScript.initHadoopExecution(config);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred initializing caching and scratch space", e);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred initializing caching and scratch space", e);
+ } catch (IOException e) {
+ throw new MLContextException("Exception occurred initializing caching and scratch space", e);
+ }
+ }
+
+ /**
+ * Optimize the program.
+ */
+ protected void globalDataFlowOptimization() {
+ if (OptimizerUtils.isOptLevel(OptimizationLevel.O4_GLOBAL_TIME_MEMORY)) {
+ try {
+ runtimeProgram = GlobalOptimizerWrapper.optimizeProgram(dmlProgram, runtimeProgram);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred during global data flow optimization", e);
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred during global data flow optimization", e);
+ } catch (LopsException e) {
+ throw new MLContextException("Exception occurred during global data flow optimization", e);
+ }
+ }
+ }
+
+ /**
+ * Parse the script into an ANTLR parse tree, and convert this parse tree
+ * into a SystemML program. Parsing includes lexical/syntactic analysis.
+ */
+ protected void parseScript() {
+ try {
+ AParserWrapper parser = AParserWrapper.createParser(script.getScriptType().isPYDML());
+ Map<String, Object> inputParameters = script.getInputParameters();
+ Map<String, String> inputParametersStringMaps = MLContextUtil.convertInputParametersForParser(
+ inputParameters, script.getScriptType());
+
+ String scriptExecutionString = script.getScriptExecutionString();
+ dmlProgram = parser.parse(null, scriptExecutionString, inputParametersStringMaps);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred while parsing script", e);
+ }
+ }
+
+ /**
+ * Replace persistent reads and writes with transient reads and writes in
+ * the symbol table.
+ */
+ protected void rewritePersistentReadsAndWrites() {
+ LocalVariableMap symbolTable = script.getSymbolTable();
+ if (symbolTable != null) {
+ String[] inputs = (script.getInputVariables() == null) ? new String[0] : script.getInputVariables()
+ .toArray(new String[0]);
+ String[] outputs = (script.getOutputVariables() == null) ? new String[0] : script.getOutputVariables()
+ .toArray(new String[0]);
+ RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs);
+ ProgramRewriter programRewriter = new ProgramRewriter(rewrite);
+ try {
+ programRewriter.rewriteProgramHopDAGs(dmlProgram);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e);
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e);
+ }
+ }
+
+ }
+
+ /**
+ * Set the SystemML configuration properties.
+ *
+ * @param config
+ * The configuration properties
+ */
+ public void setConfig(DMLConfig config) {
+ this.config = config;
+ ConfigurationManager.setGlobalConfig(config);
+ }
+
+ /**
+ * Set the explanation of the runtime program in the SparkMonitoringUtil if
+ * it exists.
+ */
+ protected void setExplainRuntimeProgramInSparkMonitor() {
+ if (sparkMonitoringUtil != null) {
+ try {
+ String explainOutput = Explain.explain(runtimeProgram);
+ sparkMonitoringUtil.setExplainOutput(explainOutput);
+ } catch (HopsException e) {
+ throw new MLContextException("Exception occurred while explaining runtime program", e);
+ }
+ }
+
+ }
+
+ /**
+ * Set the script string in the SparkMonitoringUtil if it exists.
+ */
+ protected void setScriptStringInSparkMonitor() {
+ if (sparkMonitoringUtil != null) {
+ sparkMonitoringUtil.setDMLString(script.getScriptString());
+ }
+ }
+
+ /**
+ * Set the SparkMonitoringUtil object.
+ *
+ * @param sparkMonitoringUtil
+ * The SparkMonitoringUtil object
+ */
+ public void setSparkMonitoringUtil(SparkMonitoringUtil sparkMonitoringUtil) {
+ this.sparkMonitoringUtil = sparkMonitoringUtil;
+ }
+
+ /**
+ * Liveness analysis is performed on the program, obtaining sets of live-in
+ * and live-out variables by forward and backward passes over the program.
+ */
+ protected void liveVariableAnalysis() {
+ try {
+ dmlTranslator = new DMLTranslator(dmlProgram);
+ dmlTranslator.liveVariableAnalysis(dmlProgram);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception occurred during live variable analysis", e);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred during live variable analysis", e);
+ }
+ }
+
+ /**
+ * Semantically validate the program's expressions, statements, and
+ * statement blocks in a single recursive pass over the program. Constant
+ * and size propagation occurs during this step.
+ */
+ protected void validateScript() {
+ try {
+ dmlTranslator.validateParseTree(dmlProgram);
+ } catch (LanguageException e) {
+ throw new MLContextException("Exception occurred while validating script", e);
+ } catch (ParseException e) {
+ throw new MLContextException("Exception occurred while validating script", e);
+ } catch (IOException e) {
+ throw new MLContextException("Exception occurred while validating script", e);
+ }
+ }
+
+ /**
+ * Check that the Script object has a type (DML or PYDML) and a string
+ * representing the content of the Script.
+ */
+ protected void checkScriptHasTypeAndString() {
+ if (script == null) {
+ throw new MLContextException("Script is null");
+ } else if (script.getScriptType() == null) {
+ throw new MLContextException("ScriptType (DML or PYDML) needs to be specified");
+ } else if (script.getScriptString() == null) {
+ throw new MLContextException("Script string is null");
+ } else if (StringUtils.isBlank(script.getScriptString())) {
+ throw new MLContextException("Script string is blank");
+ }
+ }
+
+ /**
+ * Obtain the program
+ *
+ * @return the program
+ */
+ public DMLProgram getDmlProgram() {
+ return dmlProgram;
+ }
+
+ /**
+ * Obtain the translator
+ *
+ * @return the translator
+ */
+ public DMLTranslator getDmlTranslator() {
+ return dmlTranslator;
+ }
+
+ /**
+ * Obtain the runtime program
+ *
+ * @return the runtime program
+ */
+ public Program getRuntimeProgram() {
+ return runtimeProgram;
+ }
+
+ /**
+ * Obtain the execution context
+ *
+ * @return the execution context
+ */
+ public ExecutionContext getExecutionContext() {
+ return executionContext;
+ }
+
+ /**
+ * Obtain the Script object associated with this ScriptExecutor
+ *
+ * @return the Script object associated with this ScriptExecutor
+ */
+ public Script getScript() {
+ return script;
+ }
+
+ /**
+ * Whether or not an explanation of the DML/PYDML program should be output
+ * to standard output.
+ *
+ * @param explain
+ * {@code true} if explanation should be output, {@code false}
+ * otherwise
+ */
+ public void setExplain(boolean explain) {
+ this.explain = explain;
+ }
+
+ /**
+ * Whether or not statistics about the DML/PYDML program should be output to
+ * standard output.
+ *
+ * @param statistics
+ * {@code true} if statistics should be output, {@code false}
+ * otherwise
+ */
+ public void setStatistics(boolean statistics) {
+ this.statistics = statistics;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java
new file mode 100644
index 0000000..5f0e56b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.MalformedURLException;
+import java.net.URL;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.io.IOUtils;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.runtime.util.LocalFileUtils;
+
+/**
+ * Factory for creating DML and PYDML Script objects from strings, files, URLs,
+ * and input streams.
+ *
+ */
+public class ScriptFactory {
+
+ /**
+ * Create a DML Script object based on a string path to a file.
+ *
+ * @param scriptFilePath
+ * path to DML script file (local or HDFS)
+ * @return DML Script object
+ */
+ public static Script dmlFromFile(String scriptFilePath) {
+ return scriptFromFile(scriptFilePath, ScriptType.DML);
+ }
+
+ /**
+ * Create a DML Script object based on an input stream.
+ *
+ * @param inputStream
+ * input stream to DML
+ * @return DML Script object
+ */
+ public static Script dmlFromInputStream(InputStream inputStream) {
+ return scriptFromInputStream(inputStream, ScriptType.DML);
+ }
+
+ /**
+ * Creates a DML Script object based on a file in the local file system. To
+ * create a DML Script object from a local file or HDFS, please use
+ * {@link #dmlFromFile(String)}.
+ *
+ * @param localScriptFile
+ * the local DML file
+ * @return DML Script object
+ */
+ public static Script dmlFromLocalFile(File localScriptFile) {
+ return scriptFromLocalFile(localScriptFile, ScriptType.DML);
+ }
+
+ /**
+ * Create a DML Script object based on a string.
+ *
+ * @param scriptString
+ * string of DML
+ * @return DML Script object
+ */
+ public static Script dmlFromString(String scriptString) {
+ return scriptFromString(scriptString, ScriptType.DML);
+ }
+
+ /**
+ * Create a DML Script object based on a URL path.
+ *
+ * @param scriptUrlPath
+ * URL path to DML script
+ * @return DML Script object
+ */
+ public static Script dmlFromUrl(String scriptUrlPath) {
+ return scriptFromUrl(scriptUrlPath, ScriptType.DML);
+ }
+
+ /**
+ * Create a DML Script object based on a URL.
+ *
+ * @param scriptUrl
+ * URL to DML script
+ * @return DML Script object
+ */
+ public static Script dmlFromUrl(URL scriptUrl) {
+ return scriptFromUrl(scriptUrl, ScriptType.DML);
+ }
+
+ /**
+ * Create a PYDML Script object based on a string path to a file.
+ *
+ * @param scriptFilePath
+ * path to PYDML script file (local or HDFS)
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromFile(String scriptFilePath) {
+ return scriptFromFile(scriptFilePath, ScriptType.PYDML);
+ }
+
+ /**
+ * Create a PYDML Script object based on an input stream.
+ *
+ * @param inputStream
+ * input stream to PYDML
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromInputStream(InputStream inputStream) {
+ return scriptFromInputStream(inputStream, ScriptType.PYDML);
+ }
+
+ /**
+ * Creates a PYDML Script object based on a file in the local file system.
+ * To create a PYDML Script object from a local file or HDFS, please use
+ * {@link #pydmlFromFile(String)}.
+ *
+ * @param localScriptFile
+ * the local PYDML file
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromLocalFile(File localScriptFile) {
+ return scriptFromLocalFile(localScriptFile, ScriptType.PYDML);
+ }
+
+ /**
+ * Create a PYDML Script object based on a string.
+ *
+ * @param scriptString
+ * string of PYDML
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromString(String scriptString) {
+ return scriptFromString(scriptString, ScriptType.PYDML);
+ }
+
+ /**
+ * Creat a PYDML Script object based on a URL path.
+ *
+ * @param scriptUrlPath
+ * URL path to PYDML script
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromUrl(String scriptUrlPath) {
+ return scriptFromUrl(scriptUrlPath, ScriptType.PYDML);
+ }
+
+ /**
+ * Create a PYDML Script object based on a URL.
+ *
+ * @param scriptUrl
+ * URL to PYDML script
+ * @return PYDML Script object
+ */
+ public static Script pydmlFromUrl(URL scriptUrl) {
+ return scriptFromUrl(scriptUrl, ScriptType.PYDML);
+ }
+
+ /**
+ * Create a DML or PYDML Script object based on a string path to a file.
+ *
+ * @param scriptFilePath
+ * path to DML or PYDML script file (local or HDFS)
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromFile(String scriptFilePath, ScriptType scriptType) {
+ String scriptString = getScriptStringFromFile(scriptFilePath);
+ return scriptFromString(scriptString, scriptType).setName(scriptFilePath);
+ }
+
+ /**
+ * Create a DML or PYDML Script object based on an input stream.
+ *
+ * @param inputStream
+ * input stream to DML or PYDML
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromInputStream(InputStream inputStream, ScriptType scriptType) {
+ String scriptString = getScriptStringFromInputStream(inputStream);
+ return scriptFromString(scriptString, scriptType);
+ }
+
+ /**
+ * Creates a DML or PYDML Script object based on a file in the local file
+ * system. To create a Script object from a local file or HDFS, please use
+ * {@link scriptFromFile(String, ScriptType)}.
+ *
+ * @param localScriptFile
+ * The local DML or PYDML file
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromLocalFile(File localScriptFile, ScriptType scriptType) {
+ String scriptString = getScriptStringFromFile(localScriptFile);
+ return scriptFromString(scriptString, scriptType).setName(localScriptFile.getName());
+ }
+
+ /**
+ * Create a DML or PYDML Script object based on a string.
+ *
+ * @param scriptString
+ * string of DML or PYDML
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromString(String scriptString, ScriptType scriptType) {
+ Script script = new Script(scriptString, scriptType);
+ return script;
+ }
+
+ /**
+ * Creat a DML or PYDML Script object based on a URL path.
+ *
+ * @param scriptUrlPath
+ * URL path to DML or PYDML script
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromUrl(String scriptUrlPath, ScriptType scriptType) {
+ String scriptString = getScriptStringFromUrl(scriptUrlPath);
+ return scriptFromString(scriptString, scriptType).setName(scriptUrlPath);
+ }
+
+ /**
+ * Create a DML or PYDML Script object based on a URL.
+ *
+ * @param scriptUrl
+ * URL to DML or PYDML script
+ * @param scriptType
+ * {@code ScriptType.DML} or {@code ScriptType.PYDML}
+ * @return DML or PYDML Script object
+ */
+ private static Script scriptFromUrl(URL scriptUrl, ScriptType scriptType) {
+ String scriptString = getScriptStringFromUrl(scriptUrl);
+ return scriptFromString(scriptString, scriptType).setName(scriptUrl.toString());
+ }
+
+ /**
+ * Create a DML Script object based on a string.
+ *
+ * @param scriptString
+ * string of DML
+ * @return DML Script object
+ */
+ public static Script dml(String scriptString) {
+ return dmlFromString(scriptString);
+ }
+
+ /**
+ * Obtain a script string from a file in the local file system. To obtain a
+ * script string from a file in HDFS, please use
+ * getScriptStringFromFile(String scriptFilePath).
+ *
+ * @param file
+ * The script file.
+ * @return The script string.
+ * @throws MLContextException
+ * If a problem occurs reading the script string from the file.
+ */
+ private static String getScriptStringFromFile(File file) {
+ if (file == null) {
+ throw new MLContextException("Script file is null");
+ }
+ String filePath = file.getPath();
+ try {
+ if (!LocalFileUtils.validateExternalFilename(filePath, false)) {
+ throw new MLContextException("Invalid (non-trustworthy) local filename: " + filePath);
+ }
+ String scriptString = FileUtils.readFileToString(file);
+ return scriptString;
+ } catch (IllegalArgumentException e) {
+ throw new MLContextException("Error trying to read script string from file: " + filePath, e);
+ } catch (IOException e) {
+ throw new MLContextException("Error trying to read script string from file: " + filePath, e);
+ }
+ }
+
+ /**
+ * Obtain a script string from a file.
+ *
+ * @param scriptFilePath
+ * The file path to the script file (either local file system or
+ * HDFS)
+ * @return The script string
+ * @throws MLContextException
+ * If a problem occurs reading the script string from the file
+ */
+ private static String getScriptStringFromFile(String scriptFilePath) {
+ if (scriptFilePath == null) {
+ throw new MLContextException("Script file path is null");
+ }
+ try {
+ if (scriptFilePath.startsWith("hdfs:") || scriptFilePath.startsWith("gpfs:")) {
+ if (!LocalFileUtils.validateExternalFilename(scriptFilePath, true)) {
+ throw new MLContextException("Invalid (non-trustworthy) hdfs/gpfs filename: " + scriptFilePath);
+ }
+ FileSystem fs = FileSystem.get(ConfigurationManager.getCachedJobConf());
+ Path path = new Path(scriptFilePath);
+ FSDataInputStream fsdis = fs.open(path);
+ String scriptString = IOUtils.toString(fsdis);
+ return scriptString;
+ } else {// from local file system
+ if (!LocalFileUtils.validateExternalFilename(scriptFilePath, false)) {
+ throw new MLContextException("Invalid (non-trustworthy) local filename: " + scriptFilePath);
+ }
+ File scriptFile = new File(scriptFilePath);
+ String scriptString = FileUtils.readFileToString(scriptFile);
+ return scriptString;
+ }
+ } catch (IllegalArgumentException e) {
+ throw new MLContextException("Error trying to read script string from file: " + scriptFilePath, e);
+ } catch (IOException e) {
+ throw new MLContextException("Error trying to read script string from file: " + scriptFilePath, e);
+ }
+ }
+
+ /**
+ * Obtain a script string from an InputStream.
+ *
+ * @param inputStream
+ * The InputStream from which to read the script string
+ * @return The script string
+ * @throws MLContextException
+ * If a problem occurs reading the script string from the URL
+ */
+ private static String getScriptStringFromInputStream(InputStream inputStream) {
+ if (inputStream == null) {
+ throw new MLContextException("InputStream is null");
+ }
+ try {
+ String scriptString = IOUtils.toString(inputStream);
+ return scriptString;
+ } catch (IOException e) {
+ throw new MLContextException("Error trying to read script string from InputStream", e);
+ }
+ }
+
+ /**
+ * Obtain a script string from a URL.
+ *
+ * @param scriptUrlPath
+ * The URL path to the script file
+ * @return The script string
+ * @throws MLContextException
+ * If a problem occurs reading the script string from the URL
+ */
+ private static String getScriptStringFromUrl(String scriptUrlPath) {
+ if (scriptUrlPath == null) {
+ throw new MLContextException("Script URL path is null");
+ }
+ try {
+ URL url = new URL(scriptUrlPath);
+ return getScriptStringFromUrl(url);
+ } catch (MalformedURLException e) {
+ throw new MLContextException("Error trying to read script string from URL path: " + scriptUrlPath, e);
+ }
+ }
+
+ /**
+ * Obtain a script string from a URL.
+ *
+ * @param url
+ * The script URL
+ * @return The script string
+ * @throws MLContextException
+ * If a problem occurs reading the script string from the URL
+ */
+ private static String getScriptStringFromUrl(URL url) {
+ if (url == null) {
+ throw new MLContextException("URL is null");
+ }
+ String urlString = url.toString();
+ if ((!urlString.toLowerCase().startsWith("http:")) && (!urlString.toLowerCase().startsWith("https:"))) {
+ throw new MLContextException("Currently only reading from http and https URLs is supported");
+ }
+ try {
+ InputStream is = url.openStream();
+ String scriptString = IOUtils.toString(is);
+ return scriptString;
+ } catch (IOException e) {
+ throw new MLContextException("Error trying to read script string from URL: " + url, e);
+ }
+ }
+
+ /**
+ * Create a PYDML script object based on a string.
+ *
+ * @param scriptString
+ * string of PYDML
+ * @return PYDML Script object
+ */
+ public static Script pydml(String scriptString) {
+ return pydmlFromString(scriptString);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java
new file mode 100644
index 0000000..94c9057
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+/**
+ * ScriptType represents the type of script, DML (R-like syntax) or PYDML
+ * (Python-like syntax).
+ *
+ */
+public enum ScriptType {
+ /**
+ * R-like syntax.
+ */
+ DML,
+
+ /**
+ * Python-like syntax.
+ */
+ PYDML;
+
+ /**
+ * Obtain script type as a lowercase string ("dml" or "pydml").
+ *
+ * @return lowercase string representing the script type
+ */
+ public String lowerCase() {
+ return super.toString().toLowerCase();
+ }
+
+ /**
+ * Is the script type DML?
+ *
+ * @return {@code true} if the script type is DML, {@code false} otherwise
+ */
+ public boolean isDML() {
+ return (this == ScriptType.DML);
+ }
+
+ /**
+ * Is the script type PYDML?
+ *
+ * @return {@code true} if the script type is PYDML, {@code false} otherwise
+ */
+ public boolean isPYDML() {
+ return (this == ScriptType.PYDML);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 0eea221..c715331 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -36,11 +36,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.RDDInfo;
import org.apache.spark.storage.StorageLevel;
-
-import scala.Tuple2;
-
import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.MLContext;
import org.apache.sysml.api.MLContextProxy;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.OptimizerUtils;
@@ -82,6 +78,8 @@ import org.apache.sysml.runtime.util.MapReduceTool;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.utils.Statistics;
+import scala.Tuple2;
+
public class SparkExecutionContext extends ExecutionContext
{
@@ -178,22 +176,28 @@ public class SparkExecutionContext extends ExecutionContext
*
*/
private synchronized static void initSparkContext()
- {
+ {
//check for redundant spark context init
if( _spctx != null )
return;
-
+
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
//create a default spark context (master, appname, etc refer to system properties
//as given in the spark configuration or during spark-submit)
- MLContext mlCtx = MLContextProxy.getActiveMLContext();
- if(mlCtx != null)
+ Object mlCtxObj = MLContextProxy.getActiveMLContext();
+ if(mlCtxObj != null)
{
// This is when DML is called through spark shell
// Will clean the passing of static variables later as this involves minimal change to DMLScript
- _spctx = new JavaSparkContext(mlCtx.getSparkContext());
+ if (mlCtxObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlCtxObj;
+ _spctx = new JavaSparkContext(mlCtx.getSparkContext());
+ } else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj;
+ _spctx = new JavaSparkContext(mlCtx.getSparkContext());
+ }
}
else
{
@@ -1424,11 +1428,26 @@ public class SparkExecutionContext extends ExecutionContext
}
}
- MLContext mlContext = MLContextProxy.getActiveMLContext();
- if(mlContext != null && mlContext.getMonitoringUtil() != null) {
- mlContext.getMonitoringUtil().setLineageInfo(inst, outDebugString);
- }
- else {
+
+ Object mlContextObj = MLContextProxy.getActiveMLContext();
+ if (mlContextObj != null) {
+ if (mlContextObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlContextObj;
+ if (mlCtx.getMonitoringUtil() != null) {
+ mlCtx.getMonitoringUtil().setLineageInfo(inst, outDebugString);
+ } else {
+ throw new DMLRuntimeException("The method setLineageInfoForExplain should be called only through MLContext");
+ }
+ } else if (mlContextObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlContextObj;
+ if (mlCtx.getSparkMonitoringUtil() != null) {
+ mlCtx.getSparkMonitoringUtil().setLineageInfo(inst, outDebugString);
+ } else {
+ throw new DMLRuntimeException("The method setLineageInfoForExplain should be called only through MLContext");
+ }
+ }
+
+ } else {
throw new DMLRuntimeException("The method setLineageInfoForExplain should be called only through MLContext");
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
index 0c0d3f0..d5301e7 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
@@ -19,7 +19,6 @@
package org.apache.sysml.runtime.instructions.spark;
-import org.apache.sysml.api.MLContext;
import org.apache.sysml.api.MLContextProxy;
import org.apache.sysml.lops.runtime.RunMRJobs;
import org.apache.sysml.runtime.DMLRuntimeException;
@@ -99,13 +98,23 @@ public abstract class SPInstruction extends Instruction
//spark-explain-specific handling of current instructions
//This only relevant for ComputationSPInstruction as in postprocess we call setDebugString which is valid only for ComputationSPInstruction
- MLContext mlCtx = MLContextProxy.getActiveMLContext();
- if( tmp instanceof ComputationSPInstruction
- && mlCtx != null && mlCtx.getMonitoringUtil() != null
- && ec instanceof SparkExecutionContext )
- {
- mlCtx.getMonitoringUtil().addCurrentInstruction((SPInstruction)tmp);
- MLContextProxy.setInstructionForMonitoring(tmp);
+ Object mlCtxObj = MLContextProxy.getActiveMLContext();
+ if (mlCtxObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlCtxObj;
+ if (tmp instanceof ComputationSPInstruction
+ && mlCtx != null && mlCtx.getMonitoringUtil() != null
+ && ec instanceof SparkExecutionContext ) {
+ mlCtx.getMonitoringUtil().addCurrentInstruction((SPInstruction)tmp);
+ MLContextProxy.setInstructionForMonitoring(tmp);
+ }
+ } else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj;
+ if (tmp instanceof ComputationSPInstruction
+ && mlCtx != null && mlCtx.getSparkMonitoringUtil() != null
+ && ec instanceof SparkExecutionContext ) {
+ mlCtx.getSparkMonitoringUtil().addCurrentInstruction((SPInstruction)tmp);
+ MLContextProxy.setInstructionForMonitoring(tmp);
+ }
}
return tmp;
@@ -120,14 +129,25 @@ public abstract class SPInstruction extends Instruction
throws DMLRuntimeException
{
//spark-explain-specific handling of current instructions
- MLContext mlCtx = MLContextProxy.getActiveMLContext();
- if( this instanceof ComputationSPInstruction
- && mlCtx != null && mlCtx.getMonitoringUtil() != null
- && ec instanceof SparkExecutionContext )
- {
- SparkExecutionContext sec = (SparkExecutionContext) ec;
- sec.setDebugString(this, ((ComputationSPInstruction) this).getOutputVariableName());
- mlCtx.getMonitoringUtil().removeCurrentInstruction(this);
+ Object mlCtxObj = MLContextProxy.getActiveMLContext();
+ if (mlCtxObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlCtxObj;
+ if (this instanceof ComputationSPInstruction
+ && mlCtx != null && mlCtx.getMonitoringUtil() != null
+ && ec instanceof SparkExecutionContext ) {
+ SparkExecutionContext sec = (SparkExecutionContext) ec;
+ sec.setDebugString(this, ((ComputationSPInstruction) this).getOutputVariableName());
+ mlCtx.getMonitoringUtil().removeCurrentInstruction(this);
+ }
+ } else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj;
+ if (this instanceof ComputationSPInstruction
+ && mlCtx != null && mlCtx.getSparkMonitoringUtil() != null
+ && ec instanceof SparkExecutionContext ) {
+ SparkExecutionContext sec = (SparkExecutionContext) ec;
+ sec.setDebugString(this, ((ComputationSPInstruction) this).getOutputVariableName());
+ mlCtx.getSparkMonitoringUtil().removeCurrentInstruction(this);
+ }
}
//maintain statistics
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java
index 956b841..3bf2f67 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java
@@ -33,16 +33,14 @@ import org.apache.spark.storage.RDDInfo;
import org.apache.spark.ui.jobs.StagesTab;
import org.apache.spark.ui.jobs.UIData.TaskUIData;
import org.apache.spark.ui.scope.RDDOperationGraphListener;
+import org.apache.sysml.api.MLContextProxy;
+import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import scala.Option;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.xml.Node;
-import org.apache.sysml.api.MLContext;
-import org.apache.sysml.api.MLContextProxy;
-import org.apache.sysml.runtime.instructions.spark.SPInstruction;
-
// Instead of extending org.apache.spark.JavaSparkListener
/**
* This class is only used by MLContext for now. It is used to provide UI data for Python notebook.
@@ -94,9 +92,19 @@ public class SparkListener extends RDDOperationGraphListener {
jobDAGs.put(jobID, jobNodes);
synchronized(currentInstructions) {
for(SPInstruction inst : currentInstructions) {
- MLContext mlContext = MLContextProxy.getActiveMLContext();
- if(mlContext != null && mlContext.getMonitoringUtil() != null) {
- mlContext.getMonitoringUtil().setJobId(inst, jobID);
+ Object mlContextObj = MLContextProxy.getActiveMLContext();
+ if (mlContextObj != null) {
+ if (mlContextObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlContext = (org.apache.sysml.api.MLContext) mlContextObj;
+ if (mlContext.getMonitoringUtil() != null) {
+ mlContext.getMonitoringUtil().setJobId(inst, jobID);
+ }
+ } else if (mlContextObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlContext = (org.apache.sysml.api.mlcontext.MLContext) mlContextObj;
+ if (mlContext.getSparkMonitoringUtil() != null) {
+ mlContext.getSparkMonitoringUtil().setJobId(inst, jobID);
+ }
+ }
}
}
}
@@ -140,9 +148,19 @@ public class SparkListener extends RDDOperationGraphListener {
synchronized(currentInstructions) {
for(SPInstruction inst : currentInstructions) {
- MLContext mlContext = MLContextProxy.getActiveMLContext();
- if(mlContext != null && mlContext.getMonitoringUtil() != null) {
- mlContext.getMonitoringUtil().setStageId(inst, stageSubmitted.stageInfo().stageId());
+ Object mlContextObj = MLContextProxy.getActiveMLContext();
+ if (mlContextObj != null) {
+ if (mlContextObj instanceof org.apache.sysml.api.MLContext) {
+ org.apache.sysml.api.MLContext mlContext = (org.apache.sysml.api.MLContext) mlContextObj;
+ if (mlContext.getMonitoringUtil() != null) {
+ mlContext.getMonitoringUtil().setStageId(inst, stageSubmitted.stageInfo().stageId());
+ }
+ } else if (mlContextObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
+ org.apache.sysml.api.mlcontext.MLContext mlContext = (org.apache.sysml.api.mlcontext.MLContext) mlContextObj;
+ if (mlContext.getSparkMonitoringUtil() != null) {
+ mlContext.getSparkMonitoringUtil().setStageId(inst, stageSubmitted.stageInfo().stageId());
+ }
+ }
}
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
index ccdc927..f022e40 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
@@ -410,7 +410,7 @@ public class RDDConverterUtilsExt
}
- private static class DataFrameAnalysisFunction implements Function<Row,Row> {
+ public static class DataFrameAnalysisFunction implements Function<Row,Row> {
private static final long serialVersionUID = 5705371332119770215L;
private RowAnalysisFunctionHelper helper = null;
boolean isVectorBasedRDD;
@@ -445,7 +445,7 @@ public class RDDConverterUtilsExt
}
- private static class DataFrameToBinaryBlockFunction implements PairFlatMapFunction<Iterator<Tuple2<Row,Long>>,MatrixIndexes,MatrixBlock> {
+ public static class DataFrameToBinaryBlockFunction implements PairFlatMapFunction<Iterator<Tuple2<Row,Long>>,MatrixIndexes,MatrixBlock> {
private static final long serialVersionUID = 653447740362447236L;
private RowToBinaryBlockFunctionHelper helper = null;
boolean isVectorBasedDF;