You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2022/07/26 07:26:23 UTC

[systemds] branch main updated: [SYSTEMDS-3018] Federated Rewriting Fixes

This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 03fc10328a [SYSTEMDS-3018] Federated Rewriting Fixes
03fc10328a is described below

commit 03fc10328a18fe731d9d2089e25802518cb26d27
Author: sebwrede <sw...@know-center.at>
AuthorDate: Fri Jul 22 10:39:00 2022 +0200

    [SYSTEMDS-3018] Federated Rewriting Fixes
    
    Edit Repetition Estimate Update To Prevent Infinite Loops.
    Add Memo Table Size Explain and Fed Instruction Parsing Detail.
    
    Closes #1669.
---
 src/main/java/org/apache/sysds/hops/Hop.java                       | 5 ++++-
 src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java      | 6 +++++-
 .../apache/sysds/runtime/instructions/FEDInstructionParser.java    | 1 +
 src/main/java/org/apache/sysds/utils/Explain.java                  | 7 +++++++
 4 files changed, 17 insertions(+), 2 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index 4d1dff8f22..3988a6b59f 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -94,6 +94,7 @@ public abstract class Hop implements ParseInfo {
 	protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
 	protected FederatedCost _federatedCost = new FederatedCost();
 	protected double repetitions = 1;
+	protected boolean repetitionsUpdated = false;
 
 	/**
 	 * Field defining if prefetch should be activated for operation.
@@ -1556,8 +1557,10 @@ public abstract class Hop implements ParseInfo {
 	}
 
 	public void updateRepetitionEstimates(double repetitions){
-		if ( !federatedCostInitialized() ){
+		LOG.trace("Updating repetition estimates of " + this.getName() + " to " + repetitions);
+		if ( !federatedCostInitialized() && !repetitionsUpdated ){
 			this.repetitions = repetitions;
+			this.repetitionsUpdated = true;
 			for ( Hop input : getInput() )
 				input.updateRepetitionEstimates(repetitions);
 		}
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
index f84aecc5e8..3ecb0b29b9 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
@@ -161,10 +161,14 @@ public class MemoTable {
 			.orElseThrow(() -> new DMLRuntimeException("FType not found in memo"));
 	}
 
+	public int getSize(){
+		return hopRelMemo.size();
+	}
+
 	@Override
 	public String toString(){
 		StringBuilder sb = new StringBuilder();
-		sb.append("Federated MemoTable has ").append(hopRelMemo.size()).append(" entries with the following values:");
+		sb.append("Federated MemoTable has ").append(getSize()).append(" entries with the following values:");
 		sb.append("\n").append("{").append("\n");
 		for (Map.Entry<Long,List<HopRel>> hopEntry : hopRelMemo.entrySet()){
 			sb.append("  ").append(hopEntry.getKey()).append(":").append("\n");
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 81d2983da1..f61e86e800 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -73,6 +73,7 @@ public class FEDInstructionParser extends InstructionParser
 		String2FEDInstructionType.put( "/" ,  FEDType.Binary );
 		String2FEDInstructionType.put( "1-*", FEDType.Binary); //special * case
 		String2FEDInstructionType.put( "^2" , FEDType.Binary); //special ^ case
+		String2FEDInstructionType.put( "*2" , FEDType.Binary); //special * case
 		String2FEDInstructionType.put( "max", FEDType.Binary );
 		String2FEDInstructionType.put( "==",  FEDType.Binary);
 		String2FEDInstructionType.put( "!=",  FEDType.Binary);
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java b/src/main/java/org/apache/sysds/utils/Explain.java
index ded46c039a..75740c1c5a 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -125,6 +125,7 @@ public class Explain
 		return "# EXPLAIN ("+type.name()+"):\n"
 				+ Explain.explainMemoryBudget(counts)+"\n"
 				+ Explain.explainDegreeOfParallelism(counts)
+				+ Explain.explainMemoTableSize()
 				+ Explain.explain(prog, rtprog, type, counts);
 	}
 
@@ -185,6 +186,12 @@ public class Explain
 		return sb.toString();
 	}
 
+	private static String explainMemoTableSize(){
+		if ( MEMO_TABLE != null )
+			return "\n# Number of HOPs in Memo = " + MEMO_TABLE.getSize();
+		else return "";
+	}
+
 	public static String explain(DMLProgram prog, Program rtprog, ExplainType type) {
 		return explain(prog, rtprog, type, null);
 	}