You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/12/28 11:11:11 UTC
[tvm] branch main updated: [TIR] Create Layout with specified axis dtype (#13663)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d6507b256f [TIR] Create Layout with specified axis dtype (#13663)
d6507b256f is described below
commit d6507b256f2f133d2acc187f1740ebe5c082f914
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Wed Dec 28 06:11:06 2022 -0500
[TIR] Create Layout with specified axis dtype (#13663)
---
include/tvm/tir/data_layout.h | 4 +++-
python/tvm/tir/data_layout.py | 8 ++++++--
src/tir/ir/data_layout.cc | 15 +++++++++------
tests/python/unittest/test_tir_data_layout.py | 27 ++++++++++++++++++++++++++-
4 files changed, 44 insertions(+), 10 deletions(-)
diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h
index 81c3e98e66..7aefef6e48 100644
--- a/include/tvm/tir/data_layout.h
+++ b/include/tvm/tir/data_layout.h
@@ -137,8 +137,10 @@ class Layout : public ObjectRef {
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
+ * \param dtype The dtype of generated axes vars in the returned layout.
+ * It is required to be integer type.
*/
- TVM_DLL Layout(const std::string& name); // NOLINT(*)
+ TVM_DLL Layout(const std::string& name, DataType dtype = DataType::Int(32)); // NOLINT(*)
/*!
* \brief access the internal node container
diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py
index f46a154612..71cc404ee2 100644
--- a/python/tvm/tir/data_layout.py
+++ b/python/tvm/tir/data_layout.py
@@ -163,7 +163,7 @@ class BijectiveLayout(Object):
return _ffi_api.BijectiveLayoutBackwardShape(self, shape) # type: ignore
-def layout(layout_str: str) -> Layout:
+def layout(layout_str: str, dtype: str = "int32") -> Layout:
"""Create a layout node from a string.
Parameters
@@ -177,12 +177,16 @@ def layout(layout_str: str) -> Layout:
Here subordinate axis channel_block=16 is the factor size of
the primal axis C (channel).
+ dtype : str
+ The dtype of generated axes vars in the returned layout.
+ It is required to be integer type.
+
Returns
-------
layout : Layout
The created layout
"""
- return _ffi_api.Layout(layout_str) # type: ignore
+ return _ffi_api.Layout(layout_str, dtype) # type: ignore
def bijective_layout(
diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc
index 3b22ffc711..3bcb6e8d53 100644
--- a/src/tir/ir/data_layout.cc
+++ b/src/tir/ir/data_layout.cc
@@ -90,7 +90,8 @@ Layout::Layout(const Array<IterVar>& axes) {
data_ = std::move(node);
}
-Layout::Layout(const std::string& name) { // NOLINT(*)
+Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*)
+ CHECK(dtype.is_int()) << "TypeError: The input dtype should be integer type";
if (name == "__undef__") return;
auto node = make_object<LayoutNode>();
@@ -106,14 +107,14 @@ Layout::Layout(const std::string& name) { // NOLINT(*)
<< " before dimension " << c;
std::string shape_name("_shape");
shape_name.insert(0, 1, c);
- IterVar axis =
- IterVar(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), tir::kDataPar);
+ IterVar axis(Range(IntImm(dtype, 0), Var(shape_name, dtype)), Var(std::string(1, c), dtype),
+ tir::kDataPar);
node->axes.push_back(axis);
} else if (c >= 'a' && c <= 'z') {
ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor
<< " for dimension " << c;
- IterVar axis =
- IterVar(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), tir::kDataPar);
+ IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(std::string(1, c), dtype),
+ tir::kDataPar);
node->axes.push_back(axis);
factor = 0;
} else if (c >= '0' && c <= '9') {
@@ -426,7 +427,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ")";
});
-TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name) { return Layout(name); });
+TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name, DataType dtype) {
+ return Layout(name, dtype);
+});
TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::Get(axis));
diff --git a/tests/python/unittest/test_tir_data_layout.py b/tests/python/unittest/test_tir_data_layout.py
index 5c2eb8febd..a76cb50da3 100644
--- a/tests/python/unittest/test_tir_data_layout.py
+++ b/tests/python/unittest/test_tir_data_layout.py
@@ -16,8 +16,9 @@
# under the License.
"""Test layout and bijective-layout node"""
+import pytest
import tvm
-from tvm import te
+import tvm.error
from tvm.topi.utils import get_const_tuple
@@ -52,6 +53,29 @@ def test_layout():
assert layout[-1] == "c"
+def test_layout_dtype():
+ layout_i32 = tvm.tir.layout("NCHW")
+ assert layout_i32.axes[0].var.dtype == "int32"
+ assert layout_i32.axes[0].dom.min.dtype == "int32"
+ assert layout_i32.axes[0].dom.extent.dtype == "int32"
+ assert layout_i32.axes[1].var.dtype == "int32"
+ assert layout_i32.axes[1].dom.min.dtype == "int32"
+ assert layout_i32.axes[1].dom.extent.dtype == "int32"
+
+ layout_i64 = tvm.tir.layout("NCHW", dtype="int64")
+ assert layout_i64.axes[2].var.dtype == "int64"
+ assert layout_i64.axes[2].dom.min.dtype == "int64"
+ assert layout_i64.axes[2].dom.extent.dtype == "int64"
+ assert layout_i64.axes[3].var.dtype == "int64"
+ assert layout_i64.axes[3].dom.min.dtype == "int64"
+ assert layout_i64.axes[3].dom.extent.dtype == "int64"
+
+ with pytest.raises(TypeError):
+ tvm.tir.layout("NCHW", dtype="float32")
+ with pytest.raises(TypeError):
+ tvm.tir.layout("NCHW", dtype=None)
+
+
def test_bilayout_convertible():
# not convertible
assert tvm.tir.bijective_layout("NCHW", "ABCD") is None
@@ -88,6 +112,7 @@ def test_bilayout_index():
if __name__ == "__main__":
test_layout()
+ test_layout_dtype()
test_bilayout_convertible()
test_bilayout_shape()
test_bilayout_index()