You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/12/28 22:47:23 UTC

[systemds] branch main updated: [SYSTEMDS-3108] Fix size inference for block indexing expressions

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a3be9c0  [SYSTEMDS-3108] Fix size inference for block indexing expressions
a3be9c0 is described below

commit a3be9c019a68bdb57f9062018fe4d108bfd75651
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Tue Dec 28 23:45:18 2021 +0100

    [SYSTEMDS-3108] Fix size inference for block indexing expressions
    
    This patch fixes an issue of the existing size inference for block
    indexing expressions such as (with nc being a constant)
    
    R = X[(nc * (i-1) + 1) : (nc * i), ];
    
    Previously, we specifically added this for Kmeans, but the detection
    logic expected i to be a transient read of a variable. Later rewrites
    (e.g., removal of branches and merge of basic blocks) then caused i
    to never be bound to a logical variable (hop intermediate) and thus
    requiring a workaround at script level.
    
    This patch makes the rewrite more general to (1) work for arbitrary
    variables (transient reads or intermediates) by comparing the hops
    directly (which depends on common subexpression elimination)
    as well as (2) variations such as (nc * i) and (i * nc).
---
 scripts/builtin/kmeans.dml                         |  1 -
 .../java/org/apache/sysds/hops/IndexingOp.java     | 35 +++++++++++-----------
 2 files changed, 17 insertions(+), 19 deletions(-)

diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml
index 45b74d1..1e7e9df 100644
--- a/scripts/builtin/kmeans.dml
+++ b/scripts/builtin/kmeans.dml
@@ -210,7 +210,6 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, Integer runs = 10, Integer
              + ";  Avg WCSS = " + avg_wcss + ";  Worst WCSS = " + worst_wcss);
 
     C = All_Centroids [(num_centroids * (best_index - 1) + 1) : (num_centroids * best_index), ];
-    while(FALSE){} # workaround to make ncol t(C) known
     D =  -2 * (X %*% t(C)) + t(rowSums (C ^ 2));
     P = (D <= rowMins (D));
     aggr_P = t(cumsum (t(P)));
diff --git a/src/main/java/org/apache/sysds/hops/IndexingOp.java b/src/main/java/org/apache/sysds/hops/IndexingOp.java
index 394a5bd..17c097e 100644
--- a/src/main/java/org/apache/sysds/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/IndexingOp.java
@@ -262,36 +262,35 @@ public class IndexingOp extends Hop
 	{
 		boolean ret = false;
 		LiteralOp constant = null;
-		DataOp var = null;
+		Hop var = null;
 
 		//handle lower bound
 		if( lbound instanceof BinaryOp && ((BinaryOp)lbound).getOp()==OpOp2.PLUS
-			&& lbound.getInput().get(1) instanceof LiteralOp 
-			&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)lbound.getInput().get(1))==1
-			&& lbound.getInput().get(0) instanceof BinaryOp)
+			&& lbound.getInput(1) instanceof LiteralOp 
+			&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)lbound.getInput(1))==1
+			&& lbound.getInput(0) instanceof BinaryOp)
 		{
-			BinaryOp lmult = (BinaryOp)lbound.getInput().get(0);
-			if( lmult.getOp()==OpOp2.MULT && lmult.getInput().get(0) instanceof LiteralOp
-				&& lmult.getInput().get(1) instanceof BinaryOp )
+			BinaryOp lmult = (BinaryOp)lbound.getInput(0);
+			if( lmult.getOp()==OpOp2.MULT && lmult.getInput(0) instanceof LiteralOp
+				&& lmult.getInput(1) instanceof BinaryOp )
 			{
-				BinaryOp lminus = (BinaryOp)lmult.getInput().get(1);
-				if( lminus.getOp()==OpOp2.MINUS && lminus.getInput().get(1) instanceof LiteralOp
-					&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)lminus.getInput().get(1))==1 
-					&& lminus.getInput().get(0) instanceof DataOp )
+				BinaryOp lminus = (BinaryOp)lmult.getInput(1);
+				if( lminus.getOp()==OpOp2.MINUS && lminus.getInput(1) instanceof LiteralOp
+					&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)lminus.getInput(1))==1 )
 				{
-					constant = (LiteralOp)lmult.getInput().get(0);
-					var = (DataOp) lminus.getInput().get(0);
+					constant = (LiteralOp)lmult.getInput(0);
+					var = lminus.getInput(0); //any DataOp or intermediate hop
 				}
 			}
 		}
 		
-		//handle upper bound
+		//handle upper bound (general check for var depends on CSE)
 		if( var != null && constant != null && ubound instanceof BinaryOp 
-			&& ubound.getInput().get(0) instanceof LiteralOp
-			&& ubound.getInput().get(1) instanceof DataOp 
-			&& ubound.getInput().get(1).getName().equals(var.getName()) ) 
+			&& ((ubound.getInput(0) instanceof LiteralOp && ubound.getInput(1) == var)
+			  ||(ubound.getInput(1) instanceof LiteralOp && ubound.getInput(0) == var)) )
 		{
-			LiteralOp constant2 = (LiteralOp)ubound.getInput().get(0);
+			int constIndex = (ubound.getInput(1) == var) ? 0 : 1;
+			LiteralOp constant2 = (LiteralOp)ubound.getInput(constIndex);
 			ret = ( HopRewriteUtils.getDoubleValueSafe(constant) == 
 					HopRewriteUtils.getDoubleValueSafe(constant2) );
 		}