You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/09/27 15:02:25 UTC

[tvm] branch unity updated: [Disco] Add AllGather (#15764)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 230f8b2491 [Disco] Add AllGather (#15764)
230f8b2491 is described below

commit 230f8b249147ddaf8a514282f6dad96337ac8f5e
Author: Farshid Salemi Parizi <fp...@octoml.ai>
AuthorDate: Wed Sep 27 08:02:16 2023 -0700

    [Disco] Add AllGather (#15764)
    
    * [Disco] Add AllGather
    
    * update to allgather to be compatiple with #15766
    
    * lint fix
    
    * add num workers to all gather
    
    * change num_gpus to prim value instead of int
    
    * fix typo
    
    * remove all_gather attrs and minor improvment
    
    * Update allgather doc
    
    Co-authored-by: Lesheng Jin <34...@users.noreply.github.com>
    
    ---------
    
    Co-authored-by: Lesheng Jin <34...@users.noreply.github.com>
---
 python/tvm/relax/op/ccl/ccl.py                     | 26 +++++++++-
 python/tvm/relax/transform/legalize_ops/ccl.py     | 25 ++++++++++
 python/tvm/runtime/disco/session.py                | 17 +++++++
 src/relax/op/ccl/ccl.cc                            | 36 ++++++++++++++
 src/relax/op/ccl/ccl.h                             |  3 ++
 src/runtime/disco/builtin.cc                       |  3 ++
 src/runtime/disco/builtin.h                        |  6 +++
 src/runtime/disco/nccl/nccl.cc                     | 11 +++++
 tests/python/disco/test_ccl.py                     | 23 +++++++++
 tests/python/relax/test_op_ccl.py                  | 55 ++++++++++++++++++++++
 .../relax/test_transform_legalize_ops_ccl.py       | 23 +++++++++
 11 files changed, 227 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py
index 62e5914a89..4829bac761 100644
--- a/python/tvm/relax/op/ccl/ccl.py
+++ b/python/tvm/relax/op/ccl/ccl.py
@@ -15,9 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 """Relax Collective Communications Library (CCL) operators"""
-from . import _ffi_api
+from typing import Union
+from tvm.relax import PrimValue
 
+from . import _ffi_api
 from ...expr import Expr
+from ....ir import PrimExpr
 
 
 def allreduce(x, op_type: str = "sum"):  # pylint: disable=invalid-name
@@ -44,6 +47,27 @@ def allreduce(x, op_type: str = "sum"):  # pylint: disable=invalid-name
     return _ffi_api.allreduce(x, op_type)  # type: ignore # pylint: disable=no-member
 
 
+def allgather(x, num_workers: Union[int, PrimExpr, PrimValue]):  # pylint: disable=invalid-name
+    """AllGather operator
+
+    Parameters
+    ----------
+    x : relax.Expr
+      The input tensor.
+
+    num_worker : Union[int, PrimExpr, PrimValue]
+      The number of workers to gather data from.
+
+    Returns
+    -------
+    result : relax.Expr
+      The result of allgather.
+    """
+    if not isinstance(num_workers, PrimValue):
+        num_workers = PrimValue(num_workers)
+    return _ffi_api.allgather(x, num_workers)  # type: ignore # pylint: disable=no-member
+
+
 def broadcast_from_worker0(x: Expr) -> Expr:
     """Broadcast data from worker-0 to all other workers.
 
diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py b/python/tvm/relax/transform/legalize_ops/ccl.py
index 7b51cb7738..9b13d1be7c 100644
--- a/python/tvm/relax/transform/legalize_ops/ccl.py
+++ b/python/tvm/relax/transform/legalize_ops/ccl.py
@@ -46,6 +46,31 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.ccl.allgather")
+def _allgather(_bb: BlockBuilder, call: Call) -> Expr:
+    output_shape = []
+    arg_sinfo = call.args[0].struct_info
+    assert isinstance(
+        arg_sinfo, TensorStructInfo
+    ), "The input struct info of allgather should be TensorStructInfo."
+    assert isinstance(arg_sinfo.shape.struct_info, ShapeStructInfo)
+    arg_shape = arg_sinfo.shape.struct_info
+    for i, shape_value in enumerate(arg_shape.values):
+        if i == 0:
+            output_shape.append(shape_value * call.args[1].value)
+        else:
+            output_shape.append(shape_value)
+    return call_dps_packed(
+        "runtime.disco.allgather",
+        call.args[0],
+        out_sinfo=TensorStructInfo(
+            shape=output_shape,
+            dtype=arg_sinfo.dtype,
+            vdevice=arg_sinfo.vdevice,
+        ),
+    )
+
+
 @register_legalize("relax.ccl.broadcast_from_worker0")
 def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
     return call_dps_packed(
diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py
index d05561c2d1..bd015f945e 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -329,6 +329,23 @@ class Session(Object):
         func = self._get_cached_method("runtime.disco.allreduce")
         func(src, op, dst)
 
+    def allgather(
+        self,
+        src: DRef,
+        dst: DRef,
+    ) -> DRef:
+        """Perform an allgather operation on an array.
+
+        Parameters
+        ----------
+        src : DRef
+            The array to be gathered from.
+        dst : DRef
+            The array to be gathered to.
+        """
+        func = self._get_cached_method("runtime.disco.allgather")
+        func(src, dst)
+
 
 @register_object("runtime.disco.ThreadedSession")
 class ThreadedSession(Session):
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index a114cb1dae..4372dd0aa6 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -50,6 +50,42 @@ TVM_REGISTER_OP("relax.ccl.allreduce")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.ccl.allgather */
+Expr allgather(Expr x, Expr num_workers) {
+  static const Op& op = Op::Get("relax.ccl.allgather");
+  return Call(op, {std::move(x), std::move(num_workers)});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather);
+
+StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) {
+  CHECK_EQ(call->args.size(), 2);
+  auto input_sinfo = Downcast<TensorStructInfo>(call->args[0]->struct_info_);
+  auto num_workers_sinfo = Downcast<PrimStructInfo>(call->args[1]->struct_info_);
+
+  auto num_workers = num_workers_sinfo->value;
+
+  DataType output_dtype = input_sinfo->dtype;
+  auto input_shape = input_sinfo->GetShape();
+  if (!input_shape.defined()) {
+    return input_sinfo;
+  }
+  Array<PrimExpr> output_shape = input_shape.value();
+  output_shape.Set(0, floor(output_shape[0] * num_workers.value()));
+  VDevice vdevice;
+  if (input_sinfo->vdevice.defined()) {
+    vdevice = input_sinfo->vdevice.value();
+  }
+  return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdevice);
+}
+
+TVM_REGISTER_OP("relax.ccl.allgather")
+    .set_num_inputs(1)
+    .add_argument("x", "Tensor", "Input to which allgather will be applied.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAllGather)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.ccl.broadcast_from_worker0 */
 Expr broadcast_from_worker0(Expr x) {
   static const Op& op = Op::Get("relax.ccl.broadcast_from_worker0");
diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h
index 169b4b446c..7742997d5a 100644
--- a/src/relax/op/ccl/ccl.h
+++ b/src/relax/op/ccl/ccl.h
@@ -35,6 +35,9 @@ namespace relax {
 /*! \brief AllReduce. */
 Expr allreduce(Expr data, String op_type);
 
+/*! \brief AllGather. */
+Expr allgather(Expr data, Expr num_workers);
+
 /*! \brief Broadcast data from worker-0 to all other workers. */
 Expr broadcast_from_worker0(Expr data);
 
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 06408c723a..5aea39cf66 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -84,6 +84,8 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
   GetCCLFunc("allreduce")(send, static_cast<int>(reduce_kind), recv);
 }
 
+void AllGather(NDArray send, NDArray recv) { GetCCLFunc("allgather")(send, recv); }
+
 void BroadcastFromWorker0(NDArray send, NDArray recv) {
   GetCCLFunc("broadcast_from_worker0")(send, recv);
 }
@@ -114,6 +116,7 @@ TVM_REGISTER_GLOBAL("runtime.disco.allreduce")
       CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind;
       AllReduce(send, static_cast<ReduceKind>(kind), recv);
     });
+TVM_REGISTER_GLOBAL("runtime.disco.allgather").set_body_typed(AllGather);
 TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(BroadcastFromWorker0);
 TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
 TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0);
diff --git a/src/runtime/disco/builtin.h b/src/runtime/disco/builtin.h
index 10b20562d0..cfbf2e2477 100644
--- a/src/runtime/disco/builtin.h
+++ b/src/runtime/disco/builtin.h
@@ -52,6 +52,12 @@ NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device device);
  * \return The outcome of allreduce
  */
 void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv);
+/*!
+ * \brief Perform an allgather operation using the underlying communication library
+ * \param send The array send to perform allgather on
+ * \return The outcome of allgather
+ */
+void AllGather(NDArray send, NDArray recv);
 /*!
  * \brief Perform a broadcast operation from worker-0
  * \param buffer The buffer to be broadcasted
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 6c882176ff..0ce8b985ff 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -192,6 +192,15 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
                           /*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, stream));
 }
 
+void AllGather(NDArray send, NDArray recv) {
+  CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+  ShapeTuple shape = send.Shape();
+  int64_t numel = shape->Product();
+  cudaStream_t stream = ctx->GetDefaultStream();
+  NCCL_CALL(ncclAllGather(send->data, recv->data, numel,
+                          /*datatype=*/AsNCCLDataType(DataType(send->dtype)), ctx->comm, stream));
+}
+
 void BroadcastFromWorker0(NDArray send, NDArray recv) {
   CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
   ICHECK(send.Shape()->Product() == recv.Shape()->Product());
@@ -316,6 +325,8 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce")
       CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind;
       AllReduce(send, static_cast<ReduceKind>(kind), recv);
     });
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather")
+    .set_body_typed([](NDArray send, NDArray recv) { AllGather(send, recv); });
 TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0")
     .set_body_typed(BroadcastFromWorker0);
 TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0")
diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py
index ecd2e07287..4ecc14babc 100644
--- a/tests/python/disco/test_ccl.py
+++ b/tests/python/disco/test_ccl.py
@@ -77,6 +77,29 @@ def test_allreduce(session_kind, ccl):
         np.testing.assert_equal(result, expected)
 
 
+@pytest.mark.parametrize("session_kind", _all_session_kinds)
+@pytest.mark.parametrize("ccl", _ccl)
+def test_allgather(session_kind, ccl):
+    devices = [0, 1]
+    sess = session_kind(num_workers=len(devices))
+    sess.init_ccl(ccl, *devices)
+
+    array = np.arange(36, dtype="float32")
+    d_src = sess.empty((3, 3, 2), "float32")
+    d_dst = sess.empty((3, 4, 3), "float32")
+    d_src.debug_copy_from(0, array[:18])
+    d_src.debug_copy_from(1, array[18:])
+    sess.allgather(d_src, d_dst)
+    np.testing.assert_equal(
+        d_dst.debug_get_from_remote(0).numpy(),
+        array.reshape(3, 4, 3),
+    )
+    np.testing.assert_equal(
+        d_dst.debug_get_from_remote(1).numpy(),
+        array.reshape(3, 4, 3),
+    )
+
+
 @pytest.mark.parametrize("session_kind", _all_session_kinds)
 @pytest.mark.parametrize("ccl", _ccl)
 def test_broadcast_from_worker0(session_kind, ccl):
diff --git a/tests/python/relax/test_op_ccl.py b/tests/python/relax/test_op_ccl.py
index 1c98df4ea7..6b2f375e59 100644
--- a/tests/python/relax/test_op_ccl.py
+++ b/tests/python/relax/test_op_ccl.py
@@ -27,6 +27,7 @@ def test_op_correctness():
     x = relax.Var("x", R.Tensor((2, 3), "float32"))
     assert relax.op.ccl.allreduce(x).op == Op.get("relax.ccl.allreduce")
     assert relax.op.ccl.broadcast_from_worker0(x).op == Op.get("relax.ccl.broadcast_from_worker0")
+    assert relax.op.ccl.allgather(x, 2).op == Op.get("relax.ccl.allgather")
 
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
@@ -86,6 +87,60 @@ def test_allreduce_infer_struct_info_more_input_dtype():
     _check_inference(bb, relax.op.ccl.allreduce(x2), relax.TensorStructInfo((2, 3), "int64"))
 
 
+def test_allgather_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
+    x3 = relax.Var("x", R.Tensor((2, 3)))
+    x4 = relax.Var("x", R.Tensor())
+    x5 = relax.Var("x", R.Tensor((3, 4)))
+
+    _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo((4, 3), "float32"))
+    _check_inference(
+        bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo(dtype="float32", ndim=3)
+    )
+    _check_inference(bb, relax.op.ccl.allgather(x2, 2), relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.ccl.allgather(x3, 2), relax.TensorStructInfo((4, 3), dtype=""))
+    _check_inference(bb, relax.op.ccl.allgather(x4, 2), relax.TensorStructInfo(dtype=""))
+    _check_inference(bb, relax.op.ccl.allgather(x5, 2), relax.TensorStructInfo((6, 4), dtype=""))
+
+
+def test_allgather_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+    x1 = relax.Var("x", R.Tensor((4, n), "float32"))
+
+    _check_inference(
+        bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo((m * 2, n), "float32")
+    )
+    _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo((8, n), "float32"))
+
+
+def test_allgather_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+    s1 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+    _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo(s0, "float32"))
+    _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo(s1, "float32"))
+
+
+def test_allgather_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float64"))
+    x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3), "int64"))
+
+    _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo((4, 3), "float64"))
+    _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo((4, 3), "int8"))
+    _check_inference(bb, relax.op.ccl.allgather(x2, 2), relax.TensorStructInfo((4, 3), "int64"))
+
+
 def test_broadcast_from_worker0_infer_struct_info():
     bb = relax.BlockBuilder()
     x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py
index 071c9bc939..bb2e8f3394 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -51,6 +51,29 @@ def test_allreduce():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_allgather():
+    # fmt: off
+    @tvm.script.ir_module
+    class AllGather:
+        @R.function
+        def main(x: R.Tensor((10, 10), "float32"))  -> R.Tensor((10, 10), "float32"):
+            gv0: R.Tensor((20, 10), "float32") = R.ccl.allgather(x, 2)
+            gv1 = R.ccl.allgather(x, 2)
+            return x
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"):
+            gv0: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10), dtype="float32"))
+            gv1: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10), dtype="float32"))
+            return x
+    # fmt: on
+
+    mod = LegalizeOps()(AllGather)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_broadcast_from_zero():
     # fmt: off
     @tvm.script.ir_module