You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/01 21:07:12 UTC

[GitHub] [incubator-tvm] masahi commented on a change in pull request #5961: [Relay/TOPI][OP] Add meshgrid op in Relay, TOPI, Pytorch frontend

masahi commented on a change in pull request #5961:
URL: https://github.com/apache/incubator-tvm/pull/5961#discussion_r448616432



##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1269,6 +1269,93 @@ RELAY_REGISTER_OP("repeat")
     .set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
     .set_attr<TOpPattern>("TOpPattern", kBroadcast);
 
+// meshgrid operator
+TVM_REGISTER_NODE_TYPE(MeshgridAttrs);
+
+bool MeshgridRel(const Array<Type>& types, int num_inputs, const Attrs& raw_attrs,
+                 const TypeReporter& reporter) {
+  // types: [data, result]
+  CHECK_EQ(types.size(), 2);
+  const MeshgridAttrs* attrs = raw_attrs.as<MeshgridAttrs>();
+  const auto* tensor_tuple = types[0].as<TupleTypeNode>();
+  if (tensor_tuple == nullptr) {
+    throw Error(
+        ErrorBuilder() << "meshgrid requires a tuple of tensors as the first argument, found "
+                       << PrettyPrint(types[0]));
+  } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
+    return false;
+  }
+  const int data_length = static_cast<int>(tensor_tuple->fields.size());
+
+  // Get first dtype.
+  const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
+  const DataType dtype = first->dtype;
+
+  // Get size of output grid.
+  std::vector<IndexExpr> grid_shape;
+  grid_shape.reserve(data_length);
+  for (const Type& ele : tensor_tuple->fields) {
+    if (ele.as<IncompleteTypeNode>()) {
+      return false;
+    }
+    const auto& e = Downcast<TensorType>(ele);
+    int e_ndim = static_cast<int>(e->shape.size());
+    const DataType& e_dtype = e->dtype;
+    if (e_dtype != dtype) {
+      throw Error("relay.meshgrid requires all tensors have the same dtype");
+    }
+    if (e_ndim == 0) {
+      grid_shape.emplace_back(1);
+    } else if (e_ndim == 1) {
+      grid_shape.emplace_back(e->shape[0]);
+    } else {
+      throw Error("relay.meshgrid requires all tensors be either scalars or 1-D vectors.");
+    }
+  }
+
+  // "xy" mode swaps first two dimensions
+  if (attrs->indexing == "xy" && grid_shape.size() >= 2) {
+    std::swap(grid_shape[0], grid_shape[1]);
+  }
+
+  // There is one output grid for each input, all with same shape.
+  std::vector<Type> grids;
+  grids.reserve(data_length);
+  for (int i = 0; i < data_length; i++) {
+    grids.emplace_back(TensorType(grid_shape, dtype));
+  }
+  reporter->Assign(types[1], TupleType(Array<Type>(grids)));
+  return true;
+}
+
+Array<te::Tensor> MeshgridCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+                                  const Type& out_type) {
+  const MeshgridAttrs* param = attrs.as<MeshgridAttrs>();
+  CHECK(param != nullptr);
+  return {topi::meshgrid(inputs, param->indexing)};
+}
+
+Expr MakeMeshgrid(Expr data, String indexing) {
+  auto attrs = make_object<MeshgridAttrs>();
+  attrs->indexing = std::move(indexing);
+  static const Op& op = Op::Get("meshgrid");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.meshgrid").set_body_typed(MakeMeshgrid);
+
+RELAY_REGISTER_OP("meshgrid")
+    .describe(R"code(Create coordinate matrices from coordinate vectors.
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<MeshgridAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input list of tensors.")
+    .set_support_level(3)
+    .add_type_rel("Meshgrid", MeshgridRel)
+    .set_attr<FTVMCompute>("FTVMCompute", MeshgridCompute)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);

Review comment:
       I'm not sure if injective is the right choice, given that it returns multiple tensors it's not clear to me how it would play with the op fusion pass (injective means it can be fused with basically any other ops).
   
   But if this is working for your use case, I think it is ok for now.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org