You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2017/09/01 23:07:22 UTC
systemml git commit: [MINOR] Increase MLContext test coverage
Repository: systemml
Updated Branches:
refs/heads/master 8dbc93022 -> 912c65506
[MINOR] Increase MLContext test coverage
Create MLContext tests to test previously untested methods.
Update MLContext and MLContextConversionUtil to avoid NPEs.
Closes #649.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/912c6550
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/912c6550
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/912c6550
Branch: refs/heads/master
Commit: 912c65506d626c8b0128ceb80744fde49efd4a1a
Parents: 8dbc930
Author: Deron Eriksson <de...@apache.org>
Authored: Fri Sep 1 16:02:55 2017 -0700
Committer: Deron Eriksson <de...@apache.org>
Committed: Fri Sep 1 16:02:55 2017 -0700
----------------------------------------------------------------------
.../apache/sysml/api/mlcontext/MLContext.java | 10 +-
.../api/mlcontext/MLContextConversionUtil.java | 3 +
.../sysml/api/mlcontext/MLContextUtil.java | 206 ++++++------
.../integration/mlcontext/MLContextTest.java | 314 +++++++++++++++++++
4 files changed, 431 insertions(+), 102 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
index 83eedb3..35720a5 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
@@ -55,7 +55,7 @@ public class MLContext {
/**
* Logger for MLContext
*/
- public static Logger log = Logger.getLogger(MLContext.class);
+ protected static Logger log = Logger.getLogger(MLContext.class);
/**
* SparkSession object.
@@ -665,7 +665,9 @@ public class MLContext {
// clear local status, but do not stop sc as it
// may be used or stopped externally
- executionScript.clearAll();
+ if (executionScript != null) {
+ executionScript.clearAll();
+ }
resetConfig();
spark = null;
}
@@ -693,7 +695,7 @@ public class MLContext {
*/
public String version() {
if (info() == null) {
- return "Version not available";
+ return MLContextUtil.VERSION_NOT_AVAILABLE;
}
return info().version();
}
@@ -705,7 +707,7 @@ public class MLContext {
*/
public String buildTime() {
if (info() == null) {
- return "Build time not available";
+ return MLContextUtil.BUILD_TIME_NOT_AVAILABLE;
}
return info().buildTime();
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 5883127..3f12ace 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -482,6 +482,9 @@ public class MLContextConversionUtil {
* the matrix metadata, if available
*/
public static void determineMatrixFormatIfNeeded(Dataset<Row> dataFrame, MatrixMetadata matrixMetadata) {
+ if (matrixMetadata == null) {
+ return;
+ }
MatrixFormat matrixFormat = matrixMetadata.getMatrixFormat();
if (matrixFormat != null) {
return;
http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
index 51d38a5..03184e3 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -91,122 +91,32 @@ import org.w3c.dom.NodeList;
*
*/
public final class MLContextUtil {
-
+
/**
- * Get HOP DAG in dot format for a DML or PYDML Script.
- *
- * @param mlCtx
- * MLContext object.
- * @param script
- * The DML or PYDML Script object to execute.
- * @param lines
- * Only display the hops that have begin and end line number
- * equals to the given integers.
- * @param performHOPRewrites
- * should perform static rewrites, perform
- * intra-/inter-procedural analysis to propagate size information
- * into functions and apply dynamic rewrites
- * @param withSubgraph
- * If false, the dot graph will be created without subgraphs for
- * statement blocks.
- * @return hop DAG in dot format
- * @throws LanguageException
- * if error occurs
- * @throws DMLRuntimeException
- * if error occurs
- * @throws HopsException
- * if error occurs
+ * Version not available message.
*/
- public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines,
- boolean performHOPRewrites, boolean withSubgraph) throws HopsException, DMLRuntimeException,
- LanguageException {
- return getHopDAG(mlCtx, script, lines, null, performHOPRewrites, withSubgraph);
- }
+ public static final String VERSION_NOT_AVAILABLE = "Version not available";
/**
- * Get HOP DAG in dot format for a DML or PYDML Script.
- *
- * @param mlCtx
- * MLContext object.
- * @param script
- * The DML or PYDML Script object to execute.
- * @param lines
- * Only display the hops that have begin and end line number
- * equals to the given integers.
- * @param newConf
- * Spark Configuration.
- * @param performHOPRewrites
- * should perform static rewrites, perform
- * intra-/inter-procedural analysis to propagate size information
- * into functions and apply dynamic rewrites
- * @param withSubgraph
- * If false, the dot graph will be created without subgraphs for
- * statement blocks.
- * @return hop DAG in dot format
- * @throws LanguageException
- * if error occurs
- * @throws DMLRuntimeException
- * if error occurs
- * @throws HopsException
- * if error occurs
+ * Build time not available message.
*/
- public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines, SparkConf newConf,
- boolean performHOPRewrites, boolean withSubgraph) throws HopsException, DMLRuntimeException,
- LanguageException {
- SparkConf oldConf = mlCtx.getSparkSession().sparkContext().getConf();
- SparkExecutionContext.SparkClusterConfig systemmlConf = SparkExecutionContext.getSparkClusterConfig();
- long oldMaxMemory = InfrastructureAnalyzer.getLocalMaxMemory();
- try {
- if (newConf != null) {
- systemmlConf.analyzeSparkConfiguation(newConf);
- InfrastructureAnalyzer.setLocalMaxMemory(newConf.getSizeAsBytes("spark.driver.memory"));
- }
- ScriptExecutor scriptExecutor = new ScriptExecutor();
- scriptExecutor.setExecutionType(mlCtx.getExecutionType());
- scriptExecutor.setGPU(mlCtx.isGPU());
- scriptExecutor.setForceGPU(mlCtx.isForceGPU());
- scriptExecutor.setInit(mlCtx.isInitBeforeExecution());
- if (mlCtx.isInitBeforeExecution()) {
- mlCtx.setInitBeforeExecution(false);
- }
- scriptExecutor.setMaintainSymbolTable(mlCtx.isMaintainSymbolTable());
-
- Long time = new Long((new Date()).getTime());
- if ((script.getName() == null) || (script.getName().equals(""))) {
- script.setName(time.toString());
- }
-
- mlCtx.setExecutionScript(script);
- scriptExecutor.compile(script, performHOPRewrites);
- Explain.reset();
- // To deal with potential Py4J issues
- lines = lines.size() == 1 && lines.get(0) == -1 ? new ArrayList<Integer>() : lines;
- return Explain.getHopDAG(scriptExecutor.dmlProgram, lines, withSubgraph);
- } catch (RuntimeException e) {
- throw new MLContextException("Exception when compiling script", e);
- } finally {
- if (newConf != null) {
- systemmlConf.analyzeSparkConfiguation(oldConf);
- InfrastructureAnalyzer.setLocalMaxMemory(oldMaxMemory);
- }
- }
- }
+ public static final String BUILD_TIME_NOT_AVAILABLE = "Build time not available";
/**
- * Basic data types supported by the MLContext API
+ * Basic data types supported by the MLContext API.
*/
@SuppressWarnings("rawtypes")
public static final Class[] BASIC_DATA_TYPES = { Integer.class, Boolean.class, Double.class, String.class };
/**
- * Complex data types supported by the MLContext API
+ * Complex data types supported by the MLContext API.
*/
@SuppressWarnings("rawtypes")
public static final Class[] COMPLEX_DATA_TYPES = { JavaRDD.class, RDD.class, Dataset.class, Matrix.class,
Frame.class, (new double[][] {}).getClass(), MatrixBlock.class, URL.class };
/**
- * All data types supported by the MLContext API
+ * All data types supported by the MLContext API.
*/
@SuppressWarnings("rawtypes")
public static final Class[] ALL_SUPPORTED_DATA_TYPES = (Class[]) ArrayUtils.addAll(BASIC_DATA_TYPES,
@@ -1252,4 +1162,104 @@ public final class MLContextUtil {
}
}
}
+
+ /**
+ * Get HOP DAG in dot format for a DML or PYDML Script.
+ *
+ * @param mlCtx
+ * MLContext object.
+ * @param script
+ * The DML or PYDML Script object to execute.
+ * @param lines
+ * Only display the hops that have begin and end line number
+ * equals to the given integers.
+ * @param performHOPRewrites
+ * should perform static rewrites, perform
+ * intra-/inter-procedural analysis to propagate size information
+ * into functions and apply dynamic rewrites
+ * @param withSubgraph
+ * If false, the dot graph will be created without subgraphs for
+ * statement blocks.
+ * @return hop DAG in dot format
+ * @throws LanguageException
+ * if error occurs
+ * @throws DMLRuntimeException
+ * if error occurs
+ * @throws HopsException
+ * if error occurs
+ */
+ public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines, boolean performHOPRewrites,
+ boolean withSubgraph) throws HopsException, DMLRuntimeException, LanguageException {
+ return getHopDAG(mlCtx, script, lines, null, performHOPRewrites, withSubgraph);
+ }
+
+ /**
+ * Get HOP DAG in dot format for a DML or PYDML Script.
+ *
+ * @param mlCtx
+ * MLContext object.
+ * @param script
+ * The DML or PYDML Script object to execute.
+ * @param lines
+ * Only display the hops that have begin and end line number
+ * equals to the given integers.
+ * @param newConf
+ * Spark Configuration.
+ * @param performHOPRewrites
+ * should perform static rewrites, perform
+ * intra-/inter-procedural analysis to propagate size information
+ * into functions and apply dynamic rewrites
+ * @param withSubgraph
+ * If false, the dot graph will be created without subgraphs for
+ * statement blocks.
+ * @return hop DAG in dot format
+ * @throws LanguageException
+ * if error occurs
+ * @throws DMLRuntimeException
+ * if error occurs
+ * @throws HopsException
+ * if error occurs
+ */
+ public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines, SparkConf newConf,
+ boolean performHOPRewrites, boolean withSubgraph)
+ throws HopsException, DMLRuntimeException, LanguageException {
+ SparkConf oldConf = mlCtx.getSparkSession().sparkContext().getConf();
+ SparkExecutionContext.SparkClusterConfig systemmlConf = SparkExecutionContext.getSparkClusterConfig();
+ long oldMaxMemory = InfrastructureAnalyzer.getLocalMaxMemory();
+ try {
+ if (newConf != null) {
+ systemmlConf.analyzeSparkConfiguation(newConf);
+ InfrastructureAnalyzer.setLocalMaxMemory(newConf.getSizeAsBytes("spark.driver.memory"));
+ }
+ ScriptExecutor scriptExecutor = new ScriptExecutor();
+ scriptExecutor.setExecutionType(mlCtx.getExecutionType());
+ scriptExecutor.setGPU(mlCtx.isGPU());
+ scriptExecutor.setForceGPU(mlCtx.isForceGPU());
+ scriptExecutor.setInit(mlCtx.isInitBeforeExecution());
+ if (mlCtx.isInitBeforeExecution()) {
+ mlCtx.setInitBeforeExecution(false);
+ }
+ scriptExecutor.setMaintainSymbolTable(mlCtx.isMaintainSymbolTable());
+
+ Long time = new Long((new Date()).getTime());
+ if ((script.getName() == null) || (script.getName().equals(""))) {
+ script.setName(time.toString());
+ }
+
+ mlCtx.setExecutionScript(script);
+ scriptExecutor.compile(script, performHOPRewrites);
+ Explain.reset();
+ // To deal with potential Py4J issues
+ lines = lines.size() == 1 && lines.get(0) == -1 ? new ArrayList<Integer>() : lines;
+ return Explain.getHopDAG(scriptExecutor.dmlProgram, lines, withSubgraph);
+ } catch (RuntimeException e) {
+ throw new MLContextException("Exception when compiling script", e);
+ } finally {
+ if (newConf != null) {
+ systemmlConf.analyzeSparkConfiguation(oldConf);
+ InfrastructureAnalyzer.setLocalMaxMemory(oldMaxMemory);
+ }
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
index b08f5b9..9e4cfac 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -42,10 +42,13 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
+import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
@@ -54,10 +57,12 @@ import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.sysml.api.mlcontext.MLContextConversionUtil;
import org.apache.sysml.api.mlcontext.MLContextException;
+import org.apache.sysml.api.mlcontext.MLContextUtil;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.MatrixFormat;
@@ -69,11 +74,14 @@ import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.util.DataConverter;
import org.junit.Assert;
import org.junit.Test;
+import scala.Tuple1;
import scala.Tuple2;
import scala.Tuple3;
+import scala.Tuple4;
import scala.collection.Iterator;
import scala.collection.JavaConversions;
import scala.collection.Seq;
@@ -2756,4 +2764,310 @@ public class MLContextTest extends MLContextTestBase {
Assert.assertEquals(3, results.getLong("y"));
}
+ @Test
+ public void testOutputDataFrameOfVectorsDML() {
+ System.out.println("MLContextTest - output DataFrame of vectors DML");
+
+ String s = "m=matrix('1 2 3 4',rows=2,cols=2);";
+ Script script = dml(s).out("m");
+ MLResults results = ml.execute(script);
+ Dataset<Row> df = results.getDataFrame("m", true);
+ Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
+
+ // verify column types
+ StructType schema = sortedDF.schema();
+ StructField[] fields = schema.fields();
+ StructField idColumn = fields[0];
+ StructField vectorColumn = fields[1];
+ Assert.assertTrue(idColumn.dataType() instanceof DoubleType);
+ Assert.assertTrue(vectorColumn.dataType() instanceof VectorUDT);
+
+ List<Row> list = sortedDF.collectAsList();
+
+ Row row1 = list.get(0);
+ Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+ Vector v1 = (DenseVector) row1.get(1);
+ double[] arr1 = v1.toArray();
+ Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, arr1, 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+ Vector v2 = (DenseVector) row2.get(1);
+ double[] arr2 = v2.toArray();
+ Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, arr2, 0.0);
+ }
+
+ @Test
+ public void testOutputDoubleArrayFromMatrixDML() {
+ System.out.println("MLContextTest - output double array from matrix DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ double[][] matrix = ml.execute(dml(s).out("M")).getMatrix("M").to2DDoubleArray();
+ Assert.assertEquals(1.0, matrix[0][0], 0);
+ Assert.assertEquals(2.0, matrix[0][1], 0);
+ Assert.assertEquals(3.0, matrix[1][0], 0);
+ Assert.assertEquals(4.0, matrix[1][1], 0);
+ }
+
+ @Test
+ public void testOutputDataFrameFromMatrixDML() {
+ System.out.println("MLContextTest - output DataFrame from matrix DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ Script script = dml(s).out("M");
+ Dataset<Row> df = ml.execute(script).getMatrix("M").toDF();
+ Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
+ List<Row> list = sortedDF.collectAsList();
+ Row row1 = list.get(0);
+ Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+ Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
+ Assert.assertEquals(2.0, row1.getDouble(2), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+ Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
+ Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFrameDoublesNoIDColumnFromMatrixDML() {
+ System.out.println("MLContextTest - output DataFrame of doubles with no ID column from matrix DML");
+
+ String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
+ Script script = dml(s).out("M");
+ Dataset<Row> df = ml.execute(script).getMatrix("M").toDFDoubleNoIDColumn();
+ List<Row> list = df.collectAsList();
+
+ Row row = list.get(0);
+ Assert.assertEquals(1.0, row.getDouble(0), 0.0);
+ Assert.assertEquals(2.0, row.getDouble(1), 0.0);
+ Assert.assertEquals(3.0, row.getDouble(2), 0.0);
+ Assert.assertEquals(4.0, row.getDouble(3), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFrameDoublesWithIDColumnFromMatrixDML() {
+ System.out.println("MLContextTest - output DataFrame of doubles with ID column from matrix DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ Script script = dml(s).out("M");
+ Dataset<Row> df = ml.execute(script).getMatrix("M").toDFDoubleWithIDColumn();
+ Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
+ List<Row> list = sortedDF.collectAsList();
+
+ Row row1 = list.get(0);
+ Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+ Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
+ Assert.assertEquals(2.0, row1.getDouble(2), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+ Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
+ Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFrameVectorsNoIDColumnFromMatrixDML() {
+ System.out.println("MLContextTest - output DataFrame of vectors with no ID column from matrix DML");
+
+ String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
+ Script script = dml(s).out("M");
+ Dataset<Row> df = ml.execute(script).getMatrix("M").toDFVectorNoIDColumn();
+ List<Row> list = df.collectAsList();
+
+ Row row = list.get(0);
+ Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0, 4.0 }, ((Vector) row.get(0)).toArray(), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFrameVectorsWithIDColumnFromMatrixDML() {
+ System.out.println("MLContextTest - output DataFrame of vectors with ID column from matrix DML");
+
+ String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
+ Script script = dml(s).out("M");
+ Dataset<Row> df = ml.execute(script).getMatrix("M").toDFVectorWithIDColumn();
+ List<Row> list = df.collectAsList();
+
+ Row row = list.get(0);
+ Assert.assertEquals(1.0, row.getDouble(0), 0.0);
+ Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0, 4.0 }, ((Vector) row.get(1)).toArray(), 0.0);
+ }
+
+ @Test
+ public void testOutputJavaRDDStringCSVFromMatrixDML() {
+ System.out.println("MLContextTest - output Java RDD String CSV from matrix DML");
+
+ String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
+ Script script = dml(s).out("M");
+ JavaRDD<String> javaRDDStringCSV = ml.execute(script).getMatrix("M").toJavaRDDStringCSV();
+ List<String> lines = javaRDDStringCSV.collect();
+ Assert.assertEquals("1.0,2.0,3.0,4.0", lines.get(0));
+ }
+
+ @Test
+ public void testOutputJavaRDDStringIJVFromMatrixDML() {
+ System.out.println("MLContextTest - output Java RDD String IJV from matrix DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("M");
+ List<String> lines = javaRDDStringIJV.sortBy(row -> row, true, 1).collect();
+ Assert.assertEquals("1 1 1.0", lines.get(0));
+ Assert.assertEquals("1 2 2.0", lines.get(1));
+ Assert.assertEquals("2 1 3.0", lines.get(2));
+ Assert.assertEquals("2 2 4.0", lines.get(3));
+ }
+
+ @Test
+ public void testOutputRDDStringCSVFromMatrixDML() {
+ System.out.println("MLContextTest - output RDD String CSV from matrix DML");
+
+ String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
+ Script script = dml(s).out("M");
+ RDD<String> rddStringCSV = ml.execute(script).getMatrix("M").toRDDStringCSV();
+ Iterator<String> iterator = rddStringCSV.toLocalIterator();
+ Assert.assertEquals("1.0,2.0,3.0,4.0", iterator.next());
+ }
+
+ @Test
+ public void testOutputRDDStringIJVFromMatrixDML() {
+ System.out.println("MLContextTest - output RDD String IJV from matrix DML");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ Script script = dml(s).out("M");
+ RDD<String> rddStringIJV = ml.execute(script).getMatrix("M").toRDDStringIJV();
+ String[] rows = (String[]) rddStringIJV.collect();
+ Arrays.sort(rows);
+ Assert.assertEquals("1 1 1.0", rows[0]);
+ Assert.assertEquals("1 2 2.0", rows[1]);
+ Assert.assertEquals("2 1 3.0", rows[2]);
+ Assert.assertEquals("2 2 4.0", rows[3]);
+ }
+
+ @Test
+ public void testMLContextVersionMessage() {
+ System.out.println("MLContextTest - version message");
+
+ String version = ml.version();
+ // not available until jar built
+ Assert.assertEquals(MLContextUtil.VERSION_NOT_AVAILABLE, version);
+ }
+
+ @Test
+ public void testMLContextBuildTimeMessage() {
+ System.out.println("MLContextTest - build time message");
+
+ String buildTime = ml.buildTime();
+ // not available until jar built
+ Assert.assertEquals(MLContextUtil.BUILD_TIME_NOT_AVAILABLE, buildTime);
+ }
+
+ @Test
+ public void testMLContextCreateAndClose() {
+ // MLContext created by the @BeforeClass method in MLContextTestBase
+ // MLContext closed by the @AfterClass method in MLContextTestBase
+ System.out.println("MLContextTest - create MLContext and close (without script execution)");
+ }
+
+ @Test
+ public void testDataFrameToBinaryBlocks() {
+ System.out.println("MLContextTest - DataFrame to binary blocks");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,2,3");
+ list.add("4,5,6");
+ list.add("7,8,9");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema);
+
+ JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks = MLContextConversionUtil
+ .dataFrameToMatrixBinaryBlocks(dataFrame);
+ Tuple2<MatrixIndexes, MatrixBlock> first = binaryBlocks.first();
+ MatrixBlock mb = first._2();
+ double[][] matrix = DataConverter.convertToDoubleMatrix(mb);
+ Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0 }, matrix[0], 0.0);
+ Assert.assertArrayEquals(new double[] { 4.0, 5.0, 6.0 }, matrix[1], 0.0);
+ Assert.assertArrayEquals(new double[] { 7.0, 8.0, 9.0 }, matrix[2], 0.0);
+ }
+
+ @Test
+ public void testGetTuple1DML() {
+ System.out.println("MLContextTest - Get Tuple1<Matrix> DML");
+ JavaRDD<String> javaRddString = sc
+ .parallelize(Stream.of("1,2,3", "4,5,6", "7,8,9").collect(Collectors.toList()));
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ Dataset<Row> df = spark.createDataFrame(javaRddRow, schema);
+
+ Script script = dml("N=M*2").in("M", df).out("N");
+ Tuple1<Matrix> tuple = ml.execute(script).getTuple("N");
+ double[][] n = tuple._1().to2DDoubleArray();
+ Assert.assertEquals(2.0, n[0][0], 0);
+ Assert.assertEquals(4.0, n[0][1], 0);
+ Assert.assertEquals(6.0, n[0][2], 0);
+ Assert.assertEquals(8.0, n[1][0], 0);
+ Assert.assertEquals(10.0, n[1][1], 0);
+ Assert.assertEquals(12.0, n[1][2], 0);
+ Assert.assertEquals(14.0, n[2][0], 0);
+ Assert.assertEquals(16.0, n[2][1], 0);
+ Assert.assertEquals(18.0, n[2][2], 0);
+ }
+
+ @Test
+ public void testGetTuple2DML() {
+ System.out.println("MLContextTest - Get Tuple2<Matrix,Double> DML");
+
+ double[][] m = new double[][] { { 1, 2 }, { 3, 4 } };
+
+ Script script = dml("N=M*2;s=sum(N)").in("M", m).out("N", "s");
+ Tuple2<Matrix, Double> tuple = ml.execute(script).getTuple("N", "s");
+ double[][] n = tuple._1().to2DDoubleArray();
+ double s = tuple._2();
+ Assert.assertArrayEquals(new double[] { 2, 4 }, n[0], 0.0);
+ Assert.assertArrayEquals(new double[] { 6, 8 }, n[1], 0.0);
+ Assert.assertEquals(20.0, s, 0.0);
+ }
+
+ @Test
+ public void testGetTuple3DML() {
+ System.out.println("MLContextTest - Get Tuple3<Long,Double,Boolean> DML");
+
+ Script script = dml("a=1+2;b=a+0.5;c=TRUE;").out("a", "b", "c");
+ Tuple3<Long, Double, Boolean> tuple = ml.execute(script).getTuple("a", "b", "c");
+ long a = tuple._1();
+ double b = tuple._2();
+ boolean c = tuple._3();
+ Assert.assertEquals(3, a);
+ Assert.assertEquals(3.5, b, 0.0);
+ Assert.assertEquals(true, c);
+ }
+
+ @Test
+ public void testGetTuple4DML() {
+ System.out.println("MLContextTest - Get Tuple4<Long,Double,Boolean,String> DML");
+
+ Script script = dml("a=1+2;b=a+0.5;c=TRUE;d=\"yes it's \"+c").out("a", "b", "c", "d");
+ Tuple4<Long, Double, Boolean, String> tuple = ml.execute(script).getTuple("a", "b", "c", "d");
+ long a = tuple._1();
+ double b = tuple._2();
+ boolean c = tuple._3();
+ String d = tuple._4();
+ Assert.assertEquals(3, a);
+ Assert.assertEquals(3.5, b, 0.0);
+ Assert.assertEquals(true, c);
+ Assert.assertEquals("yes it's TRUE", d);
+ }
+
}