You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jc...@apache.org on 2021/06/30 05:59:42 UTC

[tvm] branch main updated: [Topi][Unittests] Parametrized tests in `test_topi_dense.py`, split out gpu-independent implementations (#8336)

This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ae58f2c  [Topi][Unittests] Parametrized tests in `test_topi_dense.py`, split out gpu-independent implementations (#8336)
ae58f2c is described below

commit ae58f2c387de9944d241a083ce9a0dd4c9ae613d
Author: Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Tue Jun 29 22:59:16 2021 -0700

    [Topi][Unittests] Parametrized tests in `test_topi_dense.py`, split out gpu-independent implementations (#8336)
    
    * [Topi][UnitTests] Parametrized tests in test_topi_dense.py
    
    Now, tests run for multiple data types, can be extended with
    additional datatypes.
    
    * [Topi] Separated generic-gpu nn.dense implementations into topi.gpu.dense
    
    As a follow-up to the renaming of "gpu" to "cuda", separating
    implementations that require CUDA (e.g. dense_cublas.cuda) from
    implementations that require any GPU, but not necessarily a CUDA GPU
    (e.g. dense_small_batch.gpu).
    
    My intent is to pair this migration with the extension of unit tests
    to cover additional GPU runtimes, migrating only implementations that
    run correctly on non-CUDA GPU devices.
    
    * [Vulkan][Codegen] Updated storage sync to avoid incorrect matmul results on some GPUs
    
    - In ThreadAllreduceBuilder, separate out load/store so that they can
      have a memory barrier in-between.
    
    - In Vulkan codegen, added Workgroup memory sync for subgroup thread
      sync, since the different subgroup threads can still access
      workgroup memory.  Longer-term, may need tir enhancements to
      separate out sync of control/memory.
    
    Co-authored-by: Eric Lunderberg <el...@octoml.ai>
---
 python/tvm/relay/op/strategy/cuda.py               |  19 +-
 python/tvm/testing.py                              |   2 +-
 python/tvm/topi/__init__.py                        |   1 +
 python/tvm/topi/cuda/dense.py                      | 191 -------------------
 python/tvm/topi/gpu/__init__.py                    |  20 ++
 python/tvm/topi/{cuda => gpu}/dense.py             | 169 ++--------------
 src/target/spirv/codegen_spirv.cc                  |  28 ++-
 src/tir/transforms/lower_thread_allreduce.cc       |  57 +++++-
 tests/python/relay/test_autotvm_task_extraction.py |   2 +-
 tests/python/topi/python/test_topi_dense.py        | 212 +++++++++++----------
 10 files changed, 228 insertions(+), 473 deletions(-)

diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py
index 6418f1f..683f3ec 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -705,7 +705,12 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
     data, weights = inputs
     b, i = get_const_tuple(data.shape)
     o, _ = get_const_tuple(weights.shape)
-    if data.dtype == "int8" and weights.dtype == "int8" and out_type.dtype == "int32":
+    if (
+        target.kind.name == "cuda"
+        and data.dtype == "int8"
+        and weights.dtype == "int8"
+        and out_type.dtype == "int32"
+    ):
         strategy.add_implementation(
             wrap_compute_dense(topi.cuda.dense_int8),
             wrap_topi_schedule(topi.cuda.schedule_dense_int8),
@@ -713,16 +718,16 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
         )
     else:
         strategy.add_implementation(
-            wrap_compute_dense(topi.cuda.dense_small_batch),
-            wrap_topi_schedule(topi.cuda.schedule_dense_small_batch),
-            name="dense_small_batch.cuda",
+            wrap_compute_dense(topi.gpu.dense_small_batch),
+            wrap_topi_schedule(topi.gpu.schedule_dense_small_batch),
+            name="dense_small_batch.gpu",
         )
 
         with SpecializedCondition(b >= 32):
             strategy.add_implementation(
-                wrap_compute_dense(topi.cuda.dense_large_batch),
-                wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
-                name="dense_large_batch.cuda",
+                wrap_compute_dense(topi.gpu.dense_large_batch),
+                wrap_topi_schedule(topi.gpu.schedule_dense_large_batch),
+                name="dense_large_batch.gpu",
                 plevel=5,
             )
         if target.kind.name == "cuda":
diff --git a/python/tvm/testing.py b/python/tvm/testing.py
index 8178b0a..4721c00 100644
--- a/python/tvm/testing.py
+++ b/python/tvm/testing.py
@@ -414,7 +414,7 @@ def _get_targets(target_str=None):
 
 
 DEFAULT_TEST_TARGETS = (
-    "llvm;cuda;opencl;metal;rocm;vulkan;nvptx;"
+    "llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;"
     "llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu"
 )
 
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index ef2c5c1..9b843ae 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -49,6 +49,7 @@ from . import generic
 from . import nn
 from . import x86
 from . import cuda
+from . import gpu
 from . import arm_cpu
 from . import mali
 from . import bifrost
diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py
index 8adc38b..0f410ae 100644
--- a/python/tvm/topi/cuda/dense.py
+++ b/python/tvm/topi/cuda/dense.py
@@ -19,10 +19,8 @@
 import logging
 from tvm import te
 import tvm.autotvm as autotvm
-from tvm.autotvm.task.space import SplitEntity
 from tvm.contrib import cublas
 from .tensor_intrin import dp4a
-from .. import nn
 from .. import tag
 from .. import generic
 from ..utils import traverse_inline, get_const_tuple
@@ -57,195 +55,6 @@ def schedule_dense_cublas(_, outs):
     return generic.schedule_extern(outs)
 
 
-@autotvm.register_topi_compute("dense_small_batch.cuda")
-def dense_small_batch(cfg, data, weight, bias=None, out_dtype=None):
-    """Dense operator on CUDA"""
-    return nn.dense(data, weight, bias, out_dtype)
-
-
-@autotvm.register_topi_schedule("dense_small_batch.cuda")
-def schedule_dense_small_batch(cfg, outs):
-    """Schedule float32/64 dense with small batch size"""
-    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
-    s = te.create_schedule([x.op for x in outs])
-
-    def _callback(op):
-        if op.tag == "dense":
-            _schedule_dense_small_batch(cfg, s, op.output(0))
-
-    traverse_inline(s, outs[0].op, _callback)
-    return s
-
-
-def _schedule_dense_small_batch(cfg, s, C):
-    A, weights = C.op.input_tensors
-    _, in_dim_weights = get_const_tuple(weights.shape)
-    _, in_dim_A = get_const_tuple(A.shape)
-
-    if isinstance(in_dim_A, int):
-        in_dim = in_dim_A
-    elif isinstance(in_dim_weights, int):
-        in_dim = in_dim_weights
-    else:
-        in_dim = None
-
-    if in_dim is not None:
-        cfg.define_split("tile_k", in_dim, num_outputs=2)
-        if cfg.is_fallback:
-            cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
-        _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0])
-    else:
-        tile_k = 64
-        _, kf = s[C].split(C.op.reduce_axis[0], tile_k)
-
-    CF = s.rfactor(C, kf)
-
-    if C.op in s.outputs:
-        Out = C
-    else:
-        Out = s.outputs[0].output(0)
-        s[C].compute_at(s[Out], s[Out].op.axis[1])
-    s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y"))
-    s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x"))
-
-    tx = s[C].op.reduce_axis[0]
-    thread_x = te.thread_axis("threadIdx.x")
-    s[C].bind(tx, thread_x)
-    s[CF].compute_at(s[C], tx)
-    s[C].set_store_predicate(thread_x.var.equal(0))
-    s[Out].set_store_predicate(thread_x.var.equal(0))
-
-
-@autotvm.register_topi_compute("dense_large_batch.cuda")
-def dense_large_batch(cfg, data, weight, bias=None, out_dtype=None):
-    """Dense operator on CUDA"""
-    return nn.dense(data, weight, bias, out_dtype)
-
-
-@autotvm.register_topi_schedule("dense_large_batch.cuda")
-def schedule_dense_large_batch(cfg, outs):
-    """Schedule float32/64 dense with large batch size"""
-    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
-    s = te.create_schedule([x.op for x in outs])
-
-    def _callback(op):
-        if op.tag == "dense":
-            _schedule_dense_large_batch(cfg, s, op.output(0))
-
-    traverse_inline(s, outs[0].op, _callback)
-    return s
-
-
-def _schedule_dense_large_batch(cfg, s, C):
-    """Schedule float32/64 dense with large batch size"""
-    A, B = C.op.input_tensors
-    batch, in_dim = get_const_tuple(A.shape)
-    out_dim, _ = get_const_tuple(B.shape)
-    k = C.op.reduce_axis[0]
-
-    # create tuning space
-    try:
-        block_cand = [64, 128]
-        vthread_cand = [2 ** x for x in range(1, 7)]
-        n_thread_cand = [2 ** x for x in range(3, 7)]
-        cfg.define_split(
-            "tile_x",
-            batch,
-            num_outputs=4,
-            filter=lambda x: (
-                x.size[1] in vthread_cand
-                and x.size[2] in n_thread_cand
-                and (x.size[1] * x.size[2] * x.size[3]) in block_cand
-            ),
-        )
-        cfg.define_split(
-            "tile_y",
-            out_dim,
-            num_outputs=4,
-            filter=lambda x: (
-                x.size[1] in vthread_cand
-                and x.size[2] in n_thread_cand
-                and (x.size[1] * x.size[2] * x.size[3]) in block_cand
-            ),
-        )
-        cfg.define_split("tile_k", in_dim, num_outputs=3, filter=lambda x: x.size[0] > 2)
-    except IndexError:
-        # Index error happens when no entities left after filtering, which was designed
-        # to prune tuning space for better search efficiency.
-        logger.debug("Tuning space was created without pruning due to unfit shapes")
-        cfg.define_split("tile_x", batch, num_outputs=4)
-        cfg.define_split("tile_y", out_dim, num_outputs=4)
-        cfg.define_split("tile_k", in_dim, num_outputs=3)
-
-    if cfg.is_fallback:
-        if batch > 1:
-            cfg["tile_x"] = SplitEntity([-1, 2, 16, 2])
-        else:
-            cfg["tile_x"] = SplitEntity([1, 1, 1, 1])
-        if out_dim > 1:
-            cfg["tile_y"] = SplitEntity([-1, 2, 16, 2])
-        else:
-            cfg["tile_y"] = SplitEntity([1, 1, 1, 1])
-        if in_dim > 8:
-            cfg["tile_k"] = SplitEntity([-1, 8, 1])
-        else:
-            cfg["tile_k"] = SplitEntity([-1, 1, 1])
-
-    # Explicit memory access
-    AA = s.cache_read(A, "shared", [C])
-    BB = s.cache_read(B, "shared", [C])
-    AL = s.cache_read(AA, "local", [C])
-    BL = s.cache_read(BB, "local", [C])
-    CC = s.cache_write(C, "local")
-
-    # Deal with op fusion
-    if C.op not in s.outputs:
-        s[C].compute_inline()
-        C = s.outputs[0].output(0)
-
-    # Split and reorder computation
-    bx, txz, tx, xi = cfg["tile_x"].apply(s, C, C.op.axis[0])
-    by, tyz, ty, yi = cfg["tile_y"].apply(s, C, C.op.axis[1])
-    s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
-    s[CC].compute_at(s[C], tx)
-
-    # Binding
-    s[C].bind(by, te.thread_axis("blockIdx.y"))
-    s[C].bind(bx, te.thread_axis("blockIdx.x"))
-    s[C].bind(tyz, te.thread_axis("vthread"))
-    s[C].bind(txz, te.thread_axis("vthread"))
-    s[C].bind(ty, te.thread_axis("threadIdx.y"))
-    s[C].bind(tx, te.thread_axis("threadIdx.x"))
-
-    # Split reduction
-    yo, xo = CC.op.axis
-    ko, kt, ki = cfg["tile_k"].apply(s, CC, k)
-    s[CC].reorder(ko, kt, ki, yo, xo)
-    s[AA].compute_at(s[CC], ko)
-    s[BB].compute_at(s[CC], ko)
-    s[CC].unroll(kt)
-    s[AL].compute_at(s[CC], kt)
-    s[BL].compute_at(s[CC], kt)
-
-    # Schedule for A's shared memory load
-    num_thread_x = cfg["tile_x"].size[2]
-    ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x)
-    _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4)
-    tx, xi = s[AA].split(xi, nparts=num_thread_x)
-    s[AA].bind(ty, te.thread_axis("threadIdx.y"))
-    s[AA].bind(tx, te.thread_axis("threadIdx.x"))
-    s[AA].double_buffer()
-
-    # Schedule for B' shared memory load
-    num_thread_y = cfg["tile_y"].size[2]
-    ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y)
-    _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4)
-    tx, xi = s[BB].split(xi, nparts=num_thread_y)
-    s[BB].bind(ty, te.thread_axis("threadIdx.y"))
-    s[BB].bind(tx, te.thread_axis("threadIdx.x"))
-    s[BB].double_buffer()
-
-
 @autotvm.register_topi_compute("dense_int8.cuda")
 def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
     """Dense operator for int8 on CUDA"""
