You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/09/01 10:32:50 UTC
[tvm] branch main updated: [Relay] Extract intermediate node by its expression ID (#12646)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 38ba8c0bb6 [Relay] Extract intermediate node by its expression ID (#12646)
38ba8c0bb6 is described below
commit 38ba8c0bb69dd76203a96ba6b2a5c067fe0b2ba0
Author: sisleyli <43...@users.noreply.github.com>
AuthorDate: Thu Sep 1 18:32:42 2022 +0800
[Relay] Extract intermediate node by its expression ID (#12646)
[Relay] Extract Intermediate Expr by relay expr ID for analysis
modify doc comments
Co-authored-by: Bin Li <bi...@amd.com>
---
python/tvm/relay/analysis/analysis.py | 38 ++++++
src/relay/analysis/extract_intermediate_expr.cc | 88 ++++++++++++++
.../test_analysis_extract_intermediate_expr.py | 130 +++++++++++++++++++++
3 files changed, 256 insertions(+)
diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py
index 3b38c07a0a..12f659f003 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -431,3 +431,41 @@ def get_calibration_data(mod, data):
calib_data[gvar] = value
return calib_data
+
+
+def extract_intermdeiate_expr(mod, expr_id):
+ """Extract Relay Expr by its expression ID
+
+ This function is used for extracting Relay Expr
+ by its expression ID of the main function
+ that we can see in `print(mod["main"])`.
+
+ Parameters
+ ----------
+ mod : tvm.IRModule
+
+ expr_id : the Expr ID that we want to extract
+
+ Returns
+ -------
+ ret : Extracted IRModule
+
+ Examples
+ --------
+ .. code-block:: python
+
+ # Suppose our module is printed like this:
+ # def @main(%x: Tensor[(1, 1, 5, 1), float32], %w1, %w2) {
+ # %0 = nn.conv2d(%x, %w1, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]);
+ # %1 = nn.conv2d(%0, %w2, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]);
+ # %2 = add(%0, %1);
+ # %3 = split(%2, indices_or_sections=1);
+ # %4 = %3.0;
+ # add(%4, 1f)
+ # }
+ # if we want to extract `%1 = nn.conv2d`
+ from tvm import relay
+
+ relay.analysis.extract_intermdeiate_expr(mod, 1)
+ """
+ return _ffi_api.ExtractIntermediateExpr(mod, expr_id)
diff --git a/src/relay/analysis/extract_intermediate_expr.cc b/src/relay/analysis/extract_intermediate_expr.cc
new file mode 100644
index 0000000000..d7466e2729
--- /dev/null
+++ b/src/relay/analysis/extract_intermediate_expr.cc
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file extract_intermediate_expr.cc
+ * \brief Used for extracting Relay Expr
+ by the expression ID of the main function
+ that we can see in `print(mod["main"])`.
+ */
+#include <tvm/node/structural_hash.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+
+namespace tvm {
+namespace relay {
+
+class ExtractIntermediateExprWrapper : private MixedModeVisitor {
+ public:
+ explicit ExtractIntermediateExprWrapper(const IRModule& mod, const int expr_id)
+ : mod_(mod), target_expr_id_(expr_id), counter_(0) {}
+
+ IRModule Extract() {
+ VisitExpr(this->mod_->Lookup("main"));
+
+ // ensure the target expr_id we want to extract is valid.
+ ICHECK(target_expr_id_ >= 0 && target_expr_id_ < counter_);
+
+ return IRModule::FromExpr(target_op_, {});
+ }
+
+ private:
+ using MixedModeVisitor::VisitExpr_;
+
+ const IRModule mod_;
+ /*! \brief the expr id that we want to extract. */
+ const int target_expr_id_;
+ int counter_;
+ Expr target_op_;
+
+ void VisitExpr_(const CallNode* n) final {
+ CheckCounterAndIncrease(GetRef<Expr>(n));
+ MixedModeVisitor::VisitExpr_(n);
+ }
+
+ void VisitExpr_(const TupleNode* n) final {
+ CheckCounterAndIncrease(GetRef<Expr>(n));
+ MixedModeVisitor::VisitExpr_(n);
+ }
+
+ void VisitExpr_(const TupleGetItemNode* n) final {
+ CheckCounterAndIncrease(GetRef<Expr>(n));
+ MixedModeVisitor::VisitExpr_(n);
+ }
+
+ void CheckCounterAndIncrease(const Expr& expr) {
+ if (target_expr_id_ == counter_) {
+ target_op_ = expr;
+ }
+ ++counter_;
+ }
+};
+
+IRModule ExtractIntermediateExprPacked(const IRModule& mod, const int expr_id) {
+ return ExtractIntermediateExprWrapper(mod, expr_id).Extract();
+}
+
+TVM_REGISTER_GLOBAL("relay.analysis.ExtractIntermediateExpr")
+ .set_body_typed(ExtractIntermediateExprPacked);
+
+} // namespace relay
+} // namespace tvm
diff --git a/tests/python/relay/test_analysis_extract_intermediate_expr.py b/tests/python/relay/test_analysis_extract_intermediate_expr.py
new file mode 100644
index 0000000000..abcaf880b4
--- /dev/null
+++ b/tests/python/relay/test_analysis_extract_intermediate_expr.py
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test function extraction"""
+import pytest
+import tvm
+from tvm import relay
+
+
+def get_conv_net():
+ """This gets the net for:
+ conv2d
+ / |
+ / |
+ conv2d |
+ \ |
+ \ |
+ elemwise add
+ |
+ |
+ |
+ split
+ |
+ |
+ |
+ elemwise add
+ """
+ dshape = (1, 1, 5, 1)
+ x = relay.var("x", shape=dshape)
+ y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+ x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+
+ z = relay.add(y, x1)
+
+ tuple_out = relay.op.split(z, indices_or_sections=1, axis=0)
+
+ tuple_0_add = relay.add(tuple_out[0], relay.const(1, dtype="float32"))
+
+ return tvm.IRModule.from_expr(tuple_0_add)
+
+
+def get_conv2d():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+ y = relay.nn.conv2d(
+ x,
+ weight1,
+ channels=32,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ )
+ return tvm.IRModule.from_expr(y)
+
+
+def test_extract():
+ dshape = (1, 1, 5, 1)
+
+ def before():
+ return get_conv_net()
+
+ def expected_0():
+ x = relay.var("x", shape=dshape)
+ y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+ return tvm.IRModule.from_expr(y)
+
+ def expected_1():
+ x = relay.var("x", shape=dshape)
+ y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+ x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+ return tvm.IRModule.from_expr(x1)
+
+ def expected_2():
+ x = relay.var("x", shape=dshape)
+ y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+ x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+ z = relay.add(y, x1)
+ return tvm.IRModule.from_expr(z)
+
+ def expected_3():
+ x = relay.var("x", shape=dshape)
+ y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+ x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+ z = relay.add(y, x1)
+ tuple_out = relay.op.split(z, indices_or_sections=1, axis=0)
+ return tvm.IRModule.from_expr(tuple_out.astuple())
+
+ def expected_4():
+ # check tuple node
+ x = relay.var("x", shape=dshape)
+ y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+ x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+ z = relay.add(y, x1)
+ tuple_out = relay.op.split(z, indices_or_sections=1, axis=0)
+ return tvm.IRModule.from_expr(tuple_out[0])
+
+ assert tvm.ir.structural_equal(
+ relay.analysis.extract_intermdeiate_expr(before(), 0), expected_0()
+ )
+ assert tvm.ir.structural_equal(
+ relay.analysis.extract_intermdeiate_expr(before(), 1), expected_1()
+ )
+ assert tvm.ir.structural_equal(
+ relay.analysis.extract_intermdeiate_expr(before(), 2), expected_2()
+ )
+ assert tvm.ir.structural_equal(
+ (relay.analysis.extract_intermdeiate_expr(before(), 3)), expected_3()
+ )
+ assert tvm.ir.structural_equal(
+ relay.analysis.extract_intermdeiate_expr(before(), 4), expected_4()
+ )
+ assert tvm.ir.structural_equal(relay.analysis.extract_intermdeiate_expr(before(), 5), before())
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])