You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/04/26 13:55:07 UTC

[GitHub] [tvm] mbaret opened a new pull request #7925: Add a 'rolling_buffer' scheduling primitive

mbaret opened a new pull request #7925:
URL: https://github.com/apache/tvm/pull/7925


   Support declaring a tensor as a rolling buffer so that it will be realized as a circular buffer without the need to recompute elements. For further detail you can take a look at this RFC: [TBA]
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-828544745


   @junrushao1994 @merrymercy could you take a look at this?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r634334158



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():

Review comment:
       The necessary TIR snippets would be very large for this particular pass. I have a before/after TIR example in the tests so hopefully these could be used as a reference. In lieu of TIR in the documentation, I've improved it somewhat and included a link to the RFC which has a far more detailed explanation of the pass + diagrams. Let me know if you think this is sufficient.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r634332852



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    class RollingBufferInfo:
+        def __init__(self, rolling_axis, rolling_extent, axis_overlaps, axis_iter_vars):
+            self.rolling_axis = rolling_axis
+            self.rolling_extent = rolling_extent
+            self.axis_overlaps = axis_overlaps
+            self.axis_iter_vars = axis_iter_vars
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_strides = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(bound.min, tvm.tir.Mul)

Review comment:
       I couldn't think of an especially clean way to handle this. My primary reference for this is `inject_double_buffer.cc` and there the behaviour is to CHECK assert when the pass fails. I'm not sure exceptions are necessarily a better way to go here because that suggests to a degree that they can be handled. For now I've compromised by keeping the asserts but adding an assert message to at least explain what went wrong. Let me know what you think.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-830468003


   CC @Hzfengsy @spectrometerHBH @jinhongyii @MasterJH5574


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-860773309


   ping @manupa-arm, could you please take another look at this patch?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-866866452


   It would appear the conflict is quite non-trivial (the removal of the Python driver). I don't especially want to work around that by registering my pass into the global registry and then calling it from the C++ API so as not to pollute that with a Python dependency. Given this problem, I shall close this PR for now and when I find the time rewrite the pass in C++.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r646729876



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict, namedtuple
+import math
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Rolling buffers are buffers where one of the dimensions has been made into
+    a circular buffer. Two optimizations are implemented in order to accomplish
+    this: sliding window and storage folding. In particular, the sliding window
+    optimization is applied to the entire buffer (to avoid recomputing elements)
+    and storage folding is then applied to just the rolling dimension.
+
+    Rolling buffers must be inside a loop with only part of the buffer used per
+    iteration. The outermost axis will be rolled over.
+
+    For more information, see the RFC:
+    https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    RollingBufferInfo = namedtuple(
+        "RollingBufferInfo", ["rolling_axis", "rolling_extent", "axis_overlaps", "axis_iter_vars"]
+    )
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer_scope" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    divisor = 1
+                    # Handle the case of fractional strides
+                    # They take this form: floordiv(hh.outer, 2)
+                    # Strip the floordiv and keep track of the divisor
+                    if isinstance(bound.min, tvm.tir.FloorDiv):
+                        divisor = bound.min.b.value
+                        bound.min = bound.min.a
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(
+                            bound.min, tvm.tir.Mul
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.a, tvm.tir.Var
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.b, tvm.tir.IntImm
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        iter_var = bound.min.a
+                        stride = bound.min.b.value
+                    stride = math.ceil(stride / divisor)
+                    bound_iter_vars.append(iter_var)
+                    if iter_var is not None:
+                        bound_overlaps.append(bound.extent.value - stride)
+                    else:
+                        bound_overlaps.append(0)
+
+                # Pick the outermost iter_var that's mentioned in the bounds
+                # to be the rolling axis
+                roll_iter_var = None
+                roll_axis = -1
+                for loop in iter_vars:
+                    iter_var = loop.loop_var
+                    if iter_var in bound_iter_vars:
+                        roll_iter_var = iter_var
+                        roll_axis = bound_iter_vars.index(iter_var)
+                        break
+
+                # We must have found an axis to roll over
+                assert (
+                    roll_iter_var is not None
+                ), "Rolling buffer injection failed: no rolling axis found"
+                assert roll_axis != -1, "Rolling buffer injection failed: no rolling axis found"
+                rolling_buffer_info = RollingBufferInfo(
+                    roll_axis, stmt.bounds[roll_axis].extent.value, bound_overlaps, bound_iter_vars
+                )
+                rolling_buffer_to_info[stmt.buffer] = rolling_buffer_info
+                new_bounds = []
+                for i, extent in enumerate(stmt.buffer.shape):
+                    if i == rolling_buffer_info.rolling_axis:
+                        new_bounds.append(tvm.ir.Range(rolling_buffer_info.rolling_extent))
+                    else:
+                        new_bounds.append(tvm.ir.Range(extent))
+                new_realize = tvm.tir.BufferRealize(
+                    stmt.buffer, new_bounds, stmt.condition, stmt.body, stmt.span
+                )
+                hoist_buffer_to_for[iter_var].append(new_realize)
+
+    def _post_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.pop()
+            # If the loop corresponds to an iter_var that needs a BufferRealize
+            # hoisting to its scope, perform the hoisting
+            if stmt.loop_var in hoist_buffer_to_for:
+                body = stmt
+                for realize in hoist_buffer_to_for[stmt.loop_var]:
+                    attrs = buffer_to_attrs[realize.buffer]
+                    new_realize = tvm.tir.BufferRealize(
+                        realize.buffer, realize.bounds, realize.condition, body, realize.span
+                    )
+                    # The attributes attached to the BufferRealize need hoisting too
+                    for attr in attrs:
+                        if attr.attr_key == "rolling_buffer_scope":
+                            continue
+                        new_realize = tvm.tir.AttrStmt(
+                            attr.node, attr.attr_key, attr.value, new_realize, attr.span
+                        )
+                    body = new_realize
+                return body
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if stmt.node in rolling_buffers:
+                # Remove the attribute statements attached to rolling buffers
+                # because they will have been hoisted to the relevant rolling
+                # scope
+                return stmt.body
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # Remove the original BufferRealize for rolling buffers
+                # because they will have been hoisted to the relevant rolling
+                # scope
+                return stmt.body
+        elif isinstance(stmt, tvm.tir.BufferStore):
+            if stmt.buffer in rolling_buffer_to_info:
+                rolling_buffer_info = rolling_buffer_to_info[stmt.buffer]
+                indices = []
+                # First modify the access indices to use modulo arithmetic
+                # for the rolling axis
+                for i, index in enumerate(stmt.indices):
+                    if i == rolling_buffer_info.rolling_axis:
+                        indices.append(tvm.tir.FloorMod(index, rolling_buffer_info.rolling_extent))
+                    else:
+                        indices.append(index)
+                buffer_store = tvm.tir.BufferStore(stmt.buffer, stmt.value, indices, stmt.span)
+                # Then wrap the BufferStores in some Ifs to avoid recomputing elements
+                for i, iter_var in enumerate(rolling_buffer_info.axis_iter_vars):
+                    if iter_var is not None and rolling_buffer_info.axis_overlaps[i] > 0:
+                        dmap = {iter_var: arith.IntervalSet(0, 0)}
+                        term_2 = arith.Analyzer().int_set(stmt.indices[i], dmap).min_value
+                        buffer_store = tvm.tir.IfThenElse(
+                            tvm.tir.Or(
+                                iter_var < 1, term_2 >= rolling_buffer_info.axis_overlaps[i]
+                            ),
+                            buffer_store,
+                            None,
+                        )
+                return buffer_store
+        elif isinstance(stmt, tvm.tir.BufferLoad):
+            if stmt.buffer in rolling_buffer_to_info:
+                rolling_buffer_info = rolling_buffer_to_info[stmt.buffer]
+                indices = []
+                # Modify the access indices to use modulo arithmetic
+                # for the rolling axis
+                for i, index in enumerate(stmt.indices):
+                    if i == rolling_buffer_info.rolling_axis:
+                        indices.append(tvm.tir.FloorMod(index, rolling_buffer_info.rolling_extent))
+                    else:
+                        indices.append(index)
+                return tvm.tir.BufferLoad(stmt.buffer, indices, stmt.span)
+
+    def _ftransform(f, mod, ctx):
+        return f.with_body(
+            tvm.tir.stmt_functor.ir_transform(
+                f.body,
+                _pre_visit,
+                _post_visit,
+                [
+                    "tir.AttrStmt",
+                    "tir.BufferRealize",
+                    "tir.For",
+                    "tir.BufferStore",
+                    "tir.BufferLoad",
+                ],
+            )
+        )
+
+    return tvm.tir.transform.prim_func_pass(
+        _ftransform, opt_level=0, name="tir.InjectRollingBuffer"
+    )

Review comment:
       Yes I'd definitely consider doing that rewrite at some point. I was actually waiting to see how the new 'scheduling passes' would look for TensorIR so that I could potentially follow any such pattern there. Unfortunately I don't have time to currently so I'd appreciate taking this in with the TODO. Perhaps we can revisit once TensorIR is complete?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] giuseros commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r631173425



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    class RollingBufferInfo:
+        def __init__(self, rolling_axis, rolling_extent, axis_overlaps, axis_iter_vars):
+            self.rolling_axis = rolling_axis
+            self.rolling_extent = rolling_extent
+            self.axis_overlaps = axis_overlaps
+            self.axis_iter_vars = axis_iter_vars
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_strides = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(bound.min, tvm.tir.Mul)

Review comment:
       Should we simply not unrolling instead of asserting?

##########
File path: src/te/schedule/schedule_ops.cc
##########
@@ -54,6 +54,9 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map<IterVar, Range>& dom_
   pipeline = s->op->BuildRealize(s, dom_map, pipeline);
   // use attribute to mark scope of the operation.
   pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline);
