You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/03/15 04:56:08 UTC

[tvm] branch main updated: [TIR][Hexagon] Enhancement of NarrowDataType pass for binary ops (#14298)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 970cd1def8 [TIR][Hexagon] Enhancement of NarrowDataType pass for binary ops (#14298)
970cd1def8 is described below

commit 970cd1def8df7f140499590d11cfe235a56bc642
Author: ibsidorenko <98...@users.noreply.github.com>
AuthorDate: Wed Mar 15 07:55:59 2023 +0300

    [TIR][Hexagon] Enhancement of NarrowDataType pass for binary ops (#14298)
    
    This is enhancement of PR#13327.
    
    Motivation:
    Playing with MetaScheduler for Hexagon target it was found that
    avg_pool2d has rather poor performance due to lack of vectorized code.
    IndexDataTypeNormalizer pass converts all indices to int64 format and
    NarrowDataTypeRewriter should do the opposite (back to int32). In case of fail,
    we have a lot of int64 arithmetic for average pooling that can not be
    vectorized.
    
    What was done:
    Added support of binary ops ("div", "max", "min", "+" etc.) in
    NarrowDataTypeRewriter. In case of different bitwidth of operands in
    binary opeation it does downcasting instead of upcasting (as it was
    before).
    
    Performance impact:
    avg_pool2d from quantized InceptionV3 with the shape [1, 8, 35, 35, 32]
    (NCHW32c layout) tuned with MetaScheduler on Snapdragon 8gen1:
    
    shape             | Before fix, ms | After fix, ms |   speedup   |
    ------------------|----------------|---------------|-------------|
    avg_pool2d, int32 |      6.67      |      4.41     |    +34%     |
    -----------------------------------------------------------------|
---
 src/tir/transforms/narrow_datatype.cc              | 38 ++++++++++++++
 .../unittest/test_tir_transform_narrow_datatype.py | 61 ++++++++++++++++++++++
 2 files changed, 99 insertions(+)

diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc
index e9c57eb78e..ad8132521d 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -258,6 +258,44 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter {
     return Parent::VisitExpr_(op);
   }
 
+#define TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)             \
+  PrimExpr VisitExpr_(const OP* op) {                                     \
+    PrimExpr a = this->VisitExpr(op->a);                                  \
+    PrimExpr b = this->VisitExpr(op->b);                                  \
+    if (op->a.same_as(a) && op->b.same_as(b) && a.dtype() == b.dtype()) { \
+      return GetRef<PrimExpr>(op);                                        \
+    } else {                                                              \
+      if (a.dtype() != b.dtype()) {                                       \
+        bool is_enabled = is_enabled_;                                    \
+        is_enabled_ = true;                                               \
+        PrimExpr lhs = this->VisitExpr(op->a);                            \
+        PrimExpr rhs = this->VisitExpr(op->b);                            \
+        is_enabled_ = is_enabled;                                         \
+        return FUNC(lhs, rhs);                                            \
+      } else {                                                            \
+        return FUNC(a, b);                                                \
+      }                                                                   \
+    }                                                                     \
+  }
+
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=);
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<);  // NOLINT(*)
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>);  // NOLINT(*)
+  TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=);
+
+#undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
+
  private:
   // the internal visitor to deduce the narrowed dtype
   DataTypeVisitor visitor_;
diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py
index 56b63c8893..b3b0c6f59b 100644
--- a/tests/python/unittest/test_tir_transform_narrow_datatype.py
+++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py
@@ -346,5 +346,66 @@ def test_block():
     tvm.ir.assert_structural_equal(after, expected_after)
 
 
+def test_avg_pool2d():
+    @T.prim_func
+    def before(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), "int32")):
+        for j in T.parallel(T.int64(0), T.int64(280)):
+            for i in T.serial(T.int64(0), T.int64(35)):
+                for vi in T.vectorized(T.int64(0), T.int64(32)):
+                    PAVG[(((j * T.int64(1120)) + (i * T.int64(32))) + vi)] = T.cast(
+                        T.Div(
+                            T.cast(PSUM[(((j * T.int64(1120)) + (i * T.int64(32))) + vi)], "int64"),
+                            T.max(
+                                (
+                                    (
+                                        (
+                                            T.min(
+                                                T.int64(1),
+                                                (T.int64(34) - T.floormod(j, T.int64(35))),
+                                            )
+                                            + T.int64(2)
+                                        )
+                                        - T.max(
+                                            (T.int64(1) - T.floormod(j, T.int64(35))), T.int64(0)
+                                        )
+                                    )
+                                    * (
+                                        (T.min(T.int64(1), (T.int64(34) - i)) + T.int64(2))
+                                        - T.max((T.int64(1) - i), T.int64(0))
+                                    )
+                                ),
+                                T.int64(1),
+                            ),
+                        ),
+                        "int32",
+                    )
+
+    @T.prim_func
+    def expected_after(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), "int32")):
+        for j in T.parallel(T.int32(0), T.int32(280)):
+            for i in T.serial(T.int32(0), T.int32(35)):
+                for vi in T.vectorized(T.int32(0), T.int32(32)):
+                    PAVG[(((j * T.int32(1120)) + (i * T.int32(32))) + vi)] = T.Div(
+                        PSUM[(((j * T.int32(1120)) + (i * T.int32(32))) + vi)],
+                        (
+                            (
+                                (
+                                    T.min(T.int32(1), (T.int32(34) - T.floormod(j, T.int32(35))))
+                                    + T.int32(2)
+                                )
+                                - T.max((T.int32(1) - T.floormod(j, T.int32(35))), T.int32(0))
+                            )
+                            * (
+                                (T.min(T.int32(1), (T.int32(34) - i)) + T.int32(2))
+                                - T.max((T.int32(1) - i), T.int32(0))
+                            )
+                        ),
+                    )
+
+    after = tvm.tir.transform.NarrowDataType(32)(tvm.IRModule.from_expr(before))
+    after = tvm.tir.transform.Simplify()(after)
+    tvm.ir.assert_structural_equal(after["main"], expected_after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()