You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by na...@apache.org on 2017/07/05 18:34:13 UTC
systemml git commit: [SYSTEMML-1735] relational operators for GPU
Repository: systemml
Updated Branches:
refs/heads/master 978d4de47 -> a7364746a
[SYSTEMML-1735] relational operators for GPU
Closes #557
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a7364746
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a7364746
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a7364746
Branch: refs/heads/master
Commit: a7364746a462069853421d59db1093ab145253c9
Parents: 978d4de
Author: Nakul Jindal <na...@gmail.com>
Authored: Wed Jul 5 11:33:41 2017 -0700
Committer: Nakul Jindal <na...@gmail.com>
Committed: Wed Jul 5 11:33:41 2017 -0700
----------------------------------------------------------------------
relational.dml | 6 +
.../java/org/apache/sysml/hops/BinaryOp.java | 8 +-
.../instructions/GPUInstructionParser.java | 13 +-
.../instructions/gpu/GPUInstruction.java | 34 ++--
.../MatrixMatrixArithmeticGPUInstruction.java | 2 +-
...rixMatrixRelationalBinaryGPUInstruction.java | 69 ++++++++
.../gpu/RelationalBinaryGPUInstruction.java | 68 +++++++
...larMatrixRelationalBinaryGPUInstruction.java | 61 +++++++
.../instructions/gpu/context/CSRPointer.java | 6 +-
.../instructions/gpu/context/GPUObject.java | 2 +-
.../runtime/matrix/data/LibMatrixCUDA.java | 177 ++++++++++++++-----
.../gpu/MatrixMatrixElementWiseOpTests.java | 32 +++-
.../gpu/ScalarMatrixElementwiseOpTests.java | 64 ++++++-
13 files changed, 477 insertions(+), 65 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/relational.dml
----------------------------------------------------------------------
diff --git a/relational.dml b/relational.dml
new file mode 100644
index 0000000..3f492a1
--- /dev/null
+++ b/relational.dml
@@ -0,0 +1,6 @@
+A = rand(rows=10, cols=10)
+B = rand(rows=10, cols=10)
+
+C = A >= B
+
+print(toString(C))
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/main/java/org/apache/sysml/hops/BinaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java b/src/main/java/org/apache/sysml/hops/BinaryOp.java
index 83209ef..36f573c 100644
--- a/src/main/java/org/apache/sysml/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java
@@ -582,7 +582,9 @@ public class BinaryOp extends Hop
if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < GPUContextPool
.initialGPUMemBudget())
&& (op == OpOp2.MULT || op == OpOp2.PLUS || op == OpOp2.MINUS || op == OpOp2.DIV || op == OpOp2.POW
- || op == OpOp2.MINUS_NZ || op == OpOp2.MINUS1_MULT || op == OpOp2.MODULUS || op == OpOp2.INTDIV) ) {
+ || op == OpOp2.MINUS_NZ || op == OpOp2.MINUS1_MULT || op == OpOp2.MODULUS || op == OpOp2.INTDIV
+ || op == OpOp2.LESS || op == OpOp2.LESSEQUAL || op == OpOp2.EQUAL || op == OpOp2.NOTEQUAL
+ || op == OpOp2.GREATER || op == OpOp2.GREATEREQUAL)) {
et = ExecType.GPU;
}
Unary unary1 = new Unary(getInput().get(0).constructLops(),
@@ -602,7 +604,9 @@ public class BinaryOp extends Hop
if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < GPUContextPool
.initialGPUMemBudget())
&& (op == OpOp2.MULT || op == OpOp2.PLUS || op == OpOp2.MINUS || op == OpOp2.DIV || op == OpOp2.POW
- || op == OpOp2.SOLVE || op == OpOp2.MINUS1_MULT || op == OpOp2.MODULUS || op == OpOp2.INTDIV)) {
+ || op == OpOp2.SOLVE || op == OpOp2.MINUS1_MULT || op == OpOp2.MODULUS || op == OpOp2.INTDIV
+ || op == OpOp2.LESS || op == OpOp2.LESSEQUAL || op == OpOp2.EQUAL || op == OpOp2.NOTEQUAL
+ || op == OpOp2.GREATER || op == OpOp2.GREATEREQUAL)) {
et = ExecType.GPU;
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
index 5fd6fa0..17b1578 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java
@@ -30,6 +30,7 @@ import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.MatrixMatrixAxpyGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE;
import org.apache.sysml.runtime.instructions.gpu.MMTSJGPUInstruction;
+import org.apache.sysml.runtime.instructions.gpu.RelationalBinaryGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.ReorgGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.AggregateUnaryGPUInstruction;
@@ -115,6 +116,14 @@ public class GPUInstructionParser extends InstructionParser
String2GPUInstructionType.put( "uavar" , GPUINSTRUCTION_TYPE.AggregateUnary); // Variance
String2GPUInstructionType.put( "uarvar" , GPUINSTRUCTION_TYPE.AggregateUnary); // Row Variance
String2GPUInstructionType.put( "uacvar" , GPUINSTRUCTION_TYPE.AggregateUnary); // Col Variance
+
+ // Relational Binary
+ String2GPUInstructionType.put( "==" , GPUINSTRUCTION_TYPE.RelationalBinary);
+ String2GPUInstructionType.put( "!=" , GPUINSTRUCTION_TYPE.RelationalBinary);
+ String2GPUInstructionType.put( "<" , GPUINSTRUCTION_TYPE.RelationalBinary);
+ String2GPUInstructionType.put( ">" , GPUINSTRUCTION_TYPE.RelationalBinary);
+ String2GPUInstructionType.put( "<=" , GPUINSTRUCTION_TYPE.RelationalBinary);
+ String2GPUInstructionType.put( ">=" , GPUINSTRUCTION_TYPE.RelationalBinary);
}
public static GPUInstruction parseSingleInstruction (String str )
@@ -168,7 +177,9 @@ public class GPUInstructionParser extends InstructionParser
return MatrixMatrixAxpyGPUInstruction.parseInstruction(str);
else
return ArithmeticBinaryGPUInstruction.parseInstruction(str);
-
+ case RelationalBinary:
+ return RelationalBinaryGPUInstruction.parseInstruction(str);
+
default:
throw new DMLRuntimeException("Invalid GPU Instruction Type: " + gputype );
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
index 48b7da6..7f981eb 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java
@@ -32,7 +32,18 @@ import org.apache.sysml.utils.Statistics;
public abstract class GPUInstruction extends Instruction
{
- public enum GPUINSTRUCTION_TYPE { AggregateUnary, AggregateBinary, Convolution, MMTSJ, Reorg, ArithmeticBinary, BuiltinUnary, BuiltinBinary, Builtin };
+ public enum GPUINSTRUCTION_TYPE {
+ AggregateUnary,
+ AggregateBinary,
+ RelationalBinary,
+ Convolution,
+ MMTSJ,
+ Reorg,
+ ArithmeticBinary,
+ BuiltinUnary,
+ BuiltinBinary,
+ Builtin
+ };
// Memory/conversions
public final static String MISC_TIMER_HOST_TO_DEVICE = "H2D"; // time spent in bringing data to gpu (from host)
@@ -46,7 +57,8 @@ public abstract class GPUInstruction extends Instruction
public final static String MISC_TIMER_CUDA_FREE = "f"; // time spent in calling cudaFree
public final static String MISC_TIMER_ALLOCATE = "a"; // time spent to allocate memory on gpu
- public final static String MISC_TIMER_ALLOCATE_DENSE_OUTPUT = "ao"; // time spent to allocate dense output (recorded differently than MISC_TIMER_ALLOCATE)
+ public final static String MISC_TIMER_ALLOCATE_DENSE_OUTPUT = "ad"; // time spent to allocate dense output (recorded differently than MISC_TIMER_ALLOCATE)
+ public final static String MISC_TIMER_ALLOCATE_SPARSE_OUTPUT = "as"; // time spent to allocate sparse output (recorded differently than MISC_TIMER_ALLOCATE)
public final static String MISC_TIMER_SET_ZERO = "az"; // time spent to allocate
public final static String MISC_TIMER_REUSE = "r"; // time spent in reusing already allocated memory on GPU (mainly for the count)
@@ -114,27 +126,27 @@ public abstract class GPUInstruction extends Instruction
protected GPUINSTRUCTION_TYPE _gputype;
protected Operator _optr;
-
+
protected boolean _requiresLabelUpdate = false;
-
+
public GPUInstruction(String opcode, String istr) {
type = INSTRUCTION_TYPE.GPU;
instString = istr;
-
+
//prepare opcode and update requirement for repeated usage
instOpcode = opcode;
_requiresLabelUpdate = super.requiresLabelUpdate();
}
-
+
public GPUInstruction(Operator op, String opcode, String istr) {
this(opcode, istr);
_optr = op;
}
-
+
public GPUINSTRUCTION_TYPE getGPUInstructionType() {
return _gputype;
}
-
+
@Override
public boolean requiresLabelUpdate() {
return _requiresLabelUpdate;
@@ -147,11 +159,11 @@ public abstract class GPUInstruction extends Instruction
@Override
public Instruction preprocessInstruction(ExecutionContext ec)
- throws DMLRuntimeException
+ throws DMLRuntimeException
{
//default preprocess behavior (e.g., debug state)
Instruction tmp = super.preprocessInstruction(ec);
-
+
//instruction patching
if( tmp.requiresLabelUpdate() ) { //update labels only if required
//note: no exchange of updated instruction as labels might change in the general case
@@ -162,7 +174,7 @@ public abstract class GPUInstruction extends Instruction
return tmp;
}
- @Override
+ @Override
public abstract void processInstruction(ExecutionContext ec)
throws DMLRuntimeException;
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixArithmeticGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixArithmeticGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixArithmeticGPUInstruction.java
index a03f9b1..ef3333d 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixArithmeticGPUInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixArithmeticGPUInstruction.java
@@ -71,7 +71,7 @@ public class MatrixMatrixArithmeticGPUInstruction extends ArithmeticBinaryGPUIns
ec.setMetaData(_output.getName(), (int)rlen, (int)clen);
BinaryOperator bop = (BinaryOperator) _optr;
- LibMatrixCUDA.matrixScalarArithmetic(ec, ec.getGPUContext(0), getExtendedOpcode(), in1, in2, _output.getName(), isLeftTransposed, isRightTransposed, bop);
+ LibMatrixCUDA.matrixMatrixArithmetic(ec, ec.getGPUContext(0), getExtendedOpcode(), in1, in2, _output.getName(), isLeftTransposed, isRightTransposed, bop);
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixRelationalBinaryGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixRelationalBinaryGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixRelationalBinaryGPUInstruction.java
new file mode 100644
index 0000000..a7e969f
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixRelationalBinaryGPUInstruction.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.gpu;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+import org.apache.sysml.utils.GPUStatistics;
+
+public class MatrixMatrixRelationalBinaryGPUInstruction extends RelationalBinaryGPUInstruction {
+
+ public MatrixMatrixRelationalBinaryGPUInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, String istr) {
+ super(op, in1, in2, out, opcode, istr);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
+ GPUStatistics.incrementNoOfExecutedGPUInst();
+
+ MatrixObject in1 = getMatrixInputForGPUInstruction(ec, _input1.getName());
+ MatrixObject in2 = getMatrixInputForGPUInstruction(ec, _input2.getName());
+
+ long rlen1 = in1.getNumRows();
+ long clen1 = in1.getNumColumns();
+ long rlen2 = in2.getNumRows();
+ long clen2 = in2.getNumColumns();
+
+ // Assume ordinary binary op
+ long rlen = rlen1;
+ long clen = clen1;
+
+ // Outer binary op ( [100,1] + [1,100] or [100,100] + [100,1]
+ if (rlen1 != rlen2 || clen1 != clen2){
+ rlen = rlen1 > rlen2 ? rlen1 : rlen2;
+ clen = clen1 > clen2 ? clen1 : clen2;
+ }
+
+ ec.setMetaData(_output.getName(), (int)rlen, (int)clen);
+
+ BinaryOperator bop = (BinaryOperator) _optr;
+ LibMatrixCUDA.matrixMatrixRelational(ec, ec.getGPUContext(0), getExtendedOpcode(), in1, in2, _output.getName(), bop);
+
+ ec.releaseMatrixInputForGPUInstruction(_input1.getName());
+ ec.releaseMatrixInputForGPUInstruction(_input2.getName());
+ ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/main/java/org/apache/sysml/runtime/instructions/gpu/RelationalBinaryGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/RelationalBinaryGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/RelationalBinaryGPUInstruction.java
new file mode 100644
index 0000000..8dedf0b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/RelationalBinaryGPUInstruction.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.gpu;
+
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+
+public abstract class RelationalBinaryGPUInstruction extends GPUInstruction {
+
+ protected CPOperand _input1;
+ protected CPOperand _input2;
+ protected CPOperand _output;
+
+ public RelationalBinaryGPUInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
+ super(op, opcode, istr);
+ _gputype = GPUINSTRUCTION_TYPE.RelationalBinary;
+ _input1 = in1;
+ _input2 = in2;
+ _output = out;
+ }
+
+ public static RelationalBinaryGPUInstruction parseInstruction ( String str ) throws DMLRuntimeException {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ InstructionUtils.checkNumFields ( parts, 3 );
+
+ String opcode = parts[0];
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand out = new CPOperand(parts[3]);
+
+ Expression.DataType dt1 = in1.getDataType();
+ Expression.DataType dt2 = in2.getDataType();
+ Expression.DataType dt3 = out.getDataType();
+
+ Operator operator = (dt1 != dt2) ?
+ InstructionUtils.parseScalarBinaryOperator(opcode, (dt1 == Expression.DataType.SCALAR)) :
+ InstructionUtils.parseBinaryOperator(opcode);
+
+ if(dt1 == Expression.DataType.MATRIX && dt2 == Expression.DataType.MATRIX && dt3 == Expression.DataType.MATRIX) {
+ return new MatrixMatrixRelationalBinaryGPUInstruction(operator, in1, in2, out, opcode, str);
+ }
+ else if( dt3 == Expression.DataType.MATRIX && ((dt1 == Expression.DataType.SCALAR && dt2 == Expression.DataType.MATRIX) || (dt1 == Expression.DataType.MATRIX && dt2 == Expression.DataType.SCALAR)) ) {
+ return new ScalarMatrixRelationalBinaryGPUInstruction(operator, in1, in2, out, opcode, str);
+ }
+ else
+ throw new DMLRuntimeException("Unsupported GPU RelationalBinaryGPUInstruction.");
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/main/java/org/apache/sysml/runtime/instructions/gpu/ScalarMatrixRelationalBinaryGPUInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ScalarMatrixRelationalBinaryGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ScalarMatrixRelationalBinaryGPUInstruction.java
new file mode 100644
index 0000000..2a084b9
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ScalarMatrixRelationalBinaryGPUInstruction.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.gpu;
+
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.ScalarObject;
+import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+import org.apache.sysml.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysml.utils.GPUStatistics;
+
+public class ScalarMatrixRelationalBinaryGPUInstruction extends RelationalBinaryGPUInstruction {
+
+ public ScalarMatrixRelationalBinaryGPUInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, String istr) {
+ super(op, in1, in2, out, opcode, istr);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
+ GPUStatistics.incrementNoOfExecutedGPUInst();
+
+ CPOperand mat = ( _input1.getDataType() == Expression.DataType.MATRIX ) ? _input1 : _input2;
+ CPOperand scalar = ( _input1.getDataType() == Expression.DataType.MATRIX ) ? _input2 : _input1;
+ MatrixObject in1 = getMatrixInputForGPUInstruction(ec, mat.getName());
+ ScalarObject constant = (ScalarObject) ec.getScalarInput(scalar.getName(), scalar.getValueType(), scalar.isLiteral());
+
+ int rlen = (int) in1.getNumRows();
+ int clen = (int) in1.getNumColumns();
+ ec.setMetaData(_output.getName(), rlen, clen);
+
+ ScalarOperator sc_op = (ScalarOperator) _optr;
+ sc_op.setConstant(constant.getDoubleValue());
+
+ LibMatrixCUDA.matrixScalarRelational(ec, ec.getGPUContext(0), getExtendedOpcode(), in1, _output.getName(), sc_op);
+
+ ec.releaseMatrixInputForGPUInstruction(mat.getName());
+ ec.releaseMatrixOutputForGPUInstruction(_output.getName());
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
index b15dd69..a4bff9a 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java
@@ -275,10 +275,8 @@ public class CSRPointer {
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static CSRPointer allocateEmpty(GPUContext gCtx, long nnz2, long rows) throws DMLRuntimeException {
- LOG.trace(
- "GPU : allocateEmpty from CSRPointer with nnz=" + nnz2 + " and rows=" + rows + ", GPUContext=" + gCtx);
- assert nnz2
- > -1 : "Incorrect usage of internal API, number of non zeroes is less than 0 when trying to allocate sparse data on GPU";
+ LOG.trace("GPU : allocateEmpty from CSRPointer with nnz=" + nnz2 + " and rows=" + rows + ", GPUContext=" + gCtx);
+ assert nnz2 > -1 : "Incorrect usage of internal API, number of non zeroes is less than 0 when trying to allocate sparse data on GPU";
CSRPointer r = new CSRPointer(gCtx);
r.nnz = nnz2;
if (nnz2 == 0) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index 366eee5..94ceb36 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -891,7 +891,7 @@ public class GPUObject {
"Block not in sparse format on host yet the device sparse matrix pointer is not null");
if (this.isSparseAndEmpty()) {
- MatrixBlock tmp = new MatrixBlock(); // Empty Block
+ MatrixBlock tmp = new MatrixBlock((int)mat.getNumRows(), (int)mat.getNumColumns(), 0l); // Empty Block
mat.acquireModify(tmp);
mat.release();
} else {
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 7b6e9b7..6f28313 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -62,8 +62,6 @@ import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
-import jcuda.jcusparse.cusparseAction;
-import jcuda.jcusparse.cusparseIndexBase;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
@@ -137,7 +135,9 @@ import jcuda.jcudnn.cudnnStatus;
import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.jcusolver.JCusolverDn;
import jcuda.jcusparse.JCusparse;
+import jcuda.jcusparse.cusparseAction;
import jcuda.jcusparse.cusparseHandle;
+import jcuda.jcusparse.cusparseIndexBase;
/**
* All CUDA kernels and library calls are redirected through this class
@@ -2317,15 +2317,48 @@ public class LibMatrixCUDA {
//********************************************************************/
/**
- * Entry point to perform elementwise matrix-scalar operation specified by op
+ * Entry point to perform elementwise matrix-scalar relational operation specified by op
*
- * @param ec execution context
- * @param gCtx a valid {@link GPUContext}
- * @param instName the invoking instruction's name for record {@link Statistics}.
- * @param in input matrix
+ * @param ec execution context
+ * @param gCtx a valid {@link GPUContext}
+ * @param instName the invoking instruction's name for record {@link Statistics}.
+ * @param in input matrix
* @param outputName output matrix name
+ * @param op scalar operator
+ * @throws DMLRuntimeException if DMLRuntimeException occurs
+ */
+ public static void matrixScalarRelational(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in, String outputName, ScalarOperator op) throws DMLRuntimeException {
+ if (ec.getGPUContext(0) != gCtx)
+ throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
+ double constant = op.getConstant();
+ LOG.trace("GPU : matrixScalarRelational, scalar: " + constant + ", GPUContext=" + gCtx);
+
+ Pointer A, C;
+ if (isSparseAndEmpty(gCtx, in)) {
+ setOutputToConstant(ec, gCtx, instName, op.executeScalar(0.0), outputName);
+ return;
+ } else {
+ A = getDensePointer(gCtx, in, instName);
+ MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName); // Allocated the dense output matrix
+ C = getDensePointer(gCtx, out, instName);
+ }
+
+ int rlenA = (int) in.getNumRows();
+ int clenA = (int) in.getNumColumns();
+
+ matrixScalarOp(gCtx, instName, A, constant, rlenA, clenA, C, op);
+ }
+
+ /**
+ * Entry point to perform elementwise matrix-scalar arithmetic operation specified by op
+ *
+ * @param ec execution context
+ * @param gCtx a valid {@link GPUContext}
+ * @param instName the invoking instruction's name for record {@link Statistics}.
+ * @param in input matrix
+ * @param outputName output matrix name
* @param isInputTransposed true if input transposed
- * @param op scalar operator
+ * @param op scalar operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static void matrixScalarArithmetic(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in, String outputName, boolean isInputTransposed, ScalarOperator op) throws DMLRuntimeException {
@@ -2342,6 +2375,7 @@ public class LibMatrixCUDA {
}
else if(op.fn instanceof Multiply || op.fn instanceof And) {
setOutputToConstant(ec, gCtx, instName, 0.0, outputName);
+
}
else if(op.fn instanceof Power) {
setOutputToConstant(ec, gCtx, instName, 1.0, outputName);
@@ -2393,8 +2427,44 @@ public class LibMatrixCUDA {
//}
}
+
+ /**
+ * Performs elementwise operation relational specified by op of two input matrices in1 and in2
+ *
+ * @param ec execution context
+ * @param gCtx a valid {@link GPUContext}
+ * @param instName the invoking instruction's name for record {@link Statistics}.
+ * @param in1 input matrix 1
+ * @param in2 input matrix 2
+ * @param outputName output matrix name
+ * @param op binary operator
+ * @throws DMLRuntimeException if DMLRuntimeException occurs
+ */
+ public static void matrixMatrixRelational(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, MatrixObject in2,
+ String outputName, BinaryOperator op) throws DMLRuntimeException {
+
+ if (ec.getGPUContext(0) != gCtx)
+ throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
+
+ boolean in1SparseAndEmpty = isSparseAndEmpty(gCtx, in1);
+ boolean in2SparseAndEmpty = isSparseAndEmpty(gCtx, in2);
+ if (in1SparseAndEmpty && in2SparseAndEmpty) {
+ if (op.fn instanceof LessThan || op.fn instanceof GreaterThan || op.fn instanceof NotEquals) {
+ setOutputToConstant(ec, gCtx, instName, 0.0, outputName);
+ } else if (op.fn instanceof LessThanEquals || op.fn instanceof GreaterThanEquals || op.fn instanceof Equals) {
+ setOutputToConstant(ec, gCtx, instName, 1.0, outputName);
+ }
+ } else if (in1SparseAndEmpty) {
+ matrixScalarRelational(ec, gCtx, instName, in2, outputName, new LeftScalarOperator(op.fn, 0.0));
+ } else if (in2SparseAndEmpty) {
+ matrixScalarRelational(ec, gCtx, instName, in1, outputName, new RightScalarOperator(op.fn, 0.0));
+ } else {
+ matrixMatrixOp(ec, gCtx, instName, in1, in2, outputName, false, false, op);
+ }
+ }
+
/**
- * Performs elementwise operation specified by op of two input matrices in1 and in2
+ * Performs elementwise arithmetic operation specified by op of two input matrices in1 and in2
*
* @param ec execution context
* @param gCtx a valid {@link GPUContext}
@@ -2407,7 +2477,7 @@ public class LibMatrixCUDA {
* @param op binary operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
- public static void matrixScalarArithmetic(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, MatrixObject in2,
+ public static void matrixMatrixArithmetic(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, MatrixObject in2,
String outputName, boolean isLeftTransposed, boolean isRightTransposed, BinaryOperator op) throws DMLRuntimeException {
if (ec.getGPUContext(0) != gCtx)
throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
@@ -2456,24 +2526,25 @@ public class LibMatrixCUDA {
int clenA = (int) in.getNumColumns();
Pointer A = getDensePointer(gCtx, in, instName); // TODO: FIXME: Implement sparse binCellSparseScalarOp kernel
double scalar = op.getConstant();
- MatrixObject out = ec.getMatrixObject(outputName);
- getDenseMatrixOutputForGPUInstruction(ec, instName, outputName); // Allocated the dense output matrix
+ // MatrixObject out = ec.getMatrixObject(outputName);
+ MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName); // Allocated the dense output matrix
Pointer C = getDensePointer(gCtx, out, instName);
matrixScalarOp(gCtx, instName, A, scalar, rlenA, clenA, C, op);
}
/**
- * Helper method to launch binary scalar-matrix arithmetic operations CUDA kernel.
- * This method is isolated to be taken advatage of from other operations
+ * Helper method to launch binary scalar-matrix arithmetic/relational operations CUDA kernel.
+ * This method is isolated to be taken advantage of from other operations
* as it accepts JCuda {@link Pointer} instances instead of {@link MatrixObject} instances.
- * @param gCtx a valid {@link GPUContext}
+ *
+ * @param gCtx a valid {@link GPUContext}
* @param instName the invoking instruction's name for record {@link Statistics}.
- * @param a the dense input matrix (allocated on GPU)
- * @param scalar the scalar value to do the op
- * @param rlenA row length of matrix a
- * @param clenA column lenght of matrix a
- * @param c the dense output matrix
- * @param op operation to perform
+ * @param a the dense input matrix (allocated on GPU)
+ * @param scalar the scalar value to do the op
+ * @param rlenA row length of matrix a
+ * @param clenA column lenght of matrix a
+ * @param c the dense output matrix
+ * @param op operation to perform
* @throws DMLRuntimeException throws runtime exception
*/
private static void matrixScalarOp(GPUContext gCtx, String instName, Pointer a, double scalar, int rlenA, int clenA, Pointer c, ScalarOperator op) throws DMLRuntimeException {
@@ -2490,15 +2561,16 @@ public class LibMatrixCUDA {
/**
* Utility to launch binary cellwise matrix-matrix operations CUDA kernel
- * @param gCtx a valid {@link GPUContext}
- * @param ec execution context
- * @param instName the invoking instruction's name for record {@link Statistics}.
- * @param in1 left input matrix
- * @param in2 right input matrix
- * @param outputName output variable name
- * @param isLeftTransposed true if left matrix is transposed
+ *
+ * @param gCtx a valid {@link GPUContext}
+ * @param ec execution context
+ * @param instName the invoking instruction's name for record {@link Statistics}.
+ * @param in1 left input matrix
+ * @param in2 right input matrix
+ * @param outputName output variable name
+ * @param isLeftTransposed true if left matrix is transposed
* @param isRightTransposed true if right matrix is transposed
- * @param op operator
+ * @param op operator
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
private static void matrixMatrixOp(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in1, MatrixObject in2,
@@ -2679,19 +2751,21 @@ public class LibMatrixCUDA {
if (ec.getGPUContext(0) != gCtx)
throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
if(constant == 0) {
- // TODO: Create sparse empty block instead
+ MatrixObject out = getSparseMatrixOutputForGPUInstruction(ec, 0, instName, outputName);
+ } else {
+ //MatrixObject out = ec.getMatrixObject(outputName);
+ MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName); // Allocated the dense output matrix
+ Pointer A = getDensePointer(gCtx, out, instName);
+ int rlen = (int) out.getNumRows();
+ int clen = (int) out.getNumColumns();
+ long t0 = 0;
+ if (GPUStatistics.DISPLAY_STATISTICS)
+ t0 = System.nanoTime();
+ int size = rlen * clen;
+ getCudaKernels(gCtx).launchKernel("fill", ExecutionConfig.getConfigForSimpleVectorOperations(size), A, constant, size);
+ if (GPUStatistics.DISPLAY_STATISTICS)
+ GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_FILL_KERNEL, System.nanoTime() - t0);
}
- MatrixObject out = ec.getMatrixObject(outputName);
- getDenseMatrixOutputForGPUInstruction(ec, instName, outputName); // Allocated the dense output matrix
- Pointer A = getDensePointer(gCtx, out, instName);
- int rlen = (int) out.getNumRows();
- int clen = (int) out.getNumColumns();
- long t0=0;
- if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
- int size = rlen * clen;
- getCudaKernels(gCtx).launchKernel("fill", ExecutionConfig.getConfigForSimpleVectorOperations(size),
- A, constant, size);
- if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_FILL_KERNEL, System.nanoTime() - t0);
}
/**
@@ -3374,4 +3448,25 @@ public class LibMatrixCUDA {
return mb.getKey();
}
+ /**
+ * Helper method to get the output block (allocated on the GPU)
+ * Also records performance information into {@link Statistics}
+ * @param ec active {@link ExecutionContext}
+ * @param nnz number of non zeroes in output matrix
+ * @param instName the invoking instruction's name for record {@link Statistics}.
+ * @param name name of input matrix (that the {@link ExecutionContext} is aware of)
+ * @return the matrix object
+ * @throws DMLRuntimeException if an error occurs
+ */
+ private static MatrixObject getSparseMatrixOutputForGPUInstruction(ExecutionContext ec, long nnz, String instName, String name) throws DMLRuntimeException {
+ long t0=0;
+ if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime();
+ Pair<MatrixObject, Boolean> mb = ec.getSparseMatrixOutputForGPUInstruction(name, nnz);
+ if (mb.getValue())
+ if (GPUStatistics.DISPLAY_STATISTICS)
+ GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ALLOCATE_SPARSE_OUTPUT, System.nanoTime() - t0);
+ return mb.getKey();
+ }
+
+
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/test/java/org/apache/sysml/test/gpu/MatrixMatrixElementWiseOpTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/MatrixMatrixElementWiseOpTests.java b/src/test/java/org/apache/sysml/test/gpu/MatrixMatrixElementWiseOpTests.java
index 744b2c2..490befa 100644
--- a/src/test/java/org/apache/sysml/test/gpu/MatrixMatrixElementWiseOpTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/MatrixMatrixElementWiseOpTests.java
@@ -34,8 +34,8 @@ import org.junit.Test;
public class MatrixMatrixElementWiseOpTests extends GPUTests {
private final static String TEST_NAME = "MatrixMatrixElementWiseOpTests";
- private final int[] rowSizes = new int[] { 1, 64, 130, 1024, 2049 };
- private final int[] columnSizes = new int[] { 1, 64, 130, 1024, 2049 };
+ private final int[] rowSizes = new int[] { 1, 64, 1024, 2049 };
+ private final int[] columnSizes = new int[] { 1, 64, 1024, 2049 };
private final double[] sparsities = new double[] { 0.0, 0.03, 0.3, 0.9 };
private final double[] scalars = new double[] { 0.0, 0.5, 2.0 };
private final int seed = 42;
@@ -171,7 +171,35 @@ public class MatrixMatrixElementWiseOpTests extends GPUTests {
runMatrixRowVectorTest("O = 1 - X * Y", "X", "Y", "O", "gpu_1-*");
}
+ @Test
+ public void testLessThan() {
+ runMatrixMatrixElementwiseTest("O = X < Y", "X", "Y", "O", "gpu_<");
+ }
+
+ @Test
+ public void testLessThanEqual() {
+ runMatrixMatrixElementwiseTest("O = X <= Y", "X", "Y", "O", "gpu_<=");
+ }
+
+ @Test
+ public void testGreaterThan() {
+ runMatrixMatrixElementwiseTest("O = X > Y", "X", "Y", "O", "gpu_>");
+ }
+ @Test
+ public void testGreaterThanEqual() {
+ runMatrixMatrixElementwiseTest("O = X >= Y", "X", "Y", "O", "gpu_>=");
+ }
+
+ @Test
+ public void testEqual() {
+ runMatrixMatrixElementwiseTest("O = X == Y", "X", "Y", "O", "gpu_==");
+ }
+
+ @Test
+ public void NotEqual() {
+ runMatrixMatrixElementwiseTest("O = X != Y", "X", "Y", "O", "gpu_!=");
+ }
/**
* Runs a simple matrix-matrix elementwise op test
http://git-wip-us.apache.org/repos/asf/systemml/blob/a7364746/src/test/java/org/apache/sysml/test/gpu/ScalarMatrixElementwiseOpTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/ScalarMatrixElementwiseOpTests.java b/src/test/java/org/apache/sysml/test/gpu/ScalarMatrixElementwiseOpTests.java
index c58365a..7ceeb0f 100644
--- a/src/test/java/org/apache/sysml/test/gpu/ScalarMatrixElementwiseOpTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/ScalarMatrixElementwiseOpTests.java
@@ -35,8 +35,8 @@ public class ScalarMatrixElementwiseOpTests extends GPUTests {
private final static String TEST_NAME = "ScalarMatrixElementwiseOpTests";
- private final int[] rowSizes = new int[] { 1, 64, 130, 2049 };
- private final int[] columnSizes = new int[] { 1, 64, 130, 2049 };
+ private final int[] rowSizes = new int[] { 1, 64, 2049 };
+ private final int[] columnSizes = new int[] { 1, 64, 2049 };
private final double[] sparsities = new double[] { 0.0, 0.03, 0.3, 0.9 };
private final int seed = 42;
@@ -48,6 +48,66 @@ public class ScalarMatrixElementwiseOpTests extends GPUTests {
}
@Test
+ public void testLessThanRightScalar() {
+ runScalarMatrixElementWiseTests("O = X < scalar", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_<");
+ }
+
+ @Test
+ public void testLessThanLeftScalar() {
+ runScalarMatrixElementWiseTests("O = scalar < X", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_<");
+ }
+
+ @Test
+ public void testLessThanEqualRightScalar() {
+ runScalarMatrixElementWiseTests("O = X <= scalar", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_<=");
+ }
+
+ @Test
+ public void testLessThanEqualLeftScalar() {
+ runScalarMatrixElementWiseTests("O = scalar <= X", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_<=");
+ }
+
+ @Test
+ public void testGreaterThanRightScalar() {
+ runScalarMatrixElementWiseTests("O = X > scalar", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_>");
+ }
+
+ @Test
+ public void testGreaterThanLeftScalar() {
+ runScalarMatrixElementWiseTests("O = scalar > X", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_>");
+ }
+
+ @Test
+ public void testGreaterThanEqualRightScalar() {
+ runScalarMatrixElementWiseTests("O = X >= scalar", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_>=");
+ }
+
+ @Test
+ public void testGreaterThanEqualLeftScalar() {
+ runScalarMatrixElementWiseTests("O = scalar >= X", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_>=");
+ }
+
+ @Test
+ public void testEqualRightScalar() {
+ runScalarMatrixElementWiseTests("O = X == scalar", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_==");
+ }
+
+ @Test
+ public void testEqualLeftScalar() {
+ runScalarMatrixElementWiseTests("O = scalar == X", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_==");
+ }
+
+ @Test
+ public void testNotEqualRightScalar() {
+ runScalarMatrixElementWiseTests("O = X != scalar", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_!=");
+ }
+
+ @Test
+ public void testNotEqualEqualLeftScalar() {
+ runScalarMatrixElementWiseTests("O = scalar != X", "X", "scalar", "O", new double[] { 0.0, 20.0 }, "gpu_!=");
+ }
+
+ @Test
public void testPlusRightScalar() {
runScalarMatrixElementWiseTests("O = X + scalar", "X", "scalar", "O", new double[] { 0.0, 0.5, 20.0 }, "gpu_+");
}