You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/17 02:09:12 UTC

[GitHub] [tvm] junrushao1994 opened a new pull request, #11031: [BugFix][TIR] Fix rfactor when RF block becomes spatial

junrushao1994 opened a new pull request, #11031:
URL: https://github.com/apache/tvm/pull/11031

   Should fix #10899


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao1994 commented on pull request #11031: [BugFix][TIR] Fix rfactor when RF block becomes spatial

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on PR #11031:
URL: https://github.com/apache/tvm/pull/11031#issuecomment-1100787778

   CC @zxybazh @MasterJH5574 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao1994 commented on a diff in pull request #11031: [BugFix][TIR] Fix rfactor when RF block becomes spatial

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on code in PR #11031:
URL: https://github.com/apache/tvm/pull/11031#discussion_r851694486


##########
src/tir/schedule/primitive/reduction.cc:
##########
@@ -734,14 +743,19 @@ class RFactorBlockCreator : public BaseBlockCreator {
     var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_));
   }
 
-  void CreateReductionUpdate() final {
+  void CreateReductionUpdate(bool has_reduce_iter) final {
     rf_buf_access_indices_ = old_reduction_update_->indices;
     rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_,
                                   additional_iter_->var);
-    new_reduction_update_ = BufferStore(
-        rf_buffer_,
-        (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0],
-        rf_buf_access_indices_);
+    PrimExpr rhs{nullptr};
+    if (has_reduce_iter) {
+      rhs = (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0];
+    } else {
+      arith::Analyzer analyzer;
+      rhs = (*reducer_.get())(reducer_->identity_element, {combiner_rhs_})[0];
+      rhs = analyzer.Simplify(rhs);

Review Comment:
   good point! i wasnt aware!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #11031: [BugFix][TIR] Fix rfactor when RF block becomes spatial

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on code in PR #11031:
URL: https://github.com/apache/tvm/pull/11031#discussion_r851693583


##########
tests/python/unittest/test_tir_schedule_rfactor.py:
##########
@@ -606,6 +604,68 @@ def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None:
                     F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
 
 
+@T.prim_func
+def rfactor_spatial_only(
+    A: T.Buffer[(1, 512, 7, 7), "float32"],
+    B: T.Buffer[(1, 512, 1, 1), "float32"],
+) -> None:
+    acc = T.alloc_buffer([1, 512, 1, 1], dtype="float32")
+    for i0, i1, i2, i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1):
+        with T.block("acc"):
+            ax0 = T.axis.spatial(1, 0)
+            ax1 = T.axis.spatial(512, i1)
+            ax2 = T.axis.spatial(1, 0)
+            ax3 = T.axis.spatial(1, 0)
+            rv0 = T.axis.reduce(7, i4 // 7)
+            rv1 = T.axis.reduce(7, i4 % 7)
+            T.reads(A[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1])
+            T.writes(acc[ax0, ax1, ax2, ax3])
+            with T.init():
+                acc[ax0, ax1, ax2, ax3] = T.float32(0)
+            acc[ax0, ax1, ax2, ax3] = (
+                acc[ax0, ax1, ax2, ax3] + A[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]
+            )
+    for i0, i1, i2, i3 in T.grid(1, 512, 1, 1):
+        with T.block("B"):
+            ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+            T.reads(acc[ax0, ax1, ax2, ax3])
+            T.writes(B[ax0, ax1, ax2, ax3])
+            B[ax0, ax1, ax2, ax3] = acc[ax0, ax1, ax2, ax3]

Review Comment:
   The block looks redundant to me 🤔 since it doesn’t relate to the rfactor operation. Is there any specific reason we want to keep it here?



##########
src/tir/schedule/primitive/reduction.cc:
##########
@@ -734,14 +743,19 @@ class RFactorBlockCreator : public BaseBlockCreator {
     var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_));
   }
 
-  void CreateReductionUpdate() final {
+  void CreateReductionUpdate(bool has_reduce_iter) final {
     rf_buf_access_indices_ = old_reduction_update_->indices;
     rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_,
                                   additional_iter_->var);
-    new_reduction_update_ = BufferStore(
-        rf_buffer_,
-        (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0],
-        rf_buf_access_indices_);
+    PrimExpr rhs{nullptr};
+    if (has_reduce_iter) {
+      rhs = (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0];
+    } else {
+      arith::Analyzer analyzer;
+      rhs = (*reducer_.get())(reducer_->identity_element, {combiner_rhs_})[0];
+      rhs = analyzer.Simplify(rhs);

Review Comment:
   I think it’s okay to just let `rhs = combiner_rhs_` :eyes:



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao1994 commented on a diff in pull request #11031: [BugFix][TIR] Fix rfactor when RF block becomes spatial

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on code in PR #11031:
URL: https://github.com/apache/tvm/pull/11031#discussion_r851695280


##########
tests/python/unittest/test_tir_schedule_rfactor.py:
##########
@@ -606,6 +604,68 @@ def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None:
                     F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
 
 
+@T.prim_func
+def rfactor_spatial_only(
+    A: T.Buffer[(1, 512, 7, 7), "float32"],
+    B: T.Buffer[(1, 512, 1, 1), "float32"],
+) -> None:
+    acc = T.alloc_buffer([1, 512, 1, 1], dtype="float32")
+    for i0, i1, i2, i3, i4, _i5 in T.grid(1, 512, 1, 1, 49, 1):
+        with T.block("acc"):
+            ax0 = T.axis.spatial(1, 0)
+            ax1 = T.axis.spatial(512, i1)
+            ax2 = T.axis.spatial(1, 0)
+            ax3 = T.axis.spatial(1, 0)
+            rv0 = T.axis.reduce(7, i4 // 7)
+            rv1 = T.axis.reduce(7, i4 % 7)
+            T.reads(A[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1])
+            T.writes(acc[ax0, ax1, ax2, ax3])
+            with T.init():
+                acc[ax0, ax1, ax2, ax3] = T.float32(0)
+            acc[ax0, ax1, ax2, ax3] = (
+                acc[ax0, ax1, ax2, ax3] + A[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]
+            )
+    for i0, i1, i2, i3 in T.grid(1, 512, 1, 1):
+        with T.block("B"):
+            ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+            T.reads(acc[ax0, ax1, ax2, ax3])
+            T.writes(B[ax0, ax1, ax2, ax3])
+            B[ax0, ax1, ax2, ax3] = acc[ax0, ax1, ax2, ax3]

Review Comment:
   Oh I just copied from @zxybazh's example. Removed!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao1994 merged pull request #11031: [BugFix][TIR] Fix rfactor when RF block becomes spatial

Posted by GitBox <gi...@apache.org>.
junrushao1994 merged PR #11031:
URL: https://github.com/apache/tvm/pull/11031


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org