You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/08 04:30:02 UTC
[tvm] branch unity updated: [Disco][Op] scatter_from_worker0 (#15680)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9e7227e084 [Disco][Op] scatter_from_worker0 (#15680)
9e7227e084 is described below
commit 9e7227e084b550d2e08ca058aecd0f3ac19672dd
Author: Lesheng Jin <34...@users.noreply.github.com>
AuthorDate: Thu Sep 7 21:29:56 2023 -0700
[Disco][Op] scatter_from_worker0 (#15680)
This pr introduces op scatter_from_worker0, which performs a scatter operation from worker-0, chunking the given buffer into equal parts.
---
include/tvm/relax/attrs/ccl.h | 12 +++++
python/tvm/relax/op/ccl/ccl.py | 19 ++++++++
python/tvm/relax/transform/legalize_ops/ccl.py | 34 ++++++++++++-
src/relax/op/ccl/ccl.cc | 48 ++++++++++++++++++
src/relax/op/ccl/ccl.h | 3 ++
tests/python/relax/test_op_ccl.py | 57 ++++++++++++++++++++++
.../relax/test_transform_legalize_ops_ccl.py | 21 ++++++++
7 files changed, 193 insertions(+), 1 deletion(-)
diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h
index 45de0e949c..b4b3880384 100644
--- a/include/tvm/relax/attrs/ccl.h
+++ b/include/tvm/relax/attrs/ccl.h
@@ -40,6 +40,18 @@ struct AllReduceAttrs : public tvm::AttrsNode<AllReduceAttrs> {
}
}; // struct AllReduceAttrs
+/*! \brief Attributes used in scatter_from_worker0 operators */
+struct ScatterFromWorker0Attrs : public tvm::AttrsNode<ScatterFromWorker0Attrs> {
+ int num_workers;
+
+ TVM_DECLARE_ATTRS(ScatterFromWorker0Attrs, "relax.attrs.ScatterFromWorker0Attrs") {
+ TVM_ATTR_FIELD(num_workers)
+ .describe(
+ "The number of workers, also the number of parts the given buffer should be chunked "
+ "into.");
+ }
+}; // struct ScatterFromWorker0Attrs
+
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/op/ccl/ccl.py b/python/tvm/relax/op/ccl/ccl.py
index 093e01b638..62e5914a89 100644
--- a/python/tvm/relax/op/ccl/ccl.py
+++ b/python/tvm/relax/op/ccl/ccl.py
@@ -58,3 +58,22 @@ def broadcast_from_worker0(x: Expr) -> Expr:
The same tensor, which has been broadcast to all other workers.
"""
return _ffi_api.broadcast_from_worker0(x)
+
+
+def scatter_from_worker0(x: Expr, num_workers: int) -> Expr:
+ """Perform a scatter operation from worker-0, chunking the given buffer into equal parts.
+
+ Parameters
+ ----------
+ x : relax.Expr
+ The buffer to be divided into equal parts and sent to each worker accordingly.
+
+ num_worker : int
+ The number of workers, i.e. the number of parts the given buffer should be chunked into.
+
+ Returns
+ -------
+ result : relax.Expr
+ Chunked Tensor received by different workers.
+ """
+ return _ffi_api.scatter_from_worker0(x, num_workers)
diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py b/python/tvm/relax/transform/legalize_ops/ccl.py
index c9c09952b2..019f1726f0 100644
--- a/python/tvm/relax/transform/legalize_ops/ccl.py
+++ b/python/tvm/relax/transform/legalize_ops/ccl.py
@@ -16,9 +16,11 @@
# under the License.
# pylint: disable=invalid-name
"""Default legalization function for ccl operators."""
+from tvm import tir, arith
from ...block_builder import BlockBuilder
from ...expr import Call, Expr, ShapeExpr
from ...op import call_pure_packed
+from ...struct_info import TensorStructInfo, ShapeStructInfo
from .common import register_legalize
@@ -46,9 +48,39 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr:
@register_legalize("relax.ccl.broadcast_from_worker0")
-def broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
+def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
return call_pure_packed(
"runtime.disco.broadcast_from_worker0",
call.args[0],
sinfo_args=call.args[0].struct_info,
)
+
+
+@register_legalize("relax.ccl.scatter_from_worker0")
+def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
+ output_shape = []
+ assert isinstance(
+ call.args[0].struct_info, TensorStructInfo
+ ), "The input struct info of scatter_from_worker0 should be TensorStructInfo."
+ assert isinstance(call.args[0].struct_info.shape.struct_info, ShapeStructInfo)
+ arg_shape = call.args[0].struct_info.shape.struct_info
+ for i, shape_value in enumerate(arg_shape.values):
+ if i == 0:
+ modulo = arith.Analyzer().simplify(shape_value % call.attrs.num_workers)
+ assert modulo == 0, (
+ "scatter_from_worker0 expects the size of axis 0 of input tensor "
+ "to be divisible by num_workers. However, the axis 0 of input tensor "
+ f"is {shape_value} while num_workers is {call.attrs.num_workers}"
+ )
+ output_shape.append(tir.div(shape_value, call.attrs.num_workers))
+ else:
+ output_shape.append(shape_value)
+ return call_pure_packed(
+ "runtime.disco.scatter_from_worker0",
+ call.args[0],
+ sinfo_args=TensorStructInfo(
+ shape=output_shape,
+ dtype=call.args[0].struct_info.dtype,
+ vdevice=call.args[0].struct_info.vdevice,
+ ),
+ )
diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc
index 6fa6c96db3..a114cb1dae 100644
--- a/src/relax/op/ccl/ccl.cc
+++ b/src/relax/op/ccl/ccl.cc
@@ -70,5 +70,53 @@ TVM_REGISTER_OP("relax.ccl.broadcast_from_worker0")
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.ccl.scatter_from_worker0 */
+TVM_REGISTER_NODE_TYPE(ScatterFromWorker0Attrs);
+
+Expr scatter_from_worker0(Expr data, int num_workers) {
+ ObjectPtr<ScatterFromWorker0Attrs> attrs = make_object<ScatterFromWorker0Attrs>();
+ attrs->num_workers = std::move(num_workers);
+ static const Op& op = Op::Get("relax.ccl.scatter_from_worker0");
+
+ return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.ccl.scatter_from_worker0").set_body_typed(scatter_from_worker0);
+
+StructInfo InferStructInfoScatterFromWorker0(const Call& call, const BlockBuilder& ctx) {
+ TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ DataType output_dtype = input_sinfo->dtype;
+
+ const auto* attrs = call->attrs.as<ScatterFromWorker0Attrs>();
+ int num_workers = attrs->num_workers;
+
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ auto input_shape = input_sinfo->GetShape();
+ CHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should have defined shape.";
+
+ if (analyzer->CanProve(floormod(input_shape.value()[0], PrimExpr(num_workers))) != 0) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "scatter_from_worker0 expects the size of axis 0 of input tensor to be "
+ "divisible by the "
+ "num_workers. However, the axis 0 of input tensor is "
+ << input_shape.value() << " while num_workers is " << num_workers);
+ }
+
+ Array<PrimExpr> output_shape = input_shape.value();
+ output_shape.Set(0, div(output_shape[0], num_workers));
+ if (input_sinfo->vdevice.defined()) {
+ return TensorStructInfo(ShapeExpr(output_shape), output_dtype, input_sinfo->vdevice.value());
+ }
+ return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
+}
+
+TVM_REGISTER_OP("relax.ccl.scatter_from_worker0")
+ .set_num_inputs(1)
+ .add_argument("x", "Tensor",
+ "The buffer to be divided into equal parts and sent to each worker accordingly.")
+ .set_attrs_type<ScatterFromWorker0Attrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterFromWorker0)
+ .set_attr<Bool>("FPurity", Bool(true));
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/ccl/ccl.h b/src/relax/op/ccl/ccl.h
index 55402f3d37..169b4b446c 100644
--- a/src/relax/op/ccl/ccl.h
+++ b/src/relax/op/ccl/ccl.h
@@ -38,6 +38,9 @@ Expr allreduce(Expr data, String op_type);
/*! \brief Broadcast data from worker-0 to all other workers. */
Expr broadcast_from_worker0(Expr data);
+/*! \brief Perform a scatter operation from worker-0, chunking the given buffer into equal parts. */
+Expr scatter_from_worker0(Expr data, int num_workers);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_op_ccl.py b/tests/python/relax/test_op_ccl.py
index 09924d27ec..1c98df4ea7 100644
--- a/tests/python/relax/test_op_ccl.py
+++ b/tests/python/relax/test_op_ccl.py
@@ -160,5 +160,62 @@ def test_broadcast_from_worker0_infer_struct_info_more_input_dtype():
)
+def test_scatter_from_worker0_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
+ x1 = relax.Var("x", R.Tensor((3, 4, 5)))
+
+ _check_inference(
+ bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorStructInfo((1, 3), "float32")
+ )
+ _check_inference(
+ bb, relax.op.ccl.scatter_from_worker0(x1, 3), relax.TensorStructInfo((1, 4, 5), dtype="")
+ )
+
+
+def test_scatter_from_worker0_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ m = tir.Var("m", "int64")
+ n = tir.Var("n", "int64")
+ x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+ x1 = relax.Var("x", R.Tensor((4, n), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.ccl.scatter_from_worker0(x0, 2),
+ relax.TensorStructInfo((tir.div(m, 2), n), "float32"),
+ )
+ _check_inference(
+ bb, relax.op.ccl.scatter_from_worker0(x1, 2), relax.TensorStructInfo((2, n), "float32")
+ )
+
+
+def test_scatter_from_worker0_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s", relax.ShapeStructInfo((2, 4, 8)))
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+
+ _check_inference(
+ bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorStructInfo((1, 4, 8), "float32")
+ )
+
+
+def test_scatter_from_worker0_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3), "float64"))
+ x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
+ x2 = relax.Var("x", R.Tensor((2, 3), "int64"))
+
+ _check_inference(
+ bb, relax.op.ccl.scatter_from_worker0(x0, 2), relax.TensorStructInfo((1, 3), "float64")
+ )
+ _check_inference(
+ bb, relax.op.ccl.scatter_from_worker0(x1, 2), relax.TensorStructInfo((1, 3), "int8")
+ )
+ _check_inference(
+ bb, relax.op.ccl.scatter_from_worker0(x2, 2), relax.TensorStructInfo((1, 3), "int64")
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py
index 9bce76cecb..b1da283e70 100644
--- a/tests/python/relax/test_transform_legalize_ops_ccl.py
+++ b/tests/python/relax/test_transform_legalize_ops_ccl.py
@@ -72,5 +72,26 @@ def test_broadcast_from_zero():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_scatter_from_worker0():
+ # fmt: off
+ @tvm.script.ir_module
+ class ScatterFromWorker0:
+ @R.function
+ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((5, 10), "float32"):
+ gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x, 2)
+ return gv0
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((5, 10), dtype="float32"):
+ gv0: R.Tensor((5, 10), dtype="float32") = R.call_pure_packed("runtime.disco.scatter_from_worker0", x, sinfo_args=R.Tensor((5, 10), dtype="float32"))
+ return gv0
+ # fmt: on
+
+ mod = LegalizeOps()(ScatterFromWorker0)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()