You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/17 01:14:05 UTC

[tvm] branch main updated: [TIR] Construct the inverse in SuggestIndexMap (#12797)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 91cce56cfa [TIR] Construct the inverse in SuggestIndexMap (#12797)
91cce56cfa is described below

commit 91cce56cfa697a6a2e097bbae1c67ace22ef8af3
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri Sep 16 18:13:58 2022 -0700

    [TIR] Construct the inverse in SuggestIndexMap (#12797)
    
    Computing the inverse mapping requires arithmetic analysis which is not guaranteed to cover all cases. We provide the pre-defined inverse index map instead.
---
 include/tvm/tir/index_map.h                        | 26 +++++++++++-
 python/tvm/tir/function.py                         | 46 +++++++++++++++++---
 src/tir/ir/index_map.cc                            | 47 +++++++++++++++++----
 src/tir/schedule/analysis/layout.cc                | 49 +++++++++++++++++++---
 .../python/unittest/test_tir_schedule_analysis.py  | 41 ++++++++++++++++++
 5 files changed, 188 insertions(+), 21 deletions(-)

diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h
index f461c5640b..8a176cb3ce 100644
--- a/include/tvm/tir/index_map.h
+++ b/include/tvm/tir/index_map.h
@@ -70,6 +70,18 @@ class IndexMapNode : public Object {
    */
   Array<PrimExpr> final_indices;
 
+  /*!
+   * \brief The inverse index map.
+   *
+   * When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
+   * Otherwise, the inverse index map will be computed on the fly.
+   * It is the user's responsibility to ensure the correctness of the pre-defined inverse index
+   * map.
+   *
+   * \note ObjectRef is used here instead of IndexMap to avoid circular reference.
+   */
+  Optional<ObjectRef> inverse_index_map;
+
   /*!
    * \brief Default constructor
    *
@@ -133,6 +145,7 @@ class IndexMapNode : public Object {
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("initial_indices", &initial_indices);
     v->Visit("final_indices", &final_indices);
+    v->Visit("inverse_index_map", &inverse_index_map);
   }
 
   bool SEqualReduce(const IndexMapNode* other, SEqualReducer equal) const {
@@ -153,15 +166,24 @@ class IndexMapNode : public Object {
 
 class IndexMap : public ObjectRef {
  public:
-  IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices);
+  /*!
+   * \brief The constructor
+   * \param initial_indices Variables representing the indices prior to remapping
+   * \param final_indices Expressions defining the indices after remapping.
+   * \param inverse_index_map The optional pre-defined inverse index map
+   */
+  IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices,
+           Optional<IndexMap> inverse_index_map = NullOpt);
 
   /*!
    * \brief Create an index map from a packed function
    * \param ndim The number of dimensions
    * \param func The function to be applied
+   * \param inverse_index_map The optional pre-defined inverse index map
    * \return The created index map
    */
-  static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func);
+  static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func,
+                           Optional<IndexMap> inverse_index_map = NullOpt);
 
   /*! \brief Generate the inverse mapping.
    *
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index 12c8053e39..e525fc2cc3 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -271,6 +271,12 @@ class IndexMap(Object):
         Variables representing the indices prior to remapping.
     final_indices : List[PrimExpr]
         Expressions defining the indices after remapping.
+    inverse_index_map : Union[Callable, Optional[IndexMap]]
+        The optional pre-defined inverse index map.
+        When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
+        Otherwise, the inverse index map will be computed on the fly.
+        It is the user's responsibility to ensure the correctness of the pre-defined inverse
+        index map.
     """
 
     initial_indices: List[Var]
@@ -281,11 +287,19 @@ class IndexMap(Object):
     # Stage.transform_layout for more details.
     AXIS_SEPARATOR = "axis_separator"
 
-    def __init__(self, initial_indices, final_indices):
-        self.__init_handle_by_constructor__(_ffi_api.IndexMap, initial_indices, final_indices)
+    def __init__(self, initial_indices, final_indices, inverse_index_map):
+        if isinstance(inverse_index_map, Callable):
+            inverse_index_map = IndexMap.from_func(inverse_index_map)
+        self.__init_handle_by_constructor__(
+            _ffi_api.IndexMap, initial_indices, final_indices, inverse_index_map
+        )
 
     @staticmethod
-    def from_func(mapping_function: Callable, ndim: Optional[int] = None):
+    def from_func(
+        mapping_function: Callable,
+        ndim: Optional[int] = None,
+        inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
+    ):
         """Create an index map from a function
 
         Parameters
@@ -305,6 +319,13 @@ class IndexMap(Object):
             mapping_function does not use variadic arguments, ndim is
             optional.
 
+        inverse_index_map : Union[Callable, Optional[IndexMap]]
+            The optional pre-defined inverse index map.
+            When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
+            Otherwise, the inverse index map will be computed on the fly.
+            It is the user's responsibility to ensure the correctness of the pre-defined inverse
+            index map.
+
         Returns
         -------
         index_map: IndexMap
@@ -312,7 +333,9 @@ class IndexMap(Object):
             Returns an IndexMap representing the `mapping_function`.
 
         """
-        index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim)
+        index_map, axis_separators = IndexMap.from_func_with_separators(
+            mapping_function, ndim, inverse_index_map
+        )
         assert not axis_separators, (
             "The mapping_function provided to IndexMap.from_func "
             "may not return IndexMap.AXIS_SEPARATOR.  "
@@ -321,7 +344,11 @@ class IndexMap(Object):
         return index_map
 
     @staticmethod
-    def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = None):
+    def from_func_with_separators(
+        mapping_function: Callable,
+        ndim: Optional[int] = None,
+        inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
+    ):
         """Create an index map from a function
 
         Parameters
@@ -341,6 +368,13 @@ class IndexMap(Object):
             mapping_function does not use variadic arguments, ndim is
             optional.
 
+        inverse_index_map : Union[Callable, Optional[IndexMap]]
+            The optional pre-defined inverse index map.
+            When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
+            Otherwise, the inverse index map will be computed on the fly.
+            It is the user's responsibility to ensure the correctness of the pre-defined inverse
+            index map.
+
         Returns
         -------
         ret: Tuple[IndexMap, List[int]]
@@ -401,7 +435,7 @@ class IndexMap(Object):
                     f"Instead received {val} of type {type(val)}."
                 )
 
-        return IndexMap(initial_indices, final_indices), axis_separators
+        return IndexMap(initial_indices, final_indices, inverse_index_map), axis_separators
 
     def is_equivalent_to(self, other_map: "IndexMap") -> bool:
         """Return if the index maps are equivalent.
diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
index 0e3c3b2774..cceff72ec8 100644
--- a/src/tir/ir/index_map.cc
+++ b/src/tir/ir/index_map.cc
@@ -34,20 +34,23 @@
 namespace tvm {
 namespace tir {
 
-IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices) {
+IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices,
+                   Optional<IndexMap> inverse_index_map) {
   auto n = make_object<IndexMapNode>();
   n->initial_indices = std::move(initial_indices);
   n->final_indices = std::move(final_indices);
+  n->inverse_index_map = std::move(inverse_index_map);
   data_ = std::move(n);
 }
 
-IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func) {
+IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func,
+                            Optional<IndexMap> inverse_index_map) {
   Array<Var> initial_indices;
   initial_indices.reserve(ndim);
   for (int i = 0; i < ndim; ++i) {
     initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32)));
   }
-  return IndexMap(initial_indices, func(initial_indices));
+  return IndexMap(initial_indices, func(initial_indices), std::move(inverse_index_map));
 }
 
 std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initial_ranges) const {
@@ -114,6 +117,10 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia
 }
 
 IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
+  if ((*this)->inverse_index_map.defined()) {
+    // return the pre-defined inverse index map if exists.
+    return Downcast<IndexMap>((*this)->inverse_index_map.value());
+  }
   // Dummy variables to represent the inverse's inputs.
   Array<Var> output_vars;
   for (size_t i = 0; i < (*this)->final_indices.size(); i++) {
@@ -232,7 +239,14 @@ Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,
   return output;
 }
 
-String IndexMapNode::ToPythonString() const {
+/*!
+ * \brief Auxilarry function to comvert an index map to lambda expression in Python.
+ * \param initial_indices The initial indices in the index map.
+ * \param final_indices The final indices in the index map.
+ * \return The lambda expression string.
+ */
+std::string IndexMap2PythonLambdaExpr(const Array<Var>& initial_indices,
+                                      const Array<PrimExpr>& final_indices) {
   std::unordered_set<std::string> used_names;
   Map<Var, PrimExpr> var_remap;
   for (const Var& initial_index : initial_indices) {
@@ -259,10 +273,28 @@ String IndexMapNode::ToPythonString() const {
   }
   oss << ": (";
   for (size_t i = 0; i < final_indices.size(); ++i) {
+    if (i != 0) {
+      oss << " ";
+    }
     oss << Substitute(final_indices[i], var_remap);
-    oss << ", ";
+    oss << ",";
   }
   oss << ")";
+  return oss.str();
+}
+
+String IndexMapNode::ToPythonString() const {
+  std::string lambda_expr = IndexMap2PythonLambdaExpr(initial_indices, final_indices);
+  if (!inverse_index_map.defined()) {
+    return String(lambda_expr);
+  }
+  // Also convert the inverse index map.
+  IndexMap inverse = Downcast<IndexMap>(inverse_index_map.value());
+  std::string inverse_lambda_expr =
+      IndexMap2PythonLambdaExpr(inverse->initial_indices, inverse->final_indices);
+  std::ostringstream oss;
+  oss << "tvm.tir.IndexMap.from_func(" << lambda_expr
+      << ", inverse_index_map=" << inverse_lambda_expr << ")";
   return String(oss.str());
 }
 
@@ -275,8 +307,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 TVM_REGISTER_NODE_TYPE(IndexMapNode);
 
 TVM_REGISTER_GLOBAL("tir.IndexMap")
-    .set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices) {
-      return IndexMap(initial_indices, final_indices);
+    .set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices,
+                       Optional<IndexMap> inverse_index_map) {
+      return IndexMap(initial_indices, final_indices, inverse_index_map);
     });
 
 TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices")
diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc
index b0cafac315..b071b2d7e4 100644
--- a/src/tir/schedule/analysis/layout.cc
+++ b/src/tir/schedule/analysis/layout.cc
@@ -167,20 +167,25 @@ Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>&
     }
     return a.lower_factor > b.lower_factor;
   });
+  // Compute the inverse permutation by argsort
+  std::vector<int> inverse_order = order;
+  std::sort(inverse_order.begin(), inverse_order.end(),
+            [&order](int _a, int _b) -> bool { return order[_a] < order[_b]; });
   // Step 5. Create the indexing mapping
   auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index),  //
-                         split_exprs = std::move(split_exprs),          //
-                         order = std::move(order),                      //
-                         shape = buffer->shape,                         //
+                         &split_exprs,                                  //
+                         &order,                                        //
+                             & shape = buffer->shape,                   //
                          analyzer                                       //
   ](Array<Var> indices) -> Array<PrimExpr> {
     ICHECK_EQ(indices.size(), shape.size());
     for (int i = 0, n = indices.size(); i < n; ++i) {
       analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i]));
     }
+    // Step 5.1: Fuse all indices into a flattened one
     PrimExpr index = f_flatten_index({indices.begin(), indices.end()});
     int ndim = split_exprs.size();
-    // Step 5.1. Split the flattened index according to `split_exprs`
+    // Step 5.2. Split the flattened index according to `split_exprs`
     std::vector<PrimExpr> split;
     split.reserve(ndim);
     for (int i = ndim - 1; i >= 0; --i) {
@@ -190,7 +195,7 @@ Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>&
       index = floordiv(index, extent);
     }
     std::reverse(split.begin(), split.end());
