You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by wa...@apache.org on 2020/03/02 09:02:27 UTC

[singa] branch dev updated: add gemm

This is an automated email from the ASF dual-hosted git repository.

wangwei pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/singa.git


The following commit(s) were added to refs/heads/dev by this push:
     new 1869bab  add gemm
     new 1fcfbfa  Merge pull request #614 from joddiy/gemm_v3
1869bab is described below

commit 1869babfab8c3bd8ad08e63b811f2f75e99d62d8
Author: joddiy <jo...@qq.com>
AuthorDate: Sat Feb 29 11:22:51 2020 +0800

    add gemm
---
 python/singa/autograd.py         | 108 ++++++++++++++++++++++++++
 python/singa/sonnx.py            |  99 ++++++++++++------------
 test/python/test_onnx.py         |  50 ++++++------
 test/python/test_onnx_backend.py | 161 +++++++++++++++++++++++++++++++++++++++
 test/python/test_operation.py    |  69 +++++++++++++++++
 5 files changed, 415 insertions(+), 72 deletions(-)

diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index d818ece..35611f6 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -2833,3 +2833,111 @@ class Reciprocal(Operation):
 
 def reciprocal(x):
     return Reciprocal()(x)[0]
+
+
+class Gemm(Operation):
+
+    def __init__(self, alpha=1.0, beta=1.0, transA=0, transB=0):
+        """
+        init a General Matrix multiplication(Gemm) operator
+        Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M), input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N).
+        A' = transpose(A) if transA else A
+        B' = transpose(B) if transB else B
+        Args:alpha: 
+            float, Scalar multiplier for the product of input tensors A * B.
+        Args:beta: 
+            float, Scalar multiplier for input tensor C.
+        Args:transA: 
+            int, Whether A should be transposed
+        Args:transB: 
+            int, Whether B should be transposed
+        Returns: 
+            tensor, the output
+        """
+        super(Gemm, self).__init__()
+        self.alpha = alpha
+        self.beta = beta
+        self.transA = transA
+        self.transB = transB
+
+    def forward(self, A, B, C=None):
+        """
+        forward propogation of Gemm
+        Args:A: 
+            tensor, The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
+        Args:B: 
+            tensor, The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
+        Args:C: 
+            tensor(optional), Optional input tensor C. If not specified, the computation is done as if C is a scalar 0. The shape of C should be unidirectional broadcastable to (M, N).
+        Returns: 
+            tensor, the output
+        """
+        _A = singa.DefaultTranspose(A) if self.transA == 1 else A
+        _B = singa.DefaultTranspose(B) if self.transB == 1 else B
+        if training:
+            self.inputs = (_A, _B, C)
+        tmpM = singa.MultFloat(singa.Mult(_A, _B), self.alpha)
+        if C:
+            tmpM = singa.__add__(tmpM, singa.MultFloat(C, self.beta))
+        return tmpM
+
+    def backward(self, dy):
+        """
+        backward propogation of Gemm
+        Args:dy: 
+            tensor, The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
+        Returns: 
+            tensor, the gradient over A
+            tensor, the gradient over B
+            tensor(optional), the gradient over C
+        """
+        _A, _B, C = self.inputs
+        # y = alpha * A  * B  => da = alpha * dy * BT
+        # y = alpha * A  * BT => da = alpha * dy * B
+        # y = alpha * AT * B  => da = alpha * B * dyT = alpha * (dy * BT)T
+        # y = alpha * AT * BT => da = alpha * BT * dyT = alpha * (dy * B)T
+        da = singa.MultFloat(singa.Mult(dy, singa.DefaultTranspose(_B)),
+                             self.alpha)
+        if self.transA:
+            da = singa.DefaultTranspose(da)
+
+        # y = alpha * A  * B  => db = alpha * AT * dy
+        # y = alpha * AT * B  => db = alpha * A * dy
+        # y = alpha * A  * BT => db = alpha * dyT * A = alpha * (AT * dy)T
+        # y = alpha * AT * BT => db = alpha * dyT * AT = alpha * (A * dy)T
+        db = singa.MultFloat(singa.Mult(singa.DefaultTranspose(_A), dy),
+                             self.alpha)
+        if self.transB:
+            db = singa.DefaultTranspose(db)
+        if C:
+            dc = back_broadcast(dy.shape(), C.shape(),
+                                singa.MultFloat(dy, self.beta))
+            return da, db, dc
+        else:
+            return da, db
+
+
+def gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
+    """
+    init a General Matrix multiplication(Gemm) operator
+    Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M), input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N), and output tensor Y has shape (M, N).
+    A' = transpose(A) if transA else A
+    B' = transpose(B) if transB else B
+    Args:A: 
+        tensor, The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
+    Args:B: 
+        tensor, The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
+    Args:C: 
+        tensor(optional), Optional input tensor C. If not specified, the computation is done as if C is a scalar 0. The shape of C should be unidirectional broadcastable to (M, N).
+    Args:alpha: 
+        float, Scalar multiplier for the product of input tensors A * B.
+    Args:beta: 
+        float, Scalar multiplier for input tensor C.
+    Args:transA: 
+        int, Whether A should be transposed
+    Args:transB: 
+        int, Whether B should be transposed
+    Returns: 
+        tensor, the output
+    """
+    return Gemm(alpha, beta, transA, transB)(A, B, C)[0]
diff --git a/python/singa/sonnx.py b/python/singa/sonnx.py
index c0e24c9..2bdb079 100755
--- a/python/singa/sonnx.py
+++ b/python/singa/sonnx.py
@@ -162,7 +162,7 @@ class SingaFrontend(object):
         'Concat': 'Concat',
         'Flatten': 'Flatten',
         'AddBias': 'Add',
