You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/17 01:14:05 UTC
[tvm] branch main updated: [TIR] Construct the inverse in SuggestIndexMap (#12797)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 91cce56cfa [TIR] Construct the inverse in SuggestIndexMap (#12797)
91cce56cfa is described below
commit 91cce56cfa697a6a2e097bbae1c67ace22ef8af3
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri Sep 16 18:13:58 2022 -0700
[TIR] Construct the inverse in SuggestIndexMap (#12797)
Computing the inverse mapping requires arithmetic analysis which is not guaranteed to cover all cases. We provide the pre-defined inverse index map instead.
---
include/tvm/tir/index_map.h | 26 +++++++++++-
python/tvm/tir/function.py | 46 +++++++++++++++++---
src/tir/ir/index_map.cc | 47 +++++++++++++++++----
src/tir/schedule/analysis/layout.cc | 49 +++++++++++++++++++---
.../python/unittest/test_tir_schedule_analysis.py | 41 ++++++++++++++++++
5 files changed, 188 insertions(+), 21 deletions(-)
diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h
index f461c5640b..8a176cb3ce 100644
--- a/include/tvm/tir/index_map.h
+++ b/include/tvm/tir/index_map.h
@@ -70,6 +70,18 @@ class IndexMapNode : public Object {
*/
Array<PrimExpr> final_indices;
+ /*!
+ * \brief The inverse index map.
+ *
+ * When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
+ * Otherwise, the inverse index map will be computed on the fly.
+ * It is the user's responsibility to ensure the correctness of the pre-defined inverse index
+ * map.
+ *
+ * \note ObjectRef is used here instead of IndexMap to avoid circular reference.
+ */
+ Optional<ObjectRef> inverse_index_map;
+
/*!
* \brief Default constructor
*
@@ -133,6 +145,7 @@ class IndexMapNode : public Object {
void VisitAttrs(AttrVisitor* v) {
v->Visit("initial_indices", &initial_indices);
v->Visit("final_indices", &final_indices);
+ v->Visit("inverse_index_map", &inverse_index_map);
}
bool SEqualReduce(const IndexMapNode* other, SEqualReducer equal) const {
@@ -153,15 +166,24 @@ class IndexMapNode : public Object {
class IndexMap : public ObjectRef {
public:
- IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices);
+ /*!
+ * \brief The constructor
+ * \param initial_indices Variables representing the indices prior to remapping
+ * \param final_indices Expressions defining the indices after remapping.
+ * \param inverse_index_map The optional pre-defined inverse index map
+ */
+ IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices,
+ Optional<IndexMap> inverse_index_map = NullOpt);
/*!
* \brief Create an index map from a packed function
* \param ndim The number of dimensions
* \param func The function to be applied
+ * \param inverse_index_map The optional pre-defined inverse index map
* \return The created index map
*/
- static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func);
+ static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func,
+ Optional<IndexMap> inverse_index_map = NullOpt);
/*! \brief Generate the inverse mapping.
*
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index 12c8053e39..e525fc2cc3 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -271,6 +271,12 @@ class IndexMap(Object):
Variables representing the indices prior to remapping.
final_indices : List[PrimExpr]
Expressions defining the indices after remapping.
+ inverse_index_map : Union[Callable, Optional[IndexMap]]
+ The optional pre-defined inverse index map.
+ When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
+ Otherwise, the inverse index map will be computed on the fly.
+ It is the user's responsibility to ensure the correctness of the pre-defined inverse
+ index map.
"""
initial_indices: List[Var]
@@ -281,11 +287,19 @@ class IndexMap(Object):
# Stage.transform_layout for more details.
AXIS_SEPARATOR = "axis_separator"
- def __init__(self, initial_indices, final_indices):
- self.__init_handle_by_constructor__(_ffi_api.IndexMap, initial_indices, final_indices)
+ def __init__(self, initial_indices, final_indices, inverse_index_map):
+ if isinstance(inverse_index_map, Callable):
+ inverse_index_map = IndexMap.from_func(inverse_index_map)
+ self.__init_handle_by_constructor__(
+ _ffi_api.IndexMap, initial_indices, final_indices, inverse_index_map
+ )
@staticmethod
- def from_func(mapping_function: Callable, ndim: Optional[int] = None):
+ def from_func(
+ mapping_function: Callable,
+ ndim: Optional[int] = None,
+ inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
+ ):
"""Create an index map from a function
Parameters
@@ -305,6 +319,13 @@ class IndexMap(Object):
mapping_function does not use variadic arguments, ndim is
optional.
+ inverse_index_map : Union[Callable, Optional[IndexMap]]
+ The optional pre-defined inverse index map.
+ When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
+ Otherwise, the inverse index map will be computed on the fly.
+ It is the user's responsibility to ensure the correctness of the pre-defined inverse
+ index map.
+
Returns
-------
index_map: IndexMap
@@ -312,7 +333,9 @@ class IndexMap(Object):
Returns an IndexMap representing the `mapping_function`.
"""
- index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim)
+ index_map, axis_separators = IndexMap.from_func_with_separators(
+ mapping_function, ndim, inverse_index_map
+ )
assert not axis_separators, (
"The mapping_function provided to IndexMap.from_func "
"may not return IndexMap.AXIS_SEPARATOR. "
@@ -321,7 +344,11 @@ class IndexMap(Object):
return index_map
@staticmethod
- def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = None):
+ def from_func_with_separators(
+ mapping_function: Callable,
+ ndim: Optional[int] = None,
+ inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
+ ):
"""Create an index map from a function
Parameters
@@ -341,6 +368,13 @@ class IndexMap(Object):
mapping_function does not use variadic arguments, ndim is
optional.
+ inverse_index_map : Union[Callable, Optional[IndexMap]]
+ The optional pre-defined inverse index map.
+ When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
+ Otherwise, the inverse index map will be computed on the fly.
+ It is the user's responsibility to ensure the correctness of the pre-defined inverse
+ index map.
+
Returns
-------
ret: Tuple[IndexMap, List[int]]
@@ -401,7 +435,7 @@ class IndexMap(Object):
f"Instead received {val} of type {type(val)}."
)
- return IndexMap(initial_indices, final_indices), axis_separators
+ return IndexMap(initial_indices, final_indices, inverse_index_map), axis_separators
def is_equivalent_to(self, other_map: "IndexMap") -> bool:
"""Return if the index maps are equivalent.
diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
index 0e3c3b2774..cceff72ec8 100644
--- a/src/tir/ir/index_map.cc
+++ b/src/tir/ir/index_map.cc
@@ -34,20 +34,23 @@
namespace tvm {
namespace tir {
-IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices) {
+IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices,
+ Optional<IndexMap> inverse_index_map) {
auto n = make_object<IndexMapNode>();
n->initial_indices = std::move(initial_indices);
n->final_indices = std::move(final_indices);
+ n->inverse_index_map = std::move(inverse_index_map);
data_ = std::move(n);
}
-IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func) {
+IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func,
+ Optional<IndexMap> inverse_index_map) {
Array<Var> initial_indices;
initial_indices.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32)));
}
- return IndexMap(initial_indices, func(initial_indices));
+ return IndexMap(initial_indices, func(initial_indices), std::move(inverse_index_map));
}
std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initial_ranges) const {
@@ -114,6 +117,10 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia
}
IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
+ if ((*this)->inverse_index_map.defined()) {
+ // return the pre-defined inverse index map if exists.
+ return Downcast<IndexMap>((*this)->inverse_index_map.value());
+ }
// Dummy variables to represent the inverse's inputs.
Array<Var> output_vars;
for (size_t i = 0; i < (*this)->final_indices.size(); i++) {
@@ -232,7 +239,14 @@ Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,
return output;
}
-String IndexMapNode::ToPythonString() const {
+/*!
+ * \brief Auxilarry function to comvert an index map to lambda expression in Python.
+ * \param initial_indices The initial indices in the index map.
+ * \param final_indices The final indices in the index map.
+ * \return The lambda expression string.
+ */
+std::string IndexMap2PythonLambdaExpr(const Array<Var>& initial_indices,
+ const Array<PrimExpr>& final_indices) {
std::unordered_set<std::string> used_names;
Map<Var, PrimExpr> var_remap;
for (const Var& initial_index : initial_indices) {
@@ -259,10 +273,28 @@ String IndexMapNode::ToPythonString() const {
}
oss << ": (";
for (size_t i = 0; i < final_indices.size(); ++i) {
+ if (i != 0) {
+ oss << " ";
+ }
oss << Substitute(final_indices[i], var_remap);
- oss << ", ";
+ oss << ",";
}
oss << ")";
+ return oss.str();
+}
+
+String IndexMapNode::ToPythonString() const {
+ std::string lambda_expr = IndexMap2PythonLambdaExpr(initial_indices, final_indices);
+ if (!inverse_index_map.defined()) {
+ return String(lambda_expr);
+ }
+ // Also convert the inverse index map.
+ IndexMap inverse = Downcast<IndexMap>(inverse_index_map.value());
+ std::string inverse_lambda_expr =
+ IndexMap2PythonLambdaExpr(inverse->initial_indices, inverse->final_indices);
+ std::ostringstream oss;
+ oss << "tvm.tir.IndexMap.from_func(" << lambda_expr
+ << ", inverse_index_map=" << inverse_lambda_expr << ")";
return String(oss.str());
}
@@ -275,8 +307,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IndexMapNode);
TVM_REGISTER_GLOBAL("tir.IndexMap")
- .set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices) {
- return IndexMap(initial_indices, final_indices);
+ .set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices,
+ Optional<IndexMap> inverse_index_map) {
+ return IndexMap(initial_indices, final_indices, inverse_index_map);
});
TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices")
diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc
index b0cafac315..b071b2d7e4 100644
--- a/src/tir/schedule/analysis/layout.cc
+++ b/src/tir/schedule/analysis/layout.cc
@@ -167,20 +167,25 @@ Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>&
}
return a.lower_factor > b.lower_factor;
});
+ // Compute the inverse permutation by argsort
+ std::vector<int> inverse_order = order;
+ std::sort(inverse_order.begin(), inverse_order.end(),
+ [&order](int _a, int _b) -> bool { return order[_a] < order[_b]; });
// Step 5. Create the indexing mapping
auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), //
- split_exprs = std::move(split_exprs), //
- order = std::move(order), //
- shape = buffer->shape, //
+ &split_exprs, //
+ &order, //
+ & shape = buffer->shape, //
analyzer //
](Array<Var> indices) -> Array<PrimExpr> {
ICHECK_EQ(indices.size(), shape.size());
for (int i = 0, n = indices.size(); i < n; ++i) {
analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i]));
}
+ // Step 5.1: Fuse all indices into a flattened one
PrimExpr index = f_flatten_index({indices.begin(), indices.end()});
int ndim = split_exprs.size();
- // Step 5.1. Split the flattened index according to `split_exprs`
+ // Step 5.2. Split the flattened index according to `split_exprs`
std::vector<PrimExpr> split;
split.reserve(ndim);
for (int i = ndim - 1; i >= 0; --i) {
@@ -190,7 +195,7 @@ Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>&
index = floordiv(index, extent);
}
std::reverse(split.begin(), split.end());
- // Step 5.2. Reorder the indexing pattern according to `order`
+ // Step 5.3. Reorder the indexing pattern according to `order`
Array<PrimExpr> results;
results.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
@@ -198,7 +203,39 @@ Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>&
}
return results;
};
- return IndexMap::FromFunc(ndim, f_alter_layout);
+ // Step 6: Create the inverse index mapping.
+ auto f_inverse = [&inverse_order, &split_exprs, &shape = buffer->shape,
+ analyzer](Array<Var> indices) -> Array<PrimExpr> {
+ ICHECK_EQ(indices.size(), split_exprs.size());
+ // Step 6.1: Reorder the indices according to `inverse_order`. This is the inverse of Step 5.3.
+ // After the inverse permutation, indices[i] corresponds to split_exprs[i]
+ Array<Var> inv_permuted_indices;
+ inv_permuted_indices.reserve(indices.size());
+ for (int i = 0, n = indices.size(); i < n; ++i) {
+ const Var& index = indices[inverse_order[i]];
+ inv_permuted_indices.push_back(index);
+ analyzer->Bind(index, Range::FromMinExtent(0, Integer(split_exprs[i].extent)));
+ }
+
+ // Step 6.2: Fuse all the indices. This is the inverse of Step 5.2.
+ PrimExpr flattened_index = make_const(indices[0]->dtype, 0);
+ int64_t stride = 1;
+ for (int i = static_cast<int>(split_exprs.size()) - 1; i >= 0; --i) {
+ flattened_index = inv_permuted_indices[i] * Integer(stride) + flattened_index;
+ stride *= split_exprs[i].extent;
+ }
+ // Step 6.3: Split the flattened index into multiple indices. This is the inverse of Step 5.1.
+ Array<PrimExpr> result;
+ result.reserve(shape.size());
+ for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
+ PrimExpr index = analyzer->Simplify(floormod(flattened_index, shape[i]));
+ flattened_index = floordiv(flattened_index, shape[i]);
+ result.push_back(index);
+ }
+ return Array<PrimExpr>(result.rbegin(), result.rend());
+ };
+ IndexMap inverse_index_map = IndexMap::FromFunc(split_exprs.size(), f_inverse);
+ return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map);
}
TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap")
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py
index 5524abbaf0..378e5183b4 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -101,6 +101,47 @@ def test_suggest_index_map_bijective():
assert index_map.is_equivalent_to(expected_index_map)
+def test_suggest_index_map_winograd():
+ """use case in winograd conv where the indices are complicated"""
+ fused_outer, i3_3_fused, i4_0, i4_1 = _make_vars("fused_outer", "i3_3_fused", "i4_0", "i4_1")
+ eps = floordiv(fused_outer, 336) * 2 + floordiv(floormod(fused_outer, 16), 8)
+ nu = floordiv(floormod(fused_outer, 336), 112) * 2 + floordiv(floormod(fused_outer, 8), 4)
+ co = floormod(fused_outer, 4) * 32 + i3_3_fused
+ ci = (i4_0 * 32) + i4_1
+ buffer = decl_buffer(shape=[6, 6, 128, 128])
+ index_map = suggest_index_map(
+ buffer=buffer,
+ indices=[eps, nu, co, ci],
+ loops=_make_loops(
+ loop_vars=[fused_outer, i3_3_fused, i4_0, i4_1],
+ extents=[1008, 32, 4, 32],
+ ),
+ predicate=True,
+ )
+ expected_index_map = IndexMap.from_func(
+ lambda i0, i1, i2, i3: (
+ floordiv(i0, 2),
+ floordiv(i1, 2),
+ floormod(i0, 2),
+ floormod(((i1 * 4) + floordiv(i2, 32)), 8),
+ floormod(i2, 32),
+ floordiv(i3, 32),
+ floormod(i3, 32),
+ )
+ )
+ assert index_map.is_equivalent_to(expected_index_map)
+ inverse_index_map = index_map.inverse(buffer.shape)
+ expected_inverse_index_map = IndexMap.from_func(
+ lambda i0, i1, i2, i3, i4, i5, i6: (
+ ((i0 * 2) + i2),
+ ((i1 * 2) + floordiv(((i3 * 32) + i4), 128)),
+ floormod(((i3 * 32) + i4), 128),
+ ((i5 * 32) + i6),
+ )
+ )
+ assert inverse_index_map.is_equivalent_to(expected_inverse_index_map)
+
+
@tvm.script.ir_module
class DenseVNNIModule:
@T.prim_func