You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/08/10 20:41:14 UTC
[systemds] branch master updated: [SYSTEMDS-2601] Comparison
operators for frame-frame ops (CP, Spark)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 98ea24d [SYSTEMDS-2601] Comparison operators for frame-frame ops (CP, Spark)
98ea24d is described below
commit 98ea24d9f7d713bdcc1c0898ee79eaa493b096b3
Author: Shafaq Siddiqi <sh...@tugraz.at>
AuthorDate: Mon Aug 10 21:55:31 2020 +0200
[SYSTEMDS-2601] Comparison operators for frame-frame ops (CP, Spark)
Closes #1009.
---
.../org/apache/sysds/parser/DMLTranslator.java | 4 +
.../apache/sysds/parser/RelationalExpression.java | 13 +-
.../cp/BinaryFrameFrameCPInstruction.java | 23 ++-
.../spark/BinaryFrameFrameSPInstruction.java | 44 +++++-
.../sysds/runtime/matrix/data/FrameBlock.java | 79 +++++++++
.../functions/binary/frame/FrameEqualTest.java | 176 +++++++++++++++++++++
.../functions/binary/frame/frameComparisonTest.R | 46 ++++++
.../functions/binary/frame/frameComparisonTest.dml | 43 +++++
8 files changed, 415 insertions(+), 13 deletions(-)
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index a9e5a6e..f84f469 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1812,6 +1812,10 @@ public class DMLTranslator
target.setDataType(DataType.MATRIX);
target.setValueType(ValueType.FP64);
}
+ else if(left.getDataType() == DataType.FRAME || right.getDataType() == DataType.FRAME) {
+ target.setDataType(DataType.FRAME);
+ target.setValueType(ValueType.BOOLEAN);
+ }
else {
// Added to support scalar relational comparison
target.setDataType(DataType.SCALAR);
diff --git a/src/main/java/org/apache/sysds/parser/RelationalExpression.java b/src/main/java/org/apache/sysds/parser/RelationalExpression.java
index 5c19a18..f0b4695 100644
--- a/src/main/java/org/apache/sysds/parser/RelationalExpression.java
+++ b/src/main/java/org/apache/sysds/parser/RelationalExpression.java
@@ -140,7 +140,9 @@ public class RelationalExpression extends Expression
output.setParseInfo(this);
boolean isLeftMatrix = (_left.getOutput() != null && _left.getOutput().getDataType() == DataType.MATRIX);
- boolean isRightMatrix = (_right.getOutput() != null && _right.getOutput().getDataType() == DataType.MATRIX);
+ boolean isRightMatrix = (_right.getOutput() != null && _right.getOutput().getDataType() == DataType.MATRIX);
+ boolean isLeftFrame = (_left.getOutput() != null && _left.getOutput().getDataType() == DataType.FRAME);
+ boolean isRightFrame = (_right.getOutput() != null && _right.getOutput().getDataType() == DataType.FRAME);
if(isLeftMatrix || isRightMatrix) {
// Added to support matrix relational comparison
if(isLeftMatrix && isRightMatrix) {
@@ -155,6 +157,15 @@ public class RelationalExpression extends Expression
//double; once we support boolean matrices this needs to change
output.setValueType(ValueType.FP64);
}
+ else if(isLeftFrame && isRightFrame) {
+ output.setDataType(DataType.FRAME);
+ output.setDimensions(_left.getOutput().getDim1(), _left.getOutput().getDim2());
+ output.setValueType(ValueType.BOOLEAN);
+ }
+ else if( isLeftFrame || isRightFrame ) {
+ raiseValidateError("Unsupported relational expression for mixed types "
+ +_left.getOutput().getDataType().name()+" "+_right.getOutput().getDataType().name());
+ }
else {
output.setBooleanProperties();
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
index 1116675..7968b18 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.cp;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
public class BinaryFrameFrameCPInstruction extends BinaryCPInstruction
@@ -32,16 +33,26 @@ public class BinaryFrameFrameCPInstruction extends BinaryCPInstruction
@Override
public void processInstruction(ExecutionContext ec) {
- // Read input matrices
+ // get input frames
FrameBlock inBlock1 = ec.getFrameInput(input1.getName());
FrameBlock inBlock2 = ec.getFrameInput(input2.getName());
-
- // Perform computation using input frames, and produce the result frame
- FrameBlock retBlock = inBlock1.dropInvalid(inBlock2);
+
+ if(getOpcode().equals("dropInvalidType")) {
+ // Perform computation using input frames, and produce the result frame
+ FrameBlock retBlock = inBlock1.dropInvalid(inBlock2);
+ // Attach result frame with FrameBlock associated with output_name
+ ec.setFrameOutput(output.getName(), retBlock);
+ }
+ else {
+ // Execute binary operations
+ BinaryOperator dop = (BinaryOperator) _optr;
+ FrameBlock outBlock = inBlock1.binaryOperations(dop, inBlock2, null);
+ // Attach result frame with FrameBlock associated with output_name
+ ec.setFrameOutput(output.getName(), outBlock);
+ }
+
// Release the memory occupied by input frames
ec.releaseFrameInput(input1.getName());
ec.releaseFrameInput(input2.getName());
- // Attach result frame with FrameBlock associated with output_name
- ec.setFrameOutput(output.getName(), retBlock);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java
index 263abf3..021ac84 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java
@@ -29,7 +29,9 @@ import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
+import scala.Tuple2;
public class BinaryFrameFrameSPInstruction extends BinarySPInstruction {
protected BinaryFrameFrameSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
@@ -55,16 +57,33 @@ public class BinaryFrameFrameSPInstruction extends BinarySPInstruction {
@Override
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
+
// Get input RDDs
JavaPairRDD<Long, FrameBlock> in1 = sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName());
- // get schema frame-block
- Broadcast<FrameBlock> fb = sec.getSparkContext().broadcast(sec.getFrameInput(input2.getName()));
- JavaPairRDD<Long, FrameBlock> out = in1.mapValues(new isCorrectbySchema(fb.getValue()));
- //release input frame
- sec.releaseFrameInput(input2.getName());
- //set output RDD
+ JavaPairRDD<Long, FrameBlock> out = null;
+
+ if(getOpcode().equals("dropInvalidType")) {
+ // get schema frame-block
+ Broadcast<FrameBlock> fb = sec.getSparkContext().broadcast(sec.getFrameInput(input2.getName()));
+ out = in1.mapValues(new isCorrectbySchema(fb.getValue()));
+ //release input frame
+ sec.releaseFrameInput(input2.getName());
+ }
+ else {
+ JavaPairRDD<Long, FrameBlock> in2 = sec.getFrameBinaryBlockRDDHandleForVariable(input2.getName());
+ // create output frame
+ BinaryOperator dop = (BinaryOperator) _optr;
+ // check for binary operations
+ out = in1.join(in2).mapValues(new FrameComparison(dop));
+ }
+
+ //set output RDD and maintain dependencies
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), input1.getName());
+ if( getOpcode().equals("dropInvalidType") )
+ sec.addLineageBroadcast(output.getName(), input2.getName());
+ else
+ sec.addLineageRDD(output.getName(), input2.getName());
}
private static class isCorrectbySchema implements Function<FrameBlock,FrameBlock> {
@@ -81,4 +100,17 @@ public class BinaryFrameFrameSPInstruction extends BinarySPInstruction {
return arg0.dropInvalid(schema_frame);
}
}
+
+ private static class FrameComparison implements Function<Tuple2<FrameBlock, FrameBlock>, FrameBlock> {
+ private static final long serialVersionUID = 5850400295183766401L;
+ private final BinaryOperator bop;
+ public FrameComparison(BinaryOperator op){
+ bop = op;
+ }
+
+ @Override
+ public FrameBlock call(Tuple2<FrameBlock, FrameBlock> arg0) throws Exception {
+ return arg0._1().binaryOperations(bop, arg0._2(), null);
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index 325819b..e473acd 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -41,7 +41,10 @@ import org.apache.sysds.api.DMLException;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.functionobjects.ValueComparisonFunction;
+import org.apache.sysds.runtime.instructions.cp.*;
import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.transform.encode.EncoderRecode;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -277,6 +280,7 @@ public class FrameBlock implements CacheBlock, Externalizable
case BOOLEAN: _coldata[j] = new BooleanArray(new boolean[numRows]); break;
case INT32: _coldata[j] = new IntegerArray(new int[numRows]); break;
case INT64: _coldata[j] = new LongArray(new long[numRows]); break;
+ case FP32: _coldata[j] = new FloatArray(new float[numRows]); break;
case FP64: _coldata[j] = new DoubleArray(new double[numRows]); break;
default: throw new RuntimeException("Unsupported value type: "+_schema[j]);
}
@@ -702,6 +706,8 @@ public class FrameBlock implements CacheBlock, Externalizable
case BOOLEAN: arr = new BooleanArray(new boolean[_numRows]); break;
case INT64: arr = new LongArray(new long[_numRows]); break;
case FP64: arr = new DoubleArray(new double[_numRows]); break;
+ case INT32: arr = new IntegerArray(new int[_numRows]); break;
+ case FP32: arr = new FloatArray(new float[_numRows]); break;
default: throw new IOException("Unsupported value type: "+vt);
}
arr.readFields(in);
@@ -837,6 +843,79 @@ public class FrameBlock implements CacheBlock, Externalizable
+ 32 + value.length(); //char array
}
+ /**
+ * This method performs the value comparison on two frames
+ * if the values in both frames are equal, not equal, less than, greater than, less than/greater than and equal to
+ * the output frame will store boolean value for each each comparison
+ *
+ * @param bop binary operator
+ * @param that frame block of rhs of m * n dimensions
+ * @param out output frame block
+ * @return a boolean frameBlock
+ */
+ public FrameBlock binaryOperations(BinaryOperator bop, FrameBlock that, FrameBlock out) {
+ if(getNumColumns() != that.getNumColumns() && getNumRows() != that.getNumColumns())
+ throw new DMLRuntimeException("Frame dimension mismatch "+getNumRows()+" * "+getNumColumns()+
+ " != "+that.getNumRows()+" * "+that.getNumColumns());
+ String[][] outputData = new String[getNumRows()][getNumColumns()];
+
+ //compare output value, incl implicit type promotion if necessary
+ if( !(bop.fn instanceof ValueComparisonFunction) )
+ throw new DMLRuntimeException("Unsupported binary operation on frames (only comparisons supported)");
+ ValueComparisonFunction vcomp = (ValueComparisonFunction) bop.fn;
+
+ for (int i = 0; i < getNumColumns(); i++) {
+ if (getSchema()[i] == ValueType.STRING || that.getSchema()[i] == ValueType.STRING) {
+ for (int j = 0; j < getNumRows(); j++) {
+ if(checkAndSetEmpty(this, that, outputData, j, i))
+ continue;
+ String v1 = UtilFunctions.objectToString(get(j, i));
+ String v2 = UtilFunctions.objectToString(that.get(j, i));
+ outputData[j][i] = String.valueOf(vcomp.compare(v1, v2));
+ }
+ }
+ else if (getSchema()[i] == ValueType.FP64 || that.getSchema()[i] == ValueType.FP64 ||
+ getSchema()[i] == ValueType.FP32 || that.getSchema()[i] == ValueType.FP32) {
+ for (int j = 0; j < getNumRows(); j++) {
+ if(checkAndSetEmpty(this, that, outputData, j, i))
+ continue;
+ ScalarObject so1 = new DoubleObject(Double.parseDouble(get(j, i).toString()));
+ ScalarObject so2 = new DoubleObject(Double.parseDouble(that.get(j, i).toString()));
+ outputData[j][i] = String.valueOf(vcomp.compare(so1.getDoubleValue(), so2.getDoubleValue()));
+ }
+ }
+ else if (getSchema()[i] == ValueType.INT64 || that.getSchema()[i] == ValueType.INT64 ||
+ getSchema()[i] == ValueType.INT32 || that.getSchema()[i] == ValueType.INT32) {
+ for (int j = 0; j < this.getNumRows(); j++) {
+ if(checkAndSetEmpty(this, that, outputData, j, i))
+ continue;
+ ScalarObject so1 = new IntObject(Integer.parseInt(get(j, i).toString()));
+ ScalarObject so2 = new IntObject(Integer.parseInt(that.get(j, i).toString()));
+ outputData[j][i] = String.valueOf(vcomp.compare(so1.getLongValue(), so2.getLongValue()));
+ }
+ }
+ else {
+ for (int j = 0; j < getNumRows(); j++) {
+ if(checkAndSetEmpty(this, that, outputData, j, i))
+ continue;
+ ScalarObject so1 = new BooleanObject( Boolean.parseBoolean(get(j, i).toString()));
+ ScalarObject so2 = new BooleanObject( Boolean.parseBoolean(that.get(j, i).toString()));
+ outputData[j][i] = String.valueOf(vcomp.compare(so1.getBooleanValue(), so2.getBooleanValue()));
+ }
+ }
+ }
+
+ return new FrameBlock(UtilFunctions.nCopies(this.getNumColumns(), ValueType.BOOLEAN), outputData);
+ }
+
+ private static boolean checkAndSetEmpty(FrameBlock fb1, FrameBlock fb2, String[][] out, int r, int c) {
+ if(fb1.get(r, c) == null || fb2.get(r, c) == null) {
+ out[r][c] = (fb1.get(r, c) == null && fb2.get(r, c) == null) ? "true" : "false";
+ return true;
+ }
+ return false;
+ }
+
///////
// indexing and append operations
diff --git a/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameEqualTest.java b/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameEqualTest.java
new file mode 100644
index 0000000..cdb8999
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameEqualTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.binary.frame;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class FrameEqualTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "frameComparisonTest";
+ private final static String TEST_DIR = "functions/binary/frame/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FrameEqualTest.class.getSimpleName() + "/";
+
+ private final static int rows = 100;
+ private final static Types.ValueType[] schemaStrings1 = {Types.ValueType.FP64, Types.ValueType.BOOLEAN, Types.ValueType.INT64, Types.ValueType.STRING, Types.ValueType.STRING, Types.ValueType.FP64};
+ private final static Types.ValueType[] schemaStrings2 = {Types.ValueType.INT64, Types.ValueType.BOOLEAN, Types.ValueType.FP32, Types.ValueType.FP64, Types.ValueType.STRING, Types.ValueType.FP32};
+
+ public enum TestType {
+ GREATER, LESS, EQUALS, NOT_EQUALS, GREATER_EQUALS, LESS_EQUALS,
+ }
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @AfterClass
+ public static void cleanUp() {
+ if (TEST_CACHE_ENABLED) {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"D"}));
+ if (TEST_CACHE_ENABLED) {
+ setOutAndExpectedDeletionDisabled(true);
+ }
+ }
+
+ @Test
+ public void testFrameEqualCP() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.EQUALS, ExecType.CP);
+ }
+
+ @Test
+ public void testFrameEqualSpark() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.EQUALS, ExecType.SPARK);
+ }
+
+ @Test
+ public void testFrameNotEqualCP() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.NOT_EQUALS, ExecType.CP);
+ }
+
+ @Test
+ public void testFrameNotEqualSpark() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.NOT_EQUALS, ExecType.SPARK);
+ }
+
+ @Test
+ public void testFrameLessThanCP() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.LESS, ExecType.CP);
+ }
+
+ @Test
+ public void testFrameLessThanSpark() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.LESS, ExecType.SPARK);
+ }
+
+ @Test
+ public void testFrameGreaterEqualsCP() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.GREATER_EQUALS, ExecType.CP);
+ }
+
+ @Test
+ public void testFrameGreaterEqualsSpark() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.GREATER_EQUALS, ExecType.SPARK);
+ }
+
+ @Test
+ public void testFrameLessEqualsCP() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.LESS_EQUALS, ExecType.CP);
+ }
+
+ @Test
+ public void testFrameLessEqualsSpark() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.LESS_EQUALS, ExecType.SPARK);
+ }
+
+ @Test
+ public void testFrameGreaterThanCP() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.GREATER, ExecType.CP);
+ }
+
+ @Test
+ public void testFrameGreaterThanSpark() {
+ runComparisonTest(schemaStrings1, schemaStrings2, rows, schemaStrings1.length, TestType.GREATER, ExecType.SPARK);
+ }
+
+ private void runComparisonTest(Types.ValueType[] schema1, Types.ValueType[] schema2, int rows, int cols,
+ TestType type, ExecType et)
+ {
+ Types.ExecMode platformOld = setExecMode(et);
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ double[][] A = getRandomMatrix(rows, cols, 2, 3, 1, 2);
+ double[][] B = getRandomMatrix(rows, cols, 10, 20, 1, 0);
+
+ writeInputFrameWithMTD("A", A, true, schemaStrings1, FileFormat.BINARY);
+ writeInputFrameWithMTD("B", B, true, schemaStrings2, FileFormat.BINARY);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-explain", "recompile_runtime", "-nvargs", "A=" + input("A"), "B=" + input("B"),
+ "rows=" + String.valueOf(rows), "cols=" + Integer.toString(cols), "type=" + String.valueOf(type), "C=" + output("C")};
+
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + String.valueOf(type) + " " + expectedDir();
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("C");
+ HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromFS("C");
+
+ double eps = 0.0001;
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+ }
+ catch (Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ finally {
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true;
+ OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/binary/frame/frameComparisonTest.R b/src/test/scripts/functions/binary/frame/frameComparisonTest.R
new file mode 100644
index 0000000..c931d7e
--- /dev/null
+++ b/src/test/scripts/functions/binary/frame/frameComparisonTest.R
@@ -0,0 +1,46 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A=read.csv(paste(args[1], "A.csv", sep=""), header = FALSE, stringsAsFactors=FALSE)
+B=read.csv(paste(args[1], "B.csv", sep=""), header = FALSE, stringsAsFactors=FALSE)
+
+test = args[2]
+
+if(test == "GREATER")
+{
+C = A > B
+} else if (test == "LESS") {
+C = A < B
+} else if (test == "EQUALS") {
+C = A == B
+} else if (test == "NOT_EQUALS") {
+C = A != B
+} else if(test == "GREATER_EQUALS") {
+C = A >= B
+} else if(test == "LESS_EQUALS") {
+C = A <= B
+}
+
+writeMM(as(C, "CsparseMatrix"), paste(args[3], "C", sep=""));
\ No newline at end of file
diff --git a/src/test/scripts/functions/binary/frame/frameComparisonTest.dml b/src/test/scripts/functions/binary/frame/frameComparisonTest.dml
new file mode 100644
index 0000000..c43a614
--- /dev/null
+++ b/src/test/scripts/functions/binary/frame/frameComparisonTest.dml
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($A, rows=$rows, cols=$cols, data_type="frame", format="binary", header=FALSE);
+B = read($B, rows=$rows, cols=$cols, data_type="frame", format="binary", header=FALSE);
+
+test = $type
+
+if(test == "GREATER")
+ C = A > B
+else if (test == "LESS")
+ C = A < B
+else if (test == "EQUALS")
+ C = A == B
+else if (test == "NOT_EQUALS")
+ C = A != B
+else if (test == "GREATER_EQUALS")
+ C = A >= B
+else if (test == "LESS_EQUALS")
+ C = A <= B
+
+C = as.matrix(C)
+# print("this is C "+toString(C))
+
+write(C, $C);
\ No newline at end of file