diff --git a/python/tvm/topi/gpu/__init__.py b/python/tvm/topi/gpu/__init__.py
new file mode 100644
index 0000000..6d9fd39
--- /dev/null
+++ b/python/tvm/topi/gpu/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=redefined-builtin, wildcard-import
+"""GPU specific declaration and schedules."""
+from .dense import *
diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/gpu/dense.py
similarity index 57%
copy from python/tvm/topi/cuda/dense.py
copy to python/tvm/topi/gpu/dense.py
index 8adc38b..806aa9f 100644
--- a/python/tvm/topi/cuda/dense.py
+++ b/python/tvm/topi/gpu/dense.py
@@ -14,56 +14,28 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 # pylint: disable=invalid-name, unused-argument
 """Schedule for dense operator"""
+
 import logging
-from tvm import te
-import tvm.autotvm as autotvm
+
+from tvm import autotvm, te
 from tvm.autotvm.task.space import SplitEntity
-from tvm.contrib import cublas
-from .tensor_intrin import dp4a
+
 from .. import nn
-from .. import tag
-from .. import generic
 from ..utils import traverse_inline, get_const_tuple
 
 logger = logging.getLogger("topi")
 
 
-@autotvm.register_topi_compute("dense_cublas.cuda")
-def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
-    """Dense operator on CUDA with CUBLAS"""
-    assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense"
-    if bias is not None:
-        assert len(bias.shape) == 1
-    if out_dtype is None:
-        out_dtype = data.dtype
-    assert out_dtype == data.dtype, "Mixed precision not supported."
-    batch, in_dim = get_const_tuple(data.shape)
-    out_dim, _ = get_const_tuple(weight.shape)
-    matmul = cublas.matmul(data, weight, False, True)
-    if all(isinstance(d, int) for d in [batch, in_dim, out_dim]):
-        cfg.add_flop(batch * in_dim * out_dim * 2)
-    if bias is not None:
-        matmul = te.compute(
-            (batch, out_dim), lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST
-        )
-    return matmul
-
-
-@autotvm.register_topi_schedule("dense_cublas.cuda")
-def schedule_dense_cublas(_, outs):
-    """Schedule dense operator using CUBLAS"""
-    return generic.schedule_extern(outs)
-
-
-@autotvm.register_topi_compute("dense_small_batch.cuda")
+@autotvm.register_topi_compute("dense_small_batch.gpu")
 def dense_small_batch(cfg, data, weight, bias=None, out_dtype=None):
