You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ho...@apache.org on 2024/02/02 18:57:28 UTC

(tvm) branch main updated: [Relax][Frontent] "tensor_ir_inplace" op (#16498)

This is an automated email from the ASF dual-hosted git repository.

hongyij pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c68932cae [Relax][Frontent] "tensor_ir_inplace" op (#16498)
5c68932cae is described below

commit 5c68932cae820864cca10197ad2889531fca63ec
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Fri Feb 2 13:57:21 2024 -0500

    [Relax][Frontent] "tensor_ir_inplace" op (#16498)
    
    This PR introduces the `tensor_ir_inplace_op` for frontend
    so that we can leverage our `call_tir_inplace` in SLM model
    definition flow.
    
    One unit test is added. This PR also fixed a few typos in
    type annotations.
---
 python/tvm/relax/frontend/nn/op.py        |  69 ++++++++++++++++++++
 python/tvm/relax/op/base.py               |  20 +++---
 tests/python/relax/test_frontend_nn_op.py | 105 ++++++++++++++++++++++++++++++
 3 files changed, 184 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py
index 720e3dd3b4..fbca48f0ee 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1629,6 +1629,75 @@ def tensor_ir_op(
     )
 
 
+def tensor_ir_inplace_op(
+    func: _tir.PrimFunc,
+    name_hint: str,
+    args: Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]],
+    inplace_indices: Union[int, List[int]],
+    out: OutType,
+) -> OutType:
+    """Create a `call_tir_inplace` binding with given PrimFunc
+
+    Parameters
+    ----------
+    func : _tir.PrimFunc
+        The PrimFunc to call.
+
+    name_hint : str
+        Name hint.
+
+    args : Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]]
+        The arguments to pass to the PrimFunc.
+
+    inplace_indices : Union[int, List[int]]
+        Specify which arguments should be used for in-place computations.
+        If `inplace_indices` is a single integer, it will be made into a singleton list.
+        Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
+        will be an alias of `args[j]`.
+        If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
+        At least one member of `inplace_indices` must not be -1.
+
+    out : Union[Tensor, List[Tensor]]
+        The output tensors.
+
+    Returns
+    -------
+    result : Tensor
+        The result tensor
+    """
+    from tvm import relax as rx  # pylint: disable=import-outside-toplevel
+
+    call_tir_args, tir_vars = [], []
+    if not isinstance(args, (tuple, list)):
+        args = [args]
+
+    for arg in args:
+        if isinstance(arg, Tensor):
+            call_tir_args.append(arg._expr)
+        elif isinstance(arg, (rx.ShapeExpr, _tir.PrimExpr)):
+            tir_vars.append(arg)
+        else:
+            raise TypeError(
+                "Unsupported type: tensor_ir_inplace_op args expect Tensor or ShapeExpr or"
+                f" PrimExpr, but got {type(arg)}"
+            )
+
+    if isinstance(out, Tensor):
+        out_sinfo = [out._expr.struct_info]
+    else:
+        out_sinfo = [x._expr.struct_info for x in out]
+
+    bb = BlockBuilder.current()
+    global_var = bb.add_func(func, name_hint)
+
+    return wrap_nested(
+        bb.emit(
+            rx.call_tir_inplace(global_var, call_tir_args, inplace_indices, out_sinfo, tir_vars)
+        ),
+        name=name_hint,
+    )
+
+
 def extern(
     name: str,
     args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]],
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index b363dc6952..92235ffb47 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -198,13 +198,13 @@ def call_tir_inplace(
     args : Expr
         The input arguments.
 
-    input_indices : Union[int, List[int]]
+    inplace_indices : Union[int, List[int]]
         Specify which arguments should be used for in-place computations.
-        If `input_indices` is a single integer, it will be made into a singleton list.
-        Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output
+        If `inplace_indices` is a single integer, it will be made into a singleton list.
+        Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
         will be an alias of `args[j]`.
-        If `input_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
-        At least one member of `input_indices` must not be -1.
+        If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
+        At least one member of `inplace_indices` must not be -1.
 
     out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
         The structure info of the call_tir_inplace output.
@@ -637,13 +637,13 @@ def call_inplace_packed(
     args: Expr
       The arguments for the PackedFunc.
 
-    input_indices : Union[int, List[int]]
+    inplace_indices : Union[int, List[int]]
       Specify which arguments should be used for in-place computations.
-      If `input_indices` is a single integer, it will be made into a singleton list.
-      Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output
+      If `inplace_indices` is a single integer, it will be made into a singleton list.
+      Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
       will be an alias of `args[j]`.
-      If `input_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
-      At least one member of `input_indices` must not be -1.
+      If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
+      At least one member of `inplace_indices` must not be -1.
 
     sinfo_args: Union[StructInfo, List[StructInfo]]
         The list of structure info arguments (giving the structural info for the returned value).
diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py
index ed2e3753b2..c74e06490f 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -619,6 +619,111 @@ def test_tensor_ir_op():
     tvm.ir.assert_structural_equal(irmodule, Expected)
 
 
+def test_tensor_ir_inplace_op():
+    hidden_size = 4096
+    dtype = "float16"
+
+    @T.prim_func
+    def inplace_take(
+        var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64
+    ):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        vocab_size = T.int64()
+        weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype)
+        seq_len = T.int64()
+        total_seq_len = T.int64()
+        pos = T.match_buffer(var_pos, (seq_len,), "int32")
+        embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype)
+        for ax0, ax1 in T.grid(seq_len, hidden_size):
+            with T.block("T_take"):
+                v0, v1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(weight[pos[v0], v1], pos[v0])
+                T.writes(embeddings[v0, v1])
+                embeddings[v0 + offset, v1] = weight[pos[v0], v1]
+
+    class Model(Module):
+        def test(
+            self, embedding_table: Tensor, input_ids: Tensor, embedding_dst: Tensor, offset: int
+        ):
+            tensor_expr_op_out = op.tensor_ir_op(
+                inplace_take,
+                "inplace_take",
+                args=[embedding_table, input_ids, embedding_dst, offset],
+                out=Tensor.placeholder(embedding_dst.shape, embedding_dst.dtype),
+            )
+            return tensor_expr_op_out
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def inplace_take(
+            var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            vocab_size = T.int64()
+            weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype)
+            seq_len = T.int64()
+            total_seq_len = T.int64()
+            pos = T.match_buffer(var_pos, (seq_len,), "int32")
+            embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype)
+            for ax0, ax1 in T.grid(seq_len, hidden_size):
+                with T.block("T_take"):
+                    v0, v1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(weight[pos[v0], v1], pos[v0])
+                    T.writes(embeddings[v0, v1])
+                    embeddings[v0 + offset, v1] = weight[pos[v0], v1]
+
+        @R.function
+        def _initialize_effect() -> R.Tuple(R.Object):
+            with R.dataflow():
+                _io: R.Object = R.null_value()
+                lv: R.Tuple(R.Object) = (_io,)
+                gv: R.Tuple(R.Object) = lv
+                R.output(gv)
+            return gv
+
+        @R.function
+        def test(
+            embedding_table: R.Tensor(("vocab_size", hidden_size), dtype),
+            input_ids: R.Tensor(("seq_len",), "int32"),
+            embedding_dst: R.Tensor(("total_seq_len", hidden_size), dtype),
+            offset: R.Shape(["offset_1"]),
+            packed_params: R.Tuple,
+        ) -> R.Tensor(("total_seq_len", hidden_size), dtype):
+            total_seq_len = T.int64()
+            offset_1 = T.int64()
+            R.func_attr({"num_input": 4})
+            cls = Expected
+            with R.dataflow():
+                lv1 = R.call_tir(
+                    cls.inplace_take,
+                    (embedding_table, input_ids, embedding_dst),
+                    out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype),
+                    tir_vars=R.shape([offset_1]),
+                )
+                gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1
+                R.output(gv1)
+            return gv1
+
+    m = Model()
+    irmodule, _ = m.export_tvm(
+        spec={
+            "test": {
+                "embedding_table": spec.Tensor(["vocab_size", hidden_size], dtype),
+                "input_ids": spec.Tensor(["seq_len"], "int32"),
+                "embedding_dst": spec.Tensor(["total_seq_len", hidden_size], dtype),
+                "offset": int,
+                "$": {
+                    "param_mode": "packed",
+                    "effect_mode": "none",
+                },
+            },
+        },
+        debug=True,
+    )
+    tvm.ir.assert_structural_equal(irmodule, Expected)
+
+
 def test_extern():
     class Model(Module):
         def test(self, q: Tensor, k: Tensor, v: Tensor):