You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/22 15:14:12 UTC

[GitHub] [tvm] areusch commented on a diff in pull request #13402: [Relay][Frontend] Span filling common API

areusch commented on code in PR #13402:
URL: https://github.com/apache/tvm/pull/13402#discussion_r1029444355


##########
tests/python/frontend/test_common.py:
##########
@@ -27,6 +32,203 @@ def test_key_is_not_present():
     assert not attrs.has_attr("b")
 
 
+def test_set_span():
+    def _verify_env_var_switch():
+        def _res(should_fill):
+            if should_fill:
+                with testing.enable_span_filling():
+                    return set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+            else:
+                with testing.disable_span_filling():
+                    return set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+
+        disable = relay.var("x", shape=(1, 64, 56, 56))
+        enable = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
+
+        assert _verify_structural_equal_with_span(_res(False), disable)
+        assert _verify_structural_equal_with_span(_res(True), enable)
+
+    # Should tag all exprs without span, and stop when expr is span-tagged
+    def _verify_builtin_tuple():
+        def _res():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
+            return set_span(tuple([a, b]), "tuple")
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64", span=_create_span("tuple"))
+            return tuple([a, b])
+
+        res_tuple, golden_tuple = _res(), _golden()
+        assert len(res_tuple) == len(golden_tuple)
+        for i in range(len(res_tuple)):
+            assert _verify_structural_equal_with_span(res_tuple[i], golden_tuple[i])
+
+    def _verify_builtin_list():
+        def _res():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
+            t = relay.Tuple([a, b])
+            t_a = relay.TupleGetItem(t, 0)
+            t_b = relay.TupleGetItem(t, 1)
+            return set_span([t_a, t_b], "list")
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64", span=_create_span("list"))
+            t = relay.Tuple([a, b], span=_create_span("list"))
+            t_a = relay.TupleGetItem(t, 0, span=_create_span("list"))
+            t_b = relay.TupleGetItem(t, 1, span=_create_span("list"))
+            return [t_a, t_b]
+
+        res_list, golden_list = _res(), _golden()
+        assert len(res_list) == len(golden_list)
+        for i in range(len(res_list)):
+            assert _verify_structural_equal_with_span(res_list[i], golden_list[i])
+
+    def _verify_var():
+        x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+        x_expected = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
+        assert _verify_structural_equal_with_span(x, x_expected)
+
+    def _verify_constant():
+        c = set_span(relay.const(np.ones([64, 64, 3, 3]), dtype="int64"), "const_c")
+        c_expected = relay.const(
+            np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("const_c")
+        )
+        assert _verify_structural_equal_with_span(c, c_expected)
+
+    def _verify_call():
+        def _res():
+            x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
+            y = set_span(
+                relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d"
+            )
+            return relay.Function([x], y)
+
+        def _golden():
+            x = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("conv2d"))
+            y = _set_span(
+                relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d"
+            )
+            return relay.Function([x], y)
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    def _verify_tuple():
+        def _res():
+            a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64")
+            t = set_span(relay.Tuple([a, b]), "t")
+            return relay.Function([], t)
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("t"))
+            t = relay.Tuple([a, b], span=_create_span("t"))
+            return relay.Function([], t)
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    def _verify_tuple_getitem():
+        def _res():
+            a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64")
+            t = relay.Tuple([a, b])
+            i = set_span(relay.TupleGetItem(t, 0), "i")
+            return relay.Function([], i)
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("i"))
+            t = relay.Tuple([a, b], span=_create_span("i"))
+            i = relay.TupleGetItem(t, 0, span=_create_span("i"))
+            return relay.Function([], i)
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    def _verify_let():
+        def _res():
+            x = set_span(relay.Var("x"), "x_var")
+            c_1 = relay.const(np.ones(10))
+            add = relay.add(x, x)
+            body = set_span(relay.Let(x, c_1, add), "let")
+
+            c_2 = set_span(relay.const(np.zeros(10)), "zeros")
+            y = set_span(relay.add(body, c_2), "add_2")
+            return relay.Function([x], y)
+
+        def _golden():
+            x = relay.Var("x", span=_create_span("x_var"))
+            c_1 = relay.const(np.ones(10), span=_create_span("let"))
+            add = _set_span(relay.add(x, x), "let")
+            body = relay.Let(x, c_1, add, span=_create_span("let"))
+
+            c_2 = relay.const(np.zeros(10), span=_create_span("zeros"))
+            y = _set_span(relay.add(body, c_2), "add_2")
+            return relay.Function([x], y)
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    def _verify_if():
+        def _res():
+            x = set_span(relay.var("x", shape=[], dtype="float32"), "x_var")
+            y = set_span(relay.var("y", shape=[], dtype="float32"), "y_var")
+            eq = relay.equal(x, y)
+
+            true_branch = set_span(relay.add(x, y), "true_branch")
+            false_branch = relay.subtract(x, y)
+            ife = set_span(relay.If(eq, true_branch, false_branch), "if")
+            return relay.Function([x, y], ife)
+
+        def _golden():
+            x = relay.var("x", shape=[], dtype="float32", span=_create_span("x_var"))
+            y = relay.var("y", shape=[], dtype="float32", span=_create_span("y_var"))
+            eq = _set_span(relay.equal(x, y), "if")
+
+            true_branch = _set_span(relay.add(x, y), "true_branch")
+            false_branch = _set_span(relay.subtract(x, y), "if")
+            ife = relay.If(eq, true_branch, false_branch, span=_create_span("if"))
+            return relay.Function([x, y], ife)
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    def _verify_fn():
+        def _res():
+            x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
+            y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1))
+            f = set_span(relay.Function([x], y), "func")
+            return f
+
+        def _golden():
+            x = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("func"))
+            y = _set_span(
+                relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "func"
+            )
+            f = relay.Function([x], y, span=_create_span("func"))
+            return f
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    _verify_env_var_switch()
+    _verify_builtin_tuple()
+    _verify_builtin_list()
+    _verify_var()
+    _verify_constant()
+    _verify_call()
+    _verify_tuple()
+    _verify_tuple_getitem()
+    _verify_let()
+    _verify_if()
+    _verify_fn()
+
+
 if __name__ == "__main__":
     test_key_is_present()

