You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/07 19:26:41 UTC

[GitHub] [tvm] Wheest opened a new pull request #7050: Sparse Conv2D for CPU (NCHW)

Wheest opened a new pull request #7050:
URL: https://github.com/apache/tvm/pull/7050


   This pull request adds sparse conv2d implementations to CPU for TOPI.  I have implemented sparse GEMM convolution, and sparse direct convolution for the NCHW data layout, using the CSR sparse data format.
   
   The extension to the C++ runtime is pretty stable.  The code for TOPI is not clean or very well integrated yet, but I am looking for some guidance from other developers.
   
   [This gist](https://gist.github.com/Wheest/94433f73ff3279669bf35adcc38b321d) has a simple example of running a single layer Conv2D network with sparsity.  
   
   You can choose what algorithm the Relay strategy uses with the two environment variables:
   
   ```
   export TVM_DIRECT_CONV=1
   export TVM_GEMM_CONV=0
   ```
   
   Comments on how to improve the integration appreciated.  Further pull requests could add other sparse algorithms, and sparse data formats.
   
   I am in the process of creating sparse versions for GPU runtimes, but am having some difficulties I am discussing on the [Discuss](https://discuss.tvm.apache.org/t/sparse-opencl-error-scheduling-sparse-computations-that-use-tir-ir-builder/).
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Wheest commented on a change in pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
Wheest commented on a change in pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#discussion_r538450478



##########
File path: python/tvm/topi/sparse/csrmm.py
##########
@@ -121,3 +121,46 @@ def csrmm(a, b, c=None):
         2-D with shape [m, n]
     """
     return csrmm_default(a.data, a.indices, a.indptr, b, c)
+
+
+def batch_csrmm(data, indices, indptr, dense, oshape):
+    # pylint: disable=invalid-name
+    assert len(data.shape) == 1 and len(indices.shape) == 1 and len(indptr.shape) == 1 \
+        and len(dense.shape) == 3, "only support 2-dim csrmm"
+    assert indptr.dtype == 'int32', f"CSR indptr must be integers, but is {indptr.dtype}"
+    assert indices.dtype == 'int32', f"CSR indices must be integers, but is {indices.dtype}"
+
+    assert isinstance(dense, te.tensor.Tensor), \
+        "dense matrix is assumed to be tvm.te.Tensor, but dense is `%s`" % (type(dense))
+
+    M = simplify(indptr.shape[0]-1)
+    batches, _, N = dense.shape
+    def csrmm_default_ir(data, indices, indptr, dense, out):
+        """define ir for csrmm"""
+        irb = tvm.tir.ir_builder.create()
+        data_ptr = irb.buffer_ptr(data)
+        indices_ptr = irb.buffer_ptr(indices)
+        indptr_ptr = irb.buffer_ptr(indptr)
+        dense_ptr = irb.buffer_ptr(dense)
+        out_ptr = irb.buffer_ptr(out)
+        M = simplify(indptr.shape[0]-1)
+        batches, _, N = dense.shape
+        with irb.for_range(0, batches, name='batch') as batch:
+            with irb.for_range(0, N, for_type="vectorize", name='n') as n:
+                with irb.for_range(0, M, for_type="parallel", name='row') as row:
+                    dot = irb.allocate('float32', (1,), name='dot', scope='local')
+                    out_ptr[(batch*N*M) + (row*N+n)] = 0.
+                    dot[0] = 0.
+                    row_start = indptr_ptr[row]
+                    row_end = indptr_ptr[row+1]
+                    row_elems = row_end-row_start
+                    with irb.for_range(0, row_elems, name='idx') as idx:
+                        elem = row_start+idx
+                        dot[0] += data_ptr[elem] * dense_ptr[indices_ptr[elem]*N+n]
+                    out_ptr[(batch*N*M) + row*N+n] += dot[0]
+        return irb.get()
+    matmul = te.extern(oshape, [data, indices, indptr, dense],
+                       lambda ins, outs: csrmm_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+                       tag="csrmm", dtype='float32', name='out')

Review comment:
       +1, the TOPI call for the sparse conv2d will have this type information I will pass to it. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Wheest commented on a change in pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
Wheest commented on a change in pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#discussion_r538449021



##########
File path: python/tvm/topi/sparse/csrmm.py
##########
@@ -121,3 +121,46 @@ def csrmm(a, b, c=None):
         2-D with shape [m, n]
     """
     return csrmm_default(a.data, a.indices, a.indptr, b, c)
+
+
+def batch_csrmm(data, indices, indptr, dense, oshape):
+    # pylint: disable=invalid-name
+    assert len(data.shape) == 1 and len(indices.shape) == 1 and len(indptr.shape) == 1 \
+        and len(dense.shape) == 3, "only support 2-dim csrmm"
+    assert indptr.dtype == 'int32', f"CSR indptr must be integers, but is {indptr.dtype}"
+    assert indices.dtype == 'int32', f"CSR indices must be integers, but is {indices.dtype}"
+
+    assert isinstance(dense, te.tensor.Tensor), \
+        "dense matrix is assumed to be tvm.te.Tensor, but dense is `%s`" % (type(dense))
+
+    M = simplify(indptr.shape[0]-1)
+    batches, _, N = dense.shape
+    def csrmm_default_ir(data, indices, indptr, dense, out):
+        """define ir for csrmm"""
+        irb = tvm.tir.ir_builder.create()
+        data_ptr = irb.buffer_ptr(data)
+        indices_ptr = irb.buffer_ptr(indices)
+        indptr_ptr = irb.buffer_ptr(indptr)
+        dense_ptr = irb.buffer_ptr(dense)
+        out_ptr = irb.buffer_ptr(out)
+        M = simplify(indptr.shape[0]-1)
+        batches, _, N = dense.shape
+        with irb.for_range(0, batches, name='batch') as batch:
+            with irb.for_range(0, N, for_type="vectorize", name='n') as n:
+                with irb.for_range(0, M, for_type="parallel", name='row') as row:
+                    dot = irb.allocate('float32', (1,), name='dot', scope='local')
+                    out_ptr[(batch*N*M) + (row*N+n)] = 0.

Review comment:
       +1 on multidimensional access, will update my code when I go back to it.
   
   
   
   >     1. You probably want to implement block sparsity as CSR is just a special case of BSR.
   
   I'll need to read more about block sparsity, as it's not a format I understand yet.  Will take a look, interesting if we get two formats for the price of one.
   
   >     2. There is already existing code to convert a dense matrix to a sparse matrix. Do we need another version for conv2d?
   
   The existing `csrmm` function does not support batches afaik (so would be `(NxK) x (KxM)`, rather than the `(NxK) x (BxKxM)` we need where `B` is the number of batches).  Ideally one sparse matmul would be good, in theory we could do `B` calls to the standard `csrmm` function.  Practically I'd need to think about what that would look like from an implementation perspective, since the size data is stored in the TVM tensors, and is used in `csrmm` to generate the function.
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ThatAIGeek commented on pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
ThatAIGeek commented on pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#issuecomment-814953070


   Hi there!
   What is the state of this? Sparse convolutions is smth I really would like to have in TVM, what the state of it right now?
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Wheest commented on a change in pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
Wheest commented on a change in pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#discussion_r538435168



##########
File path: python/tvm/topi/nn/conv2d_sparse.py
##########
@@ -0,0 +1,261 @@
+import tvm
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.nn.util import get_pad_tuple
+from tvm.topi.util import get_const_tuple
+from tvm import autotvm
+from ..nn.conv2d import conv2d_infer_layout, _get_workload as _get_conv2d_workload
+from ..util import get_const_tuple, traverse_inline
+from tvm.topi.sparse import batch_csrmm, csrmm_default
+
+def _fallback_schedule(cfg, wkl):
+    HPAD, WPAD = wkl.hpad, wkl.wpad
+    HSTR, WSTR = wkl.hstride, wkl.wstride
+    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+
+def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
+                        layout='NCHW'):
+    """
+    Get default schedule config for the workload
+    """
+    static_data_shape = []
+    for dim in get_const_tuple(data.shape):
+        if isinstance(dim, tvm.tir.Var):
+            static_data_shape.append(1)
+        else:
+            static_data_shape.append(dim)
+    data = te.placeholder(static_data_shape, dtype=data.dtype)
+    wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
+    is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
+    _fallback_schedule(cfg, wkl)
+
+def conv2d_sparse_gemm_nchw(data, w_data, w_indices, w_indptr,
+                            OC, KH, KW,
+                            strides, padding, dilation,
+                            out_dtype='float32'):
+    """Compute conv2d by transforming the input,
+    executing GEMM and not transforming the output back yet"""
+    batches, IC, IH, IW = get_const_tuple(data.shape)
+
+    K = KH * KW
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = \
+        get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    N = OC
+    K = KH * KW * IC
+    M = OH * OW
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_down, pad_right],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    # --- Im2col
+
+    B_shape = (batches, K, M)
+    idxmod = tvm.tir.indexmod
+    idxdiv = tvm.tir.indexdiv
+    # print(KH, KW, IC, OW, HSTR)
+
+    B = te.compute(B_shape, lambda n, k, m:
+                   data_pad[n, (k // (KH*KW)) % IC,
+                            (k // KH) % KW + ((m // OW) * HSTR),
+                            (k % KW) + ((m % OW) * WSTR)],
+                       name='data_im2col')
+
+
+    # --- GEMM: A*B'
+    # oshape = (batches, N, M)
+    oshape = (batches, OC, OH, OW)
+    # B = te.compute((N,M), lambda n, m:
+    #                B[0, n, m],
+    #                name='data_flatten')
+    C = batch_csrmm(w_data, w_indices, w_indptr, B, oshape)
+    # C = csrmm_default(w_data, w_indices, w_indptr, B)
+
+
+    # placeholder reshape
+    # k = te.reduce_axis((0, K), 'k')
+    # C = te.compute(
+    #     oshape,
+    #     lambda b, c, h, w: te.sum(C[b, c, w] * C[b, c, w], axis=k),
+    #     name='C')
+
+    return C
+
+def csrdc(data, indices, indptr, inputs, oshape, kdim, strides, padding):

Review comment:
       CSR Direct Convolution, following the naming convention from `csrmm` (CSR Matrix Multiply) in `python/tvm/topi/sparse/csrmm.py`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#discussion_r537793211



##########
File path: python/tvm/topi/nn/conv2d_sparse.py
##########
@@ -0,0 +1,261 @@
+import tvm
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.nn.util import get_pad_tuple
+from tvm.topi.util import get_const_tuple
+from tvm import autotvm
+from ..nn.conv2d import conv2d_infer_layout, _get_workload as _get_conv2d_workload
+from ..util import get_const_tuple, traverse_inline
+from tvm.topi.sparse import batch_csrmm, csrmm_default
+
+def _fallback_schedule(cfg, wkl):
+    HPAD, WPAD = wkl.hpad, wkl.wpad
+    HSTR, WSTR = wkl.hstride, wkl.wstride
+    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+
+def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
+                        layout='NCHW'):
+    """
+    Get default schedule config for the workload
+    """
+    static_data_shape = []
+    for dim in get_const_tuple(data.shape):
+        if isinstance(dim, tvm.tir.Var):
+            static_data_shape.append(1)
+        else:
+            static_data_shape.append(dim)
+    data = te.placeholder(static_data_shape, dtype=data.dtype)
+    wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
+    is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
+    _fallback_schedule(cfg, wkl)
+
+def conv2d_sparse_gemm_nchw(data, w_data, w_indices, w_indptr,
+                            OC, KH, KW,
+                            strides, padding, dilation,
+                            out_dtype='float32'):
+    """Compute conv2d by transforming the input,
+    executing GEMM and not transforming the output back yet"""
+    batches, IC, IH, IW = get_const_tuple(data.shape)
+
+    K = KH * KW
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = \
+        get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    N = OC
+    K = KH * KW * IC
+    M = OH * OW
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_down, pad_right],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    # --- Im2col
+
+    B_shape = (batches, K, M)
+    idxmod = tvm.tir.indexmod
+    idxdiv = tvm.tir.indexdiv
+    # print(KH, KW, IC, OW, HSTR)
+
+    B = te.compute(B_shape, lambda n, k, m:
+                   data_pad[n, (k // (KH*KW)) % IC,
+                            (k // KH) % KW + ((m // OW) * HSTR),
+                            (k % KW) + ((m % OW) * WSTR)],
+                       name='data_im2col')
+
+
+    # --- GEMM: A*B'
+    # oshape = (batches, N, M)
+    oshape = (batches, OC, OH, OW)
+    # B = te.compute((N,M), lambda n, m:
+    #                B[0, n, m],
+    #                name='data_flatten')
+    C = batch_csrmm(w_data, w_indices, w_indptr, B, oshape)
+    # C = csrmm_default(w_data, w_indices, w_indptr, B)
+
+
+    # placeholder reshape
+    # k = te.reduce_axis((0, K), 'k')
+    # C = te.compute(
+    #     oshape,
+    #     lambda b, c, h, w: te.sum(C[b, c, w] * C[b, c, w], axis=k),
+    #     name='C')
+
+    return C
+
+def csrdc(data, indices, indptr, inputs, oshape, kdim, strides, padding):

Review comment:
       What is csrdc?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
jroesch commented on pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#issuecomment-1016774381


   This PR appears to be out of date, please feel free to reopen it if this is not the case.
   
   As part of the new year we are attempting to triage the project's open pull requests to ensure that code which
   is ready for review and/or merging receives adequate attention.
   
   Thanks again for your contribution, and feel free to reach out to discuss these changes.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] giuseros commented on pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#issuecomment-740673709


   Hi @Wheest , 
   This is very interesting. I see you are adding  support for x86. Is there anything x86 specific? If not, how hard would be to support also ARM cpus?
   
   Thanks,


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Wheest commented on pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
Wheest commented on pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#issuecomment-740733320


   Hi there @giuseros, I'm testing in both x86 and ARM, and as long as you're not doing autotuning it will use the same code.  Nothing x86 specific yet.   It's just that I've got the code defined in the `x86` TOPI directory for now, perhaps I should move it to the `sparse`, or `nn` one.
   
   See the [gist](https://gist.github.com/Wheest/94433f73ff3279669bf35adcc38b321d) for an example of how it is used.  You can also apply the `data_dep_optimization` to sparsify Conv2D on full networks, 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch closed pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
jroesch closed pull request #7050:
URL: https://github.com/apache/tvm/pull/7050


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#discussion_r538663169



##########
File path: python/tvm/topi/nn/conv2d_sparse.py
##########
@@ -0,0 +1,261 @@
+import tvm
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.nn.util import get_pad_tuple
+from tvm.topi.util import get_const_tuple
+from tvm import autotvm
+from ..nn.conv2d import conv2d_infer_layout, _get_workload as _get_conv2d_workload
+from ..util import get_const_tuple, traverse_inline
+from tvm.topi.sparse import batch_csrmm, csrmm_default
+
+def _fallback_schedule(cfg, wkl):
+    HPAD, WPAD = wkl.hpad, wkl.wpad
+    HSTR, WSTR = wkl.hstride, wkl.wstride
+    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+
+def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
+                        layout='NCHW'):
+    """
+    Get default schedule config for the workload
+    """
+    static_data_shape = []
+    for dim in get_const_tuple(data.shape):
+        if isinstance(dim, tvm.tir.Var):
+            static_data_shape.append(1)
+        else:
+            static_data_shape.append(dim)
+    data = te.placeholder(static_data_shape, dtype=data.dtype)
+    wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
+    is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
+    _fallback_schedule(cfg, wkl)
+
+def conv2d_sparse_gemm_nchw(data, w_data, w_indices, w_indptr,
+                            OC, KH, KW,
+                            strides, padding, dilation,
+                            out_dtype='float32'):
+    """Compute conv2d by transforming the input,
+    executing GEMM and not transforming the output back yet"""
+    batches, IC, IH, IW = get_const_tuple(data.shape)
+
+    K = KH * KW
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = \
+        get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    N = OC
+    K = KH * KW * IC
+    M = OH * OW
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_down, pad_right],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    # --- Im2col
+
+    B_shape = (batches, K, M)
+    idxmod = tvm.tir.indexmod
+    idxdiv = tvm.tir.indexdiv
+    # print(KH, KW, IC, OW, HSTR)
+
+    B = te.compute(B_shape, lambda n, k, m:
+                   data_pad[n, (k // (KH*KW)) % IC,
+                            (k // KH) % KW + ((m // OW) * HSTR),
+                            (k % KW) + ((m % OW) * WSTR)],
+                       name='data_im2col')
+
+
+    # --- GEMM: A*B'
+    # oshape = (batches, N, M)
+    oshape = (batches, OC, OH, OW)
+    # B = te.compute((N,M), lambda n, m:
+    #                B[0, n, m],
+    #                name='data_flatten')
+    C = batch_csrmm(w_data, w_indices, w_indptr, B, oshape)
+    # C = csrmm_default(w_data, w_indices, w_indptr, B)
+
+
+    # placeholder reshape
+    # k = te.reduce_axis((0, K), 'k')
+    # C = te.compute(
+    #     oshape,
+    #     lambda b, c, h, w: te.sum(C[b, c, w] * C[b, c, w], axis=k),
+    #     name='C')
+
+    return C
+
+def csrdc(data, indices, indptr, inputs, oshape, kdim, strides, padding):

Review comment:
       Could we use that (`car_direct_convolution`) as the name instead?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#discussion_r537807196



##########
File path: python/tvm/topi/sparse/csrmm.py
##########
@@ -121,3 +121,46 @@ def csrmm(a, b, c=None):
         2-D with shape [m, n]
     """
     return csrmm_default(a.data, a.indices, a.indptr, b, c)
+
+
+def batch_csrmm(data, indices, indptr, dense, oshape):
+    # pylint: disable=invalid-name
+    assert len(data.shape) == 1 and len(indices.shape) == 1 and len(indptr.shape) == 1 \
+        and len(dense.shape) == 3, "only support 2-dim csrmm"
+    assert indptr.dtype == 'int32', f"CSR indptr must be integers, but is {indptr.dtype}"
+    assert indices.dtype == 'int32', f"CSR indices must be integers, but is {indices.dtype}"
+
+    assert isinstance(dense, te.tensor.Tensor), \
+        "dense matrix is assumed to be tvm.te.Tensor, but dense is `%s`" % (type(dense))
+
+    M = simplify(indptr.shape[0]-1)
+    batches, _, N = dense.shape
+    def csrmm_default_ir(data, indices, indptr, dense, out):
+        """define ir for csrmm"""
+        irb = tvm.tir.ir_builder.create()
+        data_ptr = irb.buffer_ptr(data)
+        indices_ptr = irb.buffer_ptr(indices)
+        indptr_ptr = irb.buffer_ptr(indptr)
+        dense_ptr = irb.buffer_ptr(dense)
+        out_ptr = irb.buffer_ptr(out)
+        M = simplify(indptr.shape[0]-1)
+        batches, _, N = dense.shape
+        with irb.for_range(0, batches, name='batch') as batch:
+            with irb.for_range(0, N, for_type="vectorize", name='n') as n:
+                with irb.for_range(0, M, for_type="parallel", name='row') as row:
+                    dot = irb.allocate('float32', (1,), name='dot', scope='local')
+                    out_ptr[(batch*N*M) + (row*N+n)] = 0.

Review comment:
       ir_builder supports multidimensional access (`out_ptr[batch, row, n]`), which might make this code cleaner.

##########
File path: python/tvm/topi/sparse/csrmm.py
##########
@@ -121,3 +121,46 @@ def csrmm(a, b, c=None):
         2-D with shape [m, n]
     """
     return csrmm_default(a.data, a.indices, a.indptr, b, c)
+
+
+def batch_csrmm(data, indices, indptr, dense, oshape):
+    # pylint: disable=invalid-name
+    assert len(data.shape) == 1 and len(indices.shape) == 1 and len(indptr.shape) == 1 \
+        and len(dense.shape) == 3, "only support 2-dim csrmm"
+    assert indptr.dtype == 'int32', f"CSR indptr must be integers, but is {indptr.dtype}"
+    assert indices.dtype == 'int32', f"CSR indices must be integers, but is {indices.dtype}"
+
+    assert isinstance(dense, te.tensor.Tensor), \
+        "dense matrix is assumed to be tvm.te.Tensor, but dense is `%s`" % (type(dense))
+
+    M = simplify(indptr.shape[0]-1)
+    batches, _, N = dense.shape
+    def csrmm_default_ir(data, indices, indptr, dense, out):
+        """define ir for csrmm"""
+        irb = tvm.tir.ir_builder.create()
+        data_ptr = irb.buffer_ptr(data)
+        indices_ptr = irb.buffer_ptr(indices)
+        indptr_ptr = irb.buffer_ptr(indptr)
+        dense_ptr = irb.buffer_ptr(dense)
+        out_ptr = irb.buffer_ptr(out)
+        M = simplify(indptr.shape[0]-1)
+        batches, _, N = dense.shape
+        with irb.for_range(0, batches, name='batch') as batch:
+            with irb.for_range(0, N, for_type="vectorize", name='n') as n:
+                with irb.for_range(0, M, for_type="parallel", name='row') as row:
+                    dot = irb.allocate('float32', (1,), name='dot', scope='local')
+                    out_ptr[(batch*N*M) + (row*N+n)] = 0.
+                    dot[0] = 0.
+                    row_start = indptr_ptr[row]
+                    row_end = indptr_ptr[row+1]
+                    row_elems = row_end-row_start
+                    with irb.for_range(0, row_elems, name='idx') as idx:
+                        elem = row_start+idx
+                        dot[0] += data_ptr[elem] * dense_ptr[indices_ptr[elem]*N+n]
+                    out_ptr[(batch*N*M) + row*N+n] += dot[0]
+        return irb.get()
+    matmul = te.extern(oshape, [data, indices, indptr, dense],
+                       lambda ins, outs: csrmm_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
+                       tag="csrmm", dtype='float32', name='out')

Review comment:
       I think we would like to support more than float32.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on pull request #7050: Sparse Conv2D for CPU (NCHW)

Posted by GitBox <gi...@apache.org>.
altanh commented on pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#issuecomment-740137646


   cc @tkonolige who has some sparse experience


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org