You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/07/12 17:16:50 UTC
[systemds] branch master updated: [SYSTEMDS-3055] Frame replace
support
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 282e20d [SYSTEMDS-3055] Frame replace support
282e20d is described below
commit 282e20d73fadc98b4afef1959870f728072b8418
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Mon Jul 12 18:12:02 2021 +0200
[SYSTEMDS-3055] Frame replace support
Add support for replace on a frame both for CP and SP instructions.
simply provide a frame and string target and replacement:
X = replace(target=X, pattern ="REPLACE_ME", replacement = "SOMETHING_ELSE")
Closes #1344
---
.../ParameterizedBuiltinFunctionExpression.java | 11 ++-
.../cp/ParameterizedBuiltinCPInstruction.java | 22 ++++--
.../spark/ParameterizedBuiltinSPInstruction.java | 65 ++++++++++-----
.../sysds/runtime/matrix/data/FrameBlock.java | 13 +++
.../test/functions/frame/FrameReplaceTest.java | 92 ++++++++++++++++++++++
src/test/scripts/functions/frame/ReplaceTest.dml | 28 +++++++
6 files changed, 204 insertions(+), 27 deletions(-)
diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index ec731e6..d074d0d 100644
--- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -478,7 +478,9 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
private void validateReplace(DataIdentifier output, boolean conditional) {
//check existence and correctness of arguments
Expression target = getVarParam("target");
- checkTargetParam(target, conditional);
+ if( target.getOutput().getDataType() != DataType.FRAME ){
+ checkTargetParam(target, conditional);
+ }
Expression pattern = getVarParam("pattern");
if( pattern==null ) {
@@ -497,8 +499,11 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
}
// Output is a matrix with same dims as input
- output.setDataType(DataType.MATRIX);
- output.setValueType(ValueType.FP64);
+ output.setDataType(target.getOutput().getDataType());
+ if(target.getOutput().getDataType() == DataType.FRAME)
+ output.setValueType(ValueType.STRING);
+ else
+ output.setValueType(ValueType.FP64);
output.setDimensions(target.getOutput().getDim1(), target.getOutput().getDim2());
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 54a5339..f115b52 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -225,12 +225,22 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
ec.releaseMatrixInput(params.get("select"));
}
else if(opcode.equalsIgnoreCase("replace")) {
- MatrixBlock target = ec.getMatrixInput(params.get("target"));
- double pattern = Double.parseDouble(params.get("pattern"));
- double replacement = Double.parseDouble(params.get("replacement"));
- MatrixBlock ret = target.replaceOperations(new MatrixBlock(), pattern, replacement);
- ec.setMatrixOutput(output.getName(), ret);
- ec.releaseMatrixInput(params.get("target"));
+ if(ec.isFrameObject(params.get("target"))){
+ FrameBlock target = ec.getFrameInput(params.get("target"));
+ String pattern = params.get("pattern");
+ String replacement = params.get("replacement");
+ FrameBlock ret = target.replaceOperations(pattern, replacement);
+ ec.setFrameOutput(output.getName(), ret);
+ ec.releaseFrameInput(params.get("target"));
+ }else{
+ MatrixBlock target = ec.getMatrixInput(params.get("target"));
+ double pattern = Double.parseDouble(params.get("pattern"));
+ double replacement = Double.parseDouble(params.get("replacement"));
+ MatrixBlock ret = target.replaceOperations(new MatrixBlock(), pattern, replacement);
+ ec.setMatrixOutput(output.getName(), ret);
+ ec.releaseMatrixInput(params.get("target"));
+ }
+
}
else if(opcode.equals("lowertri") || opcode.equals("uppertri")) {
MatrixBlock target = ec.getMatrixInput(params.get("target"));
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 9975925..40e152f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -358,25 +358,38 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
}
}
else if(opcode.equalsIgnoreCase("replace")) {
- JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
- .getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
- DataCharacteristics mcIn = sec.getDataCharacteristics(params.get("target"));
-
- // execute replace operation
- double pattern = Double.parseDouble(params.get("pattern"));
- double replacement = Double.parseDouble(params.get("replacement"));
- JavaPairRDD<MatrixIndexes, MatrixBlock> out = in1.mapValues(new RDDReplaceFunction(pattern, replacement));
-
- // store output rdd handle
- sec.setRDDHandleForVariable(output.getName(), out);
- sec.addLineageRDD(output.getName(), params.get("target"));
+ if(sec.isFrameObject(params.get("target"))){
+ params.get("target");
+ JavaPairRDD<Long, FrameBlock> in1 = sec.getFrameBinaryBlockRDDHandleForVariable(params.get("target"));
+ DataCharacteristics mcIn = sec.getDataCharacteristics(params.get("target"));
+ String pattern = params.get("pattern");
+ String replacement = params.get("replacement");
+ JavaPairRDD<Long, FrameBlock> out = in1.mapValues(new RDDFrameReplaceFunction(pattern, replacement));
+ sec.setRDDHandleForVariable(output.getName(), out);
+ sec.addLineageRDD(output.getName(), params.get("target"));
+ sec.getDataCharacteristics(output.getName()).set(mcIn.getRows(), mcIn.getCols(), mcIn.getBlocksize(), mcIn.getNonZeros());
+ }
+ else {
+ JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
+ .getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
+ DataCharacteristics mcIn = sec.getDataCharacteristics(params.get("target"));
+
+ // execute replace operation
+ double pattern = Double.parseDouble(params.get("pattern"));
+ double replacement = Double.parseDouble(params.get("replacement"));
+ JavaPairRDD<MatrixIndexes, MatrixBlock> out = in1.mapValues(new RDDReplaceFunction(pattern, replacement));
+
+ // store output rdd handle
+ sec.setRDDHandleForVariable(output.getName(), out);
+ sec.addLineageRDD(output.getName(), params.get("target"));
+
+ // update output statistics (required for correctness)
+ sec.getDataCharacteristics(output.getName()).set(mcIn.getRows(),
+ mcIn.getCols(),
+ mcIn.getBlocksize(),
+ (pattern != 0 && replacement != 0) ? mcIn.getNonZeros() : -1);
+ }
- // update output statistics (required for correctness)
- DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
- mcOut.set(mcIn.getRows(),
- mcIn.getCols(),
- mcIn.getBlocksize(),
- (pattern != 0 && replacement != 0) ? mcIn.getNonZeros() : -1);
}
else if(opcode.equalsIgnoreCase("lowertri") || opcode.equalsIgnoreCase("uppertri")) {
JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
@@ -544,6 +557,22 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
}
}
+ public static class RDDFrameReplaceFunction implements Function<FrameBlock, FrameBlock>{
+ private static final long serialVersionUID = 6576713401901671660L;
+ private final String _pattern;
+ private final String _replacement;
+
+ public RDDFrameReplaceFunction(String pattern, String replacement){
+ _pattern = pattern;
+ _replacement = replacement;
+ }
+
+ @Override
+ public FrameBlock call(FrameBlock arg0){
+ return arg0.replaceOperations(_pattern, _replacement);
+ }
+ }
+
private static class RDDExtractTriangularFunction
implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
private static final long serialVersionUID = 2754868819184155702L;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index 322cfad..8ee6f33 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -2237,4 +2237,17 @@ public class FrameBlock implements CacheBlock, Externalizable {
public String apply(String input) {return null;}
public String apply(String input1, String input2) { return null;}
}
+
+ public FrameBlock replaceOperations(String pattern, String replacement){
+ FrameBlock ret = new FrameBlock(this);
+ for(int i = 0; i < ret.getNumColumns(); i++){
+ Array colData = ret._coldata[i];
+ for(int j = 0; j < colData._size; j++){
+ Object ent = colData.get(j);
+ if(ent != null && ent.equals(pattern))
+ colData.set(j,replacement);
+ }
+ }
+ return ret;
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameReplaceTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameReplaceTest.java
new file mode 100644
index 0000000..73868e3
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameReplaceTest.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.frame;
+
+import static org.junit.Assert.assertTrue;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class FrameReplaceTest extends AutomatedTestBase {
+ // private static final Log LOG = LogFactory.getLog(FrameReplaceTest.class.getName());
+ private final static String TEST_DIR = "functions/frame/";
+ private final static String TEST_NAME = "ReplaceTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FrameReplaceTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+ }
+
+ @Test
+ public void testParforFrameIntermediatesCP() {
+ runReplaceTest(ExecType.CP);
+ }
+
+ @Test
+ public void testParforFrameIntermediatesSpark() {
+ runReplaceTest(ExecType.SPARK);
+ }
+
+ private void runReplaceTest(ExecType et) {
+ ExecMode platformOld = rtplatform;
+ switch(et) {
+ case SPARK:
+ rtplatform = ExecMode.SPARK;
+ break;
+ default:
+ rtplatform = ExecMode.HYBRID;
+ break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if(rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ try {
+ // setup testcase
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {};
+
+ // run test
+ String out = runTest(null).toString();
+
+ assertTrue(out.contains("south"));
+ assertTrue(!out.contains("north"));
+
+ }
+ catch(Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ finally {
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+
+}
diff --git a/src/test/scripts/functions/frame/ReplaceTest.dml b/src/test/scripts/functions/frame/ReplaceTest.dml
new file mode 100644
index 0000000..2a12b48
--- /dev/null
+++ b/src/test/scripts/functions/frame/ReplaceTest.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read("src/test/resources/datasets/homes/homes.csv")
+
+X = replace(target = X, pattern="north", replacement="south")
+X = replace(target = X, pattern="east", replacement="south")
+X = replace(target = X, pattern="west", replacement="south")
+
+print(toString(X))
\ No newline at end of file