You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/16 06:29:25 UTC

[tvm] branch main updated: [TVMScript] IRBuilder methods for `Axis` (#12808)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c0d2734056 [TVMScript] IRBuilder methods for `Axis` (#12808)
c0d2734056 is described below

commit c0d2734056d4d4bfc67a125b4e61194a809f22d5
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Thu Sep 15 23:29:17 2022 -0700

    [TVMScript] IRBuilder methods for `Axis` (#12808)
    
    This PR introduces remaining IRBuilder methods for `Axis`.
    
    Co-authored-by: yongwww <yo...@gmail.com>
---
 include/tvm/script/ir_builder/tir/ir.h             |  49 +++++++
 python/tvm/script/ir_builder/tir/ir.py             | 157 ++++++++++++++++++++-
 src/script/ir_builder/tir/ir.cc                    |  86 +++++++++++
 .../unittest/test_tvmscript_ir_builder_tir.py      |  43 ++++++
 4 files changed, 334 insertions(+), 1 deletion(-)

diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
index 68948196ff..037606253a 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -141,6 +141,55 @@ void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
  */
 BlockFrame Block(String name, bool no_realize = false);
 
+namespace axis {
+
+/*!
+ * \brief The spatial block axis defining function.
+ * \param dom The domain of the iteration variable.
+ * \param binding The binding value of the iteration variable.
+ * \param dtype The data type of the iteration variable.
+ * \return The iteration variable.
+ */
+Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+
+/*!
+ * \brief The reduced block axis defining function.
+ * \param dom The domain of the iteration variable.
+ * \param binding The binding value of the iteration variable.
+ * \param dtype The data type of the iteration variable.
+ * \return The iteration variable.
+ */
+Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+
+/*!
+ * \brief The scanning block axis defining function.
+ * \param dom The domain of the iteration variable.
+ * \param binding The binding value of the iteration variable.
+ * \param dtype The data type of the iteration variable.
+ * \return The iteration variable.
+ */
+Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+
+/*!
+ * \brief The opaque block axis defining function.
+ * \param dom The domain of the iteration variable.
+ * \param binding The binding value of the iteration variable.
+ * \param dtype The data type of the iteration variable.
+ * \return The iteration variable.
+ */
+Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+
+/*!
+ * \brief The block axis remapping function.
+ * \param kinds The types of the iteration variables.
+ * \param bindings The binding values of the iteration variables.
+ * \param dtype The data types of the iteration variables.
+ * \return The iteration variables.
+ */
+Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype = DataType::Int(32));
+
+}  // namespace axis
+
 /*!
  * \brief The serial For statement.
  * \param start The minimum value of iteration.
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index a5cdf8a3a1..40cd99c744 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -20,7 +20,7 @@
 from numbers import Integral
 from typing import Any, Dict, List, Optional, Union, Tuple
 
-from tvm.ir import Type
+from tvm.ir import Range, Type
 from tvm.tir import (
     Buffer,
     BufferLoad,
@@ -344,6 +344,160 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
     return _ffi_api.Block(name, no_realize)  # pylint: disable=no-member # type: ignore
 
 
+def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
+    """The range constructor.
+
+    Parameters
+    ----------
+    dom : Union[Range, List[PrimExpr]]
+        The domain.
+
+    Returns
+    -------
+    res : Range
+        The Range.
+    """
+    if isinstance(dom, Range):
+        return dom
+    if isinstance(dom, (list, tuple)):
+        return Range(dom[0], dom[1])
+    return Range(0, dom)
+
+
+class axis:  # pylint: disable=invalid-name
+    @staticmethod
+    def spatial(
+        dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+    ) -> Var:
+        """The spatial block axis defining function.
+
+        Parameters
+        ----------
+        dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
+            The domain of the iteration variable.
+
+        binding : PrimExpr
+            The binding value of the iteration variable.
+
+        dtype : str
+            The data type of the iteration variable.
+
+        Returns
+        -------
+        res : Var
+            The iteration variable.
+        """
+        return _ffi_api.AxisSpatial(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def reduce(
+        dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+    ) -> Var:
+        """The reduced block axis defining function.
+
+        Parameters
+        ----------
+        dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
+            The domain of the iteration variable.
+
+        binding : PrimExpr
+            The binding value of the iteration variable.
+
+        dtype : str
+            The data type of the iteration variable.
+
+        Returns
+        -------
+        res : Var
+            The iteration variable.
+        """
+        return _ffi_api.AxisReduce(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def scan(
+        dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+    ) -> Var:
+        """The scanning block axis defining function.
+
+        Parameters
+        ----------
+        dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
+            The domain of the iteration variable.
+
+        binding : PrimExpr
+            The binding value of the iteration variable.
+
+        dtype : str
+            The data type of the iteration variable.
+
+        Returns
+        -------
+        res : Var
+            The iteration variable.
+        """
+        return _ffi_api.AxisScan(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def opaque(
+        dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+    ) -> Var:
+        """The opaque block axis defining function.
+
+        Parameters
+        ----------
+        dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
+            The domain of the iteration variable.
+
+        binding : PrimExpr
+            The binding value of the iteration variable.
+
+        dtype : str
+            The data type of the iteration variable.
+
+        Returns
+        -------
+        res : Var
+            The iteration variable.
+        """
+        return _ffi_api.AxisOpaque(  # pylint: disable=no-member # type: ignore
+            _as_range(dom), binding, dtype
+        )
+
+    @staticmethod
+    def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]:
+        """The block axis remapping function.
+
+        Parameters
+        ----------
+        kinds : str
+            The types of the iteration variables.
+
+        bindings : List[PrimExpr]
+            The binding values of the iteration variables.
+
+        dtype : str
+            The data types of the iteration variables.
+
+        Returns
+        -------
+        res : Var
+            The iteration variables.
+        """
+        iter_vars = _ffi_api.AxisRemap(  # pylint: disable=no-member # type: ignore
+            kinds, bindings, dtype
+        )
+        return iter_vars[0] if len(iter_vars) == 1 else iter_vars
+
+    S = spatial  # pylint: disable=invalid-name
+    R = reduce  # pylint: disable=invalid-name
+
+
 def serial(
     start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
 ) -> frame.ForFrame:
@@ -843,6 +997,7 @@ __all__ = [
     "match_buffer",
     "preflattened_buffer",
     "block",
+    "axis",
     "serial",
     "parallel",
     "vectorized",
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 22c7face70..5013e32172 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -173,6 +173,86 @@ BlockFrame Block(String name, bool no_realize) {
   return BlockFrame(n);
 }
 
+namespace axis {
+
+IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) {
+  if (Optional<BlockFrame> opt_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+    BlockFrame frame = opt_frame.value();
+    frame->iter_vars.push_back(iter_var);
+    frame->iter_values.push_back(binding);
+  } else {
+    LOG(FATAL) << "TypeError: The last frame is not BlockFrame";
+  }
+  return iter_var;
+}
+
+#define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name)                                           \
+  Var Method(Range dom, PrimExpr binding, DataType dtype) {                                   \
+    ICHECK(dom.defined()) << Name << " axis must have a domain";                              \
+    int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \
+    return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)),          \
+                                /*iter_type=*/Kind, /*thread_tag=*/""),                       \
+                        binding)                                                              \
+        ->var;                                                                                \
+  }
+TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial");
+TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction");
+TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan");
+TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque");
+#undef TVM_TIR_IR_BUILDER_AXIS
+
+Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
+  using namespace tvm::tir;
+  Array<Var> results;
+  ICHECK_EQ(kinds.size(), bindings.size());
+  int n = bindings.size();
+  results.reserve(n);
+  for (int i = 0; i < n; ++i) {
+    char c = kinds.c_str()[i];
+    PrimExpr e = bindings[i];
+    const VarNode* v = e.as<VarNode>();
+    ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap";
+    Range dom{nullptr};
+    for (const auto& frame : IRBuilder::Current()->frames) {
+      if (const auto* for_frame = frame.as<ForFrameNode>()) {
+        ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size());
+        int n = for_frame->doms.size();
+        for (int i = 0; i < n; ++i) {
+          if (for_frame->vars[i].get() == v) {
+            dom = for_frame->doms[i];
+            break;
+          }
+        }
+        if (dom.defined()) {
+          break;
+        }
+      }
+    }
+    ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef<Var>(v);
+    DataType dtype = v->dtype;
+    if (c == 'S') {
+      results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+                                             /*var=*/Var("", dtype),
+                                             /*iter_type=*/IterVarType::kDataPar,
+                                             /*thread_tag=*/""),
+                                     e)
+                            ->var);
+    } else if (c == 'R') {
+      results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+                                             /*var=*/Var("", dtype),
+                                             /*iter_type=*/IterVarType::kCommReduce,
+                                             /*thread_tag=*/""),
+                                     e)
+                            ->var);
+    } else {
+      LOG(FATAL) << "Unknown axis kind: " << c;
+    }
+  }
+  return results;
+}
+
+}  // namespace axis
+
 #define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind)                                                \
   ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, ObjectRef>> annotations) {  \
     PrimExpr min = start;                                                                         \
@@ -304,6 +384,12 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(P
 
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block);
 
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan").set_body_typed(axis::Scan);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap").set_body_typed(axis::Remap);
+
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized);
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index 9cbfd75e22..d893ebc545 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -114,6 +114,49 @@ def test_ir_builder_tir_block():
     assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)
 
 
