You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/14 19:14:11 UTC

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #5235: [RELAY][Fix] i64 indices

jwfromm commented on a change in pull request #5235:
URL: https://github.com/apache/incubator-tvm/pull/5235#discussion_r454584237



##########
File path: tests/python/relay/test_pass_fuse_ops.py
##########
@@ -621,6 +621,81 @@ def expected():
     after = run_opt_pass(expected(), transform.InferType())
     assert tvm.ir.structural_equal(zz, after)
 
+
+def test_fuse_take():
+    """Test fusion case involving concat and take"""
+
+    def before():
+        shape = (tvm.tir.const(10, "int64"),
+                 tvm.tir.const(1, "int64"))
+        x = relay.var("x", shape=shape)
+        concat = relay.concatenate([x,x], axis=-1)
+        out = relay.op.take(concat, indices=relay.const([0], dtype="int64"))
+        return relay.Function(relay.analysis.free_vars(out), out)
+
+    def expected():
+        shape1 = (tvm.tir.const(10, "int64"),
+                  tvm.tir.const(1, "int64"))
+        shape2 = (tvm.tir.const(1, "int64"),)
+        x = relay.var("x", shape=shape1)
+        p0 = relay.var("p0", shape=shape1)
+        p1 = relay.var("p1", shape=shape2,
+                             dtype="int64")
+        c = relay.const([0], dtype="int64")
+        concat = relay.concatenate([p0,p0], axis=-1)
+        out = relay.op.take(concat, indices=p1)
+
+        f0 = relay.Function([p0, p1], out)
+        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+
+        y = relay.Call(f0, [x, c])
+        return relay.Function([x], y)
+
+    orig = before()
+    fuse0(tvm.IRModule.from_expr(orig))
+    m = fuse2(tvm.IRModule.from_expr(orig))
+    relay.build(m, 'llvm')
+    after = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(m["main"], after)
+
+
+def test_fuse_gather_nd():
+    """Test fusion case involving concat and gather_nd"""
+
+    def before():
+        shape = (tvm.tir.const(10, "int64"),
+                 tvm.tir.const(1, "int64"))
+        x = relay.var("x", shape=shape)
+        concat = relay.concatenate([x,x], axis=-1)
+        out = relay.gather_nd(concat, indices=relay.expr.const([[0,1],[1,0]], dtype="int64"))
+        return relay.Function(relay.analysis.free_vars(out), out)
+
+    def expected():
+        shape1 = (tvm.tir.const(10, "int64"),
+                  tvm.tir.const(1, "int64"))
+        shape2 = (tvm.tir.const(2, "int64"),
+                  tvm.tir.const(2, "int64"))
+        x = relay.var("x", shape=shape1)
+        p0 = relay.var("p0", shape=shape1)
+        p1 = relay.var("p1", shape=shape2, dtype="int64")
+        c = relay.const([[0,1],[1,0]], dtype="int64")
+        concat = relay.concatenate([p0,p0], axis=-1)
+        out = relay.gather_nd(concat, indices=p1)
+
+        f0 = relay.Function([p0, p1], out)
+        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

Review comment:
       Just curious, why do we set this to `int32` instead of `int64`. I know that all other tests in this file do the same but it's unclear what the actual purpose is.

##########
File path: src/te/schedule/operation_inline.cc
##########
@@ -63,7 +63,9 @@ class OperationInliner final : public StmtExprMutator {
       } else {
         Map<Var, PrimExpr> vmap;
         for (size_t i = 0; i < args_.size(); ++i) {
-          vmap.Set(args_[i], op->indices[i]);
+          // indices into `operation_` must be in the range of its output shape,
+          // so we can safely cast the indices without worrying about overflow

Review comment:
       I dont understand this comment, can you try to clarify it?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org