You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/11/11 00:08:58 UTC
[tvm] branch main updated: [MetaSchedule] Improve inlining and `VerifyGPUCode` for quantized model workload (#13334)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 93fdf83e8f [MetaSchedule] Improve inlining and `VerifyGPUCode` for quantized model workload (#13334)
93fdf83e8f is described below
commit 93fdf83e8f40b806ee5a8bd6625e0f4e431b459d
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Nov 11 09:08:51 2022 +0900
[MetaSchedule] Improve inlining and `VerifyGPUCode` for quantized model workload (#13334)
* [MetaSchedule] Add a new schedule rule to inline all scalar constants
* add doc
* reorg
* identify constant block by its structure, not by name
---
include/tvm/meta_schedule/schedule_rule.h | 10 ++
python/tvm/meta_schedule/schedule_rule/__init__.py | 2 +-
.../tvm/meta_schedule/schedule_rule/auto_inline.py | 17 +++
src/meta_schedule/postproc/verify_gpu_code.cc | 2 +
src/meta_schedule/schedule_rule/auto_inline.cc | 37 +++++++
src/meta_schedule/schedule_rule/schedule_rule.cc | 3 +
src/tir/analysis/verify_gpu_code.cc | 13 +++
.../metaschedule_e2e/test_resnet50_int8.py | 5 +-
...test_meta_schedule_schedule_rule_auto_inline.py | 115 +++++++++++++++++++++
9 files changed, 201 insertions(+), 3 deletions(-)
diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h
index da8f1faa8e..70dec47e60 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -125,6 +125,16 @@ class ScheduleRule : public runtime::ObjectRef {
bool require_injective, //
bool require_ordered, //
Optional<Array<String>> disallow_op);
+
+ /*!
+ * \brief Inline blocks that produce a constant scalar. Such blocks get in the way of
+ * ReverseComputeInline during AutoInline, since they are also counted as a producer block
+ * unless they are inlined first. So it is recommended to run InlineConstantScalars before
+ * AutoInline.
+ * \return The schedule rule created
+ */
+ TVM_DLL static ScheduleRule InlineConstantScalars();
+
/*!
* \brief Create a mega rule: multi-level tiling with data reuse
* \param structure The tiling structure. Recommended:
diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py
index 5971ad53c4..d330fc7139 100644
--- a/python/tvm/meta_schedule/schedule_rule/__init__.py
+++ b/python/tvm/meta_schedule/schedule_rule/__init__.py
@@ -22,7 +22,7 @@ blocks in a schedule. See also PostOrderApply.
from .add_rfactor import AddRFactor
from .apply_custom_rule import ApplyCustomRule
from .auto_bind import AutoBind
-from .auto_inline import AutoInline
+from .auto_inline import AutoInline, InlineConstantScalars
from .cross_thread_reduction import CrossThreadReduction
from .multi_level_tiling import (
MultiLevelTiling,
diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py
index 22206f3fcc..c84dbaf89b 100644
--- a/python/tvm/meta_schedule/schedule_rule/auto_inline.py
+++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py
@@ -65,3 +65,20 @@ class AutoInline(ScheduleRule):
require_ordered,
disallow_op,
)
+
+
+@register_object("meta_schedule.InlineConstantScalars")
+class InlineConstantScalars(ScheduleRule):
+ """Inline blocks that produce a constant scalar.
+
+ Such blocks get in the way of ReverseComputeInline during AutoInline, since they are also
+ counted as a producer block unless they are inlined first. So it is recommended to run
+ InlineConstantScalars before AutoInline.
+ """
+
+ def __init__(
+ self,
+ ) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.ScheduleRuleInlineConstantScalars, # type: ignore # pylint: disable=no-member
+ )
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc
index 0828ee5384..ae6f3474bb 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -175,10 +175,12 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::StorageRewrite());
pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
+ pass_list.push_back(tir::transform::LowerIntrin());
// Convert Function to IRModule
transform::PassContext pass_ctx = transform::PassContext::Current();
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
+ f = WithAttr(f, tvm::attr::kTarget, Target("cuda")); // Required for LowerIntrin
bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
if (noalias) {
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc
index dcdc83f95c..d2d48b9008 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -189,5 +189,42 @@ TVM_REGISTER_NODE_TYPE(AutoInlineNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline")
.set_body_typed(ScheduleRule::AutoInline);
+/*! \brief Inline blocks that produce a constant scalar. */
+class InlineConstantScalarsNode : public ScheduleRuleNode {
+ public:
+ void InitializeWithTuneContext(const TuneContext& context) final {}
+
+ Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+ // Look for a block of the form
+ // block compile_engine_const(iter_var(vi, range(min=0, ext=1))) {
+ // reads([])
+ // writes([compile_engine_const[]])
+ // compile_engine_const[] = 59
+ // }
+ auto block = sch->Get(block_rv);
+ if (block->reads.size() == 0 && block->writes.size() == 1 &&
+ block->writes[0]->buffer->shape.size() == 0) {
+ sch->ComputeInline(block_rv);
+ }
+ return {sch};
+ }
+
+ ScheduleRule Clone() const final {
+ ObjectPtr<InlineConstantScalarsNode> n = make_object<InlineConstantScalarsNode>(*this);
+ return ScheduleRule(n);
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.InlineConstantScalars";
+ TVM_DECLARE_FINAL_OBJECT_INFO(InlineConstantScalarsNode, ScheduleRuleNode);
+};
+
+ScheduleRule ScheduleRule::InlineConstantScalars() {
+ ObjectPtr<InlineConstantScalarsNode> n = make_object<InlineConstantScalarsNode>();
+ return ScheduleRule(n);
+}
+
+TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode);
+TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars")
+ .set_body_typed(ScheduleRule::InlineConstantScalars);
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc
index 141b93be5e..b1e8c3695d 100644
--- a/src/meta_schedule/schedule_rule/schedule_rule.cc
+++ b/src/meta_schedule/schedule_rule/schedule_rule.cc
@@ -54,6 +54,7 @@ ScheduleRule ScheduleRule::PyScheduleRule(
Array<ScheduleRule> ScheduleRule::DefaultLLVM() {
return {
ScheduleRule::ApplyCustomRule(),
+ ScheduleRule::InlineConstantScalars(),
ScheduleRule::AutoInline(
/*into_producer=*/false,
/*into_consumer=*/true,
@@ -100,6 +101,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDA() {
Map<String, ObjectRef>{{"req", String("must")},
{"levels", Array<Integer>{3}}, //
{"scope", String("local")}}),
+ ScheduleRule::InlineConstantScalars(),
ScheduleRule::AutoInline(
/*into_producer=*/true,
/*into_consumer=*/true,
@@ -178,6 +180,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
Array<ScheduleRule> ScheduleRule::DefaultHexagon() {
return {
ScheduleRule::ApplyCustomRule(),
+ ScheduleRule::InlineConstantScalars(),
ScheduleRule::AutoInline(
/*into_producer=*/false,
/*into_consumer=*/true,
diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc
index f0672f3921..3377515a95 100644
--- a/src/tir/analysis/verify_gpu_code.cc
+++ b/src/tir/analysis/verify_gpu_code.cc
@@ -209,6 +209,19 @@ class GPUCodeVerifier : public StmtExprVisitor {
}
}
+ void VisitExpr_(const CastNode* op) {
+ if (op->dtype.lanes() > 1) {
+ if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
+ std::stringstream s;
+ s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
+ << op->dtype.bytes() << ") for dtype " << op->dtype
+ << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
+ errors_.push_back(s.str());
+ }
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
void VisitExpr_(const BufferLoadNode* op) {
if (op->dtype.lanes() > 1) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
index b703c79c5d..9edf5877fd 100644
--- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
+++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
@@ -33,6 +33,7 @@ from tvm.contrib.hexagon.meta_schedule import (
)
from tvm.meta_schedule import postproc, schedule_rule
from tvm.tir.schedule import BlockRV, Schedule
+from tvm.tir.schedule.analysis import has_block
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8i8i32_INTRIN, VRMPY_u8u8i32_INTRIN
from ..infrastructure import get_hexagon_target
@@ -206,9 +207,9 @@ def _schedule_packed_8x8x32_conv2d():
def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool:
if conv2d_block is None:
- try:
+ if has_block(sch, "conv2d_NCHWc_int8"):
conv2d_block = sch.get_block("conv2d_NCHWc_int8")
- except ValueError:
+ else:
return False
assert "conv2d_NCHWc_int8" in sch.get(conv2d_block).annotations["schedule_rule"]
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
index c17209e2cb..1baa13793f 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
@@ -15,7 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+import pytest
+
import tvm
+from tvm.tir import Schedule
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.space_generation import generate_design_space
from tvm.script import tir as T
@@ -334,6 +337,101 @@ class ConstConsumer:
T.writes(T_full[ax0, ax1, ax2])
T_full[ax0, ax1, ax2] = T.int64(0)
+
+@tvm.script.ir_module
+class Conv2dInt8:
+ @T.prim_func
+ def main(p0: T.Buffer[(16, 14, 14, 256), "int8"], p1: T.Buffer[(1024, 1, 1, 256), "int8"], p2: T.Buffer[(1, 1, 1, 1024), "int32"], p3: T.Buffer[(1, 1, 1, 1024), "int32"], p4: T.Buffer[1024, "int32"], p5: T.Buffer[1024, "int32"], p6: T.Buffer[1024, "int32"], p7: T.Buffer[1, "int32"], p8: T.Buffer[(16, 14, 14, 1024), "int32"], compute: T.Buffer[(16, 14, 14, 1024), "int32"]) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ compile_engine_const = T.alloc_buffer([], dtype="int32")
+ pad_temp = T.alloc_buffer([16, 14, 14, 256], dtype="int8")
+ conv2d_nhwc = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+ T_subtract = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+ T_add = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+ compute_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+ T_add_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+ compute_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+ T_subtract_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+ compute_3 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+ T_add_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
+ with T.block("compile_engine_const"):
+ vi = T.axis.spatial(1, 0)
+ T.reads()
+ T.writes(compile_engine_const[()])
+ compile_engine_const[()] = 59
+ for i0, i1, i2, i3 in T.grid(16, 14, 14, 256):
+ with T.block("pad_temp"):
+ i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(p0[i0_1, i1_1, i2_1, i3_1])
+ T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1])
+ pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1]
+ for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 14, 14, 1024, 1, 1, 256):
+ with T.block("conv2d_nhwc"):
+ nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
+ T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc])
+ T.writes(conv2d_nhwc[nn, yy, xx, ff])
+ with T.init():
+ conv2d_nhwc[nn, yy, xx, ff] = 0
+ conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32")
+ for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
+ with T.block("T_subtract"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3])
+ T.writes(T_subtract[ax0, ax1, ax2, ax3])
+ T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3]
+ for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
+ with T.block("T_add"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3])
+ T.writes(T_add[ax0, ax1, ax2, ax3])
+ T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3]
+ for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
+ with T.block("compute"):
+ i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2])
+ T.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
+ compute_1[i0_2, i1_2, i2_2, i3_2] = T.q_multiply_shift_per_axis(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2], 31, False, True, dtype="int32")
+ for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 14, 14, 1024):
+ with T.block("T_add_1"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
+ T.reads(compile_engine_const[()], compute_1[ax0, ax1, ax2, ax3])
+ T.writes(T_add_1[ax0, ax1, ax2, ax3])
+ T_add_1[ax0, ax1, ax2, ax3] = compile_engine_const[()] + compute_1[ax0, ax1, ax2, ax3]
+ for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 14, 14, 1024):
+ with T.block("compute_1"):
+ i0_5, i1_5, i2_5, i3_5 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
+ T.reads(T_add_1[i0_5, i1_5, i2_5, i3_5])
+ T.writes(compute_2[i0_5, i1_5, i2_5, i3_5])
+ compute_2[i0_5, i1_5, i2_5, i3_5] = T.max(T.min(T_add_1[i0_5, i1_5, i2_5, i3_5], 255), 0)
+ for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 14, 14, 1024):
+ with T.block("T_subtract_1"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
+ T.reads(compute_2[ax0, ax1, ax2, ax3], p7[0])
+ T.writes(T_subtract_1[ax0, ax1, ax2, ax3])
+ T_subtract_1[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] - p7[0]
+ for i0_7, i1_7, i2_7, i3_7 in T.grid(16, 14, 14, 1024):
+ with T.block("compute_2"):
+ i0_8, i1_8, i2_8, i3_8 = T.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
+ T.reads(T_subtract_1[i0_8, i1_8, i2_8, i3_8])
+ T.writes(compute_3[i0_8, i1_8, i2_8, i3_8])
+ compute_3[i0_8, i1_8, i2_8, i3_8] = T.q_multiply_shift(T_subtract_1[i0_8, i1_8, i2_8, i3_8], 1408572815, 31, 1, dtype="int32")
+ for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 14, 14, 1024):
+ with T.block("T_add_2"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9])
+ T.reads(compute_3[ax0, ax1, ax2, ax3], p8[ax0, ax1, ax2, ax3])
+ T.writes(T_add_2[ax0, ax1, ax2, ax3])
+ T_add_2[ax0, ax1, ax2, ax3] = compute_3[ax0, ax1, ax2, ax3] + p8[ax0, ax1, ax2, ax3]
+ for i0_10, i1_10, i2_10, i3_10 in T.grid(16, 14, 14, 1024):
+ with T.block("compute_3"):
+ i0_11, i1_11, i2_11, i3_11 = T.axis.remap("SSSS", [i0_10, i1_10, i2_10, i3_10])
+ T.reads(T_add_2[i0_11, i1_11, i2_11, i3_11])
+ T.writes(compute[i0_11, i1_11, i2_11, i3_11])
+ compute[i0_11, i1_11, i2_11, i3_11] = T.max(T.min(T_add_2[i0_11, i1_11, i2_11, i3_11], 255), 0)
+
+
# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on
@@ -398,9 +496,26 @@ def test_inline_constant_tensor():
tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer)
+def test_conv2d_int8_inline_constant_scalars():
+ sch = Schedule(Conv2dInt8)
+
+ conv2d = sch.get_block("conv2d_nhwc")
+ sch.cache_write(conv2d, 0, "shared")
+
+ with pytest.raises(tvm.tir.ScheduleError) as e:
+ sch.reverse_compute_inline(sch.get_block("T_add_1"))
+
+ err_msg = "The block is only allowed to read a single buffer region, but it reads 2 region(s)"
+ assert err_msg in str(e)
+
+ ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_block("compile_engine_const"))
+ sch.reverse_compute_inline(sch.get_block("T_add_1"))
+
+
if __name__ == "__main__":
test_inline_consumer_chain()
test_inline_into_cache()
test_inline_into_multiple_consumers()
test_inline_pure_spatial()
test_inline_constant_tensor()
+ test_conv2d_int8_inline_constant_scalars()