+  if (s->rolling_buffer) {
+    pipeline = AttrStmt(s->op, tir::attr::rolling_buffer, Bool(true), pipeline);

Review comment:
       What about `rolling_buffer_scope`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-861413701


   ping @junrushao1994, could you take a look?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r648188950



##########
File path: src/te/operation/compute_op.cc
##########
@@ -484,7 +484,9 @@ ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Sta
     }
     ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap),
                                  debug_keep_trivial_loop);
-    ret.init_predicates = MakeBoundCheck(stage, dom_map, ret.init_vmap, true, skip_iter);
+    bool skip_ivar_domain = !stage->rolling_buffer;
+    ret.init_predicates =
+        MakeBoundCheck(stage, dom_map, ret.init_vmap, skip_ivar_domain, skip_iter);

Review comment:
       I agree this is a slight hack. It's because in a rolling buffer we need to expand the size and scope of the intermediate buffers and if we drop the bound checks they end up getting corrupted. To do this more accurately we'd need to replicate something closer to Halide's `store_at` I think to formally change the realization point of the intermediate tensor. However this is my current workaround.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 edited a comment on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
jcf94 edited a comment on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-864671064


   @mbaret Seems this PR is ready for merge once the conflict has been solved?
   
   Do we have a github issue to track the progress of your Cascade scheduling & RFC? If not, I suggest to have one since this schedule primitive can only work in a special condition. 😄 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] manupa-arm commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r631650646



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():

Review comment:
       I think generally it would better to give some extensive documentation -- possibly with an example two TIR snippets to show the transformation. 
   As a community, we should try to encourage this.