-        # 'GEMM': 'Gemm',
+        'Gemm': 'Gemm',
         'Reshape': 'Reshape',
         'Sum': 'Sum',
         'cos': 'Cos',
@@ -217,7 +217,7 @@ class SingaFrontend(object):
         '_BatchNorm2d': '_create_batchnorm',
         'Concat': '_create_concat',
         'Flatten': '_create_flatten',
-        # 'GEMM': '_create_gemm',
+        'Gemm': '_create_gemm',
         'Reshape': '_create_reshape',
         'SoftMax': '_create_softmax',
         'SeLU': '_create_selu',
@@ -431,27 +431,27 @@ class SingaFrontend(object):
         ])
         return node
 
-    # @classmethod
-    # def _create_gemm(cls, op, op_t):
-    #     """
-    #     get a onnx node from singa gemm operator
-    #     Args:
-    #         op: a given operator
-    #     Args:
-    #         op_t: the tensor of the operator
-    #     Returns:
-    #         the onnx node
-    #     """
-    #     node = cls._common_singa_tensor_to_onnx_node(op, op_t)
-
-    #     node.attribute.extend([
-    #         helper.make_attribute('alpha', float(op.alpha)),
-    #         helper.make_attribute('beta', float(op.beta)),
-    #         helper.make_attribute('transA', 1 if op.transA else 0),
-    #         helper.make_attribute('transB', 1 if op.transB else 0),
-    #     ])
-
-    #     return node
+    @classmethod
+    def _create_gemm(cls, op, op_t):
+        """
+        get a onnx node from singa gemm operator
+        Args:
+            op: a given operator
+        Args:
+            op_t: the tensor of the operator
+        Returns: 
+            the onnx node
+        """
+        node = cls._common_singa_tensor_to_onnx_node(op, op_t)
+
+        node.attribute.extend([
+            helper.make_attribute('alpha', float(op.alpha)),
+            helper.make_attribute('beta', float(op.beta)),
+            helper.make_attribute('transA', op.transA),
+            helper.make_attribute('transB', op.transB),
+        ])
+
+        return node
 
     @classmethod
     def _create_batchnorm(cls, op, op_t):
@@ -617,7 +617,6 @@ class SingaFrontend(object):
         if optype in cls._bool_operators:
             y_dtype = cls._bool_operators[optype]
         Y = [helper.make_tensor_value_info(y.name, y_dtype, y.shape)]
