You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2017/02/10 22:11:30 UTC
incubator-systemml git commit: [SYSTEMML-1232] Migrate
stringDataFrameToVectorDataFrame to ml Vector
Repository: incubator-systemml
Updated Branches:
refs/heads/master 10b7b8669 -> 67f16c46e
[SYSTEMML-1232] Migrate stringDataFrameToVectorDataFrame to ml Vector
Restore and migrate RDDConverterUtilsExt.stringDataFrameToVectorDataFrame
method from mllib Vector class to ml Vector class. Use NumericParser since
ml.linalg.Vectors.parse() does not exist.
Closes #379.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/67f16c46
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/67f16c46
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/67f16c46
Branch: refs/heads/master
Commit: 67f16c46e692adfe2533cc31103374d6e5d39bb3
Parents: 10b7b86
Author: Deron Eriksson <de...@us.ibm.com>
Authored: Sun Feb 5 14:21:04 2017 -0800
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Fri Feb 10 14:05:29 2017 -0800
----------------------------------------------------------------------
pom.xml | 1 +
.../spark/utils/RDDConverterUtilsExt.java | 112 +++++++++++--
.../conversion/RDDConverterUtilsExtTest.java | 160 +++++++++++++++++++
.../integration/conversion/ZPackageSuite.java | 36 +++++
4 files changed, 296 insertions(+), 13 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index ab088c8..f81557e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -372,6 +372,7 @@
<includes>
<include>**/integration/applications/**/*Suite.java</include>
+ <include>**/integration/conversion/*Suite.java</include>
<include>**/integration/functions/data/*Suite.java</include>
<include>**/integration/functions/gdfo/*Suite.java</include>
<include>**/integration/functions/sparse/*Suite.java</include>
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
index e0d347f..e3b4d0c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
@@ -21,15 +21,14 @@ package org.apache.sysml.runtime.instructions.spark.utils;
import java.io.IOException;
import java.io.Serializable;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.Iterator;
-import java.util.List;
-import java.util.Scanner;
import org.apache.hadoop.io.Text;
-import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
+import org.apache.spark.SparkException;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
@@ -40,6 +39,7 @@ import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
+import org.apache.spark.mllib.util.NumericParser;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
@@ -47,15 +47,7 @@ import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
-
-import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
-
-import scala.Tuple2;
-
import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.instructions.spark.functions.ConvertMatrixBlockToIJVLines;
-import org.apache.sysml.runtime.io.IOUtilFunctions;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
@@ -63,7 +55,8 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.ReblockBuffer;
import org.apache.sysml.runtime.util.FastStringTokenizer;
-import org.apache.sysml.runtime.util.UtilFunctions;
+
+import scala.Tuple2;
/**
* NOTE: These are experimental converter utils. Once thoroughly tested, they
@@ -362,4 +355,97 @@ public class RDDConverterUtilsExt
ret.addAll(SparkUtils.fromIndexedMatrixBlock(rettmp));
}
}
+
+ /**
+ * Convert a dataframe of comma-separated string rows to a dataframe of
+ * ml.linalg.Vector rows.
+ *
+ * <p>
+ * Example input rows:<br>
+ *
+ * <code>
+ * ((1.2, 4.3, 3.4))<br>
+ * (1.2, 3.4, 2.2)<br>
+ * [[1.2, 34.3, 1.2, 1.25]]<br>
+ * [1.2, 3.4]<br>
+ * </code>
+ *
+ * @param sqlContext
+ * Spark SQL Context
+ * @param inputDF
+ * dataframe of comma-separated row strings to convert to
+ * dataframe of ml.linalg.Vector rows
+ * @return dataframe of ml.linalg.Vector rows
+ * @throws DMLRuntimeException
+ * if DMLRuntimeException occurs
+ */
+ public static Dataset<Row> stringDataFrameToVectorDataFrame(SQLContext sqlContext, Dataset<Row> inputDF)
+ throws DMLRuntimeException {
+
+ StructField[] oldSchema = inputDF.schema().fields();
+ StructField[] newSchema = new StructField[oldSchema.length];
+ for (int i = 0; i < oldSchema.length; i++) {
+ String colName = oldSchema[i].name();
+ newSchema[i] = DataTypes.createStructField(colName, new VectorUDT(), true);
+ }
+
+ // converter
+ class StringToVector implements Function<Tuple2<Row, Long>, Row> {
+ private static final long serialVersionUID = -4733816995375745659L;
+
+ @Override
+ public Row call(Tuple2<Row, Long> arg0) throws Exception {
+ Row oldRow = arg0._1;
+ int oldNumCols = oldRow.length();
+ if (oldNumCols > 1) {
+ throw new DMLRuntimeException("The row must have at most one column");
+ }
+
+ // parse the various strings. i.e
+ // ((1.2, 4.3, 3.4)) or (1.2, 3.4, 2.2)
+ // [[1.2, 34.3, 1.2, 1.2]] or [1.2, 3.4]
+ Object[] fields = new Object[oldNumCols];
+ ArrayList<Object> fieldsArr = new ArrayList<Object>();
+ for (int i = 0; i < oldRow.length(); i++) {
+ Object ci = oldRow.get(i);
+ if (ci == null) {
+ fieldsArr.add(null);
+ } else if (ci instanceof String) {
+ String cis = (String) ci;
+ StringBuffer sb = new StringBuffer(cis.trim());
+ for (int nid = 0; i < 2; i++) { // remove two level
+ // nesting
+ if ((sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')')
+ || (sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']')) {
+ sb.deleteCharAt(0);
+ sb.setLength(sb.length() - 1);
+ }
+ }
+ // have the replace code
+ String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]";
+
+ try {
+ // ncis [ ] will always result in double array return type
+ double[] doubles = (double[]) NumericParser.parse(ncis);
+ Vector dense = Vectors.dense(doubles);
+ fieldsArr.add(dense);
+ } catch (Exception e) { // can't catch SparkException here in Java apparently
+ throw new DMLRuntimeException("Error converting to double array. " + e.getMessage(), e);
+ }
+
+ } else {
+ throw new DMLRuntimeException("Only String is supported");
+ }
+ }
+ Row row = RowFactory.create(fieldsArr.toArray());
+ return row;
+ }
+ }
+
+ // output DF
+ JavaRDD<Row> newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector());
+ Dataset<Row> outDF = sqlContext.createDataFrame(newRows.rdd(), DataTypes.createStructType(newSchema));
+
+ return outDF;
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java b/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java
new file mode 100644
index 0000000..7a69423
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/conversion/RDDConverterUtilsExtTest.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.conversion;
+
+import static org.junit.Assert.assertTrue;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkException;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class RDDConverterUtilsExtTest extends AutomatedTestBase {
+
+ private static SparkConf conf;
+ private static JavaSparkContext sc;
+
+ @BeforeClass
+ public static void setUpClass() {
+ if (conf == null)
+ conf = SparkExecutionContext.createSystemMLSparkConf().setAppName("RDDConverterUtilsExtTest")
+ .setMaster("local");
+ if (sc == null)
+ sc = new JavaSparkContext(conf);
+ }
+
+ @Override
+ public void setUp() {
+ // no setup required
+ }
+
+ /**
+ * Convert a basic String to a spark.sql.Row.
+ */
+ static class StringToRow implements Function<String, Row> {
+ private static final long serialVersionUID = 3945939649355731805L;
+
+ @Override
+ public Row call(String str) throws Exception {
+ return RowFactory.create(str);
+ }
+ }
+
+ @Test
+ public void testStringDataFrameToVectorDataFrame() throws DMLRuntimeException {
+ List<String> list = new ArrayList<String>();
+ list.add("((1.2, 4.3, 3.4))");
+ list.add("(1.2, 3.4, 2.2)");
+ list.add("[[1.2, 34.3, 1.2, 1.25]]");
+ list.add("[1.2, 3.4]");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+ JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ Dataset<Row> inDF = sqlContext.createDataFrame(javaRddRow, schema);
+ Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sqlContext, inDF);
+
+ List<String> expectedResults = new ArrayList<String>();
+ expectedResults.add("[[1.2,4.3,3.4]]");
+ expectedResults.add("[[1.2,3.4,2.2]]");
+ expectedResults.add("[[1.2,34.3,1.2,1.25]]");
+ expectedResults.add("[[1.2,3.4]]");
+
+ List<Row> outputList = outDF.collectAsList();
+ for (Row row : outputList) {
+ assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
+ }
+ }
+
+ @Test
+ public void testStringDataFrameToVectorDataFrameNull() throws DMLRuntimeException {
+ List<String> list = new ArrayList<String>();
+ list.add("[1.2, 3.4]");
+ list.add(null);
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+ JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ Dataset<Row> inDF = sqlContext.createDataFrame(javaRddRow, schema);
+ Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sqlContext, inDF);
+
+ List<String> expectedResults = new ArrayList<String>();
+ expectedResults.add("[[1.2,3.4]]");
+ expectedResults.add("[null]");
+
+ List<Row> outputList = outDF.collectAsList();
+ for (Row row : outputList) {
+ assertTrue("Expected results don't contain: " + row, expectedResults.contains(row.toString()));
+ }
+ }
+
+ @Test(expected = SparkException.class)
+ public void testStringDataFrameToVectorDataFrameNonNumbers() throws DMLRuntimeException {
+ List<String> list = new ArrayList<String>();
+ list.add("[cheeseburger,fries]");
+ JavaRDD<String> javaRddString = sc.parallelize(list);
+ JavaRDD<Row> javaRddRow = javaRddString.map(new StringToRow());
+ SQLContext sqlContext = new SQLContext(sc);
+ List<StructField> fields = new ArrayList<StructField>();
+ fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ Dataset<Row> inDF = sqlContext.createDataFrame(javaRddRow, schema);
+ Dataset<Row> outDF = RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(sqlContext, inDF);
+ // trigger evaluation to throw exception
+ outDF.collectAsList();
+ }
+
+ @After
+ public void tearDown() {
+ super.tearDown();
+ }
+
+ @AfterClass
+ public static void tearDownClass() {
+ // stop spark context to allow single jvm tests (otherwise the
+ // next test that tries to create a SparkContext would fail)
+ sc.stop();
+ sc = null;
+ conf = null;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/67f16c46/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java
new file mode 100644
index 0000000..b8ab13d
--- /dev/null
+++ b/src/test_suites/java/org/apache/sysml/test/integration/conversion/ZPackageSuite.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.conversion;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/** Group together the tests in this package/related subpackages into a single suite so that the Maven build
+ * won't run two of them at once. */
+@RunWith(Suite.class)
+@Suite.SuiteClasses({
+ org.apache.sysml.test.integration.conversion.RDDConverterUtilsExtTest.class
+})
+
+
+/** This class is just a holder for the above JUnit annotations. */
+public class ZPackageSuite {
+
+}