You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2016/09/04 17:41:51 UTC
[1/2] incubator-systemml git commit: [SYSTEMML-896] Additional
MLContext Frame support
Repository: incubator-systemml
Updated Branches:
refs/heads/master b7657dbc3 -> d39865e9e
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
index 3841bc8..d8446a9 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
@@ -24,11 +24,7 @@ import java.util.Set;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.parser.Expression.ValueType;
-import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
-import org.apache.sysml.runtime.controlprogram.caching.CacheException;
import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
@@ -39,11 +35,6 @@ import org.apache.sysml.runtime.instructions.cp.DoubleObject;
import org.apache.sysml.runtime.instructions.cp.IntObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.cp.StringObject;
-import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
-import org.apache.sysml.runtime.matrix.MatrixDimensionsMetaData;
-import org.apache.sysml.runtime.matrix.data.FrameBlock;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.util.DataConverter;
import scala.Tuple1;
import scala.Tuple10;
@@ -120,21 +111,7 @@ public class MLResults {
*/
public MatrixObject getMatrixObject(String outputName) {
Data data = getData(outputName);
- if(data instanceof ScalarObject) {
- double val = getDouble(outputName);
- MatrixObject one_X_one_mo = new MatrixObject(ValueType.DOUBLE, " ", new MatrixDimensionsMetaData(new MatrixCharacteristics(1, 1, OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, 1)));
- MatrixBlock mb = new MatrixBlock(1, 1, false);
- mb.allocateDenseBlock();
- mb.setValue(0, 0, val);
- try {
- one_X_one_mo.acquireModify(mb);
- one_X_one_mo.release();
- } catch (CacheException e) {
- throw new RuntimeException(e);
- }
- return one_X_one_mo;
- }
- else if (!(data instanceof MatrixObject)) {
+ if (!(data instanceof MatrixObject)) {
throw new MLContextException("Variable '" + outputName + "' not a matrix");
}
MatrixObject mo = (MatrixObject) data;
@@ -163,7 +140,7 @@ public class MLResults {
* the name of the output
* @return the output as a two-dimensional {@code double} array
*/
- public double[][] getDoubleMatrix(String outputName) {
+ public double[][] getMatrixAs2DDoubleArray(String outputName) {
MatrixObject mo = getMatrixObject(outputName);
double[][] doubleMatrix = MLContextConversionUtil.matrixObjectToDoubleMatrix(mo);
return doubleMatrix;
@@ -190,22 +167,16 @@ public class MLResults {
* @return the output as a {@code JavaRDD<String>} in IJV format
*/
public JavaRDD<String> getJavaRDDStringIJV(String outputName) {
- MatrixObject mo = getMatrixObject(outputName);
- JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil.matrixObjectToJavaRDDStringIJV(mo);
- return javaRDDStringIJV;
- }
-
- /**
- * Obtain an output as a {@code JavaRDD<String>} in IJV format.
- *
- * @param outputName
- * the name of the output
- * @return the output as a {@code JavaRDD<String>} in IJV format
- */
- public JavaRDD<String> getFrameJavaRDDStringIJV(String outputName) {
- FrameObject fo = getFrameObject(outputName);
- JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil.frameObjectToJavaRDDStringIJV(fo);
- return javaRDDStringIJV;
+ if (isMatrixObject(outputName)) {
+ MatrixObject mo = getMatrixObject(outputName);
+ JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil.matrixObjectToJavaRDDStringIJV(mo);
+ return javaRDDStringIJV;
+ } else if (isFrameObject(outputName)) {
+ FrameObject fo = getFrameObject(outputName);
+ JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil.frameObjectToJavaRDDStringIJV(fo);
+ return javaRDDStringIJV;
+ }
+ return null;
}
/**
@@ -227,22 +198,16 @@ public class MLResults {
* @return the output as a {@code JavaRDD<String>} in CSV format
*/
public JavaRDD<String> getJavaRDDStringCSV(String outputName) {
- MatrixObject mo = getMatrixObject(outputName);
- JavaRDD<String> javaRDDStringCSV = MLContextConversionUtil.matrixObjectToJavaRDDStringCSV(mo);
- return javaRDDStringCSV;
- }
-
- /**
- * Obtain an output as a {@code JavaRDD<String>} in CSV format.
- *
- * @param outputName
- * the name of the output
- * @return the output as a {@code JavaRDD<String>} in CSV format
- */
- public JavaRDD<String> getFrameJavaRDDStringCSV(String outputName, String delimiter) {
- FrameObject fo = getFrameObject(outputName);
- JavaRDD<String> javaRDDStringCSV = MLContextConversionUtil.frameObjectToJavaRDDStringCSV(fo, delimiter);
- return javaRDDStringCSV;
+ if (isMatrixObject(outputName)) {
+ MatrixObject mo = getMatrixObject(outputName);
+ JavaRDD<String> javaRDDStringCSV = MLContextConversionUtil.matrixObjectToJavaRDDStringCSV(mo);
+ return javaRDDStringCSV;
+ } else if (isFrameObject(outputName)) {
+ FrameObject fo = getFrameObject(outputName);
+ JavaRDD<String> javaRDDStringCSV = MLContextConversionUtil.frameObjectToJavaRDDStringCSV(fo, ",");
+ return javaRDDStringCSV;
+ }
+ return null;
}
/**
@@ -264,23 +229,16 @@ public class MLResults {
* @return the output as a {@code RDD<String>} in CSV format
*/
public RDD<String> getRDDStringCSV(String outputName) {
- MatrixObject mo = getMatrixObject(outputName);
- RDD<String> rddStringCSV = MLContextConversionUtil.matrixObjectToRDDStringCSV(mo);
- return rddStringCSV;
- }
-
-
- /**
- * Obtain an output as a {@code RDD<String>} in CSV format.
- *
- * @param outputName
- * the name of the output
- * @return the output as a {@code RDD<String>} in CSV format
- */
- public RDD<String> getFrameRDDStringCSV(String outputName, String delimiter) {
- FrameObject fo = getFrameObject(outputName);
- RDD<String> rddStringCSV = MLContextConversionUtil.frameObjectToRDDStringCSV(fo, delimiter);
- return rddStringCSV;
+ if (isMatrixObject(outputName)) {
+ MatrixObject mo = getMatrixObject(outputName);
+ RDD<String> rddStringCSV = MLContextConversionUtil.matrixObjectToRDDStringCSV(mo);
+ return rddStringCSV;
+ } else if (isFrameObject(outputName)) {
+ FrameObject fo = getFrameObject(outputName);
+ RDD<String> rddStringCSV = MLContextConversionUtil.frameObjectToRDDStringCSV(fo, ",");
+ return rddStringCSV;
+ }
+ return null;
}
/**
@@ -304,26 +262,21 @@ public class MLResults {
* @return the output as a {@code RDD<String>} in IJV format
*/
public RDD<String> getRDDStringIJV(String outputName) {
- MatrixObject mo = getMatrixObject(outputName);
- RDD<String> rddStringIJV = MLContextConversionUtil.matrixObjectToRDDStringIJV(mo);
- return rddStringIJV;
- }
-
- /**
- * Obtain an output as a {@code RDD<String>} in IJV format.
- *
- * @param outputName
- * the name of the output
- * @return the output as a {@code RDD<String>} in IJV format
- */
- public RDD<String> getFrameRDDStringIJV(String outputName) {
- FrameObject fo = getFrameObject(outputName);
- RDD<String> rddStringIJV = MLContextConversionUtil.frameObjectToRDDStringIJV(fo);
- return rddStringIJV;
+ if (isMatrixObject(outputName)) {
+ MatrixObject mo = getMatrixObject(outputName);
+ RDD<String> rddStringIJV = MLContextConversionUtil.matrixObjectToRDDStringIJV(mo);
+ return rddStringIJV;
+ } else if (isFrameObject(outputName)) {
+ FrameObject fo = getFrameObject(outputName);
+ RDD<String> rddStringIJV = MLContextConversionUtil.frameObjectToRDDStringIJV(fo);
+ return rddStringIJV;
+ }
+ return null;
}
/**
- * Obtain an output as a {@code DataFrame} of doubles with an ID column.
+ * Obtain an output as a {@code DataFrame}. If outputting a Matrix, this
+ * will be a DataFrame of doubles with an ID column.
* <p>
* The following matrix in DML:
* </p>
@@ -338,12 +291,53 @@ public class MLResults {
*
* @param outputName
* the name of the output
- * @return the output as a {@code DataFrame} of doubles with an ID column
+ * @return the output as a {@code DataFrame}
*/
public DataFrame getDataFrame(String outputName) {
- MatrixObject mo = getMatrixObject(outputName);
- DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false);
- return df;
+ if (isMatrixObject(outputName)) {
+ MatrixObject mo = getMatrixObject(outputName);
+ DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false);
+ return df;
+ } else if (isFrameObject(outputName)) {
+ FrameObject mo = getFrameObject(outputName);
+ DataFrame df = MLContextConversionUtil.frameObjectToDataFrame(mo, sparkExecutionContext);
+ return df;
+ }
+ return null;
+ }
+
+ /**
+ * Is the output a MatrixObject?
+ *
+ * @param outputName
+ * the name of the output
+ * @return {@code true} if the output is a MatrixObject, {@code false}
+ * otherwise.
+ */
+ private boolean isMatrixObject(String outputName) {
+ Data data = getData(outputName);
+ if (data instanceof MatrixObject) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ /**
+ * Is the output a FrameObject?
+ *
+ * @param outputName
+ * the name of the output
+ * @return {@code true} if the output is a FrameObject, {@code false}
+ * otherwise.
+ */
+ private boolean isFrameObject(String outputName) {
+ Data data = getData(outputName);
+ if (data instanceof FrameObject) {
+ return true;
+ } else {
+ return false;
+ }
}
/**
@@ -376,6 +370,9 @@ public class MLResults {
* ID column
*/
public DataFrame getDataFrame(String outputName, boolean isVectorDF) {
+ if (isFrameObject(outputName)) {
+ throw new MLContextException("This method currently supports only matrices");
+ }
MatrixObject mo = getMatrixObject(outputName);
DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, isVectorDF);
return df;
@@ -400,6 +397,9 @@ public class MLResults {
* @return the output as a {@code DataFrame} of doubles with an ID column
*/
public DataFrame getDataFrameDoubleWithIDColumn(String outputName) {
+ if (isFrameObject(outputName)) {
+ throw new MLContextException("This method currently supports only matrices");
+ }
MatrixObject mo = getMatrixObject(outputName);
DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false);
return df;
@@ -424,6 +424,9 @@ public class MLResults {
* @return the output as a {@code DataFrame} of vectors with an ID column
*/
public DataFrame getDataFrameVectorWithIDColumn(String outputName) {
+ if (isFrameObject(outputName)) {
+ throw new MLContextException("This method currently supports only matrices");
+ }
MatrixObject mo = getMatrixObject(outputName);
DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, true);
return df;
@@ -448,6 +451,9 @@ public class MLResults {
* @return the output as a {@code DataFrame} of doubles with no ID column
*/
public DataFrame getDataFrameDoubleNoIDColumn(String outputName) {
+ if (isFrameObject(outputName)) {
+ throw new MLContextException("This method currently supports only matrices");
+ }
MatrixObject mo = getMatrixObject(outputName);
DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false);
df = df.sort("ID").drop("ID");
@@ -473,6 +479,9 @@ public class MLResults {
* @return the output as a {@code DataFrame} of vectors with no ID column
*/
public DataFrame getDataFrameVectorNoIDColumn(String outputName) {
+ if (isFrameObject(outputName)) {
+ throw new MLContextException("This method currently supports only matrices");
+ }
MatrixObject mo = getMatrixObject(outputName);
DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, true);
df = df.sort("ID").drop("ID");
@@ -480,29 +489,29 @@ public class MLResults {
}
/**
- * Obtain an output as a {@code DataFrame} without an ID column.
+ * Obtain an output as a {@code Matrix}.
*
* @param outputName
* the name of the output
- * @return the output as a {@code DataFrame} without an ID column
+ * @return the output as a {@code Matrix}
*/
- public DataFrame getFrameDataFrame(String outputName) {
- FrameObject mo = getFrameObject(outputName);
- DataFrame df = MLContextConversionUtil.frameObjectToDataFrame(mo, sparkExecutionContext);
- return df;
+ public Matrix getMatrix(String outputName) {
+ MatrixObject mo = getMatrixObject(outputName);
+ Matrix matrix = new Matrix(mo, sparkExecutionContext);
+ return matrix;
}
/**
- * Obtain an output as a {@code Matrix}.
+ * Obtain an output as a {@code Frame}.
*
* @param outputName
* the name of the output
- * @return the output as a {@code Matrix}
+ * @return the output as a {@code Frame}
*/
- public Matrix getMatrix(String outputName) {
- MatrixObject mo = getMatrixObject(outputName);
- Matrix matrix = new Matrix(mo, sparkExecutionContext);
- return matrix;
+ public Frame getFrame(String outputName) {
+ FrameObject fo = getFrameObject(outputName);
+ Frame frame = new Frame(fo, sparkExecutionContext);
+ return frame;
}
/**
@@ -526,22 +535,10 @@ public class MLResults {
* the name of the output
* @return the output as a two-dimensional {@code String} array
*/
- public String[][] getFrame(String outputName) {
- try {
- Data data = getData(outputName);
- if (!(data instanceof FrameObject)) {
- throw new MLContextException("Variable '" + outputName + "' not a frame");
- }
- FrameObject fo = (FrameObject) data;
- FrameBlock fb = fo.acquireRead();
- String[][] frame = DataConverter.convertToStringFrame(fb);
- fo.release();
- return frame;
- } catch (CacheException e) {
- throw new MLContextException("Cache exception when reading frame", e);
- } catch (DMLRuntimeException e) {
- throw new MLContextException("DML runtime exception when reading frame", e);
- }
+ public String[][] getFrameAs2DStringArray(String outputName) {
+ FrameObject frameObject = getFrameObject(outputName);
+ String[][] frame = MLContextConversionUtil.frameObjectTo2DStringArray(frameObject);
+ return frame;
}
/**
@@ -569,8 +566,10 @@ public class MLResults {
if (data instanceof ScalarObject) {
ScalarObject so = (ScalarObject) data;
return so.getValue();
- } else if(data instanceof MatrixObject) {
+ } else if (data instanceof MatrixObject) {
return getMatrix(outputName);
+ } else if (data instanceof FrameObject) {
+ return getFrame(outputName);
} else {
return data;
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java b/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java
index 1ea3a10..513b74d 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java
@@ -27,7 +27,7 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
* columns per block in the matrix.
*
*/
-public class MatrixMetadata {
+public class MatrixMetadata extends Metadata {
private Long numRows = null;
private Long numColumns = null;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/Metadata.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Metadata.java b/src/main/java/org/apache/sysml/api/mlcontext/Metadata.java
new file mode 100644
index 0000000..c1c0a36
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/Metadata.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+/**
+ * Abstract metadata class for MLContext API. Complex types such as SystemML
+ * matrices and frames typically require metadata, so this abstract class serves
+ * as a common parent class of these types.
+ *
+ */
+public abstract class Metadata {
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/Script.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Script.java b/src/main/java/org/apache/sysml/api/mlcontext/Script.java
index bfa947c..17a3996 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/Script.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/Script.java
@@ -68,13 +68,9 @@ public class Script {
*/
private Set<String> inputVariables = new LinkedHashSet<String>();
/**
- * The input variable type (if its frame of matrix).
+ * The input matrix or frame metadata if present.
*/
- private Map<String, Boolean> inputVariablesType = new LinkedHashMap<String, Boolean>();
- /**
- * The input matrix metadata if present.
- */
- private Map<String, MatrixMetadata> inputMatrixMetadata = new LinkedHashMap<String, MatrixMetadata>();
+ private Map<String, Metadata> inputMetadata = new LinkedHashMap<String, Metadata>();
/**
* The output variables.
*/
@@ -186,15 +182,6 @@ public class Script {
}
/**
- * Obtain the input variable type flag (if its frame or not)
- *
- * @return the input variable names
- */
- public Map<String, Boolean> getInputVariablesType() {
- return inputVariablesType;
- }
-
- /**
* Obtain the output variable names as an unmodifiable set of strings.
*
* @return the output variable names
@@ -223,12 +210,12 @@ public class Script {
}
/**
- * Obtain an unmodifiable map of input matrix metadata.
+ * Obtain an unmodifiable map of input matrix/frame metadata.
*
- * @return input matrix metadata
+ * @return input matrix/frame metadata
*/
- public Map<String, MatrixMetadata> getInputMatrixMetadata() {
- return Collections.unmodifiableMap(inputMatrixMetadata);
+ public Map<String, Metadata> getInputMetadata() {
+ return Collections.unmodifiableMap(inputMetadata);
}
/**
@@ -317,7 +304,7 @@ public class Script {
* @return {@code this} Script object to allow chaining of methods
*/
public Script in(String name, Object value) {
- return in(name, value, (MatrixMetadata) null);
+ return in(name, value, null);
}
/**
@@ -328,45 +315,11 @@ public class Script {
* name of the input
* @param value
* value of the input
- * @param matrixFormat
- * optional matrix format
+ * @param metadata
+ * optional matrix/frame metadata
* @return {@code this} Script object to allow chaining of methods
*/
- public Script in(String name, Object value, MatrixFormat matrixFormat) {
- MatrixMetadata matrixMetadata = new MatrixMetadata(matrixFormat);
- return in(name, value, matrixMetadata);
- }
-
- /**
- * Register an input (parameter ($) or variable) with optional matrix
- * metadata.
- *
- * @param name
- * name of the input
- * @param value
- * value of the input
- * @param matrixMetadata
- * optional matrix metadata
- * @return {@code this} Script object to allow chaining of methods
- */
- public Script in(String name, Object value, MatrixMetadata matrixMetadata) {
- return in(name, value, matrixMetadata, false);
- }
- /**
- * Register an input (parameter ($) or variable) with optional matrix
- * metadata.
- *
- * @param name
- * name of the input
- * @param value
- * value of the input
- * @param matrixMetadata
- * optional matrix metadata
- * @param bFrame
- * if input is of type frame
- * @return {@code this} Script object to allow chaining of methods
- */
- public Script in(String name, Object value, MatrixMetadata matrixMetadata, boolean bFrame) {
+ public Script in(String name, Object value, Metadata metadata) {
MLContextUtil.checkInputValueType(name, value);
if (inputs == null) {
inputs = new LinkedHashMap<String, Object>();
@@ -380,17 +333,13 @@ public class Script {
}
inputParameters.put(name, value);
} else {
- Data data = MLContextUtil.convertInputType(name, value, matrixMetadata, bFrame);
+ Data data = MLContextUtil.convertInputType(name, value, metadata);
if (data != null) {
symbolTable.put(name, data);
inputVariables.add(name);
- if (inputVariablesType == null) {
- inputVariablesType = new LinkedHashMap<String, Boolean>();
- }
- inputVariablesType.put(name, new Boolean(bFrame));
if (data instanceof MatrixObject || data instanceof FrameObject) {
- if (matrixMetadata != null) {
- inputMatrixMetadata.put(name, matrixMetadata);
+ if (metadata != null) {
+ inputMetadata.put(name, metadata);
}
}
}
@@ -454,8 +403,7 @@ public class Script {
inputs.clear();
inputParameters.clear();
inputVariables.clear();
- inputVariablesType.clear();
- inputMatrixMetadata.clear();
+ inputMetadata.clear();
}
/**
@@ -556,11 +504,10 @@ public class Script {
sb.append(" = " + quotedString + ";\n");
} else if (MLContextUtil.isBasicType(inValue)) {
sb.append(" = read('', data_type='scalar');\n");
+ } else if (MLContextUtil.doesSymbolTableContainFrameObject(symbolTable, in)) {
+ sb.append(" = read('', data_type='frame');\n");
} else {
- if(inputVariablesType.get(in).booleanValue())
- sb.append(" = read('', data_type='frame');\n");
- else
- sb.append(" = read('');\n");
+ sb.append(" = read('');\n");
}
} else if (isPYDML()) {
if (inValue instanceof String) {
@@ -568,11 +515,10 @@ public class Script {
sb.append(" = " + quotedString + "\n");
} else if (MLContextUtil.isBasicType(inValue)) {
sb.append(" = load('', data_type='scalar')\n");
+ } else if (MLContextUtil.doesSymbolTableContainFrameObject(symbolTable, in)) {
+ sb.append(" = load('', data_type='frame')\n");
} else {
- if(inputVariablesType.get(in).booleanValue())
- sb.append(" = load('', data_type='frame')\n");
- else
- sb.append(" = load('')\n");
+ sb.append(" = load('')\n");
}
}
@@ -603,7 +549,7 @@ public class Script {
public String toString() {
StringBuilder sb = new StringBuilder();
- sb.append(MLContextUtil.displayInputs("Inputs", inputs));
+ sb.append(MLContextUtil.displayInputs("Inputs", inputs, symbolTable));
sb.append("\n");
sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable));
return sb.toString();
@@ -623,7 +569,7 @@ public class Script {
sb.append("Script Type: ");
sb.append(scriptType);
sb.append("\n\n");
- sb.append(MLContextUtil.displayInputs("Inputs", inputs));
+ sb.append(MLContextUtil.displayInputs("Inputs", inputs, symbolTable));
sb.append("\n");
sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable));
sb.append("\n");
@@ -649,7 +595,7 @@ public class Script {
* @return the script inputs
*/
public String displayInputs() {
- return MLContextUtil.displayInputs("Inputs", inputs);
+ return MLContextUtil.displayInputs("Inputs", inputs, symbolTable);
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index cf0d09f..2973ed2 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -348,14 +348,14 @@ public class ScriptExecutor {
*/
protected void restoreInputsInSymbolTable() {
Map<String, Object> inputs = script.getInputs();
- Map<String, MatrixMetadata> inputMatrixMetadata = script.getInputMatrixMetadata();
+ Map<String, Metadata> inputMetadata = script.getInputMetadata();
LocalVariableMap symbolTable = script.getSymbolTable();
Set<String> inputVariables = script.getInputVariables();
for (String inputVariable : inputVariables) {
if (symbolTable.get(inputVariable) == null) {
// retrieve optional metadata if it exists
- MatrixMetadata mm = inputMatrixMetadata.get(inputVariable);
- script.in(inputVariable, inputs.get(inputVariable), mm, script.getInputVariablesType().get(inputVariable));
+ Metadata m = inputMetadata.get(inputVariable);
+ script.in(inputVariable, inputs.get(inputVariable), m);
}
}
}
@@ -451,8 +451,8 @@ public class ScriptExecutor {
if (symbolTable != null) {
String[] inputs = (script.getInputVariables() == null) ? new String[0] : script.getInputVariables()
.toArray(new String[0]);
- String[] outputs = (script.getOutputVariables() == null) ? new String[0] : script.getOutputVariables()
- .toArray(new String[0]);
+ String[] outputs = (script.getOutputVariables() == null) ? new String[0]
+ : script.getOutputVariables().toArray(new String[0]);
RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs);
ProgramRewriter programRewriter = new ProgramRewriter(rewrite);
try {
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
new file mode 100644
index 0000000..98c8b10
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
@@ -0,0 +1,557 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.mlcontext;
+
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dml;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.pydml;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.sysml.api.mlcontext.FrameFormat;
+import org.apache.sysml.api.mlcontext.FrameMetadata;
+import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLResults;
+import org.apache.sysml.api.mlcontext.MatrixFormat;
+import org.apache.sysml.api.mlcontext.MatrixMetadata;
+import org.apache.sysml.api.mlcontext.Script;
+import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.mlcontext.MLContextTest.CommaSeparatedValueStringToRow;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import scala.collection.Iterator;
+
+public class MLContextFrameTest extends AutomatedTestBase {
+ protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext";
+ protected final static String TEST_NAME = "MLContextFrame";
+
+ public static enum SCRIPT_TYPE {
+ DML, PYDML
+ };
+
+ public static enum IO_TYPE {
+ ANY, FILE, JAVA_RDD_STR_CSV, JAVA_RDD_STR_IJV, RDD_STR_CSV, RDD_STR_IJV, DATAFRAME
+ };
+
+ private static SparkConf conf;
+ private static JavaSparkContext sc;
+ private static MLContext ml;
+
+ @BeforeClass
+ public static void setUpClass() {
+ if (conf == null)
+ conf = new SparkConf().setAppName("MLContextFrameTest").setMaster("local");
+ if (sc == null)
+ sc = new JavaSparkContext(conf);
+ ml = new MLContext(sc);
+ }
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_DIR, TEST_NAME);
+ getAndLoadTestConfiguration(TEST_NAME);
+ }
+
+ @Test
+ public void testFrameJavaRDD_CSV_DML() {
+ testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameJavaRDD_CSV_DML_OutJavaRddCSV() {
+ testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.JAVA_RDD_STR_CSV);
+ }
+
+ @Test
+ public void testFrameJavaRDD_CSV_PYDML() {
+ testFrame(FrameFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameRDD_CSV_PYDML() {
+ testFrame(FrameFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.RDD_STR_CSV, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameJavaRDD_CSV_PYDML_OutRddIJV() {
+ testFrame(FrameFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.RDD_STR_IJV);
+ }
+
+ @Test
+ public void testFrameJavaRDD_IJV_DML() {
+ testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameRDD_IJV_DML() {
+ testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.RDD_STR_IJV, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameJavaRDD_IJV_DML_OutRddCSV() {
+ testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.RDD_STR_CSV);
+ }
+
+ @Test
+ public void testFrameJavaRDD_IJV_PYDML() {
+ testFrame(FrameFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameJavaRDD_IJV_PYDML_OutJavaRddIJV() {
+ testFrame(FrameFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.JAVA_RDD_STR_IJV);
+ }
+
+ @Test
+ public void testFrameFile_CSV_DML() {
+ testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.FILE, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameFile_CSV_PYDML() {
+ testFrame(FrameFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.FILE, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameFile_IJV_DML() {
+ testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.FILE, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameFile_IJV_PYDML() {
+ testFrame(FrameFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.FILE, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameDataFrame_CSV_DML() {
+ testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.DATAFRAME, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameDataFrame_CSV_PYDML() {
+ testFrame(FrameFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.DATAFRAME, IO_TYPE.ANY);
+ }
+
+ @Test
+ public void testFrameDataFrameOutDataFrame_CSV_DML() {
+ testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.DATAFRAME, IO_TYPE.DATAFRAME);
+ }
+
+ public void testFrame(FrameFormat format, SCRIPT_TYPE script_type, IO_TYPE inputType, IO_TYPE outputType) {
+
+ System.out.println("MLContextTest - Frame JavaRDD<String> for format: " + format + " Script: " + script_type);
+
+ List<String> listA = new ArrayList<String>();
+ List<String> listB = new ArrayList<String>();
+ FrameMetadata fmA = null, fmB = null;
+ Script script = null;
+
+ if (inputType != IO_TYPE.FILE) {
+ if (format == FrameFormat.CSV) {
+ listA.add("1,Str2,3.0,true");
+ listA.add("4,Str5,6.0,false");
+ listA.add("7,Str8,9.0,true");
+
+ listB.add("Str12,13.0,true");
+ listB.add("Str25,26.0,false");
+
+ fmA = new FrameMetadata(FrameFormat.CSV, 3, 4);
+ fmB = new FrameMetadata(FrameFormat.CSV, 2, 3);
+ } else if (format == FrameFormat.IJV) {
+ listA.add("1 1 1");
+ listA.add("1 2 Str2");
+ listA.add("1 3 3.0");
+ listA.add("1 4 true");
+ listA.add("2 1 4");
+ listA.add("2 2 Str5");
+ listA.add("2 3 6.0");
+ listA.add("2 4 false");
+ listA.add("3 1 7");
+ listA.add("3 2 Str8");
+ listA.add("3 3 9.0");
+ listA.add("3 4 true");
+
+ listB.add("1 1 Str12");
+ listB.add("1 2 13.0");
+ listB.add("1 3 true");
+ listB.add("2 1 Str25");
+ listB.add("2 2 26.0");
+ listB.add("2 3 false");
+
+ fmA = new FrameMetadata(FrameFormat.IJV, 3, 4);
+ fmB = new FrameMetadata(FrameFormat.IJV, 2, 3);
+ }
+ JavaRDD<String> javaRDDA = sc.parallelize(listA);
+ JavaRDD<String> javaRDDB = sc.parallelize(listB);
+
+ if (inputType == IO_TYPE.DATAFRAME) {
+ JavaRDD<Row> javaRddRowA = javaRDDA.map(new MLContextTest.CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRowB = javaRDDB.map(new MLContextTest.CommaSeparatedValueStringToRow());
+
+ ValueType[] schemaA = { ValueType.INT, ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN };
+ List<ValueType> lschemaA = Arrays.asList(schemaA);
+ ValueType[] schemaB = { ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN };
+ List<ValueType> lschemaB = Arrays.asList(schemaB);
+
+ // Create DataFrame
+ SQLContext sqlContext = new SQLContext(sc);
+ StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(lschemaA);
+ DataFrame dataFrameA = sqlContext.createDataFrame(javaRddRowA, dfSchemaA);
+ StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(lschemaB);
+ DataFrame dataFrameB = sqlContext.createDataFrame(javaRddRowB, dfSchemaB);
+ if (script_type == SCRIPT_TYPE.DML)
+ script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", dataFrameA, fmA).in("B", dataFrameB, fmB).out("A")
+ .out("C");
+ else if (script_type == SCRIPT_TYPE.PYDML)
+ // DO NOT USE ; at the end of any statment, it throws NPE
+ script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", dataFrameA, fmA)
+ .in("B", dataFrameB, fmB)
+ // Value for ROW index gets incremented at script
+ // level to adjust index in PyDML, but not for
+ // Column Index
+ .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
+ } else {
+ if (inputType == IO_TYPE.JAVA_RDD_STR_CSV || inputType == IO_TYPE.JAVA_RDD_STR_IJV) {
+ if (script_type == SCRIPT_TYPE.DML)
+ script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", javaRDDA, fmA).in("B", javaRDDB, fmB).out("A")
+ .out("C");
+ else if (script_type == SCRIPT_TYPE.PYDML)
+ // DO NOT USE ; at the end of any statment, it throws
+ // NPE
+ script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", javaRDDA, fmA)
+ .in("B", javaRDDB, fmB)
+ // Value for ROW index gets incremented at
+ // script level to adjust index in PyDML, but
+ // not for Column Index
+ .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
+ } else if (inputType == IO_TYPE.RDD_STR_CSV || inputType == IO_TYPE.RDD_STR_IJV) {
+ RDD<String> rddA = JavaRDD.toRDD(javaRDDA);
+ RDD<String> rddB = JavaRDD.toRDD(javaRDDB);
+
+ if (script_type == SCRIPT_TYPE.DML)
+ script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", rddA, fmA).in("B", rddB, fmB).out("A")
+ .out("C");
+ else if (script_type == SCRIPT_TYPE.PYDML)
+ // DO NOT USE ; at the end of any statment, it throws
+ // NPE
+ script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", rddA, fmA).in("B", rddB, fmB)
+ // Value for ROW index gets incremented at
+ // script level to adjust index in PyDML, but
+ // not for Column Index
+ .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
+ }
+
+ }
+
+ } else { // Input type is file
+ String fileA = null, fileB = null;
+ if (format == FrameFormat.CSV) {
+ fileA = baseDirectory + File.separator + "FrameA.csv";
+ fileB = baseDirectory + File.separator + "FrameB.csv";
+ } else if (format == FrameFormat.IJV) {
+ fileA = baseDirectory + File.separator + "FrameA.ijv";
+ fileB = baseDirectory + File.separator + "FrameB.ijv";
+ }
+
+ if (script_type == SCRIPT_TYPE.DML)
+ script = dml("A=read($A); B=read($B);A[2:3,2:4]=B;C=A[2:3,2:3]").in("$A", fileA, fmA)
+ .in("$B", fileB, fmB).out("A").out("C");
+ else if (script_type == SCRIPT_TYPE.PYDML)
+ // DO NOT USE ; at the end of any statment, it throws NPE
+ script = pydml("A=load($A)\nB=load($B)\nA[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("$A", fileA)
+ .in("$B", fileB)
+ // Value for ROW index gets incremented at script level
+ // to adjust index in PyDML, but not for Column Index
+ .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
+ }
+
+ MLResults mlResults = ml.execute(script);
+
+ if (outputType == IO_TYPE.JAVA_RDD_STR_CSV) {
+
+ JavaRDD<String> javaRDDStringCSVA = mlResults.getJavaRDDStringCSV("A");
+ List<String> linesA = javaRDDStringCSVA.collect();
+ Assert.assertEquals("1,Str2,3.0,true", linesA.get(0));
+ Assert.assertEquals("4,Str12,13.0,true", linesA.get(1));
+ Assert.assertEquals("7,Str25,26.0,false", linesA.get(2));
+
+ JavaRDD<String> javaRDDStringCSVC = mlResults.getJavaRDDStringCSV("C");
+ List<String> linesC = javaRDDStringCSVC.collect();
+ Assert.assertEquals("Str12,13.0", linesC.get(0));
+ Assert.assertEquals("Str25,26.0", linesC.get(1));
+ } else if (outputType == IO_TYPE.JAVA_RDD_STR_IJV) {
+ JavaRDD<String> javaRDDStringIJVA = mlResults.getJavaRDDStringIJV("A");
+ List<String> linesA = javaRDDStringIJVA.collect();
+ Assert.assertEquals("1 1 1", linesA.get(0));
+ Assert.assertEquals("1 2 Str2", linesA.get(1));
+ Assert.assertEquals("1 3 3.0", linesA.get(2));
+ Assert.assertEquals("1 4 true", linesA.get(3));
+ Assert.assertEquals("2 1 4", linesA.get(4));
+ Assert.assertEquals("2 2 Str12", linesA.get(5));
+ Assert.assertEquals("2 3 13.0", linesA.get(6));
+ Assert.assertEquals("2 4 true", linesA.get(7));
+
+ JavaRDD<String> javaRDDStringIJVC = mlResults.getJavaRDDStringIJV("C");
+ List<String> linesC = javaRDDStringIJVC.collect();
+ Assert.assertEquals("1 1 Str12", linesC.get(0));
+ Assert.assertEquals("1 2 13.0", linesC.get(1));
+ Assert.assertEquals("2 1 Str25", linesC.get(2));
+ Assert.assertEquals("2 2 26.0", linesC.get(3));
+ } else if (outputType == IO_TYPE.RDD_STR_CSV) {
+ RDD<String> rddStringCSVA = mlResults.getRDDStringCSV("A");
+ Iterator<String> iteratorA = rddStringCSVA.toLocalIterator();
+ Assert.assertEquals("1,Str2,3.0,true", iteratorA.next());
+ Assert.assertEquals("4,Str12,13.0,true", iteratorA.next());
+ Assert.assertEquals("7,Str25,26.0,false", iteratorA.next());
+
+ RDD<String> rddStringCSVC = mlResults.getRDDStringCSV("C");
+ Iterator<String> iteratorC = rddStringCSVC.toLocalIterator();
+ Assert.assertEquals("Str12,13.0", iteratorC.next());
+ Assert.assertEquals("Str25,26.0", iteratorC.next());
+ } else if (outputType == IO_TYPE.RDD_STR_IJV) {
+ RDD<String> rddStringIJVA = mlResults.getRDDStringIJV("A");
+ Iterator<String> iteratorA = rddStringIJVA.toLocalIterator();
+ Assert.assertEquals("1 1 1", iteratorA.next());
+ Assert.assertEquals("1 2 Str2", iteratorA.next());
+ Assert.assertEquals("1 3 3.0", iteratorA.next());
+ Assert.assertEquals("1 4 true", iteratorA.next());
+ Assert.assertEquals("2 1 4", iteratorA.next());
+ Assert.assertEquals("2 2 Str12", iteratorA.next());
+ Assert.assertEquals("2 3 13.0", iteratorA.next());
+ Assert.assertEquals("2 4 true", iteratorA.next());
+ Assert.assertEquals("3 1 7", iteratorA.next());
+ Assert.assertEquals("3 2 Str25", iteratorA.next());
+ Assert.assertEquals("3 3 26.0", iteratorA.next());
+ Assert.assertEquals("3 4 false", iteratorA.next());
+
+ RDD<String> rddStringIJVC = mlResults.getRDDStringIJV("C");
+ Iterator<String> iteratorC = rddStringIJVC.toLocalIterator();
+ Assert.assertEquals("1 1 Str12", iteratorC.next());
+ Assert.assertEquals("1 2 13.0", iteratorC.next());
+ Assert.assertEquals("2 1 Str25", iteratorC.next());
+ Assert.assertEquals("2 2 26.0", iteratorC.next());
+
+ } else if (outputType == IO_TYPE.DATAFRAME) {
+
+ DataFrame dataFrameA = mlResults.getDataFrame("A");
+ List<Row> listAOut = dataFrameA.collectAsList();
+
+ Row row1 = listAOut.get(0);
+ Assert.assertEquals("Mistmatch with expected value", "1", row1.getString(0));
+ Assert.assertEquals("Mistmatch with expected value", "Str2", row1.getString(1));
+ Assert.assertEquals("Mistmatch with expected value", "3.0", row1.getString(2));
+ Assert.assertEquals("Mistmatch with expected value", "true", row1.getString(3));
+
+ Row row2 = listAOut.get(1);
+ Assert.assertEquals("Mistmatch with expected value", "4", row2.getString(0));
+ Assert.assertEquals("Mistmatch with expected value", "Str12", row2.getString(1));
+ Assert.assertEquals("Mistmatch with expected value", "13.0", row2.getString(2));
+ Assert.assertEquals("Mistmatch with expected value", "true", row2.getString(3));
+
+ DataFrame dataFrameC = mlResults.getDataFrame("C");
+ List<Row> listCOut = dataFrameC.collectAsList();
+
+ Row row3 = listCOut.get(0);
+ Assert.assertEquals("Mistmatch with expected value", "Str12", row3.getString(0));
+ Assert.assertEquals("Mistmatch with expected value", "13.0", row3.getString(1));
+
+ Row row4 = listCOut.get(1);
+ Assert.assertEquals("Mistmatch with expected value", "Str25", row4.getString(0));
+ Assert.assertEquals("Mistmatch with expected value", "26.0", row4.getString(1));
+ } else {
+ String[][] frameA = mlResults.getFrameAs2DStringArray("A");
+ Assert.assertEquals("Str2", frameA[0][1]);
+ Assert.assertEquals("3.0", frameA[0][2]);
+ Assert.assertEquals("13.0", frameA[1][2]);
+ Assert.assertEquals("true", frameA[1][3]);
+ Assert.assertEquals("Str25", frameA[2][1]);
+
+ String[][] frameC = mlResults.getFrameAs2DStringArray("C");
+ Assert.assertEquals("Str12", frameC[0][0]);
+ Assert.assertEquals("Str25", frameC[1][0]);
+ Assert.assertEquals("13.0", frameC[0][1]);
+ Assert.assertEquals("26.0", frameC[1][1]);
+ }
+ }
+
+ @Test
+ public void testOutputFrameDML() {
+ System.out.println("MLContextFrameTest - output frame DML");
+
+ String s = "M = read($Min, data_type='frame', format='csv');";
+ String csvFile = baseDirectory + File.separator + "one-two-three-four.csv";
+ Script script = dml(s).in("$Min", csvFile).out("M");
+ String[][] frame = ml.execute(script).getFrameAs2DStringArray("M");
+ Assert.assertEquals("one", frame[0][0]);
+ Assert.assertEquals("two", frame[0][1]);
+ Assert.assertEquals("three", frame[1][0]);
+ Assert.assertEquals("four", frame[1][1]);
+ }
+
+ @Test
+ public void testOutputFramePYDML() {
+ System.out.println("MLContextFrameTest - output frame PYDML");
+
+ String s = "M = load($Min, data_type='frame', format='csv')";
+ String csvFile = baseDirectory + File.separator + "one-two-three-four.csv";
+ Script script = pydml(s).in("$Min", csvFile).out("M");
+ String[][] frame = ml.execute(script).getFrameAs2DStringArray("M");
+ Assert.assertEquals("one", frame[0][0]);
+ Assert.assertEquals("two", frame[0][1]);
+ Assert.assertEquals("three", frame[1][0]);
+ Assert.assertEquals("four", frame[1][1]);
+ }
+
+ @Test
+ public void testInputFrameAndMatrixOutputMatrix() {
+ System.out.println("MLContextFrameTest - input frame and matrix, output matrix");
+
+ List<String> dataA = new ArrayList<String>();
+ dataA.add("Test1,4.0");
+ dataA.add("Test2,5.0");
+ dataA.add("Test3,6.0");
+ JavaRDD<String> javaRddStringA = sc.parallelize(dataA);
+
+ List<String> dataB = new ArrayList<String>();
+ dataB.add("1.0");
+ dataB.add("2.0");
+ JavaRDD<String> javaRddStringB = sc.parallelize(dataB);
+
+ JavaRDD<Row> javaRddRowA = javaRddStringA.map(new CommaSeparatedValueStringToRow());
+ JavaRDD<Row> javaRddRowB = javaRddStringB.map(new CommaSeparatedValueStringToRow());
+
+ SQLContext sqlContext = new SQLContext(sc);
+
+ List<StructField> fieldsA = new ArrayList<StructField>();
+ fieldsA.add(DataTypes.createStructField("1", DataTypes.StringType, true));
+ fieldsA.add(DataTypes.createStructField("2", DataTypes.DoubleType, true));
+ StructType schemaA = DataTypes.createStructType(fieldsA);
+ DataFrame dataFrameA = sqlContext.createDataFrame(javaRddRowA, schemaA);
+
+ List<StructField> fieldsB = new ArrayList<StructField>();
+ fieldsB.add(DataTypes.createStructField("1", DataTypes.DoubleType, true));
+ StructType schemaB = DataTypes.createStructType(fieldsB);
+ DataFrame dataFrameB = sqlContext.createDataFrame(javaRddRowB, schemaB);
+
+ String dmlString = "[tA, tAM] = transformencode (target = A, spec = \"{ids: true ,recode: [ 1, 2 ]}\");\n"
+ + "C = tA %*% B;\n" + "M = s * C;";
+
+ Script script = dml(dmlString)
+ .in("A", dataFrameA,
+ new FrameMetadata(FrameFormat.CSV, dataFrameA.count(), (long) dataFrameA.columns().length))
+ .in("B", dataFrameB,
+ new MatrixMetadata(MatrixFormat.CSV, dataFrameB.count(), (long) dataFrameB.columns().length))
+ .in("s", 2).out("M");
+ MLResults results = ml.execute(script);
+ double[][] matrix = results.getMatrixAs2DDoubleArray("M");
+ Assert.assertEquals(6.0, matrix[0][0], 0.0);
+ Assert.assertEquals(12.0, matrix[1][0], 0.0);
+ Assert.assertEquals(18.0, matrix[2][0], 0.0);
+ }
+
+ // NOTE: the ordering of the frame values seem to come out differently here
+ // than in the scala shell,
+ // so this should be investigated or explained.
+ // @Test
+ // public void testInputFrameOutputMatrixAndFrame() {
+ // System.out.println("MLContextFrameTest - input frame, output matrix and
+ // frame");
+ //
+ // List<String> dataA = new ArrayList<String>();
+ // dataA.add("Test1,Test4");
+ // dataA.add("Test2,Test5");
+ // dataA.add("Test3,Test6");
+ // JavaRDD<String> javaRddStringA = sc.parallelize(dataA);
+ //
+ // JavaRDD<Row> javaRddRowA = javaRddStringA.map(new
+ // CommaSeparatedValueStringToRow());
+ //
+ // SQLContext sqlContext = new SQLContext(sc);
+ //
+ // List<StructField> fieldsA = new ArrayList<StructField>();
+ // fieldsA.add(DataTypes.createStructField("1", DataTypes.StringType,
+ // true));
+ // fieldsA.add(DataTypes.createStructField("2", DataTypes.StringType,
+ // true));
+ // StructType schemaA = DataTypes.createStructType(fieldsA);
+ // DataFrame dataFrameA = sqlContext.createDataFrame(javaRddRowA, schemaA);
+ //
+ // String dmlString = "[tA, tAM] = transformencode (target = A, spec =
+ // \"{ids: true ,recode: [ 1, 2 ]}\");\n";
+ //
+ // Script script = dml(dmlString)
+ // .in("A", dataFrameA,
+ // new FrameMetadata(FrameFormat.CSV, dataFrameA.count(), (long)
+ // dataFrameA.columns().length))
+ // .out("tA", "tAM");
+ // MLResults results = ml.execute(script);
+ // double[][] matrix = results.getMatrixAs2DDoubleArray("tA");
+ // Assert.assertEquals(1.0, matrix[0][0], 0.0);
+ // Assert.assertEquals(1.0, matrix[0][1], 0.0);
+ // Assert.assertEquals(2.0, matrix[1][0], 0.0);
+ // Assert.assertEquals(2.0, matrix[1][1], 0.0);
+ // Assert.assertEquals(3.0, matrix[2][0], 0.0);
+ // Assert.assertEquals(3.0, matrix[2][1], 0.0);
+ //
+ // TODO: Add asserts for frame if ordering is as expected
+ // String[][] frame = results.getFrameAs2DStringArray("tAM");
+ // for (int i = 0; i < frame.length; i++) {
+ // for (int j = 0; j < frame[i].length; j++) {
+ // System.out.println("[" + i + "][" + j + "]:" + frame[i][j]);
+ // }
+ // }
+ // }
+
+ @After
+ public void tearDown() {
+ super.tearDown();
+ }
+
+ @AfterClass
+ public static void tearDownClass() {
+ // stop spark context to allow single jvm tests (otherwise the
+ // next test that tries to create a SparkContext would fail)
+ sc.stop();
+ sc = null;
+ conf = null;
+
+ // clear status mlcontext and spark exec context
+ ml.close();
+ ml = null;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
index fd220d9..0252b50 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -38,7 +38,6 @@ import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -67,9 +66,7 @@ import org.apache.sysml.api.mlcontext.MatrixFormat;
import org.apache.sysml.api.mlcontext.MatrixMetadata;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.api.mlcontext.ScriptExecutor;
-import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysml.test.integration.AutomatedTestBase;
import org.junit.After;
import org.junit.AfterClass;
@@ -87,9 +84,6 @@ public class MLContextTest extends AutomatedTestBase {
protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext";
protected final static String TEST_NAME = "MLContext";
- public static enum SCRIPT_TYPE {DML, PYDML, SCALA};
- public static enum IO_TYPE {ANY, FILE, JAVA_RDD_STR_CSV, JAVA_RDD_STR_IJV, RDD_STR_CSV, RDD_STR_IJV, DATAFRAME};
-
private static SparkConf conf;
private static JavaSparkContext sc;
private static MLContext ml;
@@ -927,7 +921,7 @@ public class MLContextTest extends AutomatedTestBase {
public void testOutputDoubleArrayMatrixDML() {
System.out.println("MLContextTest - output double array matrix DML");
String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
- double[][] matrix = ml.execute(dml(s).out("M")).getDoubleMatrix("M");
+ double[][] matrix = ml.execute(dml(s).out("M")).getMatrixAs2DDoubleArray("M");
Assert.assertEquals(1.0, matrix[0][0], 0);
Assert.assertEquals(2.0, matrix[0][1], 0);
Assert.assertEquals(3.0, matrix[1][0], 0);
@@ -938,7 +932,7 @@ public class MLContextTest extends AutomatedTestBase {
public void testOutputDoubleArrayMatrixPYDML() {
System.out.println("MLContextTest - output double array matrix PYDML");
String s = "M = full('1 2 3 4', rows=2, cols=2)";
- double[][] matrix = ml.execute(pydml(s).out("M")).getDoubleMatrix("M");
+ double[][] matrix = ml.execute(pydml(s).out("M")).getMatrixAs2DDoubleArray("M");
Assert.assertEquals(1.0, matrix[0][0], 0);
Assert.assertEquals(2.0, matrix[0][1], 0);
Assert.assertEquals(3.0, matrix[1][0], 0);
@@ -1032,34 +1026,6 @@ public class MLContextTest extends AutomatedTestBase {
}
@Test
- public void testOutputFrameDML() {
- System.out.println("MLContextTest - output frame DML");
-
- String s = "M = read($Min, data_type='frame', format='csv');";
- String csvFile = baseDirectory + File.separator + "one-two-three-four.csv";
- Script script = dml(s).in("$Min", csvFile).out("M");
- String[][] frame = ml.execute(script).getFrame("M");
- Assert.assertEquals("one", frame[0][0]);
- Assert.assertEquals("two", frame[0][1]);
- Assert.assertEquals("three", frame[1][0]);
- Assert.assertEquals("four", frame[1][1]);
- }
-
- @Test
- public void testOutputFramePYDML() {
- System.out.println("MLContextTest - output frame PYDML");
-
- String s = "M = load($Min, data_type='frame', format='csv')";
- String csvFile = baseDirectory + File.separator + "one-two-three-four.csv";
- Script script = pydml(s).in("$Min", csvFile).out("M");
- String[][] frame = ml.execute(script).getFrame("M");
- Assert.assertEquals("one", frame[0][0]);
- Assert.assertEquals("two", frame[0][1]);
- Assert.assertEquals("three", frame[1][0]);
- Assert.assertEquals("four", frame[1][1]);
- }
-
- @Test
public void testOutputJavaRDDStringIJVDML() {
System.out.println("MLContextTest - output Java RDD String IJV DML");
@@ -1518,7 +1484,7 @@ public class MLContextTest extends AutomatedTestBase {
String s = "M = matrix('1 2 3 4', rows=2, cols=2); N = sum(M)";
// alternative to .out("M").out("N")
MLResults results = ml.execute(dml(s).out("M", "N"));
- double[][] matrix = results.getDoubleMatrix("M");
+ double[][] matrix = results.getMatrixAs2DDoubleArray("M");
double sum = results.getDouble("N");
Assert.assertEquals(1.0, matrix[0][0], 0);
Assert.assertEquals(2.0, matrix[0][1], 0);
@@ -1534,7 +1500,7 @@ public class MLContextTest extends AutomatedTestBase {
String s = "M = full('1 2 3 4', rows=2, cols=2)\nN = sum(M)";
// alternative to .out("M").out("N")
MLResults results = ml.execute(pydml(s).out("M", "N"));
- double[][] matrix = results.getDoubleMatrix("M");
+ double[][] matrix = results.getMatrixAs2DDoubleArray("M");
double sum = results.getDouble("N");
Assert.assertEquals(1.0, matrix[0][0], 0);
Assert.assertEquals(2.0, matrix[0][1], 0);
@@ -2262,6 +2228,7 @@ public class MLContextTest extends AutomatedTestBase {
setExpectedStdOut("sum: 45.0");
ml.execute(script);
}
+
// NOTE: Uncomment these tests once they work
// @SuppressWarnings({ "rawtypes", "unchecked" })
@@ -2330,322 +2297,6 @@ public class MLContextTest extends AutomatedTestBase {
// ml.execute(script);
// }
- ////////////////////////////////////////////
- // SystemML Frame MLContext testset Begin
- ////////////////////////////////////////////
- @Test
- public void testFrameJavaRDD_CSV_DML() {
- testFrame(MatrixFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameJavaRDD_CSV_DML_OutJavaRddCSV() {
- testFrame(MatrixFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.JAVA_RDD_STR_CSV);
- }
-
- @Test
- public void testFrameJavaRDD_CSV_PYDML() {
- testFrame(MatrixFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameRDD_CSV_PYDML() {
- testFrame(MatrixFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.RDD_STR_CSV, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameJavaRDD_CSV_PYDML_OutRddIJV() {
- testFrame(MatrixFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.RDD_STR_IJV);
- }
-
- @Test
- public void testFrameJavaRDD_IJV_DML() {
- testFrame(MatrixFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameRDD_IJV_DML() {
- testFrame(MatrixFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.RDD_STR_IJV, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameJavaRDD_IJV_DML_OutRddCSV() {
- testFrame(MatrixFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.RDD_STR_CSV);
- }
-
- @Test
- public void testFrameJavaRDD_IJV_PYDML() {
- testFrame(MatrixFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameJavaRDD_IJV_PYDML_OutJavaRddIJV() {
- testFrame(MatrixFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.JAVA_RDD_STR_IJV);
- }
-
- @Test
- public void testFrameFile_CSV_DML() {
- testFrame(MatrixFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.FILE, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameFile_CSV_PYDML() {
- testFrame(MatrixFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.FILE, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameFile_IJV_DML() {
- testFrame(MatrixFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.FILE, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameFile_IJV_PYDML() {
- testFrame(MatrixFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.FILE, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameDataFrame_CSV_DML() {
- testFrame(MatrixFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.DATAFRAME, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameDataFrame_CSV_PYDML() {
- testFrame(MatrixFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.DATAFRAME, IO_TYPE.ANY);
- }
-
- @Test
- public void testFrameDataFrameOutDataFrame_CSV_DML() {
- testFrame(MatrixFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.DATAFRAME, IO_TYPE.DATAFRAME);
- }
-
-
-
-
- public void testFrame(MatrixFormat format, SCRIPT_TYPE script_type, IO_TYPE inputType, IO_TYPE outputType) {
-
- System.out.println("MLContextTest - Frame JavaRDD<String> for format: " + format + " Script: " + script_type);
-
- List<String> listA = new ArrayList<String>();
- List<String> listB = new ArrayList<String>();
- MatrixMetadata mmA = null, mmB = null;
- Script script = null;
-
-
- if(inputType != IO_TYPE.FILE) {
- if(format == MatrixFormat.CSV) {
- listA.add("1,Str2,3.0,true");
- listA.add("4,Str5,6.0,false");
- listA.add("7,Str8,9.0,true");
-
- listB.add("Str12,13.0,true");
- listB.add("Str25,26.0,false");
-
- mmA = new MatrixMetadata(MatrixFormat.CSV, 3, 4);
- mmB = new MatrixMetadata(MatrixFormat.CSV, 2, 3);
- } else if(format == MatrixFormat.IJV) {
- listA.add("1 1 1");
- listA.add("1 2 Str2");
- listA.add("1 3 3.0");
- listA.add("1 4 true");
- listA.add("2 1 4");
- listA.add("2 2 Str5");
- listA.add("2 3 6.0");
- listA.add("2 4 false");
- listA.add("3 1 7");
- listA.add("3 2 Str8");
- listA.add("3 3 9.0");
- listA.add("3 4 true");
-
- listB.add("1 1 Str12");
- listB.add("1 2 13.0");
- listB.add("1 3 true");
- listB.add("2 1 Str25");
- listB.add("2 2 26.0");
- listB.add("2 3 false");
-
- mmA = new MatrixMetadata(MatrixFormat.IJV, 3, 4);
- mmB = new MatrixMetadata(MatrixFormat.IJV, 2, 3);
- }
- JavaRDD<String> javaRDDA = sc.parallelize(listA);
- JavaRDD<String> javaRDDB = sc.parallelize(listB);
-
- if(inputType == IO_TYPE.DATAFRAME) {
- JavaRDD<Row> javaRddRowA = javaRDDA.map(new CommaSeparatedValueStringToRow());
- JavaRDD<Row> javaRddRowB = javaRDDB.map(new CommaSeparatedValueStringToRow());
-
- ValueType[] schemaA = {ValueType.INT, ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN};
- List<ValueType> lschemaA = Arrays.asList(schemaA);
- ValueType[] schemaB = {ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN};
- List<ValueType> lschemaB = Arrays.asList(schemaB);
-
- //Create DataFrame
- SQLContext sqlContext = new SQLContext(sc);
- StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(lschemaA);
- DataFrame dataFrameA = sqlContext.createDataFrame(javaRddRowA, dfSchemaA);
- StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(lschemaB);
- DataFrame dataFrameB = sqlContext.createDataFrame(javaRddRowB, dfSchemaB);
- if (script_type == SCRIPT_TYPE.DML)
- script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", dataFrameA, mmA, true).in("B", dataFrameB, mmB, true).out("A").out("C");
- else if (script_type == SCRIPT_TYPE.PYDML)
- // DO NOT USE ; at the end of any statment, it throws NPE
- script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", dataFrameA, mmA, true).in("B", dataFrameB, mmB, true)
- // Value for ROW index gets incremented at script level to adjust index in PyDML, but not for Column Index
- .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
- } else {
- if(inputType == IO_TYPE.JAVA_RDD_STR_CSV || inputType == IO_TYPE.JAVA_RDD_STR_IJV) {
- if (script_type == SCRIPT_TYPE.DML)
- script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", javaRDDA, mmA, true).in("B", javaRDDB, mmB, true).out("A").out("C");
- else if (script_type == SCRIPT_TYPE.PYDML)
- // DO NOT USE ; at the end of any statment, it throws NPE
- script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", javaRDDA, mmA, true).in("B", javaRDDB, mmB, true)
- // Value for ROW index gets incremented at script level to adjust index in PyDML, but not for Column Index
- .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
- } else if(inputType == IO_TYPE.RDD_STR_CSV || inputType == IO_TYPE.RDD_STR_IJV) {
- RDD<String> rddA = JavaRDD.toRDD(javaRDDA);
- RDD<String> rddB = JavaRDD.toRDD(javaRDDB);
-
- if (script_type == SCRIPT_TYPE.DML)
- script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", rddA, mmA, true).in("B", rddB, mmB, true).out("A").out("C");
- else if (script_type == SCRIPT_TYPE.PYDML)
- // DO NOT USE ; at the end of any statment, it throws NPE
- script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", rddA, mmA, true).in("B", rddB, mmB, true)
- // Value for ROW index gets incremented at script level to adjust index in PyDML, but not for Column Index
- .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
- }
-
- }
-
- } else { // Input type is file
- String fileA = null, fileB = null;
- if(format == MatrixFormat.CSV) {
- fileA = baseDirectory + File.separator + "FrameA.csv";
- fileB = baseDirectory + File.separator + "FrameB.csv";
- } else if(format == MatrixFormat.IJV) {
- fileA = baseDirectory + File.separator + "FrameA.ijv";
- fileB = baseDirectory + File.separator + "FrameB.ijv";
- }
-
- if (script_type == SCRIPT_TYPE.DML)
- script = dml("A=read($A); B=read($B);A[2:3,2:4]=B;C=A[2:3,2:3]").in("$A", fileA, mmA, true).in("$B", fileB, mmB, true).out("A").out("C");
- else if (script_type == SCRIPT_TYPE.PYDML)
- // DO NOT USE ; at the end of any statment, it throws NPE
- script = pydml("A=load($A)\nB=load($B)\nA[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("$A", fileA).in("$B", fileB)
- // Value for ROW index gets incremented at script level to adjust index in PyDML, but not for Column Index
- .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C");
- }
-
- MLResults mlResults = ml.execute(script);
-
- if(outputType == IO_TYPE.JAVA_RDD_STR_CSV) {
-
- JavaRDD<String> javaRDDStringCSVA = mlResults.getFrameJavaRDDStringCSV("A", ",");
- List<String> linesA = javaRDDStringCSVA.collect();
- Assert.assertEquals("1,Str2,3.0,true", linesA.get(0));
- Assert.assertEquals("4,Str12,13.0,true", linesA.get(1));
- Assert.assertEquals("7,Str25,26.0,false", linesA.get(2));
-
- JavaRDD<String> javaRDDStringCSVC = mlResults.getFrameJavaRDDStringCSV("C", ",");
- List<String> linesC = javaRDDStringCSVC.collect();
- Assert.assertEquals("Str12,13.0", linesC.get(0));
- Assert.assertEquals("Str25,26.0", linesC.get(1));
- } else if(outputType == IO_TYPE.JAVA_RDD_STR_IJV) {
- JavaRDD<String> javaRDDStringIJVA = mlResults.getFrameJavaRDDStringIJV("A");
- List<String> linesA = javaRDDStringIJVA.collect();
- Assert.assertEquals("1 1 1", linesA.get(0));
- Assert.assertEquals("1 2 Str2", linesA.get(1));
- Assert.assertEquals("1 3 3.0", linesA.get(2));
- Assert.assertEquals("1 4 true", linesA.get(3));
- Assert.assertEquals("2 1 4", linesA.get(4));
- Assert.assertEquals("2 2 Str12", linesA.get(5));
- Assert.assertEquals("2 3 13.0", linesA.get(6));
- Assert.assertEquals("2 4 true", linesA.get(7));
-
- JavaRDD<String> javaRDDStringIJVC = mlResults.getFrameJavaRDDStringIJV("C");
- List<String> linesC = javaRDDStringIJVC.collect();
- Assert.assertEquals("1 1 Str12", linesC.get(0));
- Assert.assertEquals("1 2 13.0", linesC.get(1));
- Assert.assertEquals("2 1 Str25", linesC.get(2));
- Assert.assertEquals("2 2 26.0", linesC.get(3));
- } else if(outputType == IO_TYPE.RDD_STR_CSV) {
- RDD<String> rddStringCSVA = mlResults.getFrameRDDStringCSV("A", ","); //TODO fix delimiter
- Iterator<String> iteratorA = rddStringCSVA.toLocalIterator();
- Assert.assertEquals("1,Str2,3.0,true", iteratorA.next());
- Assert.assertEquals("4,Str12,13.0,true", iteratorA.next());
- Assert.assertEquals("7,Str25,26.0,false", iteratorA.next());
-
- RDD<String> rddStringCSVC = mlResults.getFrameRDDStringCSV("C", ","); //TODO fix delimiter
- Iterator<String> iteratorC = rddStringCSVC.toLocalIterator();
- Assert.assertEquals("Str12,13.0", iteratorC.next());
- Assert.assertEquals("Str25,26.0", iteratorC.next());
- } else if(outputType == IO_TYPE.RDD_STR_IJV) {
- RDD<String> rddStringIJVA = mlResults.getFrameRDDStringIJV("A");
- Iterator<String> iteratorA = rddStringIJVA.toLocalIterator();
- Assert.assertEquals("1 1 1", iteratorA.next());
- Assert.assertEquals("1 2 Str2", iteratorA.next());
- Assert.assertEquals("1 3 3.0", iteratorA.next());
- Assert.assertEquals("1 4 true", iteratorA.next());
- Assert.assertEquals("2 1 4", iteratorA.next());
- Assert.assertEquals("2 2 Str12", iteratorA.next());
- Assert.assertEquals("2 3 13.0", iteratorA.next());
- Assert.assertEquals("2 4 true", iteratorA.next());
- Assert.assertEquals("3 1 7", iteratorA.next());
- Assert.assertEquals("3 2 Str25", iteratorA.next());
- Assert.assertEquals("3 3 26.0", iteratorA.next());
- Assert.assertEquals("3 4 false", iteratorA.next());
-
- RDD<String> rddStringIJVC = mlResults.getFrameRDDStringIJV("C");
- Iterator<String> iteratorC = rddStringIJVC.toLocalIterator();
- Assert.assertEquals("1 1 Str12", iteratorC.next());
- Assert.assertEquals("1 2 13.0", iteratorC.next());
- Assert.assertEquals("2 1 Str25", iteratorC.next());
- Assert.assertEquals("2 2 26.0", iteratorC.next());
-
- } else if(outputType == IO_TYPE.DATAFRAME) {
-
- DataFrame dataFrameA = mlResults.getFrameDataFrame("A");
- List<Row> listAOut = dataFrameA.collectAsList();
-
- Row row1 = listAOut.get(0);
- Assert.assertEquals("Mistmatch with expected value", "1", row1.getString(0));
- Assert.assertEquals("Mistmatch with expected value", "Str2", row1.getString(1));
- Assert.assertEquals("Mistmatch with expected value", "3.0", row1.getString(2));
- Assert.assertEquals("Mistmatch with expected value", "true", row1.getString(3));
-
- Row row2 = listAOut.get(1);
- Assert.assertEquals("Mistmatch with expected value", "4", row2.getString(0));
- Assert.assertEquals("Mistmatch with expected value", "Str12", row2.getString(1));
- Assert.assertEquals("Mistmatch with expected value", "13.0", row2.getString(2));
- Assert.assertEquals("Mistmatch with expected value", "true", row2.getString(3));
-
- DataFrame dataFrameC = mlResults.getFrameDataFrame("C");
- List<Row> listCOut = dataFrameC.collectAsList();
-
- Row row3 = listCOut.get(0);
- Assert.assertEquals("Mistmatch with expected value", "Str12", row3.getString(0));
- Assert.assertEquals("Mistmatch with expected value", "13.0", row3.getString(1));
-
- Row row4 = listCOut.get(1);
- Assert.assertEquals("Mistmatch with expected value", "Str25", row4.getString(0));
- Assert.assertEquals("Mistmatch with expected value", "26.0", row4.getString(1));
- } else {
- String[][] frameA = mlResults.getFrame("A");
- Assert.assertEquals("Str2", frameA[0][1]);
- Assert.assertEquals("3.0", frameA[0][2]);
- Assert.assertEquals("13.0", frameA[1][2]);
- Assert.assertEquals("true", frameA[1][3]);
- Assert.assertEquals("Str25", frameA[2][1]);
-
- String[][] frameC = mlResults.getFrame("C");
- Assert.assertEquals("Str12", frameC[0][0]);
- Assert.assertEquals("Str25", frameC[1][0]);
- Assert.assertEquals("13.0", frameC[0][1]);
- Assert.assertEquals("26.0", frameC[1][1]);
- }
- }
- ////////////////////////////////////////////
- // SystemML Frame MLContext testset End
- ////////////////////////////////////////////
-
@After
public void tearDown() {
super.tearDown();
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java
index 5687a55..387579f 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java
@@ -27,7 +27,8 @@ import org.junit.runners.Suite;
* they should not be run in parallel. */
@RunWith(Suite.class)
@Suite.SuiteClasses({
- org.apache.sysml.test.integration.mlcontext.MLContextTest.class
+ org.apache.sysml.test.integration.mlcontext.MLContextTest.class,
+ org.apache.sysml.test.integration.mlcontext.MLContextFrameTest.class
})
[2/2] incubator-systemml git commit: [SYSTEMML-896] Additional
MLContext Frame support
Posted by de...@apache.org.
[SYSTEMML-896] Additional MLContext Frame support
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/d39865e9
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/d39865e9
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/d39865e9
Branch: refs/heads/master
Commit: d39865e9e3468cee0fab95cb9d8efe1ba4fe992f
Parents: b7657db
Author: Deron Eriksson <de...@us.ibm.com>
Authored: Sat Sep 3 23:10:40 2016 -0700
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Sat Sep 3 23:10:40 2016 -0700
----------------------------------------------------------------------
.../sysml/api/mlcontext/BinaryBlockFrame.java | 179 +++++
.../sysml/api/mlcontext/BinaryBlockMatrix.java | 4 +-
.../org/apache/sysml/api/mlcontext/Frame.java | 138 ++++
.../apache/sysml/api/mlcontext/FrameFormat.java | 42 ++
.../sysml/api/mlcontext/FrameMetadata.java | 695 +++++++++++++++++++
.../apache/sysml/api/mlcontext/FrameSchema.java | 128 ++++
.../api/mlcontext/MLContextConversionUtil.java | 250 ++++---
.../sysml/api/mlcontext/MLContextUtil.java | 306 +++++---
.../apache/sysml/api/mlcontext/MLResults.java | 247 ++++---
.../sysml/api/mlcontext/MatrixMetadata.java | 2 +-
.../apache/sysml/api/mlcontext/Metadata.java | 30 +
.../org/apache/sysml/api/mlcontext/Script.java | 100 +--
.../sysml/api/mlcontext/ScriptExecutor.java | 10 +-
.../mlcontext/MLContextFrameTest.java | 557 +++++++++++++++
.../integration/mlcontext/MLContextTest.java | 359 +---------
.../integration/mlcontext/ZPackageSuite.java | 3 +-
16 files changed, 2319 insertions(+), 731 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockFrame.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockFrame.java b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockFrame.java
new file mode 100644
index 0000000..88b1b38
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockFrame.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.sql.DataFrame;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.FrameBlock;
+
+/**
+ * BinaryBlockFrame stores data as a SystemML binary-block frame representation.
+ *
+ */
+public class BinaryBlockFrame {
+
+ JavaPairRDD<Long, FrameBlock> binaryBlocks;
+ FrameMetadata frameMetadata;
+
+ /**
+ * Convert a Spark DataFrame to a SystemML binary-block representation.
+ *
+ * @param dataFrame
+ * the Spark DataFrame
+ * @param frameMetadata
+ * frame metadata, such as number of rows and columns
+ */
+ public BinaryBlockFrame(DataFrame dataFrame, FrameMetadata frameMetadata) {
+ this.frameMetadata = frameMetadata;
+ binaryBlocks = MLContextConversionUtil.dataFrameToFrameBinaryBlocks(dataFrame, frameMetadata);
+ }
+
+ /**
+ * Convert a Spark DataFrame to a SystemML binary-block representation,
+ * specifying the number of rows and columns.
+ *
+ * @param dataFrame
+ * the Spark DataFrame
+ * @param numRows
+ * the number of rows
+ * @param numCols
+ * the number of columns
+ */
+ public BinaryBlockFrame(DataFrame dataFrame, long numRows, long numCols) {
+ this(dataFrame, new FrameMetadata(numRows, numCols, MLContextUtil.defaultBlockSize(),
+ MLContextUtil.defaultBlockSize()));
+ }
+
+ /**
+ * Convert a Spark DataFrame to a SystemML binary-block representation.
+ *
+ * @param dataFrame
+ * the Spark DataFrame
+ */
+ public BinaryBlockFrame(DataFrame dataFrame) {
+ this(dataFrame, new FrameMetadata());
+ }
+
+ /**
+ * Create a BinaryBlockFrame, specifying the SystemML binary-block frame and
+ * its metadata.
+ *
+ * @param binaryBlocks
+ * the {@code JavaPairRDD<Long, FrameBlock>} frame
+ * @param matrixCharacteristics
+ * the frame metadata as {@code MatrixCharacteristics}
+ */
+ public BinaryBlockFrame(JavaPairRDD<Long, FrameBlock> binaryBlocks, MatrixCharacteristics matrixCharacteristics) {
+ this.binaryBlocks = binaryBlocks;
+ this.frameMetadata = new FrameMetadata(matrixCharacteristics);
+ }
+
+ /**
+ * Create a BinaryBlockFrame, specifying the SystemML binary-block frame and
+ * its metadata.
+ *
+ * @param binaryBlocks
+ * the {@code JavaPairRDD<Long, FrameBlock>} frame
+ * @param frameMetadata
+ * the frame metadata as {@code FrameMetadata}
+ */
+ public BinaryBlockFrame(JavaPairRDD<Long, FrameBlock> binaryBlocks, FrameMetadata frameMetadata) {
+ this.binaryBlocks = binaryBlocks;
+ this.frameMetadata = frameMetadata;
+ }
+
+ /**
+ * Obtain a SystemML binary-block frame as a
+ * {@code JavaPairRDD<Long, FrameBlock>}
+ *
+ * @return the SystemML binary-block frame
+ */
+ public JavaPairRDD<Long, FrameBlock> getBinaryBlocks() {
+ return binaryBlocks;
+ }
+
+ /**
+ * Obtain a SystemML binary-block frame as a {@code FrameBlock}
+ *
+ * @return the SystemML binary-block frame as a {@code FrameBlock}
+ */
+ public FrameBlock getFrameBlock() {
+ try {
+ MatrixCharacteristics mc = getMatrixCharacteristics();
+ FrameSchema frameSchema = frameMetadata.getFrameSchema();
+ FrameBlock mb = SparkExecutionContext.toFrameBlock(binaryBlocks, frameSchema.getSchema(),
+ (int) mc.getRows(), (int) mc.getCols());
+ return mb;
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception while getting FrameBlock from binary-block frame", e);
+ }
+ }
+
+ /**
+ * Obtain the SystemML binary-block frame characteristics
+ *
+ * @return the frame metadata as {@code MatrixCharacteristics}
+ */
+ public MatrixCharacteristics getMatrixCharacteristics() {
+ return frameMetadata.asMatrixCharacteristics();
+ }
+
+ /**
+ * Obtain the SystemML binary-block frame metadata
+ *
+ * @return the frame metadata as {@code FrameMetadata}
+ */
+ public FrameMetadata getFrameMetadata() {
+ return frameMetadata;
+ }
+
+ /**
+ * Set the SystemML binary-block frame metadata
+ *
+ * @param frameMetadata
+ * the frame metadata
+ */
+ public void setFrameMetadata(FrameMetadata frameMetadata) {
+ this.frameMetadata = frameMetadata;
+ }
+
+ /**
+ * Set the SystemML binary-block frame as a
+ * {@code JavaPairRDD<Long, FrameBlock>}
+ *
+ * @param binaryBlocks
+ * the SystemML binary-block frame
+ */
+ public void setBinaryBlocks(JavaPairRDD<Long, FrameBlock> binaryBlocks) {
+ this.binaryBlocks = binaryBlocks;
+ }
+
+ @Override
+ public String toString() {
+ if (frameMetadata != null) {
+ return frameMetadata.toString();
+ } else {
+ return super.toString();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
index b13669d..ffa8a11 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
@@ -28,7 +28,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
/**
- * BinaryBlockMatrix stores data as a SystemML binary-block representation.
+ * BinaryBlockMatrix stores data as a SystemML binary-block matrix representation.
*
*/
public class BinaryBlockMatrix {
@@ -46,7 +46,7 @@ public class BinaryBlockMatrix {
*/
public BinaryBlockMatrix(DataFrame dataFrame, MatrixMetadata matrixMetadata) {
this.matrixMetadata = matrixMetadata;
- binaryBlocks = MLContextConversionUtil.dataFrameToBinaryBlocks(dataFrame, matrixMetadata);
+ binaryBlocks = MLContextConversionUtil.dataFrameToMatrixBinaryBlocks(dataFrame, matrixMetadata);
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/Frame.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Frame.java b/src/main/java/org/apache/sysml/api/mlcontext/Frame.java
new file mode 100644
index 0000000..ee447df
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/Frame.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+
+/**
+ * Frame encapsulates a SystemML frame.
+ *
+ */
+public class Frame {
+
+ private FrameObject frameObject;
+ private SparkExecutionContext sparkExecutionContext;
+
+ public Frame(FrameObject frameObject, SparkExecutionContext sparkExecutionContext) {
+ this.frameObject = frameObject;
+ this.sparkExecutionContext = sparkExecutionContext;
+ }
+
+ /**
+ * Obtain the frame as a SystemML FrameObject.
+ *
+ * @return the frame as a SystemML FrameObject
+ */
+ public FrameObject asFrameObject() {
+ return frameObject;
+ }
+
+ /**
+ * Obtain the frame as a two-dimensional String array
+ *
+ * @return the frame as a two-dimensional String array
+ */
+ public String[][] as2DStringArray() {
+ String[][] strArray = MLContextConversionUtil.frameObjectTo2DStringArray(frameObject);
+ return strArray;
+ }
+
+ /**
+ * Obtain the frame as a {@code JavaRDD<String>} in IJV format
+ *
+ * @return the frame as a {@code JavaRDD<String>} in IJV format
+ */
+ public JavaRDD<String> asJavaRDDStringIJV() {
+ JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil.frameObjectToJavaRDDStringIJV(frameObject);
+ return javaRDDStringIJV;
+ }
+
+ /**
+ * Obtain the frame as a {@code JavaRDD<String>} in CSV format
+ *
+ * @return the frame as a {@code JavaRDD<String>} in CSV format
+ */
+ public JavaRDD<String> asJavaRDDStringCSV() {
+ JavaRDD<String> javaRDDStringCSV = MLContextConversionUtil.frameObjectToJavaRDDStringCSV(frameObject, ",");
+ return javaRDDStringCSV;
+ }
+
+ /**
+ * Obtain the frame as a {@code RDD<String>} in CSV format
+ *
+ * @return the frame as a {@code RDD<String>} in CSV format
+ */
+ public RDD<String> asRDDStringCSV() {
+ RDD<String> rddStringCSV = MLContextConversionUtil.frameObjectToRDDStringCSV(frameObject, ",");
+ return rddStringCSV;
+ }
+
+ /**
+ * Obtain the frame as a {@code RDD<String>} in IJV format
+ *
+ * @return the frame as a {@code RDD<String>} in IJV format
+ */
+ public RDD<String> asRDDStringIJV() {
+ RDD<String> rddStringIJV = MLContextConversionUtil.frameObjectToRDDStringIJV(frameObject);
+ return rddStringIJV;
+ }
+
+ /**
+ * Obtain the frame as a {@code DataFrame}
+ *
+ * @return the frame as a {@code DataFrame}
+ */
+ public DataFrame asDataFrame() {
+ DataFrame df = MLContextConversionUtil.frameObjectToDataFrame(frameObject, sparkExecutionContext);
+ return df;
+ }
+
+ /**
+ * Obtain the matrix as a {@code BinaryBlockFrame}
+ *
+ * @return the matrix as a {@code BinaryBlockFrame}
+ */
+ public BinaryBlockFrame asBinaryBlockFrame() {
+ BinaryBlockFrame binaryBlockFrame = MLContextConversionUtil.frameObjectToBinaryBlockFrame(frameObject,
+ sparkExecutionContext);
+ return binaryBlockFrame;
+ }
+
+ /**
+ * Obtain the frame metadata
+ *
+ * @return the frame metadata
+ */
+ public FrameMetadata getFrameMetadata() {
+ MatrixCharacteristics matrixCharacteristics = frameObject.getMatrixCharacteristics();
+ FrameMetadata frameMetadata = new FrameMetadata(matrixCharacteristics);
+ return frameMetadata;
+ }
+
+ @Override
+ public String toString() {
+ return frameObject.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/FrameFormat.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/FrameFormat.java b/src/main/java/org/apache/sysml/api/mlcontext/FrameFormat.java
new file mode 100644
index 0000000..bce8e5d
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/FrameFormat.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+/**
+ * FrameFormat represents the different frame formats supported by the MLContext
+ * API.
+ *
+ */
+public enum FrameFormat {
+ /**
+ * Comma-separated value format (dense).
+ */
+ CSV,
+
+ /**
+ * (I J V) format (sparse). I and J represent frame coordinates and V
+ * represents the value. The I J and V values are space-separated.
+ */
+ IJV;
+
+ public boolean hasIDColumn() {
+ return false;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/FrameMetadata.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/FrameMetadata.java b/src/main/java/org/apache/sysml/api/mlcontext/FrameMetadata.java
new file mode 100644
index 0000000..5aabd80
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/FrameMetadata.java
@@ -0,0 +1,695 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+
+/**
+ * Frame metadata, such as the number of rows, the number of columns, the number
+ * of non-zero values, the number of rows per block, and the number of columns
+ * per block in the frame.
+ *
+ */
+public class FrameMetadata extends Metadata {
+
+ private Long numRows = null;
+ private Long numColumns = null;
+ private Long numNonZeros = null;
+ private Integer numRowsPerBlock = null;
+ private Integer numColumnsPerBlock = null;
+ private FrameFormat frameFormat;
+ private FrameSchema frameSchema;
+
+ public FrameMetadata() {
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on a string
+ * representation of a frame schema.
+ *
+ * @param schema
+ * String representation of the frame schema.
+ */
+ public FrameMetadata(String schema) {
+ this.frameSchema = new FrameSchema(schema);
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format.
+ *
+ * @param frameFormat
+ * The frame format.
+ */
+ public FrameMetadata(FrameFormat frameFormat) {
+ this.frameFormat = frameFormat;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame schema.
+ *
+ * @param frameSchema
+ * The frame schema.
+ */
+ public FrameMetadata(FrameSchema frameSchema) {
+ this.frameSchema = frameSchema;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format and
+ * frame schema.
+ *
+ * @param frameFormat
+ * The frame format.
+ * @param frameSchema
+ * The frame schema.
+ */
+ public FrameMetadata(FrameFormat frameFormat, FrameSchema frameSchema) {
+ this.frameFormat = frameFormat;
+ this.frameSchema = frameSchema;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format, frame
+ * schema, the number of rows, and the number of columns in a frame.
+ *
+ * @param frameFormat
+ * The frame format.
+ * @param frameSchema
+ * The frame schema.
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ */
+ public FrameMetadata(FrameFormat frameFormat, FrameSchema frameSchema, Long numRows, Long numColumns) {
+ this.frameFormat = frameFormat;
+ this.frameSchema = frameSchema;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format, frame
+ * schema, the number of rows, and the number of columns in a frame.
+ *
+ * @param frameFormat
+ * The frame format.
+ * @param frameSchema
+ * The frame schema.
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ */
+ public FrameMetadata(FrameFormat frameFormat, FrameSchema frameSchema, int numRows, int numColumns) {
+ this.frameFormat = frameFormat;
+ this.frameSchema = frameSchema;
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format, frame
+ * schema, the number of rows, the number of columns, the number of non-zero
+ * values, the number of rows per block, and the number of columns per block
+ * in a frame.
+ *
+ * @param frameFormat
+ * The frame format.
+ * @param frameSchema
+ * The frame schema.
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numNonZeros
+ * The number of non-zero values in the frame.
+ * @param numRowsPerBlock
+ * The number of rows per block in the frame.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the frame.
+ */
+ public FrameMetadata(FrameFormat frameFormat, FrameSchema frameSchema, Long numRows, Long numColumns,
+ Long numNonZeros, Integer numRowsPerBlock, Integer numColumnsPerBlock) {
+ this.frameFormat = frameFormat;
+ this.frameSchema = frameSchema;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.numNonZeros = numNonZeros;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format, frame
+ * schema, the number of rows, the number of columns, the number of non-zero
+ * values, the number of rows per block, and the number of columns per block
+ * in a frame.
+ *
+ * @param frameFormat
+ * The frame format.
+ * @param frameSchema
+ * The frame schema.
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numNonZeros
+ * The number of non-zero values in the frame.
+ * @param numRowsPerBlock
+ * The number of rows per block in the frame.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the frame.
+ */
+ public FrameMetadata(FrameFormat frameFormat, FrameSchema frameSchema, int numRows, int numColumns, int numNonZeros,
+ int numRowsPerBlock, int numColumnsPerBlock) {
+ this.frameFormat = frameFormat;
+ this.frameSchema = frameSchema;
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ this.numNonZeros = (long) numNonZeros;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format, the
+ * number of rows, and the number of columns in a frame.
+ *
+ * @param frameFormat
+ * The frame format.
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ */
+ public FrameMetadata(FrameFormat frameFormat, Long numRows, Long numColumns) {
+ this.frameFormat = frameFormat;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format, the
+ * number of rows, and the number of columns in a frame.
+ *
+ * @param frameFormat
+ * The frame format.
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ */
+ public FrameMetadata(FrameFormat frameFormat, int numRows, int numColumns) {
+ this.frameFormat = frameFormat;
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format, the
+ * number of rows, the number of columns, and the number of non-zero values
+ * in a frame.
+ *
+ * @param frameFormat
+ * The frame format.
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numNonZeros
+ * The number of non-zero values in the frame.
+ */
+ public FrameMetadata(FrameFormat frameFormat, Long numRows, Long numColumns, Long numNonZeros) {
+ this.frameFormat = frameFormat;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.numNonZeros = numNonZeros;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format, the
+ * number of rows, the number of columns, and the number of non-zero values
+ * in a frame.
+ *
+ * @param frameFormat
+ * The frame format.
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numNonZeros
+ * The number of non-zero values in the frame.
+ */
+ public FrameMetadata(FrameFormat frameFormat, int numRows, int numColumns, int numNonZeros) {
+ this.frameFormat = frameFormat;
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ this.numNonZeros = (long) numNonZeros;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format, the
+ * number of rows, the number of columns, the number of non-zero values, the
+ * number of rows per block, and the number of columns per block in a frame.
+ *
+ * @param frameFormat
+ * The frame format.
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numNonZeros
+ * The number of non-zero values in the frame.
+ * @param numRowsPerBlock
+ * The number of rows per block in the frame.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the frame.
+ */
+ public FrameMetadata(FrameFormat frameFormat, Long numRows, Long numColumns, Long numNonZeros,
+ Integer numRowsPerBlock, Integer numColumnsPerBlock) {
+ this.frameFormat = frameFormat;
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.numNonZeros = numNonZeros;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on frame format, the
+ * number of rows, the number of columns, the number of non-zero values, the
+ * number of rows per block, and the number of columns per block in a frame.
+ *
+ * @param frameFormat
+ * The frame format.
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numNonZeros
+ * The number of non-zero values in the frame.
+ * @param numRowsPerBlock
+ * The number of rows per block in the frame.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the frame.
+ */
+ public FrameMetadata(FrameFormat frameFormat, int numRows, int numColumns, int numNonZeros, int numRowsPerBlock,
+ int numColumnsPerBlock) {
+ this.frameFormat = frameFormat;
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ this.numNonZeros = (long) numNonZeros;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on the number of rows
+ * and the number of columns in a frame.
+ *
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ */
+ public FrameMetadata(Long numRows, Long numColumns) {
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on the number of rows
+ * and the number of columns in a frame.
+ *
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ */
+ public FrameMetadata(int numRows, int numColumns) {
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on the number of rows,
+ * the number of columns, and the number of non-zero values in a frame.
+ *
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numNonZeros
+ * The number of non-zero values in the frame.
+ */
+ public FrameMetadata(Long numRows, Long numColumns, Long numNonZeros) {
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.numNonZeros = numNonZeros;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on the number of rows,
+ * the number of columns, and the number of non-zero values in a frame.
+ *
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numNonZeros
+ * The number of non-zero values in the frame.
+ */
+ public FrameMetadata(int numRows, int numColumns, int numNonZeros) {
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ this.numNonZeros = (long) numNonZeros;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on the number of rows,
+ * the number of columns, the number of rows per block, and the number of
+ * columns per block in a frame.
+ *
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numRowsPerBlock
+ * The number of rows per block in the frame.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the frame.
+ */
+ public FrameMetadata(Long numRows, Long numColumns, Integer numRowsPerBlock, Integer numColumnsPerBlock) {
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on the number of rows,
+ * the number of columns, the number of rows per block, and the number of
+ * columns per block in a frame.
+ *
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numRowsPerBlock
+ * The number of rows per block in the frame.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the frame.
+ */
+ public FrameMetadata(int numRows, int numColumns, int numRowsPerBlock, int numColumnsPerBlock) {
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on the number of rows,
+ * the number of columns, the number of non-zero values, the number of rows
+ * per block, and the number of columns per block in a frame.
+ *
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numNonZeros
+ * The number of non-zero values in the frame.
+ * @param numRowsPerBlock
+ * The number of rows per block in the frame.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the frame.
+ */
+ public FrameMetadata(Long numRows, Long numColumns, Long numNonZeros, Integer numRowsPerBlock,
+ Integer numColumnsPerBlock) {
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ this.numNonZeros = numNonZeros;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on the number of rows,
+ * the number of columns, the number of non-zero values, the number of rows
+ * per block, and the number of columns per block in a frame.
+ *
+ * @param numRows
+ * The number of rows in the frame.
+ * @param numColumns
+ * The number of columns in the frame.
+ * @param numNonZeros
+ * The number of non-zero values in the frame.
+ * @param numRowsPerBlock
+ * The number of rows per block in the frame.
+ * @param numColumnsPerBlock
+ * The number of columns per block in the frame.
+ */
+ public FrameMetadata(int numRows, int numColumns, int numNonZeros, int numRowsPerBlock, int numColumnsPerBlock) {
+ this.numRows = (long) numRows;
+ this.numColumns = (long) numColumns;
+ this.numNonZeros = (long) numNonZeros;
+ this.numRowsPerBlock = numRowsPerBlock;
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on a
+ * MatrixCharacteristics object.
+ *
+ * @param matrixCharacteristics
+ * the frame metadata as a MatrixCharacteristics object
+ */
+ public FrameMetadata(MatrixCharacteristics matrixCharacteristics) {
+ this.numRows = matrixCharacteristics.getRows();
+ this.numColumns = matrixCharacteristics.getCols();
+ this.numNonZeros = matrixCharacteristics.getNonZeros();
+ this.numRowsPerBlock = matrixCharacteristics.getRowsPerBlock();
+ this.numColumnsPerBlock = matrixCharacteristics.getColsPerBlock();
+ }
+
+ /**
+ * Constructor to create a FrameMetadata object based on the frame schema
+ * and a MatrixCharacteristics object.
+ *
+ * @param frameSchema
+ * The frame schema.
+ * @param matrixCharacteristics
+ * the frame metadata as a MatrixCharacteristics object
+ */
+ public FrameMetadata(FrameSchema frameSchema, MatrixCharacteristics matrixCharacteristics) {
+ this.frameSchema = frameSchema;
+ this.numRows = matrixCharacteristics.getRows();
+ this.numColumns = matrixCharacteristics.getCols();
+ this.numNonZeros = matrixCharacteristics.getNonZeros();
+ this.numRowsPerBlock = matrixCharacteristics.getRowsPerBlock();
+ this.numColumnsPerBlock = matrixCharacteristics.getColsPerBlock();
+ }
+
+ /**
+ * Set the FrameMetadata fields based on a MatrixCharacteristics object.
+ *
+ * @param matrixCharacteristics
+ * the frame metadata as a MatrixCharacteristics object
+ */
+ public void setMatrixCharacteristics(MatrixCharacteristics matrixCharacteristics) {
+ this.numRows = matrixCharacteristics.getRows();
+ this.numColumns = matrixCharacteristics.getCols();
+ this.numNonZeros = matrixCharacteristics.getNonZeros();
+ this.numRowsPerBlock = matrixCharacteristics.getRowsPerBlock();
+ this.numColumnsPerBlock = matrixCharacteristics.getColsPerBlock();
+ }
+
+ /**
+ * Obtain the number of rows
+ *
+ * @return the number of rows
+ */
+ public Long getNumRows() {
+ return numRows;
+ }
+
+ /**
+ * Set the number of rows
+ *
+ * @param numRows
+ * the number of rows
+ */
+ public void setNumRows(Long numRows) {
+ this.numRows = numRows;
+ }
+
+ /**
+ * Obtain the number of columns
+ *
+ * @return the number of columns
+ */
+ public Long getNumColumns() {
+ return numColumns;
+ }
+
+ /**
+ * Set the number of columns
+ *
+ * @param numColumns
+ * the number of columns
+ */
+ public void setNumColumns(Long numColumns) {
+ this.numColumns = numColumns;
+ }
+
+ /**
+ * Obtain the number of non-zero values
+ *
+ * @return the number of non-zero values
+ */
+ public Long getNumNonZeros() {
+ return numNonZeros;
+ }
+
+ /**
+ * Set the number of non-zero values
+ *
+ * @param numNonZeros
+ * the number of non-zero values
+ */
+ public void setNumNonZeros(Long numNonZeros) {
+ this.numNonZeros = numNonZeros;
+ }
+
+ /**
+ * Obtain the number of rows per block
+ *
+ * @return the number of rows per block
+ */
+ public Integer getNumRowsPerBlock() {
+ return numRowsPerBlock;
+ }
+
+ /**
+ * Set the number of rows per block
+ *
+ * @param numRowsPerBlock
+ * the number of rows per block
+ */
+ public void setNumRowsPerBlock(Integer numRowsPerBlock) {
+ this.numRowsPerBlock = numRowsPerBlock;
+ }
+
+ /**
+ * Obtain the number of columns per block
+ *
+ * @return the number of columns per block
+ */
+ public Integer getNumColumnsPerBlock() {
+ return numColumnsPerBlock;
+ }
+
+ /**
+ * Set the number of columns per block
+ *
+ * @param numColumnsPerBlock
+ * the number of columns per block
+ */
+ public void setNumColumnsPerBlock(Integer numColumnsPerBlock) {
+ this.numColumnsPerBlock = numColumnsPerBlock;
+ }
+
+ /**
+ * Convert the frame metadata to a MatrixCharacteristics object. If all
+ * field values are {@code null}, {@code null} is returned.
+ *
+ * @return the frame metadata as a MatrixCharacteristics object, or
+ * {@code null} if all field values are null
+ */
+ public MatrixCharacteristics asMatrixCharacteristics() {
+
+ if ((numRows == null) && (numColumns == null) && (numRowsPerBlock == null) && (numColumnsPerBlock == null)
+ && (numNonZeros == null)) {
+ return null;
+ }
+
+ long nr = (numRows == null) ? -1 : numRows;
+ long nc = (numColumns == null) ? -1 : numColumns;
+ int nrpb = (numRowsPerBlock == null) ? MLContextUtil.defaultBlockSize() : numRowsPerBlock;
+ int ncpb = (numColumnsPerBlock == null) ? MLContextUtil.defaultBlockSize() : numColumnsPerBlock;
+ long nnz = (numNonZeros == null) ? -1 : numNonZeros;
+ MatrixCharacteristics mc = new MatrixCharacteristics(nr, nc, nrpb, ncpb, nnz);
+ return mc;
+ }
+
+ @Override
+ public String toString() {
+ return "rows: " + fieldDisplay(numRows) + ", columns: " + fieldDisplay(numColumns) + ", non-zeros: "
+ + fieldDisplay(numNonZeros) + ", rows per block: " + fieldDisplay(numRowsPerBlock)
+ + ", columns per block: " + fieldDisplay(numColumnsPerBlock);
+ }
+
+ private String fieldDisplay(Object field) {
+ if (field == null) {
+ return "None";
+ } else {
+ return field.toString();
+ }
+ }
+
+ /**
+ * Obtain the frame format
+ *
+ * @return the frame format
+ */
+ public FrameFormat getFrameFormat() {
+ return frameFormat;
+ }
+
+ /**
+ * Set the frame format
+ *
+ * @param frameFormat
+ * the frame format
+ */
+ public void setFrameFormat(FrameFormat frameFormat) {
+ this.frameFormat = frameFormat;
+ }
+
+ /**
+ * Obtain the frame schema
+ *
+ * @return the frame schema
+ */
+ public FrameSchema getFrameSchema() {
+ return frameSchema;
+ }
+
+ /**
+ * Set the frame schema
+ *
+ * @param frameSchema
+ * the frame schema
+ */
+ public void setFrameSchema(FrameSchema frameSchema) {
+ this.frameSchema = frameSchema;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/FrameSchema.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/FrameSchema.java b/src/main/java/org/apache/sysml/api/mlcontext/FrameSchema.java
new file mode 100644
index 0000000..040d77b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/mlcontext/FrameSchema.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.mlcontext;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.sysml.parser.Expression.ValueType;
+
+/**
+ * The frame schema, stored as a list of {@code ValueType} values.
+ *
+ */
+public class FrameSchema {
+
+ private List<ValueType> schema = null;
+
+ public FrameSchema() {
+ }
+
+ /**
+ * Constructor that specifies the schema as a list of {@code ValueType}
+ * values.
+ *
+ * @param schema
+ * the frame schema
+ */
+ public FrameSchema(List<ValueType> schema) {
+ this.schema = schema;
+ }
+
+ /**
+ * Constructor that specifies the schema as a comma-separated string.
+ *
+ * @param schema
+ * the frame schema as a string
+ */
+ public FrameSchema(String schema) {
+ this.schema = schemaStringToListOfValueTypes(schema);
+ }
+
+ /**
+ * Obtain the frame schema
+ *
+ * @return the frame schema as a list of {@code ValueType} values
+ */
+ public List<ValueType> getSchema() {
+ return schema;
+ }
+
+ /**
+ * Set the frame schema
+ *
+ * @param schema
+ * the frame schema
+ */
+ public void setSchema(List<ValueType> schema) {
+ this.schema = schema;
+ }
+
+ /**
+ * Set the frame schema, specifying the frame schema as a comma-separated
+ * string
+ *
+ * @param schema
+ * the frame schema as a string
+ */
+ public void setSchemaAsString(String schema) {
+ this.schema = schemaStringToListOfValueTypes(schema);
+ }
+
+ /**
+ * Convert a schema string to a list of {@code ValueType} values
+ *
+ * @param schemaString
+ * the frame schema as a string
+ * @return the frame schema as a list of {@code ValueType} values
+ */
+ private List<ValueType> schemaStringToListOfValueTypes(String schemaString) {
+ if (StringUtils.isBlank(schemaString)) {
+ return null;
+ }
+ String[] cols = schemaString.split(",");
+ List<ValueType> list = new ArrayList<ValueType>();
+ for (String col : cols) {
+ list.add(ValueType.valueOf(col.toUpperCase()));
+ }
+ return list;
+ }
+
+ /**
+ * Obtain the schema as a comma-separated string
+ *
+ * @return the frame schema as a string
+ */
+ public String getSchemaAsString() {
+ if ((schema == null) || (schema.size() == 0)) {
+ return null;
+ }
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < schema.size(); i++) {
+ ValueType vt = schema.get(i);
+ sb.append(vt);
+ if (i + 1 < schema.size()) {
+ sb.append(",");
+ }
+ }
+ return sb.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 7feb86a..b0f8432 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -198,16 +198,16 @@ public class MLContextConversionUtil {
* name of the variable associated with the frame
* @param frameBlock
* frame as a FrameBlock
- * @param matrixMetadata
- * the matrix metadata
+ * @param frameMetadata
+ * the frame metadata
* @return the {@code FrameBlock} converted to a {@code FrameObject}
*/
public static FrameObject frameBlockToFrameObject(String variableName, FrameBlock frameBlock,
- MatrixMetadata matrixMetadata) {
+ FrameMetadata frameMetadata) {
try {
MatrixCharacteristics matrixCharacteristics;
- if (matrixMetadata != null) {
- matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
+ if (frameMetadata != null) {
+ matrixCharacteristics = frameMetadata.asMatrixCharacteristics();
} else {
matrixCharacteristics = new MatrixCharacteristics();
}
@@ -273,16 +273,15 @@ public class MLContextConversionUtil {
}
/**
- * Convert a {@code JavaPairRDD<Long, FrameBlock>} to a
- * {@code FrameObject}.
+ * Convert a {@code JavaPairRDD<Long, FrameBlock>} to a {@code FrameObject}.
*
* @param variableName
* name of the variable associated with the frame
* @param binaryBlocks
- * {@code JavaPairRDD<Long, FrameBlock>} representation
- * of a binary-block frame
- * @return the {@code JavaPairRDD<Long, FrameBlock>} frame
- * converted to a {@code FrameObject}
+ * {@code JavaPairRDD<Long, FrameBlock>} representation of a
+ * binary-block frame
+ * @return the {@code JavaPairRDD<Long, FrameBlock>} frame converted to a
+ * {@code FrameObject}
*/
public static FrameObject binaryBlocksToFrameObject(String variableName,
JavaPairRDD<Long, FrameBlock> binaryBlocks) {
@@ -290,33 +289,32 @@ public class MLContextConversionUtil {
}
/**
- * Convert a {@code JavaPairRDD<Long, FrameBlock>} to a
- * {@code FrameObject}.
+ * Convert a {@code JavaPairRDD<Long, FrameBlock>} to a {@code FrameObject}.
*
* @param variableName
* name of the variable associated with the frame
* @param binaryBlocks
- * {@code JavaPairRDD<Long, FrameBlock>} representation
- * of a binary-block frame
- * @param matrixMetadata
- * the matrix metadata
- * @return the {@code JavaPairRDD<Long, FrameBlock>} frame
- * converted to a {@code FrameObject}
+ * {@code JavaPairRDD<Long, FrameBlock>} representation of a
+ * binary-block frame
+ * @param frameMetadata
+ * the frame metadata
+ * @return the {@code JavaPairRDD<Long, FrameBlock>} frame converted to a
+ * {@code FrameObject}
*/
- public static FrameObject binaryBlocksToFrameObject(String variableName,
- JavaPairRDD<Long, FrameBlock> binaryBlocks, MatrixMetadata matrixMetadata) {
+ public static FrameObject binaryBlocksToFrameObject(String variableName, JavaPairRDD<Long, FrameBlock> binaryBlocks,
+ FrameMetadata frameMetadata) {
MatrixCharacteristics matrixCharacteristics;
- if (matrixMetadata != null) {
- matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
+ if (frameMetadata != null) {
+ matrixCharacteristics = frameMetadata.asMatrixCharacteristics();
} else {
matrixCharacteristics = new MatrixCharacteristics();
}
- MatrixFormatMetaData mtd = new MatrixFormatMetaData(matrixCharacteristics,
- OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
- FrameObject frameObject = new FrameObject(MLContextUtil.scratchSpace() + "/" + "temp_"
- + System.nanoTime() + variableName, mtd);
+ MatrixFormatMetaData mtd = new MatrixFormatMetaData(matrixCharacteristics, OutputInfo.BinaryBlockOutputInfo,
+ InputInfo.BinaryBlockInputInfo);
+ FrameObject frameObject = new FrameObject(
+ MLContextUtil.scratchSpace() + "/" + "temp_" + System.nanoTime() + variableName, mtd);
frameObject.setRDDHandle(new RDDObject(binaryBlocks, variableName));
return frameObject;
}
@@ -352,8 +350,8 @@ public class MLContextConversionUtil {
if (matrixMetadata == null) {
matrixMetadata = new MatrixMetadata();
}
- JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = MLContextConversionUtil.dataFrameToBinaryBlocks(dataFrame,
- matrixMetadata);
+ JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = MLContextConversionUtil
+ .dataFrameToMatrixBinaryBlocks(dataFrame, matrixMetadata);
MatrixObject matrixObject = MLContextConversionUtil.binaryBlocksToMatrixObject(variableName, binaryBlock,
matrixMetadata);
return matrixObject;
@@ -368,9 +366,8 @@ public class MLContextConversionUtil {
* the Spark {@code DataFrame}
* @return the {@code DataFrame} matrix converted to a converted to a
* {@code FrameObject}
- * @throws DMLRuntimeException
*/
- public static FrameObject dataFrameToFrameObject(String variableName, DataFrame dataFrame) throws DMLRuntimeException {
+ public static FrameObject dataFrameToFrameObject(String variableName, DataFrame dataFrame) {
return dataFrameToFrameObject(variableName, dataFrame, null);
}
@@ -381,28 +378,38 @@ public class MLContextConversionUtil {
* name of the variable associated with the frame
* @param dataFrame
* the Spark {@code DataFrame}
- * @param matrixMetadata
- * the matrix metadata
+ * @param frameMetadata
+ * the frame metadata
* @return the {@code DataFrame} frame converted to a converted to a
* {@code FrameObject}
- * @throws DMLRuntimeException
*/
public static FrameObject dataFrameToFrameObject(String variableName, DataFrame dataFrame,
- MatrixMetadata matrixMetadata) throws DMLRuntimeException {
- if (matrixMetadata == null) {
- matrixMetadata = new MatrixMetadata();
- }
+ FrameMetadata frameMetadata) {
+ try {
+ if (frameMetadata == null) {
+ frameMetadata = new FrameMetadata();
+ }
- JavaSparkContext javaSparkContext = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContext());
- boolean containsID = isDataFrameWithIDColumn(matrixMetadata);
- MatrixCharacteristics matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
- JavaPairRDD<Long, FrameBlock> binaryBlock =
- FrameRDDConverterUtils.dataFrameToBinaryBlock(javaSparkContext, dataFrame,
- matrixCharacteristics, containsID);
+ JavaSparkContext javaSparkContext = MLContextUtil
+ .getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContext());
+ boolean containsID = isDataFrameWithIDColumn(frameMetadata);
+ MatrixCharacteristics matrixCharacteristics = frameMetadata.asMatrixCharacteristics();
+ if (matrixCharacteristics == null) {
+ matrixCharacteristics = new MatrixCharacteristics();
+ long rows = dataFrame.count();
+ int cols = dataFrame.columns().length;
+ matrixCharacteristics.setDimension(rows, cols);
+ frameMetadata.setMatrixCharacteristics(matrixCharacteristics);
+ }
+ JavaPairRDD<Long, FrameBlock> binaryBlock = FrameRDDConverterUtils.dataFrameToBinaryBlock(javaSparkContext,
+ dataFrame, matrixCharacteristics, containsID);
- FrameObject frameObject = MLContextConversionUtil.binaryBlocksToFrameObject(variableName, binaryBlock,
- matrixMetadata);
- return frameObject;
+ FrameObject frameObject = MLContextConversionUtil.binaryBlocksToFrameObject(variableName, binaryBlock,
+ frameMetadata);
+ return frameObject;
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("Exception converting DataFrame to FrameObject", e);
+ }
}
/**
@@ -415,8 +422,8 @@ public class MLContextConversionUtil {
* {@code JavaPairRDD<MatrixIndexes,
* MatrixBlock>} binary-block matrix
*/
- public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlocks(DataFrame dataFrame) {
- return dataFrameToBinaryBlocks(dataFrame, null);
+ public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToMatrixBinaryBlocks(DataFrame dataFrame) {
+ return dataFrameToMatrixBinaryBlocks(dataFrame, null);
}
/**
@@ -431,7 +438,7 @@ public class MLContextConversionUtil {
* {@code JavaPairRDD<MatrixIndexes,
* MatrixBlock>} binary-block matrix
*/
- public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlocks(DataFrame dataFrame,
+ public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToMatrixBinaryBlocks(DataFrame dataFrame,
MatrixMetadata matrixMetadata) {
determineMatrixFormatIfNeeded(dataFrame, matrixMetadata);
@@ -467,6 +474,23 @@ public class MLContextConversionUtil {
}
/**
+ * Convert a {@code DataFrame} to a {@code JavaPairRDD<Long, FrameBlock>}
+ * binary-block frame.
+ *
+ * @param dataFrame
+ * the Spark {@code DataFrame}
+ * @param frameMetadata
+ * the frame metadata
+ * @return the {@code DataFrame} matrix converted to a
+ * {@code JavaPairRDD<Long,
+ * FrameBlock>} binary-block frame
+ */
+ public static JavaPairRDD<Long, FrameBlock> dataFrameToFrameBinaryBlocks(DataFrame dataFrame,
+ FrameMetadata frameMetadata) {
+ throw new MLContextException("dataFrameToFrameBinaryBlocks is unimplemented");
+ }
+
+ /**
* If the MatrixFormat of the DataFrame has not been explicitly specified,
* attempt to determine the proper MatrixFormat.
*
@@ -530,6 +554,25 @@ public class MLContextConversionUtil {
}
/**
+ * Return whether or not the DataFrame has an ID column.
+ *
+ * @param frameMetadata
+ * the frame metadata
+ * @return {@code true} if the DataFrame has an ID column, {@code false}
+ * otherwise.
+ */
+ public static boolean isDataFrameWithIDColumn(FrameMetadata frameMetadata) {
+ if (frameMetadata == null) {
+ return false;
+ }
+ FrameFormat frameFormat = frameMetadata.getFrameFormat();
+ if (frameFormat == null) {
+ return false;
+ }
+ return frameFormat.hasIDColumn();
+ }
+
+ /**
* Return whether or not the DataFrame is vector-based.
*
* @param matrixMetadata
@@ -645,27 +688,29 @@ public class MLContextConversionUtil {
* name of the variable associated with the frame
* @param javaRDD
* the Java RDD of strings
- * @param matrixMetadata
- * matrix metadata
+ * @param frameMetadata
+ * frame metadata
* @return the {@code JavaRDD<String>} converted to a {@code FrameObject}
*/
public static FrameObject javaRDDStringCSVToFrameObject(String variableName, JavaRDD<String> javaRDD,
- MatrixMetadata matrixMetadata) {
+ FrameMetadata frameMetadata) {
JavaPairRDD<LongWritable, Text> javaPairRDD = javaRDD.mapToPair(new ConvertStringToLongTextPair());
MatrixCharacteristics matrixCharacteristics;
- if (matrixMetadata != null) {
- matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
+ if (frameMetadata != null) {
+ matrixCharacteristics = frameMetadata.asMatrixCharacteristics();
} else {
matrixCharacteristics = new MatrixCharacteristics();
}
JavaPairRDD<LongWritable, Text> javaPairRDDText = javaPairRDD.mapToPair(new CopyTextInputFunction());
-
+
JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContext());
- FrameObject frameObject = new FrameObject(null, new MatrixFormatMetaData(matrixCharacteristics, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
+ FrameObject frameObject = new FrameObject(null, new MatrixFormatMetaData(matrixCharacteristics,
+ OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
JavaPairRDD<Long, FrameBlock> rdd;
try {
- rdd = FrameRDDConverterUtils.csvToBinaryBlock(jsc, javaPairRDDText, matrixCharacteristics, false, ",", false, -1);
+ rdd = FrameRDDConverterUtils.csvToBinaryBlock(jsc, javaPairRDDText, matrixCharacteristics, false, ",",
+ false, -1);
} catch (DMLRuntimeException e) {
e.printStackTrace();
return null;
@@ -710,30 +755,31 @@ public class MLContextConversionUtil {
* name of the variable associated with the frame
* @param javaRDD
* the Java RDD of strings
- * @param matrixMetadata
- * matrix metadata
+ * @param frameMetadata
+ * frame metadata
* @return the {@code JavaRDD<String>} converted to a {@code FrameObject}
*/
public static FrameObject javaRDDStringIJVToFrameObject(String variableName, JavaRDD<String> javaRDD,
- MatrixMetadata matrixMetadata) {
+ FrameMetadata frameMetadata) {
JavaPairRDD<LongWritable, Text> javaPairRDD = javaRDD.mapToPair(new ConvertStringToLongTextPair());
MatrixCharacteristics matrixCharacteristics;
- if (matrixMetadata != null) {
- matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
+ if (frameMetadata != null) {
+ matrixCharacteristics = frameMetadata.asMatrixCharacteristics();
} else {
matrixCharacteristics = new MatrixCharacteristics();
}
-
+
JavaPairRDD<LongWritable, Text> javaPairRDDText = javaPairRDD.mapToPair(new CopyTextInputFunction());
-
+
JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContext());
- FrameObject frameObject = new FrameObject(null, new MatrixFormatMetaData(matrixCharacteristics, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
+ FrameObject frameObject = new FrameObject(null, new MatrixFormatMetaData(matrixCharacteristics,
+ OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
JavaPairRDD<Long, FrameBlock> rdd;
try {
List<ValueType> lschema = null;
- if(lschema == null)
- lschema = Collections.nCopies((int)matrixCharacteristics.getCols(), ValueType.STRING);
+ if (lschema == null)
+ lschema = Collections.nCopies((int) matrixCharacteristics.getCols(), ValueType.STRING);
rdd = FrameRDDConverterUtils.textCellToBinaryBlock(jsc, javaPairRDDText, matrixCharacteristics, lschema);
} catch (DMLRuntimeException e) {
e.printStackTrace();
@@ -794,15 +840,15 @@ public class MLContextConversionUtil {
* name of the variable associated with the frame
* @param rdd
* the RDD of strings
- * @param matrixMetadata
+ * @param frameMetadata
* frame metadata
* @return the {@code RDD<String>} converted to a {@code FrameObject}
*/
public static FrameObject rddStringCSVToFrameObject(String variableName, RDD<String> rdd,
- MatrixMetadata matrixMetadata) {
+ FrameMetadata frameMetadata) {
ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class);
JavaRDD<String> javaRDD = JavaRDD.fromRDD(rdd, tag);
- return javaRDDStringCSVToFrameObject(variableName, javaRDD, matrixMetadata);
+ return javaRDDStringCSVToFrameObject(variableName, javaRDD, frameMetadata);
}
/**
@@ -832,15 +878,15 @@ public class MLContextConversionUtil {
* name of the variable associated with the frame
* @param rdd
* the RDD of strings
- * @param matrixMetadata
+ * @param frameMetadata
* frame metadata
* @return the {@code RDD<String>} converted to a {@code FrameObject}
*/
public static FrameObject rddStringIJVToFrameObject(String variableName, RDD<String> rdd,
- MatrixMetadata matrixMetadata) {
+ FrameMetadata frameMetadata) {
ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class);
JavaRDD<String> javaRDD = JavaRDD.fromRDD(rdd, tag);
- return javaRDDStringIJVToFrameObject(variableName, javaRDD, matrixMetadata);
+ return javaRDDStringIJVToFrameObject(variableName, javaRDD, frameMetadata);
}
/**
@@ -898,8 +944,7 @@ public class MLContextConversionUtil {
}
/**
- * Convert a {@code FrameObject} to a {@code JavaRDD<String>} in CSV
- * format.
+ * Convert a {@code FrameObject} to a {@code JavaRDD<String>} in CSV format.
*
* @param frameObject
* the {@code FrameObject}
@@ -908,7 +953,7 @@ public class MLContextConversionUtil {
public static JavaRDD<String> frameObjectToJavaRDDStringCSV(FrameObject frameObject, String delimiter) {
List<String> list = frameObjectToListStringCSV(frameObject, delimiter);
- JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext)MLContextProxy.getActiveMLContext());
+ JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContext());
JavaRDD<String> javaRDDStringCSV = jsc.parallelize(list);
return javaRDDStringCSV;
}
@@ -933,8 +978,7 @@ public class MLContextConversionUtil {
}
/**
- * Convert a {@code FrameObject} to a {@code JavaRDD<String>} in IJV
- * format.
+ * Convert a {@code FrameObject} to a {@code JavaRDD<String>} in IJV format.
*
* @param frameObject
* the {@code FrameObject}
@@ -943,7 +987,7 @@ public class MLContextConversionUtil {
public static JavaRDD<String> frameObjectToJavaRDDStringIJV(FrameObject frameObject) {
List<String> list = frameObjectToListStringIJV(frameObject);
- JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext)MLContextProxy.getActiveMLContext());
+ JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContext());
JavaRDD<String> javaRDDStringIJV = jsc.parallelize(list);
return javaRDDStringIJV;
}
@@ -1343,5 +1387,49 @@ public class MLContextConversionUtil {
throw new MLContextException("DMLRuntimeException while converting matrix object to BinaryBlockMatrix", e);
}
}
-
+
+ /**
+ * Convert a {@code FrameObject} to a {@code BinaryBlockFrame}.
+ *
+ * @param frameObject
+ * the {@code FrameObject}
+ * @param sparkExecutionContext
+ * the Spark execution context
+ * @return the {@code FrameObject} converted to a {@code BinaryBlockFrame}
+ */
+ public static BinaryBlockFrame frameObjectToBinaryBlockFrame(FrameObject frameObject,
+ SparkExecutionContext sparkExecutionContext) {
+ try {
+ @SuppressWarnings("unchecked")
+ JavaPairRDD<Long, FrameBlock> binaryBlock = (JavaPairRDD<Long, FrameBlock>) sparkExecutionContext
+ .getRDDHandleForFrameObject(frameObject, InputInfo.BinaryBlockInputInfo);
+ MatrixCharacteristics matrixCharacteristics = frameObject.getMatrixCharacteristics();
+ FrameSchema fs = new FrameSchema(frameObject.getSchema());
+ FrameMetadata fm = new FrameMetadata(fs, matrixCharacteristics);
+ BinaryBlockFrame binaryBlockFrame = new BinaryBlockFrame(binaryBlock, fm);
+ return binaryBlockFrame;
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("DMLRuntimeException while converting frame object to BinaryBlockFrame", e);
+ }
+ }
+
+ /**
+ * Convert a {@code FrameObject} to a two-dimensional string array.
+ *
+ * @param frameObject
+ * the {@code FrameObject}
+ * @return the {@code FrameObject} converted to a {@code String[][]}
+ */
+ public static String[][] frameObjectTo2DStringArray(FrameObject frameObject) {
+ try {
+ FrameBlock fb = frameObject.acquireRead();
+ String[][] frame = DataConverter.convertToStringFrame(fb);
+ frameObject.release();
+ return frame;
+ } catch (CacheException e) {
+ throw new MLContextException("CacheException while converting frame object to 2D string array", e);
+ } catch (DMLRuntimeException e) {
+ throw new MLContextException("DMLRuntimeException while converting frame object to 2D string array", e);
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
index 9813174..566fba1 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -25,7 +25,6 @@ import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.HashMap;
-import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
@@ -39,14 +38,15 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
import org.apache.sysml.conf.CompilerConfig;
import org.apache.sysml.conf.CompilerConfig.ConfigType;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.parser.ParseException;
-import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
@@ -76,7 +76,8 @@ public final class MLContextUtil {
*/
@SuppressWarnings("rawtypes")
public static final Class[] COMPLEX_DATA_TYPES = { JavaRDD.class, RDD.class, DataFrame.class,
- BinaryBlockMatrix.class, Matrix.class, (new double[][] {}).getClass(), MatrixBlock.class, URL.class };
+ BinaryBlockMatrix.class, BinaryBlockFrame.class, Matrix.class, Frame.class, (new double[][] {}).getClass(),
+ MatrixBlock.class, URL.class };
/**
* All data types supported by the MLContext API
@@ -157,8 +158,8 @@ public final class MLContextUtil {
*/
public static void verifySparkVersionSupported(SparkContext sc) {
if (!MLContextUtil.isSparkVersionSupported(sc.version())) {
- throw new MLContextException("SystemML requires Spark " + MLContext.SYSTEMML_MINIMUM_SPARK_VERSION
- + " or greater");
+ throw new MLContextException(
+ "SystemML requires Spark " + MLContext.SYSTEMML_MINIMUM_SPARK_VERSION + " or greater");
}
}
@@ -201,40 +202,6 @@ public final class MLContextUtil {
}
/**
- * Convenience method to generate a {@code Map<String, Object>} of key/value
- * pairs.
- * <p>
- * Example:<br>
- * {@code Map<String, Object> inputMap = MLContextUtil.generateInputMap("A", 1, "B", "two", "C", 3);}
- * <br>
- * <br>
- * This is equivalent to:<br>
- * <code>Map<String, Object> inputMap = new LinkedHashMap<String, Object>(){{
- * <br>put("A", 1);
- * <br>put("B", "two");
- * <br>put("C", 3);
- * <br>}};</code>
- *
- * @param objs
- * List of String/Object pairs
- * @return Map of String/Object pairs
- * @throws MLContextException
- * if the number of arguments is not an even number
- */
- public static Map<String, Object> generateInputMap(Object... objs) {
- int len = objs.length;
- if ((len & 1) == 1) {
- throw new MLContextException("The number of arguments needs to be an even number");
- }
- Map<String, Object> map = new LinkedHashMap<String, Object>();
- int i = 0;
- while (i < len) {
- map.put((String) objs[i++], objs[i++]);
- }
- return map;
- }
-
- /**
* Verify that the types of input values are supported.
*
* @param inputs
@@ -314,8 +281,8 @@ public final class MLContextUtil {
}
}
if (!supported) {
- throw new MLContextException("Input parameter (\"" + parameterName + "\") value type not supported: "
- + o.getClass());
+ throw new MLContextException(
+ "Input parameter (\"" + parameterName + "\") value type not supported: " + o.getClass());
}
}
@@ -412,7 +379,7 @@ public final class MLContextUtil {
* @return input in SystemML data representation
*/
public static Data convertInputType(String parameterName, Object parameterValue) {
- return convertInputType(parameterName, parameterValue, null, false);
+ return convertInputType(parameterName, parameterValue, null);
}
/**
@@ -422,15 +389,16 @@ public final class MLContextUtil {
* The name of the input parameter
* @param parameterValue
* The value of the input parameter
- * @param matrixMetadata
- * matrix metadata
- * @param bFrame
- * if input is of type frame
+ * @param metadata
+ * matrix/frame metadata
* @return input in SystemML data representation
*/
- public static Data convertInputType(String parameterName, Object parameterValue, MatrixMetadata matrixMetadata, boolean bFrame) {
+ public static Data convertInputType(String parameterName, Object parameterValue, Metadata metadata) {
String name = parameterName;
Object value = parameterValue;
+ boolean hasMetadata = (metadata != null) ? true : false;
+ boolean hasMatrixMetadata = hasMetadata && (metadata instanceof MatrixMetadata) ? true : false;
+ boolean hasFrameMetadata = hasMetadata && (metadata instanceof FrameMetadata) ? true : false;
if (name == null) {
throw new MLContextException("Input parameter name is null");
} else if (value == null) {
@@ -438,91 +406,138 @@ public final class MLContextUtil {
} else if (value instanceof JavaRDD<?>) {
@SuppressWarnings("unchecked")
JavaRDD<String> javaRDD = (JavaRDD<String>) value;
- if(!bFrame) {
+
+ if (hasMatrixMetadata) {
+ MatrixMetadata matrixMetadata = (MatrixMetadata) metadata;
MatrixObject matrixObject;
- if ((matrixMetadata != null) && (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV)) {
- matrixObject = MLContextConversionUtil.javaRDDStringIJVToMatrixObject(name, javaRDD, matrixMetadata);
+ if (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV) {
+ matrixObject = MLContextConversionUtil.javaRDDStringIJVToMatrixObject(name, javaRDD,
+ matrixMetadata);
} else {
- matrixObject = MLContextConversionUtil.javaRDDStringCSVToMatrixObject(name, javaRDD, matrixMetadata);
+ matrixObject = MLContextConversionUtil.javaRDDStringCSVToMatrixObject(name, javaRDD,
+ matrixMetadata);
}
return matrixObject;
- } else {
+ } else if (hasFrameMetadata) {
+ FrameMetadata frameMetadata = (FrameMetadata) metadata;
FrameObject frameObject;
- if ((matrixMetadata != null) && (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV)) {
- frameObject = MLContextConversionUtil.javaRDDStringIJVToFrameObject(name, javaRDD, matrixMetadata);
+ if (frameMetadata.getFrameFormat() == FrameFormat.IJV) {
+ frameObject = MLContextConversionUtil.javaRDDStringIJVToFrameObject(name, javaRDD, frameMetadata);
} else {
- frameObject = MLContextConversionUtil.javaRDDStringCSVToFrameObject(name, javaRDD, matrixMetadata);
+ frameObject = MLContextConversionUtil.javaRDDStringCSVToFrameObject(name, javaRDD, frameMetadata);
}
return frameObject;
+ } else if (!hasMetadata) {
+ String firstLine = javaRDD.first();
+ boolean isAllNumbers = isCSVLineAllNumbers(firstLine);
+ if (isAllNumbers) {
+ MatrixObject matrixObject = MLContextConversionUtil.javaRDDStringCSVToMatrixObject(name, javaRDD);
+ return matrixObject;
+ } else {
+ FrameObject frameObject = MLContextConversionUtil.javaRDDStringCSVToFrameObject(name, javaRDD);
+ return frameObject;
+ }
}
+
} else if (value instanceof RDD<?>) {
@SuppressWarnings("unchecked")
RDD<String> rdd = (RDD<String>) value;
- if(!bFrame) {
+
+ if (hasMatrixMetadata) {
+ MatrixMetadata matrixMetadata = (MatrixMetadata) metadata;
MatrixObject matrixObject;
- if ((matrixMetadata != null) && (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV)) {
+ if (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV) {
matrixObject = MLContextConversionUtil.rddStringIJVToMatrixObject(name, rdd, matrixMetadata);
} else {
matrixObject = MLContextConversionUtil.rddStringCSVToMatrixObject(name, rdd, matrixMetadata);
}
return matrixObject;
- } else {
+ } else if (hasFrameMetadata) {
+ FrameMetadata frameMetadata = (FrameMetadata) metadata;
FrameObject frameObject;
- if ((matrixMetadata != null) && (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV)) {
- frameObject = MLContextConversionUtil.rddStringIJVToFrameObject(name, rdd, matrixMetadata);
+ if (frameMetadata.getFrameFormat() == FrameFormat.IJV) {
+ frameObject = MLContextConversionUtil.rddStringIJVToFrameObject(name, rdd, frameMetadata);
} else {
- frameObject = MLContextConversionUtil.rddStringCSVToFrameObject(name, rdd, matrixMetadata);
+ frameObject = MLContextConversionUtil.rddStringCSVToFrameObject(name, rdd, frameMetadata);
}
return frameObject;
+ } else if (!hasMetadata) {
+ String firstLine = rdd.first();
+ boolean isAllNumbers = isCSVLineAllNumbers(firstLine);
+ if (isAllNumbers) {
+ MatrixObject matrixObject = MLContextConversionUtil.rddStringCSVToMatrixObject(name, rdd);
+ return matrixObject;
+ } else {
+ FrameObject frameObject = MLContextConversionUtil.rddStringCSVToFrameObject(name, rdd);
+ return frameObject;
+ }
}
-
} else if (value instanceof MatrixBlock) {
MatrixBlock matrixBlock = (MatrixBlock) value;
MatrixObject matrixObject = MLContextConversionUtil.matrixBlockToMatrixObject(name, matrixBlock,
- matrixMetadata);
+ (MatrixMetadata) metadata);
return matrixObject;
} else if (value instanceof FrameBlock) {
FrameBlock frameBlock = (FrameBlock) value;
FrameObject frameObject = MLContextConversionUtil.frameBlockToFrameObject(name, frameBlock,
- matrixMetadata);
+ (FrameMetadata) metadata);
return frameObject;
} else if (value instanceof DataFrame) {
DataFrame dataFrame = (DataFrame) value;
- if(!bFrame) {
- MatrixObject matrixObject = MLContextConversionUtil
- .dataFrameToMatrixObject(name, dataFrame, matrixMetadata);
+
+ if (hasMatrixMetadata) {
+ MatrixObject matrixObject = MLContextConversionUtil.dataFrameToMatrixObject(name, dataFrame,
+ (MatrixMetadata) metadata);
return matrixObject;
- } else {
- FrameObject frameObject = null;
- try {
- frameObject = MLContextConversionUtil
- .dataFrameToFrameObject(name, dataFrame, matrixMetadata);
- } catch (DMLRuntimeException e) {
- e.printStackTrace();
- }
+ } else if (hasFrameMetadata) {
+ FrameObject frameObject = MLContextConversionUtil.dataFrameToFrameObject(name, dataFrame,
+ (FrameMetadata) metadata);
return frameObject;
+ } else if (!hasMetadata) {
+ Row firstRow = dataFrame.first();
+ boolean looksLikeMatrix = doesRowLookLikeMatrixRow(firstRow);
+ if (looksLikeMatrix) {
+ MatrixObject matrixObject = MLContextConversionUtil.dataFrameToMatrixObject(name, dataFrame);
+ return matrixObject;
+ } else {
+ FrameObject frameObject = MLContextConversionUtil.dataFrameToFrameObject(name, dataFrame);
+ return frameObject;
+ }
}
} else if (value instanceof BinaryBlockMatrix) {
BinaryBlockMatrix binaryBlockMatrix = (BinaryBlockMatrix) value;
- if (matrixMetadata == null) {
- matrixMetadata = binaryBlockMatrix.getMatrixMetadata();
+ if (metadata == null) {
+ metadata = binaryBlockMatrix.getMatrixMetadata();
}
JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks = binaryBlockMatrix.getBinaryBlocks();
MatrixObject matrixObject = MLContextConversionUtil.binaryBlocksToMatrixObject(name, binaryBlocks,
- matrixMetadata);
+ (MatrixMetadata) metadata);
return matrixObject;
+ } else if (value instanceof BinaryBlockFrame) {
+ BinaryBlockFrame binaryBlockFrame = (BinaryBlockFrame) value;
+ if (metadata == null) {
+ metadata = binaryBlockFrame.getFrameMetadata();
+ }
+ JavaPairRDD<Long, FrameBlock> binaryBlocks = binaryBlockFrame.getBinaryBlocks();
+ FrameObject frameObject = MLContextConversionUtil.binaryBlocksToFrameObject(name, binaryBlocks,
+ (FrameMetadata) metadata);
+ return frameObject;
} else if (value instanceof Matrix) {
Matrix matrix = (Matrix) value;
MatrixObject matrixObject = matrix.asMatrixObject();
return matrixObject;
+ } else if (value instanceof Frame) {
+ Frame frame = (Frame) value;
+ FrameObject frameObject = frame.asFrameObject();
+ return frameObject;
} else if (value instanceof double[][]) {
double[][] doubleMatrix = (double[][]) value;
MatrixObject matrixObject = MLContextConversionUtil.doubleMatrixToMatrixObject(name, doubleMatrix,
- matrixMetadata);
+ (MatrixMetadata) metadata);
return matrixObject;
} else if (value instanceof URL) {
URL url = (URL) value;
- MatrixObject matrixObject = MLContextConversionUtil.urlToMatrixObject(name, url, matrixMetadata);
+ MatrixObject matrixObject = MLContextConversionUtil.urlToMatrixObject(name, url, (MatrixMetadata) metadata);
return matrixObject;
} else if (value instanceof Integer) {
Integer i = (Integer) value;
@@ -545,6 +560,56 @@ public final class MLContextUtil {
}
/**
+ * If no metadata is supplied for an RDD or JavaRDD, this method can be used
+ * to determine whether the data appears to be matrix (or a frame)
+ *
+ * @param line
+ * a line of the RDD
+ * @return {@code true} if all the csv-separated values are numbers,
+ * {@code false} otherwise
+ */
+ public static boolean isCSVLineAllNumbers(String line) {
+ if (StringUtils.isBlank(line)) {
+ return false;
+ }
+ String[] parts = line.split(",");
+ for (int i = 0; i < parts.length; i++) {
+ String part = parts[i].trim();
+ try {
+ Double.parseDouble(part);
+ } catch (NumberFormatException e) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * If no metadata is supplied for a DataFrame, this method can be used to
+ * determine whether the data appears to be a matrix (or a frame)
+ *
+ * @param row
+ * a row in the DataFrame
+ * @return {@code true} if the row appears to be a matrix row, {@code false}
+ * otherwise
+ */
+ public static boolean doesRowLookLikeMatrixRow(Row row) {
+ for (int i = 0; i < row.length(); i++) {
+ Object object = row.get(i);
+ if (object instanceof Vector) {
+ return true;
+ }
+ String str = object.toString();
+ try {
+ Double.parseDouble(str);
+ } catch (NumberFormatException e) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
* Return the default matrix block size.
*
* @return the default matrix block size
@@ -559,7 +624,7 @@ public final class MLContextUtil {
* @return the lcoation of the scratch space directory
*/
public static String scratchSpace() {
- return ConfigurationManager.getScratchSpace();
+ return ConfigurationManager.getScratchSpace();
}
/**
@@ -738,7 +803,7 @@ public final class MLContextUtil {
* the map of inputs
* @return the script inputs represented as a String
*/
- public static String displayInputs(String name, Map<String, Object> map) {
+ public static String displayInputs(String name, Map<String, Object> map, LocalVariableMap symbolTable) {
StringBuilder sb = new StringBuilder();
sb.append(name);
sb.append(":\n");
@@ -764,6 +829,11 @@ public final class MLContextUtil {
sb.append(" (");
sb.append(type);
+ if (doesSymbolTableContainMatrixObject(symbolTable, key)) {
+ sb.append(" as Matrix");
+ } else if (doesSymbolTableContainFrameObject(symbolTable, key)) {
+ sb.append(" as Frame");
+ }
sb.append(") ");
sb.append(key);
@@ -890,14 +960,78 @@ public final class MLContextUtil {
return sb.toString();
}
- public static SparkContext getSparkContext(MLContext mlContext)
- {
+ /**
+ * Obtain the Spark Context
+ *
+ * @param mlContext
+ * the SystemML MLContext
+ * @return the Spark Context
+ */
+ public static SparkContext getSparkContext(MLContext mlContext) {
return mlContext.getSparkContext();
}
- public static JavaSparkContext getJavaSparkContext(MLContext mlContext)
- {
+ /**
+ * Obtain the Java Spark Context
+ *
+ * @param mlContext
+ * the SystemML MLContext
+ * @return the Java Spark Context
+ */
+ public static JavaSparkContext getJavaSparkContext(MLContext mlContext) {
return new JavaSparkContext(mlContext.getSparkContext());
}
+ /**
+ * Determine if the symbol table contains a FrameObject with the given
+ * variable name.
+ *
+ * @param symbolTable
+ * the LocalVariableMap
+ * @param variableName
+ * the variable name
+ * @return {@code true} if the variable in the symbol table is a
+ * FrameObject, {@code false} otherwise.
+ */
+ public static boolean doesSymbolTableContainFrameObject(LocalVariableMap symbolTable, String variableName) {
+ if (symbolTable == null) {
+ return false;
+ }
+ Data data = symbolTable.get(variableName);
+ if (data == null) {
+ return false;
+ }
+ if (data instanceof FrameObject) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ /**
+ * Determine if the symbol table contains a MatrixObject with the given
+ * variable name.
+ *
+ * @param symbolTable
+ * the LocalVariableMap
+ * @param variableName
+ * the variable name
+ * @return {@code true} if the variable in the symbol table is a
+ * MatrixObject, {@code false} otherwise.
+ */
+ public static boolean doesSymbolTableContainMatrixObject(LocalVariableMap symbolTable, String variableName) {
+ if (symbolTable == null) {
+ return false;
+ }
+ Data data = symbolTable.get(variableName);
+ if (data == null) {
+ return false;
+ }
+ if (data instanceof MatrixObject) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
}