You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2016/09/15 21:41:08 UTC
incubator-systemml git commit: [SYSTEMML-906][SYSTEMML-907] Update df
first to df schema
Repository: incubator-systemml
Updated Branches:
refs/heads/master ad65dfa10 -> 32fa72621
[SYSTEMML-906][SYSTEMML-907] Update df first to df schema
Utilize DataFrame schema to determine matrix/frame characteristics
rather than calling first().
Closes #242.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/32fa7262
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/32fa7262
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/32fa7262
Branch: refs/heads/master
Commit: 32fa72621b3f37bd2e83c3348ba97e2690498d3a
Parents: ad65dfa
Author: Deron Eriksson <de...@us.ibm.com>
Authored: Thu Sep 15 14:36:10 2016 -0700
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Thu Sep 15 14:36:10 2016 -0700
----------------------------------------------------------------------
.../api/mlcontext/MLContextConversionUtil.java | 26 ++--
.../sysml/api/mlcontext/MLContextUtil.java | 59 +++++----
.../integration/mlcontext/MLContextTest.java | 130 ++++++++++---------
3 files changed, 120 insertions(+), 95 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/32fa7262/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 799225b..5476902 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -33,11 +33,11 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.sysml.api.MLContextProxy;
import org.apache.sysml.conf.ConfigurationManager;
@@ -462,19 +462,23 @@ public class MLContextConversionUtil {
hasID = true;
} catch (IllegalArgumentException iae) {
}
- Row firstRow = dataFrame.first();
+
+ StructField[] fields = schema.fields();
MatrixFormat mf = null;
if (hasID) {
- Object object = firstRow.get(1);
- mf = (object instanceof Vector) ?
- MatrixFormat.DF_VECTOR_WITH_INDEX :
- MatrixFormat.DF_DOUBLES_WITH_INDEX;
+ if (fields[1].dataType() instanceof VectorUDT) {
+ mf = MatrixFormat.DF_VECTOR_WITH_INDEX;
+ } else {
+ mf = MatrixFormat.DF_DOUBLES_WITH_INDEX;
+ }
} else {
- Object object = firstRow.get(0);
- mf = (object instanceof Vector) ?
- MatrixFormat.DF_VECTOR :
- MatrixFormat.DF_DOUBLES;
+ if (fields[0].dataType() instanceof VectorUDT) {
+ mf = MatrixFormat.DF_VECTOR;
+ } else {
+ mf = MatrixFormat.DF_DOUBLES;
+ }
}
+
if (mf == null) {
throw new MLContextException("DataFrame format not recognized as an accepted SystemML MatrixFormat");
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/32fa7262/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
index 8c2ba78..5ee0e3a 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -38,10 +38,13 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.Row;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
import org.apache.sysml.conf.CompilerConfig;
import org.apache.sysml.conf.CompilerConfig.ConfigType;
import org.apache.sysml.conf.ConfigurationManager;
@@ -468,8 +471,7 @@ public final class MLContextUtil {
} else if (hasFrameMetadata) {
return MLContextConversionUtil.dataFrameToFrameObject(name, dataFrame, (FrameMetadata) metadata);
} else if (!hasMetadata) {
- Row firstRow = dataFrame.first();
- boolean looksLikeMatrix = doesRowLookLikeMatrixRow(firstRow);
+ boolean looksLikeMatrix = doesDataFrameLookLikeMatrix(dataFrame);
if (looksLikeMatrix) {
return MLContextConversionUtil.dataFrameToMatrixObject(name, dataFrame);
} else {
@@ -540,24 +542,31 @@ public final class MLContextUtil {
}
/**
- * If no metadata is supplied for a DataFrame, this method can be used to
- * determine whether the data appears to be a matrix (or a frame)
+ * Examine the DataFrame schema to determine whether the data appears to be
+ * a matrix.
*
- * @param row
- * a row in the DataFrame
- * @return {@code true} if the row appears to be a matrix row, {@code false}
- * otherwise
+ * @param df
+ * the DataFrame
+ * @return {@code true} if the DataFrame appears to be a matrix,
+ * {@code false} otherwise
*/
- public static boolean doesRowLookLikeMatrixRow(Row row) {
- for (int i = 0; i < row.length(); i++) {
- Object object = row.get(i);
- if (object instanceof Vector) {
- return true;
- }
- String str = object.toString();
- try {
- Double.parseDouble(str);
- } catch (NumberFormatException e) {
+ public static boolean doesDataFrameLookLikeMatrix(DataFrame df) {
+ StructType schema = df.schema();
+ StructField[] fields = schema.fields();
+ if (fields == null) {
+ return true;
+ }
+ for (StructField field : fields) {
+ DataType dataType = field.dataType();
+ if ((dataType != DataTypes.DoubleType) && (dataType != DataTypes.IntegerType)
+ && (dataType != DataTypes.LongType) && (!(dataType instanceof VectorUDT))) {
+ // uncomment if we support arrays of doubles for matrices
+ // if (dataType instanceof ArrayType) {
+ // ArrayType arrayType = (ArrayType) dataType;
+ // if (arrayType.elementType() == DataTypes.DoubleType) {
+ // continue;
+ // }
+ // }
return false;
}
}
@@ -931,9 +940,8 @@ public final class MLContextUtil {
* FrameObject, {@code false} otherwise.
*/
public static boolean doesSymbolTableContainFrameObject(LocalVariableMap symbolTable, String variableName) {
- return (symbolTable != null
- && symbolTable.keySet().contains(variableName)
- && symbolTable.get(variableName) instanceof FrameObject);
+ return (symbolTable != null && symbolTable.keySet().contains(variableName)
+ && symbolTable.get(variableName) instanceof FrameObject);
}
/**
@@ -948,8 +956,7 @@ public final class MLContextUtil {
* MatrixObject, {@code false} otherwise.
*/
public static boolean doesSymbolTableContainMatrixObject(LocalVariableMap symbolTable, String variableName) {
- return (symbolTable != null
- && symbolTable.keySet().contains(variableName)
- && symbolTable.get(variableName) instanceof MatrixObject);
+ return (symbolTable != null && symbolTable.keySet().contains(variableName)
+ && symbolTable.get(variableName) instanceof MatrixObject);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/32fa7262/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
index 484a777..5d8e195 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -510,12 +510,12 @@ public class MLContextTest extends AutomatedTestBase {
list.add("70,80,90");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -536,12 +536,12 @@ public class MLContextTest extends AutomatedTestBase {
list.add("70,80,90");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -562,13 +562,13 @@ public class MLContextTest extends AutomatedTestBase {
list.add("3,7,8,9");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -589,13 +589,13 @@ public class MLContextTest extends AutomatedTestBase {
list.add("3,7,8,9");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -616,13 +616,13 @@ public class MLContextTest extends AutomatedTestBase {
list.add("2,4,5,6");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -643,13 +643,13 @@ public class MLContextTest extends AutomatedTestBase {
list.add("2,4,5,6");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -673,7 +673,7 @@ public class MLContextTest extends AutomatedTestBase {
JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -698,7 +698,7 @@ public class MLContextTest extends AutomatedTestBase {
JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -788,6 +788,20 @@ public class MLContextTest extends AutomatedTestBase {
}
}
+ static class CommaSeparatedValueStringToDoubleArrayRow implements Function<String, Row> {
+ private static final long serialVersionUID = -8058786466523637317L;
+
+ @Override
+ public Row call(String str) throws Exception {
+ String[] strings = str.split(",");
+ Double[] doubles = new Double[strings.length];
+ for (int i = 0; i < strings.length; i++) {
+ doubles[i] = Double.parseDouble(strings[i]);
+ }
+ return RowFactory.create((Object[]) doubles);
+ }
+ }
+
@Test
public void testCSVMatrixFileInputParameterSumDML() {
System.out.println("MLContextTest - CSV matrix file input parameter sum DML");
@@ -1836,12 +1850,12 @@ public class MLContextTest extends AutomatedTestBase {
list.add("70,80,90");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -1862,12 +1876,12 @@ public class MLContextTest extends AutomatedTestBase {
list.add("70,80,90");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -2052,12 +2066,12 @@ public class MLContextTest extends AutomatedTestBase {
list.add("4,4,4");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -2076,12 +2090,12 @@ public class MLContextTest extends AutomatedTestBase {
list.add("4,4,4");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -2100,13 +2114,13 @@ public class MLContextTest extends AutomatedTestBase {
list.add("3,4,4,4");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -2125,13 +2139,13 @@ public class MLContextTest extends AutomatedTestBase {
list.add("3,4,4,4");
JavaRDD<String> javaRddString = sc.parallelize(list);
- JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
- fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+ fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+ fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -2153,7 +2167,7 @@ public class MLContextTest extends AutomatedTestBase {
JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -2176,7 +2190,7 @@ public class MLContextTest extends AutomatedTestBase {
JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
SQLContext sqlContext = new SQLContext(sc);
List<StructField> fields = new ArrayList<StructField>();
- fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
+ fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
StructType schema = DataTypes.createStructType(fields);
DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);