You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/03/15 19:51:46 UTC

[tvm] branch main updated: Remove pytest dependency in arm_compute_lib.py (#7556)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 068fed9  Remove pytest dependency in arm_compute_lib.py (#7556)
068fed9 is described below

commit 068fed94cf3468e3df510ac8a9aed635ed746804
Author: Nicola Lancellotti <ni...@arm.com>
AuthorDate: Mon Mar 15 19:51:26 2021 +0000

    Remove pytest dependency in arm_compute_lib.py (#7556)
    
    * Add OpAttrContext class which allows to temporarily change an attribute of an operator
    
    Change-Id: I19b809a105ea8769e56bd89e028e090959a08728
    
    * Replace TempOpAttr with OpAttrContext in arm_compute_lib.py
    
    Change-Id: I1c42dd6a29e765b06ce28192397016efeea2e82a
---
 python/tvm/relay/op/contrib/arm_compute_lib.py | 39 ++++++++++++++++++++++++--
 1 file changed, 36 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py
index 139f25f..fabb639 100644
--- a/python/tvm/relay/op/contrib/arm_compute_lib.py
+++ b/python/tvm/relay/op/contrib/arm_compute_lib.py
@@ -18,11 +18,11 @@
 """Arm Compute Library supported operators."""
 import tvm
 
+from tvm import relay
 from tvm._ffi import register_func
 from tvm.relay.expr import const
 from tvm.relay import transform
 from tvm.relay.build_module import bind_params_by_name
-from tvm.relay.testing.temp_op_attr import TempOpAttr
 
 from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr
 from .register import register_pattern_table
@@ -111,9 +111,9 @@ def preprocess_module(mod):
 
         return convert_conv
 
-    with TempOpAttr(
+    with OpAttrContext(
         "nn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.nn.conv2d)
-    ), TempOpAttr(
+    ), OpAttrContext(
         "qnn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.qnn.op.conv2d)
     ):
         seq = tvm.transform.Sequential(
@@ -481,3 +481,36 @@ def qnn_add(expr):
             return False
 
     return True
+
+
+class OpAttrContext(object):
+    """ Temporarily changes the attr of an op. """
+
+    def __init__(self, op_name, attr_key, attr_value):
+        """Saves the required info for RAII pattern usage.
+
+        Parameters
+        ----------
+        op_name : str
+            The op name.
+
+        attr_key : str
+            The attribute name.
+
+        attr_value : object
+            The attribute value.
+        """
+        self.op = relay.op.get(op_name)
+        self.attr_key = attr_key
+        self.attr_value = attr_value
+
+    def __enter__(self):
+        self.older_attr = self.op.get_attr(self.attr_key)
+        self.op.reset_attr(self.attr_key)
+        self.op.set_attr(self.attr_key, self.attr_value)
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        self.op.reset_attr(self.attr_key)
+        if self.older_attr:
+            self.op.set_attr(self.attr_key, self.older_attr)