You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/07/12 17:16:50 UTC

[systemds] branch master updated: [SYSTEMDS-3055] Frame replace support

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 282e20d  [SYSTEMDS-3055] Frame replace support
282e20d is described below

commit 282e20d73fadc98b4afef1959870f728072b8418
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Mon Jul 12 18:12:02 2021 +0200

    [SYSTEMDS-3055] Frame replace support
    
    Add support for replace on a frame both for CP and SP instructions.
    simply provide a frame and string target and replacement:
    
    X = replace(target=X, pattern ="REPLACE_ME", replacement = "SOMETHING_ELSE")
    
    Closes #1344
---
 .../ParameterizedBuiltinFunctionExpression.java    | 11 ++-
 .../cp/ParameterizedBuiltinCPInstruction.java      | 22 ++++--
 .../spark/ParameterizedBuiltinSPInstruction.java   | 65 ++++++++++-----
 .../sysds/runtime/matrix/data/FrameBlock.java      | 13 +++
 .../test/functions/frame/FrameReplaceTest.java     | 92 ++++++++++++++++++++++
 src/test/scripts/functions/frame/ReplaceTest.dml   | 28 +++++++
 6 files changed, 204 insertions(+), 27 deletions(-)

diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index ec731e6..d074d0d 100644
--- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -478,7 +478,9 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 	private void validateReplace(DataIdentifier output, boolean conditional) {
 		//check existence and correctness of arguments
 		Expression target = getVarParam("target");
-		checkTargetParam(target, conditional);
+		if( target.getOutput().getDataType() != DataType.FRAME ){
+			checkTargetParam(target, conditional);
+		}
 		
 		Expression pattern = getVarParam("pattern");
 		if( pattern==null ) {
@@ -497,8 +499,11 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
 		}	
 		
 		// Output is a matrix with same dims as input
-		output.setDataType(DataType.MATRIX);
-		output.setValueType(ValueType.FP64);
+		output.setDataType(target.getOutput().getDataType());
+		if(target.getOutput().getDataType() == DataType.FRAME)
+			output.setValueType(ValueType.STRING);
+		else
+			output.setValueType(ValueType.FP64);
 		output.setDimensions(target.getOutput().getDim1(), target.getOutput().getDim2());
 	}
 
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 54a5339..f115b52 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -225,12 +225,22 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
 				ec.releaseMatrixInput(params.get("select"));
 		}
 		else if(opcode.equalsIgnoreCase("replace")) {
-			MatrixBlock target = ec.getMatrixInput(params.get("target"));
-			double pattern = Double.parseDouble(params.get("pattern"));
-			double replacement = Double.parseDouble(params.get("replacement"));
-			MatrixBlock ret = target.replaceOperations(new MatrixBlock(), pattern, replacement);
-			ec.setMatrixOutput(output.getName(), ret);
-			ec.releaseMatrixInput(params.get("target"));
+			if(ec.isFrameObject(params.get("target"))){
+				FrameBlock target = ec.getFrameInput(params.get("target"));
+				String pattern = params.get("pattern");
+				String replacement = params.get("replacement");
+				FrameBlock ret = target.replaceOperations(pattern, replacement);
+				ec.setFrameOutput(output.getName(), ret);
+				ec.releaseFrameInput(params.get("target"));
+			}else{
+				MatrixBlock target = ec.getMatrixInput(params.get("target"));
+				double pattern = Double.parseDouble(params.get("pattern"));
+				double replacement = Double.parseDouble(params.get("replacement"));
+				MatrixBlock ret = target.replaceOperations(new MatrixBlock(), pattern, replacement);
+				ec.setMatrixOutput(output.getName(), ret);
+				ec.releaseMatrixInput(params.get("target"));
+			}
+			
 		}
 		else if(opcode.equals("lowertri") || opcode.equals("uppertri")) {
 			MatrixBlock target = ec.getMatrixInput(params.get("target"));
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 9975925..40e152f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -358,25 +358,38 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
 			}
 		}
 		else if(opcode.equalsIgnoreCase("replace")) {
-			JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
-				.getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
-			DataCharacteristics mcIn = sec.getDataCharacteristics(params.get("target"));
-
-			// execute replace operation
-			double pattern = Double.parseDouble(params.get("pattern"));
-			double replacement = Double.parseDouble(params.get("replacement"));
-			JavaPairRDD<MatrixIndexes, MatrixBlock> out = in1.mapValues(new RDDReplaceFunction(pattern, replacement));
-
-			// store output rdd handle
-			sec.setRDDHandleForVariable(output.getName(), out);
-			sec.addLineageRDD(output.getName(), params.get("target"));
+			if(sec.isFrameObject(params.get("target"))){
+				params.get("target");
+				JavaPairRDD<Long, FrameBlock> in1 = sec.getFrameBinaryBlockRDDHandleForVariable(params.get("target"));
+				DataCharacteristics mcIn = sec.getDataCharacteristics(params.get("target"));
+				String pattern = params.get("pattern");
+				String replacement = params.get("replacement");
+				JavaPairRDD<Long, FrameBlock> out = in1.mapValues(new RDDFrameReplaceFunction(pattern, replacement));
+				sec.setRDDHandleForVariable(output.getName(), out);
+				sec.addLineageRDD(output.getName(), params.get("target"));
+				sec.getDataCharacteristics(output.getName()).set(mcIn.getRows(), mcIn.getCols(), mcIn.getBlocksize(), mcIn.getNonZeros());
+			}
+			else {
+				JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
+					.getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
+				DataCharacteristics mcIn = sec.getDataCharacteristics(params.get("target"));
+	
+				// execute replace operation
+				double pattern = Double.parseDouble(params.get("pattern"));
+				double replacement = Double.parseDouble(params.get("replacement"));
+				JavaPairRDD<MatrixIndexes, MatrixBlock> out = in1.mapValues(new RDDReplaceFunction(pattern, replacement));
+	
+				// store output rdd handle
+				sec.setRDDHandleForVariable(output.getName(), out);
+				sec.addLineageRDD(output.getName(), params.get("target"));
+	
+				// update output statistics (required for correctness)
+				sec.getDataCharacteristics(output.getName()).set(mcIn.getRows(),
+					mcIn.getCols(),
+					mcIn.getBlocksize(),
+					(pattern != 0 && replacement != 0) ? mcIn.getNonZeros() : -1);
+			}
 
-			// update output statistics (required for correctness)
-			DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
-			mcOut.set(mcIn.getRows(),
-				mcIn.getCols(),
-				mcIn.getBlocksize(),
-				(pattern != 0 && replacement != 0) ? mcIn.getNonZeros() : -1);
 		}
 		else if(opcode.equalsIgnoreCase("lowertri") || opcode.equalsIgnoreCase("uppertri")) {
 			JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
@@ -544,6 +557,22 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
 		}
 	}
 
