You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/11/11 00:08:58 UTC

[tvm] branch main updated: [MetaSchedule] Improve inlining and `VerifyGPUCode` for quantized model workload (#13334)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 93fdf83e8f [MetaSchedule] Improve inlining and `VerifyGPUCode` for quantized model workload (#13334)
93fdf83e8f is described below

commit 93fdf83e8f40b806ee5a8bd6625e0f4e431b459d
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Nov 11 09:08:51 2022 +0900

    [MetaSchedule] Improve inlining and `VerifyGPUCode` for quantized model workload (#13334)
    
    * [MetaSchedule] Add a new schedule rule to inline all scalar constants
    
    * add doc
    
    * reorg
    
    * identify constant block by its structure, not by name
---
 include/tvm/meta_schedule/schedule_rule.h          |  10 ++
 python/tvm/meta_schedule/schedule_rule/__init__.py |   2 +-
 .../tvm/meta_schedule/schedule_rule/auto_inline.py |  17 +++
 src/meta_schedule/postproc/verify_gpu_code.cc      |   2 +
 src/meta_schedule/schedule_rule/auto_inline.cc     |  37 +++++++
 src/meta_schedule/schedule_rule/schedule_rule.cc   |   3 +
 src/tir/analysis/verify_gpu_code.cc                |  13 +++
 .../metaschedule_e2e/test_resnet50_int8.py         |   5 +-
 ...test_meta_schedule_schedule_rule_auto_inline.py | 115 +++++++++++++++++++++
 9 files changed, 201 insertions(+), 3 deletions(-)

diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h
index da8f1faa8e..70dec47e60 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -125,6 +125,16 @@ class ScheduleRule : public runtime::ObjectRef {
                                          bool require_injective,      //
                                          bool require_ordered,        //
                                          Optional<Array<String>> disallow_op);
+
+  /*!
+   * \brief Inline blocks that produce a constant scalar. Such blocks get in the way of
+   * ReverseComputeInline during AutoInline, since they are also counted as a producer block
+   * unless they are inlined first. So it is recommended to run InlineConstantScalars before
+   * AutoInline.
+   * \return The schedule rule created
+   */
+  TVM_DLL static ScheduleRule InlineConstantScalars();
+
   /*!
    * \brief Create a mega rule: multi-level tiling with data reuse
    * \param structure The tiling structure. Recommended:
diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py
index 5971ad53c4..d330fc7139 100644
--- a/python/tvm/meta_schedule/schedule_rule/__init__.py
+++ b/python/tvm/meta_schedule/schedule_rule/__init__.py
@@ -22,7 +22,7 @@ blocks in a schedule. See also PostOrderApply.
 from .add_rfactor import AddRFactor
 from .apply_custom_rule import ApplyCustomRule
 from .auto_bind import AutoBind
-from .auto_inline import AutoInline
+from .auto_inline import AutoInline, InlineConstantScalars
 from .cross_thread_reduction import CrossThreadReduction
 from .multi_level_tiling import (
     MultiLevelTiling,
diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py
index 22206f3fcc..c84dbaf89b 100644
--- a/python/tvm/meta_schedule/schedule_rule/auto_inline.py
+++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py
@@ -65,3 +65,20 @@ class AutoInline(ScheduleRule):
             require_ordered,
             disallow_op,
         )
+
+
+@register_object("meta_schedule.InlineConstantScalars")
+class InlineConstantScalars(ScheduleRule):
+    """Inline blocks that produce a constant scalar.
+
+    Such blocks get in the way of ReverseComputeInline during AutoInline, since they are also
+    counted as a producer block unless they are inlined first. So it is recommended to run
+    InlineConstantScalars before AutoInline.
+    """
+
+    def __init__(
+        self,
+    ) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.ScheduleRuleInlineConstantScalars,  # type: ignore # pylint: disable=no-member
+        )
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc
index 0828ee5384..ae6f3474bb 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -175,10 +175,12 @@ class VerifyGPUCodeNode : public PostprocNode {
           pass_list.push_back(tir::transform::InjectDoubleBuffer());
           pass_list.push_back(tir::transform::StorageRewrite());
           pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
+          pass_list.push_back(tir::transform::LowerIntrin());
           // Convert Function to IRModule
           transform::PassContext pass_ctx = transform::PassContext::Current();
           tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
                                      runtime::String(g_var->name_hint));
+          f = WithAttr(f, tvm::attr::kTarget, Target("cuda"));  // Required for LowerIntrin
           bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
           if (noalias) {
             f = WithAttr(std::move(f), "tir.noalias", Bool(true));
diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc
index dcdc83f95c..d2d48b9008 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -189,5 +189,42 @@ TVM_REGISTER_NODE_TYPE(AutoInlineNode);
 TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline")
     .set_body_typed(ScheduleRule::AutoInline);
 
+/*! \brief Inline blocks that produce a constant scalar. */
+class InlineConstantScalarsNode : public ScheduleRuleNode {
+ public:
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+    // Look for a block of the form
+    // block compile_engine_const(iter_var(vi, range(min=0, ext=1))) {
+    //   reads([])
+    //   writes([compile_engine_const[]])
+    //   compile_engine_const[] = 59
+    // }
+    auto block = sch->Get(block_rv);
+    if (block->reads.size() == 0 && block->writes.size() == 1 &&
+        block->writes[0]->buffer->shape.size() == 0) {
+      sch->ComputeInline(block_rv);
+    }
+    return {sch};
+  }
+
+  ScheduleRule Clone() const final {
+    ObjectPtr<InlineConstantScalarsNode> n = make_object<InlineConstantScalarsNode>(*this);
+    return ScheduleRule(n);
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.InlineConstantScalars";
+  TVM_DECLARE_FINAL_OBJECT_INFO(InlineConstantScalarsNode, ScheduleRuleNode);
+};
+
+ScheduleRule ScheduleRule::InlineConstantScalars() {
+  ObjectPtr<InlineConstantScalarsNode> n = make_object<InlineConstantScalarsNode>();
+  return ScheduleRule(n);
+}
+
+TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode);
+TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars")
+    .set_body_typed(ScheduleRule::InlineConstantScalars);
 }  // namespace meta_schedule
 }  // namespace tvm
diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc
index 141b93be5e..b1e8c3695d 100644
--- a/src/meta_schedule/schedule_rule/schedule_rule.cc
+++ b/src/meta_schedule/schedule_rule/schedule_rule.cc
@@ -54,6 +54,7 @@ ScheduleRule ScheduleRule::PyScheduleRule(
 Array<ScheduleRule> ScheduleRule::DefaultLLVM() {
   return {
       ScheduleRule::ApplyCustomRule(),
+      ScheduleRule::InlineConstantScalars(),
       ScheduleRule::AutoInline(
           /*into_producer=*/false,
           /*into_consumer=*/true,
@@ -100,6 +101,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDA() {
           Map<String, ObjectRef>{{"req", String("must")},
                                  {"levels", Array<Integer>{3}},  //
                                  {"scope", String("local")}}),
+      ScheduleRule::InlineConstantScalars(),
       ScheduleRule::AutoInline(
           /*into_producer=*/true,
           /*into_consumer=*/true,
@@ -178,6 +180,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
 Array<ScheduleRule> ScheduleRule::DefaultHexagon() {
   return {
       ScheduleRule::ApplyCustomRule(),
+      ScheduleRule::InlineConstantScalars(),
       ScheduleRule::AutoInline(
           /*into_producer=*/false,
           /*into_consumer=*/true,
diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc
index f0672f3921..3377515a95 100644
--- a/src/tir/analysis/verify_gpu_code.cc
+++ b/src/tir/analysis/verify_gpu_code.cc
@@ -209,6 +209,19 @@ class GPUCodeVerifier : public StmtExprVisitor {
     }
   }
 
+  void VisitExpr_(const CastNode* op) {
+    if (op->dtype.lanes() > 1) {
+      if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
+        std::stringstream s;
+        s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
+          << op->dtype.bytes() << ") for dtype " << op->dtype
+          << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
+        errors_.push_back(s.str());
+      }
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
   void VisitExpr_(const BufferLoadNode* op) {
     if (op->dtype.lanes() > 1) {
       if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
index b703c79c5d..9edf5877fd 100644
--- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
+++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
@@ -33,6 +33,7 @@ from tvm.contrib.hexagon.meta_schedule import (
 )
 from tvm.meta_schedule import postproc, schedule_rule
 from tvm.tir.schedule import BlockRV, Schedule
+from tvm.tir.schedule.analysis import has_block
 from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, VRMPY_u8u8i32_INTRIN
 
 from ..infrastructure import get_hexagon_target
@@ -206,9 +207,9 @@ def _schedule_packed_8x8x32_conv2d():
 
     def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool:
         if conv2d_block is None:
-            try:
+            if has_block(sch, "conv2d_NCHWc_int8"):
                 conv2d_block = sch.get_block("conv2d_NCHWc_int8")
-            except ValueError:
+            else:
                 return False
 
         assert "conv2d_NCHWc_int8" in sch.get(conv2d_block).annotations["schedule_rule"]
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
index c17209e2cb..1baa13793f 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
@@ -15,7 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+import pytest
+
 import tvm
+from tvm.tir import Schedule
 from tvm import meta_schedule as ms
 from tvm.meta_schedule.testing.space_generation import generate_design_space
 from tvm.script import tir as T
@@ -334,6 +337,101 @@ class ConstConsumer:
                 T.writes(T_full[ax0, ax1, ax2])
                 T_full[ax0, ax1, ax2] = T.int64(0)
 
+
+@tvm.script.ir_module
+class Conv2dInt8:
+    @T.prim_func
+    def main(p0: T.Buffer[(16, 14, 14, 256), "int8"], p1: T.Buffer[(1024, 1, 1, 256), "int8"], p2: T.Buffer[(1, 1, 1, 1024), "int32"], p3: T.Buffer[(1, 1, 1, 1024), "int32"], p4: T.Buffer[1024, "int32"], p5: T.Buffer[1024, "int32"], p6: T.Buffer[1024, "int32"], p7: T.Buffer[1, "int32"], p8: T.Buffer[(16, 14, 14, 1024), "int32"], compute: T.Buffer[(16, 14, 14, 1024), "int32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        compile_engine_const = T.alloc_buffer([], dtype="int32")
+        pad_temp = T.alloc_buffer([16, 14, 14, 256], dtype="int8")
+        conv2d_nhwc = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+        T_subtract = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+        T_add = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+        compute_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+        T_add_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+        compute_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+        T_subtract_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+        compute_3 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+        T_add_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+        with T.block("compile_engine_const"):
+            vi = T.axis.spatial(1, 0)
+            T.reads()
+            T.writes(compile_engine_const[()])
+            compile_engine_const[()] = 59
+        for i0, i1, i2, i3 in T.grid(16, 14, 14, 256):
+            with T.block("pad_temp"):
+                i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                T.reads(p0[i0_1, i1_1, i2_1, i3_1])
+                T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1])
+                pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1]
+        for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 14, 14, 1024, 1, 1, 256):
+            with T.block("conv2d_nhwc"):
+                nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
+                T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc])
+                T.writes(conv2d_nhwc[nn, yy, xx, ff])
+                with T.init():
+                    conv2d_nhwc[nn, yy, xx, ff] = 0
+                conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32")
+        for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
+            with T.block("T_subtract"):
+                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3])
+                T.writes(T_subtract[ax0, ax1, ax2, ax3])
+                T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3]
+        for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
+            with T.block("T_add"):
+                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3])
+                T.writes(T_add[ax0, ax1, ax2, ax3])
+                T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3]
+        for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
+            with T.block("compute"):
+                i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                T.reads(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2])
+                T.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
+                compute_1[i0_2, i1_2, i2_2, i3_2] = T.q_multiply_shift_per_axis(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2], 31, False, True, dtype="int32")
+        for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 14, 14, 1024):
+            with T.block("T_add_1"):
+                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
+                T.reads(compile_engine_const[()], compute_1[ax0, ax1, ax2, ax3])
+                T.writes(T_add_1[ax0, ax1, ax2, ax3])
+                T_add_1[ax0, ax1, ax2, ax3] = compile_engine_const[()] + compute_1[ax0, ax1, ax2, ax3]
+        for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 14, 14, 1024):
+            with T.block("compute_1"):
+                i0_5, i1_5, i2_5, i3_5 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
+                T.reads(T_add_1[i0_5, i1_5, i2_5, i3_5])
+                T.writes(compute_2[i0_5, i1_5, i2_5, i3_5])
+                compute_2[i0_5, i1_5, i2_5, i3_5] = T.max(T.min(T_add_1[i0_5, i1_5, i2_5, i3_5], 255), 0)
+        for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 14, 14, 1024):
+            with T.block("T_subtract_1"):
+                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
+                T.reads(compute_2[ax0, ax1, ax2, ax3], p7[0])
+                T.writes(T_subtract_1[ax0, ax1, ax2, ax3])
+                T_subtract_1[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] - p7[0]
+        for i0_7, i1_7, i2_7, i3_7 in T.grid(16, 14, 14, 1024):
+            with T.block("compute_2"):
+                i0_8, i1_8, i2_8, i3_8 = T.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
+                T.reads(T_subtract_1[i0_8, i1_8, i2_8, i3_8])
+                T.writes(compute_3[i0_8, i1_8, i2_8, i3_8])
+                compute_3[i0_8, i1_8, i2_8, i3_8] = T.q_multiply_shift(T_subtract_1[i0_8, i1_8, i2_8, i3_8], 1408572815, 31, 1, dtype="int32")
+        for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 14, 14, 1024):
+            with T.block("T_add_2"):
+                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9])
+                T.reads(compute_3[ax0, ax1, ax2, ax3], p8[ax0, ax1, ax2, ax3])
+                T.writes(T_add_2[ax0, ax1, ax2, ax3])
+                T_add_2[ax0, ax1, ax2, ax3] = compute_3[ax0, ax1, ax2, ax3] + p8[ax0, ax1, ax2, ax3]
+        for i0_10, i1_10, i2_10, i3_10 in T.grid(16, 14, 14, 1024):
+            with T.block("compute_3"):
+                i0_11, i1_11, i2_11, i3_11 = T.axis.remap("SSSS", [i0_10, i1_10, i2_10, i3_10])
+                T.reads(T_add_2[i0_11, i1_11, i2_11, i3_11])
+                T.writes(compute[i0_11, i1_11, i2_11, i3_11])
+                compute[i0_11, i1_11, i2_11, i3_11] = T.max(T.min(T_add_2[i0_11, i1_11, i2_11, i3_11], 255), 0)
+
+
 # pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
 # fmt: on
 
@@ -398,9 +496,26 @@ def test_inline_constant_tensor():
     tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer)
 
 
+def test_conv2d_int8_inline_constant_scalars():
+    sch = Schedule(Conv2dInt8)
+
+    conv2d = sch.get_block("conv2d_nhwc")
+    sch.cache_write(conv2d, 0, "shared")
+
+    with pytest.raises(tvm.tir.ScheduleError) as e:
+        sch.reverse_compute_inline(sch.get_block("T_add_1"))
+
+    err_msg = "The block is only allowed to read a single buffer region, but it reads 2 region(s)"
+    assert err_msg in str(e)
+
+    ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_block("compile_engine_const"))
+    sch.reverse_compute_inline(sch.get_block("T_add_1"))
+
+
 if __name__ == "__main__":
     test_inline_consumer_chain()
     test_inline_into_cache()
     test_inline_into_multiple_consumers()
     test_inline_pure_spatial()
     test_inline_constant_tensor()
+    test_conv2d_int8_inline_constant_scalars()