You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2016/09/15 21:41:08 UTC

incubator-systemml git commit: [SYSTEMML-906][SYSTEMML-907] Update df first to df schema

Repository: incubator-systemml
Updated Branches:
  refs/heads/master ad65dfa10 -> 32fa72621


[SYSTEMML-906][SYSTEMML-907] Update df first to df schema

Utilize DataFrame schema to determine matrix/frame characteristics
rather than calling first().

Closes #242.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/32fa7262
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/32fa7262
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/32fa7262

Branch: refs/heads/master
Commit: 32fa72621b3f37bd2e83c3348ba97e2690498d3a
Parents: ad65dfa
Author: Deron Eriksson <de...@us.ibm.com>
Authored: Thu Sep 15 14:36:10 2016 -0700
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Thu Sep 15 14:36:10 2016 -0700

----------------------------------------------------------------------
 .../api/mlcontext/MLContextConversionUtil.java  |  26 ++--
 .../sysml/api/mlcontext/MLContextUtil.java      |  59 +++++----
 .../integration/mlcontext/MLContextTest.java    | 130 ++++++++++---------
 3 files changed, 120 insertions(+), 95 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/32fa7262/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 799225b..5476902 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -33,11 +33,11 @@ import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
 import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.api.MLContextProxy;
 import org.apache.sysml.conf.ConfigurationManager;
@@ -462,19 +462,23 @@ public class MLContextConversionUtil {
 			hasID = true;
 		} catch (IllegalArgumentException iae) {
 		}
-		Row firstRow = dataFrame.first();
+
+		StructField[] fields = schema.fields();
 		MatrixFormat mf = null;
 		if (hasID) {
-			Object object = firstRow.get(1);
-			mf = (object instanceof Vector) ? 
-				MatrixFormat.DF_VECTOR_WITH_INDEX :
-				MatrixFormat.DF_DOUBLES_WITH_INDEX;
+			if (fields[1].dataType() instanceof VectorUDT) {
+				mf = MatrixFormat.DF_VECTOR_WITH_INDEX;
+			} else {
+				mf = MatrixFormat.DF_DOUBLES_WITH_INDEX;
+			}
 		} else {
-			Object object = firstRow.get(0);
-			mf = (object instanceof Vector) ? 
-				MatrixFormat.DF_VECTOR : 
-				MatrixFormat.DF_DOUBLES;
+			if (fields[0].dataType() instanceof VectorUDT) {
+				mf = MatrixFormat.DF_VECTOR;
+			} else {
+				mf = MatrixFormat.DF_DOUBLES;
+			}
 		}
+
 		if (mf == null) {
 			throw new MLContextException("DataFrame format not recognized as an accepted SystemML MatrixFormat");
 		}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/32fa7262/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
index 8c2ba78..5ee0e3a 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -38,10 +38,13 @@ import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
 import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.Row;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.conf.CompilerConfig;
 import org.apache.sysml.conf.CompilerConfig.ConfigType;
 import org.apache.sysml.conf.ConfigurationManager;