Review Comment:
   can you change this to `tvm.testing.main()` so the test cases are split out ?



##########
tests/python/relay/utils/tag_span.py:
##########
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay, tir
+from tvm.relay.expr_functor import ExprVisitor
+
+
+def _set_span(expr, src):
+    if isinstance(expr, relay.Call):
+        return relay.Call(expr.op, expr.args, expr.attrs, expr.type_args, _create_span(src))
+    elif isinstance(expr, relay.Var):
+        return relay.var(expr.name_hint, expr.type_annotation, None, None, _create_span(src))
+    elif isinstance(expr, relay.TupleGetItem):
+        return relay.TupleGetItem(expr.tuple_value, expr.index, _create_span(src))
+    elif isinstance(expr, relay.Constant):
+        return relay.Constant(expr.data, _create_span(src))
+    elif isinstance(expr, relay.TupleWrapper):
+        return relay.TupleWrapper(_set_span(expr.tuple_value, src), expr.size)
+    elif isinstance(expr, relay.Tuple):
+        return relay.Tuple(expr.fields, _create_span(src))
+    elif isinstance(expr, tir.AttrStmt):
+        return tir.AttrStmt(expr.node, expr.attr_key, expr.value, expr.body, _create_span(src))
+
+    assert False, f"unsupported type {type(expr)}"
+
+
+def _create_span(src):
+    if isinstance(src, list):
+        tmp_list = []
+        for s in src:
+            if isinstance(s, str):
+                tmp_list.append(_create_span(s))
+            elif isinstance(s, relay.Span):
+                tmp_list.append(s)
+            elif isinstance(s, relay.SequentialSpan):
+                tmp_list.extend(s.spans)
+            elif s is None:
+                tmp_list.append(s)
+            else:
+                assert False, f"unsupported type {type(s)}"
+        return relay.SequentialSpan(tmp_list)
+    return relay.Span(relay.SourceName(src), 0, 0, 0, 0)
+
+
+def _collect_spans(objref):
+    class Collector:
+        def __init__(self):
+            self._spans = []
+
+        def collect(self, objref):
+            if hasattr(objref, "span"):
+                self._spans.append(objref.span)
+
+        @property
+        def get_spans(self):
+            return self._spans
+
+    pov = None
+    if isinstance(objref, relay.Expr):
+        pov = relay.analysis.post_order_visit
+    elif isinstance(objref, (tir.Stmt, tir.expr.PrimExprWithOp)):
+        pov = tir.stmt_functor.post_order_visit
+    else:
+        assert False, f"unsupported type {type(objref)}"
+
+    c = Collector()
+    pov(objref, c.collect)
+    return c.get_spans
+
+
+def _verify_span(lhs, rhs):
+    lhs_spans, rhs_spans = _collect_spans(lhs), _collect_spans(rhs)
+
+    if len(lhs_spans) != len(rhs_spans):
+        return False
+
+    for i in range(len(lhs_spans)):
+        if not tvm.ir.structural_equal(lhs_spans[i], rhs_spans[i]):
+            return False
+    return True

