You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/12/28 11:11:11 UTC

[tvm] branch main updated: [TIR] Create Layout with specified axis dtype (#13663)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d6507b256f [TIR] Create Layout with specified axis dtype (#13663)
d6507b256f is described below

commit d6507b256f2f133d2acc187f1740ebe5c082f914
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Wed Dec 28 06:11:06 2022 -0500

    [TIR] Create Layout with specified axis dtype (#13663)
---
 include/tvm/tir/data_layout.h                 |  4 +++-
 python/tvm/tir/data_layout.py                 |  8 ++++++--
 src/tir/ir/data_layout.cc                     | 15 +++++++++------
 tests/python/unittest/test_tir_data_layout.py | 27 ++++++++++++++++++++++++++-
 4 files changed, 44 insertions(+), 10 deletions(-)

diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h
index 81c3e98e66..7aefef6e48 100644
--- a/include/tvm/tir/data_layout.h
+++ b/include/tvm/tir/data_layout.h
@@ -137,8 +137,10 @@ class Layout : public ObjectRef {
    *        the corresponding lower case with factor size
    *        indicates the split dimension.
    *        return undefined layout if "__undef__" is passed.
+   * \param dtype The dtype of generated axes vars in the returned layout.
+   *        It is required to be integer type.
    */
-  TVM_DLL Layout(const std::string& name);  // NOLINT(*)
+  TVM_DLL Layout(const std::string& name, DataType dtype = DataType::Int(32));  // NOLINT(*)
 
   /*!
    * \brief access the internal node container
diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py
index f46a154612..71cc404ee2 100644
--- a/python/tvm/tir/data_layout.py
+++ b/python/tvm/tir/data_layout.py
@@ -163,7 +163,7 @@ class BijectiveLayout(Object):
         return _ffi_api.BijectiveLayoutBackwardShape(self, shape)  # type: ignore
 
 
-def layout(layout_str: str) -> Layout:
+def layout(layout_str: str, dtype: str = "int32") -> Layout:
     """Create a layout node from a string.
 
     Parameters
@@ -177,12 +177,16 @@ def layout(layout_str: str) -> Layout:
         Here subordinate axis channel_block=16 is the factor size of
         the primal axis C (channel).
 
+    dtype : str
+        The dtype of generated axes vars in the returned layout.
+        It is required to be integer type.
+
     Returns
     -------
     layout : Layout
         The created layout
     """
-    return _ffi_api.Layout(layout_str)  # type: ignore
+    return _ffi_api.Layout(layout_str, dtype)  # type: ignore
 
 
 def bijective_layout(
diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc
index 3b22ffc711..3bcb6e8d53 100644
--- a/src/tir/ir/data_layout.cc
+++ b/src/tir/ir/data_layout.cc
@@ -90,7 +90,8 @@ Layout::Layout(const Array<IterVar>& axes) {
   data_ = std::move(node);
 }
 
-Layout::Layout(const std::string& name) {  // NOLINT(*)
+Layout::Layout(const std::string& name, DataType dtype) {  // NOLINT(*)
+  CHECK(dtype.is_int()) << "TypeError: The input dtype should be integer type";
   if (name == "__undef__") return;
 
   auto node = make_object<LayoutNode>();
@@ -106,14 +107,14 @@ Layout::Layout(const std::string& name) {  // NOLINT(*)
                            << " before dimension " << c;
       std::string shape_name("_shape");
       shape_name.insert(0, 1, c);
-      IterVar axis =
-          IterVar(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), tir::kDataPar);
+      IterVar axis(Range(IntImm(dtype, 0), Var(shape_name, dtype)), Var(std::string(1, c), dtype),
+                   tir::kDataPar);
       node->axes.push_back(axis);
     } else if (c >= 'a' && c <= 'z') {
       ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor
                            << " for dimension " << c;
-      IterVar axis =
-          IterVar(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), tir::kDataPar);
+      IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(std::string(1, c), dtype),
+                   tir::kDataPar);
       node->axes.push_back(axis);
       factor = 0;
     } else if (c >= '0' && c <= '9') {
@@ -426,7 +427,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
                 << ")";
     });
 
-TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name) { return Layout(name); });
+TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name, DataType dtype) {
+  return Layout(name, dtype);
+});
 
 TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int {
   return layout.IndexOf(LayoutAxis::Get(axis));
diff --git a/tests/python/unittest/test_tir_data_layout.py b/tests/python/unittest/test_tir_data_layout.py
index 5c2eb8febd..a76cb50da3 100644
--- a/tests/python/unittest/test_tir_data_layout.py
+++ b/tests/python/unittest/test_tir_data_layout.py
@@ -16,8 +16,9 @@
 # under the License.
 """Test layout and bijective-layout node"""
 
+import pytest
 import tvm
-from tvm import te
+import tvm.error
 from tvm.topi.utils import get_const_tuple
 
 
@@ -52,6 +53,29 @@ def test_layout():
     assert layout[-1] == "c"
 
 
+def test_layout_dtype():
+    layout_i32 = tvm.tir.layout("NCHW")
+    assert layout_i32.axes[0].var.dtype == "int32"
+    assert layout_i32.axes[0].dom.min.dtype == "int32"
+    assert layout_i32.axes[0].dom.extent.dtype == "int32"
+    assert layout_i32.axes[1].var.dtype == "int32"
+    assert layout_i32.axes[1].dom.min.dtype == "int32"
+    assert layout_i32.axes[1].dom.extent.dtype == "int32"
+
+    layout_i64 = tvm.tir.layout("NCHW", dtype="int64")
+    assert layout_i64.axes[2].var.dtype == "int64"
+    assert layout_i64.axes[2].dom.min.dtype == "int64"
+    assert layout_i64.axes[2].dom.extent.dtype == "int64"
+    assert layout_i64.axes[3].var.dtype == "int64"
+    assert layout_i64.axes[3].dom.min.dtype == "int64"
+    assert layout_i64.axes[3].dom.extent.dtype == "int64"
+
+    with pytest.raises(TypeError):
+        tvm.tir.layout("NCHW", dtype="float32")
+    with pytest.raises(TypeError):
+        tvm.tir.layout("NCHW", dtype=None)
+
+
 def test_bilayout_convertible():
     # not convertible
     assert tvm.tir.bijective_layout("NCHW", "ABCD") is None
@@ -88,6 +112,7 @@ def test_bilayout_index():
 
 if __name__ == "__main__":
     test_layout()
+    test_layout_dtype()
     test_bilayout_convertible()
     test_bilayout_shape()
     test_bilayout_index()