-    """Dense operator on CUDA"""
+    """Dense operator on GPU"""
     return nn.dense(data, weight, bias, out_dtype)
 
 
-@autotvm.register_topi_schedule("dense_small_batch.cuda")
+@autotvm.register_topi_schedule("dense_small_batch.gpu")
 def schedule_dense_small_batch(cfg, outs):
     """Schedule float32/64 dense with small batch size"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
@@ -116,13 +88,13 @@ def _schedule_dense_small_batch(cfg, s, C):
     s[Out].set_store_predicate(thread_x.var.equal(0))
 
 
-@autotvm.register_topi_compute("dense_large_batch.cuda")
+@autotvm.register_topi_compute("dense_large_batch.gpu")
 def dense_large_batch(cfg, data, weight, bias=None, out_dtype=None):
-    """Dense operator on CUDA"""
+    """Dense operator on GPU"""
     return nn.dense(data, weight, bias, out_dtype)
 
 
-@autotvm.register_topi_schedule("dense_large_batch.cuda")
+@autotvm.register_topi_schedule("dense_large_batch.gpu")
 def schedule_dense_large_batch(cfg, outs):
     """Schedule float32/64 dense with large batch size"""
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
@@ -244,122 +216,3 @@ def _schedule_dense_large_batch(cfg, s, C):
     s[BB].bind(ty, te.thread_axis("threadIdx.y"))
     s[BB].bind(tx, te.thread_axis("threadIdx.x"))
     s[BB].double_buffer()
-
-
-@autotvm.register_topi_compute("dense_int8.cuda")
-def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
-    """Dense operator for int8 on CUDA"""
-    if out_dtype is None:
-        out_dtype = data.dtype
-
-    batch, in_dim = get_const_tuple(data.shape)
-    out_dim, _ = get_const_tuple(weight.shape)
-    k = te.reduce_axis((0, in_dim), name="k")
-
-    matmul = te.compute(
-        (batch, out_dim),
-        lambda i, j: te.sum(
-            data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=[k]
-        ),
-        tag="dense_int8",
-    )
-
-    cfg.add_flop(batch * in_dim * out_dim * 2)
-
-    if bias is not None:
-        matmul = te.compute(
-            (batch, out_dim),
-            lambda i, j: matmul[i, j] + bias[j].astype(out_dtype),
-            tag=tag.BROADCAST,
-        )
-        cfg.add_flop(batch * out_dim)
-
-    return matmul
-
-
-@autotvm.register_topi_schedule("dense_int8.cuda")
-def schedule_dense_int8(cfg, outs):
-    """Dense schedule for int8 on CUDA"""
-    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
-    s = te.create_schedule([x.op for x in outs])
-
-    def _callback(op):
-        if "dense_int8" in op.tag:
-            _schedule_dense_int8(cfg, s, op.output(0))
-
-    traverse_inline(s, outs[0].op, _callback)
-    return s
-
-
-_dp4a = dp4a("shared", "shared", "local")
-
-
-def _schedule_dense_int8(cfg, s, output):
-    data, weight = s[output].op.input_tensors
-
-    batch, in_dim = get_const_tuple(data.shape)
-    out_dim, _ = get_const_tuple(weight.shape)
-
-    in_dim_factor = 4
-    assert in_dim % in_dim_factor == 0, "Input dimension must divide {}".format(in_dim_factor)
-    if in_dim % 16 == 0:
-        in_dim_factor = 16
-
-    # create tuning space
-    cfg.define_split("tile_y", batch, num_outputs=4)
-    cfg.define_split("tile_x", out_dim, num_outputs=4)
-    cfg.define_split("tile_k", in_dim // in_dim_factor, num_outputs=2)
-    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
-
-    # create cache stage
-    AA = s.cache_read(data, "shared", [output])
-    WW = s.cache_read(weight, "shared", [output])
-    CC = s.cache_write(output, "local")
-
-    # handle bias
-    if output.op not in s.outputs:
-        s[output].compute_inline()
-        output = s.outputs[0].output(0)
-
-    n, x = s[output].op.axis
-
-    # this is the scope to attach global config inside this kernel
-    kernel_scope, n = s[output].split(n, nparts=1)
-
-    ko = CC.op.reduce_axis[0]
-    ko, ki = s[CC].split(ko, factor=4)
-    ko, kt = cfg["tile_k"].apply(s, CC, ko)
-    s[CC].tensorize(ki, _dp4a)
-    by, vy, ty, yi = cfg["tile_y"].apply(s, output, n)
-    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
-
-    s[output].reorder(by, bx, vy, vx, ty, tx, yi, xi)
-    s[output].bind(by, te.thread_axis("blockIdx.y"))
-    s[output].bind(bx, te.thread_axis("blockIdx.x"))
-    s[output].bind(vy, te.thread_axis("vthread"))
-    s[output].bind(vx, te.thread_axis("vthread"))
-    s[output].bind(ty, te.thread_axis("threadIdx.y"))
-    s[output].bind(tx, te.thread_axis("threadIdx.x"))
-    n_ty = cfg["tile_y"].size[2]
-    n_tx = cfg["tile_x"].size[2]
-
-    s[CC].compute_at(s[output], tx)
-    yo, xo = CC.op.axis[:2]
-    s[CC].reorder(ko, kt, yo, xo, ki)
-
-    for load in [AA, WW]:
-        s[load].compute_at(s[CC], ko)
-
-        outer, inner = s[load].split(s[load].op.axis[-1], factor=in_dim_factor)
-        s[load].vectorize(inner)
-        fused = s[load].op.axis[:-1] + [outer]
-        fused = s[load].fuse(*fused)
-
-        fused, tx = s[load].split(fused, factor=n_tx)
-        fused, ty = s[load].split(fused, factor=n_ty)
-        s[load].bind(tx, te.thread_axis("threadIdx.x"))
-        s[load].bind(ty, te.thread_axis("threadIdx.y"))
-
-    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
-    s[output].pragma(kernel_scope, "unroll_explicit", False)
-    return s
diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc
index 2628406..5d52bee 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -63,6 +63,7 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
         }
         spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type),
                                                           descriptor_set, num_buffer);
+        builder_->SetName(arg_value, arg->name_hint);
         storage_info_[arg.get()].UpdateContentType(value_storage_type);
         var_map_[arg.get()] = arg_value;
       } else {
@@ -144,15 +145,21 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
   uint32_t vulkan_api_version = spirv_support_.vulkan_api_version;
 
   int64_t sync_scope;
-  int64_t memory_semantics;
+  int64_t memory_semantics = spv::MemorySemanticsSequentiallyConsistentMask;
   if ((sync == "warp") && (vulkan_api_version >= VK_API_VERSION_1_1)) {
+    // Synchronize control at the Subgroup level, but memory at the
+    // Workgroup level.  This is because different invocations in a
+    // subgroup may have each modified memory that exists at the
+    // workgroup scope.  This should be changed if/when tir exposes
+    // more information as to which memory access needs to be
+    // synchronized.
     sync_scope = spv::ScopeSubgroup;
-    memory_semantics =
-        spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsSubgroupMemoryMask;
+    memory_semantics |=
+        spv::MemorySemanticsSubgroupMemoryMask | spv::MemorySemanticsWorkgroupMemoryMask;
+
   } else if ((sync == "shared") || (sync == "warp")) {
     sync_scope = spv::ScopeWorkgroup;
-    memory_semantics =
-        spv::MemorySemanticsSequentiallyConsistentMask | spv::MemorySemanticsWorkgroupMemoryMask;
+    memory_semantics |= spv::MemorySemanticsWorkgroupMemoryMask;
   } else {
     LOG(FATAL) << "Do not support sync " << sync;
   }
@@ -161,6 +168,7 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
   builder_->MakeInst(spv::OpControlBarrier, builder_->IntImm(type_int, sync_scope),
                      builder_->IntImm(type_int, sync_scope),
                      builder_->IntImm(type_int, memory_semantics));
+
   return value;
 }
 
@@ -642,14 +650,16 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
   if (info.scope.rank == runtime::StorageRank::kLocal) {
     buf =
         builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassFunction);
-  } else {
-    // shared memory
-    ICHECK(info.scope.rank == runtime::StorageRank::kShared)
-        << "Can only allocate shared or local memory inside kernel";
+  } else if (info.scope.rank == runtime::StorageRank::kShared) {
     // Shared memory
     buf =
         builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassWorkgroup);
+  } else {
+    LOG(FATAL) << "Can only allocate shared or local memory inside kernel";
   }
+
+  builder_->SetName(buf, op->buffer_var->name_hint);
+
   ICHECK(!info.content_fixed);
   info.UpdateContentType(op->dtype);
   ICHECK(!var_map_.count(op->buffer_var.get()));
diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc
index 9598f07..9e53681 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -388,7 +388,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     size_t size = shared_bufs.size();
     PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
     // make reduction
-    auto freduce = [&](int offset) {
+    auto fload = [&](int offset) {
       Array<PrimExpr> a, b;
       for (size_t i = 0; i < size; ++i) {
         b.push_back(Load(types[i], shared_bufs[i],
@@ -397,12 +397,19 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         a.push_back(Load(types[i], shared_bufs[i], buf_index, const_true()));
       }
       Array<PrimExpr> ret = (*combiner)(a, b);
+      return ret;
+    };
+    auto fstore = [&](const Array<PrimExpr>& ret) {
       std::vector<Stmt> stores(size);
       for (size_t i = 0; i < size; ++i) {
         stores[i] = Store(shared_bufs[i], ret[i], buf_index, const_true());
       }
       return SeqStmt::Flatten(stores);
     };
+    auto freduce = [&](int offset) {
+      auto ret = fload(offset);
+      return fstore(ret);
+    };
     // Step one, check for
     if (reduce_align > reduce_extent) {
       // reduction with the boundary condition
@@ -420,15 +427,47 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       seq.emplace_back(SyncThread("shared"));
     }
     // in warp synchronization.
-    std::vector<Stmt> in_warp_seq;
-    PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1);
-    while (reduce_align > 1) {
-      reduce_align = reduce_align >> 1;
-      in_warp_seq.emplace_back(freduce(reduce_align));
-      in_warp_seq.emplace_back(SyncThread("warp"));
-    }
-    if (in_warp_seq.size() != 0) {
+    if (reduce_align > 1) {
+      PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1);
+
+      std::vector<Stmt> in_warp_seq;
+
+      while (reduce_align > 1) {
+        reduce_align = reduce_align >> 1;
+
+        // freduce can read/write to the same memory location.  For
+        // example, with reduce_align of 4, threadIdx 3 reads from
+        // memory location 7 as threadIdx 7 is writing to it.
+        // Therefore, we need to separate out the load from the store
+        // with a memory barrier in-between.  This isn't necessary for
+        // the earlier normal synchronization, because those are each
+        // protected by an if-statement.  The if-statement is avoided
+        // here to reduce thread divergence.
+        auto loads = fload(reduce_align);
+
+        Array<Var> in_warp_local_vars;
+        for (auto expr : loads) {
+          Var var(
+              "w_" + std::to_string(reduce_align) + "_" + std::to_string(in_warp_local_vars.size()),
+              expr->dtype);
+          in_warp_local_vars.push_back(var);
+        }
+
+        std::vector<Stmt> in_let_statement;
+        in_let_statement.emplace_back(SyncThread("warp"));
+        in_let_statement.emplace_back(
+            fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()}));
+        in_let_statement.emplace_back(SyncThread("warp"));
+
+        Stmt body = SeqStmt::Flatten(in_let_statement);
+        for (size_t i = 0; i < size; i++) {
+          body = LetStmt(in_warp_local_vars[i], loads[i], body);
+        }
+        in_warp_seq.push_back(body);
+      }
+
       Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
+
       seq.emplace_back(IfThenElse(in_warp_cond, warp_body));
       seq.emplace_back(SyncThread("shared"));
     }
diff --git a/tests/python/relay/test_autotvm_task_extraction.py b/tests/python/relay/test_autotvm_task_extraction.py
index b3f1868..83480a0 100644
--- a/tests/python/relay/test_autotvm_task_extraction.py
+++ b/tests/python/relay/test_autotvm_task_extraction.py
@@ -115,7 +115,7 @@ def test_task_extraction_for_dense_int8_cuda():
 
     mod, params = get_net(1, 16, 32, "float32", "float32")
     tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,))
-    assert len(tasks) == 1 and tasks[0].name == "dense_small_batch.cuda"
+    assert len(tasks) == 1 and tasks[0].name == "dense_small_batch.gpu"
 
     mod, params = get_net(1, 16, 32, "int8", "int32")
     tasks = autotvm.task.extract_from_program(mod, target=target, params=params, ops=(dense,))
diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py
index 07301fa..235a094 100644
--- a/tests/python/topi/python/test_topi_dense.py
+++ b/tests/python/topi/python/test_topi_dense.py
@@ -15,26 +15,37 @@
 # specific language governing permissions and limitations
 # under the License.
 """Test code for dense operator"""
+import contextlib
 import numpy as np
+import pytest
+import sys
+
 import tvm
-from tvm import te
-from tvm import topi
+import tvm.testing
 import tvm.topi.testing
+from tvm import te, topi
 from tvm.topi.utils import get_const_tuple
-from tvm.contrib.pickle_memoize import memoize
 
 from common import Int8Fallback
-import tvm.testing
 
-_dense_implement = {
+use_bias = tvm.testing.parameter(True, False)
+batch_size = tvm.testing.parameter(1, 2, 128)
+in_dim, out_dim = tvm.testing.parameters((1024, 1000))
+in_dtype, out_dtype = tvm.testing.parameters(
+    ("float32", "float32"),
+    ("int8", "int32"),
+)
+
+
+_dense_implementations = {
     "generic": [(topi.nn.dense, topi.generic.schedule_dense)],
     "cpu": [
         (topi.x86.dense_nopack, topi.x86.schedule_dense_nopack),
         (topi.x86.dense_pack, topi.x86.schedule_dense_pack),
     ],
     "gpu": [
-        (topi.cuda.dense_small_batch, topi.cuda.schedule_dense_small_batch),
-        (topi.cuda.dense_large_batch, topi.cuda.schedule_dense_large_batch),
+        (topi.gpu.dense_small_batch, topi.gpu.schedule_dense_small_batch),
+        (topi.gpu.dense_large_batch, topi.gpu.schedule_dense_large_batch),
     ],
     "mali": [(topi.mali.dense, topi.mali.schedule_dense)],
     "bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)],
@@ -43,108 +54,115 @@ _dense_implement = {
 }
 
 
-def verify_dense(batch, in_dim, out_dim, use_bias=True):
-    A = te.placeholder((batch, in_dim), name="A")
-    B = te.placeholder((out_dim, in_dim), name="B")
-    C = te.placeholder((out_dim,), name="C")
-    dtype = A.dtype
-
-    # use memoize to pickle the test data for next time use
-    @memoize("topi.tests.test_topi_dense")
-    def get_ref_data():
-        a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype)
-        b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype)
-        c_np = np.random.uniform(size=(out_dim,)).astype(dtype)
-        if use_bias:
-            d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0)
-        else:
-            d_np = np.maximum(np.dot(a_np, b_np.T), 0.0)
-        return (a_np, b_np, c_np, d_np)
-
-    # get the test data
-    a_np, b_np, c_np, d_np = get_ref_data()
-
-    def check_device(device, dev):
-        print("Running on target: %s" % device)
-        for fcompute, fschedule in tvm.topi.testing.dispatch(device, _dense_implement):
-            with tvm.target.Target(device):
-                D = fcompute(A, B, C if use_bias else None)
-                D = topi.nn.relu(D)
-                s = fschedule([D])
-            a = tvm.nd.array(a_np, dev)
-            b = tvm.nd.array(b_np, dev)
-            c = tvm.nd.array(c_np, dev)
-            d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), dev)
-            f = tvm.build(s, [A, B, C, D], device, name="dense")
-            f(a, b, c, d)
-            tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-5)
-
-    for device, dev in tvm.testing.enabled_targets():
-        check_device(device, dev)
-
-
-def verify_dense_int8(batch, in_dim, out_dim, use_bias=True):
-    dtype = "int8"
-    out_dtype = "int32"
-    A = te.placeholder((batch, in_dim), name="A", dtype=dtype)
-    B = te.placeholder((out_dim, in_dim), name="B", dtype=dtype)
+@tvm.testing.fixture(cache_return_value=True)
+def dense_ref_data(batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype):
+    if "float" in in_dtype:
+        a_np = np.random.uniform(size=(batch_size, in_dim)).astype(in_dtype)
+        b_np = np.random.uniform(size=(out_dim, in_dim)).astype(in_dtype)
+        c_np = np.random.uniform(size=(out_dim,)).astype(out_dtype)
+    elif in_dtype == "int8":
+        a_np = np.random.randint(low=-128, high=127, size=(batch_size, in_dim)).astype(in_dtype)
+        b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(in_dtype)
+        c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(out_dtype)
+    else:
+        raise ValueError("No method to generate test data for data type '{}'".format(in_dtype))
+
+    matmul = np.dot(a_np.astype(out_dtype), b_np.T.astype(out_dtype))
+
+    if use_bias:
+        matmul += c_np
+
+    d_np = np.maximum(matmul, 0)
+    return (a_np, b_np, c_np, d_np)
+
+
+def test_dense(
+    target,
+    dev,
+    batch_size,
+    in_dim,
+    out_dim,
+    use_bias,
+    dense_ref_data,
+    in_dtype,
+    out_dtype,
+    implementations=None,
+):
+    target = tvm.target.Target(target)
+
+    if (
+        in_dtype == "int8"
+        and target.kind.name == "cuda"
+        and not tvm.contrib.nvcc.have_int8(dev.compute_version)
+    ):
+        pytest.xfail("CUDA int8 intrinsics not available")
+
+    if (
+        in_dtype == "int8"
+        and target.kind.name == "vulkan"
+        and not target.attrs.get("supports_int8", False)
+    ):
+        pytest.xfail("Vulkan int8 driver support not available")
+
+    if (
+        target.kind.name not in ["llvm", "c"]
+        and len(set(target.keys) & set(_dense_implementations)) == 0
+    ):
+        pytest.xfail("No implementation for tvm.topi.testing.dispatch to find")
+
+    A = te.placeholder((batch_size, in_dim), name="A", dtype=in_dtype)
+    B = te.placeholder((out_dim, in_dim), name="B", dtype=in_dtype)
     C = te.placeholder((out_dim,), name="C", dtype=out_dtype)
 
-    # use memoize to pickle the test data for next time use
-    @memoize("topi.tests.test_topi_dense_int8")
-    def get_ref_data():
-        a_np = np.random.randint(low=-128, high=127, size=(batch, in_dim)).astype(dtype)
-        b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(dtype)
-        c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(out_dtype)
-        d_np = np.dot(a_np.astype(out_dtype), b_np.T.astype(out_dtype))
-        if use_bias:
-            d_np += c_np
-        d_np = np.maximum(d_np, 0.0)
-        return (a_np, b_np, c_np, d_np)
-
-    # get the test data
-    a_np, b_np, c_np, d_np = get_ref_data()
-
-    def check_device(device):
-        dev = tvm.device(device, 0)
-        if device == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version):
-            print("Skip because int8 intrinsics are not available")
-            return
-
-        print("Running on target: %s" % device)
-        with tvm.target.Target(device):
-            D = topi.cuda.dense_int8(A, B, C if use_bias else None, out_dtype)
+    a_np, b_np, c_np, d_np = dense_ref_data
+
+    if implementations is None:
+        implementations = tvm.topi.testing.dispatch(target, _dense_implementations)
+
+    for fcompute, fschedule in implementations:
+        with tvm.target.Target(target):
+            D = fcompute(A, B, C if use_bias else None, out_dtype)
             D = topi.nn.relu(D)
-            s = topi.cuda.schedule_dense_int8([D])
+            s = fschedule([D])
+
         a = tvm.nd.array(a_np, dev)
         b = tvm.nd.array(b_np, dev)
         c = tvm.nd.array(c_np, dev)
         d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev)
-        f = tvm.build(s, [A, B, C, D], device, name="dense")
+        f = tvm.build(s, [A, B, C, D], target, name="dense")
         f(a, b, c, d)
         tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-5)
 
-    for device in ["cuda"]:
-        check_device(device)
-
-
-@tvm.testing.uses_gpu
-def test_dense():
-    verify_dense(1, 1024, 1000, use_bias=True)
-    verify_dense(1, 1024, 1000, use_bias=False)
-    verify_dense(2, 1024, 1000, use_bias=True)
-    verify_dense(128, 1024, 1000, use_bias=False)
-    verify_dense(128, 1024, 1000, use_bias=True)
-
 
-@tvm.testing.requires_cuda
-@tvm.testing.requires_gpu
-def test_dense_int8():
+@pytest.mark.parametrize("target,in_dtype,out_dtype", [("cuda", "int8", "int32")])
+def test_dense_cuda_int8(
+    target,
+    dev,
+    batch_size,
+    in_dim,
+    out_dim,
+    use_bias,
+    dense_ref_data,
+    in_dtype,
+    out_dtype,
+):
+    implementations = [
+        (topi.cuda.dense_int8, topi.cuda.schedule_dense_int8),
+    ]
     with Int8Fallback():
-        verify_dense_int8(2, 1024, 1000, use_bias=True)
-        verify_dense_int8(2, 1024, 1000, use_bias=False)
+        test_dense(
+            target,
+            dev,
+            batch_size,
+            in_dim,
+            out_dim,
+            use_bias,
+            dense_ref_data,
+            in_dtype,
+            out_dtype,
+            implementations=implementations,
+        )
 
 
 if __name__ == "__main__":
-    test_dense()
-    test_dense_int8()
+    sys.exit(pytest.main(sys.argv))