You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/04/17 05:29:36 UTC

[tvm] branch main updated: [BugFix][TIR] Fix rfactor when RF block becomes spatial (#11031)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8d868f6bf3 [BugFix][TIR] Fix rfactor when RF block becomes spatial (#11031)
8d868f6bf3 is described below

commit 8d868f6bf3802dcf61cea2697ee81ffeae08b6b0
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sat Apr 16 22:29:31 2022 -0700

    [BugFix][TIR] Fix rfactor when RF block becomes spatial (#11031)
    
    Should fix #10899
---
 src/tir/schedule/primitive/reduction.cc            | 34 ++++++++----
 tests/python/unittest/test_tir_schedule_rfactor.py | 63 ++++++++++++++++++++--
 2 files changed, 83 insertions(+), 14 deletions(-)

diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc
index fddf73da01..99ca03b6c9 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -578,7 +578,14 @@ class BaseBlockCreator {
     for (int i = 0; i < n_block_iters_; ++i) {
       CreateNormalIters(i);
     }
-    CreateReductionUpdate();
+    bool has_reduce_iter = false;
+    for (const IterVar& iter_var : iter_vars_) {
+      if (iter_var->iter_type == IterVarType::kCommReduce) {
+        has_reduce_iter = true;
+        break;
+      }
+    }
+    CreateReductionUpdate(has_reduce_iter);
     CreateReadWriteRegions();
 
     String new_block_name = old_block_realize_->block->name_hint;
@@ -587,15 +594,17 @@ class BaseBlockCreator {
       new_block_name = new_block_name + "_rf";
       predicate = old_block_realize_->predicate;
     }
+    Optional<Stmt> init_block =
+        has_reduce_iter ? BufferStore(new_reduction_update_->buffer, reducer_->identity_element[0],
+                                      new_reduction_update_->indices)
+                        : Optional<Stmt>(NullOpt);
     new_block_ = Block(
         /*iter_vars=*/iter_vars_,
         /*reads=*/read_regions_,
         /*writes=*/write_regions_,
         /*name_hint=*/new_block_name,
         /*body=*/new_reduction_update_,
-        /*init=*/
-        BufferStore(new_reduction_update_->buffer, reducer_->identity_element[0],
-                    new_reduction_update_->indices),
+        /*init=*/init_block,
         /*alloc_buffers=*/{},
         /*match_buffers=*/{},
         /*annotations=*/old_block_realize_->block->annotations);
@@ -605,7 +614,7 @@ class BaseBlockCreator {
  private:
   virtual void CreateAdditionalIter() = 0;
   virtual void CreateNormalIters(int idx) = 0;
-  virtual void CreateReductionUpdate() = 0;
+  virtual void CreateReductionUpdate(bool has_reduce_iter) = 0;
   virtual void CreateReadWriteRegions() = 0;
 
  public:
@@ -734,14 +743,17 @@ class RFactorBlockCreator : public BaseBlockCreator {
     var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_));
   }
 
-  void CreateReductionUpdate() final {
+  void CreateReductionUpdate(bool has_reduce_iter) final {
     rf_buf_access_indices_ = old_reduction_update_->indices;
     rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_,
                                   additional_iter_->var);
-    new_reduction_update_ = BufferStore(
-        rf_buffer_,
-        (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0],
-        rf_buf_access_indices_);
+    PrimExpr rhs{nullptr};
+    if (has_reduce_iter) {
+      rhs = (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0];
+    } else {
+      rhs = combiner_rhs_;
+    }
+    new_reduction_update_ = BufferStore(rf_buffer_, rhs, rf_buf_access_indices_);
     new_reduction_update_ = Downcast<BufferStore>(Substitute(new_reduction_update_, var_map_));
   }
 
@@ -830,7 +842,7 @@ class WriteBackBlockCreator : public BaseBlockCreator {
     }
   }
 
-  void CreateReductionUpdate() final {
+  void CreateReductionUpdate(bool has_reduce_iter) final {
     wb_lhs_ = Downcast<BufferLoad>(Substitute(combiner_lhs_, var_map_));
     wb_rhs_ =
         Downcast<BufferLoad>(Substitute(BufferLoad(rf_buffer_, rf_buf_access_indices_), var_map_));
diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py
index b2885404c5..a533668023 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -472,9 +472,7 @@ def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None:
     for i in range(128):
         with T.block("B_rf"):
             vi0 = T.axis.S(128, i)
-            with T.init():
-                B_rf[vi0] = 0.0
-            B_rf[vi0] = B_rf[vi0] + A[vi0]
+            B_rf[vi0] = A[vi0]
 
     for i in range(128):
         with T.block("B"):
@@ -606,6 +604,56 @@ def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None:
                     F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
 
 
+@T.prim_func
+def rfactor_spatial_only(
+    A: T.Buffer[(1, 512, 7, 7), "float32"],
+    B: T.Buffer[(1, 512, 1, 1), "float32"],
+) -> None:
+    for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1):
+        with T.block("acc"):
+            ax0 = T.axis.spatial(1, 0)
+            ax1 = T.axis.spatial(512, i1)
+            ax2 = T.axis.spatial(1, 0)
+            ax3 = T.axis.spatial(1, 0)
+            rv0 = T.axis.reduce(7, i4 // 7)
+            rv1 = T.axis.reduce(7, i4 % 7)
+            T.reads(A[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1])
+            T.writes(B[ax0, ax1, ax2, ax3])
+            with T.init():
+                B[ax0, ax1, ax2, ax3] = T.float32(0)
+            B[ax0, ax1, ax2, ax3] = (
+                B[ax0, ax1, ax2, ax3] + A[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]
+            )
+
+
+@T.prim_func
+def rfactor_spatial_only_after(
+    A: T.Buffer[(1, 512, 7, 7), "float32"],
+    B: T.Buffer[(1, 512, 1, 1), "float32"],
+) -> None:
+    # body
+    # with T.block("root")
+    B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32")
+    for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1):
+        with T.block("acc_rf"):
+            vi4 = T.axis.spatial(49, i4)
+            ax0 = T.axis.spatial(1, 0)
+            ax1 = T.axis.spatial(512, i1)
+            ax2 = T.axis.spatial(1, 0)
+            ax3 = T.axis.spatial(1, 0)
+            B_rf[ax0, ax1, ax2, ax3, vi4] = A[ax0, ax1, ax2 * 7 + vi4 // 7, ax3 * 7 + vi4 % 7]
+    for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1):
+        with T.block("acc"):
+            vi4 = T.axis.reduce(49, i4)
+            ax0 = T.axis.spatial(1, 0)
+            ax1 = T.axis.spatial(512, i1)
+            ax2 = T.axis.spatial(1, 0)
+            ax3 = T.axis.spatial(1, 0)
+            with T.init():
+                B[ax0, ax1, ax2, ax3] = T.float32(0)
+            B[ax0, ax1, ax2, ax3] = B[ax0, ax1, ax2, ax3] + B_rf[ax0, ax1, ax2, ax3, vi4]
+
+
 # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
 
 
@@ -800,5 +848,14 @@ def test_reduction_rfactor_with_annotation():
     verify_trace_roundtrip(s, mod=square_sum_with_annotation)
 
 
+def test_reduction_rfactor_spatial_only():
+    s = tir.Schedule(rfactor_spatial_only, debug_mask="all")
+    block = s.get_block(name="acc", func_name="main")
+    _, _, _, _, loop, _ = s.get_loops(block)
+    s.rfactor(loop=loop, factor_axis=4)
+    tvm.ir.assert_structural_equal(s.mod["main"], rfactor_spatial_only_after)
+    verify_trace_roundtrip(s, mod=rfactor_spatial_only)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))