You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/04/27 08:38:07 UTC

[tvm] branch main updated: [TVMScript] Support TVMScript template meta-programming over variables (#11097)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c09a24dcdc [TVMScript] Support TVMScript template meta-programming over variables (#11097)
c09a24dcdc is described below

commit c09a24dcdce3bc71133712c003c2135842b64be1
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Wed Apr 27 16:38:00 2022 +0800

    [TVMScript] Support TVMScript template meta-programming over variables (#11097)
    
    This PR supports a simple meta-programming paradigm for TVMScript, which allows users to get access to var definition in the Python environment.
---
 python/tvm/script/context_maintainer.py            | 11 +++-
 python/tvm/script/parser.py                        | 15 ++++--
 .../unittest/test_tvmscript_meta_programming.py    | 59 ++++++++++++++++++++++
 3 files changed, 78 insertions(+), 7 deletions(-)

diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py
index 972e5845fc..f7f16855c7 100644
--- a/python/tvm/script/context_maintainer.py
+++ b/python/tvm/script/context_maintainer.py
@@ -121,6 +121,8 @@ class ContextMaintainer:
     """Dict[Var, Range]: The dict from loop var to its domain outside the block"""
     symbols: List[Dict[str, Union[Var, Buffer]]] = []
     """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""
+    closure_vars: Dict[str, Object] = {}
+    """ClosureVars: The closure vars defined in Python interpreter"""
 
     # function context
     func_params: List[Var] = []
@@ -144,12 +146,17 @@ class ContextMaintainer:
     root_alloc_buffers: List[Buffer] = []
     """List[Buffer]: The buffers allocated under root block"""
 
-    def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
+    def __init__(
+        self,
+        _report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
+        closure_vars: Dict[str, Object],
+    ):
         # scope context
         self.node_stack = []
         self.block_info_stack = []
         self.loop_stack = {}
         self.symbols = []
+        self.closure_vars = closure_vars
         # function context
         self.func_params = []
         self.func_buffer_map = {}
@@ -233,7 +240,7 @@ class ContextMaintainer:
         for symbols in reversed(self.symbols):
             if name in symbols:
                 return symbols[name]
-        return None
+        return self.closure_vars.get(name)
 
     def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
         self._report_error(message, span)
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index b01ad383c3..13b283bc0c 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -158,18 +158,21 @@ class TVMScriptParser(Transformer):
 
     # pylint gets confused here with synr.Transformer which doesn't have a
     # custom init, so just disable it
-    def __init__(self, base_lineno, tir_namespace):  # pylint: disable=super-init-not-called
+    def __init__(
+        self, base_lineno, tir_namespace, closure_vars
+    ):  # pylint: disable=super-init-not-called
         self.context = None
 
         self.base_lineno = base_lineno
         self.current_lineno = 0
         self.current_col_offset = 0
         self.tir_namespace = tir_namespace
+        self.closure_vars = closure_vars
         self.meta = None
 
     def init_function_parsing_env(self):
         """Initialize function parsing environment"""
-        self.context = ContextMaintainer(self.report_error)  # scope emitter
+        self.context = ContextMaintainer(self.report_error, self.closure_vars)  # scope emitter
 
     def init_meta(self, meta_dict):
         if meta_dict is not None:
@@ -709,7 +712,7 @@ class TVMScriptParser(Transformer):
         self.context.enter_scope(nodes=node.body.stmts)
         # for scope handler process the scope
         arg_list = [
-            tvm.runtime.convert(arg, span=node.rhs.span)
+            tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span))
             for arg in self.parse_arg_list(func, node.rhs)
         ]
         func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
@@ -1253,12 +1256,14 @@ def from_source(
     """
     if isinstance(input_func, str):
         tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix
-        return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix))
+        return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {}))
     elif inspect.isfunction(input_func):
         _, start_line = inspect.getsourcelines(input_func)
         env: Dict[str, Any] = input_func.__globals__
         namespace = [key for key in env.keys() if env[key] is tir]
-        parser = TVMScriptParser(start_line, namespace)
+        _closure_vars = inspect.getclosurevars(input_func)
+        closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals}
+        parser = TVMScriptParser(start_line, namespace, closure_vars)
         result = to_ast(input_func, TVMDiagnosticCtx(), parser)
         return result
     else:
diff --git a/tests/python/unittest/test_tvmscript_meta_programming.py b/tests/python/unittest/test_tvmscript_meta_programming.py
new file mode 100644
index 0000000000..2473c0c845
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_meta_programming.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm.script import tir as T
+
+
+def matmul_generator(M: int, N: int, K: int, dtype: str):
+    @T.prim_func
+    def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, [M, K], dtype=dtype)
+        B = T.match_buffer(b, [N, K], dtype=dtype)
+        C = T.match_buffer(c, [M, N], dtype=dtype)
+
+        for i, j, k in T.grid(M, N, K):
+            with T.block():
+                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                with T.init():
+                    C[vi, vj] = T.float32(0)
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+    return matmul
+
+
+@T.prim_func
+def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128], dtype="float16")
+    B = T.match_buffer(b, [128, 128], dtype="float16")
+    C = T.match_buffer(c, [128, 128], dtype="float16")
+
+    for i, j, k in T.grid(128, 128, 128):
+        with T.block():
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                C[vi, vj] = T.float32(0)
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+def test_meta_programming_matmul():
+    f = matmul_generator(128, 128, 128, "float16")
+    tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16)
+
+
+if __name__ == "__main__":
+    test_meta_programming_matmul()