You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/04/22 17:26:03 UTC

[GitHub] [tvm] mbrookhart edited a comment on pull request #7910: [Relay] Shape func fix for all_class_nms and where op

mbrookhart edited a comment on pull request #7910:
URL: https://github.com/apache/tvm/pull/7910#issuecomment-825041882


   The subgraph that's causing the where issues is this:
   ```
     %5 = take(%p3, 1 /* ty=int32 */) /* ty=int64 */;
     %6 = add(4 /* ty=int64 */, %5) /* ty=int64 */;
     %7 = where(False /* ty=bool */, %6, 4 /* ty=int64 */) /* ty=int64 */;
     %8 = take(%p4, %7, axis=1) /* ty=Tensor[(?), float32] */;
   ```
   And this unit test reproduces it:
   ```
   diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
   index ef02b6f10..7b21c01b2 100644
   --- a/tests/python/relay/test_any.py
   +++ b/tests/python/relay/test_any.py
   @@ -1512,6 +1512,22 @@ def test_any_where():
            any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4), y_np_shape_invalid=(2, 4)
        )
    
   +    # Test scalar where in a dynamically shaped graph
   +    x_np = np.random.randn(2).astype("int64")
   +    y_np = np.random.randn(2, 6).astype("float32")
   +    expected = y_np[:, 4]
   +    x = relay.var("x", shape=any_dims(1), dtype="int64")
   +    y = relay.var("y", shape=any_dims(2), dtype="float32")
   +
   +    left = relay.take(x, relay.const(1, dtype="int32")) + relay.const(4, "int64")
   +    right = relay.const(4, "int64")
   +    where = relay.where(relay.const(False, "bool"), left, right)
   +    z = relay.take(y, where, axis=1)
   +
   +    mod = tvm.IRModule()
   +    mod["main"] = relay.Function([x, y], z)
   +    check_result([x_np, y_np], mod, expected)
   +
    
    @tvm.testing.uses_gpu
    def test_non_max_suppression():
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org