You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/16 10:15:08 UTC

[GitHub] [tvm] ekalda commented on a change in pull request #9457: Add the Arm(R) Ethos(TM)-U NPU identity operator

ekalda commented on a change in pull request #9457:
URL: https://github.com/apache/tvm/pull/9457#discussion_r750115026



##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -123,6 +123,108 @@ def __call__(self, *args, **kwargs):
         pass
 
 
+class StridedSliceRewriter(DFPatternCallback):
+    """This pass brings the strided slice out of the partitioned function"""
+
+    def __init__(self):
+        super().__init__(require_type=True, rewrite_once=True)
+        self.pattern = (wildcard().has_attr({"Composite": "ethosu.strided_slice"}))(wildcard())
+
+    def callback(
+        self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
+    ) -> tvm.relay.Expr:
+        input = post.args[0]
+        attrs = post.op.body.attrs
+        begin = attrs.begin
+        end = attrs.end
+        strides = attrs.strides
+        axes = attrs.axes
+        slice_mode = attrs.slice_mode
+        strided_slice = relay.op.strided_slice(
+            input, begin, end, strides=strides, axes=axes, slice_mode=slice_mode
+        )
+        return strided_slice
+
+
+@ir.transform.module_pass(opt_level=1)
+class LegalizeStridedSlice:
+    """This is the pass that wraps StridedSliceRewriter"""
+
+    def transform_module(
+        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+    ) -> tvm.ir.IRModule:
+        for global_var, func in mod.functions.items():
+            func = rewrite(StridedSliceRewriter(), func)
+            mod.update_func(global_var, func)
+        return mod
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
+class ReshapeRewriter(DFPatternCallback):
+    """This pass brings the reshape out of the partitioned function"""
+
+    def __init__(self):
+        super().__init__(require_type=True, rewrite_once=True)
+        self.pattern = (wildcard().has_attr({"Composite": "ethosu.reshape"}))(wildcard())
+
+    def callback(
+        self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
+    ) -> tvm.relay.Expr:
+        reshape_input = post.args[0]
+        new_shape = post.op.body.attrs.newshape
+        reshape = relay.op.reshape(reshape_input, newshape=new_shape)
+        return reshape
+
+
+@ir.transform.module_pass(opt_level=1)
+class LegalizeReshape:
+    """This is the pass that wraps ReshapeRewriter"""
+
+    def transform_module(
+        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+    ) -> tvm.ir.IRModule:
+        for global_var, func in mod.functions.items():
+            func = rewrite(ReshapeRewriter(), func)
+            mod.update_func(global_var, func)
+        return mod
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
+class NoOpRewriter(DFPatternCallback):
+    """This pass adds and idenity operator to reshape and strided slice to avoid a no op without a consumer"""

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org