You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/28 09:03:33 UTC

[tvm] branch main updated: [MetaSchedule] Integration test for CUDA AutoTensorization (#12142)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ee319d9d23 [MetaSchedule] Integration test for CUDA AutoTensorization (#12142)
ee319d9d23 is described below

commit ee319d9d23c80091da9c4fb764b1e6d49d462714
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Thu Jul 28 02:03:27 2022 -0700

    [MetaSchedule] Integration test for CUDA AutoTensorization (#12142)
    
    * [MetaSchedule] Integration test for CUDA AutoTensorization
    
    * cleanup
    
    * fix
---
 python/tvm/meta_schedule/default_config.py         | 52 +++++++++++++++
 src/meta_schedule/schedule_rule/auto_bind.cc       |  3 +
 .../test_meta_schedule_auto_tensorize.py           | 74 ++++++++++++++++++++--
 3 files changed, 123 insertions(+), 6 deletions(-)

diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index e99dd1383a..dc021e1731 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -349,3 +349,55 @@ class _DefaultCUDA:
             M.MutateUnroll(): 0.08,
             M.MutateThreadBinding(): 0.02,
         }
+
+
+class _DefaultCUDATensorCore:
+    """Default tuning configuration for CUDA TensorCore."""
+
+    @staticmethod
+    def schedule_rules():
+        from tvm.meta_schedule import schedule_rule as M
+        from tvm.tir.tensor_intrin import get_wmma_intrin_group
+
+        return [
+            M.MultiLevelTilingTensorCore(
+                intrin_groups=[
+                    get_wmma_intrin_group(
+                        store_scope="shared",
+                        in_dtype="float16",
+                        out_dtype="float16",
+                        trans_b=trans_b,
+                    )
+                    for trans_b in [False, True]
+                ],
+                structure="SSSRRSRS",
+                tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
+                max_innermost_factor=4,
+                vector_load_lens=[1, 2, 3, 4],
+                reuse_read=M.ReuseType(req="must", levels=[4], scope="shared"),
+                reuse_write=M.ReuseType(
+                    req="must",
+                    levels=[2],
+                    scope="shared",
+                ),
+            ),
+            *_DefaultCUDA.schedule_rules(),
+        ]
+
+    @staticmethod
+    def postprocs() -> List[Postproc]:
+        from tvm.meta_schedule import postproc as M
+
+        return [
+            M.DisallowDynamicLoop(),
+            M.RewriteCooperativeFetch(),
+            M.RewriteUnboundBlock(),
+            M.RewriteParallelVectorizeUnroll(),
+            M.RewriteReductionBlock(),
+            M.RewriteTensorize(),
+            M.VerifyGPUCode(),
+        ]
+
+    @staticmethod
+    def mutator_probs() -> Dict[Mutator, float]:
+        return _DefaultCUDA.mutator_probs()
diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc
index a67432ebc5..ff4d26084e 100644
--- a/src/meta_schedule/schedule_rule/auto_bind.cc
+++ b/src/meta_schedule/schedule_rule/auto_bind.cc
@@ -34,6 +34,9 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv,
   if (block_sref->parent == nullptr) {
     return;
   }
+  if (tir::HasBeenMultiLevelTiled(block_sref)) {
+    return;
+  }
   Array<StmtSRef> loops = tir::GetLoops(block_sref);
   int n = loops.size();
   int i_block_idx = -1;
diff --git a/tests/python/integration/test_meta_schedule_auto_tensorize.py b/tests/python/integration/test_meta_schedule_auto_tensorize.py
index b855dc6fa0..b1525df10e 100644
--- a/tests/python/integration/test_meta_schedule_auto_tensorize.py
+++ b/tests/python/integration/test_meta_schedule_auto_tensorize.py
@@ -27,6 +27,7 @@ from tvm import meta_schedule as ms
 from tvm import relay
 from tvm.meta_schedule import ApplyHistoryBest, postproc, schedule_rule
 from tvm.meta_schedule.relay_integration import extract_task_from_relay
+from tvm.meta_schedule.testing import relay_workload
 from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
 from tvm.meta_schedule.tune import tune_extracted_tasks
 from tvm.tir.tensor_intrin import AMDGPU_SDOT4_INTRIN, DP4A_INTRIN
@@ -337,10 +338,71 @@ def test_dp4a_bert_int8():
     # _test_bert_int8("rocm", sch_rules_for_sdot4, postprocs_for_dp4a)
 
 
+@tvm.testing.requires_gpu
+@pytest.mark.skip("Slow on CI")
+@pytest.mark.parametrize(
+    ["model_name", "input_shape"],
+    [("bert_base", (8, 128)), ("resnet_18", (16, 3, 224, 224)), ("resnet_50", (16, 3, 224, 224))],
+)
+def test_cuda_tensor_core(model_name, input_shape):
+    """Integration tests of auto tensorization with CUDA tensor core"""
+    target = tvm.target.Target("nvidia/geforce-rtx-3070")
+    dev = tvm.cuda()
+    if model_name.startswith("bert"):
+        data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), dev)  # embedding size
+    else:
+        data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), dev)
+
+    mod, params, (input_name, _, _) = relay_workload.get_network(model_name, input_shape)
+    seq = tvm.transform.Sequential(
+        [
+            relay.transform.ToMixedPrecision(),
+        ]
+    )
+
+    with tvm.transform.PassContext(opt_level=3):
+        mod = seq(mod)
+
+    def convert_layout(mod):
+        seq = tvm.transform.Sequential(
+            [relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"]})]
+        )
+        with tvm.transform.PassContext(opt_level=3):
+            mod = seq(mod)
+        return mod
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        with ms.Profiler() as profiler:
+            rt_mod1: tvm.runtime.Module = ms.tune_relay(
+                mod=convert_layout(mod),
+                params=params,
+                target=target,
+                config=ms.TuneConfig(
+                    num_trials_per_iter=32,
+                    max_trials_per_task=200,
+                    max_trials_global=3000,
+                ),
+                sch_rules=ms.default_config._DefaultCUDATensorCore.schedule_rules,
+                postprocs=ms.default_config._DefaultCUDATensorCore.postprocs,
+                work_dir=work_dir,
+            )
+        print(profiler.table())
+
+        # Compile without meta-scheduler for correctness check
+        with tvm.transform.PassContext(opt_level=0):
+            rt_mod2 = relay.build(mod, target=target, params=params)
+
+        def get_output(data, lib):
+            module = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+            module.set_input(input_name, data)
+            module.run()
+            return module.get_output(0).numpy()
+
+        # Check correctness
+        actual_output = get_output(data, rt_mod1)
+        expected_output = get_output(data, rt_mod2)
+        assert np.allclose(actual_output, expected_output, rtol=1e-2, atol=2e-2)
+
+
 if __name__ == "__main__":
-    test_vnni_dense()
-    test_vnni_conv2d()
-    test_vnni_bert_int8()
-    test_dp4a_dense()
-    test_dp4a_conv2d()
-    test_dp4a_bert_int8()
+    tvm.testing.main()