You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/21 02:08:15 UTC

[tvm] branch main updated: [TIR] Add TileWithTensorIntrin (#11075)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0070b6cc05 [TIR] Add TileWithTensorIntrin (#11075)
0070b6cc05 is described below

commit 0070b6cc0557cce64c13e3b64f58d1f3d85a4687
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Thu Apr 21 11:08:10 2022 +0900

    [TIR] Add TileWithTensorIntrin (#11075)
    
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
---
 python/tvm/tir/schedule/__init__.py                |   1 +
 python/tvm/tir/schedule/transform.py               |  42 +++++
 src/tir/schedule/transform.cc                      |  63 +++++++
 src/tir/schedule/transform.h                       |  13 ++
 .../python/unittest/test_tir_schedule_transform.py | 181 +++++++++++++++++++++
 5 files changed, 300 insertions(+)

diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py
index 66ac7b9d77..63638a8945 100644
--- a/python/tvm/tir/schedule/__init__.py
+++ b/python/tvm/tir/schedule/__init__.py
@@ -24,3 +24,4 @@ from .state import ScheduleDebugMask, ScheduleState
 from .trace import Trace
 
 from . import analysis
+from . import transform
diff --git a/python/tvm/tir/schedule/transform.py b/python/tvm/tir/schedule/transform.py
new file mode 100644
index 0000000000..5dbc06846d
--- /dev/null
+++ b/python/tvm/tir/schedule/transform.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Transformation on TIR schedule."""
+from typing import Optional
+
+from tvm.tir.schedule import Schedule, BlockRV, LoopRV
+from . import _ffi_api
+
+
+def tile_with_tensor_intrin(sch: Schedule, block: BlockRV, intrin_name: str) -> Optional[LoopRV]:
+    """Tile a subset of loops in the block according to the given tensor intrinsic.
+
+    Parameters
+    ----------
+    sch : Schedule
+        The schedule to which tiling is applied
+    block : BlockRV
+        The block whose subset of loops will be tiled
+    intrin_name : str
+        The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand
+
+    Returns
+    -------
+    tiled_loop_rv : Optional[LoopRV]
+        LoopRV corresponding to the outermost loop of a block tiled according to the given intrin
+        NullOpt if no valid loop mapping is found
+    """
+    return _ffi_api.TileWithTensorIntrin(sch, block, intrin_name)  # type: ignore
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index ffb6b2d526..b2e71a9a0d 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -136,5 +136,68 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
   throw OnlyLeafError(self->mod, GetRef<Block>(leaf_block), GetRef<Block>(scope_block));
 }
 
+Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
+                                      const String& intrin_name) {
+  Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
+      sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
+  if (!opt_tensorize_info) return NullOpt;
+  const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
+  // Construct a mapping from tir loops back to LoopRVs
+  Map<tir::StmtSRef, LoopRV> loop2rv;
+  {
+    Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
+    for (const LoopRV& loop_rv : loop_rvs) {
+      loop2rv.Set(sch->GetSRef(loop_rv), loop_rv);
+    }
+  }
+  // Split the loops
+  arith::Analyzer analyzer;
+  std::unordered_set<const tir::StmtSRefNode*> inner_loops;
+  std::vector<LoopRV> reorder_suffix;
+  reorder_suffix.resize(info->loop_map.size());
+  for (const auto& kv : info->loop_map) {
+    // Extract mapping (block_loop => desc_loop)
+    const tir::StmtSRef& block_loop_sref = kv.first;
+    const tir::ForNode* block_loop = block_loop_sref->StmtAs<tir::ForNode>();
+    const tir::ForNode* desc_loop = kv.second.get();
+    ICHECK(block_loop != nullptr && desc_loop != nullptr);
+    // Extract the loop extent
+    PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
+    PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
+    const auto* int_block_extent = block_extent.as<IntImmNode>();
+    const auto* int_desc_extent = desc_extent.as<IntImmNode>();
+    ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr);
+    // Check divisibility
+    int64_t total = int_block_extent->value;
+    int64_t inner = int_desc_extent->value;
+    ICHECK_EQ(total % inner, 0);
+    int64_t outer = int_block_extent->value / int_desc_extent->value;
+    // Do the split
+    Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)});
+    ICHECK_EQ(split.size(), 2);
+    inner_loops.insert(sch->GetSRef(split[1]).operator->());
+    // The inner split will be reordered to the loop domain that is tensorized
+    int desc_loop_index = info->desc_loop_indexer.at(GetRef<tir::For>(desc_loop));
+    reorder_suffix[desc_loop_index] = split[1];
+  }
+  // Reorder the loops
+  std::vector<LoopRV> reorder_list;
+  bool meet = false;
+  Array<LoopRV> all_loops = sch->GetLoops(block_rv);
+  for (const LoopRV& loop : all_loops) {
+    if (inner_loops.count(sch->GetSRef(loop).operator->())) {
+      meet = true;
+    } else if (meet) {
+      reorder_list.push_back(loop);
+    }
+  }
+  reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end());
+  sch->Reorder(reorder_list);
+  ICHECK(!reorder_suffix.empty());
+  return reorder_suffix[0];
+}
+
+TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin);
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 3932c4bdbd..12326b3418 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -19,6 +19,7 @@
 #ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_
 #define TVM_TIR_SCHEDULE_TRANSFORM_H_
 
+#include <tvm/tir/schedule/schedule.h>
 #include <tvm/tir/schedule/state.h>
 
 namespace tvm {
@@ -104,6 +105,18 @@ Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, c
 void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref,
                           Stmt* src_stmt, Stmt* tgt_stmt);
 
+/*!
+ * \brief Tile a subset of loops in the block according to the given tensor intrinsic.
+ * \param self The schedule to which tiling is applied
+ * \param block_rv The block whose subset of loops will be tiled
+ * \param intrin_name The name of a tensor intrinsic, must be registerd via
+ * TensorIntrin.register(...) beforehand
+ * \return LoopRV corresponding to the outermost loop of a
+ * block tiled according to the given intrin, NullOpt if a valid loop mapping is not found
+ */
+Optional<tir::LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
+                                           const String& intrin_name);
+
 }  // namespace tir
 }  // namespace tvm
 
diff --git a/tests/python/unittest/test_tir_schedule_transform.py b/tests/python/unittest/test_tir_schedule_transform.py
new file mode 100644
index 0000000000..6dfd4315ec
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_transform.py
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
+
+from tvm.tir import Schedule
+from tvm.script import tir as T
+from tvm.tir.schedule.transform import tile_with_tensor_intrin
+
+
+@tvm.script.ir_module
+class DenseVNNIModule:
+    @T.prim_func
+    def main(
+        placeholder: T.Buffer[(1024, 1024), "uint8"],
+        placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
+        compute: T.Buffer[(1024, 1024), "int32"],
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            for i0, i1, i2 in T.grid(1024, 1024, 1024):
+                with T.block("compute"):
+                    i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+                    T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4])
+                    T.writes(compute[i, j])
+                    with T.init():
+                        compute[i, j] = 0
+                    compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast(
+                        placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32"
+                    )
+
+
+@tvm.script.ir_module
+class DenseVNNIModuleTiled:
+    @T.prim_func
+    def main(
+        placeholder: T.Buffer[(1024, 1024), "uint8"],
+        placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
+        compute: T.Buffer[(1024, 1024), "int32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        for i0, i1_0, i2_0, i1_1, i2_1 in T.grid(1024, 64, 256, 16, 4):
+            with T.block("compute"):
+                i = T.axis.spatial(1024, i0)
+                j = T.axis.spatial(1024, i1_0 * 16 + i1_1)
+                k = T.axis.reduce(1024, i2_0 * 4 + i2_1)
+                T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4])
+                T.writes(compute[i, j])
+                with T.init():
+                    compute[i, j] = 0
+                compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast(
+                    placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32"
+                )
+
+
+@tvm.script.ir_module
+class Conv2dNCHWcVNNIModule:
+    @T.prim_func
+    def main(
+        placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
+        placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
+        conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
+            with T.block("conv2d_NCHWc_int8"):
+                (
+                    n,
+                    oc_chunk,
+                    oh,
+                    ow,
+                    oc_block,
+                    kh,
+                    kw,
+                    ic_outer,
+                    ic_f_inner,
+                    ic_s_inner,
+                ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9])
+                T.reads(
+                    placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
+                    placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+                )
+                T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
+                with T.init():
+                    conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
+                conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
+                    n, oc_chunk, oh, ow, oc_block
+                ] + T.cast(
+                    placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
+                ) * T.cast(
+                    placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+                    "int32",
+                )
+
+
+@tvm.script.ir_module
+class Conv2dNCHWcVNNIModuleTiled:
+    @T.prim_func
+    def main(
+        placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
+        placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
+        conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        for i0, i1, i2, i3, i4_0, i5, i6, i7, i8, i9_0, i4_1, i9_1 in T.grid(
+            1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4
+        ):
+            with T.block("conv2d_NCHWc_int8"):
+                n = T.axis.spatial(1, 0)
+                oc_chunk, oh, ow, oc_block = T.axis.remap("SSSS", [i1, i2, i3, i4_1])
+                kh = T.axis.reduce(1, 0)
+                kw = T.axis.reduce(1, 0)
+                ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("RRR", [i7, i8, i9_1])
+                T.reads(
+                    placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
+                    placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+                )
+                T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
+                with T.init():
+                    conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
+                conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
+                    n, oc_chunk, oh, ow, oc_block
+                ] + T.cast(
+                    placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
+                ) * T.cast(
+                    placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
+                    "int32",
+                )
+
+
+def test_tile_with_tensor_intrin_dense_vnni():
+    s = Schedule(DenseVNNIModule)
+    block = s.get_block("compute")
+
+    tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)
+
+    _, _, _, i1_1, _ = s.get_loops(block)
+
+    assert s.get(tiled_loop) == s.get(i1_1)
+    tvm.ir.assert_structural_equal(s.mod, DenseVNNIModuleTiled)
+
+
+def test_tile_with_tensor_intrin_conv2d_nchwc_vnni():
+    s = Schedule(Conv2dNCHWcVNNIModule)
+    block = s.get_block("conv2d_NCHWc_int8")
+
+    tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)
+
+    tiled_loops = s.get_loops(block)
+
+    assert len(tiled_loops) == 12
+    assert s.get(tiled_loop) == s.get(tiled_loops[-2])
+
+    tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled)
+
+
+if __name__ == "__main__":
+    test_tile_with_tensor_intrin_dense_vnni()
+    test_tile_with_tensor_intrin_conv2d_nchwc_vnni()