You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/03/15 19:51:46 UTC
[tvm] branch main updated: Remove pytest dependency in
arm_compute_lib.py (#7556)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 068fed9 Remove pytest dependency in arm_compute_lib.py (#7556)
068fed9 is described below
commit 068fed94cf3468e3df510ac8a9aed635ed746804
Author: Nicola Lancellotti <ni...@arm.com>
AuthorDate: Mon Mar 15 19:51:26 2021 +0000
Remove pytest dependency in arm_compute_lib.py (#7556)
* Add OpAttrContext class which allows to temporarily change an attribute of an operator
Change-Id: I19b809a105ea8769e56bd89e028e090959a08728
* Replace TempOpAttr with OpAttrContext in arm_compute_lib.py
Change-Id: I1c42dd6a29e765b06ce28192397016efeea2e82a
---
python/tvm/relay/op/contrib/arm_compute_lib.py | 39 ++++++++++++++++++++++++--
1 file changed, 36 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py
index 139f25f..fabb639 100644
--- a/python/tvm/relay/op/contrib/arm_compute_lib.py
+++ b/python/tvm/relay/op/contrib/arm_compute_lib.py
@@ -18,11 +18,11 @@
"""Arm Compute Library supported operators."""
import tvm
+from tvm import relay
from tvm._ffi import register_func
from tvm.relay.expr import const
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
-from tvm.relay.testing.temp_op_attr import TempOpAttr
from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr
from .register import register_pattern_table
@@ -111,9 +111,9 @@ def preprocess_module(mod):
return convert_conv
- with TempOpAttr(
+ with OpAttrContext(
"nn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.nn.conv2d)
- ), TempOpAttr(
+ ), OpAttrContext(
"qnn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.qnn.op.conv2d)
):
seq = tvm.transform.Sequential(
@@ -481,3 +481,36 @@ def qnn_add(expr):
return False
return True
+
+
+class OpAttrContext(object):
+ """ Temporarily changes the attr of an op. """
+
+ def __init__(self, op_name, attr_key, attr_value):
+ """Saves the required info for RAII pattern usage.
+
+ Parameters
+ ----------
+ op_name : str
+ The op name.
+
+ attr_key : str
+ The attribute name.
+
+ attr_value : object
+ The attribute value.
+ """
+ self.op = relay.op.get(op_name)
+ self.attr_key = attr_key
+ self.attr_value = attr_value
+
+ def __enter__(self):
+ self.older_attr = self.op.get_attr(self.attr_key)
+ self.op.reset_attr(self.attr_key)
+ self.op.set_attr(self.attr_key, self.attr_value)
+ return self
+
+ def __exit__(self, ptype, value, trace):
+ self.op.reset_attr(self.attr_key)
+ if self.older_attr:
+ self.op.set_attr(self.attr_key, self.older_attr)