You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by kp...@apache.org on 2023/07/11 16:37:57 UTC

[tvm] branch main updated: [TIR] Implement TIR macros (#15260)

This is an automated email from the ASF dual-hosted git repository.

kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fddbec7079 [TIR] Implement TIR macros (#15260)
fddbec7079 is described below

commit fddbec7079a817d3339b45fedd3c5b8326cfafac
Author: Krzysztof Parzyszek <kp...@quicinc.com>
AuthorDate: Tue Jul 11 11:37:51 2023 -0500

    [TIR] Implement TIR macros (#15260)
    
    * [TIR] Implement TIR macros
    
    This patch introduces two new symbols: `T.macro` and `T.insert`.
    `T.macro` is a decorator that, when applied to a function, turns the
    body of that function into a piece of TIR that can be inserted via
    `T.insert` into a PrimFunc.
    
    For example:
    
    ```python
    @T.macro
    def copy_backwards(dst, src, size):
        with T.block("backwards"):
            for i in T.serial(size):
                ai = T.axis.remap("S", [i])
                T.reads(src[0:size])
                T.writes(dst[0:size])
                dst[ai] = src[size - ai - 1]
    
    @T.prim_func
    def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")):
        T.insert(copy_backwards, A, B, 128)
    
    @T.prim_func
    def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")):
        T.insert(copy_backwards, A, B, 128)
    ```
    
    The above will generate two PrimFuncs that do the same backwards copy,
    but applied to buffers with different data types.
    
    Semantics:
    - Function that is decorated with @T.macro can have any parameters that
      follow Python syntax, i.e. positional, keyword, etc. Type annotations
      are not required, but are allowed.
    - The arguments to `T.insert` are macro name followed by the argument
      list.
      For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into
      the body of the macro as in the call `arg1(arg2, arg3, ...)`.
      The body with the substituted values is then inserted at the point
      where the `T.insert` is located.
    
    * Fix linter
    
    * Fix linter again
    
    One linter suggested something that the other didn't like...
    
    * Get rid of T.insert, apply macro via function-call syntax
    
    * Store closure vars in TIRMacro
    
    * ast.parse always returns ast.Module, hence doc is doc.Module
    
    * Simplify `expand_macro`, capture environment variables
    
    * Implement macro hygiene
    
    * Fix linter
    
    * Make T.macro work same as T.macro()
    
    The previous commit inadvertently made T.macro (without parentheses)
    illegal, only abbreviated form allowed was T.macro(). Restore T.macro
    as a valid decorator use.
    
    * Edit comment: insertion -> expansion
    
    * Add import pytest
    
    * One more typo...
    
    * Remove stale testcase
---
 python/tvm/script/parser/_core.py                  |   2 +-
 python/tvm/script/parser/core/entry.py             |  31 +++---
 python/tvm/script/parser/tir/__init__.py           |   4 +-
 python/tvm/script/parser/tir/entry.py              |  99 ++++++++++++++++++-
 python/tvm/script/parser/tir/parser.py             |  60 +++++++++++-
 tests/python/unittest/test_tvmscript_parser_tir.py | 107 +++++++++++++++++++++
 6 files changed, 286 insertions(+), 17 deletions(-)

diff --git a/python/tvm/script/parser/_core.py b/python/tvm/script/parser/_core.py
index 4f5411dc36..b7ba5ee471 100644
--- a/python/tvm/script/parser/_core.py
+++ b/python/tvm/script/parser/_core.py
@@ -18,5 +18,5 @@
 # pylint: disable=unused-import
 from .core import dispatch, doc, utils
 from .core.dispatch import OpMethod, register_op
-from .core.entry import parse
+from .core.entry import parse, parse_macro
 from .core.parser import Parser
diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py
index 5315c0f675..08a593d5d3 100644
--- a/python/tvm/script/parser/core/entry.py
+++ b/python/tvm/script/parser/core/entry.py
@@ -25,6 +25,25 @@ from .error import ParserError
 from .parser import Parser
 
 
+def _default_globals() -> Dict[str, Any]:
+    import tvm  # pylint: disable=import-outside-toplevel
+    from tvm.script.parser import ir  # pylint: disable=import-outside-toplevel
+    from tvm.script.parser import tir  # pylint: disable=import-outside-toplevel
+
+    extra_vars = {"tvm": tvm, "I": ir, "ir": ir, "T": tir, "tir": tir}
+    return extra_vars
+
+
+def parse_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> Any:
+    """Generate the AST, and the source code for __repr__."""
+    # The AST will be converted into TIR at the time of expansion.
+    source = Source(program)
+    source_txt = source.source
+    source_ast = source.as_ast()
+    closure_vars = extra_vars or _default_globals()
+    return source_ast, source_txt, closure_vars
+
+
 def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any:
     """Register a method for a operand type, AST operator node and operand index.
 
@@ -42,17 +61,7 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None)
         The parsed TVMScript program.
     """
     if extra_vars is None:
-        import tvm  # pylint: disable=import-outside-toplevel
-        from tvm.script.parser import ir  # pylint: disable=import-outside-toplevel
-        from tvm.script.parser import tir  # pylint: disable=import-outside-toplevel
-
-        extra_vars = {
-            "tvm": tvm,
-            "I": ir,
-            "ir": ir,
-            "T": tir,
-            "tir": tir,
-        }
+        extra_vars = _default_globals()
 
     ann = {}
     if inspect.isfunction(program):
diff --git a/python/tvm/script/parser/tir/__init__.py b/python/tvm/script/parser/tir/__init__.py
index ad16821a89..9d3fc1ec98 100644
--- a/python/tvm/script/parser/tir/__init__.py
+++ b/python/tvm/script/parser/tir/__init__.py
@@ -30,6 +30,6 @@ if TYPE_CHECKING:
     # so most tvmscript won't trigger pylint error here.
     prim_func = staticmethod
 else:
-    from .entry import prim_func
+    from .entry import prim_func, macro
 
-__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"]
+__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func", "macro"]
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
index d5bff7a856..64b71d699f 100644
--- a/python/tvm/script/parser/tir/entry.py
+++ b/python/tvm/script/parser/tir/entry.py
@@ -16,13 +16,13 @@
 # under the License.
 """The entry point of TVM parser for tir."""
 import inspect
-from typing import Callable, Union
+from typing import Any, Callable, Dict, Union
 
 from tvm.ir.base import deprecated
 from tvm.tir import Buffer, PrimFunc
 
 from ...ir_builder.tir import buffer, ptr
-from .._core import parse, utils
+from .._core import doc, parse, parse_macro, utils
 
 
 def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
@@ -50,6 +50,101 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
 setattr(prim_func, "dispatch_token", "tir")
 
 
+# Semantics of TIR macros:
+# - Function that is decorated with @T.macro can have any parameters that
+#   follow Python syntax, i.e. positional, keyword, etc. Type annotations
+#   are not required, but are allowed.
+# - Macro use follows the same syntax as a function call.
+#   For `macro_name(arg1, arg2, arg3, ...)`, the values are substituted into
+#   the body of the macro, and the body with the substituted values is then
+#   inserted at the point where the call to the macro is located.
+
+
+class TIRMacro:
+    """Representation of T.macro."""
+
+    def __init__(
+        self,
+        source_ast: doc.AST,
+        source_txt: str,
+        closure_vars: Dict[str, Any],
+        func: Callable,
+        hygienic: bool,
+    ) -> None:
+        self.source_ast = source_ast
+        self.source_txt = source_txt
+        self.closure_vars = closure_vars
+        self.func = func
+        self.hygienic = hygienic
+
+    def __repr__(self):
+        return self.source_txt
+
+
+def macro(*args, hygienic: bool = True) -> Callable:
+    """Decorator for macro definitions.
+
+    Parameters
+    ----------
+    hygienic: bool
+        Specifies whether the macro is hygienic or not.
+        A macro is hygienic if all symbols used in the macro's body are resolved
+        to values from the location of the macro definition. A non-hygienic macro
+        will have its symbols resolved to values at the time of the macro's use.
+
+        Example:
+        ```
+        import tvm
+        from tvm.script import tir as T
+
+        x_value = 128
+
+        @T.macro(hygienic=True)
+        def static_capture(A, B):
+            B[()] = A[x_value]          ### x_value binds to 128
+
+        @T.macro(hygienic=False)
+        def dynamic_capture(A, B):
+            B[()] = A[x_value]          ### x_value will bind at the time of use
+
+
+        @T.prim_func
+        def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
+            for x_value in T.serial(10):
+                static_capture(A, B)    ### Produces B[()] = A[128]
+
+        @T.prim_func
+        def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
+            for x_value in T.serial(10):
+                dynamic_capture(A, B)   ### Produces B[()] = A[x_value]
+        ```
+    """
+
+    def _decorator(func: Callable) -> TIRMacro:
+        source_ast, source_txt, closure_vars = parse_macro(
+            func, utils.inspect_function_capture(func)
+        )
+        obj = TIRMacro(source_ast, source_txt, closure_vars, func, hygienic)
+        obj.__name__ = func.__name__
+        # We don't need to explicitly store the return value anywhere.
+        # This function is a decorator, so the return value will replace
+        # the function definition (to which the decorator it is applied)
+        # in that function's name space.
+        return obj
+
+    if len(args) == 0:
+        return _decorator
+    if len(args) == 1 and inspect.isfunction(args[0]):
+        return _decorator(args[0])
+
+    raise ValueError(
+        "Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])"
+    )
+
+
+# There is no dispatch_token for macro, because macro doesn't invoke parser.
+
+
 class BufferProxy:
     """Buffer proxy class for constructing tir buffer."""
 
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
index f81f9bd9ea..67e14d0e97 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -17,8 +17,9 @@
 """The base parser for tir"""
 
 import contextlib
+import inspect
 from functools import partial
-from typing import Any
+from typing import Any, Union
 
 import tvm
 from tvm.ir import GlobalVar, PrimType
@@ -29,6 +30,8 @@ from ...ir_builder import tir as T
 from ...ir_builder.base import IRBuilder
 from ...ir_builder.base import IRBuilderFrame as Frame
 from .._core import Parser, dispatch, doc
+from ..core.parser import VarTable
+from .entry import TIRMacro
 
 
 def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
@@ -427,6 +430,12 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
     node : doc.Expr
         The doc AST Expr node.
     """
+
+    if isinstance(node.value, doc.Call):
+        callee = self.eval_expr(node.value.func)
+        if isinstance(callee, TIRMacro):
+            return expand_macro(self, callee, node.value)
+
     res = self.eval_expr(node.value)
     if res is None:
         pass
@@ -447,6 +456,7 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
         pass
     else:
         self.report_error(node, f"Parsing resulted in unexpected type {type(res)}")
+    return None  # For pylint
 
 
 @dispatch.register(token="tir", type_name="If")
@@ -528,3 +538,51 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar
     # Only ret_type is needed for func_signature.
     func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
     return I.decl_function(node.name, func_signature)
+
+
+def expand_macro(self: Parser, callee: TIRMacro, call: doc.Call) -> None:
+    """Bind arguments to the macro invocation to the parameters in the macro definition,
+    and pass the macro body for further parsing.
+    """
+
+    assert isinstance(callee, TIRMacro), f"Unexpected macro type {type(callee)}"
+
+    def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]:
+        for decl in decl_list:
+            if isinstance(decl, doc.FunctionDef) and decl.name == name:
+                return decl
+        return None
+
+    macro_def = find_macro_def(callee.__name__, callee.source_ast.body)
+    assert macro_def is not None, f"Invalid macro AST for {callee.__name__}"
+    # `macro_def` is the FunctionDef of the macro.
+
+    args = [self.eval_expr(arg) for arg in call.args]
+    kwargs = {kw.arg: self.eval_expr(kw.value) for kw in call.keywords}
+    param_binding = inspect.signature(callee.func).bind(*args, **kwargs)
+    param_binding.apply_defaults()
+    local_vars = param_binding.arguments
+
+    if callee.hygienic:
+        # If the macro was hygienic, construct new var_table with a single frame that
+        # contains the captured environment, and process the macro's body with that
+        # frame.
+        saved_var_table = self.var_table
+        self.var_table = VarTable()
+        with self.var_table.with_frame():
+            for k, v in callee.closure_vars.items():
+                self.var_table.add(k, v)
+            for k, v in local_vars.items():
+                self.var_table.add(k, v)
+
+            self.visit_body(macro_def.body)
+
+        self.var_table = saved_var_table
+
+    else:
+        # Otherwise, dynamically resolve symbols in the macro's body.
+        with self.var_table.with_frame():
+            for k, v in local_vars.items():
+                self.var_table.add(k, v)
+
+            self.visit_body(macro_def.body)
diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py
index 31bf5cc101..38d3e14746 100644
--- a/tests/python/unittest/test_tvmscript_parser_tir.py
+++ b/tests/python/unittest/test_tvmscript_parser_tir.py
@@ -16,6 +16,7 @@
 # under the License.
 """Unittests for tvm.script.parser.tir"""
 
+import pytest
 import tvm.testing
 from tvm.script.parser import tir as T
 from tvm import ir, tir
@@ -71,5 +72,111 @@ def test_tir_func_name():
     assert matmul.__name__ == "matmul"
 
 
+def test_tir_macro_decorator_signature():
+    @T.prim_func
+    def evaluate0():
+        T.evaluate(0)
+
+    # Ok, no parentheses
+    @T.macro
+    def func1():
+        T.evaluate(0)
+
+    assert func1.hygienic
+
+    @T.prim_func
+    def use1():
+        func1()
+
+    tvm.ir.assert_structural_equal(use1, evaluate0)
+
+    # Ok, empty parentheses
+    @T.macro()
+    def func2():
+        T.evaluate(0)
+
+    assert func2.hygienic
+
+    @T.prim_func
+    def use2():
+        func2()
+
+    tvm.ir.assert_structural_equal(use1, evaluate0)
+
+    with pytest.raises(ValueError):
+        # Wrong: non-keyword argument
+        @T.macro(True)
+        def func3():
+            T.evaluate()
+
+
+def test_tir_macro_signature():
+    @T.macro
+    def assign(i, *args, t1, **kwargs):
+        vi, vj, vk = T.axis.remap("SSR", [i, args[0], args[1]])
+        kwargs["t3"][vi, vj] = kwargs["t3"][vi, vj] + t1[vi, vk] * kwargs["t2"][vj, vk]
+
+    @T.prim_func
+    def matmul_w_macro(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, [128, 128])
+        B = T.match_buffer(b, [128, 128])
+        C = T.match_buffer(c, [128, 128])
+        for i, j, k in T.grid(128, 128, 128):
+            with T.block("update"):
+                assign(i, j, k, t1=A, t2=B, t3=C)
+
+    @T.prim_func
+    def matmul_no_macro(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, [128, 128])
+        B = T.match_buffer(b, [128, 128])
+        C = T.match_buffer(c, [128, 128])
+        for i, j, k in T.grid(128, 128, 128):
+            with T.block("update"):
+                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+    tvm.ir.assert_structural_equal(matmul_no_macro, matmul_w_macro)
+
+
+def test_tir_macro_hygienic():
+    x_value = 128
+
+    @T.macro(hygienic=True)
+    def static_capture(A, B):
+        B[()] = A[x_value]
+
+    @T.prim_func
+    def use_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
+        for x_value in T.serial(10):
+            static_capture(A, B)
+
+    @T.prim_func
+    def expected_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
+        for x_value in range(10):
+            B[()] = A[128]
+
+    tvm.ir.assert_structural_equal(use_hygienic, expected_hygienic)
+
+
+def test_tir_macro_non_hygienic():
+    x_value = 128
+
+    @T.macro(hygienic=False)
+    def dynamic_capture(A, B):
+        B[()] = A[x_value]
+
+    @T.prim_func
+    def use_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
+        for x_value in T.serial(10):
+            dynamic_capture(A, B)
+
+    @T.prim_func
+    def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
+        for x_value in range(10):
+            B[()] = A[x_value]
+
+    tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic)
+
+
 if __name__ == "__main__":
     tvm.testing.main()