-
         for op, yid, op_t in topol:
             optype = cls._get_singa_op_type(op)
             # print(op.name, cls._get_singa_op_type(op), op_t, optype, yid)
@@ -779,7 +778,7 @@ class SingaBackend(Backend):
         'BatchNormalization': 'batchnorm_2d',
         'Concat': 'Concat',
         'Flatten': 'Flatten',
-        # 'Gemm': 'GEMM',
+        'Gemm': 'Gemm',
         'Reshape': 'reshape',
         'Sum': 'sum',
         'Cos': 'cos',
@@ -835,7 +834,7 @@ class SingaBackend(Backend):
         'Concat': '_create_concat',
         'MatMul': '_create_matmul',
         'Flatten': '_create_flatten',
-        # 'Gemm': '_create_gemm',
+        'Gemm': '_create_gemm',
         'Reshape': '_create_reshape',
         'Softmax': '_create_softmax',
         'Selu': '_create_selu',
@@ -1163,28 +1162,32 @@ class SingaBackend(Backend):
                                                        opset_version)
         return None, forward(axis=factor)
 
-    # @classmethod
-    # def _create_gemm(cls, onnx_node, inputs, opset_version):
-    #     """
-    #     get the gemm operator from onnx node
-    #     Args:
-    #         onnx_node: a given onnx node
-    #     Args:
-    #         inputs: the input tensor
-    #     Args:
-    #         opset_version: the opset version
-    #     Returns:
-    #         the handle of singa operator
-    #     Returns:
-    #         the autograd of singa operator
-    #     """
-    #     x = inputs[0]
-    #     alpha = onnx_node.attrs["alpha"]
-    #     beta = onnx_node.attrs["beta"]
-    #     transA = False if onnx_node.attrs["transA"] == 0 else True
-    #     transB = False if onnx_node.attrs["transB"] == 0 else True
-    #     _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs, opset_version)
-    #     return None, forward(alpha=alpha, beta=beta, transA=transA, transB=transB)
+    @classmethod
+    def _create_gemm(cls, onnx_node, inputs, opset_version):
+        """
+        get the gemm operator from onnx node
+        Args:
+            onnx_node: a given onnx node
+        Args:
+            inputs: the input tensor
+        Args:
+            opset_version: the opset version
+        Returns: 
+            the handle of singa operator
+        Returns: 
+            the autograd of singa operator
+        """
+        x = inputs[0]
+        alpha = onnx_node.getattr('alpha', 1.)
+        beta = onnx_node.getattr('beta', 1.)
+        transA = onnx_node.getattr('transA', 0)
+        transB = onnx_node.getattr('transB', 0)
+        _, forward = cls._common_onnx_node_to_singa_op(onnx_node, inputs,
+                                                       opset_version)
+        return None, forward(alpha=alpha,
+                             beta=beta,
+                             transA=transA,
+                             transB=transB)
 
     @classmethod
     def _create_flatten(cls, onnx_node, inputs, opset_version):
diff --git a/test/python/test_onnx.py b/test/python/test_onnx.py
index 4eb9c6a..b6cc7d2 100644
--- a/test/python/test_onnx.py
+++ b/test/python/test_onnx.py
@@ -259,30 +259,32 @@ class TestPythonOnnx(unittest.TestCase):
                                              tensor.to_numpy(y_t[0]),
                                              decimal=5)
 
