You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2023/08/11 13:12:22 UTC

[systemds] branch main updated: [SYSTEMDS-3613] Fix missing size propagation on transformapply/decode

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 7115b3707a [SYSTEMDS-3613] Fix missing size propagation on transformapply/decode
7115b3707a is described below

commit 7115b3707a802026ee287ce82d666b1b756941b5
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Fri Aug 11 15:08:10 2023 +0200

    [SYSTEMDS-3613] Fix missing size propagation on transformapply/decode
    
    This patch fixes the missing size propagation for transformapply and
    transformdecode. By parsing the transformspec and/or using the meta
    data frame (of original number of columns) we now infer the rows/cols
    unless there are encoders that change the number of columns. For
    feature hashing we could also support it, but for the sake of simplicity
    currently don't do it.
---
 .../apache/sysds/hops/ParameterizedBuiltinOp.java  | 32 ++++++++++++++++++----
 .../sysds/runtime/transform/meta/TfMetaUtils.java  | 15 ++++++++++
 2 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 4e70a3bf09..01883e2f5d 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -50,7 +50,10 @@ import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.instructions.cp.ParamservBuiltinCPInstruction;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
+import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
 import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.wink.json4j.JSONObject;
 
 
 /**
@@ -840,16 +843,35 @@ public class ParameterizedBuiltinOp extends MultiThreadedHop {
 			}
 			case TRANSFORMDECODE: {
 				Hop target = getTargetHop();
+				Hop meta = getParameterHop("meta");
 				//rows remain unchanged for recoding and dummy coding
-				setDim1( target.getDim1() );
-				//cols remain unchanged only if no dummy coding
-				//TODO parse json spec
+				setDim1(target.getDim1());
+				//cols remain unchanged only if no dummy coding, but meta aligned with input columns
+				setDim2(meta.getDim2());
 				break;
 			}
 			case TRANSFORMAPPLY: {
 				//rows remain unchanged only if no omitting
-				//cols remain unchanged of no dummy coding 
-				//TODO parse json spec
+				//cols remain unchanged of no dummy coding, feature hashing, word embeddings
+				Hop target = getTargetHop();
+				Hop spec = getParameterHop("spec");
+				if( dimsKnown() ) {
+					//safe to update according to new input as previously parsed 
+					setDim1(target.getDim1());
+					setDim2(target.getDim2());
+				}
+				else if( spec instanceof LiteralOp ) {
+					try {
+						JSONObject jspec = new JSONObject(((LiteralOp)spec).getStringValue());
+						if( TfMetaUtils.checkValidEncoders(jspec, TfMethod.RECODE, TfMethod.BIN, TfMethod.UDF) ) {
+							setDim1(target.getDim1());
+							setDim2(target.getDim2());
+						}
+					}
+					catch(Exception ex) {
+						throw new HopsException(ex);
+					}
+				}
 				break;
 			}
 			case TRANSFORMCOLMAP: {
diff --git a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
index 5ae26b1c3a..99fbe92bf2 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
@@ -490,4 +490,19 @@ public class TfMetaUtils
 				throw new DMLRuntimeException("Transform specification includes an invalid encoder: "+key);
 		}
 	}
+	
+	@SuppressWarnings("unchecked")
+	public static boolean checkValidEncoders(JSONObject jSpec, TfMethod... encoders) {
+		Set<String> validEncoders = new HashSet<>();
+		validEncoders.addAll(Arrays.asList("ids","K"));
+		for( TfMethod tf : encoders )
+			validEncoders.add(tf.toString());
+		Iterator<String> keys = jSpec.keys();
+		while( keys.hasNext() ) {
+			String key = keys.next();
+			if( !validEncoders.contains(key) )
+				return false;
+		}
+		return true;
+	}
 }