You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/08 04:30:02 UTC

[tvm] branch unity updated: [Disco][Op] scatter_from_worker0 (#15680)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9e7227e084 [Disco][Op] scatter_from_worker0 (#15680)
9e7227e084 is described below

commit 9e7227e084b550d2e08ca058aecd0f3ac19672dd
Author: Lesheng Jin <34...@users.noreply.github.com>
AuthorDate: Thu Sep 7 21:29:56 2023 -0700

    [Disco][Op] scatter_from_worker0 (#15680)
    
    This pr introduces op scatter_from_worker0, which performs a scatter operation from worker-0, chunking the given buffer into equal parts.
---
 include/tvm/relax/attrs/ccl.h                      | 12 +++++
 python/tvm/relax/op/ccl/ccl.py                     | 19 ++++++++
 python/tvm/relax/transform/legalize_ops/ccl.py     | 34 ++++++++++++-
 src/relax/op/ccl/ccl.cc                            | 48 ++++++++++++++++++
 src/relax/op/ccl/ccl.h                             |  3 ++
 tests/python/relax/test_op_ccl.py                  | 57 ++++++++++++++++++++++
 .../relax/test_transform_legalize_ops_ccl.py       | 21 ++++++++
 7 files changed, 193 insertions(+), 1 deletion(-)

diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h
index 45de0e949c..b4b3880384 100644
--- a/include/tvm/relax/attrs/ccl.h
+++ b/include/tvm/relax/attrs/ccl.h
@@ -40,6 +40,18 @@ struct AllReduceAttrs : public tvm::AttrsNode<AllReduceAttrs> {
   }
 };  // struct AllReduceAttrs
 
+/*! \brief Attributes used in scatter_from_worker0 operators */
+struct ScatterFromWorker0Attrs : public tvm::AttrsNode<ScatterFromWorker0Attrs> {
+  int num_workers;
+
+  TVM_DECLARE_ATTRS(ScatterFromWorker0Attrs, "relax.attrs.ScatterFromWorker0Attrs") {
+    TVM_ATTR_FIELD(num_workers)
+        .describe(
+            "The number of workers, also the number of parts the given buffer should be chunked "
+            "into.");
+  }
+};  // struct ScatterFromWorker0Attrs
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py
index 093e01b638..62e5914a89 100644
--- a/python/tvm/relax/op/ccl/ccl.py
+++ b/python/tvm/relax/op/ccl/ccl.py
@@ -58,3 +58,22 @@ def broadcast_from_worker0(x: Expr) -> Expr:
       The same tensor, which has been broadcast to all other workers.
     """
     return _ffi_api.broadcast_from_worker0(x)
+
+
+def scatter_from_worker0(x: Expr, num_workers: int) -> Expr:
+    """Perform a scatter operation from worker-0, chunking the given buffer into equal parts.
+
+    Parameters
+    ----------
+    x : relax.Expr
+      The buffer to be divided into equal parts and sent to each worker accordingly.
+
+    num_worker : int
+      The number of workers, i.e. the number of parts the given buffer should be chunked into.
+
+    Returns
+    -------
+    result : relax.Expr
+      Chunked Tensor received by different workers.
+    """
+    return _ffi_api.scatter_from_worker0(x, num_workers)
diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py b/python/tvm/relax/transform/legalize_ops/ccl.py
index c9c09952b2..019f1726f0 100644
--- a/python/tvm/relax/transform/legalize_ops/ccl.py
+++ b/python/tvm/relax/transform/legalize_ops/ccl.py
@@ -16,9 +16,11 @@
 # under the License.
 # pylint: disable=invalid-name
 """Default legalization function for ccl operators."""
+from tvm import tir, arith
 from ...block_builder import BlockBuilder
 from ...expr import Call, Expr, ShapeExpr
 from ...op import call_pure_packed
+from ...struct_info import TensorStructInfo, ShapeStructInfo
 from .common import register_legalize
 
 
@@ -46,9 +48,39 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr:
 
 
 @register_legalize("relax.ccl.broadcast_from_worker0")
-def broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
+def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
     return call_pure_packed(
         "runtime.disco.broadcast_from_worker0",
         call.args[0],
         sinfo_args=call.args[0].struct_info,
     )
+
+
+@register_legalize("relax.ccl.scatter_from_worker0")
+def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
+    output_shape = []
+    assert isinstance(
+        call.args[0].struct_info, TensorStructInfo
+    ), "The input struct info of scatter_from_worker0 should be TensorStructInfo."
+    assert isinstance(call.args[0].struct_info.shape.struct_info, ShapeStructInfo)
+    arg_shape = call.args[0].struct_info.shape.struct_info
+    for i, shape_value in enumerate(arg_shape.values):
+        if i == 0:
+            modulo = arith.Analyzer().simplify(shape_value % call.attrs.num_workers)
+            assert modulo == 0, (
+                "scatter_from_worker0 expects the size of axis 0 of input tensor "
+                "to be divisible by num_workers. However, the axis 0 of input tensor "
+                f"is {shape_value} while num_workers is {call.attrs.num_workers}"
+            )
+            output_shape.append(tir.div(shape_value, call.attrs.num_workers))
+        else:
+            output_shape.append(shape_value)
+    return call_pure_packed(
+        "runtime.disco.scatter_from_worker0",
+        call.args[0],
+        sinfo_args=TensorStructInfo(
+            shape=output_shape,
+            dtype=call.args[0].struct_info.dtype,
+            vdevice=call.args[0].struct_info.vdevice,
+        ),
+    )
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index 6fa6c96db3..a114cb1dae 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -70,5 +70,53 @@ TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.ccl.scatter_from_worker0 */
+TVM_REGISTER_NODE_TYPE(ScatterFromWorker0Attrs);
+
+Expr scatter_from_worker0(Expr data, int num_workers) {
+  ObjectPtr<ScatterFromWorker0Attrs> attrs = make_object<ScatterFromWorker0Attrs>();
+  attrs->num_workers = std::move(num_workers);
+  static const Op& op = Op::Get("relax.ccl.scatter_from_worker0");
+
+  return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.ccl.scatter_from_worker0").set_body_typed(scatter_from_worker0);
+
+StructInfo InferStructInfoScatterFromWorker0(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  DataType output_dtype = input_sinfo->dtype;
+
+  const auto* attrs = call->attrs.as<ScatterFromWorker0Attrs>();
+  int num_workers = attrs->num_workers;
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  auto input_shape = input_sinfo->GetShape();
+  CHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should have defined shape.";
+
+  if (analyzer->CanProve(floormod(input_shape.value()[0], PrimExpr(num_workers))) != 0) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "scatter_from_worker0 expects the size of axis 0 of input tensor to be "
+                        "divisible by the "
+                        "num_workers. However, the axis 0 of input tensor is "
+                     << input_shape.value() << " while num_workers is " << num_workers);
+  }
+
+  Array<PrimExpr> output_shape = input_shape.value();
+  output_shape.Set(0, div(output_shape[0], num_workers));
+  if (input_sinfo->vdevice.defined()) {
+    return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice.value());
+  }
+  return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
+}
+
+TVM_REGISTER_OP("relax.ccl.scatter_from_worker0")
+    .set_num_inputs(1)
+    .add_argument("x", "Tensor",
+                  "The buffer to be divided into equal parts and sent to each worker accordingly.")
+    .set_attrs_type<ScatterFromWorker0Attrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterFromWorker0)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h
index 55402f3d37..169b4b446c 100644
--- a/src/relax/op/ccl/ccl.h
+++ b/src/relax/op/ccl/ccl.h
@@ -38,6 +38,9 @@ Expr allreduce(Expr data, String op_type);
 /*! \brief Broadcast data from worker-0 to all other workers. */
 Expr broadcast_from_worker0(Expr data);
 
+/*! \brief Perform a scatter operation from worker-0, chunking the given buffer into equal parts. */
+Expr scatter_from_worker0(Expr data, int num_workers);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_op_ccl.py b/tests/python/relax/test_op_ccl.py
index 09924d27ec..1c98df4ea7 100644
--- a/tests/python/relax/test_op_ccl.py
+++ b/tests/python/relax/test_op_ccl.py
@@ -160,5 +160,62 @@ def test_broadcast_from_worker0_infer_struct_info_more_input_dtype():
     )
 
 
+def test_scatter_from_worker0_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor((3, 4, 5)))
+
+    _check_inference(
+        bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorStructInfo((1, 3), "float32")
+    )
+    _check_inference(
+        bb, relax.op.ccl.scatter_from_worker0(x1, 3), relax.TensorStructInfo((1, 4, 5), dtype="")
+    )
+
+
+def test_scatter_from_worker0_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+    x1 = relax.Var("x", R.Tensor((4, n), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.ccl.scatter_from_worker0(x0, 2),
+        relax.TensorStructInfo((tir.div(m, 2), n), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.ccl.scatter_from_worker0(x1, 2), relax.TensorStructInfo((2, n), "float32")
+    )
+
+
+def test_scatter_from_worker0_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo((2, 4, 8)))
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+
+    _check_inference(
+        bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorStructInfo((1, 4, 8), "float32")
+    )
+
+
+def test_scatter_from_worker0_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float64"))
+    x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3), "int64"))
+
+    _check_inference(
+        bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorStructInfo((1, 3), "float64")
+    )
+    _check_inference(
+        bb, relax.op.ccl.scatter_from_worker0(x1, 2), relax.TensorStructInfo((1, 3), "int8")
+    )
+    _check_inference(
+        bb, relax.op.ccl.scatter_from_worker0(x2, 2), relax.TensorStructInfo((1, 3), "int64")
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py
index 9bce76cecb..b1da283e70 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -72,5 +72,26 @@ def test_broadcast_from_zero():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_scatter_from_worker0():
+    # fmt: off
+    @tvm.script.ir_module
+    class ScatterFromWorker0:
+        @R.function
+        def main(x: R.Tensor((10, 10), "float32"))  -> R.Tensor((5, 10), "float32"):
+            gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x, 2)
+            return gv0
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((5, 10), dtype="float32"):
+            gv0: R.Tensor((5, 10), dtype="float32") = R.call_pure_packed("runtime.disco.scatter_from_worker0", x, sinfo_args=R.Tensor((5, 10), dtype="float32"))
+            return gv0
+    # fmt: on
+
+    mod = LegalizeOps()(ScatterFromWorker0)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()