You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/01 11:45:28 UTC

[GitHub] [tvm] ekalda commented on a change in pull request #9589: [microNPU] Add support for TFLite concatenate

ekalda commented on a change in pull request #9589:
URL: https://github.com/apache/tvm/pull/9589#discussion_r760107566



##########
File path: python/tvm/relay/backend/contrib/ethosu/tir/passes.py
##########
@@ -488,3 +485,170 @@ def _encode_constants(mod):
         return new_func, new_const_dict
 
     return _encode_constants
+
+
+def RemoveConcatenates():
+    """Remove concatenate operators by modifying the input buffers to write directly into
+    the concatenated buffer with the appropriate offset.
+
+    This pass works in two stages. The first finds every concatenate operation (marked by
+    pragma_op = ethosu_concatenate) and it performs the following analysis. For each buffer
+    that is concatenated, the buffer is marked that it is to be replaced with the concat
+    buffer and the axis along which it is concatenated as well as the offset along that
+    axis is recorded in 'ReplaceInfo'. Once this analysis is completed, the concatenate
+    loop nest along with its buffer realization statements are removed.
+
+    In the second stage, the input buffers to the concatenate operators are rewritten
+    to use the concat buffer directly. This means applying the correct offset to the
+    concatenation axis where ever the buffer is loaded or stored. Additionally, as the
+    realization statements for the concat buffers were removed in the first stage, they
+    are rewritten in place of the input buffer realization with the earliest liveness."""
+
+    in_concat = [False]  # Whether the visitor is currently inside a concatenate operator
+    concat_buffers = []  # The buffers produced by concatenate operators
+    buffer_replace_map = {}  # A map of buffers to be replaced with the concat buffer
+    attrs_by_buffer = {}  # AttrStmts by the buffer they reference
+    realizes_by_buffer = {}  # BufferRealize statements by the buffer they reference
+    first_replacements = {}  # The first buffers to be replaced by a given concat buffer
+
+    ReplaceInfo = namedtuple("ReplaceInfo", ["buffer", "axis", "offset"])
+
+    def _get_replace_info(buffer_load, concat_buffer):
+        axis = 0
+        offset = 0
+        dmap = dict()
+
+        for i, index in enumerate(buffer_load.indices):
+            if isinstance(index, tvm.tir.Sub):
+                axis = i
+                dmap = {}
+
+                def _visit(stmt):
+                    if isinstance(stmt, tvm.tir.Var):
+                        dmap[stmt] = tvm.arith.IntervalSet(0, 0)
+
+                tvm.tir.stmt_functor.post_order_visit(index, _visit)
+                offset = abs(int(tvm.arith.Analyzer().int_set(index, dmap).max_value))
+        return ReplaceInfo(concat_buffer, axis, offset)
+
+    def _pre_remove(stmt):
+        if isinstance(stmt, tvm.tir.BufferRealize):
+            # Record the realize statements by buffer as we need to hoist some of these
+            realizes_by_buffer[stmt.buffer] = stmt
+        if isinstance(stmt, tvm.tir.AttrStmt):
+            if stmt.attr_key == "realize_scope" and isinstance(stmt.node, tvm.tir.Buffer):
+                # Record the realize_scope attrs by buffer as we need to hoist some of these
+                attrs_by_buffer[stmt.node] = stmt
+            if stmt.attr_key == "pragma_op" and stmt.value.value == "ethosu_concatenate":
+                # Record that we're entering a concatenate loop nest
+                in_concat[0] = True
+        if isinstance(stmt, tvm.tir.BufferLoad) and in_concat[0]:
+            # Any buffer loaded inside a concat is a buffer we intend to replace with this pass.
+            # The buffer_replace_map keeps track of which buffers need replacing with the
+            # concat buffer.
+            replace_info = _get_replace_info(stmt, concat_buffers[-1])
+            buffer_replace_map[stmt.buffer] = replace_info
+        if isinstance(stmt, tvm.tir.BufferStore) and in_concat[0]:
+            # If we're inside a concat, the BufferStore indicates what the concat buffer is
+            concat_buffers.append(stmt.buffer)
+
+    def _post_remove(stmt):
+        if isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer) and stmt.node in concat_buffers:
+                return stmt.body
+            if stmt.attr_key == "pragma_op" and stmt.value.value == "ethosu_concatenate":
+                # When we leave a concatenate operator, record it and then remove the loop nest
+                in_concat[0] = False
+                return tvm.tir.Evaluate(0)
+        if isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in concat_buffers:
+                return stmt.body
+        return None
+
+    def _pre_replace(stmt):
+        if isinstance(stmt, (tvm.tir.BufferLoad, tvm.tir.BufferStore)):
+            # The first buffer referenced that needs replacing with a concat buffer shall
+            # be the one that the concat buffer realize is hoisted to.
+            if stmt.buffer in buffer_replace_map:
+                concat_buffer = buffer_replace_map[stmt.buffer].buffer
+                if concat_buffer not in first_replacements:
+                    first_replacements[concat_buffer] = stmt.buffer
+
+    def _post_replace(stmt):
+        if isinstance(stmt, tvm.tir.BufferStore):
+            if stmt.buffer in buffer_replace_map:
+                # Replace the original buffer store with a new one into the concat buffer
+                # and adjust the indices accordingly to account for the offset
+                replace_info = buffer_replace_map[stmt.buffer]
+                concat_buffer = replace_info.buffer
+                new_indices = list(stmt.indices)
+                new_indices[replace_info.axis] += replace_info.offset
+                # DODGY STORE NODE

Review comment:
       That comment is courtesy of @mbaret so I'm not exactly sure what makes this store node dodgy, but I updated that comment anyway :) 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org