+def test_ir_builder_tir_axis():
+    with IRBuilder() as ib:
+        a = T.var("int32", "a")
+        b = T.var("int32", "b")
+        c = T.var("int32", "c")
+        d = T.var("int32", "d")
+        with T.block("block"):
+            T.axis.spatial(8, a)
+            T.axis.reduce(16, b)
+            T.axis.scan(32, c)
+            T.axis.opaque(64, d)
+            T.evaluate(0)
+
+    # the block generated by IRBuilder
+    block_realize_actual = ib.get()
+
+    # the expected block
+    var_a = tir.Var("a", "int32")
+    var_b = tir.Var("b", "int32")
+    var_c = tir.Var("c", "int32")
+    var_d = tir.Var("d", "int32")
+    block_expected = tir.Block(
+        iter_vars=[
+            tir.IterVar((0, 8), tir.Var("", "int32"), iter_type=tir.IterVar.DataPar),
+            tir.IterVar((0, 16), tir.Var("", "int32"), iter_type=tir.IterVar.CommReduce),
+            tir.IterVar((0, 32), tir.Var("", "int32"), iter_type=tir.IterVar.Ordered),
+            tir.IterVar((0, 64), tir.Var("", "int32"), iter_type=tir.IterVar.DimInfo),
+        ],
+        reads=[],
+        writes=[],
+        name_hint="block",
+        body=tir.Evaluate(0),
+        annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)},
+    )
+    block_realize_expected = tir.BlockRealize(
+        iter_values=[var_a, var_b, var_c, var_d],
+        predicate=True,
+        block=block_expected,
+    )
+    # Check if the generated ir is expected
+    assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)
+
+
 def test_ir_builder_tir_for():
     with IRBuilder() as ib:
         with T.serial(128) as a: