You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/21 02:08:15 UTC
[tvm] branch main updated: [TIR] Add TileWithTensorIntrin (#11075)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0070b6cc05 [TIR] Add TileWithTensorIntrin (#11075)
0070b6cc05 is described below
commit 0070b6cc0557cce64c13e3b64f58d1f3d85a4687
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Thu Apr 21 11:08:10 2022 +0900
[TIR] Add TileWithTensorIntrin (#11075)
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
---
python/tvm/tir/schedule/__init__.py | 1 +
python/tvm/tir/schedule/transform.py | 42 +++++
src/tir/schedule/transform.cc | 63 +++++++
src/tir/schedule/transform.h | 13 ++
.../python/unittest/test_tir_schedule_transform.py | 181 +++++++++++++++++++++
5 files changed, 300 insertions(+)
diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py
index 66ac7b9d77..63638a8945 100644
--- a/python/tvm/tir/schedule/__init__.py
+++ b/python/tvm/tir/schedule/__init__.py
@@ -24,3 +24,4 @@ from .state import ScheduleDebugMask, ScheduleState
from .trace import Trace
from . import analysis
+from . import transform
diff --git a/python/tvm/tir/schedule/transform.py b/python/tvm/tir/schedule/transform.py
new file mode 100644
index 0000000000..5dbc06846d
--- /dev/null
+++ b/python/tvm/tir/schedule/transform.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Transformation on TIR schedule."""
+from typing import Optional
+
+from tvm.tir.schedule import Schedule, BlockRV, LoopRV
+from . import _ffi_api
+
+
+def tile_with_tensor_intrin(sch: Schedule, block: BlockRV, intrin_name: str) -> Optional[LoopRV]:
+ """Tile a subset of loops in the block according to the given tensor intrinsic.
+
+ Parameters
+ ----------
+ sch : Schedule
+ The schedule to which tiling is applied
+ block : BlockRV
+ The block whose subset of loops will be tiled
+ intrin_name : str
+ The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand
+
+ Returns
+ -------
+ tiled_loop_rv : Optional[LoopRV]
+ LoopRV corresponding to the outermost loop of a block tiled according to the given intrin
+ NullOpt if no valid loop mapping is found
+ """
+ return _ffi_api.TileWithTensorIntrin(sch, block, intrin_name) # type: ignore
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index ffb6b2d526..b2e71a9a0d 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -136,5 +136,68 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
throw OnlyLeafError(self->mod, GetRef<Block>(leaf_block), GetRef<Block>(scope_block));
}
+Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
+ const String& intrin_name) {
+ Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
+ sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
+ if (!opt_tensorize_info) return NullOpt;
+ const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
+ // Construct a mapping from tir loops back to LoopRVs
+ Map<tir::StmtSRef, LoopRV> loop2rv;
+ {
+ Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
+ for (const LoopRV& loop_rv : loop_rvs) {
+ loop2rv.Set(sch->GetSRef(loop_rv), loop_rv);
+ }
+ }
+ // Split the loops
+ arith::Analyzer analyzer;
+ std::unordered_set<const tir::StmtSRefNode*> inner_loops;
+ std::vector<LoopRV> reorder_suffix;
+ reorder_suffix.resize(info->loop_map.size());
+ for (const auto& kv : info->loop_map) {
+ // Extract mapping (block_loop => desc_loop)
+ const tir::StmtSRef& block_loop_sref = kv.first;
+ const tir::ForNode* block_loop = block_loop_sref->StmtAs<tir::ForNode>();
+ const tir::ForNode* desc_loop = kv.second.get();
+ ICHECK(block_loop != nullptr && desc_loop != nullptr);
+ // Extract the loop extent
+ PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
+ PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
+ const auto* int_block_extent = block_extent.as<IntImmNode>();
+ const auto* int_desc_extent = desc_extent.as<IntImmNode>();
+ ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr);
+ // Check divisibility
+ int64_t total = int_block_extent->value;
+ int64_t inner = int_desc_extent->value;
+ ICHECK_EQ(total % inner, 0);
+ int64_t outer = int_block_extent->value / int_desc_extent->value;
+ // Do the split
+ Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)});
+ ICHECK_EQ(split.size(), 2);
+ inner_loops.insert(sch->GetSRef(split[1]).operator->());
+ // The inner split will be reordered to the loop domain that is tensorized
+ int desc_loop_index = info->desc_loop_indexer.at(GetRef<tir::For>(desc_loop));
+ reorder_suffix[desc_loop_index] = split[1];
+ }
+ // Reorder the loops
+ std::vector<LoopRV> reorder_list;
+ bool meet = false;
+ Array<LoopRV> all_loops = sch->GetLoops(block_rv);
+ for (const LoopRV& loop : all_loops) {
+ if (inner_loops.count(sch->GetSRef(loop).operator->())) {
+ meet = true;
+ } else if (meet) {
+ reorder_list.push_back(loop);
+ }
+ }
+ reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end());
+ sch->Reorder(reorder_list);
+ ICHECK(!reorder_suffix.empty());
+ return reorder_suffix[0];
+}
+
+TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin);
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 3932c4bdbd..12326b3418 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -19,6 +19,7 @@
#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_
#define TVM_TIR_SCHEDULE_TRANSFORM_H_
+#include <tvm/tir/schedule/schedule.h>
#include <tvm/tir/schedule/state.h>
namespace tvm {
@@ -104,6 +105,18 @@ Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, c
void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref,
Stmt* src_stmt, Stmt* tgt_stmt);
+/*!
+ * \brief Tile a subset of loops in the block according to the given tensor intrinsic.
+ * \param self The schedule to which tiling is applied
+ * \param block_rv The block whose subset of loops will be tiled
+ * \param intrin_name The name of a tensor intrinsic, must be registerd via
+ * TensorIntrin.register(...) beforehand
+ * \return LoopRV corresponding to the outermost loop of a
+ * block tiled according to the given intrin, NullOpt if a valid loop mapping is not found
+ */
+Optional<tir::LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
+ const String& intrin_name);
+
} // namespace tir
} // namespace tvm
diff --git a/tests/python/unittest/test_tir_schedule_transform.py b/tests/python/unittest/test_tir_schedule_transform.py
new file mode 100644
index 0000000000..6dfd4315ec
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_transform.py
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
+
+from tvm.tir import Schedule
+from tvm.script import tir as T
+from tvm.tir.schedule.transform import tile_with_tensor_intrin
+
+
+@tvm.script.ir_module
+class DenseVNNIModule:
+ @T.prim_func
+ def main(
+ placeholder: T.Buffer[(1024, 1024), "uint8"],
+ placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
+ compute: T.Buffer[(1024, 1024), "int32"],
+ ) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ for i0, i1, i2 in T.grid(1024, 1024, 1024):
+ with T.block("compute"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4])
+ T.writes(compute[i, j])
+ with T.init():
+ compute[i, j] = 0
+ compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast(
+ placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32"
+ )
+
+
+@tvm.script.ir_module
+class DenseVNNIModuleTiled:
+ @T.prim_func
+ def main(
+ placeholder: T.Buffer[(1024, 1024), "uint8"],
+ placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
+ compute: T.Buffer[(1024, 1024), "int32"],
+ ) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ for i0, i1_0, i2_0, i1_1, i2_1 in T.grid(1024, 64, 256, 16, 4):
+ with T.block("compute"):
+ i = T.axis.spatial(1024, i0)
+ j = T.axis.spatial(1024, i1_0 * 16 + i1_1)
+ k = T.axis.reduce(1024, i2_0 * 4 + i2_1)
+ T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4])
+ T.writes(compute[i, j])
+ with T.init():
+ compute[i, j] = 0
+ compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast(
+ placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32"
+ )
+
+
+@tvm.script.ir_module
+class Conv2dNCHWcVNNIModule:
+ @T.prim_func
+ def main(
+ placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
+ placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
+ conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
+ ) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
+ with T.block("conv2d_NCHWc_int8"):
+ (
+ n,
+ oc_chunk,
+ oh,
+ ow,
+ oc_block,
+ kh,
+ kw,
+ ic_outer,
+ ic_f_inner,
+ ic_s_inner,
+ ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9])
+ T.reads(
+ placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
+ placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+ )
+ T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
+ with T.init():
+ conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
+ conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
+ n, oc_chunk, oh, ow, oc_block
+ ] + T.cast(
+ placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
+ ) * T.cast(
+ placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+ "int32",
+ )
+
+
+@tvm.script.ir_module
+class Conv2dNCHWcVNNIModuleTiled:
+ @T.prim_func
+ def main(
+ placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
+ placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
+ conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
+ ) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ for i0, i1, i2, i3, i4_0, i5, i6, i7, i8, i9_0, i4_1, i9_1 in T.grid(
+ 1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4
+ ):
+ with T.block("conv2d_NCHWc_int8"):
+ n = T.axis.spatial(1, 0)
+ oc_chunk, oh, ow, oc_block = T.axis.remap("SSSS", [i1, i2, i3, i4_1])
+ kh = T.axis.reduce(1, 0)
+ kw = T.axis.reduce(1, 0)
+ ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("RRR", [i7, i8, i9_1])
+ T.reads(
+ placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
+ placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+ )
+ T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
+ with T.init():
+ conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
+ conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
+ n, oc_chunk, oh, ow, oc_block
+ ] + T.cast(
+ placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
+ ) * T.cast(
+ placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+ "int32",
+ )
+
+
+def test_tile_with_tensor_intrin_dense_vnni():
+ s = Schedule(DenseVNNIModule)
+ block = s.get_block("compute")
+
+ tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)
+
+ _, _, _, i1_1, _ = s.get_loops(block)
+
+ assert s.get(tiled_loop) == s.get(i1_1)
+ tvm.ir.assert_structural_equal(s.mod, DenseVNNIModuleTiled)
+
+
+def test_tile_with_tensor_intrin_conv2d_nchwc_vnni():
+ s = Schedule(Conv2dNCHWcVNNIModule)
+ block = s.get_block("conv2d_NCHWc_int8")
+
+ tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)
+
+ tiled_loops = s.get_loops(block)
+
+ assert len(tiled_loops) == 12
+ assert s.get(tiled_loop) == s.get(tiled_loops[-2])
+
+ tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled)
+
+
+if __name__ == "__main__":
+ test_tile_with_tensor_intrin_dense_vnni()
+ test_tile_with_tensor_intrin_conv2d_nchwc_vnni()