You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/04/04 19:04:51 UTC

[2/3] incubator-systemml git commit: [SYSTEMML-1424] Extended codegen operations and cost model

[SYSTEMML-1424] Extended codegen operations and cost model

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/69d8b7c4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/69d8b7c4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/69d8b7c4

Branch: refs/heads/master
Commit: 69d8b7c4b53deb3a1d3e4eba99b8718366df1a86
Parents: 3547619
Author: Matthias Boehm <mb...@gmail.com>
Authored: Mon Apr 3 18:25:44 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Mon Apr 3 18:25:44 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeBinary.java   | 54 ++++++++++++--------
 .../sysml/hops/codegen/cplan/CNodeUnary.java    |  6 +--
 .../template/PlanSelectionFuseCostBased.java    | 33 +++++++++---
 3 files changed, 60 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/69d8b7c4/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index 5ec7231..b6b6ce5 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -35,9 +35,9 @@ public class CNodeBinary extends CNode
 		VECT_LESS_SCALAR, VECT_LESSEQUAL_SCALAR, VECT_GREATER_SCALAR, VECT_GREATEREQUAL_SCALAR,
 		MULT, DIV, PLUS, MINUS, MODULUS, INTDIV, 
 		LESS, LESSEQUAL, GREATER, GREATEREQUAL, EQUAL,NOTEQUAL,
-		MIN, MAX, AND, OR, LOG, POW,
-		MINUS1_MULT;
-
+		MIN, MAX, AND, OR, LOG, LOG_NZ, POW,
+		MINUS1_MULT, MINUS_NZ;
+		
 		public static boolean contains(String value) {
 			for( BinType bt : values()  )
 				if( bt.name().equals(value) )
@@ -85,41 +85,45 @@ public class CNodeBinary extends CNode
 				
 				/*Can be replaced by function objects*/
 				case MULT:
-					return "    double %TMP% = %IN1% * %IN2%;\n" ;
+					return "    double %TMP% = %IN1% * %IN2%;\n";
 				
 				case DIV:
-					return "    double %TMP% = %IN1% / %IN2%;\n" ;
+					return "    double %TMP% = %IN1% / %IN2%;\n";
 				case PLUS:
-					return "    double %TMP% = %IN1% + %IN2%;\n" ;
+					return "    double %TMP% = %IN1% + %IN2%;\n";
 				case MINUS:
-					return "    double %TMP% = %IN1% - %IN2%;\n" ;
+					return "    double %TMP% = %IN1% - %IN2%;\n";
 				case MODULUS:
-					return "    double %TMP% = LibSpoofPrimitives.mod(%IN1%, %IN2%);\n" ;
+					return "    double %TMP% = LibSpoofPrimitives.mod(%IN1%, %IN2%);\n";
 				case INTDIV: 
-					return "    double %TMP% = LibSpoofPrimitives.intDiv(%IN1%, %IN2%);\n" ;
+					return "    double %TMP% = LibSpoofPrimitives.intDiv(%IN1%, %IN2%);\n";
 				case LESS:
-					return "    double %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n" ;
+					return "    double %TMP% = (%IN1% < %IN2%) ? 1 : 0;\n";
 				case LESSEQUAL:
-					return "    double %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n" ;
+					return "    double %TMP% = (%IN1% <= %IN2%) ? 1 : 0;\n";
 				case GREATER:
-					return "    double %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n" ;
+					return "    double %TMP% = (%IN1% > %IN2%) ? 1 : 0;\n";
 				case GREATEREQUAL: 
-					return "    double %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n" ;
+					return "    double %TMP% = (%IN1% >= %IN2%) ? 1 : 0;\n";
 				case EQUAL:
-					return "    double %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n" ;
+					return "    double %TMP% = (%IN1% == %IN2%) ? 1 : 0;\n";
 				case NOTEQUAL: 
-					return "    double %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n" ;
+					return "    double %TMP% = (%IN1% != %IN2%) ? 1 : 0;\n";
 				
 				case MIN:
-					return "    double %TMP% = (%IN1% <= %IN2%) ? %IN1% : %IN2%;\n" ;
+					return "    double %TMP% = (%IN1% <= %IN2%) ? %IN1% : %IN2%;\n";
 				case MAX:
-					return "    double %TMP% = (%IN1% >= %IN2%) ? %IN1% : %IN2%;\n" ;
+					return "    double %TMP% = (%IN1% >= %IN2%) ? %IN1% : %IN2%;\n";
 				case LOG:
-					return "    double %TMP% = FastMath.log(%IN1%)/FastMath.log(%IN2%);\n" ;
+					return "    double %TMP% = FastMath.log(%IN1%)/FastMath.log(%IN2%);\n";
+				case LOG_NZ:
+					return "    double %TMP% = (%IN1% == 0) ? 0 : FastMath.log(%IN1%)/FastMath.log(%IN2%);\n";	
 				case POW:
-					return "    double %TMP% = Math.pow(%IN1%, %IN2%);\n" ;
+					return "    double %TMP% = Math.pow(%IN1%, %IN2%);\n";
 				case MINUS1_MULT:
-					return "    double %TMP% = 1 - %IN1% * %IN2%;\n" ;
+					return "    double %TMP% = 1 - %IN1% * %IN2%;\n";
+				case MINUS_NZ:
+					return "    double %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
 					
 				default: 
 					throw new RuntimeException("Invalid binary type: "+this.toString());
@@ -225,6 +229,7 @@ public class CNodeBinary extends CNode
 			case DIV: return "b(/)";
 			case PLUS: return "b(+)";
 			case MINUS: return "b(-)";
+			case POW: return "b(^)";
 			case MODULUS: return "b(%%)";
 			case INTDIV: return "b(%/%)";
 			case LESS: return "b(<)";
@@ -233,8 +238,11 @@ public class CNodeBinary extends CNode
 			case GREATEREQUAL: return "b(>=)";
 			case EQUAL: return "b(==)";
 			case NOTEQUAL: return "b(!=)";
+			case OR: return "b(|)";
+			case AND: return "b(&)";
 			case MINUS1_MULT: return "b(1-*)";
-			default: return "b("+_type.name()+")";
+			case MINUS_NZ: return "b(-nz)";
+			default: return "b("+_type.name().toLowerCase()+")";
 		}
 	}
 	
