You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ho...@apache.org on 2024/02/02 18:57:28 UTC
(tvm) branch main updated: [Relax][Frontent] "tensor_ir_inplace" op (#16498)
This is an automated email from the ASF dual-hosted git repository.
hongyij pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5c68932cae [Relax][Frontent] "tensor_ir_inplace" op (#16498)
5c68932cae is described below
commit 5c68932cae820864cca10197ad2889531fca63ec
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Fri Feb 2 13:57:21 2024 -0500
[Relax][Frontent] "tensor_ir_inplace" op (#16498)
This PR introduces the `tensor_ir_inplace_op` for frontend
so that we can leverage our `call_tir_inplace` in SLM model
definition flow.
One unit test is added. This PR also fixed a few typos in
type annotations.
---
python/tvm/relax/frontend/nn/op.py | 69 ++++++++++++++++++++
python/tvm/relax/op/base.py | 20 +++---
tests/python/relax/test_frontend_nn_op.py | 105 ++++++++++++++++++++++++++++++
3 files changed, 184 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py
index 720e3dd3b4..fbca48f0ee 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1629,6 +1629,75 @@ def tensor_ir_op(
)
+def tensor_ir_inplace_op(
+ func: _tir.PrimFunc,
+ name_hint: str,
+ args: Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]],
+ inplace_indices: Union[int, List[int]],
+ out: OutType,
+) -> OutType:
+ """Create a `call_tir_inplace` binding with given PrimFunc
+
+ Parameters
+ ----------
+ func : _tir.PrimFunc
+ The PrimFunc to call.
+
+ name_hint : str
+ Name hint.
+
+ args : Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]]
+ The arguments to pass to the PrimFunc.
+
+ inplace_indices : Union[int, List[int]]
+ Specify which arguments should be used for in-place computations.
+ If `inplace_indices` is a single integer, it will be made into a singleton list.
+ Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
+ will be an alias of `args[j]`.
+ If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
+ At least one member of `inplace_indices` must not be -1.
+
+ out : Union[Tensor, List[Tensor]]
+ The output tensors.
+
+ Returns
+ -------
+ result : Tensor
+ The result tensor
+ """
+ from tvm import relax as rx # pylint: disable=import-outside-toplevel
+
+ call_tir_args, tir_vars = [], []
+ if not isinstance(args, (tuple, list)):
+ args = [args]
+
+ for arg in args:
+ if isinstance(arg, Tensor):
+ call_tir_args.append(arg._expr)
+ elif isinstance(arg, (rx.ShapeExpr, _tir.PrimExpr)):
+ tir_vars.append(arg)
+ else:
+ raise TypeError(
+ "Unsupported type: tensor_ir_inplace_op args expect Tensor or ShapeExpr or"
+ f" PrimExpr, but got {type(arg)}"
+ )
+
+ if isinstance(out, Tensor):
+ out_sinfo = [out._expr.struct_info]
+ else:
+ out_sinfo = [x._expr.struct_info for x in out]
+
+ bb = BlockBuilder.current()
+ global_var = bb.add_func(func, name_hint)
+
+ return wrap_nested(
+ bb.emit(
+ rx.call_tir_inplace(global_var, call_tir_args, inplace_indices, out_sinfo, tir_vars)
+ ),
+ name=name_hint,
+ )
+
+
def extern(
name: str,
args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]],
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index b363dc6952..92235ffb47 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -198,13 +198,13 @@ def call_tir_inplace(
args : Expr
The input arguments.
- input_indices : Union[int, List[int]]
+ inplace_indices : Union[int, List[int]]
Specify which arguments should be used for in-place computations.
- If `input_indices` is a single integer, it will be made into a singleton list.
- Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output
+ If `inplace_indices` is a single integer, it will be made into a singleton list.
+ Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
will be an alias of `args[j]`.
- If `input_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
- At least one member of `input_indices` must not be -1.
+ If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
+ At least one member of `inplace_indices` must not be -1.
out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
The structure info of the call_tir_inplace output.
@@ -637,13 +637,13 @@ def call_inplace_packed(
args: Expr
The arguments for the PackedFunc.
- input_indices : Union[int, List[int]]
+ inplace_indices : Union[int, List[int]]
Specify which arguments should be used for in-place computations.
- If `input_indices` is a single integer, it will be made into a singleton list.
- Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output
+ If `inplace_indices` is a single integer, it will be made into a singleton list.
+ Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
will be an alias of `args[j]`.
- If `input_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
- At least one member of `input_indices` must not be -1.
+ If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
+ At least one member of `inplace_indices` must not be -1.
sinfo_args: Union[StructInfo, List[StructInfo]]
The list of structure info arguments (giving the structural info for the returned value).
diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py
index ed2e3753b2..c74e06490f 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -619,6 +619,111 @@ def test_tensor_ir_op():
tvm.ir.assert_structural_equal(irmodule, Expected)
+def test_tensor_ir_inplace_op():
+ hidden_size = 4096
+ dtype = "float16"
+
+ @T.prim_func
+ def inplace_take(
+ var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ vocab_size = T.int64()
+ weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype)
+ seq_len = T.int64()
+ total_seq_len = T.int64()
+ pos = T.match_buffer(var_pos, (seq_len,), "int32")
+ embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype)
+ for ax0, ax1 in T.grid(seq_len, hidden_size):
+ with T.block("T_take"):
+ v0, v1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(weight[pos[v0], v1], pos[v0])
+ T.writes(embeddings[v0, v1])
+ embeddings[v0 + offset, v1] = weight[pos[v0], v1]
+
+ class Model(Module):
+ def test(
+ self, embedding_table: Tensor, input_ids: Tensor, embedding_dst: Tensor, offset: int
+ ):
+ tensor_expr_op_out = op.tensor_ir_op(
+ inplace_take,
+ "inplace_take",
+ args=[embedding_table, input_ids, embedding_dst, offset],
+ out=Tensor.placeholder(embedding_dst.shape, embedding_dst.dtype),
+ )
+ return tensor_expr_op_out
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def inplace_take(
+ var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ vocab_size = T.int64()
+ weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype)
+ seq_len = T.int64()
+ total_seq_len = T.int64()
+ pos = T.match_buffer(var_pos, (seq_len,), "int32")
+ embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype)
+ for ax0, ax1 in T.grid(seq_len, hidden_size):
+ with T.block("T_take"):
+ v0, v1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(weight[pos[v0], v1], pos[v0])
+ T.writes(embeddings[v0, v1])
+ embeddings[v0 + offset, v1] = weight[pos[v0], v1]
+
+ @R.function
+ def _initialize_effect() -> R.Tuple(R.Object):
+ with R.dataflow():
+ _io: R.Object = R.null_value()
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
+ R.output(gv)
+ return gv
+
+ @R.function
+ def test(
+ embedding_table: R.Tensor(("vocab_size", hidden_size), dtype),
+ input_ids: R.Tensor(("seq_len",), "int32"),
+ embedding_dst: R.Tensor(("total_seq_len", hidden_size), dtype),
+ offset: R.Shape(["offset_1"]),
+ packed_params: R.Tuple,
+ ) -> R.Tensor(("total_seq_len", hidden_size), dtype):
+ total_seq_len = T.int64()
+ offset_1 = T.int64()
+ R.func_attr({"num_input": 4})
+ cls = Expected
+ with R.dataflow():
+ lv1 = R.call_tir(
+ cls.inplace_take,
+ (embedding_table, input_ids, embedding_dst),
+ out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype),
+ tir_vars=R.shape([offset_1]),
+ )
+ gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1
+ R.output(gv1)
+ return gv1
+
+ m = Model()
+ irmodule, _ = m.export_tvm(
+ spec={
+ "test": {
+ "embedding_table": spec.Tensor(["vocab_size", hidden_size], dtype),
+ "input_ids": spec.Tensor(["seq_len"], "int32"),
+ "embedding_dst": spec.Tensor(["total_seq_len", hidden_size], dtype),
+ "offset": int,
+ "$": {
+ "param_mode": "packed",
+ "effect_mode": "none",
+ },
+ },
+ },
+ debug=True,
+ )
+ tvm.ir.assert_structural_equal(irmodule, Expected)
+
+
def test_extern():
class Model(Module):
def test(self, q: Tensor, k: Tensor, v: Tensor):