You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/31 23:33:45 UTC

[GitHub] [tvm] masahi opened a new pull request #8165: [Relay, TOPI] Refactor strided_slice and add axes argument

masahi opened a new pull request #8165:
URL: https://github.com/apache/tvm/pull/8165


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #8165: [Relay, TOPI] Refactor strided_slice and add axes argument

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #8165:
URL: https://github.com/apache/tvm/pull/8165#discussion_r643532442



##########
File path: include/tvm/topi/transform.h
##########
@@ -594,137 +608,152 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
 }
 
 /*!
- * \brief strided_slice of a tensor
+ * \brief strided_slice of a tensor with dynamic begin/end/stride
  *
  * \param x The input tensor
  * \param begin The indices to begin with in the slicing
  * \param end Indicies indicating end of the slice
  * \param strides Specifies the stride values, it can be negative
  * in that case, the input tensor will be reversed in that particular axis
- * \param slice_mode Specifies the slice mode
  * \param name The name of the operation
  * \param tag The tag to mark the operation
  *
  * \return A Tensor whose op member is the split operation
  */
-inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
-                            const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
-                            std::string slice_mode = "end", std::string name = "T_strided_slice",
-                            std::string tag = kInjective) {
-  size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
-  // Quick path for dynamic shape strided slice.
-  // This is for ease of use to dynamice strided slice in topi.
-  bool is_static = IsConstIntArray(x->shape);
-  is_static &= IsConstIntArray(begin);
-  is_static &= IsConstIntArray(end);
-  is_static &= IsConstIntArray(strides);
-
-  Array<PrimExpr> out_shape;
-  if (!is_static) {
-    ICHECK_EQ(strides.size(), src_tensor_dim);
-    for (size_t i = 0; i < src_tensor_dim; ++i) {
-      out_shape.push_back(indexdiv(end[i] - begin[i], strides[i]));
-    }
-    return te::compute(
-        out_shape,
-        [&](const Array<tvm::tir::Var>& indices) {
-          Array<PrimExpr> real_indices;
-          for (size_t i = 0; i < src_tensor_dim; ++i) {
-            real_indices.push_back(indices[i] * strides[i] + begin[i]);
-          }
-          return x(real_indices);
-        },
-        name, tag);
-  }
-
-  // Setup the ranges.
-  // NOTE: this code duplicates the shape inference logic relay.op
-  // Consider to refactor in the future.
-  std::vector<int64_t> stride_vec(src_tensor_dim, 1);
-  for (size_t i = 0; i < strides.size(); ++i) {
-    ICHECK(strides[i].defined());
-    stride_vec[i] = GetConstInt(strides[i]);
-  }
-
-  const int64_t max_range = std::numeric_limits<int64_t>::max();
+inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin,
+                                        const te::Tensor& end, const te::Tensor& strides,
+                                        std::string name = "T_strided_slice_dynamic",
+                                        std::string tag = topi::kInjective) {
+  const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
+  ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
+  ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
 
-  std::vector<int64_t> begin_vec;
-  for (size_t i = 0; i < begin.size(); ++i) {
-    if (!begin[i].defined()) {
-      // value=None
-      begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
-    } else {
-      begin_vec.push_back(GetConstInt(begin[i]));
-    }
-  }
-  for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) {
-    begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
+  Array<PrimExpr> begin_expr, end_expr, strides_expr;
+  for (int64_t i = 0; i < num_dynamic_axes; ++i) {
+    auto i64_ind = IntImm(DataType::Int(64), i);
+    begin_expr.push_back(begin(i64_ind));
+    end_expr.push_back(end(i64_ind));
+    strides_expr.push_back(strides(i64_ind));
   }
+  return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag);
+}
 
-  std::vector<int64_t> end_vec;
-  for (size_t i = 0; i < end.size(); ++i) {
-    // allow end to be None
+/*!
+ * \brief Calcluate the output shape of strided_slice, the entry point for Relay type relation
+ *
+ * \param ishape The input tensor shape
+ * \param begin The indices to begin with in the slicing
+ * \param end Indicies indicating end of the slice
+ * \param strides Specifies the stride values, it can be negative
+ * in that case, the input tensor will be reversed in that particular axis
+ * \param axes Axes along which slicing is applied. When it is specified, the length of begin, end,
+ * strides, and axes argument must be equal
+ * \param slice_mode Specifies the slice mode
+ *
+ * \return The output shape of strided_slice using the arguments above
+ */
+inline Array<PrimExpr> StridedSliceOutputShape(
+    const Array<PrimExpr>& ishape, const Array<Integer>& begin, const Array<Integer>& end,
+    const Array<Integer>& strides, const Array<Integer>& axes, const std::string& slice_mode) {
+  ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
+  std::vector<int64_t> begin_vec, end_vec, strides_vec;
+  std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode);
+  auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes,
+                                                           begin[0]->dtype, slice_mode);
+  return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode,
+                                 begin_canonicalized, true);
+}
 