##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    class RollingBufferInfo:
+        def __init__(self, rolling_axis, rolling_extent, axis_overlaps, axis_iter_vars):
+            self.rolling_axis = rolling_axis
+            self.rolling_extent = rolling_extent
+            self.axis_overlaps = axis_overlaps
+            self.axis_iter_vars = axis_iter_vars
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_strides = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(bound.min, tvm.tir.Mul)
+                        assert isinstance(bound.min.a, tvm.tir.Var)
+                        assert isinstance(bound.min.b, tvm.tir.IntImm)
+                        iter_var = bound.min.a
+                        stride = bound.min.b.value
+                    bound_iter_vars.append(iter_var)
+                    bound_strides.append(stride)
+                    if iter_var is not None:
+                        bound_overlaps.append(bound.extent.value - stride)
+                    else:
+                        bound_overlaps.append(0)
+
+                # Pick the outermost iter_var that's mentioned in the bounds
+                # to be the rolling axis
+                roll_iter_var = None
+                roll_axis = -1
+                for loop in iter_vars:
+                    iter_var = loop.loop_var
+                    if iter_var in bound_iter_vars:
+                        roll_iter_var = iter_var
+                        roll_axis = bound_iter_vars.index(iter_var)
+                        break
+
+                # We must have found an axis to roll over
+                assert roll_iter_var is not None
+                assert roll_axis != -1
+                rolling_buffer_info = RollingBufferInfo(
+                    roll_axis, stmt.bounds[roll_axis].extent.value, bound_overlaps, bound_iter_vars
+                )
+                rolling_buffer_to_info[stmt.buffer] = rolling_buffer_info
+                new_bounds = []
+                for i, extent in enumerate(stmt.buffer.shape):
+                    if i == rolling_buffer_info.rolling_axis:
+                        new_bounds.append(tvm.ir.Range(rolling_buffer_info.rolling_extent))
+                    else:
+                        new_bounds.append(tvm.ir.Range(extent))
+                new_realize = tvm.tir.BufferRealize(
+                    stmt.buffer, new_bounds, stmt.condition, stmt.body, stmt.span
+                )
+                hoist_buffer_to_for[iter_var].append(new_realize)
+
+    def _post_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.pop()
+            # If the loop corresponds to an iter_var that needs a BufferRealize
+            # hoisting to its scope, perform the hoisting
+            if stmt.loop_var in hoist_buffer_to_for:
+                body = stmt
+                for realize in hoist_buffer_to_for[stmt.loop_var]:
+                    attrs = buffer_to_attrs[realize.buffer]
+                    new_realize = tvm.tir.BufferRealize(
+                        realize.buffer, realize.bounds, realize.condition, body, realize.span
+                    )
+                    # The attributes attached to the BufferRealize need hoisting too
+                    for attr in attrs:
+                        if attr.attr_key == "rolling_buffer":
+                            continue
+                        new_realize = tvm.tir.AttrStmt(
+                            attr.node, attr.attr_key, attr.value, new_realize, attr.span
+                        )
+                    body = new_realize
+                return body
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if stmt.node in rolling_buffers:
+                # Remove the attribute statements attached to rolling buffers
+                # because they will have been hoisted to the relevant rolling
+                # scope
+                return stmt.body

Review comment:
       Undefined behaviour for else ? (applies to all the following as well)

##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    class RollingBufferInfo:

Review comment:
       nit: consider using NamedTuple if this is not going to changed beyond init

##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    class RollingBufferInfo:
+        def __init__(self, rolling_axis, rolling_extent, axis_overlaps, axis_iter_vars):
+            self.rolling_axis = rolling_axis
+            self.rolling_extent = rolling_extent
+            self.axis_overlaps = axis_overlaps
+            self.axis_iter_vars = axis_iter_vars
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_strides = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(bound.min, tvm.tir.Mul)

Review comment:
       I see how it would be hard to stop the pass in the middle 
   
   Depending on the confidence, this should be considered to be an Exception.
   To be safe, we can raise an exception ?
   
   Given the immutability of TVM objects, I think we might need to "try" running these sort of passes -- but that part might be out of scope for this PR. 

