You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/09/27 15:02:25 UTC
[tvm] branch unity updated: [Disco] Add AllGather (#15764)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 230f8b2491 [Disco] Add AllGather (#15764)
230f8b2491 is described below
commit 230f8b249147ddaf8a514282f6dad96337ac8f5e
Author: Farshid Salemi Parizi <fp...@octoml.ai>
AuthorDate: Wed Sep 27 08:02:16 2023 -0700
[Disco] Add AllGather (#15764)
* [Disco] Add AllGather
* update to allgather to be compatiple with #15766
* lint fix
* add num workers to all gather
* change num_gpus to prim value instead of int
* fix typo
* remove all_gather attrs and minor improvment
* Update allgather doc
Co-authored-by: Lesheng Jin <34...@users.noreply.github.com>
---------
Co-authored-by: Lesheng Jin <34...@users.noreply.github.com>
---
python/tvm/relax/op/ccl/ccl.py | 26 +++++++++-
python/tvm/relax/transform/legalize_ops/ccl.py | 25 ++++++++++
python/tvm/runtime/disco/session.py | 17 +++++++
src/relax/op/ccl/ccl.cc | 36 ++++++++++++++
src/relax/op/ccl/ccl.h | 3 ++
src/runtime/disco/builtin.cc | 3 ++
src/runtime/disco/builtin.h | 6 +++
src/runtime/disco/nccl/nccl.cc | 11 +++++
tests/python/disco/test_ccl.py | 23 +++++++++
tests/python/relax/test_op_ccl.py | 55 ++++++++++++++++++++++
.../relax/test_transform_legalize_ops_ccl.py | 23 +++++++++
11 files changed, 227 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py
index 62e5914a89..4829bac761 100644
--- a/python/tvm/relax/op/ccl/ccl.py
+++ b/python/tvm/relax/op/ccl/ccl.py
@@ -15,9 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""Relax Collective Communications Library (CCL) operators"""
-from . import _ffi_api
+from typing import Union
+from tvm.relax import PrimValue
+from . import _ffi_api
from ...expr import Expr
+from ....ir import PrimExpr
def allreduce(x, op_type: str = "sum"): # pylint: disable=invalid-name
@@ -44,6 +47,27 @@ def allreduce(x, op_type: str = "sum"): # pylint: disable=invalid-name
return _ffi_api.allreduce(x, op_type) # type: ignore # pylint: disable=no-member
+def allgather(x, num_workers: Union[int, PrimExpr, PrimValue]): # pylint: disable=invalid-name
+ """AllGather operator
+
+ Parameters
+ ----------
+ x : relax.Expr
+ The input tensor.
+
+ num_worker : Union[int, PrimExpr, PrimValue]
+ The number of workers to gather data from.
+
+ Returns
+ -------
+ result : relax.Expr
+ The result of allgather.
+ """
+ if not isinstance(num_workers, PrimValue):
+ num_workers = PrimValue(num_workers)
+ return _ffi_api.allgather(x, num_workers) # type: ignore # pylint: disable=no-member
+
+
def broadcast_from_worker0(x: Expr) -> Expr:
"""Broadcast data from worker-0 to all other workers.
diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py b/python/tvm/relax/transform/legalize_ops/ccl.py
index 7b51cb7738..9b13d1be7c 100644
--- a/python/tvm/relax/transform/legalize_ops/ccl.py
+++ b/python/tvm/relax/transform/legalize_ops/ccl.py
@@ -46,6 +46,31 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.ccl.allgather")
+def _allgather(_bb: BlockBuilder, call: Call) -> Expr:
+ output_shape = []
+ arg_sinfo = call.args[0].struct_info
+ assert isinstance(
+ arg_sinfo, TensorStructInfo
+ ), "The input struct info of allgather should be TensorStructInfo."
+ assert isinstance(arg_sinfo.shape.struct_info, ShapeStructInfo)
+ arg_shape = arg_sinfo.shape.struct_info
+ for i, shape_value in enumerate(arg_shape.values):
+ if i == 0:
+ output_shape.append(shape_value * call.args[1].value)
+ else:
+ output_shape.append(shape_value)
+ return call_dps_packed(
+ "runtime.disco.allgather",
+ call.args[0],
+ out_sinfo=TensorStructInfo(
+ shape=output_shape,
+ dtype=arg_sinfo.dtype,
+ vdevice=arg_sinfo.vdevice,
+ ),
+ )
+
+
@register_legalize("relax.ccl.broadcast_from_worker0")
def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
return call_dps_packed(
diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py
index d05561c2d1..bd015f945e 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -329,6 +329,23 @@ class Session(Object):
func = self._get_cached_method("runtime.disco.allreduce")
func(src, op, dst)
+ def allgather(
+ self,
+ src: DRef,
+ dst: DRef,
+ ) -> DRef:
+ """Perform an allgather operation on an array.
+
+ Parameters
+ ----------
+ src : DRef
+ The array to be gathered from.
+ dst : DRef
+ The array to be gathered to.
+ """
+ func = self._get_cached_method("runtime.disco.allgather")
+ func(src, dst)
+
@register_object("runtime.disco.ThreadedSession")
class ThreadedSession(Session):
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index a114cb1dae..4372dd0aa6 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -50,6 +50,42 @@ TVM_REGISTER_OP("relax.ccl.allreduce")
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.ccl.allgather */
+Expr allgather(Expr x, Expr num_workers) {
+ static const Op& op = Op::Get("relax.ccl.allgather");
+ return Call(op, {std::move(x), std::move(num_workers)});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather);
+
+StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) {
+ CHECK_EQ(call->args.size(), 2);
+ auto input_sinfo = Downcast<TensorStructInfo>(call->args[0]->struct_info_);
+ auto num_workers_sinfo = Downcast<PrimStructInfo>(call->args[1]->struct_info_);
+
+ auto num_workers = num_workers_sinfo->value;
+
+ DataType output_dtype = input_sinfo->dtype;
+ auto input_shape = input_sinfo->GetShape();
+ if (!input_shape.defined()) {
+ return input_sinfo;
+ }
+ Array<PrimExpr> output_shape = input_shape.value();
+ output_shape.Set(0, floor(output_shape[0] * num_workers.value()));
+ VDevice vdevice;
+ if (input_sinfo->vdevice.defined()) {
+ vdevice = input_sinfo->vdevice.value();
+ }
+ return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdevice);
+}
+
+TVM_REGISTER_OP("relax.ccl.allgather")
+ .set_num_inputs(1)
+ .add_argument("x", "Tensor", "Input to which allgather will be applied.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAllGather)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.ccl.broadcast_from_worker0 */
Expr broadcast_from_worker0(Expr x) {
static const Op& op = Op::Get("relax.ccl.broadcast_from_worker0");
diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h
index 169b4b446c..7742997d5a 100644
--- a/src/relax/op/ccl/ccl.h
+++ b/src/relax/op/ccl/ccl.h
@@ -35,6 +35,9 @@ namespace relax {
/*! \brief AllReduce. */
Expr allreduce(Expr data, String op_type);
+/*! \brief AllGather. */
+Expr allgather(Expr data, Expr num_workers);
+
/*! \brief Broadcast data from worker-0 to all other workers. */
Expr broadcast_from_worker0(Expr data);
diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc
index 06408c723a..5aea39cf66 100644
--- a/src/runtime/disco/builtin.cc
+++ b/src/runtime/disco/builtin.cc
@@ -84,6 +84,8 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
GetCCLFunc("allreduce")(send, static_cast<int>(reduce_kind), recv);
}
+void AllGather(NDArray send, NDArray recv) { GetCCLFunc("allgather")(send, recv); }
+
void BroadcastFromWorker0(NDArray send, NDArray recv) {
GetCCLFunc("broadcast_from_worker0")(send, recv);
}
@@ -114,6 +116,7 @@ TVM_REGISTER_GLOBAL("runtime.disco.allreduce")
CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind;
AllReduce(send, static_cast<ReduceKind>(kind), recv);
});
+TVM_REGISTER_GLOBAL("runtime.disco.allgather").set_body_typed(AllGather);
TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(BroadcastFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0);
diff --git a/src/runtime/disco/builtin.h b/src/runtime/disco/builtin.h
index 10b20562d0..cfbf2e2477 100644
--- a/src/runtime/disco/builtin.h
+++ b/src/runtime/disco/builtin.h
@@ -52,6 +52,12 @@ NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device device);
* \return The outcome of allreduce
*/
void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv);
+/*!
+ * \brief Perform an allgather operation using the underlying communication library
+ * \param send The array send to perform allgather on
+ * \return The outcome of allgather
+ */
+void AllGather(NDArray send, NDArray recv);
/*!
* \brief Perform a broadcast operation from worker-0
* \param buffer The buffer to be broadcasted
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 6c882176ff..0ce8b985ff 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -192,6 +192,15 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {
/*op=*/AsNCCLRedOp(reduce_kind), ctx->comm, stream));
}
+void AllGather(NDArray send, NDArray recv) {
+ CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
+ ShapeTuple shape = send.Shape();
+ int64_t numel = shape->Product();
+ cudaStream_t stream = ctx->GetDefaultStream();
+ NCCL_CALL(ncclAllGather(send->data, recv->data, numel,
+ /*datatype=*/AsNCCLDataType(DataType(send->dtype)), ctx->comm, stream));
+}
+
void BroadcastFromWorker0(NDArray send, NDArray recv) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ICHECK(send.Shape()->Product() == recv.Shape()->Product());
@@ -316,6 +325,8 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce")
CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind;
AllReduce(send, static_cast<ReduceKind>(kind), recv);
});
+TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather")
+ .set_body_typed([](NDArray send, NDArray recv) { AllGather(send, recv); });
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0")
.set_body_typed(BroadcastFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0")
diff --git a/tests/python/disco/test_ccl.py b/tests/python/disco/test_ccl.py
index ecd2e07287..4ecc14babc 100644
--- a/tests/python/disco/test_ccl.py
+++ b/tests/python/disco/test_ccl.py
@@ -77,6 +77,29 @@ def test_allreduce(session_kind, ccl):
np.testing.assert_equal(result, expected)
+@pytest.mark.parametrize("session_kind", _all_session_kinds)
+@pytest.mark.parametrize("ccl", _ccl)
+def test_allgather(session_kind, ccl):
+ devices = [0, 1]
+ sess = session_kind(num_workers=len(devices))
+ sess.init_ccl(ccl, *devices)
+
+ array = np.arange(36, dtype="float32")
+ d_src = sess.empty((3, 3, 2), "float32")
+ d_dst = sess.empty((3, 4, 3), "float32")
+ d_src.debug_copy_from(0, array[:18])
+ d_src.debug_copy_from(1, array[18:])
+ sess.allgather(d_src, d_dst)
+ np.testing.assert_equal(
+ d_dst.debug_get_from_remote(0).numpy(),
+ array.reshape(3, 4, 3),
+ )
+ np.testing.assert_equal(
+ d_dst.debug_get_from_remote(1).numpy(),
+ array.reshape(3, 4, 3),
+ )
+
+
@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_broadcast_from_worker0(session_kind, ccl):
diff --git a/tests/python/relax/test_op_ccl.py b/tests/python/relax/test_op_ccl.py
index 1c98df4ea7..6b2f375e59 100644
--- a/tests/python/relax/test_op_ccl.py
+++ b/tests/python/relax/test_op_ccl.py
@@ -27,6 +27,7 @@ def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3), "float32"))
assert relax.op.ccl.allreduce(x).op == Op.get("relax.ccl.allreduce")
assert relax.op.ccl.broadcast_from_worker0(x).op == Op.get("relax.ccl.broadcast_from_worker0")
+ assert relax.op.ccl.allgather(x, 2).op == Op.get("relax.ccl.allgather")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
@@ -86,6 +87,60 @@ def test_allreduce_infer_struct_info_more_input_dtype():
_check_inference(bb, relax.op.ccl.allreduce(x2), relax.TensorStructInfo((2, 3), "int64"))
+def test_allgather_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
+ x3 = relax.Var("x", R.Tensor((2, 3)))
+ x4 = relax.Var("x", R.Tensor())
+ x5 = relax.Var("x", R.Tensor((3, 4)))
+
+ _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo((4, 3), "float32"))
+ _check_inference(
+ bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo(dtype="float32", ndim=3)
+ )
+ _check_inference(bb, relax.op.ccl.allgather(x2, 2), relax.TensorStructInfo(dtype="float32"))
+ _check_inference(bb, relax.op.ccl.allgather(x3, 2), relax.TensorStructInfo((4, 3), dtype=""))
+ _check_inference(bb, relax.op.ccl.allgather(x4, 2), relax.TensorStructInfo(dtype=""))
+ _check_inference(bb, relax.op.ccl.allgather(x5, 2), relax.TensorStructInfo((6, 4), dtype=""))
+
+
+def test_allgather_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ m = tir.Var("m", "int64")
+ n = tir.Var("n", "int64")
+ x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+ x1 = relax.Var("x", R.Tensor((4, n), "float32"))
+
+ _check_inference(
+ bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo((m * 2, n), "float32")
+ )
+ _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo((8, n), "float32"))
+
+
+def test_allgather_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+ s1 = relax.Var("s", relax.ShapeStructInfo())
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+ _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo(s0, "float32"))
+ _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo(s1, "float32"))
+
+
+def test_allgather_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3), "float64"))
+ x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
+ x2 = relax.Var("x", R.Tensor((2, 3), "int64"))
+
+ _check_inference(bb, relax.op.ccl.allgather(x0, 2), relax.TensorStructInfo((4, 3), "float64"))
+ _check_inference(bb, relax.op.ccl.allgather(x1, 2), relax.TensorStructInfo((4, 3), "int8"))
+ _check_inference(bb, relax.op.ccl.allgather(x2, 2), relax.TensorStructInfo((4, 3), "int64"))
+
+
def test_broadcast_from_worker0_infer_struct_info():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py
index 071c9bc939..bb2e8f3394 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -51,6 +51,29 @@ def test_allreduce():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_allgather():
+ # fmt: off
+ @tvm.script.ir_module
+ class AllGather:
+ @R.function
+ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 10), "float32"):
+ gv0: R.Tensor((20, 10), "float32") = R.ccl.allgather(x, 2)
+ gv1 = R.ccl.allgather(x, 2)
+ return x
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"):
+ gv0: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10), dtype="float32"))
+ gv1: R.Tensor((20, 10), dtype="float32") = R.call_dps_packed("runtime.disco.allgather", [x], out_sinfo=R.Tensor((20, 10), dtype="float32"))
+ return x
+ # fmt: on
+
+ mod = LegalizeOps()(AllGather)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_broadcast_from_zero():
# fmt: off
@tvm.script.ir_module