You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "MasterJH5574 (via GitHub)" <gi...@apache.org> on 2023/03/19 21:43:17 UTC

[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #14324: [Unity][Op] Fix Strided Slice Shape Inference

MasterJH5574 commented on code in PR #14324:
URL: https://github.com/apache/tvm/pull/14324#discussion_r1141475786


##########
src/relax/op/tensor/index.cc:
##########
@@ -142,6 +142,26 @@ Expr strided_slice(Expr x,                 //
 
 TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice);
 
+inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t stride) {
+  // Same as topi strided slice CanonicalizeIndex function in
+  // include/tvm/topi/detail/strided_slice.h
+  PrimExpr begin_range = stride < 0 ? -1 : 0;
+  PrimExpr end_range = stride < 0 ? extent - 1 : extent;
+  index = if_then_else(index < 0, index + extent, index);
+  return min(max(index, begin_range), end_range);  // NOLINT
+}
+
+PrimExpr GetLength(PrimExpr begin, PrimExpr end, const int64_t stride, const PrimExpr& ndim) {
+  begin = CanonicalizeIndex(begin, ndim, stride);
+  end = CanonicalizeIndex(end, ndim, stride);

Review Comment:
   Here it is the dim length rather than `ndim`.
   ```suggestion
   PrimExpr GetLength(PrimExpr begin, PrimExpr end, const int64_t stride, const PrimExpr& length) {
     begin = CanonicalizeIndex(begin, length, stride);
     end = CanonicalizeIndex(end, length, stride);
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org