You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2019/03/24 16:07:22 UTC

[systemml] branch master updated: [SYSTEMML-540] Added ternary aggregate operators for GPU backend

This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 7fba4b2  [SYSTEMML-540] Added ternary aggregate operators for GPU backend
7fba4b2 is described below

commit 7fba4b29d653747a9ed038d282954a44fea3031c
Author: Niketan Pansare <np...@us.ibm.com>
AuthorDate: Sun Mar 24 09:06:55 2019 -0700

    [SYSTEMML-540] Added ternary aggregate operators for GPU backend
    
    - Also added steps to upload SystemML's python package to pypi.
---
 docs/release-process.md                            |  25 +++-
 .../java/org/apache/sysml/hops/AggUnaryOp.java     |  11 +-
 .../runtime/instructions/GPUInstructionParser.java |   7 ++
 .../gpu/AggregateTernaryGPUInstruction.java        | 130 +++++++++++++++++++++
 .../runtime/instructions/gpu/GPUInstruction.java   |   1 +
 .../sysml/runtime/matrix/data/LibMatrixCUDA.java   |  13 ++-
 .../sysml/test/gpu/AggregateTernaryTests.java      |  57 +++++++++
 .../sysml/test/gpu/AggregateUnaryOpTests.java      |   1 +
 .../apache/sysml/test/gpu/UnaryOpTestsBase.java    |  18 +++
 9 files changed, 250 insertions(+), 13 deletions(-)

diff --git a/docs/release-process.md b/docs/release-process.md
index 2477cd0..c50a27e 100644
--- a/docs/release-process.md
+++ b/docs/release-process.md
@@ -388,7 +388,7 @@ file and remove all the `@Ignore` annotations from all the tests. Then run the N
 # Run other GPU Unit Tests 
 
 	rm result.txt
-	for t in AggregateUnaryOpTests  BinaryOpTests  MatrixMatrixElementWiseOpTests  RightIndexingTests AppendTest  MatrixMultiplicationOpTest ReorgOpTests ScalarMatrixElementwiseOpTests UnaryOpTests LstmTest LstmCPUTest
+	for t in AggregateUnaryOpTests AggregateTernaryTests  BinaryOpTests  MatrixMatrixElementWiseOpTests  RightIndexingTests AppendTest  MatrixMultiplicationOpTest ReorgOpTests ScalarMatrixElementwiseOpTests UnaryOpTests LstmTest LstmCPUTest
 	do
 		mvn -Dit.test="org.apache.sysml.test.gpu."$t verify -PgpuTests &> tmp.txt
 		SUCCESS=`grep "BUILD SUCCESS" tmp.txt`
@@ -503,8 +503,23 @@ The versioned project documentation is now deployed to the main website, and the
 
 ## Update Crawler configuration for the search indexing
 