##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    class RollingBufferInfo:
+        def __init__(self, rolling_axis, rolling_extent, axis_overlaps, axis_iter_vars):
+            self.rolling_axis = rolling_axis
+            self.rolling_extent = rolling_extent
+            self.axis_overlaps = axis_overlaps
+            self.axis_iter_vars = axis_iter_vars
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_strides = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(bound.min, tvm.tir.Mul)
+                        assert isinstance(bound.min.a, tvm.tir.Var)
+                        assert isinstance(bound.min.b, tvm.tir.IntImm)
+                        iter_var = bound.min.a
+                        stride = bound.min.b.value
+                    bound_iter_vars.append(iter_var)
+                    bound_strides.append(stride)
+                    if iter_var is not None:
+                        bound_overlaps.append(bound.extent.value - stride)
+                    else:
+                        bound_overlaps.append(0)
+
+                # Pick the outermost iter_var that's mentioned in the bounds
+                # to be the rolling axis
+                roll_iter_var = None
+                roll_axis = -1
+                for loop in iter_vars:
+                    iter_var = loop.loop_var
+                    if iter_var in bound_iter_vars:
+                        roll_iter_var = iter_var
+                        roll_axis = bound_iter_vars.index(iter_var)
+                        break
+
+                # We must have found an axis to roll over
+                assert roll_iter_var is not None
+                assert roll_axis != -1
+                rolling_buffer_info = RollingBufferInfo(
+                    roll_axis, stmt.bounds[roll_axis].extent.value, bound_overlaps, bound_iter_vars
+                )
+                rolling_buffer_to_info[stmt.buffer] = rolling_buffer_info
+                new_bounds = []
+                for i, extent in enumerate(stmt.buffer.shape):
+                    if i == rolling_buffer_info.rolling_axis:
+                        new_bounds.append(tvm.ir.Range(rolling_buffer_info.rolling_extent))
+                    else:
+                        new_bounds.append(tvm.ir.Range(extent))
+                new_realize = tvm.tir.BufferRealize(
+                    stmt.buffer, new_bounds, stmt.condition, stmt.body, stmt.span
+                )
+                hoist_buffer_to_for[iter_var].append(new_realize)
+
+    def _post_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.pop()
+            # If the loop corresponds to an iter_var that needs a BufferRealize
+            # hoisting to its scope, perform the hoisting
+            if stmt.loop_var in hoist_buffer_to_for:
+                body = stmt
+                for realize in hoist_buffer_to_for[stmt.loop_var]:
+                    attrs = buffer_to_attrs[realize.buffer]
+                    new_realize = tvm.tir.BufferRealize(
+                        realize.buffer, realize.bounds, realize.condition, body, realize.span
+                    )
+                    # The attributes attached to the BufferRealize need hoisting too
+                    for attr in attrs:
+                        if attr.attr_key == "rolling_buffer":
+                            continue
+                        new_realize = tvm.tir.AttrStmt(
+                            attr.node, attr.attr_key, attr.value, new_realize, attr.span
+                        )
+                    body = new_realize
+                return body
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if stmt.node in rolling_buffers:
+                # Remove the attribute statements attached to rolling buffers
+                # because they will have been hoisted to the relevant rolling
+                # scope
+                return stmt.body
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # Remove the original BufferRealize for rolling buffers
+                # because they will have been hoisted to the relevant rolling
+                # scope
+                return stmt.body
+        elif isinstance(stmt, tvm.tir.BufferStore):
+            if stmt.buffer in rolling_buffer_to_info:
+                rolling_buffer_info = rolling_buffer_to_info[stmt.buffer]
+                indices = []
+                # First modify the access indices to use modulo arithmetic
+                # for the rolling axis
+                for i, index in enumerate(stmt.indices):
+                    if i == rolling_buffer_info.rolling_axis:
+                        indices.append(tvm.tir.FloorMod(index, rolling_buffer_info.rolling_extent))
+                    else:
+                        indices.append(index)
+                buffer_store = tvm.tir.BufferStore(stmt.buffer, stmt.value, indices, stmt.span)
+                # Then wrap the BufferStores in some Ifs to avoid recomputing elements
+                for i, iter_var in enumerate(rolling_buffer_info.axis_iter_vars):
+                    if iter_var is not None and rolling_buffer_info.axis_overlaps[i] > 0:
+                        dmap = {iter_var: arith.IntervalSet(0, 0)}
+                        term_2 = arith.Analyzer().int_set(stmt.indices[i], dmap).min_value
+                        buffer_store = tvm.tir.IfThenElse(
+                            tvm.tir.Or(
+                                iter_var < 1, term_2 >= rolling_buffer_info.axis_overlaps[i]
+                            ),
+                            buffer_store,
+                            None,
+                        )
+                return buffer_store
+        elif isinstance(stmt, tvm.tir.BufferLoad):
+            if stmt.buffer in rolling_buffer_to_info:
+                rolling_buffer_info = rolling_buffer_to_info[stmt.buffer]
+                indices = []
+                # Modify the access indices to use modulo arithmetic
+                # for the rolling axis
+                for i, index in enumerate(stmt.indices):
+                    if i == rolling_buffer_info.rolling_axis:
+                        indices.append(tvm.tir.FloorMod(index, rolling_buffer_info.rolling_extent))
+                    else:
+                        indices.append(index)
+                return tvm.tir.BufferLoad(stmt.buffer, indices, stmt.span)

Review comment:
       Undefined behaviour for else ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-836810118


   @manupa-arm @giuseros 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] manupa-arm commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r642901265



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():

