You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/05/10 19:30:31 UTC

[4/4] systemml git commit: [MINOR] Fix uaggouterchain compilation (output data types)

[MINOR] Fix uaggouterchain compilation (output data types)

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6f2c885e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6f2c885e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6f2c885e

Branch: refs/heads/master
Commit: 6f2c885e8aad480349e039fcd0390feb341b3639
Parents: f9020a1
Author: Matthias Boehm <mb...@gmail.com>
Authored: Thu May 10 12:27:58 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Thu May 10 12:28:57 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |  8 +++---
 .../cp/UaggOuterChainCPInstruction.java         | 27 +++++---------------
 .../binary/matrix/UaggOuterChainTest.java       |  3 +--
 3 files changed, 12 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/6f2c885e/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 136d2d6..d3e0570 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -191,14 +191,14 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 					
 						if (getDataType() == DataType.SCALAR) {
 							UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
-									                    getDataType(), getValueType());
+								getDataType(), getValueType());
 							unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
 							setLineNumbers(unary1);
-							setLops(unary1);
+							agg1 = unary1;
 						}
 					
-					}				
-					else { //general case		
+					}
+					else { //general case
 						int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
 						agg1 = new PartialAggregate(input.constructLops(), 
 								HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), getDataType(),getValueType(), et, k);

http://git-wip-us.apache.org/repos/asf/systemml/blob/6f2c885e/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
index 908e5bd..e6dd403 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
@@ -74,7 +74,7 @@ public class UaggOuterChainCPInstruction extends UnaryCPInstruction {
 		boolean rightCached = (_uaggOp.indexFn instanceof ReduceCol || _uaggOp.indexFn instanceof ReduceAll
 				|| !LibMatrixOuterAgg.isSupportedUaggOp(_uaggOp, _bOp));
 
-		MatrixBlock mbLeft = null, mbRight = null, mbOut = null;		
+		MatrixBlock mbLeft = null, mbRight = null, mbOut = null;
 		//get the main data input
 		if( rightCached ) { 
 			mbLeft = ec.getMatrixInput(input1.getName(), getExtendedOpcode());
@@ -94,26 +94,13 @@ public class UaggOuterChainCPInstruction extends UnaryCPInstruction {
 		if( _uaggOp.aggOp.correctionExists )
 			mbOut.dropLastRowsOrColumns(_uaggOp.aggOp.correctionLocation);
 		
-		String output_name = output.getName();
-		//final aggregation if required
-		if(_uaggOp.indexFn instanceof ReduceAll ) //RC AGG (output is scalar)
-		{
-			//create and set output scalar
-			ScalarObject ret = null;
-			switch( output.getValueType() ) {
-				case DOUBLE:  ret = new DoubleObject(mbOut.quickGetValue(0, 0)); break;
-				
-				default: 
-					throw new DMLRuntimeException("Invalid output value type: "+output.getValueType());
-			}
-			ec.setScalarOutput(output_name, ret);
+		if(_uaggOp.indexFn instanceof ReduceAll ) { //RC AGG (output is scalar)
+			ec.setMatrixOutput(output.getName(), new MatrixBlock(
+				mbOut.quickGetValue(0, 0)), getExtendedOpcode());
 		}
-		else //R/C AGG (output is rdd)
-		{	
-			//Additional memory requirement to convert from dense to sparse can be leveraged from released memory needed for input data above.
+		else { //R/C AGG (output is rdd)
 			mbOut.examSparsity();
-			ec.setMatrixOutput(output_name, mbOut, getExtendedOpcode());
+			ec.setMatrixOutput(output.getName(), mbOut, getExtendedOpcode());
 		}
-		
-	}		
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/6f2c885e/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java
index 04a00c9..e031b53 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/UaggOuterChainTest.java
@@ -44,7 +44,6 @@ import org.apache.sysml.utils.Statistics;
  */
 public class UaggOuterChainTest extends AutomatedTestBase 
 {
-	
 	private final static String TEST_NAME1 = "UaggOuterChain";
 	private final static String TEST_DIR = "functions/binary/matrix/";
 	private final static String TEST_CLASS_DIR = TEST_DIR + UaggOuterChainTest.class.getSimpleName() + "/";
@@ -1318,7 +1317,7 @@ public class UaggOuterChainTest extends AutomatedTestBase
 			
 			loadTestConfiguration(config, TEST_CACHE_DIR);
 			
-			String HOME = SCRIPT_DIR + TEST_DIR;			
+			String HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = HOME + TEST_NAME + suffix + strSumTypeSuffix + ".dml";
 			programArgs = new String[]{"-stats", "-explain","-args", 
 				input("A"), input("B"), output("C")};