-Create a PR or an issue to update the version number in the crawler configuration. 
-Please see the `start_urls` tag in the file [https://github.com/algolia/docsearch-configs/blob/master/configs/apache_systemml.json](https://github.com/algolia/docsearch-configs/blob/master/configs/apache_systemml.json).
-If the Algolia team provides us an updated `apiKey` or `indexName` credentials, then please update the corresponding entries in the file 
+- Create a PR or an issue to update the version number in the crawler configuration. Please see the `start_urls` tag in the file [https://github.com/algolia/docsearch-configs/blob/master/configs/apache_systemml.json](https://github.com/algolia/docsearch-configs/blob/master/configs/apache_systemml.json).
+- If the Algolia team provides us an updated `apiKey` or `indexName` credentials, then please update the corresponding entries in the file 
 [https://github.com/apache/systemml/blob/master/docs/_layouts/global.html](https://github.com/apache/systemml/blob/master/docs/_layouts/global.html) 
-(see for `Algolia search section` in the previously mentioned HTML file).
\ No newline at end of file
+(see for `Algolia search section` in the previously mentioned HTML file).
+
+## Upload Python package to PyPI
+
+Download the released `systemml-*-python.tar.gz` and `systemml-*-python.tar.gz`.
+
+	$ wget https://dist.apache.org/repos/dist/release/systemml/1.0.0/systemml-1.0.0-python.tar.gz
+	$ wget https://dist.apache.org/repos/dist/release/systemml/1.0.0/systemml-1.0.0-python.tar.gz.asc
+	
+Rename the files to remove `-python` suffix.
+
+	$ mv systemml-1.0.0-python.tar.gz systemml-1.0.0.tar.gz
+	$ mv systemml-1.0.0-python.tar.gz.asc systemml-1.0.0.tar.gz.asc
+
+Upload the Python package to PyPI using [twine](https://pypi.org/project/twine/).
+
+	$ twine upload -u systemml systemml-1.0.0.tar.gz systemml-1.0.0.tar.gz.asc 
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 48d18b7..92ec22c 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -93,9 +93,12 @@ public class AggUnaryOp extends MultiThreadedHop
 			return false;
 		
 		try {
-			if( isTernaryAggregateRewriteApplicable() || isUnaryAggregateOuterCPRewriteApplicable() ) {
+			if(isUnaryAggregateOuterCPRewriteApplicable()) {
 				return false;
 			}
+			else if(isTernaryAggregateRewriteApplicable()) {
+				return true;
+			}
 			else if ((_op == AggOp.SUM    && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col))
 					 || (_op == AggOp.SUM_SQ && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col))
 					 || (_op == AggOp.MAX    && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col))
@@ -498,10 +501,6 @@ public class AggUnaryOp extends MultiThreadedHop
 	{
 		boolean ret = false;
 		
-		// TODO: Disable ternary aggregate rewrite on GPU backend.
-		if(!ConfigurationManager.isGPU())
-			return false;
-		
 		//currently we support only sum over binary multiply but potentially 
 		//it can be generalized to any RC aggregate over two common binary operations
 		if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && _op == AggOp.SUM &&
@@ -713,8 +712,6 @@ public class AggUnaryOp extends MultiThreadedHop
 		// The execution type of a unary aggregate instruction should depend on the execution type of inputs to avoid OOM
 		// Since we only support matrix-vector and not vector-matrix, checking the execution type of input1 should suffice.
 		ExecType et_input = input1.optFindExecType();
-		// Because ternary aggregate are not supported on GPU
-		et_input = et_input == ExecType.GPU ? ExecType.CP :  et_input;
 		DirectionTypes dir = HopsDirection2Lops.get(_direction);
 		
 		return new TernaryAggregate(in1, in2, in3, Aggregate.OperationTypes.KahanSum, 
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 20058de..aabb36f 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -23,6 +23,7 @@ import java.util.HashMap;
 import org.apache.sysml.lops.RightIndex;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.instructions.gpu.AggregateBinaryGPUInstruction;
+import org.apache.sysml.runtime.instructions.gpu.AggregateTernaryGPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.ArithmeticBinaryGPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.BuiltinBinaryGPUInstruction;
 import org.apache.sysml.runtime.instructions.gpu.BuiltinUnaryGPUInstruction;
@@ -43,6 +44,9 @@ public class GPUInstructionParser  extends InstructionParser
 	static final HashMap<String, GPUINSTRUCTION_TYPE> String2GPUInstructionType;
 	static {
 		String2GPUInstructionType = new HashMap<>();
+		
+		String2GPUInstructionType.put( "tak+*"   , GPUINSTRUCTION_TYPE.AggregateTernary);
+		String2GPUInstructionType.put( "tack+*"  , GPUINSTRUCTION_TYPE.AggregateTernary);
 
 		// Neural Network Operators
 		String2GPUInstructionType.put( "relu_backward",          GPUINSTRUCTION_TYPE.Dnn);
@@ -179,6 +183,9 @@ public class GPUInstructionParser  extends InstructionParser
 		switch(gputype) {
 			case AggregateUnary:
 				return AggregateUnaryGPUInstruction.parseInstruction(str);
+				
+			case AggregateTernary:
+				return AggregateTernaryGPUInstruction.parseInstruction(str);
 
 			case AggregateBinary:
 				return AggregateBinaryGPUInstruction.parseInstruction(str);
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateTernaryGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateTernaryGPUInstruction.java
new file mode 100644
index 0000000..53eab47
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateTernaryGPUInstruction.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.gpu;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.functionobjects.Multiply;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.DoubleObject;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
+import org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+import org.apache.sysml.utils.GPUStatistics;
+
+import jcuda.Pointer;
+
+public class AggregateTernaryGPUInstruction extends GPUInstruction {
+
+	private CPOperand _input1 = null;
+	private CPOperand _input2 = null;
+	private CPOperand _input3 = null;
+	private CPOperand _output = null;
+	
+	private AggregateTernaryGPUInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
+			String opcode, String istr) {
+		super(op, opcode, istr);
+		_gputype = GPUINSTRUCTION_TYPE.AggregateTernary;
+		_input1 = in1;
+		_input2 = in1;
+		_input3 = in1;
+		_output = out;
+	}
+
+	public static AggregateTernaryGPUInstruction parseInstruction( String str ) {
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode = parts[0];
+		
+		if ( opcode.equalsIgnoreCase("tak+*") || opcode.equalsIgnoreCase("tack+*") ) {
+			InstructionUtils.checkNumFields( parts, 4 );
+			
+			CPOperand in1 = new CPOperand(parts[1]);
+			CPOperand in2 = new CPOperand(parts[2]);
+			CPOperand in3 = new CPOperand(parts[3]);
+			CPOperand out = new CPOperand(parts[4]);
+			
+			AggregateTernaryOperator op = InstructionUtils.parseAggregateTernaryOperator(opcode, 1);
+			return new AggregateTernaryGPUInstruction(op, in1, in2, in3, out, opcode, str);
+		} 
+		else {
+			throw new DMLRuntimeException("AggregateTernaryGPUInstruction.parseInstruction():: Unknown opcode " + opcode);
+		}		
+	}
+	
+	@Override
+	public void processInstruction(ExecutionContext ec) {
+		GPUStatistics.incrementNoOfExecutedGPUInst();
+		GPUContext gCtx = ec.getGPUContext(0);
+		String instName = getExtendedOpcode();
+		AggregateTernaryOperator ab_op = (AggregateTernaryOperator) _optr;
+		MatrixObject in1 = getMatrixInputForGPUInstruction(ec, _input1.getName());
+		MatrixObject in2 = getMatrixInputForGPUInstruction(ec, _input2.getName());
+		
+		BinaryOperator bop = new BinaryOperator(Multiply.getMultiplyFnObject());
+		
+		int rlenA = LibMatrixCUDA.toInt(in1.getNumRows());
+		int rlenB = LibMatrixCUDA.toInt(in2.getNumRows());
+		int clenA = LibMatrixCUDA.toInt(in1.getNumColumns());
+		int clenB = LibMatrixCUDA.toInt(in2.getNumColumns());
+		int rlenOut = Math.max(rlenA, rlenB);
+		int clenOut = Math.max(clenA, clenB);
+		int sizeOfOutput =  rlenOut*clenOut;
+		Pointer out = gCtx.allocate(instName, sizeOfOutput*LibMatrixCUDA.sizeOfDataType);
+	
+		// out = in1 * in2
+		Pointer A = LibMatrixCUDA.getDensePointer(gCtx, in1, instName); 
+		Pointer B = LibMatrixCUDA.getDensePointer(gCtx, in2, instName);
+		LibMatrixCUDA.denseMatrixMatrixOp(gCtx, instName, A, B, rlenA, clenA, rlenB, clenB, out, bop);
+		ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+		ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+		
+		if(!_input3.isLiteral()) {
+			// out = out * in3
+			MatrixObject in3 = getMatrixInputForGPUInstruction(ec, _input3.getName());
+			rlenB = LibMatrixCUDA.toInt(in3.getNumRows());
+			clenB = LibMatrixCUDA.toInt(in3.getNumColumns());
+			if(rlenB*clenB > sizeOfOutput) {
+				throw new DMLRuntimeException("Matrix-vector AggregateTernaryGPUInstruction is not supported.");
+			}
+			B = LibMatrixCUDA.getDensePointer(gCtx, in3, instName);
+			LibMatrixCUDA.denseMatrixMatrixOp(gCtx, instName, out, B, rlenA, clenA, rlenB, clenB, out, bop);
+			ec.releaseMatrixInputForGPUInstruction(_input3.getName());
+		}
+		
+		if( _output.getDataType().isScalar() ) {
+			// sum( in1*in2*in3 )
+			double result = LibMatrixCUDA.reduceAll(gCtx, instName, "reduce_sum", out, sizeOfOutput);
+			ec.setScalarOutput(_output.getName(), new DoubleObject(result));
+		}
+		else {
+			// colSum( in1*in2*in3 )
+			Pointer out1 = LibMatrixCUDA.getDensePointer(gCtx, 
+					LibMatrixCUDA.getDenseMatrixOutputForGPUInstruction(ec, instName, _output.getName(), 1, clenOut), instName);
+			LibMatrixCUDA.reduceCol(gCtx, instName, "reduce_col_sum", out, out1, rlenOut, clenOut);
+			ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+		}
+		
+		gCtx.cudaFreeHelper(instName, out, gCtx.EAGER_CUDA_FREE);
+	}
+}
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
index 7f3b017..8b703e6 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
@@ -36,6 +36,7 @@ import org.apache.sysml.utils.Statistics;
 public abstract class GPUInstruction extends Instruction {
 	public enum GPUINSTRUCTION_TYPE {
 		AggregateUnary,
+		AggregateTernary,
 		AggregateBinary,
 		RelationalBinary,
 		Dnn,
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index fd06578..657143a 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -984,7 +984,7 @@ public class LibMatrixCUDA {
 	 * @param n								size of array
 	 * @return	the reduced value
 	 */
-	private static double reduceAll(GPUContext gCtx, String instName, String kernelFunction, Pointer in, int n) {
+	public static double reduceAll(GPUContext gCtx, String instName, String kernelFunction, Pointer in, int n) {
 		if(LOG.isTraceEnabled()) {
 			LOG.trace("GPU : reduceAll for " + kernelFunction + ", GPUContext=" + gCtx);
 		}
@@ -1530,6 +1530,17 @@ public class LibMatrixCUDA {
 				a, b, c, maxRlen, maxClen, vecStatusA, vecStatusB, getBinaryOp(op.fn));
 		if (ConfigurationManager.isFinegrainedStatistics()) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MATRIX_MATRIX_CELLWISE_OP_KERNEL, System.nanoTime() - t0);
 	}
+	
+	public static void denseMatrixMatrixOp(GPUContext gCtx, String instName, 
+			Pointer A, Pointer B,
+			int rlenA, int clenA, int rlenB, int clenB, 
+			Pointer C, BinaryOperator op) {
+		int vecStatusA = LibMatrixCUDA.getVectorStatus(rlenA, clenA).code();
+		int vecStatusB = LibMatrixCUDA.getVectorStatus(rlenB, clenB).code();
+		int maxRlen = Math.max(rlenA, rlenB);
+		int maxClen = Math.max(clenA, clenB);
+		matrixMatrixOp(gCtx, instName, A, B, maxRlen, maxClen, vecStatusA, vecStatusB, C, op);
+	}
 
 	/**
 	 * This enum declares the different vector shapes
diff --git a/src/test/java/org/apache/sysml/test/gpu/AggregateTernaryTests.java b/src/test/java/org/apache/sysml/test/gpu/AggregateTernaryTests.java
new file mode 100644
index 0000000..578eb26
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/gpu/AggregateTernaryTests.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.gpu;
+
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+/**
+ * Tests Ternary Aggregate ops
+ */
+public class AggregateTernaryTests extends UnaryOpTestsBase {
+
+	private final static String TEST_NAME = "AggregateTernaryTests";
+
+	@Override
+	public void setUp() {
+		super.setUp();
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_DIR, TEST_NAME);
+		getAndLoadTestConfiguration(TEST_NAME);
+	}
+	
+	@Test
+	public void ternaryAgg1() {
+		testTernaryUnaryOpMatrixOutput("out = sum(in1*in2*in3)", "gpu_tak+*", "in1", "in2", "in3",  "out", 30, 40, 0.9);
+	}
+	@Test
+	public void ternaryAgg2() {
+		testTernaryUnaryOpMatrixOutput("out = colSums(in1*in2*in3)", "gpu_tack+*", "in1", "in2", "in3",  "out", 30, 40, 0.9);
+	}
+	
+	@Test
+	public void ternaryAgg3() {
+		testTernaryUnaryOpMatrixOutput("out = sum(in1*in2*in3)", "gpu_tak+*", "in1", "in2", "in3",  "out", 30, 40, 0.2);
+	}
+	@Test
+	public void ternaryAgg4() {
+		testTernaryUnaryOpMatrixOutput("out = colSums(in1*in2*in3)", "gpu_tack+*", "in1", "in2", "in3",  "out", 30, 40, 0.2);
+	}
+}
diff --git a/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java b/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java
index 78a7c1b..ee6af94 100644
--- a/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/AggregateUnaryOpTests.java
@@ -162,4 +162,5 @@ public class AggregateUnaryOpTests extends UnaryOpTestsBase {
 	public void colSumsqs() {
 		testUnaryOpMatrixOutput("out = colSums(in1*in1)", "gpu_uacsqk+", "in1", "out");
 	}
+	
 }
diff --git a/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java b/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java
index 0f6b59c..1726ca7 100644
--- a/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java
+++ b/src/test/java/org/apache/sysml/test/gpu/UnaryOpTestsBase.java
@@ -102,5 +102,23 @@ public abstract class UnaryOpTestsBase extends GPUTests {
 		//assertHeavyHitterPresent(heavyHitterOpCode);
 		assertEqualObjects(outCPU.get(0), outGPU.get(0));
 	}
+	
+	public void testTernaryUnaryOpMatrixOutput(String scriptStr, String heavyHitterOpCode, 
+			String inStr1, String inStr2, String inStr3,  
+			String outStr,
+			int row, int column, double sparsity) {
+		int seed = 99;
+		Matrix in1 = generateInputMatrix(spark, row, column, sparsity, seed);
+		Matrix in2 = generateInputMatrix(spark, row, column, sparsity, seed);
+		Matrix in3 = generateInputMatrix(spark, row, column, sparsity, seed);
+		HashMap<String, Object> inputs = new HashMap<>();
+		inputs.put(inStr1, in1);
+		inputs.put(inStr2, in2);
+		inputs.put(inStr3, in3);
+		List<Object> outCPU = runOnCPU(spark, scriptStr, inputs, Arrays.asList(outStr));
+		List<Object> outGPU = runOnGPU(spark, scriptStr, inputs, Arrays.asList(outStr));
+		assertHeavyHitterPresent(heavyHitterOpCode);
+		assertEqualObjects(outCPU.get(0), outGPU.get(0));
+	}
 
 }