Review comment:
       Fair enough :)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r646312810



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict, namedtuple
+import math
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Rolling buffers are buffers where one of the dimensions has been made into
+    a circular buffer. Two optimizations are implemented in order to accomplish
+    this: sliding window and storage folding. In particular, the sliding window
+    optimization is applied to the entire buffer (to avoid recomputing elements)
+    and storage folding is then applied to just the rolling dimension.
+
+    Rolling buffers must be inside a loop with only part of the buffer used per
+    iteration. The outermost axis will be rolled over.
+
+    For more information, see the RFC:
+    https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    RollingBufferInfo = namedtuple(
+        "RollingBufferInfo", ["rolling_axis", "rolling_extent", "axis_overlaps", "axis_iter_vars"]
+    )
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer_scope" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    divisor = 1
+                    # Handle the case of fractional strides
+                    # They take this form: floordiv(hh.outer, 2)
+                    # Strip the floordiv and keep track of the divisor
+                    if isinstance(bound.min, tvm.tir.FloorDiv):
+                        divisor = bound.min.b.value
+                        bound.min = bound.min.a
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(
+                            bound.min, tvm.tir.Mul
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.a, tvm.tir.Var
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.b, tvm.tir.IntImm
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        iter_var = bound.min.a
+                        stride = bound.min.b.value
+                    stride = math.ceil(stride / divisor)
+                    bound_iter_vars.append(iter_var)
+                    if iter_var is not None:
+                        bound_overlaps.append(bound.extent.value - stride)
+                    else:
+                        bound_overlaps.append(0)
+
+                # Pick the outermost iter_var that's mentioned in the bounds
+                # to be the rolling axis
+                roll_iter_var = None
+                roll_axis = -1
+                for loop in iter_vars:
+                    iter_var = loop.loop_var
+                    if iter_var in bound_iter_vars:
+                        roll_iter_var = iter_var
+                        roll_axis = bound_iter_vars.index(iter_var)
+                        break
+
+                # We must have found an axis to roll over
+                assert (
+                    roll_iter_var is not None
+                ), "Rolling buffer injection failed: no rolling axis found"
+                assert roll_axis != -1, "Rolling buffer injection failed: no rolling axis found"
+                rolling_buffer_info = RollingBufferInfo(
+                    roll_axis, stmt.bounds[roll_axis].extent.value, bound_overlaps, bound_iter_vars
+                )
+                rolling_buffer_to_info[stmt.buffer] = rolling_buffer_info
+                new_bounds = []
+                for i, extent in enumerate(stmt.buffer.shape):
+                    if i == rolling_buffer_info.rolling_axis:
+                        new_bounds.append(tvm.ir.Range(rolling_buffer_info.rolling_extent))
+                    else:
+                        new_bounds.append(tvm.ir.Range(extent))
+                new_realize = tvm.tir.BufferRealize(
+                    stmt.buffer, new_bounds, stmt.condition, stmt.body, stmt.span
+                )
+                hoist_buffer_to_for[iter_var].append(new_realize)
+
+    def _post_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.pop()
+            # If the loop corresponds to an iter_var that needs a BufferRealize
+            # hoisting to its scope, perform the hoisting
+            if stmt.loop_var in hoist_buffer_to_for:
+                body = stmt
+                for realize in hoist_buffer_to_for[stmt.loop_var]:
+                    attrs = buffer_to_attrs[realize.buffer]
+                    new_realize = tvm.tir.BufferRealize(
+                        realize.buffer, realize.bounds, realize.condition, body, realize.span
+                    )
+                    # The attributes attached to the BufferRealize need hoisting too
+                    for attr in attrs:
+                        if attr.attr_key == "rolling_buffer_scope":
+                            continue
+                        new_realize = tvm.tir.AttrStmt(
+                            attr.node, attr.attr_key, attr.value, new_realize, attr.span
+                        )
+                    body = new_realize
+                return body
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if stmt.node in rolling_buffers:
+                # Remove the attribute statements attached to rolling buffers
+                # because they will have been hoisted to the relevant rolling
+                # scope
+                return stmt.body
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # Remove the original BufferRealize for rolling buffers
+                # because they will have been hoisted to the relevant rolling
+                # scope
+                return stmt.body
+        elif isinstance(stmt, tvm.tir.BufferStore):
+            if stmt.buffer in rolling_buffer_to_info:
+                rolling_buffer_info = rolling_buffer_to_info[stmt.buffer]
+                indices = []
+                # First modify the access indices to use modulo arithmetic
+                # for the rolling axis
+                for i, index in enumerate(stmt.indices):
+                    if i == rolling_buffer_info.rolling_axis:
+                        indices.append(tvm.tir.FloorMod(index, rolling_buffer_info.rolling_extent))
+                    else:
+                        indices.append(index)
+                buffer_store = tvm.tir.BufferStore(stmt.buffer, stmt.value, indices, stmt.span)
+                # Then wrap the BufferStores in some Ifs to avoid recomputing elements
+                for i, iter_var in enumerate(rolling_buffer_info.axis_iter_vars):
+                    if iter_var is not None and rolling_buffer_info.axis_overlaps[i] > 0:
+                        dmap = {iter_var: arith.IntervalSet(0, 0)}
+                        term_2 = arith.Analyzer().int_set(stmt.indices[i], dmap).min_value
+                        buffer_store = tvm.tir.IfThenElse(
+                            tvm.tir.Or(
+                                iter_var < 1, term_2 >= rolling_buffer_info.axis_overlaps[i]
+                            ),
+                            buffer_store,
+                            None,
+                        )
+                return buffer_store
+        elif isinstance(stmt, tvm.tir.BufferLoad):
+            if stmt.buffer in rolling_buffer_to_info:
+                rolling_buffer_info = rolling_buffer_to_info[stmt.buffer]
+                indices = []
+                # Modify the access indices to use modulo arithmetic
+                # for the rolling axis
+                for i, index in enumerate(stmt.indices):
+                    if i == rolling_buffer_info.rolling_axis:
+                        indices.append(tvm.tir.FloorMod(index, rolling_buffer_info.rolling_extent))
+                    else:
+                        indices.append(index)
+                return tvm.tir.BufferLoad(stmt.buffer, indices, stmt.span)
+
+    def _ftransform(f, mod, ctx):
+        return f.with_body(
+            tvm.tir.stmt_functor.ir_transform(
+                f.body,
+                _pre_visit,
+                _post_visit,
+                [
+                    "tir.AttrStmt",
+                    "tir.BufferRealize",
+                    "tir.For",
+                    "tir.BufferStore",
+                    "tir.BufferLoad",
+                ],
+            )
+        )
+
+    return tvm.tir.transform.prim_func_pass(
+        _ftransform, opt_level=0, name="tir.InjectRollingBuffer"
+    )

Review comment:
       Seems this is a pretty complex pass, would you consider to rewrite it as a C++ implementation? (not necessary in this PR)

##########
File path: src/te/operation/compute_op.cc
##########
@@ -484,7 +484,9 @@ ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Sta
     }
     ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap),
                                  debug_keep_trivial_loop);
-    ret.init_predicates = MakeBoundCheck(stage, dom_map, ret.init_vmap, true, skip_iter);
+    bool skip_ivar_domain = !stage->rolling_buffer;
+    ret.init_predicates =
+        MakeBoundCheck(stage, dom_map, ret.init_vmap, skip_ivar_domain, skip_iter);

Review comment:
       ```suggestion
       ret.init_predicates =
           MakeBoundCheck(stage, dom_map, ret.init_vmap, !stage->rolling_buffer, skip_iter);
   ```
   
   Nit: Skipping the domain check of IterVar here looks more likely a hack to me, though I don't hava any better suggestion. Seems it's hard to process some more check to guard this operation.

