You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/30 07:31:08 UTC

[GitHub] [incubator-tvm] mbrookhart commented on a change in pull request #6124: [Relay] change device annotation from post DFS to recursive

mbrookhart commented on a change in pull request #6124:
URL: https://github.com/apache/incubator-tvm/pull/6124#discussion_r462608830



##########
File path: src/relay/transforms/device_annotation.cc
##########
@@ -414,22 +427,23 @@ class DeviceInfo {
     void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); }
 
     void VisitExpr_(const VarNode* vn) final {
-      post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
+      device_tag_[vn] = dev_type_;
     }
 
     void VisitExpr_(const LetNode* ln) final {
       ExprVisitor::VisitExpr_(ln);
-      post_dfs_order_.push_back(std::make_pair(ln, has_copy_));
+      device_tag_[ln] = dev_type_;
     }
 
     void VisitExpr_(const IfNode* in) final {
       ExprVisitor::VisitExpr_(in);
-      post_dfs_order_.push_back(std::make_pair(in, has_copy_));
+      device_tag_[in] = dev_type_;
     }
 
     int num_device_copy_ops_{0};
-    bool has_copy_ = false;
-    std::vector<std::pair<const ExprNode*, bool>> post_dfs_order_;
+    int dev_type_ = -1;
+    int out_dev_type_ = -1;

Review comment:
       I don't love making this class state, precisely because you have to do crazy things with maintaining state in your recursive calls. You could do it as a set of recursive arguments, but that kind of requires re-implementing with ExprFunctor...so maybe this is the cleanest solution.

##########
File path: tests/python/relay/test_pass_annotation.py
##########
@@ -309,6 +309,76 @@ def test_visitor_annotation():
     test_visitor_annotation()
 
 
+def test_propogation():
+    R""" The network and device type is as following:
+                  x           1
+                  |
+                 log          1
+                /   \
+              log2 log10      2
+                \   /
+                 add          2
+                  |
+                 tan          1
+    """
+    ctx1 = tvm.context(1)
+    ctx2 = tvm.context(2)
+
+    expected_dev_type = {
+        'log': ctx1,
+        'log2': ctx2,
+        'log10': ctx2,
+        'add': ctx2,
+        'tan': ctx1
+    }
+
+    x = relay.var("x", shape=(3,))
+
+    def annotated():
+        log = relay.log(x)
+        _log = relay.annotation.on_device(log, expected_dev_type['log'])
+        log2 = relay.log2(_log)
+        _log2 = relay.annotation.on_device(log2, expected_dev_type['log2'])
+        log10 = relay.log10(_log)
+        _log10 = relay.annotation.on_device(log10, expected_dev_type['log10'])
+        add = relay.add(_log2, _log10)
+        _add = relay.annotation.on_device(add, expected_dev_type['add'])
+        tan = relay.tan(_add)
+        _tan = relay.annotation.on_device(tan, expected_dev_type['tan'])
+
+        func = run_opt_pass(_tan, transform.RewriteAnnotatedOps(ctx1.device_type))
+        return func
+
+    def expected():
+        log = relay.log(x)
+        _log_left = relay.device_copy(log, ctx1, ctx2)
+        _log_right = relay.device_copy(log, ctx1, ctx2)
+        log2 = relay.log2(_log_left)
+        log10 = relay.log10(_log_right)
+        add = relay.add(log2, log10)
+        _add = relay.device_copy(add, ctx2, ctx1)
+        tan = relay.tan(_add)
+
+        func = run_opt_pass(tan, transform.InferType())
+        return func
+
+    annotated_expr = annotated()
+    expected_expr = expected()
+    assert tvm.ir.structural_equal(annotated_expr, expected_expr)

Review comment:
       I'm curious that master passes this check, but fails on line 377. Why doesn't structural equal properly resolve the error in the device copy op?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org