You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2016/08/15 21:52:59 UTC

incubator-systemml git commit: [SYSTEMML-860] SparkR/HydraR integration with SystemML

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 5ac32d6be -> 6df0d2348


[SYSTEMML-860] SparkR/HydraR integration with SystemML

Closes #212.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/6df0d234
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/6df0d234
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/6df0d234

Branch: refs/heads/master
Commit: 6df0d2348e77d583ef02974e5a1f1120a959270a
Parents: 5ac32d6
Author: Alok Singh <si...@us.ibm.com>
Authored: Mon Aug 15 14:49:44 2016 -0700
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Mon Aug 15 14:49:44 2016 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/api/MLContext.java    | 76 +++++++++++++++++++-
 .../spark/utils/RDDConverterUtilsExt.java       | 67 ++++++++++++++++-
 2 files changed, 141 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6df0d234/src/main/java/org/apache/sysml/api/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLContext.java b/src/main/java/org/apache/sysml/api/MLContext.java
index d8a290d..405478f 100644
--- a/src/main/java/org/apache/sysml/api/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/MLContext.java
@@ -837,7 +837,52 @@ public class MLContext {
 		argsArr = args.toArray(argsArr);
 		return execute(dmlScriptFilePath, argsArr, parsePyDML, configFilePath);
 	}
-	
+
+	/*
+	  @NOTE: from calling with the SparkR , somehow Map passing from R to java
+	   is not working and hence we pass in two  arrays each representing keys
+	   and values
+	 */
+	/**
+	 * Execute DML script by passing positional arguments using specified config file
+	 * @param dmlScriptFilePath
+	 * @param argsName
+	 * @param argsValues
+	 * @param configFilePath
+	 * @throws IOException
+	 * @throws DMLException
+	 * @throws ParseException
+	 */
+	public MLOutput execute(String dmlScriptFilePath, ArrayList<String> argsName,
+							ArrayList<String> argsValues, String configFilePath)
+			throws IOException, DMLException, ParseException  {
+		HashMap<String, String> newNamedArgs = new HashMap<String, String>();
+		if (argsName.size() != argsValues.size()) {
+			throw new DMLException("size of argsName " + argsName.size() +
+					" is diff than " + " size of argsValues");
+		}
+		for (int i = 0; i < argsName.size(); i++) {
+			String k = argsName.get(i);
+			String v = argsValues.get(i);
+			newNamedArgs.put(k, v);
+		}
+		return execute(dmlScriptFilePath, newNamedArgs, configFilePath);
+	}
+	/**
+	 * Execute DML script by passing positional arguments using specified config file
+	 * @param dmlScriptFilePath
+	 * @param argsName
+	 * @param argsValues
+	 * @throws IOException
+	 * @throws DMLException
+	 * @throws ParseException
+	 */
+	public MLOutput execute(String dmlScriptFilePath, ArrayList<String> argsName,
+							ArrayList<String> argsValues)
+			throws IOException, DMLException, ParseException  {
+		return execute(dmlScriptFilePath, argsName, argsValues, null);
+	}
+
 	/**
 	 * Experimental: Execute DML script by passing positional arguments if parsePyDML=true, using specified config file.
 	 * @param dmlScriptFilePath
@@ -1163,11 +1208,40 @@ public class MLContext {
 		return executeScript(dmlScript, false, configFilePath);
 	}
 
+
 	public MLOutput executeScript(String dmlScript, boolean isPyDML, String configFilePath)
 			throws IOException, DMLException {
 		return compileAndExecuteScript(dmlScript, null, false, false, isPyDML, configFilePath);
 	}
 
+	/*
+	  @NOTE: from calling with the SparkR , somehow HashMap passing from R to java
+	   is not working and hence we pass in two  arrays each representing keys
+	   and values
+	 */
+	public MLOutput executeScript(String dmlScript, ArrayList<String> argsName,
+								  ArrayList<String> argsValues, String configFilePath)
+			throws IOException, DMLException, ParseException  {
+		HashMap<String, String> newNamedArgs = new HashMap<String, String>();
+		if (argsName.size() != argsValues.size()) {
+			throw new DMLException("size of argsName " + argsName.size() +
+					" is diff than " + " size of argsValues");
+		}
+		for (int i = 0; i < argsName.size(); i++) {
+			String k = argsName.get(i);
+			String v = argsValues.get(i);
+			newNamedArgs.put(k, v);
+		}
+		return executeScript(dmlScript, newNamedArgs, configFilePath);
+	}
+
+	public MLOutput executeScript(String dmlScript, ArrayList<String> argsName,
+								  ArrayList<String> argsValues)
+			throws IOException, DMLException, ParseException  {
+		return executeScript(dmlScript, argsName, argsValues, null);
+	}
+
+
 	public MLOutput executeScript(String dmlScript, scala.collection.immutable.Map<String, String> namedArgs)
 			throws IOException, DMLException {
 		return executeScript(dmlScript, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), null);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6df0d234/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