@@ -468,8 +471,7 @@ public final class MLContextUtil {
 			} else if (hasFrameMetadata) {
 				return MLContextConversionUtil.dataFrameToFrameObject(name, dataFrame, (FrameMetadata) metadata);
 			} else if (!hasMetadata) {
-				Row firstRow = dataFrame.first();
-				boolean looksLikeMatrix = doesRowLookLikeMatrixRow(firstRow);
+				boolean looksLikeMatrix = doesDataFrameLookLikeMatrix(dataFrame);
 				if (looksLikeMatrix) {
 					return MLContextConversionUtil.dataFrameToMatrixObject(name, dataFrame);
 				} else {
@@ -540,24 +542,31 @@ public final class MLContextUtil {
 	}
 
 	/**
-	 * If no metadata is supplied for a DataFrame, this method can be used to
-	 * determine whether the data appears to be a matrix (or a frame)
+	 * Examine the DataFrame schema to determine whether the data appears to be
+	 * a matrix.
 	 * 
-	 * @param row
-	 *            a row in the DataFrame
-	 * @return {@code true} if the row appears to be a matrix row, {@code false}
-	 *         otherwise
+	 * @param df
+	 *            the DataFrame
+	 * @return {@code true} if the DataFrame appears to be a matrix,
+	 *         {@code false} otherwise
 	 */
-	public static boolean doesRowLookLikeMatrixRow(Row row) {
-		for (int i = 0; i < row.length(); i++) {
-			Object object = row.get(i);
-			if (object instanceof Vector) {
-				return true;
-			}
-			String str = object.toString();
-			try {
-				Double.parseDouble(str);
-			} catch (NumberFormatException e) {
+	public static boolean doesDataFrameLookLikeMatrix(DataFrame df) {
+		StructType schema = df.schema();
+		StructField[] fields = schema.fields();
+		if (fields == null) {
+			return true;
+		}
+		for (StructField field : fields) {
+			DataType dataType = field.dataType();
+			if ((dataType != DataTypes.DoubleType) && (dataType != DataTypes.IntegerType)
+					&& (dataType != DataTypes.LongType) && (!(dataType instanceof VectorUDT))) {
+				// uncomment if we support arrays of doubles for matrices
+				// if (dataType instanceof ArrayType) {
+				// ArrayType arrayType = (ArrayType) dataType;
+				// if (arrayType.elementType() == DataTypes.DoubleType) {
+				// continue;
+				// }
+				// }
 				return false;
 			}
 		}
@@ -931,9 +940,8 @@ public final class MLContextUtil {
 	 *         FrameObject, {@code false} otherwise.
 	 */
 	public static boolean doesSymbolTableContainFrameObject(LocalVariableMap symbolTable, String variableName) {
-		return (symbolTable != null
-			&& symbolTable.keySet().contains(variableName)
-			&& symbolTable.get(variableName) instanceof FrameObject);
+		return (symbolTable != null && symbolTable.keySet().contains(variableName)
+				&& symbolTable.get(variableName) instanceof FrameObject);
 	}
 
 	/**
@@ -948,8 +956,7 @@ public final class MLContextUtil {
 	 *         MatrixObject, {@code false} otherwise.
 	 */
 	public static boolean doesSymbolTableContainMatrixObject(LocalVariableMap symbolTable, String variableName) {
-		return (symbolTable != null
-			&& symbolTable.keySet().contains(variableName)
-			&& symbolTable.get(variableName) instanceof MatrixObject);
+		return (symbolTable != null && symbolTable.keySet().contains(variableName)
+				&& symbolTable.get(variableName) instanceof MatrixObject);
 	}
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/32fa7262/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
index 484a777..5d8e195 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -510,12 +510,12 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("70,80,90");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -536,12 +536,12 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("70,80,90");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -562,13 +562,13 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("3,7,8,9");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -589,13 +589,13 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("3,7,8,9");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -616,13 +616,13 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("2,4,5,6");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -643,13 +643,13 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("2,4,5,6");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -673,7 +673,7 @@ public class MLContextTest extends AutomatedTestBase {
 		JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
 		fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -698,7 +698,7 @@ public class MLContextTest extends AutomatedTestBase {
 		JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
 		fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -788,6 +788,20 @@ public class MLContextTest extends AutomatedTestBase {
 		}
 	}
 
+	static class CommaSeparatedValueStringToDoubleArrayRow implements Function<String, Row> {
+		private static final long serialVersionUID = -8058786466523637317L;
+
+		@Override
+		public Row call(String str) throws Exception {
+			String[] strings = str.split(",");
+			Double[] doubles = new Double[strings.length];
+			for (int i = 0; i < strings.length; i++) {
+				doubles[i] = Double.parseDouble(strings[i]);
+			}
+			return RowFactory.create((Object[]) doubles);
+		}
+	}
+
 	@Test
 	public void testCSVMatrixFileInputParameterSumDML() {
 		System.out.println("MLContextTest - CSV matrix file input parameter sum DML");
@@ -1836,12 +1850,12 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("70,80,90");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -1862,12 +1876,12 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("70,80,90");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -2052,12 +2066,12 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("4,4,4");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -2076,12 +2090,12 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("4,4,4");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -2100,13 +2114,13 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("3,4,4,4");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -2125,13 +2139,13 @@ public class MLContextTest extends AutomatedTestBase {
 		list.add("3,4,4,4");
 		JavaRDD<String> javaRddString = sc.parallelize(list);
 
-		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow());
+		JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true));
-		fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
+		fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true));
+		fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
 
@@ -2153,7 +2167,7 @@ public class MLContextTest extends AutomatedTestBase {
 		JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
 		fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);
@@ -2176,7 +2190,7 @@ public class MLContextTest extends AutomatedTestBase {
 		JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow());
 		SQLContext sqlContext = new SQLContext(sc);
 		List<StructField> fields = new ArrayList<StructField>();
-		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.StringType, true));
+		fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.IntegerType, true));
 		fields.add(DataTypes.createStructField("C1", new VectorUDT(), true));
 		StructType schema = DataTypes.createStructType(fields);
 		DataFrame dataFrame = sqlContext.createDataFrame(javaRddRow, schema);