You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/07/14 22:00:16 UTC

systemml git commit: [HOTFIX] Fix for recently updated validation code for convolution operation

Repository: systemml
Updated Branches:
  refs/heads/master ccac6dd37 -> 6778a63b0


[HOTFIX] Fix for recently updated validation code for convolution operation

- Tested NNTest in local environment.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6778a63b
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6778a63b
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6778a63b

Branch: refs/heads/master
Commit: 6778a63b02fc1c644501bae67cd24e639ed3a623
Parents: ccac6dd
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Fri Jul 14 13:59:41 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Fri Jul 14 14:59:41 2017 -0700

----------------------------------------------------------------------
 .../sysml/parser/BuiltinFunctionExpression.java | 94 ++++++++++++--------
 1 file changed, 57 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/6778a63b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 58760bc..54281cc 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -1124,45 +1124,65 @@ public class BuiltinFunctionExpression extends DataIdentifier
 			output.setDataType(DataType.MATRIX);
 			output.setValueType(ValueType.DOUBLE);
 			output.setBlockDimensions(input.getOutput().getRowsInBlock(), input.getOutput().getColumnsInBlock());
-			// stride1, stride2, padding1, padding2, numImg, numChannels, imgSize, imgSize, 
- 			// filter_shape1=1, filter_shape2=1, filterSize/poolSize1, filterSize/poolSize1
-			try {
-				int start = 2;
-				if(!(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) {
-					start = 1;
+			
+			if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD) {
+				output.setDimensions(input.getOutput().getDim1(), input.getOutput().getDim2());
+			}
+			else {
+				// stride1, stride2, padding1, padding2, numImg, numChannels, imgSize, imgSize, 
+	 			// filter_shape1=1, filter_shape2=1, filterSize/poolSize1, filterSize/poolSize1
+				try {
+					int start = 2;
+					if(!(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) {
+						start = 1;
+					}
+					long stride_h = (long) getDoubleValue(_args[start++]);
+					long stride_w = (long) getDoubleValue(_args[start++]);
+					long pad_h = (long) getDoubleValue(_args[start++]);
+					long pad_w = (long) getDoubleValue(_args[start++]); 
+					long N = (long) getDoubleValue(_args[start++]);
+					long C = (long) getDoubleValue(_args[start++]);
+					long H = (long) getDoubleValue(_args[start++]);
+					long W = (long) getDoubleValue(_args[start++]);
+					long K = -1;
+					if(!(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) {
+						K = (long) getDoubleValue(_args[start]);
+					}
+					start++; start++; // Increment index for K and C
+					long R = (long) getDoubleValue(_args[start++]);
+					long S = (long) getDoubleValue(_args[start++]);
+					
+					if(this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER) {
+						output.setDimensions(K, C*R*S);
+					}
+					else if(this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA) {
+						output.setDimensions(N, C*H*W);
+					}
+					else if(H > 0 && W > 0 && stride_h > 0 && stride_w > 0 && pad_h >= 0 && pad_w >= 0 && R > 0 && S > 0) {
+						long P = ConvolutionUtils.getP(H, R, stride_h, pad_h);
+						long Q = ConvolutionUtils.getQ(W, S, stride_w, pad_w);
+						
+						// Try to set both rows and columns
+						if(this.getOpCode() == BuiltinFunctionOp.CONV2D) 
+							output.setDimensions(N, K*P*Q);
+						else if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)
+							output.setDimensions(N, C*P*Q);
+						else
+							throw new LanguageException("");
+					}
+					else {
+						// Since columns cannot be computed, set only rows
+						if(this.getOpCode() == BuiltinFunctionOp.CONV2D) 
+							output.setDimensions(input.getOutput().getDim1(), -1);
+						else if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)
+							output.setDimensions(input.getOutput().getDim1(), -1);
+						else
+							throw new LanguageException("");
+					}
 				}
-				long stride_h = (long) getDoubleValue(_args[start++]);
-				long stride_w = (long) getDoubleValue(_args[start++]);
-				long pad_h = (long) getDoubleValue(_args[start++]);
-				long pad_w = (long) getDoubleValue(_args[start++]); 
-				long N = (long) getDoubleValue(_args[start++]);
-				long C = (long) getDoubleValue(_args[start++]);
-				long H = (long) getDoubleValue(_args[start++]);
-				long W = (long) getDoubleValue(_args[start++]);
-				long K = -1;
-				if(!(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) {
-					K = (long) getDoubleValue(_args[start]);
+				catch(Exception e) {
+					output.setDimensions(-1, -1); // To make sure that output dimensions are not incorrect even if getDoubleValue doesnot return value
 				}
-				start++; start++; // Increment index for K and C
-				long R = (long) getDoubleValue(_args[start++]);
-				long S = (long) getDoubleValue(_args[start++]);
-				long P = ConvolutionUtils.getP(H, R, stride_h, pad_h);
-				long Q = ConvolutionUtils.getP(W, S, stride_w, pad_w);
-				if(this.getOpCode() == BuiltinFunctionOp.CONV2D) 
-					output.setDimensions(N, K*P*Q);
-				else if(this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER)
-					output.setDimensions(K, C*R*S);
-				else if(this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA)
-					output.setDimensions(N, C*H*W);
-				else if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL)
-					output.setDimensions(N, C*P*Q);
-				else if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD)
-					output.setDimensions(N, C*H*W);
-				else
-					throw new LanguageException("");
-			}
-			catch(Exception e) {
-				output.setDimensions(input.getOutput().getDim1(), -1); // To make sure that output dimensions are not incorrect
 			}
 			checkMatrixParam(input);
 			if(input2 != null)