-    // Step 5.2. Reorder the indexing pattern according to `order`
+    // Step 5.3. Reorder the indexing pattern according to `order`
     Array<PrimExpr> results;
     results.reserve(ndim);
     for (int i = 0; i < ndim; ++i) {
@@ -198,7 +203,39 @@ Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>&
     }
     return results;
   };
-  return IndexMap::FromFunc(ndim, f_alter_layout);
+  // Step 6: Create the inverse index mapping.
+  auto f_inverse = [&inverse_order, &split_exprs, &shape = buffer->shape,
+                    analyzer](Array<Var> indices) -> Array<PrimExpr> {
+    ICHECK_EQ(indices.size(), split_exprs.size());
+    // Step 6.1: Reorder the indices according to `inverse_order`. This is the inverse of Step 5.3.
+    // After the inverse permutation, indices[i] corresponds to split_exprs[i]
+    Array<Var> inv_permuted_indices;
+    inv_permuted_indices.reserve(indices.size());
+    for (int i = 0, n = indices.size(); i < n; ++i) {
+      const Var& index = indices[inverse_order[i]];
+      inv_permuted_indices.push_back(index);
+      analyzer->Bind(index, Range::FromMinExtent(0, Integer(split_exprs[i].extent)));
+    }
+
+    // Step 6.2: Fuse all the indices. This is the inverse of Step 5.2.
+    PrimExpr flattened_index = make_const(indices[0]->dtype, 0);
+    int64_t stride = 1;
+    for (int i = static_cast<int>(split_exprs.size()) - 1; i >= 0; --i) {
+      flattened_index = inv_permuted_indices[i] * Integer(stride) + flattened_index;
+      stride *= split_exprs[i].extent;
+    }
+    // Step 6.3: Split the flattened index into multiple indices. This is the inverse of Step 5.1.
+    Array<PrimExpr> result;
+    result.reserve(shape.size());
+    for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
+      PrimExpr index = analyzer->Simplify(floormod(flattened_index, shape[i]));
+      flattened_index = floordiv(flattened_index, shape[i]);
+      result.push_back(index);
+    }
+    return Array<PrimExpr>(result.rbegin(), result.rend());
+  };
+  IndexMap inverse_index_map = IndexMap::FromFunc(split_exprs.size(), f_inverse);
+  return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map);
 }
 
 TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap")
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py
index 5524abbaf0..378e5183b4 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -101,6 +101,47 @@ def test_suggest_index_map_bijective():
     assert index_map.is_equivalent_to(expected_index_map)
 
 