-    # def test_gemm(self):
-    #     A = np.random.randn(2, 3).astype(np.float32)
-    #     B = np.random.rand(3, 4).astype(np.float32)
-    #     C = np.random.rand(2, 4).astype(np.float32)
-    #     alpha = 1.0
-    #     beta = 2.0
-
-    #     tA = tensor.from_numpy(A)
-    #     tB = tensor.from_numpy(B)
-    #     tC = tensor.from_numpy(C)
-    #     tA.to_device(gpu_dev)
-    #     tB.to_device(gpu_dev)
-    #     tC.to_device(gpu_dev)
-    #     y = autograd.GEMM(alpha, beta, False, False)(tA, tB, tC)[0]
-
-    #     # frontend
-    #     model = sonnx.to_onnx([tA, tB, tC], [y])
-    #     # print('The model is:\n{}'.format(model))
-
-    #     # backend
-    #     sg_ir = sonnx.prepare(model, device=gpu_dev)
-    #     y_t = sg_ir.run([tA, tB, tC])
-
-    #     np.testing.assert_array_almost_equal(tensor.to_numpy(y), tensor.to_numpy(y_t[0]), decimal=5)
+    def test_gemm(self):
+        A = np.random.randn(2, 3).astype(np.float32)
+        B = np.random.rand(3, 4).astype(np.float32)
+        C = np.random.rand(2, 4).astype(np.float32)
+        alpha = 1.0
+        beta = 2.0
+
+        tA = tensor.from_numpy(A)
+        tB = tensor.from_numpy(B)
+        tC = tensor.from_numpy(C)
+        tA.to_device(gpu_dev)
+        tB.to_device(gpu_dev)
+        tC.to_device(gpu_dev)
+        y = autograd.Gemm(alpha, beta, 0, 0)(tA, tB, tC)[0]
+
+        # frontend
+        model = sonnx.to_onnx([tA, tB, tC], [y])
+        # print('The model is:\n{}'.format(model))
+
+        # backend
+        sg_ir = sonnx.prepare(model, device=gpu_dev)
+        y_t = sg_ir.run([tA, tB, tC])
+
+        np.testing.assert_array_almost_equal(tensor.to_numpy(y),
+                                             tensor.to_numpy(y_t[0]),
+                                             decimal=5)
 
     def test_reshape(self):
         x = np.array([0.1, -1.0, 0.4, 4.0, -0.9,
diff --git a/test/python/test_onnx_backend.py b/test/python/test_onnx_backend.py
index 4c1ee8e..84b69e6 100644
--- a/test/python/test_onnx_backend.py
+++ b/test/python/test_onnx_backend.py
@@ -2028,6 +2028,167 @@ class TestPythonOnnxBackend(unittest.TestCase):
         y = np.random.randn(5).astype(np.float32)
         z = x * y
         expect(node, inputs=[x, y], outputs=[z], name='test_mul_bcast')
+        
+    def test_gemm_default_zero_bias(self):
+        node = onnx.helper.make_node(
+            'Gemm',
+            inputs=['a', 'b', 'c'],
+            outputs=['y']
+        )
+        a = np.random.ranf([3, 5]).astype(np.float32)
+        b = np.random.ranf([5, 4]).astype(np.float32)
+        c = np.zeros([1, 4]).astype(np.float32)
+        y = gemm_reference_implementation(a, b, c)
+        expect(node, inputs=[a, b, c], outputs=[y],
+                name='test_gemm_default_zero_bias')
+
+    def test_gemm_default_no_bias(self):
+        node = onnx.helper.make_node(
+            'Gemm',
+            inputs=['a', 'b'],
+            outputs=['y']
+        )
+        a = np.random.ranf([2, 10]).astype(np.float32)
+        b = np.random.ranf([10, 3]).astype(np.float32)
+        y = gemm_reference_implementation(a, b)
+        expect(node, inputs=[a, b], outputs=[y],
+                name='test_gemm_default_no_bias')
+
+    def test_gemm_default_scalar_bias(self):
+        node = onnx.helper.make_node(
+            'Gemm',
+            inputs=['a', 'b', 'c'],
+            outputs=['y']
+        )
+        a = np.random.ranf([2, 3]).astype(np.float32)
+        b = np.random.ranf([3, 4]).astype(np.float32)
+        c = np.array(3.14).astype(np.float32)
+        y = gemm_reference_implementation(a, b, c)
+        expect(node, inputs=[a, b, c], outputs=[y],
+                name='test_gemm_default_scalar_bias')
+
+    def test_gemm_default_single_elem_vector_bias(self):
+        node = onnx.helper.make_node(
+            'Gemm',
+            inputs=['a', 'b', 'c'],
+            outputs=['y']
+        )
+        a = np.random.ranf([3, 7]).astype(np.float32)
+        b = np.random.ranf([7, 3]).astype(np.float32)
+        c = np.random.ranf([1]).astype(np.float32)
+        y = gemm_reference_implementation(a, b, c)
+        expect(node, inputs=[a, b, c], outputs=[y],
+                name='test_gemm_default_single_elem_vector_bias')
+
+    def test_gemm_default_vector_bias(self):
+        node = onnx.helper.make_node(
+            'Gemm',
+            inputs=['a', 'b', 'c'],
+            outputs=['y']
+        )
+        a = np.random.ranf([2, 7]).astype(np.float32)
+        b = np.random.ranf([7, 4]).astype(np.float32)
+        c = np.random.ranf([1, 4]).astype(np.float32)
+        y = gemm_reference_implementation(a, b, c)
+        expect(node, inputs=[a, b, c], outputs=[y],
+                name='test_gemm_default_vector_bias')
+
+    def test_gemm_default_matrix_bias(self):
+        node = onnx.helper.make_node(
+            'Gemm',
+            inputs=['a', 'b', 'c'],
+            outputs=['y']
+        )
+        a = np.random.ranf([3, 6]).astype(np.float32)
+        b = np.random.ranf([6, 4]).astype(np.float32)
+        c = np.random.ranf([3, 4]).astype(np.float32)
+        y = gemm_reference_implementation(a, b, c)
+        expect(node, inputs=[a, b, c], outputs=[y],
+                name='test_gemm_default_matrix_bias')
+
+    def test_gemm_transposeA(self):
+        node = onnx.helper.make_node(
+            'Gemm',
+            inputs=['a', 'b', 'c'],
+            outputs=['y'],
+            transA=1
+        )
+        a = np.random.ranf([6, 3]).astype(np.float32)
+        b = np.random.ranf([6, 4]).astype(np.float32)
+        c = np.zeros([1, 4]).astype(np.float32)
+        y = gemm_reference_implementation(a, b, c, transA=1)
+        expect(node, inputs=[a, b, c], outputs=[y],
+                name='test_gemm_transposeA')
+
+    def test_gemm_transposeB(self):
+        node = onnx.helper.make_node(
+            'Gemm',
+            inputs=['a', 'b', 'c'],
+            outputs=['y'],
+            transB=1
+        )
+        a = np.random.ranf([3, 6]).astype(np.float32)
+        b = np.random.ranf([4, 6]).astype(np.float32)
+        c = np.zeros([1, 4]).astype(np.float32)
+        y = gemm_reference_implementation(a, b, c, transB=1)
+        expect(node, inputs=[a, b, c], outputs=[y],
+                name='test_gemm_transposeB')
+
+    def test_gemm_alpha(self):
+        node = onnx.helper.make_node(
+            'Gemm',
+            inputs=['a', 'b', 'c'],
+            outputs=['y'],
+            alpha=0.5
+        )
+        a = np.random.ranf([3, 5]).astype(np.float32)
+        b = np.random.ranf([5, 4]).astype(np.float32)
+        c = np.zeros([1, 4]).astype(np.float32)
+        y = gemm_reference_implementation(a, b, c, alpha=0.5)
+        expect(node, inputs=[a, b, c], outputs=[y],
+                name='test_gemm_alpha')
+
+    def test_gemm_beta(self):
+        node = onnx.helper.make_node(
+            'Gemm',
+            inputs=['a', 'b', 'c'],
+            outputs=['y'],
+            beta=0.5
+        )
+        a = np.random.ranf([2, 7]).astype(np.float32)
+        b = np.random.ranf([7, 4]).astype(np.float32)
+        c = np.random.ranf([1, 4]).astype(np.float32)
+        y = gemm_reference_implementation(a, b, c, beta=0.5)
+        expect(node, inputs=[a, b, c], outputs=[y],
+                name='test_gemm_beta')
+
+    def test_gemm_all_attributes(self):
+        node = onnx.helper.make_node(
+            'Gemm',
+            inputs=['a', 'b', 'c'],
+            outputs=['y'],
+            alpha=0.25,
+            beta=0.35,
+            transA=1,
+            transB=1
+        )
+        a = np.random.ranf([4, 3]).astype(np.float32)
+        b = np.random.ranf([5, 4]).astype(np.float32)
+        c = np.random.ranf([1, 5]).astype(np.float32)
+        y = gemm_reference_implementation(a, b, c, transA=1, transB=1, alpha=0.25, beta=0.35)
+        expect(node, inputs=[a, b, c], outputs=[y],
+                name='test_gemm_all_attributes')
+
+
+def gemm_reference_implementation(A, B, C=None, alpha=1., beta=1., transA=0,
+                                transB=0):  # type: (np.ndarray, np.ndarray, Optional[np.ndarray], float, float, int, int) -> np.ndarray
+    A = A if transA == 0 else A.T
+    B = B if transB == 0 else B.T
+    C = C if C is not None else np.array(0)
+
+    Y = alpha * np.dot(A, B) + beta * C
+
+    return Y
 
 
 # return padding shape of conv2d or pooling
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index c0539d6..68e8c7f 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -3397,5 +3397,74 @@ class TestPythonOperation(unittest.TestCase):
                                                  decimal=5)
 
 
+    def gemm_test(self, dev):
+        configs = [
+            # alpha, beta, transA, transB, shapeA, shapeB, shapeC, shapeY
+            [0.25, 0.35, 0, 0, (3, 4), (4, 5), (1, 5), (3, 5)],
+            [0.25, 0.35, 0, 1, (3, 4), (5, 4), (1, 5), (3, 5)],
+            [0.25, 0.35, 1, 0, (4, 3), (4, 5), (1, 5), (3, 5)],
+            [0.25, 0.35, 1, 1, (4, 3), (5, 4), (1, 5), (3, 5)],
+        ]
+        for config in configs:
+            alpha = config[0]
+            beta = config[1]
+            transA = config[2]
+            transB = config[3]
+            shapeA = config[4]
+            shapeB = config[5]
+            shapeC = config[6]
+            shapeY = config[7]
+            A = np.random.randn(*shapeA).astype(np.float32)
+            B = np.random.randn(*shapeB).astype(np.float32)
+            C = np.random.randn(*shapeC).astype(np.float32)
+            DY = np.ones(shapeY, dtype=np.float32)
+
+            a = tensor.from_numpy(A)
+            a.to_device(dev)
+            b = tensor.from_numpy(B)
+            b.to_device(dev)
+            c = tensor.from_numpy(C)
+            c.to_device(dev)
+            dy = tensor.from_numpy(DY)
+            dy.to_device(dev)
+
+            result = autograd.gemm(a, b, c, alpha, beta, transA, transB)
+            da, db, dc = result.creator.backward(dy.data)
+
+            # Y = alpha * A' * B' + beta * C
+            _A = A if transA == 0 else A.T
+            _B = B if transB == 0 else B.T
+            C = C if C is not None else np.array(0)
+            Y = alpha * np.dot(_A, _B) + beta * C
+
+            DA = alpha * np.matmul(DY, _B.T)
+            DA = DA if transA == 0 else DA.T
+            DB = alpha * np.matmul(_A.T, DY)
+            DB = DB if transB == 0 else DB.T
+            DC = beta * np.sum(DY, axis=axis_helper(Y.shape, C.shape)).reshape(
+                C.shape)
+
+            np.testing.assert_array_almost_equal(tensor.to_numpy(result),
+                                                 Y,
+                                                 decimal=5)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(
+                tensor.from_raw_tensor(da)),
+                                                 DA,
+                                                 decimal=5)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(
+                tensor.from_raw_tensor(db)),
+                                                 DB,
+                                                 decimal=5)
+            np.testing.assert_array_almost_equal(tensor.to_numpy(
+                tensor.from_raw_tensor(dc)),
+                                                 DC,
+                                                 decimal=5)
+
+    def test_gemm_cpu(self):
+        self.gemm_test(cpu_dev)
+
+    def test_gemm_gpu(self):
+        self.gemm_test(gpu_dev)
+
 if __name__ == '__main__':
     unittest.main()