You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2023/08/11 13:12:22 UTC
[systemds] branch main updated: [SYSTEMDS-3613] Fix missing size propagation on transformapply/decode
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 7115b3707a [SYSTEMDS-3613] Fix missing size propagation on transformapply/decode
7115b3707a is described below
commit 7115b3707a802026ee287ce82d666b1b756941b5
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Fri Aug 11 15:08:10 2023 +0200
[SYSTEMDS-3613] Fix missing size propagation on transformapply/decode
This patch fixes the missing size propagation for transformapply and
transformdecode. By parsing the transformspec and/or using the meta
data frame (of original number of columns) we now infer the rows/cols
unless there are encoders that change the number of columns. For
feature hashing we could also support it, but for the sake of simplicity
currently don't do it.
---
.../apache/sysds/hops/ParameterizedBuiltinOp.java | 32 ++++++++++++++++++----
.../sysds/runtime/transform/meta/TfMetaUtils.java | 15 ++++++++++
2 files changed, 42 insertions(+), 5 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 4e70a3bf09..01883e2f5d 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -50,7 +50,10 @@ import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.instructions.cp.ParamservBuiltinCPInstruction;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
+import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.wink.json4j.JSONObject;
/**
@@ -840,16 +843,35 @@ public class ParameterizedBuiltinOp extends MultiThreadedHop {
}
case TRANSFORMDECODE: {
Hop target = getTargetHop();
+ Hop meta = getParameterHop("meta");
//rows remain unchanged for recoding and dummy coding
- setDim1( target.getDim1() );
- //cols remain unchanged only if no dummy coding
- //TODO parse json spec
+ setDim1(target.getDim1());
+ //cols remain unchanged only if no dummy coding, but meta aligned with input columns
+ setDim2(meta.getDim2());
break;
}
case TRANSFORMAPPLY: {
//rows remain unchanged only if no omitting
- //cols remain unchanged of no dummy coding
- //TODO parse json spec
+ //cols remain unchanged of no dummy coding, feature hashing, word embeddings
+ Hop target = getTargetHop();
+ Hop spec = getParameterHop("spec");
+ if( dimsKnown() ) {
+ //safe to update according to new input as previously parsed
+ setDim1(target.getDim1());
+ setDim2(target.getDim2());
+ }
+ else if( spec instanceof LiteralOp ) {
+ try {
+ JSONObject jspec = new JSONObject(((LiteralOp)spec).getStringValue());
+ if( TfMetaUtils.checkValidEncoders(jspec, TfMethod.RECODE, TfMethod.BIN, TfMethod.UDF) ) {
+ setDim1(target.getDim1());
+ setDim2(target.getDim2());
+ }
+ }
+ catch(Exception ex) {
+ throw new HopsException(ex);
+ }
+ }
break;
}
case TRANSFORMCOLMAP: {
diff --git a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
index 5ae26b1c3a..99fbe92bf2 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
@@ -490,4 +490,19 @@ public class TfMetaUtils
throw new DMLRuntimeException("Transform specification includes an invalid encoder: "+key);
}
}
+
+ @SuppressWarnings("unchecked")
+ public static boolean checkValidEncoders(JSONObject jSpec, TfMethod... encoders) {
+ Set<String> validEncoders = new HashSet<>();
+ validEncoders.addAll(Arrays.asList("ids","K"));
+ for( TfMethod tf : encoders )
+ validEncoders.add(tf.toString());
+ Iterator<String> keys = jSpec.keys();
+ while( keys.hasNext() ) {
+ String key = keys.next();
+ if( !validEncoders.contains(key) )
+ return false;
+ }
+ return true;
+ }
}