+def test_suggest_index_map_winograd():
+    """use case in winograd conv where the indices are complicated"""
+    fused_outer, i3_3_fused, i4_0, i4_1 = _make_vars("fused_outer", "i3_3_fused", "i4_0", "i4_1")
+    eps = floordiv(fused_outer, 336) * 2 + floordiv(floormod(fused_outer, 16), 8)
+    nu = floordiv(floormod(fused_outer, 336), 112) * 2 + floordiv(floormod(fused_outer, 8), 4)
+    co = floormod(fused_outer, 4) * 32 + i3_3_fused
+    ci = (i4_0 * 32) + i4_1
+    buffer = decl_buffer(shape=[6, 6, 128, 128])
+    index_map = suggest_index_map(
+        buffer=buffer,
+        indices=[eps, nu, co, ci],
+        loops=_make_loops(
+            loop_vars=[fused_outer, i3_3_fused, i4_0, i4_1],
+            extents=[1008, 32, 4, 32],
+        ),
+        predicate=True,
+    )
+    expected_index_map = IndexMap.from_func(
+        lambda i0, i1, i2, i3: (
+            floordiv(i0, 2),
+            floordiv(i1, 2),
+            floormod(i0, 2),
+            floormod(((i1 * 4) + floordiv(i2, 32)), 8),
+            floormod(i2, 32),
+            floordiv(i3, 32),
+            floormod(i3, 32),
+        )
+    )
+    assert index_map.is_equivalent_to(expected_index_map)
+    inverse_index_map = index_map.inverse(buffer.shape)
+    expected_inverse_index_map = IndexMap.from_func(
+        lambda i0, i1, i2, i3, i4, i5, i6: (
+            ((i0 * 2) + i2),
+            ((i1 * 2) + floordiv(((i3 * 32) + i4), 128)),
+            floormod(((i3 * 32) + i4), 128),
+            ((i5 * 32) + i6),
+        )
+    )
+    assert inverse_index_map.is_equivalent_to(expected_inverse_index_map)
+
+
 @tvm.script.ir_module
 class DenseVNNIModule:
     @T.prim_func