You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "LeshengJin (via GitHub)" <gi...@apache.org> on 2023/09/06 07:07:51 UTC

[GitHub] [tvm] LeshengJin opened a new pull request, #15680: [Disco][Op] scatter_from_worker0

LeshengJin opened a new pull request, #15680:
URL: https://github.com/apache/tvm/pull/15680

   This pr introduces op scatter_from_worker0, which performs a scatter operation from worker-0, chunking the given buffer into equal parts.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao merged pull request #15680: [Disco][Op] scatter_from_worker0

Posted by "junrushao (via GitHub)" <gi...@apache.org>.
junrushao merged PR #15680:
URL: https://github.com/apache/tvm/pull/15680


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao commented on a diff in pull request #15680: [Disco][Op] scatter_from_worker0

Posted by "junrushao (via GitHub)" <gi...@apache.org>.
junrushao commented on code in PR #15680:
URL: https://github.com/apache/tvm/pull/15680#discussion_r1319180740


##########
src/relax/op/ccl/ccl.cc:
##########
@@ -70,5 +70,55 @@ TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.ccl.scatter_from_worker0 */
+TVM_REGISTER_NODE_TYPE(ScatterFromWorker0Attrs);
+
+Expr scatter_from_worker0(Expr data, int num_workers) {
+  ObjectPtr<ScatterFromWorker0Attrs> attrs = make_object<ScatterFromWorker0Attrs>();
+  attrs->num_workers = std::move(num_workers);
+  static const Op& op = Op::Get("relax.ccl.scatter_from_worker0");
+
+  return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.ccl.scatter_from_worker0").set_body_typed(scatter_from_worker0);
+
+StructInfo InferStructInfoScatterFromWorker0(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  DataType output_dtype = input_sinfo->dtype;
+
+  const auto* attrs = call->attrs.as<ScatterFromWorker0Attrs>();
+  int num_workers = attrs->num_workers;
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  const auto* input_shape = input_sinfo->shape.as<ShapeExprNode>();
+  input_sinfo->shape.as<VarNode>();
+  CHECK(input_shape != nullptr)
+      << "input tensor of scatter_from_worker0 should have defined ShapeExpr as shape";

Review Comment:
   Use `TensorStructInfoNode::GetShape`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao commented on a diff in pull request #15680: [Disco][Op] scatter_from_worker0

Posted by "junrushao (via GitHub)" <gi...@apache.org>.
junrushao commented on code in PR #15680:
URL: https://github.com/apache/tvm/pull/15680#discussion_r1319179762


##########
python/tvm/relax/transform/legalize_ops/ccl.py:
##########
@@ -46,9 +48,34 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr:
 
 
 @register_legalize("relax.ccl.broadcast_from_worker0")
-def broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
+def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
     return call_pure_packed(
         "runtime.disco.broadcast_from_worker0",
         call.args[0],
         sinfo_args=call.args[0].struct_info,
     )
+
+
+@register_legalize("relax.ccl.scatter_from_worker0")
+def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
+    output_shape = []
+    for i, shape_value in enumerate(call.args[0].struct_info.shape.values):

Review Comment:
   ```python
   assert isinstance(call.args[0].struct_info, TensorStructInfo)
   assert isinstance(call.args[0].struct_info.shape.struct_info, ShapeStructInfo)
   arg_shape = call.args[0].struct_info.shape.struct_info
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org