You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ac...@apache.org on 2016/08/29 06:01:55 UTC
incubator-systemml git commit: [SYSTEMML-568] Frame MLContext support
Repository: incubator-systemml
Updated Branches:
refs/heads/master 4cbb02819 -> 02a9f2770
[SYSTEMML-568] Frame MLContext support
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/02a9f277
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/02a9f277
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/02a9f277
Branch: refs/heads/master
Commit: 02a9f277000bd144c729311dac6c04bcb520180f
Parents: 4cbb028
Author: Arvind Surve <ac...@yahoo.com>
Authored: Sun Aug 28 23:01:20 2016 -0700
Committer: Arvind Surve <ac...@yahoo.com>
Committed: Sun Aug 28 23:01:20 2016 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/api/MLContext.java | 218 ++++++++++--
.../java/org/apache/sysml/api/MLOutput.java | 49 ++-
.../api/mlcontext/MLContextConversionUtil.java | 34 ++
.../sysml/api/mlcontext/MLContextUtil.java | 6 +
.../sysml/runtime/util/UtilFunctions.java | 40 +++
.../functions/frame/FrameConverterTest.java | 33 +-
.../functions/mlcontext/FrameTest.java | 351 +++++++++++++++++++
src/test/scripts/functions/frame/FrameGeneral.R | 35 ++
.../scripts/functions/frame/FrameGeneral.dml | 30 ++
9 files changed, 732 insertions(+), 64 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/main/java/org/apache/sysml/api/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLContext.java b/src/main/java/org/apache/sysml/api/MLContext.java
index 405478f..8f6e95f 100644
--- a/src/main/java/org/apache/sysml/api/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/MLContext.java
@@ -23,6 +23,7 @@ package org.apache.sysml.api;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Scanner;
@@ -62,6 +63,7 @@ import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
@@ -73,11 +75,13 @@ import org.apache.sysml.runtime.instructions.spark.functions.ConvertStringToLong
import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockPairFunction;
import org.apache.sysml.runtime.instructions.spark.functions.CopyTextInputFunction;
import org.apache.sysml.runtime.instructions.spark.functions.SparkListener;
+import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties;
import org.apache.sysml.runtime.matrix.data.FileFormatProperties;
+import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
@@ -263,6 +267,21 @@ public class MLContext {
}
/**
+ * Register DataFrame as input. DataFrame is assumed to be in row format and each cell can be converted into
+ * SystemML frame row. Each column could be of type, Double, Float, Long, Integer, String or Boolean.
+ * <p>
+ * Marks the variable in the DML script as input variable.
+ * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
+ * would have been created by reading a HDFS file.
+ * @param varName
+ * @param df
+ * @throws DMLRuntimeException
+ */
+ public void registerFrameInput(String varName, DataFrame df) throws DMLRuntimeException {
+ registerFrameInput(varName, df, false);
+ }
+
+ /**
* Register DataFrame as input.
* Marks the variable in the DML script as input variable.
* Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
@@ -279,6 +298,21 @@ public class MLContext {
}
/**
+ * Register DataFrame as input. DataFrame is assumed to be in row format and each cell can be converted into
+ * SystemML frame row. Each column could be of type, Double, Float, Long, Integer, String or Boolean.
+ * <p>
+ * @param varName
+ * @param df
+ * @param containsID false if the DataFrame has an column ID which denotes the row ID.
+ * @throws DMLRuntimeException
+ */
+ public void registerFrameInput(String varName, DataFrame df, boolean containsID) throws DMLRuntimeException {
+ MatrixCharacteristics mcOut = new MatrixCharacteristics();
+ JavaPairRDD<Long, FrameBlock> rdd = FrameRDDConverterUtils.dataFrameToBinaryBlock(new JavaSparkContext(_sc), df, mcOut, containsID);
+ registerInput(varName, rdd, mcOut.getRows(), mcOut.getCols(), null);
+ }
+
+ /**
* Experimental API. Not supported in Python MLContext API.
* @param varName
* @param df
@@ -520,6 +554,87 @@ public class MLContext {
checkIfRegisteringInputAllowed();
}
+ /**
+ * Register Frame with CSV/Text as inputs: with dimensions.
+ * File properties (example: delim, fill, ..) can be specified through props else defaults will be used.
+ * <p>
+ * Marks the variable in the DML script as input variable.
+ * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
+ * would have been created by reading a HDFS file.
+ * @param varName
+ * @param rdd
+ * @param format
+ * @param rlen
+ * @param clen
+ * @param props
+ * @schema schema
+ * List of column types.
+ * @throws DMLRuntimeException
+ */
+ public void registerInput(String varName, JavaRDD<String> rddIn, String format, long rlen, long clen, FileFormatProperties props,
+ List<ValueType> schema) throws DMLRuntimeException {
+ if(!(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)) {
+ throw new DMLRuntimeException("The registerInput functionality only supported for spark runtime. Please use MLContext(sc) instead of default constructor.");
+ }
+
+ long nnz = -1;
+ if(_variables == null)
+ _variables = new LocalVariableMap();
+ if(_inVarnames == null)
+ _inVarnames = new ArrayList<String>();
+
+ JavaPairRDD<LongWritable, Text> rddText = rddIn.mapToPair(new ConvertStringToLongTextPair());
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(rlen, clen, OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, nnz);
+ FrameObject fo = new FrameObject(null, new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
+ JavaPairRDD<Long, FrameBlock> rdd = null;
+ if( format.equals("csv") ) {
+ //TODO replace default block size
+
+ rdd = FrameRDDConverterUtils.csvToBinaryBlock(new JavaSparkContext(getSparkContext()), rddText, mc, false, ",", false, -1);
+ }
+ else if( format.equals("text") ) {
+ if(rlen == -1 || clen == -1) {
+ throw new DMLRuntimeException("The metadata is required in registerInput for format:" + format);
+ }
+ //TODO replace default block size
+ rdd = FrameRDDConverterUtils.textCellToBinaryBlock(new JavaSparkContext(getSparkContext()), rddText, mc, schema);
+ }
+ else {
+
+ throw new DMLRuntimeException("Incorrect format in registerInput: " + format);
+ }
+ if(props != null)
+ fo.setFileFormatProperties(props);
+
+ fo.setRDDHandle(new RDDObject(rdd, varName));
+ _variables.put(varName, fo);
+ _inVarnames.add(varName);
+ checkIfRegisteringInputAllowed();
+ }
+
+ private void registerInput(String varName, JavaPairRDD<Long, FrameBlock> rdd, long rlen, long clen, FileFormatProperties props) throws DMLRuntimeException {
+ if(!(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)) {
+ throw new DMLRuntimeException("The registerInput functionality only supported for spark runtime. Please use MLContext(sc) instead of default constructor.");
+ }
+
+ if(_variables == null)
+ _variables = new LocalVariableMap();
+ if(_inVarnames == null)
+ _inVarnames = new ArrayList<String>();
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(rlen, clen, OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, -1);
+ FrameObject fo = new FrameObject(null, new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
+
+ if(props != null)
+ fo.setFileFormatProperties(props);
+
+ fo.setRDDHandle(new RDDObject(rdd, varName));
+ _variables.put(varName, fo);
+ _inVarnames.add(varName);
+ checkIfRegisteringInputAllowed();
+ }
+
// ------------------------------------------------------------------------------------
// 3. Binary blocked RDD: Support JavaPairRDD<MatrixIndexes,MatrixBlock>
@@ -1008,37 +1123,70 @@ public class MLContext {
// Do not check metadata file for registered reads
((DataExpression) source).setCheckMetadata(false);
- MatrixObject mo = null;
- try {
- mo = getMatrixObject(target);
- int blp = source.getBeginLine(); int bcp = source.getBeginColumn();
- int elp = source.getEndLine(); int ecp = source.getEndColumn();
- ((DataExpression) source).addVarParam(DataExpression.READROWPARAM, new IntIdentifier(mo.getNumRows(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.READCOLPARAM, new IntIdentifier(mo.getNumColumns(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.READNUMNONZEROPARAM, new IntIdentifier(mo.getNnz(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier("matrix", source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier("double", source.getFilename(), blp, bcp, elp, ecp));
+ if (((DataExpression)source).getDataType() == Expression.DataType.MATRIX) {
+
+ MatrixObject mo = null;
- if(mo.getMetaData() instanceof MatrixFormatMetaData) {
- MatrixFormatMetaData metaData = (MatrixFormatMetaData) mo.getMetaData();
- if(metaData.getOutputInfo() == OutputInfo.CSVOutputInfo) {
- ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_CSV, source.getFilename(), blp, bcp, elp, ecp));
- }
- else if(metaData.getOutputInfo() == OutputInfo.TextCellOutputInfo) {
- ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_TEXT, source.getFilename(), blp, bcp, elp, ecp));
- }
- else if(metaData.getOutputInfo() == OutputInfo.BinaryBlockOutputInfo) {
- ((DataExpression) source).addVarParam(DataExpression.ROWBLOCKCOUNTPARAM, new IntIdentifier(mo.getNumRowsPerBlock(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.COLUMNBLOCKCOUNTPARAM, new IntIdentifier(mo.getNumColumnsPerBlock(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_BINARY, source.getFilename(), blp, bcp, elp, ecp));
+ try {
+ mo = getMatrixObject(target);
+ int blp = source.getBeginLine(); int bcp = source.getBeginColumn();
+ int elp = source.getEndLine(); int ecp = source.getEndColumn();
+ ((DataExpression) source).addVarParam(DataExpression.READROWPARAM, new IntIdentifier(mo.getNumRows(), source.getFilename(), blp, bcp, elp, ecp));
+ ((DataExpression) source).addVarParam(DataExpression.READCOLPARAM, new IntIdentifier(mo.getNumColumns(), source.getFilename(), blp, bcp, elp, ecp));
+ ((DataExpression) source).addVarParam(DataExpression.READNUMNONZEROPARAM, new IntIdentifier(mo.getNnz(), source.getFilename(), blp, bcp, elp, ecp));
+ ((DataExpression) source).addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier("matrix", source.getFilename(), blp, bcp, elp, ecp));
+ ((DataExpression) source).addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier("double", source.getFilename(), blp, bcp, elp, ecp));
+
+ if(mo.getMetaData() instanceof MatrixFormatMetaData) {
+ MatrixFormatMetaData metaData = (MatrixFormatMetaData) mo.getMetaData();
+ if(metaData.getOutputInfo() == OutputInfo.CSVOutputInfo) {
+ ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_CSV, source.getFilename(), blp, bcp, elp, ecp));
+ }
+ else if(metaData.getOutputInfo() == OutputInfo.TextCellOutputInfo) {
+ ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_TEXT, source.getFilename(), blp, bcp, elp, ecp));
+ }
+ else if(metaData.getOutputInfo() == OutputInfo.BinaryBlockOutputInfo) {
+ ((DataExpression) source).addVarParam(DataExpression.ROWBLOCKCOUNTPARAM, new IntIdentifier(mo.getNumRowsPerBlock(), source.getFilename(), blp, bcp, elp, ecp));
+ ((DataExpression) source).addVarParam(DataExpression.COLUMNBLOCKCOUNTPARAM, new IntIdentifier(mo.getNumColumnsPerBlock(), source.getFilename(), blp, bcp, elp, ecp));
+ ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_BINARY, source.getFilename(), blp, bcp, elp, ecp));
+ }
+ else {
+ throw new LanguageException("Unsupported format through MLContext");
+ }
}
- else {
- throw new LanguageException("Unsupported format through MLContext");
+ } catch (DMLRuntimeException e) {
+ throw new LanguageException(e);
+ }
+ } else if (((DataExpression)source).getDataType() == Expression.DataType.FRAME) {
+ FrameObject mo = null;
+ try {
+ mo = getFrameObject(target);
+ int blp = source.getBeginLine(); int bcp = source.getBeginColumn();
+ int elp = source.getEndLine(); int ecp = source.getEndColumn();
+ ((DataExpression) source).addVarParam(DataExpression.READROWPARAM, new IntIdentifier(mo.getNumRows(), source.getFilename(), blp, bcp, elp, ecp));
+ ((DataExpression) source).addVarParam(DataExpression.READCOLPARAM, new IntIdentifier(mo.getNumColumns(), source.getFilename(), blp, bcp, elp, ecp));
+ ((DataExpression) source).addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier("frame", source.getFilename(), blp, bcp, elp, ecp));
+ ((DataExpression) source).addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier("double", source.getFilename(), blp, bcp, elp, ecp)); //TODO change to schema
+
+ if(mo.getMetaData() instanceof MatrixFormatMetaData) {
+ MatrixFormatMetaData metaData = (MatrixFormatMetaData) mo.getMetaData();
+ if(metaData.getOutputInfo() == OutputInfo.CSVOutputInfo) {
+ ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_CSV, source.getFilename(), blp, bcp, elp, ecp));
+ }
+ else if(metaData.getOutputInfo() == OutputInfo.TextCellOutputInfo) {
+ ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_TEXT, source.getFilename(), blp, bcp, elp, ecp));
+ }
+ else if(metaData.getOutputInfo() == OutputInfo.BinaryBlockOutputInfo) {
+ ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_BINARY, source.getFilename(), blp, bcp, elp, ecp));
+ }
+ else {
+ throw new LanguageException("Unsupported format through MLContext");
+ }
}
+ } catch (DMLRuntimeException e) {
+ throw new LanguageException(e);
}
- } catch (DMLRuntimeException e) {
- throw new LanguageException(e);
- }
+ }
}
}
@@ -1129,6 +1277,18 @@ public class MLContext {
throw new DMLRuntimeException("ERROR: getMatrixObject not set for variable:" + varName);
}
+ private FrameObject getFrameObject(String varName) throws DMLRuntimeException {
+ if(_variables != null) {
+ Data mo = _variables.get(varName);
+ if(mo instanceof FrameObject) {
+ return (FrameObject) mo;
+ }
+ else {
+ throw new DMLRuntimeException("ERROR: Incorrect type");
+ }
+ }
+ throw new DMLRuntimeException("ERROR: getMatrixObject not set for variable:" + varName);
+ }
private int compareVersion(String versionStr1, String versionStr2) {
Scanner s1 = null;
@@ -1329,7 +1489,7 @@ public class MLContext {
if(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK) {
- Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> retVal = null;
+ Map<String, JavaPairRDD<?,?>> retVal = null;
// Depending on whether registerInput/registerOutput was called initialize the variables
String[] inputs; String[] outputs;
@@ -1361,7 +1521,7 @@ public class MLContext {
for( String ovar : _outVarnames ) {
if( _variables.keySet().contains(ovar) ) {
if(retVal == null) {
- retVal = new HashMap<String, JavaPairRDD<MatrixIndexes,MatrixBlock>>();
+ retVal = new HashMap<String, JavaPairRDD<?,?>>();
}
retVal.put(ovar, ((SparkExecutionContext) ec).getBinaryBlockRDDHandleForVariable(ovar));
outMetadata.put(ovar, ec.getMatrixCharacteristics(ovar)); // For converting output to dataframe
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/main/java/org/apache/sysml/api/MLOutput.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLOutput.java b/src/main/java/org/apache/sysml/api/MLOutput.java
index 55daf17..916a652 100644
--- a/src/main/java/org/apache/sysml/api/MLOutput.java
+++ b/src/main/java/org/apache/sysml/api/MLOutput.java
@@ -27,6 +27,7 @@ import java.util.Map.Entry;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.mllib.linalg.DenseVector;
@@ -41,8 +42,11 @@ import org.apache.spark.sql.types.StructType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
+import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties;
+import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.util.UtilFunctions;
@@ -55,7 +59,7 @@ import scala.Tuple2;
*/
public class MLOutput {
- Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> _outputs;
+ Map<String, JavaPairRDD<?,?>> _outputs;
private Map<String, MatrixCharacteristics> _outMetadata = null;
public MatrixBlock getMatrixBlock(String varName) throws DMLRuntimeException {
@@ -66,14 +70,32 @@ public class MLOutput {
mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
return mb;
}
- public MLOutput(Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> outputs, Map<String, MatrixCharacteristics> outMetadata) {
+
+ public MLOutput(Map<String, JavaPairRDD<?,?>> outputs, Map<String, MatrixCharacteristics> outMetadata) {
this._outputs = outputs;
this._outMetadata = outMetadata;
}
+ @SuppressWarnings("unchecked")
public JavaPairRDD<MatrixIndexes,MatrixBlock> getBinaryBlockedRDD(String varName) throws DMLRuntimeException {
if(_outputs.containsKey(varName)) {
- return _outputs.get(varName);
+ JavaPairRDD<?,?> tmp = _outputs.get(varName);
+ if (tmp.first()._2() instanceof MatrixBlock)
+ return (JavaPairRDD<MatrixIndexes,MatrixBlock>)tmp;
+ else
+ return null;
+ }
+ throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
+ }
+
+ @SuppressWarnings("unchecked")
+ public JavaPairRDD<Long,FrameBlock> getFrameBinaryBlockedRDD(String varName) throws DMLRuntimeException {
+ if(_outputs.containsKey(varName)) {
+ JavaPairRDD<?,?> tmp = _outputs.get(varName);
+ if (tmp.first()._2() instanceof FrameBlock)
+ return (JavaPairRDD<Long,FrameBlock>)tmp;
+ else
+ return null;
}
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
}
@@ -197,6 +219,27 @@ public class MLOutput {
}
+ public JavaRDD<String> getStringFrameRDD(String varName, String format, CSVFileFormatProperties fprop ) throws DMLRuntimeException {
+ JavaPairRDD<Long, FrameBlock> binaryRDD = getFrameBinaryBlockedRDD(varName);
+ MatrixCharacteristics mcIn = getMatrixCharacteristics(varName);
+ if(format.equals("csv")) {
+ return FrameRDDConverterUtils.binaryBlockToCsv(binaryRDD, mcIn, fprop, false);
+ }
+ else if(format.equals("text")) {
+ return FrameRDDConverterUtils.binaryBlockToTextCell(binaryRDD, mcIn);
+ }
+ else {
+ throw new DMLRuntimeException("The output format:" + format + " is not implemented yet.");
+ }
+
+ }
+
+ public DataFrame getDataFrameRDD(String varName, JavaSparkContext jsc) throws DMLRuntimeException {
+ JavaPairRDD<Long, FrameBlock> binaryRDD = getFrameBinaryBlockedRDD(varName);
+ MatrixCharacteristics mcIn = getMatrixCharacteristics(varName);
+ return FrameRDDConverterUtils.binaryBlockToDataFrame(binaryRDD, mcIn, jsc);
+ }
+
public MLMatrix getMLMatrix(MLContext ml, SQLContext sqlContext, String varName) throws DMLRuntimeException {
if(sqlContext == null) {
throw new DMLRuntimeException("SQLContext is not created.");
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 3a482ef..0c98dea 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -42,6 +42,7 @@ import org.apache.sysml.api.MLContextProxy;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.CacheException;
+import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
@@ -54,6 +55,7 @@ import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.Da
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.DataFrameToBinaryBlockFunction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
+import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -187,6 +189,38 @@ public class MLContextConversionUtil {
}
/**
+ * Convert a {@code FrameBlock} to a {@code FrameObject}.
+ *
+ * @param variableName
+ * name of the variable associated with the frame
+ * @param frameBlock
+ * frame as a FrameBlock
+ * @param matrixMetadata
+ * the matrix metadata
+ * @return the {@code FrameBlock} converted to a {@code FrameObject}
+ */
+ public static FrameObject frameBlockToframeObject(String variableName, FrameBlock frameBlock,
+ MatrixMetadata matrixMetadata) {
+ try {
+ MatrixCharacteristics matrixCharacteristics;
+ if (matrixMetadata != null) {
+ matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
+ } else {
+ matrixCharacteristics = new MatrixCharacteristics();
+ }
+ MatrixFormatMetaData mtd = new MatrixFormatMetaData(matrixCharacteristics,
+ OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
+ FrameObject frameObject = new FrameObject(MLContextUtil.scratchSpace() + "/"
+ + variableName, mtd);
+ frameObject.acquireModify(frameBlock);
+ frameObject.release();
+ return frameObject;
+ } catch (CacheException e) {
+ throw new MLContextException("Exception converting MatrixBlock to MatrixObject", e);
+ }
+ }
+
+ /**
* Convert a {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} to a
* {@code MatrixObject}.
*
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
index ea7857e..120df32 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -53,6 +53,7 @@ import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.DoubleObject;
import org.apache.sysml.runtime.instructions.cp.IntObject;
import org.apache.sysml.runtime.instructions.cp.StringObject;
+import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
@@ -456,6 +457,11 @@ public final class MLContextUtil {
MatrixObject matrixObject = MLContextConversionUtil.matrixBlockToMatrixObject(name, matrixBlock,
matrixMetadata);
return matrixObject;
+ } else if (value instanceof FrameBlock) {
+ FrameBlock frameBlock = (FrameBlock) value;
+ FrameObject frameObject = MLContextConversionUtil.frameBlockToframeObject(name, frameBlock,
+ matrixMetadata);
+ return frameObject;
} else if (value instanceof DataFrame) {
DataFrame dataFrame = (DataFrame) value;
MatrixObject matrixObject = MLContextConversionUtil
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
index 88221f2..4b98f88 100644
--- a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
@@ -23,6 +23,11 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
@@ -659,4 +664,39 @@ public class UtilFunctions
return DataTypes.createStructType(fields);
}
+ /*
+ * It will return JavaRDD<Row> based on csv data input file.
+ */
+ public static JavaRDD<Row> getRowRDD(JavaSparkContext sc, String fnameIn, String separator, List<ValueType> schema)
+ {
+ // Load a text file and convert each line to a java rdd.
+ JavaRDD<String> dataRdd = sc.textFile(fnameIn);
+ return dataRdd.map(new RowGenerator(schema));
+ }
+
+ /*
+ * Row Generator class based on individual line in CSV file.
+ */
+ private static class RowGenerator implements Function<String,Row>
+ {
+ private static final long serialVersionUID = -6736256507697511070L;
+
+ List<ValueType> _schema = null;
+
+ public RowGenerator(List<ValueType> schema)
+ {
+ _schema = schema;
+ }
+
+ @Override
+ public Row call(String record) throws Exception {
+ String[] fields = record.split(",");
+ Object[] objects = new Object[fields.length];
+ for (int i=0; i<fields.length; i++) {
+ objects[i] = UtilFunctions.stringToObject(_schema.get(i), fields[i]);
+ }
+ return RowFactory.create(objects);
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameConverterTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameConverterTest.java b/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameConverterTest.java
index 441a63b..fb076bd 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameConverterTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameConverterTest.java
@@ -27,13 +27,11 @@ import java.util.List;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructType;
import org.apache.sysml.api.DMLScript;
@@ -522,7 +520,7 @@ public class FrameConverterTest extends AutomatedTestBase
//Create DataFrame
SQLContext sqlContext = new SQLContext(sc);
StructType dfSchema = UtilFunctions.convertFrameSchemaToDFSchema(schema);
- JavaRDD<Row> rowRDD = getRowRDD(sc, fnameIn, separator);
+ JavaRDD<Row> rowRDD = UtilFunctions.getRowRDD(sc, fnameIn, separator, schema);
DataFrame df = sqlContext.createDataFrame(rowRDD, dfSchema);
JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils
@@ -552,33 +550,4 @@ public class FrameConverterTest extends AutomatedTestBase
sec.close();
}
-
- /*
- * It will return JavaRDD<Row> based on csv data input file.
- */
- JavaRDD<Row> getRowRDD(JavaSparkContext sc, String fnameIn, String separator)
- {
- // Load a text file and convert each line to a java rdd.
- JavaRDD<String> dataRdd = sc.textFile(fnameIn);
- return dataRdd.map(new RowGenerator());
- }
-
- /*
- * Row Generator class based on individual line in CSV file.
- */
- private static class RowGenerator implements Function<String,Row>
- {
- private static final long serialVersionUID = -6736256507697511070L;
-
- @Override
- public Row call(String record) throws Exception {
- String[] fields = record.split(",");
- Object[] objects = new Object[fields.length];
- for (int i=0; i<fields.length; i++) {
- objects[i] = fields[i];
- }
- return RowFactory.create(objects);
- }
- }
-
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
new file mode 100644
index 0000000..b6184cf
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.mlcontext;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.StructType;
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.api.DMLException;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.api.MLContext;
+import org.apache.sysml.api.MLOutput;
+import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.parser.DataExpression;
+import org.apache.sysml.parser.ParseException;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
+import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties;
+import org.apache.sysml.runtime.matrix.data.FrameBlock;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.runtime.util.MapReduceTool;
+import org.apache.sysml.runtime.util.UtilFunctions;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+
+public class FrameTest extends AutomatedTestBase
+{
+ private final static String TEST_DIR = "functions/frame/";
+ private final static String TEST_NAME = "FrameGeneral";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FrameTest.class.getSimpleName() + "/";
+
+ private final static int min=0;
+ private final static int max=100;
+ private final static int rows = 2245;
+ private final static int cols = 1264;
+
+ private final static double sparsity1 = 1.0;
+ private final static double sparsity2 = 0.35;
+
+ private final static double epsilon=0.0000000001;
+
+
+ private final static List<ValueType> schemaMixedLargeListStr = Collections.nCopies(cols/4, ValueType.STRING);
+ private final static List<ValueType> schemaMixedLargeListDble = Collections.nCopies(cols/4, ValueType.DOUBLE);
+ private final static List<ValueType> schemaMixedLargeListInt = Collections.nCopies(cols/4, ValueType.INT);
+ private final static List<ValueType> schemaMixedLargeListBool = Collections.nCopies(cols/4, ValueType.BOOLEAN);
+ private static ValueType[] schemaMixedLarge = null;
+ static {
+ final List<ValueType> schemaMixedLargeList = new ArrayList<ValueType>(schemaMixedLargeListStr);
+ schemaMixedLargeList.addAll(schemaMixedLargeListDble);
+ schemaMixedLargeList.addAll(schemaMixedLargeListInt);
+ schemaMixedLargeList.addAll(schemaMixedLargeListBool);
+ schemaMixedLarge = new ValueType[schemaMixedLargeList.size()];
+ schemaMixedLarge = (ValueType[]) schemaMixedLargeList.toArray(schemaMixedLarge);
+ }
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,
+ new String[] {"AB", "C"}));
+ }
+
+ @Test
+ public void testCSVInCSVOut() throws IOException, DMLException, ParseException {
+ testFrameGeneral(InputInfo.CSVInputInfo, OutputInfo.CSVOutputInfo);
+ }
+
+ @Test
+ public void testCSVInTextOut() throws IOException, DMLException, ParseException {
+ testFrameGeneral(InputInfo.TextCellInputInfo, OutputInfo.CSVOutputInfo);
+ }
+
+ @Test
+ public void testTextInCSVOut() throws IOException, DMLException, ParseException {
+ testFrameGeneral(InputInfo.CSVInputInfo, OutputInfo.TextCellOutputInfo);
+ }
+
+ @Test
+ public void testTextInTextOut() throws IOException, DMLException, ParseException {
+ testFrameGeneral(InputInfo.TextCellInputInfo, OutputInfo.TextCellOutputInfo);
+ }
+
+ @Test
+ public void testDataFrameInCSVOut() throws IOException, DMLException, ParseException {
+ testFrameGeneral(InputInfo.CSVInputInfo, true, false);
+ }
+
+ @Test
+ public void testDataFrameInTextOut() throws IOException, DMLException, ParseException {
+ testFrameGeneral(InputInfo.TextCellInputInfo, true, false);
+ }
+
+ @Test
+ public void testDataFrameInDataFrameOut() throws IOException, DMLException, ParseException {
+ testFrameGeneral(true, true);
+ }
+
+ private void testFrameGeneral(InputInfo iinfo, OutputInfo oinfo) throws IOException, DMLException, ParseException {
+ testFrameGeneral(iinfo, oinfo, false, false);
+ }
+
+ private void testFrameGeneral(InputInfo iinfo, boolean bFromDataFrame, boolean bToDataFrame) throws IOException, DMLException, ParseException {
+ testFrameGeneral(iinfo, OutputInfo.CSVOutputInfo, bFromDataFrame, bToDataFrame);
+ }
+
+ private void testFrameGeneral(boolean bFromDataFrame, boolean bToDataFrame) throws IOException, DMLException, ParseException {
+ testFrameGeneral(InputInfo.BinaryBlockInputInfo, OutputInfo.CSVOutputInfo, bFromDataFrame, bToDataFrame);
+ }
+
+ private void testFrameGeneral(InputInfo iinfo, OutputInfo oinfo, boolean bFromDataFrame, boolean bToDataFrame) throws IOException, DMLException, ParseException {
+
+ boolean oldConfig = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ RUNTIME_PLATFORM oldRT = DMLScript.rtplatform;
+ DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
+
+ this.scriptType = ScriptType.DML;
+
+ int rowstart = 234, rowend = 1478, colstart = 125, colend = 568;
+ int bRows = rowend-rowstart+1, bCols = colend-colstart+1;
+
+ int rowstartC = 124, rowendC = 1178, colstartC = 143, colendC = 368;
+ int cRows = rowendC-rowstartC+1, cCols = colendC-colstartC+1;
+
+ HashMap<String, ValueType[]> outputSchema = new HashMap<String, ValueType[]>();
+ HashMap<String, MatrixCharacteristics> outputMC = new HashMap<String, MatrixCharacteristics>();
+
+ TestConfiguration config = getTestConfiguration(TEST_NAME);
+
+ loadTestConfiguration(config);
+
+ List<String> proArgs = new ArrayList<String>();
+ proArgs.add(input("A"));
+ proArgs.add(Integer.toString(rows));
+ proArgs.add(Integer.toString(cols));
+ proArgs.add(input("B"));
+ proArgs.add(Integer.toString(bRows));
+ proArgs.add(Integer.toString(bCols));
+ proArgs.add(Integer.toString(rowstart));
+ proArgs.add(Integer.toString(rowend));
+ proArgs.add(Integer.toString(colstart));
+ proArgs.add(Integer.toString(colend));
+ proArgs.add(output("A"));
+ proArgs.add(Integer.toString(rowstartC));
+ proArgs.add(Integer.toString(rowendC));
+ proArgs.add(Integer.toString(colstartC));
+ proArgs.add(Integer.toString(colendC));
+ proArgs.add(output("C"));
+ programArgs = proArgs.toArray(new String[proArgs.size()]);
+
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml";
+
+ ValueType[] schema = schemaMixedLarge;
+
+ //initialize the frame data.
+ List<ValueType> lschema = Arrays.asList(schema);
+
+ fullRScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " +
+ inputDir() + " " + rowstart + " " + rowend + " " + colstart + " " + colend + " " + expectedDir()
+ + " " + rowstartC + " " + rowendC + " " + colstartC + " " + colendC;
+
+ double sparsity=sparsity1;//rand.nextDouble();
+ double[][] A = getRandomMatrix(rows, cols, min, max, sparsity, 1111 /*\\System.currentTimeMillis()*/);
+ writeInputFrameWithMTD("A", A, true, lschema, oinfo);
+
+ sparsity=sparsity2;//rand.nextDouble();
+ double[][] B = getRandomMatrix((int)(bRows), (int)(bCols), min, max, sparsity, 2345 /*System.currentTimeMillis()*/);
+ //Following way of creation causes serialization issue in frame processing
+ //List<ValueType> lschemaB = lschema.subList((int)colstart-1, (int)colend);
+ ValueType[] schemaB = new ValueType[bCols];
+ for (int i = 0; i < bCols; ++i)
+ schemaB[i] = schema[colstart-1+i];
+ List<ValueType> lschemaB = Arrays.asList(schemaB);
+ writeInputFrameWithMTD("B", B, true, lschemaB, oinfo);
+
+ ValueType[] schemaC = new ValueType[colendC-colstartC+1];
+ for (int i = 0; i < cCols; ++i)
+ schemaC[i] = schema[colstartC-1+i];
+
+ MLContext mlCtx = getMLContextForTesting();
+ SparkContext sc = mlCtx.getSparkContext();
+ JavaSparkContext jsc = new JavaSparkContext(sc);
+
+ DataFrame dfA = null, dfB = null;
+ if(bFromDataFrame)
+ {
+ //Create DataFrame for input A
+ SQLContext sqlContext = new SQLContext(sc);
+ StructType dfSchemaA = UtilFunctions.convertFrameSchemaToDFSchema(lschema);
+ JavaRDD<Row> rowRDDA = UtilFunctions.getRowRDD(jsc, input("A"), DataExpression.DEFAULT_DELIM_DELIMITER, lschema);
+ dfA = sqlContext.createDataFrame(rowRDDA, dfSchemaA);
+
+ //Create DataFrame for input B
+ StructType dfSchemaB = UtilFunctions.convertFrameSchemaToDFSchema(lschemaB);
+ JavaRDD<Row> rowRDDB = UtilFunctions.getRowRDD(jsc, input("B"), DataExpression.DEFAULT_DELIM_DELIMITER, lschemaB);
+ dfB = sqlContext.createDataFrame(rowRDDB, dfSchemaB);
+ }
+
+ try
+ {
+ mlCtx.reset(true); // Cleanup config to ensure future MLContext testcases have correct 'cp.parallel.matrixmult'
+
+ String format = "csv";
+ if(oinfo == OutputInfo.TextCellOutputInfo)
+ format = "text";
+
+ if(bFromDataFrame)
+ mlCtx.registerFrameInput("A", dfA, false);
+ else {
+ JavaRDD<String> aIn = jsc.textFile(input("A"));
+ mlCtx.registerInput("A", aIn, format, rows, cols, new CSVFileFormatProperties(), lschema);
+ }
+
+ if(bFromDataFrame)
+ mlCtx.registerFrameInput("B", dfB, false);
+ else {
+ JavaRDD<String> bIn = jsc.textFile(input("B"));
+ mlCtx.registerInput("B", bIn, format, bRows, bCols, new CSVFileFormatProperties(), lschemaB);
+ }
+
+ // Output one frame to HDFS and get one as RDD //TODO HDFS input/output to do
+ mlCtx.registerOutput("A");
+ mlCtx.registerOutput("C");
+
+ MLOutput out = mlCtx.execute(fullDMLScriptName, programArgs);
+
+ format = "csv";
+ if(iinfo == InputInfo.TextCellInputInfo)
+ format = "text";
+
+ String fName = output("AB");
+ try {
+ MapReduceTool.deleteFileIfExistOnHDFS( fName );
+ } catch (IOException e) {
+ throw new DMLRuntimeException("Error: While deleting file on HDFS");
+ }
+
+ if(!bToDataFrame)
+ {
+ JavaRDD<String> aOut = out.getStringFrameRDD("A", format, new CSVFileFormatProperties());
+ aOut.saveAsTextFile(fName);
+ } else {
+ DataFrame df = out.getDataFrameRDD("A", jsc);
+
+ //Convert back DataFrame to binary block for comparison using original binary to converted DF and back to binary
+ MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, -1, -1, -1);
+ JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils
+ .dataFrameToBinaryBlock(jsc, df, mc, false)
+ .mapToPair(new LongFrameToLongWritableFrameFunction());
+ rddOut.saveAsHadoopFile(output("AB"), LongWritable.class, FrameBlock.class, OutputInfo.BinaryBlockOutputInfo.outputFormatClass);
+ }
+
+ fName = output("C");
+ try {
+ MapReduceTool.deleteFileIfExistOnHDFS( fName );
+ } catch (IOException e) {
+ throw new DMLRuntimeException("Error: While deleting file on HDFS");
+ }
+ if(!bToDataFrame)
+ {
+ JavaRDD<String> aOut = out.getStringFrameRDD("C", format, new CSVFileFormatProperties());
+ aOut.saveAsTextFile(fName);
+ } else {
+ DataFrame df = out.getDataFrameRDD("C", jsc);
+
+ //Convert back DataFrame to binary block for comparison using original binary to converted DF and back to binary
+ MatrixCharacteristics mc = new MatrixCharacteristics(cRows, cCols, -1, -1, -1);
+ JavaPairRDD<LongWritable, FrameBlock> rddOut = FrameRDDConverterUtils
+ .dataFrameToBinaryBlock(jsc, df, mc, false)
+ .mapToPair(new LongFrameToLongWritableFrameFunction());
+ rddOut.saveAsHadoopFile(fName, LongWritable.class, FrameBlock.class, OutputInfo.BinaryBlockOutputInfo.outputFormatClass);
+ }
+
+ runRScript(true);
+
+ outputSchema.put("AB", schema);
+ outputMC.put("AB", new MatrixCharacteristics(rows, cols, -1, -1));
+ outputSchema.put("C", schemaC);
+ outputMC.put("C", new MatrixCharacteristics(cRows, cCols, -1, -1));
+
+ for(String file: config.getOutputFiles())
+ {
+ MatrixCharacteristics md = outputMC.get(file);
+ FrameBlock frameBlock = readDMLFrameFromHDFS(file, iinfo, md);
+ FrameBlock frameRBlock = readRFrameFromHDFS(file+".csv", InputInfo.CSVInputInfo, md);
+ ValueType[] schemaOut = outputSchema.get(file);
+ verifyFrameData(frameBlock, frameRBlock, schemaOut);
+ System.out.println("File " + file + " processed successfully.");
+ }
+
+ //cleanup mlcontext (prevent test memory leaks)
+ mlCtx.reset();
+
+ System.out.println("Frame MLContext test completed successfully.");
+ }
+ finally {
+ DMLScript.rtplatform = oldRT;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig;
+ }
+ }
+
+ private void verifyFrameData(FrameBlock frame1, FrameBlock frame2, ValueType[] schema) {
+ for ( int i=0; i<frame1.getNumRows(); i++ )
+ for( int j=0; j<frame1.getNumColumns(); j++ ) {
+ Object val1 = UtilFunctions.stringToObject(schema[j], UtilFunctions.objectToString(frame1.get(i, j)));
+ Object val2 = UtilFunctions.stringToObject(schema[j], UtilFunctions.objectToString(frame2.get(i, j)));
+ if( TestUtils.compareToR(schema[j], val1, val2, epsilon) != 0)
+ Assert.fail("The DML data for cell ("+ i + "," + j + ") is " + val1 +
+ ", not same as the R value " + val2);
+ }
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/test/scripts/functions/frame/FrameGeneral.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/frame/FrameGeneral.R b/src/test/scripts/functions/frame/FrameGeneral.R
new file mode 100644
index 0000000..079c74c
--- /dev/null
+++ b/src/test/scripts/functions/frame/FrameGeneral.R
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A=read.csv(paste(args[1], "A.csv", sep=""), header = FALSE, stringsAsFactors=FALSE)
+B=read.csv(paste(args[1], "B.csv", sep=""), header = FALSE, stringsAsFactors=FALSE)
+
+A[args[2]:args[3],args[4]:args[5]]=0
+A[args[2]:args[3],args[4]:args[5]]=B
+write.csv(A, paste(args[6], "AB.csv", sep=""), row.names = FALSE, quote = FALSE)
+
+C=A[args[7]:args[8],args[9]:args[10]]
+write.csv(C, paste(args[6], "C.csv", sep=""), row.names = FALSE, quote = FALSE)
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/02a9f277/src/test/scripts/functions/frame/FrameGeneral.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/frame/FrameGeneral.dml b/src/test/scripts/functions/frame/FrameGeneral.dml
new file mode 100644
index 0000000..9d9a2f7
--- /dev/null
+++ b/src/test/scripts/functions/frame/FrameGeneral.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+#
+# Left Indexing test
+A=read($1, data_type="frame", rows=$2, cols=$3)
+B=read($4, data_type="frame", rows=$5, cols=$6)
+A[$7:$8,$9:$10]=B
+write(A, $11)
+
+# Right Indexing test
+C=A[$12:$13,$14:$15]
+write(C, $16)