-    if (!end[i].defined()) {
-      end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
-    } else if (slice_mode == "size") {
-      int64_t end_val = GetConstInt(end[i]);
-      if (end_val < 0) {
-        end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
-      } else {
-        end_vec.push_back(begin_vec[i] + end_val);
-      }
-    } else {
-      end_vec.push_back(GetConstInt(end[i]));
-    }
-  }
-  for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) {
-    end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
-  }
-  // Compute
-  Array<PrimExpr> begin_expr;
-  Array<PrimExpr> strides_expr;
-
-  for (size_t i = 0; i < src_tensor_dim; ++i) {
-    int64_t begin_range = stride_vec[i] < 0 ? -1 : 0;
-    int64_t dim_i = GetConstInt(x->shape[i]);
-    int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i;
-    // transform negative indices to positive value, clips on the correct range
-    auto index_canonicalization = [dim_i, begin_range, end_range](int64_t index) {
-      if (index < 0) {
-        index += dim_i;
-      }
-      return std::min(std::max(index, begin_range), end_range);
-    };
-
-    int64_t begin_i = index_canonicalization(begin_vec[i]);
-    int64_t end_i = index_canonicalization(end_vec[i]);
-
-    int interval = std::abs(end_i - begin_i);
-    int slice_size =
-        static_cast<int>((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
-    ICHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
-        << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
-        << "] is invalid for axis=" << i;
-
-    begin_expr.push_back(make_const(begin[0].dtype(), begin_i));
-    strides_expr.push_back(
-        make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i]));
-    out_shape.push_back(slice_size);
-  }
+/*!
+ * \brief strided_slice of a tensor
+ *
+ * \param x The input tensor
+ * \param begin The indices to begin with in the slicing
+ * \param end Indicies indicating end of the slice
+ * \param strides Specifies the stride values, it can be negative
+ * in that case, the input tensor will be reversed in that particular axis
+ * \param axes Axes along which slicing is applied. When it is specified, the length of begin, end,
+ * strides, and axes argument must be equal
+ * \param slice_mode Specifies the slice mode
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the split operation

Review comment:
       split->slice?

##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -2445,99 +2445,40 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
     return false;
   }
 
-  auto dshape = data->shape;
-  int64_t num_axis = dshape.size();
-
-  // calculate output shape
-  std::vector<IndexExpr> oshape(num_axis);
-  if (param->begin && param->end && param->strides) {
-    // stride will be set as 1 if slice mode is enabled
-    std::vector<int64_t> stride_vec(num_axis, 1);
-    if (param->slice_mode == "end") {
-      for (size_t i = 0; i < param->strides.value().size(); ++i) {
-        ICHECK(param->strides.value()[i].defined());
-        stride_vec[i] = param->strides.value()[i]->value;
-      }
-    }
-    const int64_t max_range = std::numeric_limits<int64_t>::max();
-    std::vector<int64_t> begin_vec;
-    for (size_t i = 0; i < param->begin.value().size(); ++i) {
-      if (!param->begin.value()[i].defined()) {
-        begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
-      } else {
-        begin_vec.push_back(param->begin.value()[i]->value);
-      }
-    }
-    for (int64_t i = begin_vec.size(); i < num_axis; ++i) {
-      begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
-    }
+  ICHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin;
+  ICHECK(param->end) << "strided_slice recieved invalid end " << param->end;
+  ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides;
+
+  auto begin = param->begin.value();
+  auto end = param->end.value();
+  auto strides = param->strides.value();
+
+  const size_t src_tensor_dim = static_cast<size_t>(data->shape.size());
+  Array<Integer> axes;
+  if (param->axes) {
+    axes = param->axes.value();
+    ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
+           axes.size() == strides.size())
+        << "axes, begin, end, and strides must have the same length";
+  } else {
+    for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
 
-    std::vector<int64_t> end_vec;
-    for (size_t i = 0; i < param->end.value().size(); ++i) {
-      // allow end to be None
-      if (!param->end.value()[i].defined()) {
-        end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
-      } else if (param->slice_mode == "size") {
-        if (param->end.value()[i]->value < 0) {
-          end_vec.push_back(max_range);
-        } else {
-          end_vec.push_back(begin_vec[i] + param->end.value()[i]->value);
-        }
-      } else if (param->slice_mode == "end") {
-        end_vec.push_back(param->end.value()[i]->value);
-      } else {
-        LOG(FATAL) << "Unsupported slice mode: " << param->slice_mode;
-      }
+    const IntImm one = IntImm(DataType::Int(64), 1);
+    const IntImm zero = IntImm(DataType::Int(64), 0);
+    const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits<int64_t>::max());
+
+    for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
+      strides.push_back(one);
     }
-    for (int64_t i = end_vec.size(); i < num_axis; ++i) {
-      end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
+    for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
+      begin.push_back(topi::GetConstInt(strides[i]) > 0 ? zero : max_range);
     }
-
-    for (int64_t i = 0; i < num_axis; ++i) {
-      int64_t stride_v = stride_vec[i];
-      int64_t begin_v = begin_vec[i];
-      int64_t end_v = end_vec[i];
-
-      if ((stride_v == 1 && begin_v == 0 && end_v == max_range) ||
-          (stride_v == -1 && begin_v == max_range && end_v == 0)) {
-        // Quick path, do not slice this dimension.
-        oshape[i] = dshape[i];
-        continue;
-      }
-      // Normal path, require the shape to be concrete integer.
-      // Require concrete integer as symbolic inference of min/max
-      // can get complicated and not very helpful.
-      const int64_t* p_dim_size = tir::as_const_int(dshape[i]);
-      if (!p_dim_size) {
-        oshape[i] = dshape[i];
-        continue;
-      }
-      int64_t dim_size = p_dim_size[0];
-      begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v;
-      end_v = (end_v < 0) ? dim_size + end_v : end_v;
-
-      int64_t slice_range, step;
-      if (stride_v < 0) {
-        if (end_v < -1) end_v = -1;
-        ICHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis " << i;
-        begin_v = std::min(dim_size - 1, begin_v);
-        slice_range = begin_v - end_v;
-        step = -stride_v;
-      } else {
-        if (begin_v < 0) begin_v = 0;
-        ICHECK_GE(stride_v, 0);
-        ICHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis " << i;
-        end_v = std::min(dim_size, end_v);
-        slice_range = end_v - begin_v;
-        step = stride_v;
-      }
-      oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step);
+    for (size_t i = end.size(); i < src_tensor_dim; ++i) {
+      end.push_back(topi::GetConstInt(strides[i]) < 0 ? zero : max_range);
     }
-  } else {
-    ICHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin;
-    ICHECK(param->end) << "strided_slice recieved invalid end " << param->end;
-    ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides;
   }
+  auto oshape =
+      topi::StridedSliceOutputShape(data->shape, begin, end, strides, axes, param->slice_mode);

Review comment:
       Thank you for moving this to a common utility :bowing_man:

##########
File path: python/tvm/relay/op/transform.py
##########
@@ -917,7 +922,7 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
         begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin)
         begin = _make.where(begin >= ishape_slice, ishape_slice, begin)
         return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)

Review comment:
       Do we not support axes with dynamic begin?

##########
File path: src/topi/transform.cc
##########
@@ -174,11 +174,26 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) {
 });
 
 TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) {

Review comment:
       Why not allow axes arguments here?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #8165: [Relay, TOPI] Refactor strided_slice and add axes argument

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #8165:
URL: https://github.com/apache/tvm/pull/8165#discussion_r643593227



##########
File path: python/tvm/relay/op/transform.py
##########
@@ -917,7 +922,7 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
         begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin)
         begin = _make.where(begin >= ishape_slice, ishape_slice, begin)
         return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)

Review comment:
       Yeah I haven't put much thought into `axes` argument for dynamic strided slice. My goal was to preserve more static dimensions along `axes`, which doesn't apply to the dynamic parameter variant. We should support it for convenience and API consistency sake, but for now I'd like to leave it as TODO and assert `axes is None` here until someone complains or comes up with a good use case.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #8165: [Relay, TOPI] Refactor strided_slice and add axes argument

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #8165:
URL: https://github.com/apache/tvm/pull/8165#discussion_r643536633



##########
File path: python/tvm/relay/op/transform.py
##########
@@ -917,7 +922,7 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
         begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin)
         begin = _make.where(begin >= ishape_slice, ishape_slice, begin)
         return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)

Review comment:
       Perhaps we should throw an error if we hit this codepath and axes is defined? With a TODO for extending axes to the dynamic case?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi merged pull request #8165: [Relay, TOPI] Refactor strided_slice and add axes argument

Posted by GitBox <gi...@apache.org>.
masahi merged pull request #8165:
URL: https://github.com/apache/tvm/pull/8165


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #8165: [Relay, TOPI] Refactor strided_slice and add axes argument

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #8165:
URL: https://github.com/apache/tvm/pull/8165#discussion_r643770095



##########
File path: src/topi/transform.cc
##########
@@ -174,11 +174,26 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) {
 });
 
 TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) {

Review comment:
       fixed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi merged pull request #8165: [Relay, TOPI] Refactor strided_slice and add axes argument

Posted by GitBox <gi...@apache.org>.
masahi merged pull request #8165:
URL: https://github.com/apache/tvm/pull/8165


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #8165: [Relay, TOPI] Refactor strided_slice and add axes argument

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #8165:
URL: https://github.com/apache/tvm/pull/8165#discussion_r643783471



##########
File path: python/tvm/relay/op/transform.py
##########
@@ -917,7 +922,7 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
         begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin)
         begin = _make.where(begin >= ishape_slice, ishape_slice, begin)
         return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)

Review comment:
       Added TODO and assert.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #8165: [Relay, TOPI] Refactor strided_slice and add axes argument

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #8165:
URL: https://github.com/apache/tvm/pull/8165#issuecomment-853398328


   thanks @mbrookhart 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #8165: [Relay, TOPI] Refactor strided_slice and add axes argument

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #8165:
URL: https://github.com/apache/tvm/pull/8165#issuecomment-853398328


   thanks @mbrookhart 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org