##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict, namedtuple
+import math
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Rolling buffers are buffers where one of the dimensions has been made into
+    a circular buffer. Two optimizations are implemented in order to accomplish
+    this: sliding window and storage folding. In particular, the sliding window
+    optimization is applied to the entire buffer (to avoid recomputing elements)
+    and storage folding is then applied to just the rolling dimension.
+
+    Rolling buffers must be inside a loop with only part of the buffer used per
+    iteration. The outermost axis will be rolled over.
+
+    For more information, see the RFC:
+    https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    RollingBufferInfo = namedtuple(
+        "RollingBufferInfo", ["rolling_axis", "rolling_extent", "axis_overlaps", "axis_iter_vars"]
+    )
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer_scope" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    divisor = 1
+                    # Handle the case of fractional strides
+                    # They take this form: floordiv(hh.outer, 2)
+                    # Strip the floordiv and keep track of the divisor
+                    if isinstance(bound.min, tvm.tir.FloorDiv):
+                        divisor = bound.min.b.value
+                        bound.min = bound.min.a
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(
+                            bound.min, tvm.tir.Mul
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.a, tvm.tir.Var
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.b, tvm.tir.IntImm
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        iter_var = bound.min.a
+                        stride = bound.min.b.value
+                    stride = math.ceil(stride / divisor)
+                    bound_iter_vars.append(iter_var)
+                    if iter_var is not None:
+                        bound_overlaps.append(bound.extent.value - stride)
+                    else:
+                        bound_overlaps.append(0)
+
+                # Pick the outermost iter_var that's mentioned in the bounds
+                # to be the rolling axis

Review comment:
       Would this be possible to support rolling multiple dimensions?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r646312810



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict, namedtuple
+import math
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Rolling buffers are buffers where one of the dimensions has been made into
+    a circular buffer. Two optimizations are implemented in order to accomplish
+    this: sliding window and storage folding. In particular, the sliding window
+    optimization is applied to the entire buffer (to avoid recomputing elements)
+    and storage folding is then applied to just the rolling dimension.
+
+    Rolling buffers must be inside a loop with only part of the buffer used per
+    iteration. The outermost axis will be rolled over.
+
+    For more information, see the RFC:
+    https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    RollingBufferInfo = namedtuple(
+        "RollingBufferInfo", ["rolling_axis", "rolling_extent", "axis_overlaps", "axis_iter_vars"]
+    )
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer_scope" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    divisor = 1
+                    # Handle the case of fractional strides
+                    # They take this form: floordiv(hh.outer, 2)
+                    # Strip the floordiv and keep track of the divisor
+                    if isinstance(bound.min, tvm.tir.FloorDiv):
+                        divisor = bound.min.b.value
+                        bound.min = bound.min.a
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(
+                            bound.min, tvm.tir.Mul
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.a, tvm.tir.Var
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.b, tvm.tir.IntImm
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        iter_var = bound.min.a
+                        stride = bound.min.b.value
+                    stride = math.ceil(stride / divisor)
+                    bound_iter_vars.append(iter_var)
+                    if iter_var is not None:
+                        bound_overlaps.append(bound.extent.value - stride)
+                    else:
+                        bound_overlaps.append(0)
+
+                # Pick the outermost iter_var that's mentioned in the bounds
+                # to be the rolling axis
+                roll_iter_var = None
+                roll_axis = -1
+                for loop in iter_vars:
+                    iter_var = loop.loop_var
+                    if iter_var in bound_iter_vars:
+                        roll_iter_var = iter_var
+                        roll_axis = bound_iter_vars.index(iter_var)
+                        break
+
+                # We must have found an axis to roll over
+                assert (
+                    roll_iter_var is not None
+                ), "Rolling buffer injection failed: no rolling axis found"
+                assert roll_axis != -1, "Rolling buffer injection failed: no rolling axis found"
+                rolling_buffer_info = RollingBufferInfo(
+                    roll_axis, stmt.bounds[roll_axis].extent.value, bound_overlaps, bound_iter_vars
+                )
+                rolling_buffer_to_info[stmt.buffer] = rolling_buffer_info
+                new_bounds = []
+                for i, extent in enumerate(stmt.buffer.shape):
+                    if i == rolling_buffer_info.rolling_axis:
+                        new_bounds.append(tvm.ir.Range(rolling_buffer_info.rolling_extent))
+                    else:
+                        new_bounds.append(tvm.ir.Range(extent))
+                new_realize = tvm.tir.BufferRealize(
+                    stmt.buffer, new_bounds, stmt.condition, stmt.body, stmt.span
+                )
+                hoist_buffer_to_for[iter_var].append(new_realize)
+
+    def _post_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.pop()
+            # If the loop corresponds to an iter_var that needs a BufferRealize
+            # hoisting to its scope, perform the hoisting
+            if stmt.loop_var in hoist_buffer_to_for:
+                body = stmt
+                for realize in hoist_buffer_to_for[stmt.loop_var]:
+                    attrs = buffer_to_attrs[realize.buffer]
+                    new_realize = tvm.tir.BufferRealize(
+                        realize.buffer, realize.bounds, realize.condition, body, realize.span
+                    )
+                    # The attributes attached to the BufferRealize need hoisting too
+                    for attr in attrs:
+                        if attr.attr_key == "rolling_buffer_scope":
+                            continue
+                        new_realize = tvm.tir.AttrStmt(
+                            attr.node, attr.attr_key, attr.value, new_realize, attr.span
+                        )
+                    body = new_realize
+                return body
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if stmt.node in rolling_buffers:
+                # Remove the attribute statements attached to rolling buffers
+                # because they will have been hoisted to the relevant rolling
+                # scope
+                return stmt.body
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # Remove the original BufferRealize for rolling buffers
+                # because they will have been hoisted to the relevant rolling
+                # scope
+                return stmt.body
+        elif isinstance(stmt, tvm.tir.BufferStore):
+            if stmt.buffer in rolling_buffer_to_info:
+                rolling_buffer_info = rolling_buffer_to_info[stmt.buffer]
+                indices = []
+                # First modify the access indices to use modulo arithmetic
+                # for the rolling axis
+                for i, index in enumerate(stmt.indices):
+                    if i == rolling_buffer_info.rolling_axis:
+                        indices.append(tvm.tir.FloorMod(index, rolling_buffer_info.rolling_extent))
+                    else:
+                        indices.append(index)
+                buffer_store = tvm.tir.BufferStore(stmt.buffer, stmt.value, indices, stmt.span)
+                # Then wrap the BufferStores in some Ifs to avoid recomputing elements
+                for i, iter_var in enumerate(rolling_buffer_info.axis_iter_vars):
+                    if iter_var is not None and rolling_buffer_info.axis_overlaps[i] > 0:
+                        dmap = {iter_var: arith.IntervalSet(0, 0)}
+                        term_2 = arith.Analyzer().int_set(stmt.indices[i], dmap).min_value
+                        buffer_store = tvm.tir.IfThenElse(
+                            tvm.tir.Or(
+                                iter_var < 1, term_2 >= rolling_buffer_info.axis_overlaps[i]
+                            ),
+                            buffer_store,
+                            None,
+                        )
+                return buffer_store
+        elif isinstance(stmt, tvm.tir.BufferLoad):
+            if stmt.buffer in rolling_buffer_to_info:
+                rolling_buffer_info = rolling_buffer_to_info[stmt.buffer]
+                indices = []
+                # Modify the access indices to use modulo arithmetic
+                # for the rolling axis
+                for i, index in enumerate(stmt.indices):
+                    if i == rolling_buffer_info.rolling_axis:
+                        indices.append(tvm.tir.FloorMod(index, rolling_buffer_info.rolling_extent))
+                    else:
+                        indices.append(index)
+                return tvm.tir.BufferLoad(stmt.buffer, indices, stmt.span)
+
+    def _ftransform(f, mod, ctx):
+        return f.with_body(
+            tvm.tir.stmt_functor.ir_transform(
+                f.body,
+                _pre_visit,
+                _post_visit,
+                [
+                    "tir.AttrStmt",
+                    "tir.BufferRealize",
+                    "tir.For",
+                    "tir.BufferStore",
+                    "tir.BufferLoad",
+                ],
+            )
+        )
+
+    return tvm.tir.transform.prim_func_pass(
+        _ftransform, opt_level=0, name="tir.InjectRollingBuffer"
+    )

Review comment:
       Seems this is a pretty complex pass, would you consider to rewrite it as a C++ implementation? (not necessary in this PR, we can add a TODO here if the C++ migration is planned)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-844438758


   @Hzfengsy took 3 tries, but CI is now passing :)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] manupa-arm commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r642933641



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict, namedtuple
+import math
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Rolling buffers are buffers where one of the dimensions has been made into
+    a circular buffer. Two optimizations are implemented in order to accomplish
+    this: sliding window and storage folding. In particular, the sliding window
+    optimization is applied to the entire buffer (to avoid recomputing elements)
+    and storage folding is then applied to just the rolling dimension.
+
+    Rolling buffers must be inside a loop with only part of the buffer used per
+    iteration. The outermost axis will be rolled over.
+
+    For more information, see the RFC:
+    https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    RollingBufferInfo = namedtuple(
+        "RollingBufferInfo", ["rolling_axis", "rolling_extent", "axis_overlaps", "axis_iter_vars"]
+    )
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer_scope" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    divisor = 1
+                    # Handle the case of fractional strides
+                    # They take this form: floordiv(hh.outer, 2)
+                    # Strip the floordiv and keep track of the divisor
+                    if isinstance(bound.min, tvm.tir.FloorDiv):
+                        divisor = bound.min.b.value
+                        bound.min = bound.min.a
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(
+                            bound.min, tvm.tir.Mul
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.a, tvm.tir.Var
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.b, tvm.tir.IntImm
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        iter_var = bound.min.a
+                        stride = bound.min.b.value
+                    stride = math.ceil(stride / divisor)
+                    bound_iter_vars.append(iter_var)
+                    if iter_var is not None:
+                        bound_overlaps.append(bound.extent.value - stride)
+                    else:
+                        bound_overlaps.append(0)
+
+                # Pick the outermost iter_var that's mentioned in the bounds
+                # to be the rolling axis
+                roll_iter_var = None
+                roll_axis = -1
+                for loop in iter_vars:
+                    iter_var = loop.loop_var
+                    if iter_var in bound_iter_vars:

Review comment:
       Clarification question : why cant we just look at bound_iter_vars directly ? Is there non-outermost iter_vars identified?

##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict, namedtuple
+import math
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Rolling buffers are buffers where one of the dimensions has been made into
+    a circular buffer. Two optimizations are implemented in order to accomplish
+    this: sliding window and storage folding. In particular, the sliding window
+    optimization is applied to the entire buffer (to avoid recomputing elements)
+    and storage folding is then applied to just the rolling dimension.
+
+    Rolling buffers must be inside a loop with only part of the buffer used per
+    iteration. The outermost axis will be rolled over.
+
+    For more information, see the RFC:
+    https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    RollingBufferInfo = namedtuple(
+        "RollingBufferInfo", ["rolling_axis", "rolling_extent", "axis_overlaps", "axis_iter_vars"]
+    )
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)

Review comment:
       nit: are these iter vars or For stmts ?

##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict, namedtuple
+import math
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Rolling buffers are buffers where one of the dimensions has been made into
+    a circular buffer. Two optimizations are implemented in order to accomplish
+    this: sliding window and storage folding. In particular, the sliding window
+    optimization is applied to the entire buffer (to avoid recomputing elements)
+    and storage folding is then applied to just the rolling dimension.
+
+    Rolling buffers must be inside a loop with only part of the buffer used per
+    iteration. The outermost axis will be rolled over.
+
+    For more information, see the RFC:
+    https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    RollingBufferInfo = namedtuple(
+        "RollingBufferInfo", ["rolling_axis", "rolling_extent", "axis_overlaps", "axis_iter_vars"]
+    )
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer_scope" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    divisor = 1
+                    # Handle the case of fractional strides
+                    # They take this form: floordiv(hh.outer, 2)
+                    # Strip the floordiv and keep track of the divisor
+                    if isinstance(bound.min, tvm.tir.FloorDiv):
+                        divisor = bound.min.b.value
+                        bound.min = bound.min.a
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(
+                            bound.min, tvm.tir.Mul
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.a, tvm.tir.Var
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.b, tvm.tir.IntImm
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        iter_var = bound.min.a
+                        stride = bound.min.b.value
+                    stride = math.ceil(stride / divisor)
+                    bound_iter_vars.append(iter_var)
+                    if iter_var is not None:
+                        bound_overlaps.append(bound.extent.value - stride)
+                    else:
+                        bound_overlaps.append(0)
+
+                # Pick the outermost iter_var that's mentioned in the bounds
+                # to be the rolling axis
+                roll_iter_var = None
+                roll_axis = -1
+                for loop in iter_vars:
+                    iter_var = loop.loop_var

Review comment:
       See my comment above, I think iter_vars represent loops and here it extracts the actual iter_var




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] manupa-arm commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-854674543


   adding cc : @jcf94 
   A friendly ping!  We would really appreciate a review on this :) 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-834459963


   Can anyone take a look at this?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r648128260



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict, namedtuple
+import math
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Rolling buffers are buffers where one of the dimensions has been made into
+    a circular buffer. Two optimizations are implemented in order to accomplish
+    this: sliding window and storage folding. In particular, the sliding window
+    optimization is applied to the entire buffer (to avoid recomputing elements)
+    and storage folding is then applied to just the rolling dimension.
+
+    Rolling buffers must be inside a loop with only part of the buffer used per
+    iteration. The outermost axis will be rolled over.
+
+    For more information, see the RFC:
+    https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    RollingBufferInfo = namedtuple(
+        "RollingBufferInfo", ["rolling_axis", "rolling_extent", "axis_overlaps", "axis_iter_vars"]
+    )
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer_scope" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    divisor = 1
+                    # Handle the case of fractional strides
+                    # They take this form: floordiv(hh.outer, 2)
+                    # Strip the floordiv and keep track of the divisor
+                    if isinstance(bound.min, tvm.tir.FloorDiv):
+                        divisor = bound.min.b.value
+                        bound.min = bound.min.a
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(
+                            bound.min, tvm.tir.Mul
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.a, tvm.tir.Var
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        assert isinstance(
+                            bound.min.b, tvm.tir.IntImm
+                        ), "Rolling buffer injection failed: the buffer striding is unsupported"
+                        iter_var = bound.min.a
+                        stride = bound.min.b.value
+                    stride = math.ceil(stride / divisor)
+                    bound_iter_vars.append(iter_var)
+                    if iter_var is not None:
+                        bound_overlaps.append(bound.extent.value - stride)
+                    else:
+                        bound_overlaps.append(0)
+
+                # Pick the outermost iter_var that's mentioned in the bounds
+                # to be the rolling axis
+                roll_iter_var = None
+                roll_axis = -1
+                for loop in iter_vars:
+                    iter_var = loop.loop_var
+                    if iter_var in bound_iter_vars:

Review comment:
       It's because we don't necessarily iterate over a tensor in the same order as its bounds (e.g. we don't have to go axis 0, 1, 2...)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 edited a comment on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
jcf94 edited a comment on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-864671064


   @mbaret Seems this PR is ready for merge once the conflict has been solved?
   
   Do we have a github issue to track the progress of your Cascade scheduling & RFC? If not, I suggest to have one since this schedule primitive can only be used in a special condition. 😄 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-856053017


   > By the way, this PR didn't extent the support of relay integration, does that means currently we're not able to use this feature in an end to end model? Since in my understanding the fusion part of relay does not support to fuse multiple `pool` to one subgraph.
   
   This doesn't currently improve the Relay build flow but rather adds a new scheduling primitive primarily for use with tvm::build at the moment. We do however intend to introduce some inter-operating scheduling as part of this design https://discuss.tvm.apache.org/t/rfc-cascade-scheduling/8119 and hope to make use of this primitive as part of that.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-839377569


   It would be great if you can fix CI errors before requesting reviewing :)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] manupa-arm commented on a change in pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r642905941



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    class RollingBufferInfo:
+        def __init__(self, rolling_axis, rolling_extent, axis_overlaps, axis_iter_vars):
+            self.rolling_axis = rolling_axis
+            self.rolling_extent = rolling_extent
+            self.axis_overlaps = axis_overlaps
+            self.axis_iter_vars = axis_iter_vars
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer" and stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a rolling
+                    # buffer, then update the rolling_buffers set to include the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the buffers
+                # they reference. We'll need this if the buffer gets hoisted and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_strides = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(bound.min, tvm.tir.Mul)

Review comment:
       Well; the reason I suggested exceptions were to transfer the responsibility to the caller (i.e., pass pipeline) to see if the pass is successful. Therefore the "handling" is not running a failing pass by the caller. However, given the status of catching TVM exceptions (last time I tried this, could not catch exception by type), I think assert might also work.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 commented on pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
jcf94 commented on pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#issuecomment-864671064


   @mbaret Seems this PR is ready for merge once the conflict issue has been solved?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbaret closed pull request #7925: Add a 'rolling_buffer' scheduling primitive

Posted by GitBox <gi...@apache.org>.
mbaret closed pull request #7925:
URL: https://github.com/apache/tvm/pull/7925


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org