You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/04/17 05:29:36 UTC
[tvm] branch main updated: [BugFix][TIR] Fix rfactor when RF block becomes spatial (#11031)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8d868f6bf3 [BugFix][TIR] Fix rfactor when RF block becomes spatial (#11031)
8d868f6bf3 is described below
commit 8d868f6bf3802dcf61cea2697ee81ffeae08b6b0
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sat Apr 16 22:29:31 2022 -0700
[BugFix][TIR] Fix rfactor when RF block becomes spatial (#11031)
Should fix #10899
---
src/tir/schedule/primitive/reduction.cc | 34 ++++++++----
tests/python/unittest/test_tir_schedule_rfactor.py | 63 ++++++++++++++++++++--
2 files changed, 83 insertions(+), 14 deletions(-)
diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc
index fddf73da01..99ca03b6c9 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -578,7 +578,14 @@ class BaseBlockCreator {
for (int i = 0; i < n_block_iters_; ++i) {
CreateNormalIters(i);
}
- CreateReductionUpdate();
+ bool has_reduce_iter = false;
+ for (const IterVar& iter_var : iter_vars_) {
+ if (iter_var->iter_type == IterVarType::kCommReduce) {
+ has_reduce_iter = true;
+ break;
+ }
+ }
+ CreateReductionUpdate(has_reduce_iter);
CreateReadWriteRegions();
String new_block_name = old_block_realize_->block->name_hint;
@@ -587,15 +594,17 @@ class BaseBlockCreator {
new_block_name = new_block_name + "_rf";
predicate = old_block_realize_->predicate;
}
+ Optional<Stmt> init_block =
+ has_reduce_iter ? BufferStore(new_reduction_update_->buffer, reducer_->identity_element[0],
+ new_reduction_update_->indices)
+ : Optional<Stmt>(NullOpt);
new_block_ = Block(
/*iter_vars=*/iter_vars_,
/*reads=*/read_regions_,
/*writes=*/write_regions_,
/*name_hint=*/new_block_name,
/*body=*/new_reduction_update_,
- /*init=*/
- BufferStore(new_reduction_update_->buffer, reducer_->identity_element[0],
- new_reduction_update_->indices),
+ /*init=*/init_block,
/*alloc_buffers=*/{},
/*match_buffers=*/{},
/*annotations=*/old_block_realize_->block->annotations);
@@ -605,7 +614,7 @@ class BaseBlockCreator {
private:
virtual void CreateAdditionalIter() = 0;
virtual void CreateNormalIters(int idx) = 0;
- virtual void CreateReductionUpdate() = 0;
+ virtual void CreateReductionUpdate(bool has_reduce_iter) = 0;
virtual void CreateReadWriteRegions() = 0;
public:
@@ -734,14 +743,17 @@ class RFactorBlockCreator : public BaseBlockCreator {
var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_));
}
- void CreateReductionUpdate() final {
+ void CreateReductionUpdate(bool has_reduce_iter) final {
rf_buf_access_indices_ = old_reduction_update_->indices;
rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_,
additional_iter_->var);
- new_reduction_update_ = BufferStore(
- rf_buffer_,
- (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0],
- rf_buf_access_indices_);
+ PrimExpr rhs{nullptr};
+ if (has_reduce_iter) {
+ rhs = (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0];
+ } else {
+ rhs = combiner_rhs_;
+ }
+ new_reduction_update_ = BufferStore(rf_buffer_, rhs, rf_buf_access_indices_);
new_reduction_update_ = Downcast<BufferStore>(Substitute(new_reduction_update_, var_map_));
}
@@ -830,7 +842,7 @@ class WriteBackBlockCreator : public BaseBlockCreator {
}
}
- void CreateReductionUpdate() final {
+ void CreateReductionUpdate(bool has_reduce_iter) final {
wb_lhs_ = Downcast<BufferLoad>(Substitute(combiner_lhs_, var_map_));
wb_rhs_ =
Downcast<BufferLoad>(Substitute(BufferLoad(rf_buffer_, rf_buf_access_indices_), var_map_));
diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py
index b2885404c5..a533668023 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -472,9 +472,7 @@ def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None:
for i in range(128):
with T.block("B_rf"):
vi0 = T.axis.S(128, i)
- with T.init():
- B_rf[vi0] = 0.0
- B_rf[vi0] = B_rf[vi0] + A[vi0]
+ B_rf[vi0] = A[vi0]
for i in range(128):
with T.block("B"):
@@ -606,6 +604,56 @@ def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None:
F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
+@T.prim_func
+def rfactor_spatial_only(
+ A: T.Buffer[(1, 512, 7, 7), "float32"],
+ B: T.Buffer[(1, 512, 1, 1), "float32"],
+) -> None:
+ for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1):
+ with T.block("acc"):
+ ax0 = T.axis.spatial(1, 0)
+ ax1 = T.axis.spatial(512, i1)
+ ax2 = T.axis.spatial(1, 0)
+ ax3 = T.axis.spatial(1, 0)
+ rv0 = T.axis.reduce(7, i4 // 7)
+ rv1 = T.axis.reduce(7, i4 % 7)
+ T.reads(A[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1])
+ T.writes(B[ax0, ax1, ax2, ax3])
+ with T.init():
+ B[ax0, ax1, ax2, ax3] = T.float32(0)
+ B[ax0, ax1, ax2, ax3] = (
+ B[ax0, ax1, ax2, ax3] + A[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]
+ )
+
+
+@T.prim_func
+def rfactor_spatial_only_after(
+ A: T.Buffer[(1, 512, 7, 7), "float32"],
+ B: T.Buffer[(1, 512, 1, 1), "float32"],
+) -> None:
+ # body
+ # with T.block("root")
+ B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32")
+ for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1):
+ with T.block("acc_rf"):
+ vi4 = T.axis.spatial(49, i4)
+ ax0 = T.axis.spatial(1, 0)
+ ax1 = T.axis.spatial(512, i1)
+ ax2 = T.axis.spatial(1, 0)
+ ax3 = T.axis.spatial(1, 0)
+ B_rf[ax0, ax1, ax2, ax3, vi4] = A[ax0, ax1, ax2 * 7 + vi4 // 7, ax3 * 7 + vi4 % 7]
+ for _i0, i1, _i2, _i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1):
+ with T.block("acc"):
+ vi4 = T.axis.reduce(49, i4)
+ ax0 = T.axis.spatial(1, 0)
+ ax1 = T.axis.spatial(512, i1)
+ ax2 = T.axis.spatial(1, 0)
+ ax3 = T.axis.spatial(1, 0)
+ with T.init():
+ B[ax0, ax1, ax2, ax3] = T.float32(0)
+ B[ax0, ax1, ax2, ax3] = B[ax0, ax1, ax2, ax3] + B_rf[ax0, ax1, ax2, ax3, vi4]
+
+
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
@@ -800,5 +848,14 @@ def test_reduction_rfactor_with_annotation():
verify_trace_roundtrip(s, mod=square_sum_with_annotation)
+def test_reduction_rfactor_spatial_only():
+ s = tir.Schedule(rfactor_spatial_only, debug_mask="all")
+ block = s.get_block(name="acc", func_name="main")
+ _, _, _, _, loop, _ = s.get_loops(block)
+ s.rfactor(loop=loop, factor_axis=4)
+ tvm.ir.assert_structural_equal(s.mod["main"], rfactor_spatial_only_after)
+ verify_trace_roundtrip(s, mod=rfactor_spatial_only)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))