Review Comment:
   this style is a bit different than the typical pytest style e.g. `assert foo == bar` rather than `return foo == bar`.



##########
python/tvm/testing/utils.py:
##########
@@ -2081,3 +2081,28 @@ def pprint(name, obj):
                 f"or an instance of `tvm.tir.PrimFunc`.  "
                 f"Instead, received {type(expected)}."
             )
+
+
+class _control_span_filling:
+    def __init__(self, on=True):
+        self._old_state = os.environ["TVM_SPANFILLING"] if "TVM_SPANFILLING" in os.environ else None

Review Comment:
   just curious why you are consulting os.envrion here?



##########
src/relay/ir/expr.cc:
##########
@@ -72,8 +72,8 @@ Constant::Constant(runtime::NDArray data, Span span) {
 
 TVM_REGISTER_NODE_TYPE(ConstantNode);
 
-TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data) {
-  return Constant(data);
+TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data, Span span) {

Review Comment:
   i'm not sure this concern should block this review, but one thing i notice here is that this now adds Span as a required field to the Relay IR constructor. personally, i think that's reasonable--you can always explicitly provide `Span()` i think if you want to provide an undefined span. it is slightly more burdensome, so i could see some differing opinions around the community here. but personally, i think that's the tradeoff we have to live with, so I wouldn't block this review over it.



##########
tests/python/frontend/test_common.py:
##########
@@ -27,6 +32,203 @@ def test_key_is_not_present():
     assert not attrs.has_attr("b")
 
 
+def test_set_span():

Review Comment:
   suggest to make this a TestCase class and promote the various e.g. `_verify_env_var_switch` functions to class methods e.g. `test_env_var_switch`, so they are separate test cases



##########
python/tvm/relay/frontend/common.py:
##########
@@ -997,3 +1002,167 @@ def try_resolve_var_to_const(x, graph_params):
         return _op.const(value, dtype)
 
     return x
+
+
+class _SpanFiller(ExprMutator):

Review Comment:
   we had been implementing a bunch of passes in C++, since the Python visitor infra involves a lot of calls that cross the FFI. I wonder if you'd be up for moving this into C++?



##########
tests/python/relay/utils/tag_span.py:
##########
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay, tir
+from tvm.relay.expr_functor import ExprVisitor
+
+
+def _set_span(expr, src):

Review Comment:
   we have an analogue `WithFields` in the C++, I think this would be better placed there. for now, see [FunctionWithFields](https://github.com/apache/tvm/blob/main/python/tvm/relay/function.py#L68).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org