You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/16 06:29:25 UTC
[tvm] branch main updated: [TVMScript] IRBuilder methods for `Axis` (#12808)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c0d2734056 [TVMScript] IRBuilder methods for `Axis` (#12808)
c0d2734056 is described below
commit c0d2734056d4d4bfc67a125b4e61194a809f22d5
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Thu Sep 15 23:29:17 2022 -0700
[TVMScript] IRBuilder methods for `Axis` (#12808)
This PR introduces remaining IRBuilder methods for `Axis`.
Co-authored-by: yongwww <yo...@gmail.com>
---
include/tvm/script/ir_builder/tir/ir.h | 49 +++++++
python/tvm/script/ir_builder/tir/ir.py | 157 ++++++++++++++++++++-
src/script/ir_builder/tir/ir.cc | 86 +++++++++++
.../unittest/test_tvmscript_ir_builder_tir.py | 43 ++++++
4 files changed, 334 insertions(+), 1 deletion(-)
diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
index 68948196ff..037606253a 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -141,6 +141,55 @@ void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
*/
BlockFrame Block(String name, bool no_realize = false);
+namespace axis {
+
+/*!
+ * \brief The spatial block axis defining function.
+ * \param dom The domain of the iteration variable.
+ * \param binding The binding value of the iteration variable.
+ * \param dtype The data type of the iteration variable.
+ * \return The iteration variable.
+ */
+Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+
+/*!
+ * \brief The reduced block axis defining function.
+ * \param dom The domain of the iteration variable.
+ * \param binding The binding value of the iteration variable.
+ * \param dtype The data type of the iteration variable.
+ * \return The iteration variable.
+ */
+Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+
+/*!
+ * \brief The scanning block axis defining function.
+ * \param dom The domain of the iteration variable.
+ * \param binding The binding value of the iteration variable.
+ * \param dtype The data type of the iteration variable.
+ * \return The iteration variable.
+ */
+Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+
+/*!
+ * \brief The opaque block axis defining function.
+ * \param dom The domain of the iteration variable.
+ * \param binding The binding value of the iteration variable.
+ * \param dtype The data type of the iteration variable.
+ * \return The iteration variable.
+ */
+Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));
+
+/*!
+ * \brief The block axis remapping function.
+ * \param kinds The types of the iteration variables.
+ * \param bindings The binding values of the iteration variables.
+ * \param dtype The data types of the iteration variables.
+ * \return The iteration variables.
+ */
+Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype = DataType::Int(32));
+
+} // namespace axis
+
/*!
* \brief The serial For statement.
* \param start The minimum value of iteration.
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index a5cdf8a3a1..40cd99c744 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -20,7 +20,7 @@
from numbers import Integral
from typing import Any, Dict, List, Optional, Union, Tuple
-from tvm.ir import Type
+from tvm.ir import Range, Type
from tvm.tir import (
Buffer,
BufferLoad,
@@ -344,6 +344,160 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore
+def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
+ """The range constructor.
+
+ Parameters
+ ----------
+ dom : Union[Range, List[PrimExpr]]
+ The domain.
+
+ Returns
+ -------
+ res : Range
+ The Range.
+ """
+ if isinstance(dom, Range):
+ return dom
+ if isinstance(dom, (list, tuple)):
+ return Range(dom[0], dom[1])
+ return Range(0, dom)
+
+
+class axis: # pylint: disable=invalid-name
+ @staticmethod
+ def spatial(
+ dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+ ) -> Var:
+ """The spatial block axis defining function.
+
+ Parameters
+ ----------
+ dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
+ The domain of the iteration variable.
+
+ binding : PrimExpr
+ The binding value of the iteration variable.
+
+ dtype : str
+ The data type of the iteration variable.
+
+ Returns
+ -------
+ res : Var
+ The iteration variable.
+ """
+ return _ffi_api.AxisSpatial( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def reduce(
+ dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+ ) -> Var:
+ """The reduced block axis defining function.
+
+ Parameters
+ ----------
+ dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
+ The domain of the iteration variable.
+
+ binding : PrimExpr
+ The binding value of the iteration variable.
+
+ dtype : str
+ The data type of the iteration variable.
+
+ Returns
+ -------
+ res : Var
+ The iteration variable.
+ """
+ return _ffi_api.AxisReduce( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def scan(
+ dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+ ) -> Var:
+ """The scanning block axis defining function.
+
+ Parameters
+ ----------
+ dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
+ The domain of the iteration variable.
+
+ binding : PrimExpr
+ The binding value of the iteration variable.
+
+ dtype : str
+ The data type of the iteration variable.
+
+ Returns
+ -------
+ res : Var
+ The iteration variable.
+ """
+ return _ffi_api.AxisScan( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def opaque(
+ dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+ ) -> Var:
+ """The opaque block axis defining function.
+
+ Parameters
+ ----------
+ dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
+ The domain of the iteration variable.
+
+ binding : PrimExpr
+ The binding value of the iteration variable.
+
+ dtype : str
+ The data type of the iteration variable.
+
+ Returns
+ -------
+ res : Var
+ The iteration variable.
+ """
+ return _ffi_api.AxisOpaque( # pylint: disable=no-member # type: ignore
+ _as_range(dom), binding, dtype
+ )
+
+ @staticmethod
+ def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]:
+ """The block axis remapping function.
+
+ Parameters
+ ----------
+ kinds : str
+ The types of the iteration variables.
+
+ bindings : List[PrimExpr]
+ The binding values of the iteration variables.
+
+ dtype : str
+ The data types of the iteration variables.
+
+ Returns
+ -------
+ res : Var
+ The iteration variables.
+ """
+ iter_vars = _ffi_api.AxisRemap( # pylint: disable=no-member # type: ignore
+ kinds, bindings, dtype
+ )
+ return iter_vars[0] if len(iter_vars) == 1 else iter_vars
+
+ S = spatial # pylint: disable=invalid-name
+ R = reduce # pylint: disable=invalid-name
+
+
def serial(
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
) -> frame.ForFrame:
@@ -843,6 +997,7 @@ __all__ = [
"match_buffer",
"preflattened_buffer",
"block",
+ "axis",
"serial",
"parallel",
"vectorized",
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 22c7face70..5013e32172 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -173,6 +173,86 @@ BlockFrame Block(String name, bool no_realize) {
return BlockFrame(n);
}
+namespace axis {
+
+IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) {
+ if (Optional<BlockFrame> opt_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+ BlockFrame frame = opt_frame.value();
+ frame->iter_vars.push_back(iter_var);
+ frame->iter_values.push_back(binding);
+ } else {
+ LOG(FATAL) << "TypeError: The last frame is not BlockFrame";
+ }
+ return iter_var;
+}
+
+#define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name) \
+ Var Method(Range dom, PrimExpr binding, DataType dtype) { \
+ ICHECK(dom.defined()) << Name << " axis must have a domain"; \
+ int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \
+ return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)), \
+ /*iter_type=*/Kind, /*thread_tag=*/""), \
+ binding) \
+ ->var; \
+ }
+TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial");
+TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction");
+TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan");
+TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque");
+#undef TVM_TIR_IR_BUILDER_AXIS
+
+Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
+ using namespace tvm::tir;
+ Array<Var> results;
+ ICHECK_EQ(kinds.size(), bindings.size());
+ int n = bindings.size();
+ results.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ char c = kinds.c_str()[i];
+ PrimExpr e = bindings[i];
+ const VarNode* v = e.as<VarNode>();
+ ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap";
+ Range dom{nullptr};
+ for (const auto& frame : IRBuilder::Current()->frames) {
+ if (const auto* for_frame = frame.as<ForFrameNode>()) {
+ ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size());
+ int n = for_frame->doms.size();
+ for (int i = 0; i < n; ++i) {
+ if (for_frame->vars[i].get() == v) {
+ dom = for_frame->doms[i];
+ break;
+ }
+ }
+ if (dom.defined()) {
+ break;
+ }
+ }
+ }
+ ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef<Var>(v);
+ DataType dtype = v->dtype;
+ if (c == 'S') {
+ results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+ /*var=*/Var("", dtype),
+ /*iter_type=*/IterVarType::kDataPar,
+ /*thread_tag=*/""),
+ e)
+ ->var);
+ } else if (c == 'R') {
+ results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
+ /*var=*/Var("", dtype),
+ /*iter_type=*/IterVarType::kCommReduce,
+ /*thread_tag=*/""),
+ e)
+ ->var);
+ } else {
+ LOG(FATAL) << "Unknown axis kind: " << c;
+ }
+ }
+ return results;
+}
+
+} // namespace axis
+
#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \
ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, ObjectRef>> annotations) { \
PrimExpr min = start; \
@@ -304,6 +384,12 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(P
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan").set_body_typed(axis::Scan);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap").set_body_typed(axis::Remap);
+
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized);
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index 9cbfd75e22..d893ebc545 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -114,6 +114,49 @@ def test_ir_builder_tir_block():
assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)
+def test_ir_builder_tir_axis():
+ with IRBuilder() as ib:
+ a = T.var("int32", "a")
+ b = T.var("int32", "b")
+ c = T.var("int32", "c")
+ d = T.var("int32", "d")
+ with T.block("block"):
+ T.axis.spatial(8, a)
+ T.axis.reduce(16, b)
+ T.axis.scan(32, c)
+ T.axis.opaque(64, d)
+ T.evaluate(0)
+
+ # the block generated by IRBuilder
+ block_realize_actual = ib.get()
+
+ # the expected block
+ var_a = tir.Var("a", "int32")
+ var_b = tir.Var("b", "int32")
+ var_c = tir.Var("c", "int32")
+ var_d = tir.Var("d", "int32")
+ block_expected = tir.Block(
+ iter_vars=[
+ tir.IterVar((0, 8), tir.Var("", "int32"), iter_type=tir.IterVar.DataPar),
+ tir.IterVar((0, 16), tir.Var("", "int32"), iter_type=tir.IterVar.CommReduce),
+ tir.IterVar((0, 32), tir.Var("", "int32"), iter_type=tir.IterVar.Ordered),
+ tir.IterVar((0, 64), tir.Var("", "int32"), iter_type=tir.IterVar.DimInfo),
+ ],
+ reads=[],
+ writes=[],
+ name_hint="block",
+ body=tir.Evaluate(0),
+ annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)},
+ )
+ block_realize_expected = tir.BlockRealize(
+ iter_values=[var_a, var_b, var_c, var_d],
+ predicate=True,
+ block=block_expected,
+ )
+ # Check if the generated ir is expected
+ assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)
+
+
def test_ir_builder_tir_for():
with IRBuilder() as ib:
with T.serial(128) as a: