You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/20 04:51:12 UTC

[GitHub] [tvm] xqdan commented on a change in pull request #8079: Complete register op from python

xqdan commented on a change in pull request #8079:
URL: https://github.com/apache/tvm/pull/8079#discussion_r635762635



##########
File path: python/tvm/ir/op.py
##########
@@ -85,17 +85,76 @@ def reset_attr(self, attr_name):
         """
         _ffi_api.OpResetAttr(self, attr_name)
 
+    def add_type_rel(self, rel_name, type_rel_func=None):

Review comment:
       ```
   
   from __future__ import absolute_import as _abs
   from tvm.relay.op import op as _op
   from tvm import relay
   from tvm.ir.attrs import DictAttrs
   from tx.utils.utils import get_broadcast_result_shape
   
   
   def npu_add_rel(arg_types, attrs):
       """Define npu.add type relation.
   
       Parameters
       ----------
       arg_types : List[Type]
           All input types.
   
       attrs : DictAttrs
           Attributes of the OpNode.
   
       Returns
       ----------
       out : Type
           The output type.
           Return None if error occurs.
   
       """
       assert len(arg_types) == 2, "type relation arg number mismatch!"
       if attrs:
           assert isinstance(attrs, DictAttrs)
   
       inputa_type = arg_types[0]
       inputb_type = arg_types[1]
       if not isinstance(inputa_type, relay.TensorType):
           return None
       if not isinstance(inputb_type, relay.TensorType):
           return None
   
       output_shape = get_broadcast_result_shape(inputa_type.shape, inputb_type.shape)
   
       return relay.TensorType(output_shape, inputa_type.dtype)
   
   
   # All the following steps must be done to compose an user define Operator.
   _op.register("npu.add", r"code(Add two tensor with inner broadcasting.)code")
   _op.get("npu.add").set_num_inputs(2)
   _op.get("npu.add").add_argument("data_0", "Tensor", "The input data tensor.")
   _op.get("npu.add").add_argument("data_1", "Tensor", "The input data tensor.")
   _op.get("npu.add").add_type_rel("NPUAdd", npu_add_rel)
   _op.get("npu.add").set_support_level(1)
   _op.register_pattern("npu.add", _op.OpPattern.ELEMWISE)
   _op.register_stateful("npu.add", False)
   ```
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org