index 72ab230..88dd44c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
@@ -37,6 +37,7 @@ import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
 import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
 import org.apache.spark.sql.DataFrame;
@@ -141,7 +142,71 @@ public class RDDConverterUtilsExt
 			throw new DMLRuntimeException("The output format:" + format + " is not implemented yet.");
 		}
 	}
-	
+
+
+
+	public static DataFrame stringDataFrameToVectorDataFrame(SQLContext sqlContext, DataFrame inputDF)
+			throws DMLRuntimeException {
+
+		StructField[] oldSchema = inputDF.schema().fields();
+		//create the new schema
+		StructField[] newSchema = new StructField[oldSchema.length];
+		for(int i = 0; i < oldSchema.length; i++) {
+			String colName = oldSchema[i].name();
+			newSchema[i] = DataTypes.createStructField(colName, new VectorUDT(), true);
+		}
+
+		//converter
+		class StringToVector implements Function<Tuple2<Row, Long>, Row> {
+			private static final long serialVersionUID = -4733816995375745659L;
+			@Override
+			public Row call(Tuple2<Row, Long> arg0) throws Exception {
+				Row oldRow = arg0._1;
+				int oldNumCols = oldRow.length();
+				if (oldNumCols > 1) {
+					throw new DMLRuntimeException("The row must have at most one column");
+				}
+
+				// parse the various strings. i.e
+				// ((1.2,4.3, 3.4))  or (1.2, 3.4, 2.2) or (1.2 3.4)
+				// [[1.2,34.3, 1.2, 1.2]] or [1.2, 3.4] or [1.3 1.2]
+				Object [] fields = new Object[oldNumCols];
+				ArrayList<Object> fieldsArr = new ArrayList<Object>();
+				for (int i = 0; i < oldRow.length(); i++) {
+					Object ci=oldRow.get(i);
+					if (ci instanceof String) {
+						String cis = (String)ci;
+						StringBuffer sb = new StringBuffer(cis.trim());
+						for (int nid=0; i < 2; i++) { //remove two level nesting
+							if ((sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')') ||
+									(sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']')
+									) {
+								sb.deleteCharAt(0);
+								sb.setLength(sb.length() - 1);
+							}
+						}
+						//have the replace code
+						String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]";
+						Vector v = Vectors.parse(ncis);
+						fieldsArr.add(v);
+					} else {
+						throw new DMLRuntimeException("Only String is supported");
+					}
+				}
+				Row row = RowFactory.create(fieldsArr.toArray());
+				return row;
+			}
+		}
+
+		//output DF
+		JavaRDD<Row> newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector());
+		// DataFrame outDF = sqlContext.createDataFrame(newRows, new StructType(newSchema)); //TODO investigate why it doesn't work
+		DataFrame outDF = sqlContext.createDataFrame(newRows.rdd(),
+				DataTypes.createStructType(newSchema));
+
+		return outDF;
+	}
+
 	public static JavaPairRDD<MatrixIndexes, MatrixBlock> vectorDataFrameToBinaryBlock(SparkContext sc,
 			DataFrame inputDF, MatrixCharacteristics mcOut, boolean containsID, String vectorColumnName) throws DMLRuntimeException {
 		return vectorDataFrameToBinaryBlock(new JavaSparkContext(sc), inputDF, mcOut, containsID, vectorColumnName);