You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/09/01 10:32:50 UTC

[tvm] branch main updated: [Relay] Extract intermediate node by its expression ID (#12646)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 38ba8c0bb6 [Relay] Extract intermediate node by its expression ID (#12646)
38ba8c0bb6 is described below

commit 38ba8c0bb69dd76203a96ba6b2a5c067fe0b2ba0
Author: sisleyli <43...@users.noreply.github.com>
AuthorDate: Thu Sep 1 18:32:42 2022 +0800

    [Relay] Extract intermediate node by its expression ID (#12646)
    
    [Relay] Extract Intermediate Expr by relay expr ID for analysis
    
    modify doc comments
    
    Co-authored-by: Bin Li <bi...@amd.com>
---
 python/tvm/relay/analysis/analysis.py              |  38 ++++++
 src/relay/analysis/extract_intermediate_expr.cc    |  88 ++++++++++++++
 .../test_analysis_extract_intermediate_expr.py     | 130 +++++++++++++++++++++
 3 files changed, 256 insertions(+)

diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py
index 3b38c07a0a..12f659f003 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -431,3 +431,41 @@ def get_calibration_data(mod, data):
         calib_data[gvar] = value
 
     return calib_data
+
+
+def extract_intermdeiate_expr(mod, expr_id):
+    """Extract Relay Expr by its expression ID
+
+    This function is used for extracting Relay Expr
+    by its expression ID of the main function
+    that we can see in `print(mod["main"])`.
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+
+    expr_id : the Expr ID that we want to extract
+
+    Returns
+    -------
+    ret : Extracted IRModule
+
+    Examples
+    --------
+    .. code-block:: python
+
+        # Suppose our module is printed like this:
+        # def @main(%x: Tensor[(1, 1, 5, 1), float32], %w1, %w2) {
+        #   %0 = nn.conv2d(%x, %w1, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]);
+        #   %1 = nn.conv2d(%0, %w2, padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3]);
+        #   %2 = add(%0, %1);
+        #   %3 = split(%2, indices_or_sections=1);
+        #   %4 = %3.0;
+        #   add(%4, 1f)
+        # }
+        # if we want to extract `%1 = nn.conv2d`
+        from tvm import relay
+
+        relay.analysis.extract_intermdeiate_expr(mod, 1)
+    """
+    return _ffi_api.ExtractIntermediateExpr(mod, expr_id)
diff --git a/src/relay/analysis/extract_intermediate_expr.cc b/src/relay/analysis/extract_intermediate_expr.cc
new file mode 100644
index 0000000000..d7466e2729
--- /dev/null
+++ b/src/relay/analysis/extract_intermediate_expr.cc
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file extract_intermediate_expr.cc
+ * \brief Used for extracting Relay Expr
+    by the expression ID of the main function
+    that we can see in `print(mod["main"])`.
+ */
+#include <tvm/node/structural_hash.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+
+namespace tvm {
+namespace relay {
+
+class ExtractIntermediateExprWrapper : private MixedModeVisitor {
+ public:
+  explicit ExtractIntermediateExprWrapper(const IRModule& mod, const int expr_id)
+      : mod_(mod), target_expr_id_(expr_id), counter_(0) {}
+
+  IRModule Extract() {
+    VisitExpr(this->mod_->Lookup("main"));
+
+    // ensure the target expr_id we want to extract is valid.
+    ICHECK(target_expr_id_ >= 0 && target_expr_id_ < counter_);
+
+    return IRModule::FromExpr(target_op_, {});
+  }
+
+ private:
+  using MixedModeVisitor::VisitExpr_;
+
+  const IRModule mod_;
+  /*! \brief the expr id that we want to extract. */
+  const int target_expr_id_;
+  int counter_;
+  Expr target_op_;
+
+  void VisitExpr_(const CallNode* n) final {
+    CheckCounterAndIncrease(GetRef<Expr>(n));
+    MixedModeVisitor::VisitExpr_(n);
+  }
+
+  void VisitExpr_(const TupleNode* n) final {
+    CheckCounterAndIncrease(GetRef<Expr>(n));
+    MixedModeVisitor::VisitExpr_(n);
+  }
+
+  void VisitExpr_(const TupleGetItemNode* n) final {
+    CheckCounterAndIncrease(GetRef<Expr>(n));
+    MixedModeVisitor::VisitExpr_(n);
+  }
+
+  void CheckCounterAndIncrease(const Expr& expr) {
+    if (target_expr_id_ == counter_) {
+      target_op_ = expr;
+    }
+    ++counter_;
+  }
+};
+
+IRModule ExtractIntermediateExprPacked(const IRModule& mod, const int expr_id) {
+  return ExtractIntermediateExprWrapper(mod, expr_id).Extract();
+}
+
+TVM_REGISTER_GLOBAL("relay.analysis.ExtractIntermediateExpr")
+    .set_body_typed(ExtractIntermediateExprPacked);
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_analysis_extract_intermediate_expr.py b/tests/python/relay/test_analysis_extract_intermediate_expr.py
new file mode 100644
index 0000000000..abcaf880b4
--- /dev/null
+++ b/tests/python/relay/test_analysis_extract_intermediate_expr.py
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test function extraction"""
+import pytest
+import tvm
+from tvm import relay
+
+
+def get_conv_net():
+    """This gets the net for:
+          conv2d
+          /  |
+         /   |
+    conv2d   |
+        \    |
+         \   |
+        elemwise add
+             |
+             |
+             |
+           split
+             |
+             |
+             |
+        elemwise add
+    """
+    dshape = (1, 1, 5, 1)
+    x = relay.var("x", shape=dshape)
+    y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+    x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+
+    z = relay.add(y, x1)
+
+    tuple_out = relay.op.split(z, indices_or_sections=1, axis=0)
+
+    tuple_0_add = relay.add(tuple_out[0], relay.const(1, dtype="float32"))
+
+    return tvm.IRModule.from_expr(tuple_0_add)
+
+
+def get_conv2d():
+    x = relay.var("x", shape=(1, 56, 56, 64))
+    weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+    y = relay.nn.conv2d(
+        x,
+        weight1,
+        channels=32,
+        kernel_size=(3, 3),
+        padding=(1, 1),
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+    )
+    return tvm.IRModule.from_expr(y)
+
+
+def test_extract():
+    dshape = (1, 1, 5, 1)
+
+    def before():
+        return get_conv_net()
+
+    def expected_0():
+        x = relay.var("x", shape=dshape)
+        y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+        return tvm.IRModule.from_expr(y)
+
+    def expected_1():
+        x = relay.var("x", shape=dshape)
+        y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+        x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+        return tvm.IRModule.from_expr(x1)
+
+    def expected_2():
+        x = relay.var("x", shape=dshape)
+        y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+        x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+        z = relay.add(y, x1)
+        return tvm.IRModule.from_expr(z)
+
+    def expected_3():
+        x = relay.var("x", shape=dshape)
+        y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+        x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+        z = relay.add(y, x1)
+        tuple_out = relay.op.split(z, indices_or_sections=1, axis=0)
+        return tvm.IRModule.from_expr(tuple_out.astuple())
+
+    def expected_4():
+        # check tuple node
+        x = relay.var("x", shape=dshape)
+        y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+        x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1)
+        z = relay.add(y, x1)
+        tuple_out = relay.op.split(z, indices_or_sections=1, axis=0)
+        return tvm.IRModule.from_expr(tuple_out[0])
+
+    assert tvm.ir.structural_equal(
+        relay.analysis.extract_intermdeiate_expr(before(), 0), expected_0()
+    )
+    assert tvm.ir.structural_equal(
+        relay.analysis.extract_intermdeiate_expr(before(), 1), expected_1()
+    )
+    assert tvm.ir.structural_equal(
+        relay.analysis.extract_intermdeiate_expr(before(), 2), expected_2()
+    )
+    assert tvm.ir.structural_equal(
+        (relay.analysis.extract_intermdeiate_expr(before(), 3)), expected_3()
+    )
+    assert tvm.ir.structural_equal(
+        relay.analysis.extract_intermdeiate_expr(before(), 4), expected_4()
+    )
+    assert tvm.ir.structural_equal(relay.analysis.extract_intermdeiate_expr(before(), 5), before())
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])