+	public static class RDDFrameReplaceFunction implements Function<FrameBlock, FrameBlock>{
+		private static final long serialVersionUID = 6576713401901671660L;
+		private final String _pattern;
+		private final String _replacement;
+
+		public RDDFrameReplaceFunction(String pattern, String replacement){
+			_pattern = pattern;
+			_replacement = replacement;
+		}
+
+		@Override 
+		public FrameBlock call(FrameBlock arg0){
+			return arg0.replaceOperations(_pattern, _replacement);
+		}
+	}
+
 	private static class RDDExtractTriangularFunction
 		implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
 		private static final long serialVersionUID = 2754868819184155702L;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index 322cfad..8ee6f33 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -2237,4 +2237,17 @@ public class FrameBlock implements CacheBlock, Externalizable  {
 		public String apply(String input) {return null;}
 		public String apply(String input1, String input2) {	return null;}
 	}
+
+	public FrameBlock replaceOperations(String pattern, String replacement){
+		FrameBlock ret = new FrameBlock(this);
+		for(int i = 0; i < ret.getNumColumns(); i++){
+			Array colData = ret._coldata[i];
+			for(int j = 0; j < colData._size; j++){
+				Object ent = colData.get(j);
+				if(ent != null && ent.equals(pattern))
+					colData.set(j,replacement); 
+			}
+		}
+		return ret;
+	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameReplaceTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameReplaceTest.java
new file mode 100644
index 0000000..73868e3
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameReplaceTest.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.frame;
+
+import static org.junit.Assert.assertTrue;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class FrameReplaceTest extends AutomatedTestBase {
+    // private static final Log LOG = LogFactory.getLog(FrameReplaceTest.class.getName());
+    private final static String TEST_DIR = "functions/frame/";
+    private final static String TEST_NAME = "ReplaceTest";
+    private final static String TEST_CLASS_DIR = TEST_DIR + FrameReplaceTest.class.getSimpleName() + "/";
+
+    @Override
+    public void setUp() {
+        TestUtils.clearAssertionInformation();
+        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+    }
+
+    @Test
+    public void testParforFrameIntermediatesCP() {
+        runReplaceTest(ExecType.CP);
+    }
+
+    @Test
+    public void testParforFrameIntermediatesSpark() {
+        runReplaceTest(ExecType.SPARK);
+    }
+
+    private void runReplaceTest(ExecType et) {
+        ExecMode platformOld = rtplatform;
+        switch(et) {
+            case SPARK:
+                rtplatform = ExecMode.SPARK;
+                break;
+            default:
+                rtplatform = ExecMode.HYBRID;
+                break;
+        }
+
+        boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+        if(rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID)
+            DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+        try {
+            // setup testcase
+            getAndLoadTestConfiguration(TEST_NAME);
+            String HOME = SCRIPT_DIR + TEST_DIR;
+            fullDMLScriptName = HOME + TEST_NAME + ".dml";
+            programArgs = new String[] {};
+
+            // run test
+            String out = runTest(null).toString();
+
+            assertTrue(out.contains("south"));
+            assertTrue(!out.contains("north"));
+
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+        }
+        finally {
+            rtplatform = platformOld;
+            DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+        }
+    }
+
+}
diff --git a/src/test/scripts/functions/frame/ReplaceTest.dml b/src/test/scripts/functions/frame/ReplaceTest.dml
new file mode 100644
index 0000000..2a12b48
--- /dev/null
+++ b/src/test/scripts/functions/frame/ReplaceTest.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read("src/test/resources/datasets/homes/homes.csv") 
+
+X = replace(target = X, pattern="north", replacement="south")
+X = replace(target = X, pattern="east", replacement="south")
+X = replace(target = X, pattern="west", replacement="south")
+
+print(toString(X))
\ No newline at end of file