You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2016/08/25 21:36:05 UTC
incubator-systemml git commit: [SYSTEMML-834] Improve MLContext
DataFrame support
Repository: incubator-systemml
Updated Branches:
refs/heads/master 9f12b5c66 -> 97dee8fba
[SYSTEMML-834] Improve MLContext DataFrame support
Closes #218.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/97dee8fb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/97dee8fb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/97dee8fb
Branch: refs/heads/master
Commit: 97dee8fba7252f9b868e23d69f23c36053f48445
Parents: 9f12b5c
Author: Deron Eriksson <de...@us.ibm.com>
Authored: Thu Aug 25 14:33:00 2016 -0700
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Thu Aug 25 14:33:00 2016 -0700
----------------------------------------------------------------------
.../sysml/api/mlcontext/BinaryBlockMatrix.java | 21 +-
.../api/mlcontext/MLContextConversionUtil.java | 66 +++-
.../apache/sysml/api/mlcontext/MLResults.java | 136 ++++++-
.../org/apache/sysml/api/mlcontext/Matrix.java | 46 ++-
.../sysml/api/mlcontext/MatrixFormat.java | 54 ++-
.../integration/mlcontext/MLContextTest.java | 390 ++++++++++++++++++-
6 files changed, 689 insertions(+), 24 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
index ea6fcf0..b13669d 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
@@ -99,12 +99,21 @@ public class BinaryBlockMatrix {
public JavaPairRDD<MatrixIndexes, MatrixBlock> getBinaryBlocks() {
return binaryBlocks;
}
-
- public MatrixBlock getMatrixBlock() throws DMLRuntimeException {
- MatrixCharacteristics mc = getMatrixCharacteristics();
- MatrixBlock mb = SparkExecutionContext.toMatrixBlock(binaryBlocks, (int) mc.getRows(), (int) mc.getCols(),
- mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
- return mb;
+
+ /**
+ * Obtain a SystemML binary-block matrix as a {@code MatrixBlock}
+ *
+ * @return the SystemML binary-block matrix as a {@code MatrixBlock}
+ */
+ public MatrixBlock getMatrixBlock() {
+ try {
+ MatrixCharacteristics mc = getMatrixCharacteristics();
+ MatrixBlock mb = SparkExecutionContext.toMatrixBlock(binaryBlocks, (int) mc.getRows(), (int) mc.getCols(),
+ mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
+ return mb;
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception while getting MatrixBlock from binary-block matrix", e);
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 33a5a3c..3a482ef 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -33,6 +33,7 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
@@ -311,7 +312,14 @@ public class MLContextConversionUtil {
} else {
matrixCharacteristics = new MatrixCharacteristics();
}
- determineDataFrameDimensionsIfNeeded(dataFrame, matrixCharacteristics);
+
+ if (isDataFrameWithIDColumn(matrixMetadata)) {
+ dataFrame = dataFrame.sort("ID").drop("ID");
+ }
+
+ boolean isVectorBasedDataFrame = isVectorBasedDataFrame(matrixMetadata);
+
+ determineDataFrameDimensionsIfNeeded(dataFrame, matrixCharacteristics, isVectorBasedDataFrame);
if (matrixMetadata != null) {
// so external reference can be updated with the metadata
matrixMetadata.setMatrixCharacteristics(matrixCharacteristics);
@@ -320,12 +328,50 @@ public class MLContextConversionUtil {
JavaRDD<Row> javaRDD = dataFrame.javaRDD();
JavaPairRDD<Row, Long> prepinput = javaRDD.zipWithIndex();
JavaPairRDD<MatrixIndexes, MatrixBlock> out = prepinput.mapPartitionsToPair(new DataFrameToBinaryBlockFunction(
- matrixCharacteristics, false));
+ matrixCharacteristics, isVectorBasedDataFrame));
out = RDDAggregateUtils.mergeByKey(out);
return out;
}
/**
+ * Return whether or not the DataFrame has an ID column.
+ *
+ * @param matrixMetadata
+ * the matrix metadata
+ * @return {@code true} if the DataFrame has an ID column, {@code false}
+ * otherwise.
+ */
+ public static boolean isDataFrameWithIDColumn(MatrixMetadata matrixMetadata) {
+ if (matrixMetadata == null) {
+ return false;
+ }
+ MatrixFormat matrixFormat = matrixMetadata.getMatrixFormat();
+ if (matrixFormat == null) {
+ return false;
+ }
+ return matrixFormat.hasIDColumn();
+ }
+
+ /**
+ * Return whether or not the DataFrame is vector-based.
+ *
+ * @param matrixMetadata
+ * the matrix metadata
+ * @return {@code true} if the DataFrame is vector-based, {@code false}
+ * otherwise.
+ */
+ public static boolean isVectorBasedDataFrame(MatrixMetadata matrixMetadata) {
+ if (matrixMetadata == null) {
+ return false;
+ }
+ MatrixFormat matrixFormat = matrixMetadata.getMatrixFormat();
+ if (matrixFormat == null) {
+ return false;
+ }
+ return matrixFormat.isVectorBased();
+ }
+
+ /**
* If the {@code DataFrame} dimensions aren't present in the
* {@code MatrixCharacteristics} metadata, determine the dimensions and
* place them in the {@code MatrixCharacteristics} metadata.
@@ -334,20 +380,28 @@ public class MLContextConversionUtil {
* the Spark {@code DataFrame}
* @param matrixCharacteristics
* the matrix metadata
+ * @param vectorBased
+ * is the DataFrame vector-based
*/
public static void determineDataFrameDimensionsIfNeeded(DataFrame dataFrame,
- MatrixCharacteristics matrixCharacteristics) {
+ MatrixCharacteristics matrixCharacteristics, boolean vectorBased) {
if (!matrixCharacteristics.dimsKnown(true)) {
- // only available to the new MLContext API, not the old API
MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext();
SparkContext sparkContext = activeMLContext.getSparkContext();
@SuppressWarnings("resource")
JavaSparkContext javaSparkContext = new JavaSparkContext(sparkContext);
Accumulator<Double> aNnz = javaSparkContext.accumulator(0L);
- JavaRDD<Row> javaRDD = dataFrame.javaRDD().map(new DataFrameAnalysisFunction(aNnz, false));
+ JavaRDD<Row> javaRDD = dataFrame.javaRDD().map(new DataFrameAnalysisFunction(aNnz, vectorBased));
long numRows = javaRDD.count();
- long numColumns = dataFrame.columns().length;
+ long numColumns;
+ if (vectorBased) {
+ Vector v = (Vector) javaRDD.first().get(0);
+ numColumns = v.size();
+ } else {
+ numColumns = dataFrame.columns().length;
+ }
+
long numNonZeros = UtilFunctions.toLong(aNnz.value());
matrixCharacteristics.set(numRows, numColumns, matrixCharacteristics.getRowsPerBlock(),
matrixCharacteristics.getColsPerBlock(), numNonZeros);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
index 31798e0..dbc8f5d 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
@@ -236,7 +236,7 @@ public class MLResults {
}
/**
- * Obtain an output as a {@code DataFrame} of doubles.
+ * Obtain an output as a {@code DataFrame} of doubles with an ID column.
* <p>
* The following matrix in DML:
* </p>
@@ -245,13 +245,13 @@ public class MLResults {
* <p>
* is equivalent to the following {@code DataFrame} of doubles:
* </p>
- * <code>[0.0,1.0,2.0]
- * <br>[1.0,3.0,4.0]
+ * <code>[1.0,1.0,2.0]
+ * <br>[2.0,3.0,4.0]
* </code>
*
* @param outputName
* the name of the output
- * @return the output as a {@code DataFrame} of doubles
+ * @return the output as a {@code DataFrame} of doubles with an ID column
*/
public DataFrame getDataFrame(String outputName) {
MatrixObject mo = getMatrixObject(outputName);
@@ -259,6 +259,35 @@ public class MLResults {
return df;
}
+ /**
+ * Obtain an output as a {@code DataFrame} of doubles or vectors with an ID
+ * column.
+ * <p>
+ * The following matrix in DML:
+ * </p>
+ * <code>M = full('1 2 3 4', rows=2, cols=2);
+ * </code>
+ * <p>
+ * is equivalent to the following {@code DataFrame} of doubles:
+ * </p>
+ * <code>[1.0,1.0,2.0]
+ * <br>[2.0,3.0,4.0]
+ * </code>
+ * <p>
+ * or the following {@code DataFrame} of vectors:
+ * </p>
+ * <code>[1.0,[1.0,2.0]]
+ * <br>[2.0,[3.0,4.0]]
+ * </code>
+ *
+ * @param outputName
+ * the name of the output
+ * @param isVectorDF
+ * {@code true} for a vector {@code DataFrame}, {@code false} for
+ * a double {@code DataFrame}
+ * @return the output as a {@code DataFrame} of doubles or vectors with an
+ * ID column
+ */
public DataFrame getDataFrame(String outputName, boolean isVectorDF) {
MatrixObject mo = getMatrixObject(outputName);
DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, isVectorDF);
@@ -266,6 +295,104 @@ public class MLResults {
}
/**
+ * Obtain an output as a {@code DataFrame} of doubles with an ID column.
+ * <p>
+ * The following matrix in DML:
+ * </p>
+ * <code>M = full('1 2 3 4', rows=2, cols=2);
+ * </code>
+ * <p>
+ * is equivalent to the following {@code DataFrame} of doubles:
+ * </p>
+ * <code>[1.0,1.0,2.0]
+ * <br>[2.0,3.0,4.0]
+ * </code>
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code DataFrame} of doubles with an ID column
+ */
+ public DataFrame getDataFrameDoubleWithIDColumn(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false);
+ return df;
+ }
+
+ /**
+ * Obtain an output as a {@code DataFrame} of vectors with an ID column.
+ * <p>
+ * The following matrix in DML:
+ * </p>
+ * <code>M = full('1 2 3 4', rows=2, cols=2);
+ * </code>
+ * <p>
+ * is equivalent to the following {@code DataFrame} of vectors:
+ * </p>
+ * <code>[1.0,[1.0,2.0]]
+ * <br>[2.0,[3.0,4.0]]
+ * </code>
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code DataFrame} of vectors with an ID column
+ */
+ public DataFrame getDataFrameVectorWithIDColumn(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, true);
+ return df;
+ }
+
+ /**
+ * Obtain an output as a {@code DataFrame} of doubles with no ID column.
+ * <p>
+ * The following matrix in DML:
+ * </p>
+ * <code>M = full('1 2 3 4', rows=2, cols=2);
+ * </code>
+ * <p>
+ * is equivalent to the following {@code DataFrame} of doubles:
+ * </p>
+ * <code>[1.0,2.0]
+ * <br>[3.0,4.0]
+ * </code>
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code DataFrame} of doubles with no ID column
+ */
+ public DataFrame getDataFrameDoubleNoIDColumn(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false);
+ df = df.sort("ID").drop("ID");
+ return df;
+ }
+
+ /**
+ * Obtain an output as a {@code DataFrame} of vectors with no ID column.
+ * <p>
+ * The following matrix in DML:
+ * </p>
+ * <code>M = full('1 2 3 4', rows=2, cols=2);
+ * </code>
+ * <p>
+ * is equivalent to the following {@code DataFrame} of vectors:
+ * </p>
+ * <code>[[1.0,2.0]]
+ * <br>[[3.0,4.0]]
+ * </code>
+ *
+ * @param outputName
+ * the name of the output
+ * @return the output as a {@code DataFrame} of vectors with no ID column
+ */
+ public DataFrame getDataFrameVectorNoIDColumn(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, true);
+ df = df.sort("ID").drop("ID");
+ return df;
+ }
+
+ /**
* Obtain an output as a {@code Matrix}.
*
* @param outputName
@@ -278,7 +405,6 @@ public class MLResults {
return matrix;
}
-
/**
* Obtain an output as a {@code BinaryBlockMatrix}.
*
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java b/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
index 3ee41b7..abd785c 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
@@ -103,9 +103,9 @@ public class Matrix {
}
/**
- * Obtain the matrix as a {@code DataFrame}
+ * Obtain the matrix as a {@code DataFrame} of doubles with an ID column
*
- * @return the matrix as a {@code DataFrame}
+ * @return the matrix as a {@code DataFrame} of doubles with an ID column
*/
public DataFrame asDataFrame() {
DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext, false);
@@ -113,6 +113,48 @@ public class Matrix {
}
/**
+ * Obtain the matrix as a {@code DataFrame} of doubles with an ID column
+ *
+ * @return the matrix as a {@code DataFrame} of doubles with an ID column
+ */
+ public DataFrame asDataFrameDoubleWithIDColumn() {
+ DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext, false);
+ return df;
+ }
+
+ /**
+ * Obtain the matrix as a {@code DataFrame} of doubles with no ID column
+ *
+ * @return the matrix as a {@code DataFrame} of doubles with no ID column
+ */
+ public DataFrame asDataFrameDoubleNoIDColumn() {
+ DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext, false);
+ df = df.sort("ID").drop("ID");
+ return df;
+ }
+
+ /**
+ * Obtain the matrix as a {@code DataFrame} of vectors with an ID column
+ *
+ * @return the matrix as a {@code DataFrame} of vectors with an ID column
+ */
+ public DataFrame asDataFrameVectorWithIDColumn() {
+ DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext, true);
+ return df;
+ }
+
+ /**
+ * Obtain the matrix as a {@code DataFrame} of vectors with no ID column
+ *
+ * @return the matrix as a {@code DataFrame} of vectors with no ID column
+ */
+ public DataFrame asDataFrameVectorNoIDColumn() {
+ DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext, true);
+ df = df.sort("ID").drop("ID");
+ return df;
+ }
+
+ /**
* Obtain the matrix as a {@code BinaryBlockMatrix}
*
* @return the matrix as a {@code BinaryBlockMatrix}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java b/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java
index 50ed634..a7ac395 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MatrixFormat.java
@@ -34,6 +34,58 @@ public enum MatrixFormat {
* (I J V) format (sparse). I and J represent matrix coordinates and V
* represents the value. The I J and V values are space-separated.
*/
- IJV;
+ IJV,
+
+ /**
+ * DataFrame of doubles with an ID column.
+ */
+ DF_DOUBLES_WITH_ID_COLUMN,
+
+ /**
+ * DataFrame of doubles with no ID column.
+ */
+ DF_DOUBLES_WITH_NO_ID_COLUMN,
+
+ /**
+ * Vector DataFrame with an ID column.
+ */
+ DF_VECTOR_WITH_ID_COLUMN,
+
+ /**
+ * Vector DataFrame with no ID column.
+ */
+ DF_VECTOR_WITH_NO_ID_COLUMN;
+
+ /**
+ * Is the matrix format vector-based?
+ *
+ * @return {@code true} if matrix is a vector-based DataFrame, {@code false}
+ * otherwise.
+ */
+ public boolean isVectorBased() {
+ if (this == DF_VECTOR_WITH_ID_COLUMN) {
+ return true;
+ } else if (this == DF_VECTOR_WITH_NO_ID_COLUMN) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ /**
+ * Does the DataFrame have an ID column?
+ *
+ * @return {@code true} if the DataFrame has an ID column, {@code false}
+ * otherwise.
+ */
+ public boolean hasIDColumn() {
+ if (this == DF_DOUBLES_WITH_ID_COLUMN) {
+ return true;
+ } else if (this == DF_VECTOR_WITH_ID_COLUMN) {
+ return true;
+ } else {
+ return false;
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/97dee8fb/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
index e6e1046..7be657b 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -46,6 +46,9 @@ import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
@@ -497,8 +500,8 @@ public class MLContextTest extends AutomatedTestBase {
}
@Test
- public void testDataFrameSumDML() {
- System.out.println("MLContextTest - DataFrame sum DML");
+ public void testDataFrameSumDMLDoublesWithNoIDColumn() {
+ System.out.println("MLContextTest - DataFrame sum DML, doubles with no ID column");
List<String> list = new ArrayList<String>();
list.add("10,20,30");
@@ -521,8 +524,8 @@ public class MLContextTest extends AutomatedTestBase {
}
@Test
- public void testDataFrameSumPYDML() {
- System.out.println("MLContextTest - DataFrame sum PYDML");
+ public void testDataFrameSumPYDMLDoublesWithNoIDColumn() {
+ System.out.println("MLContextTest - DataFrame sum PYDML, doubles with no ID column");
List<String> list = new ArrayList<String>();
list.add("10,20,30");
@@ -544,9 +547,236 @@ public class MLContextTest extends AutomatedTestBase {
ml.execute(script);
}
+ @Test
+ public void testDataFrameSumDMLDoublesWithIDColumn() {
+ System.out.println("MLContextTest - DataFrame sum DML, doubles with ID column");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,1,2,3");
+ list.add("2,4,5,6");
+ list.add("3,7,8,9");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
+ setExpectedStdOut("sum: 45.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testDataFrameSumPYDMLDoublesWithIDColumn() {
+ System.out.println("MLContextTest - DataFrame sum PYDML, doubles with ID column");
+
+ List<String> list = new ArrayList<String>();
+ list.add("1,1,2,3");
+ list.add("2,4,5,6");
+ list.add("3,7,8,9");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN);
+
+ Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame, mm);
+ setExpectedStdOut("sum: 45.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testDataFrameSumDMLDoublesWithIDColumnSortCheck() {
+ System.out.println("MLContextTest - DataFrame sum DML, doubles with ID column sort check");
+
+ List<String> list = new ArrayList<String>();
+ list.add("3,7,8,9");
+ list.add("1,1,2,3");
+ list.add("2,4,5,6");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN);
+
+ Script script = dml("print('M[1,1]: ' + as.scalar(M[1,1]));").in("M", dataFrame, mm);
+ setExpectedStdOut("M[1,1]: 1.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testDataFrameSumPYDMLDoublesWithIDColumnSortCheck() {
+ System.out.println("MLContextTest - DataFrame sum PYDML ID, doubles with ID column sort check");
+
+ List<String> list = new ArrayList<String>();
+ list.add("3,7,8,9");
+ list.add("1,1,2,3");
+ list.add("2,4,5,6");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_ID_COLUMN);
+
+ Script script = pydml("print('M[0,0]: ' + scalar(M[0,0]))").in("M", dataFrame, mm);
+ setExpectedStdOut("M[0,0]: 1.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testDataFrameSumDMLVectorWithIDColumn() {
+ System.out.println("MLContextTest - DataFrame sum DML, vector with ID column");
+
+ List<Tuple2<Double, Vector>> list = new ArrayList<Tuple2<Double, Vector>>();
+ list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 2.0, 3.0)));
+ list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 5.0, 6.0)));
+ list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 8.0, 9.0)));
+ JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_ID_COLUMN);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
+ setExpectedStdOut("sum: 45.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testDataFrameSumPYDMLVectorWithIDColumn() {
+ System.out.println("MLContextTest - DataFrame sum PYDML, vector with ID column");
+
+ List<Tuple2<Double, Vector>> list = new ArrayList<Tuple2<Double, Vector>>();
+ list.add(new Tuple2<Double, Vector>(1.0, Vectors.dense(1.0, 2.0, 3.0)));
+ list.add(new Tuple2<Double, Vector>(2.0, Vectors.dense(4.0, 5.0, 6.0)));
+ list.add(new Tuple2<Double, Vector>(3.0, Vectors.dense(7.0, 8.0, 9.0)));
+ JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("ID", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_ID_COLUMN);
+
+ Script script = dml("print('sum: ' + sum(M))").in("M", dataFrame, mm);
+ setExpectedStdOut("sum: 45.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testDataFrameSumDMLVectorWithNoIDColumn() {
+ System.out.println("MLContextTest - DataFrame sum DML, vector with no ID column");
+
+ List<Vector> list = new ArrayList<Vector>();
+ list.add(Vectors.dense(1.0, 2.0, 3.0));
+ list.add(Vectors.dense(4.0, 5.0, 6.0));
+ list.add(Vectors.dense(7.0, 8.0, 9.0));
+ JavaRDD<Vector> javaRddVector = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_NO_ID_COLUMN);
+
+ Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame, mm);
+ setExpectedStdOut("sum: 45.0");
+ ml.execute(script);
+ }
+
+ @Test
+ public void testDataFrameSumPYDMLVectorWithNoIDColumn() {
+ System.out.println("MLContextTest - DataFrame sum PYDML, vector with no ID column");
+
+ List<Vector> list = new ArrayList<Vector>();
+ list.add(Vectors.dense(1.0, 2.0, 3.0));
+ list.add(Vectors.dense(4.0, 5.0, 6.0));
+ list.add(Vectors.dense(7.0, 8.0, 9.0));
+ JavaRDD<Vector> javaRddVector = sc.parallelize(list);
+
+ JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
+
+ MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_NO_ID_COLUMN);
+
+ Script script = dml("print('sum: ' + sum(M))").in("M", dataFrame, mm);
+ setExpectedStdOut("sum: 45.0");
+ ml.execute(script);
+ }
+
+ static class DoubleVectorRow implements Function<Tuple2<Double, Vector>, Row> {
+ private static final long serialVersionUID = 3605080559931384163L;
+
+ @Override
+ public Row call(Tuple2<Double, Vector> tup) throws Exception {
+ Double doub = tup._1();
+ Vector vect = tup._2();
+ return RowFactory.create(doub, vect);
+ }
+ }
+
+ static class VectorRow implements Function<Vector, Row> {
+ private static final long serialVersionUID = 7077761802433569068L;
+
+ @Override
+ public Row call(Vector vect) throws Exception {
+ return RowFactory.create(vect);
+ }
+ }
+
static class CommaSeparatedValueStringToRow implements Function<String, Row> {
private static final long serialVersionUID = -7871020122671747808L;
+ @Override
public Row call(String str) throws Exception {
String[] fields = str.split(",");
return RowFactory.create((Object[]) fields);
@@ -1032,6 +1262,158 @@ public class MLContextTest extends AutomatedTestBase {
}
@Test
+ public void testOutputDataFrameDMLVectorWithIDColumn() {
+ System.out.println("MLContextTest - output DataFrame DML, vector with ID column");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ DataFrame dataFrame = results.getDataFrameVectorWithIDColumn("M");
+ List<Row> list = dataFrame.collectAsList();
+
+ Row row1 = list.get(0);
+ Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+ Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) row1.get(1)).toArray(), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+ Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) row2.get(1)).toArray(), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFramePYDMLVectorWithIDColumn() {
+ System.out.println("MLContextTest - output DataFrame PYDML, vector with ID column");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ DataFrame dataFrame = results.getDataFrameVectorWithIDColumn("M");
+ List<Row> list = dataFrame.collectAsList();
+
+ Row row1 = list.get(0);
+ Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+ Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) row1.get(1)).toArray(), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+ Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) row2.get(1)).toArray(), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFrameDMLVectorNoIDColumn() {
+ System.out.println("MLContextTest - output DataFrame DML, vector no ID column");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ DataFrame dataFrame = results.getDataFrameVectorNoIDColumn("M");
+ List<Row> list = dataFrame.collectAsList();
+
+ Row row1 = list.get(0);
+ Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) row1.get(0)).toArray(), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) row2.get(0)).toArray(), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFramePYDMLVectorNoIDColumn() {
+ System.out.println("MLContextTest - output DataFrame PYDML, vector no ID column");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ DataFrame dataFrame = results.getDataFrameVectorNoIDColumn("M");
+ List<Row> list = dataFrame.collectAsList();
+
+ Row row1 = list.get(0);
+ Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, ((Vector) row1.get(0)).toArray(), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, ((Vector) row2.get(0)).toArray(), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFrameDMLDoublesWithIDColumn() {
+ System.out.println("MLContextTest - output DataFrame DML, doubles with ID column");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ DataFrame dataFrame = results.getDataFrameDoubleWithIDColumn("M");
+ List<Row> list = dataFrame.collectAsList();
+
+ Row row1 = list.get(0);
+ Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+ Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
+ Assert.assertEquals(2.0, row1.getDouble(2), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+ Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
+ Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFramePYDMLDoublesWithIDColumn() {
+ System.out.println("MLContextTest - output DataFrame PYDML, doubles with ID column");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ DataFrame dataFrame = results.getDataFrameDoubleWithIDColumn("M");
+ List<Row> list = dataFrame.collectAsList();
+
+ Row row1 = list.get(0);
+ Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+ Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
+ Assert.assertEquals(2.0, row1.getDouble(2), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+ Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
+ Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFrameDMLDoublesNoIDColumn() {
+ System.out.println("MLContextTest - output DataFrame DML, doubles no ID column");
+
+ String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+ Script script = dml(s).out("M");
+ MLResults results = ml.execute(script);
+ DataFrame dataFrame = results.getDataFrameDoubleNoIDColumn("M");
+ List<Row> list = dataFrame.collectAsList();
+
+ Row row1 = list.get(0);
+ Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+ Assert.assertEquals(2.0, row1.getDouble(1), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertEquals(3.0, row2.getDouble(0), 0.0);
+ Assert.assertEquals(4.0, row2.getDouble(1), 0.0);
+ }
+
+ @Test
+ public void testOutputDataFramePYDMLDoublesNoIDColumn() {
+ System.out.println("MLContextTest - output DataFrame PYDML, doubles no ID column");
+
+ String s = "M = full('1 2 3 4', rows=2, cols=2)";
+ Script script = pydml(s).out("M");
+ MLResults results = ml.execute(script);
+ DataFrame dataFrame = results.getDataFrameDoubleNoIDColumn("M");
+ List<Row> list = dataFrame.collectAsList();
+
+ Row row1 = list.get(0);
+ Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+ Assert.assertEquals(2.0, row1.getDouble(1), 0.0);
+
+ Row row2 = list.get(1);
+ Assert.assertEquals(3.0, row2.getDouble(0), 0.0);
+ Assert.assertEquals(4.0, row2.getDouble(1), 0.0);
+ }
+
+ @Test
public void testTwoScriptsDML() {
System.out.println("MLContextTest - two scripts with inputs and outputs DML");