@@ -277,7 +285,8 @@ public class CNodeBinary extends CNode
 			case DIV: 
 			case PLUS: 
 			case MINUS: 
-			case MINUS1_MULT:	
+			case MINUS1_MULT:
+			case MINUS_NZ:
 			case MODULUS: 
 			case INTDIV: 	
 			//SCALAR Comparison
@@ -293,6 +302,7 @@ public class CNodeBinary extends CNode
 			case AND: 
 			case OR: 			
 			case LOG: 
+			case LOG_NZ:	
 			case POW: 
 				_rows = 0;
 				_cols = 0;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/69d8b7c4/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index 75b2630..119dc8c 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -28,10 +28,10 @@ public class CNodeUnary extends CNode
 {
 	public enum UnaryType {
 		ROW_SUMS, LOOKUP_R, LOOKUP_RC, LOOKUP0, //codegen specific
-		EXP, POW2, MULT2, SQRT, LOG,
+		EXP, POW2, MULT2, SQRT, LOG, LOG_NZ,
 		ABS, ROUND, CEIL, FLOOR, SIGN, 
 		SIN, COS, TAN, ASIN, ACOS, ATAN,
-		SELP, SPROP, SIGMOID, LOG_NZ; 
+		SELP, SPROP, SIGMOID; 
 		
 		public static boolean contains(String value) {
 			for( UnaryType ut : values()  )
@@ -156,7 +156,7 @@ public class CNodeUnary extends CNode
 			case LOOKUP_R:	return "u(ixr)";
 			case LOOKUP_RC:	return "u(ixrc)";
 			case LOOKUP0:	return "u(ix0)";
-			default:		return "u("+_type.name()+")";
+			default:		return "u("+_type.name().toLowerCase()+")";
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/69d8b7c4/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
index 653f43b..47717c2 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
@@ -453,11 +453,13 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 				case CEIL:
 				case FLOOR:
 				case SIGN:
-				case SELP:   costs = 1; break; 
+				case SELP:    costs = 1; break; 
 				case SPROP:
-				case SQRT:   costs = 2; break;
-				case EXP:    costs = 18; break;
-				case LOG:    costs = 32; break;
+				case SQRT:    costs = 2; break;
+				case EXP:     costs = 18; break;
+				case SIGMOID: costs = 21; break;
+				case LOG:    
+				case LOG_NZ:  costs = 32; break;
 				case NCOL:
 				case NROW:
 				case PRINT:
@@ -466,6 +468,12 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 				case CAST_AS_INT:
 				case CAST_AS_MATRIX:
 				case CAST_AS_SCALAR: costs = 1; break;
+				case SIN:     costs = 18; break;
+				case COS:     costs = 22; break;
+				case TAN:     costs = 42; break;
+				case ASIN:    costs = 93; break;
+				case ACOS:    costs = 103; break;
+				case ATAN:    costs = 40; break;
 				case CUMSUM:
 				case CUMMIN:
 				case CUMMAX:
@@ -480,6 +488,10 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 				case MULT: 
 				case PLUS:
 				case MINUS:
+				case MIN:
+				case MAX: 
+				case AND:
+				case OR:
 				case EQUAL:
 				case NOTEQUAL:
 				case LESS:
@@ -487,11 +499,16 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 				case GREATER:
 				case GREATEREQUAL: 
 				case CBIND:
-				case RBIND: costs = 1; break;
-				case DIV:   costs = 22; break;
-				case LOG:   costs = 32; break;
-				case POW:   costs = (HopRewriteUtils.isLiteralOfValue(
+				case RBIND:   costs = 1; break;
+				case INTDIV:  costs = 6; break;
+				case MODULUS: costs = 8; break;
+				case DIV:    costs = 22; break;
+				case LOG:
+				case LOG_NZ: costs = 32; break;
+				case POW:    costs = (HopRewriteUtils.isLiteralOfValue(
 						current.getInput().get(1), 2) ? 1 : 16); break;
+				case MINUS_NZ:
+				case MINUS1_MULT: costs = 2; break;
 				default:
 					throw new RuntimeException("Cost model not "
 						+ "implemented yet for: "+((BinaryOp)current).getOp());