You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/09/04 21:51:15 UTC

[incubator-tvm] branch master updated: [Relay/topi] Support scalar inputs in where op (#6383)

This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 8508ec3  [Relay/topi] Support scalar inputs in where op (#6383)
8508ec3 is described below

commit 8508ec34f1a6f52736747708a4227bedeb4899ff
Author: masahi <ma...@gmail.com>
AuthorDate: Sat Sep 5 06:51:01 2020 +0900

    [Relay/topi] Support scalar inputs in where op (#6383)
    
    * support where with scalars
    
    * add test for where with scalar
    
    * add comment
---
 include/tvm/topi/transform.h         | 21 +++++++++++++--------
 src/relay/op/tensor/transform.cc     | 13 ++++++++++++-
 tests/python/relay/test_op_level4.py | 31 ++++++++++++++++++++++++-------
 3 files changed, 49 insertions(+), 16 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index b09b035..af59928 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -891,16 +891,22 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
       << " vs " << y->shape.size();
   CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
                                << y->dtype;
-  Array<PrimExpr> oshape = x->shape;
-  Tensor out;
 
-  if (condition->shape.size() != 1) {
+  if (x->shape.size() == 0) {
+    return compute(
+        condition->shape,
+        [&](const Array<Var>& indices) {
+          Array<PrimExpr> condition_idx{indices[0]};
+          return tvm::tir::Select(condition(condition_idx) != 0, x(), y());
+        },
+        name, tag);
+  } else if (condition->shape.size() != 1) {
     CHECK_EQ(condition->shape.size(), x->shape.size())
         << "condition array must be either have the same shape as x or to be a "
            "1-D array.Got different number of dimension: "
         << condition->shape.size() << " vs " << x->shape.size();
-    out = compute(
-        oshape,
+    return compute(
+        x->shape,
         [&](const Array<Var>& indices) {
           return tvm::tir::Select(condition(indices) != 0, x(indices), y(indices));
         },
@@ -909,15 +915,14 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
     CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
         << "If condition is 1-D, the first dimension must be the same as x: " << condition->shape[0]
         << " vs " << x->shape[0];
-    out = compute(
-        oshape,
+    return compute(
+        x->shape,
         [&](const Array<Var>& indices) {
           Array<PrimExpr> condition_idx{indices[0]};
           return tvm::tir::Select(condition(condition_idx) != 0, x(indices), y(indices));
         },
         name, tag);
   }
-  return out;
 }
 
 /*!
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 5126b1d..40051e4 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1662,7 +1662,12 @@ bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
           << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
     }
   }
-  reporter->Assign(types[3], TensorType(x_shape, x->dtype));
+  if (x_shape.size() == 0) {
+    // if x and y are scalar, the condition shape becomes the output shape
+    reporter->Assign(types[3], TensorType(cond_shape, x->dtype));
+  } else {
+    reporter->Assign(types[3], TensorType(x_shape, x->dtype));
+  }
   return true;
 }
 
@@ -1694,6 +1699,9 @@ size is the same as x’s first dimension size. Each row of the output array
 is from x’s row if the corresponding element from condition is true, and
 from y’s row if false.
 
+When x and y are scalars, condition must be an 1D array. The output shape
+is the same as condition's shape.
+
 Note that all non-zero values are interpreted as True in condition.
 
 Examples::
@@ -1707,6 +1715,9 @@ Examples::
   cond = [1, 0]
   where(cond, x, y) = [[1, 2], [7, 8]]
 
+  cond = [0, 1]
+  where(cond, 1, -1) = [-1, 1]
+
 )code" TVM_ADD_FILELINE)
     .add_argument("condition", "Tensor", "Condition array")
     .add_argument("x", "Tensor", "First array to be selected")
diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py
index af38264..8c62f8c 100644
--- a/tests/python/relay/test_op_level4.py
+++ b/tests/python/relay/test_op_level4.py
@@ -144,6 +144,13 @@ def test_binary_int_broadcast_2():
 
 @tvm.testing.uses_gpu
 def test_where():
+    def run(func, inputs, ref_res):
+        for target, ctx in tvm.testing.enabled_targets():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(func)(*inputs)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
     shape = (3, 4)
     dtype = "float32"
     cond = relay.var("cond", relay.TensorType(shape, dtype))
@@ -158,11 +165,21 @@ def test_where():
     x = np.random.uniform(size=shape).astype(dtype)
     y = np.random.uniform(size=shape).astype(dtype)
     ref_res = np.where(condition, x, y)
-    for target, ctx in tvm.testing.enabled_targets():
-        for kind in ["graph", "debug"]:
-            intrp = relay.create_executor(kind, ctx=ctx, target=target)
-            op_res = intrp.evaluate(func)(condition, x, y)
-            tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
+    run(func, [condition, x, y], ref_res)
+
+    x = relay.const(1)
+    y = relay.const(-1)
+    shape = (3,)
+    dtype = "float32"
+    cond = relay.var("cond", relay.TensorType(shape, "bool"))
+    z = relay.where(cond, x, y)
+
+    func = relay.Function([cond], z)
+    condition = np.array([1, 0, 1], dtype=np.bool)
+    ref_res = np.where(condition, 1, -1)
+
+    run(func, [condition], ref_res)
 
 
 def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
@@ -232,12 +249,12 @@ def test_reduce_functions():
         if not keepdims:
             x = np.squeeze(x, axis=axis)
         return x
-    
+
     def _unbiased_relay_wrapper(f):
         def _unbiased_func(x, axis=None, keepdims=False, exclude=False):
             return f(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True)
         return _unbiased_func
-    
+
     def _unbiased_np_wrapper(f):
         def _unbiased_func(a, axis=None, dtype=None, keepdims=None):
             return f(a, axis=axis, dtype=dtype, ddof=1, keepdims=keepdims)