You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/10/26 07:22:51 UTC

[GitHub] [incubator-tvm] t-vi commented on a change in pull request #6759: [Relay, TOPI] Complete rewrite of where op to support broadcasting

t-vi commented on a change in pull request #6759:
URL: https://github.com/apache/incubator-tvm/pull/6759#discussion_r511757403



##########
File path: include/tvm/topi/broadcast.h
##########
@@ -69,6 +69,46 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
   return tvm::te::compute(oshape, l, name, tag);
 }
 
+inline tvm::te::Tensor broadcast_shape_tensors(const tvm::te::Tensor& shape_tensor1,
+                                               const tvm::te::Tensor& shape_tensor2,
+                                               std::string name = "T_broadcast_shape_tensors",
+                                               std::string tag = kBroadcast) {
+  const auto rank1 = detail::GetConstInt(shape_tensor1->shape[0]);
+  const auto rank2 = detail::GetConstInt(shape_tensor2->shape[0]);
+  const auto out_rank = std::max<int32_t>(rank1, rank2);
+  const tvm::PrimExpr one = tvm::cast(shape_tensor1->dtype, PrimExpr(1));
+
+  auto select_dim = [&](const tvm::te::Tensor& shape_tensor, int rank,
+                        tvm::tir::Var index) -> PrimExpr {
+    if (rank < out_rank) {
+      // if the rank is smaller, dimension 1 is prepended according to
+      // the numpy broadcasting semantics.
+      return tvm::tir::Select(rank - (out_rank - index) < 0, one,
+                              shape_tensor[rank - (out_rank - index)]);
+    } else {
+      // rank == out_rank, safe to index directly
+      return shape_tensor[index];
+    }
+  };
+
+  auto func = [&](tvm::Array<tvm::tir::Var> ovars) {
+    auto index = ovars[0];
+    PrimExpr dim1 = select_dim(shape_tensor1, rank1, index);
+    PrimExpr dim2 = select_dim(shape_tensor2, rank2, index);
+    if (topi::detail::EqualCheck(one, dim1)) {
+      return dim2;
+    } else if (topi::detail::EqualCheck(one, dim2)) {
+      return dim1;
+    }
+    return tvm::max(dim1, dim2);

Review comment:
       Two comments (not sure if they need to be addressed in this PR):
   - Does this (with EqualCheck and C++ if) work as expected on dynamic shapes? Should it?
   - While this seems to work for valid broadcasting (save the potential dynamic caveat), it does fail to reject invalid broadcasting, i.e. when none of the dims is 1 but they're different. Of course, this might be intentional to cover dynamic (if they aren't covered in the two cases above, but it might lead to funny error messages etc., so I think it would be neat to have a comment of what each code path is handling).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org