You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/04 18:35:03 UTC

[tvm-rfcs] branch main updated: [RFC] TVMScript Metaprogramming (#79)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-rfcs.git


The following commit(s) were added to refs/heads/main by this push:
     new ffbf686  [RFC] TVMScript Metaprogramming (#79)
ffbf686 is described below

commit ffbf68643a20a9b3daeb2adc87933d976e5a5513
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Thu Aug 4 14:34:57 2022 -0400

    [RFC] TVMScript Metaprogramming (#79)
    
    * Add TVMScript Metaprogramming RFC
    
    * Add clarification for host program interleaving
    
    * Fix typo and formating
    
    * Add clarification around parse-time eval
    
    * Further clarify parse-time eval
    
    * Add clarification on how F3 and F4 will be implemented by parse-time eval
    
    * Clarify the signature of user-provided function
    
    * Clarify allowed types for auto capture
    
    * Add section for security implication of eval
    
    * Add short discussion on Taichi and Triton
    
    * Clarify on how the new approach improves tooling
    
    * Clarify and fix typos
---
 rfcs/0079-tvmscript-metaprogramming.md | 525 +++++++++++++++++++++++++++++++++
 1 file changed, 525 insertions(+)

diff --git a/rfcs/0079-tvmscript-metaprogramming.md b/rfcs/0079-tvmscript-metaprogramming.md
new file mode 100644
index 0000000..86c3c70
--- /dev/null
+++ b/rfcs/0079-tvmscript-metaprogramming.md
@@ -0,0 +1,525 @@
+- Feature Name: tvmscript-metaprogramming
+- Start Date: 2022-06-16
+- RFC PR: [apache/tvm-rfcs#79](https://github.com/apache/tvm-rfcs/pull/79)
+- GitHub Issue: [apache/tvm#0000](https://github.com/apache/tvm/issues/0000)
+- Co-Authors: Yaxing Cai ([**@cyx-6**](https://github.com/cyx-6), main implementation), Lite Ye
+  ([**@yelite**](https://github.com/yelite)), Yong Wu
+  ([**@yongwww**](https://github.com/yongwww)), Yuchen Jin
+  ([**@YuchenJin**](https://github.com/YuchenJin)), Eric Lunderberg
+  ([**@Lunderberg**](https://github.com/Lunderberg)), Masahiro Masuda
+  ([**@masahi**](https://github.com/masahi)), Junru Shao
+  ([**@junrushao1994**](https://github.com/junrushao1994), main designer)
+
+# Summary
+[summary]: #summary
+
+This RFC proposes a new TVMScript parser infrastructure, supporting extensive
+metaprogramming and syntactic sugars. The new infrastructure is IR-agnostic,
+treating TIR just as one of dialects. Additionally, the new infrastructure will
+provide better tooling around Python ecosystem (pylint, mypy, etc.).
+
+# Motivation
+[motivation]: #motivation
+
+**What is TVMScript**. 
+Check [Blitz Course to TensorIR](https://tvm.apache.org/docs/tutorial/tensor_ir_blitz_course.html) and
+[TVMScript Unified Printer RFC](https://github.com/apache/tvm-rfcs/pull/74/files#diff-6965a40ad8df7618ae68e11c88f924542a506c74a931cc3011ae9f99989b5f51R20-R26)
+for an introduction into TVMScript.
+
+**What is metaprogramming.** In the context of TVMScript, metaprogramming means
+a programmable way to control IR generation. For example, in
+https://github.com/apache/tvm/pull/11097, a metaprogramming feature was added
+to the TVMScript parser, allows users to programmably control the shapes of the
+input buffers of a `PrimFunc`.
+
+### Limitation of current design
+
+The current parser lacks capability on generic metaprogramming that allows user
+to have more control on IR construction. This makes it challenging to support
+operators like NMS (non-maximum suppression, which is crucial to object
+detection model). There is an implementation of NMS at
+[python/tvm/topi/cuda/nms.py#L367-L386](https://github.com/apache/tvm/blob/d0650bad66d0ff89a01347537021bc442a98c223/python/tvm/topi/cuda/nms.py#L367-L386).
+The implementation of NMS-like operators requires rank-polymorphism and the
+ability to interleave host program with TVMScript, which is difficult to be
+implemented under the current design.
+
+TVMScript also needs reasonable support on Python tooling. Currently it doesn’t
+play nicely with pylint and mypy. For example,
+[test_meta_schedule_postproc_rewrite_tensorize.py](https://github.com/apache/tvm/blob/d0650bad66d0ff89a01347537021bc442a98c223/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py)
+has 100+ warnings from pylint within only 500 hundred lines of code. This
+creates confusion to the user and leaves an impression that TVMScript isn’t a
+mature product and not production-ready. Even though it’s something that can be
+incrementally improved under the current design, we believe it’s easier to get
+an ideal result if we have a design with the tooling support in mind.
+
+The current design also lacks of unified approach for different IRs. At
+[https://github.com/tlc-pack/relax/tree/relax/python/tvm/script/relax](https://github.com/tlc-pack/relax/tree/relax/python/tvm/script/relax),
+a mature implementation of TVMScript parser is maintained for Relax. But it’s
+hard to extend if we want to support more IRs for TVM unity.
+
+To conclude, with this RFC, we want to:
+
+1. Add more metaprogramming features to TVMScript, making it easier for TVM
+   developers to write complicated operators.
+2. Improve tooling and documentation of TVMScript, reducing the friction for an
+   average machine learning practitioner to use TVMScript.
+3. Modularize and infrastructuralize the TVMScript parser, lowering the cost to
+   implement parser for new IR.
+
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Metaprogramming features to support
+
+### (F1) Template Metaprogramming
+
+Users should be able to use variables from outer scope in the TVMScript
+function/class. The parsed result should be identical to function/class with
+the variable replaced by its value. For instance,
+
+```python
+@T.prim_func
+def matmul(
+  A: T.Buffer[(128, 128)],
+) -> None:
+  ...
+
+def gen_matmul(n, m) -> None:
+  @T.prim_func
+  def f(A: T.Buffer[(n, m)]):
+    ...
+  return f
+
+f = gen_matmul(n=128, m=128) # `f` should be identical to `matmul`
+```
+
+This is already partially supported by https://github.com/apache/tvm/pull/11097
+for using `PrimExpr` captured by outer function. With the new parser, we want
+to support this feature in more places and with more variable types.
+
+### (F2) Rank-polymorphism
+
+Users should be able to write a single function to handle different ranks of
+input buffers (different numbers of dimensions). For example, user should be
+able to write a generic function to do broadcast add,
+
+```python
+def broadcast_add(a, b, c):
+  @T.prim_func
+  def f(
+    A: T.BufferFrom(a),
+    B: T.BufferFrom(b),
+    C: T.BufferFrom(c),
+  ) -> None:
+    for i, i_a, i_b in T.some_broadcast_method(A.shape, B.shape):
+      with T.block():
+        C[*i] = A[*i_a] + B[*i_b]
+
+broadcast_add(
+  a = Buffer((128, 1), "float32"),
+  b = Buffer((1, 128), "float32"),
+  c = Buffer((128, 128), "float32"),
+)
+```
+
+### (F3) Sugar: TE Compute in TIR
+
+Users should be able to replace boilerplate code with a function call, which’s
+expanded to large chunk of code during parsing. For example, we may want to use
+TE’s compute-like syntax to replace nested loop,
+
+```python
+@T.prim_func
+def te_compute_sugar(
+  A: T.Buffer[(128, 128)],
+  B: T.Buffer[(128, 128)],
+) -> None:
+  ...
+  C = T.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
+  ...
+
+## expands to ====>
+
+@T.prim_func
+def te_compute_expanded(
+  A: T.Buffer[(128, 128)],
+  B: T.Buffer[(128, 128)],
+) -> None:
+  ...
+  for i in range(128):
+    for j in range(128):
+      with T.block("..."):
+        C[i, j] = A[i, j] + B[i, j]
+  ...
+```
+
+### (F4) Interleave host program and TVMScript program to customize metaprogramming
+
+As an escape hatch from writing code to be parsed by the TVMScript
+parser, users should be able to write imperative code to construct IR nodes
+directly and embed it inside regular TVMScript. Those code will be evaluated
+by the Python interpreter when parsing. This gives users the ultimate tool when
+TVMScript isn’t expressible enough for their use cases. For example, at
+[python/tvm/topi/vision/nms.py#L380-L431](https://github.com/apache/tvm/blob/3cb4597ed48360e3f3d80161d1c03f833072d28e/python/tvm/topi/vision/nms.py#L380-L431),
+there are blocks of repetitive code on computing the coordinates of the four
+corners of bounding box. This can be simplified as:
+
+```python
+# Before, without IRBuilder interleaving
+@T.prim_func
+def nms(...):
+  ...
+  for i in range(batch_size):
+    ...
+    a_l = min(
+      output[batch_idx, box_a_idx, box_start_idx],
+      output[batch_idx, box_a_idx, box_start_idx + 2],
+    )
+    a_t = min(
+      output[batch_idx, box_a_idx, box_start_idx + 1],
+      output[batch_idx, box_a_idx, box_start_idx + 3],
+    )
+    a_r = max(
+      output[batch_idx, box_a_idx, box_start_idx],
+      output[batch_idx, box_a_idx, box_start_idx + 2],
+    )
+    a_b = max(
+      output[batch_idx, box_a_idx, box_start_idx + 1],
+      output[batch_idx, box_a_idx, box_start_idx + 3],
+    )
+		...
+    for k in range(j):
+      check_iou = ...
+			...
+      if check_iou > 0:
+        # b_l: left, b_t: top, b_r: right, b_b: bottom
+        b_l = min(
+          output[batch_idx, box_b_idx, box_start_idx],
+          output[batch_idx, box_b_idx, box_start_idx + 2],
+        )
+        b_t = min(
+          output[batch_idx, box_b_idx, box_start_idx + 1],
+          output[batch_idx, box_b_idx, box_start_idx + 3],
+        )
+        b_r = max(
+          output[batch_idx, box_b_idx, box_start_idx],
+          output[batch_idx, box_b_idx, box_start_idx + 2],
+        )
+        b_b = max(
+          output[batch_idx, box_b_idx, box_start_idx + 1],
+          output[batch_idx, box_b_idx, box_start_idx + 3],
+        )
+        ...
+
+# With IRBuilder interleaving:
+
+from tvm.script import tir as T
+
+def get_box_coordinates(output, batch_idx, box_idx, box_start_idx):
+  """a method executed by python interpreter"""
+  box_l = T.min(
+    output[batch_idx, box_idx, box_start_idx],
+    output[batch_idx, box_idx, box_start_idx + 2],
+	) # type(box_l) is PrimExpr
+  ... # Repeat for other coordinates
+  return box_l, box_t, box_r, box_b
+
+@T.prim_func(capture=[get_box_coordinates])
+def nms(...):
+  ...
+  for i in range(batch_size):
+    ...
+    a_l, a_t, a_r, a_b = get_box_coordinates(output, batch_idx, box_a_idx, box_start_idx)
+    ...
+    for k in range(j):
+      check_iou = ...
+      ...
+      if check_iou > 0:
+        b_l, b_t, b_r, b_b = get_box_coordinates(output, batch_idx, box_b_idx, box_start_idx)
+        ...
+```
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## IRBuilder as Core
+
+As the foundation of IR construction, we will provide a set of APIs called
+IRBuilder to let user construct IR imperatively. IRBuilder will be used by the
+parser, as well as by users directly as described in the feature F4. IRBuilder
+allows user to write code in a style that’s similar to TVMScript, while it’s
+being executed as host program. For example, 
+
+```python
+from tvm.script.builder import Builder, def_, def_many
+from tvm.script import tir as T
+
+with Builder() as b:
+  with T.prim_func():
+    T.func_name("main")
+    buffer_a = T.Buffer((128, 128, 128), "float32")
+    buffer_b = T.Buffer((128, 128, 128), "float32")
+    arg_a = T.arg("A", buffer_a)
+    arg_b = T.arg("B", buffer_b)
+    with T.grid(128, 128, 128) as (i, j, k):
+      def_many(["i", "j", "k"], [i, j, k])
+      with T.block(name="block"):
+        vi = def_("vi", T.axis.spatial(128, i))
+        vj = def_("vj", T.axis.spatial(128, j))
+        vk = def_("vk", T.axis.reduce(128, k))
+
+f = b.get() # f is a PrimFunc
+```
+
+produces equivalent result as 
+
+```python
+@T.prim_func
+def main(
+  A: T.Buffer(shape=(128, 128, 128), dtype="float32"),
+  B: T.Buffer(shape=(128, 128, 128), dtype="float32"),
+) -> None:
+  for i, j, k in T.grid(128, 128, 128):
+      with T.block("block"):
+          vi = T.axis.S(128, i)
+          vj = T.axis.S(128, j)
+          vk = T.axis.R(128, k)
+```
+
+As shown in the example above, user doesn't need to pass the builder `b` to
+subsequent calls to IRBuilder API. The current builder state is maintained in a
+threadlocal store to improve the ergonomics of IRBuilder API by avoiding
+passing the builder state explicitly.
+
+The implementation of IRBuilder will be in C++ so that it can be used in an
+environment without Python. Python binding will be created to expose IRBuilder
+to TVMScript parser.
+
+With the separation between IRBuilder and parser, most implementation and
+documentation can be reused between TVMScript and IR definition. For example,
+most of operators are simply imported into the IRBuilder package, like
+```python
+from tvm.tir import sin, cos
+```
+so the documentation and type signatures only need to be written once, and the
+APIs are guaranteed to be consistent.
+
+## Parse-time evaluation
+
+TVMScript Parser can be considered as a thin layer built above the IRBuilder
+API. The parser transforms input AST into a sequence of calls to IRBuilder API,
+by evaluating small fragments of code as it visits AST. The IRBuilder is
+responsible for building the actual IR graph. All metaprogramming features we
+discussed above (F1 through F4) can be implemented through this parse-time
+evaluation in a consistent manner. Using the same `gen_matmul` example from F1,
+
+```python
+def gen_matmul(n, m) -> None:
+  @T.prim_func
+  def f(A: T.Buffer[(n, m)]):
+    ...
+  return f
+```
+
+What parser does here is to:
+
+1. Collect the environment inside `gen_matmul`, getting a dictionary
+    1. All primitive types (`int`, `str`, `float` and `None`) will be captured
+       automatically, while advanced types (like function) needs to be
+       explicitly declared in the decorator to be captured (for example,
+       `@T.prim_func(capture=[get_box_coordinates])`)
+2. Call the corresponding function from IRBuilder, as the parser visits the AST
+   of function `f`.
+    1. When visiting the function argument, call `eval` on its type annotation, 
+       with the environment captured in the first step.
+    2. `T.Buffer[(n, m)]` gets evaluated to a value with type `tir.Buffer`.
+    3. Call the IRBuilder API `T.arg("A", buffer)` to add an arg to the function
+       that’s being constructed
+
+Another example,
+
+```python
+for *i in T.grid(*A.shape):
+  ...
+```
+
+The parser will:
+
+1. Evaluate the expression `T.grid(*A.shape)` by the step described above.
+   `T.grid` returns a value that is nearly equivalent to `List[Var]`. 
+2. Call `exec` on a specially constructed statement `*i = __tvm_rhs_var__`,
+   with `locals` that maps `__tvm_rhs_var__` to the value evaluated in step 1.
+3. Collect the value of `i` from the `locals` dictionary
+4. Call the IRBuilder API `def_many(["i"], [i])` 
+
+As mentioned above, all metaprogramming features (F1 through F4) can be
+implemented through this parse-time evaluation. It's straightforward to see how
+F1 and F2 are implemented by parse-time evaluation, but it might be harder to
+grasp the idea behind F3 and F4. 
+
+For F3 (TE Compute in TIR),
+```python
+C = T.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
+# build a similar graph as
+for i in range(128):
+  for j in range(128):
+    with T.block("..."):
+      C[i, j] = A[i, j] + B[i, j]
+```
+`T.compute` is provided in IRBuilder API and will construct all IR nodes (For,
+BlockRealize, Block and BufferStore). `T.compute` calls the lambda function to
+get the rhs of BufferStore. Then `T.compute` returns a `Buffer` node that
+represents `C`. The parser handles the assignment by assigning `"C"` to the
+`name_hint` of the returned buffer (by calling IRBuilder API `def_("C", C)`),
+and put it into the internal variable table (which is used to resolve variable
+when evaluating subsequent statements and expressions).
+
+For F4 (Interleave host program and TVMScript program),
+```python
+a_l, a_t, a_r, a_b = get_box_coordinates(output, batch_idx, box_a_idx, box_start_idx)
+```
+The call to the `get_box_coordinates` function is evaluated when parser is visiting the 
+assign statement. The parser calls IRBuilder `def_many(["a_l", "a_t", "a_r", "a_b"], <returned_tuple>)`
+and put them into the internal variable table.
+
+Note that we will not place extra restriction on the signature of user-provided
+function (`get_box_coordinates` in this example). More precisely, the function
+argument types can be anything because parser is able to capture outter
+variables thus bringing variables with arbitrary type into the scope. The restriction on
+returned type depends on the language spec of the target IR. For example, in
+TIR user can write `A[i] = ...` to represent `BufferStore`, where the rhs is a
+`PrimExpr`, then the user can provide a custom function
+`compute_magic_number(index: PrimExpr) -> PrimExpr` to use on the rhs as `A[i]
+= compute_magic_number(i)`.
+
+By running `eval` and `exec` provided by the Python interpreter, we can implement
+language features which are difficult to implement manually, and also make sure
+TVMScript has the same semantics on expression compared to regular Python code.
+
+## Parser Registration
+
+The logic of how to process a particular Python syntax node is registered
+through decorator. For example,
+
+```python
+@dispatch.register(token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+    ...
+```
+
+handles the transformation of Python `For` node. The `token="tir"` in the
+decorator means that the handler is for TIR. `self: Parser` has all the
+infrastructural API and maintains a stack of `token` to determine which
+function to dispatch to. This makes embedding different IR possible (for
+example, embedding TIR in Relax). The folder structure will look like
+
+```
+python/tvm/script/
+└── parser
+    ├── core
+    │   └── ... # Parser infrastructure
+    ├── tir # TIR dialect
+    │   └── ...
+    └── relax # Hypothetical Relax dialect (not part of our RFC)
+        └── ...
+```
+
+## Usage of `eval` and `exec`
+
+The parser uses `eval` and `exec` in the following places:
+- It calls `eval` to evaluate fragment of expressions
+- It calls `exec` to evaluate different kinds of assignment statement, like `first, *rest, last = T.grid(...)`
+
+The usage of `eval` and `exec` is necessary to our implementation. TVMScript
+allows users to construct IR graph in TVM declaratively as if they were writing
+Python code, lowering the barrier of using TVM to do low-level customization.
+However it is still restrictive and does not allow the usage of many Python
+features to build abstractions for user's code. All features proposed in this
+RFC can be seen as an effort to narrow this gap, thus they are designed to
+follow the Python semantics. On the implementation side, the most robust
+approach is to leverage the Python interpreter itself to facilitate those
+features, rather than write our own version of restricted Python interpreter.
+And `eval` and `exec` are the most suitable choices to achieve this. Other
+mechanism, like multiprocess + IPC and subinterpreter, either lacks an easy
+path to exchange Python objects, or requires dependency on the C API of CPython.
+
+In our use cases, `eval` and `exec` do not create additional
+security risk for users. All inputs to `eval` come from user's code
+directly, without modification. Our usage of `exec` only executes a
+specific form of code, 
+```python
+<lhs> = __tvm_rhs_var__
+```
+where `<lhs>` is the tokens from the left hand side of assign statement,
+directly from user's code. The situation is very different from cases that make
+`eval` and `exec` infamous, where they are used to evaluate untrusted
+input from end users, typically in a web server. Furthermore, we require users
+to explicitly capture variables. Therefore, The evaluation will only involve
+objects from `tvm`, Python builtins and objects explicitly captured by users.
+This rules out the possibility that an external function is called by parser
+without user's acknowledgement. If the malicious actor wants to exploit the
+system through the `eval` or `exec` in TVMScript parser, they must first get
+another RCE (remote code execution) vulnerability in the Python runtime to
+modify the code in runtime, which makes such exploit useless (one needs to
+first have an RCE to exploit another RCE with the same exposure). 
+
+
+# Drawbacks
+[drawbacks]: #drawbacks
+
+N/A
+
+# Rationale and alternatives
+[rationale-and-alternatives]: #rationale-and-alternatives
+
+N/A
+
+# Prior art
+[prior-art]: #prior-art
+
+### Prior works in TVM
+- Hybrid Script: [https://tvm.apache.org/docs/reference/langref/hybrid_script.html](https://tvm.apache.org/docs/reference/langref/hybrid_script.html)
+- RFC for TVMScript: [https://discuss.tvm.apache.org/t/rfc-hybrid-script-support-for-tir/7516](https://discuss.tvm.apache.org/t/rfc-hybrid-script-support-for-tir/7516)
+
+### Other libraries
+#### Taichi 
+[https://www.taichi-lang.org](https://www.taichi-lang.org/)
+
+Taichi has a very similar metaprogramming model ([link to doc](https://docs.taichi.graphics/docs/master/meta)) as we presented in this RFC.
+The biggest difference is that Taichi requires `ti.static` to be wrapped around everything that needs to be evaluated in compile time. It also 
+has advanced features like loop unrolling, compile time branching and compile-time recursion. 
+
+In TVMScript parser, it does not need special marker to denote compile-time
+evaluation. Expressions are consistently evaluated in compile time (evaluated
+to IR node like PrimExpr, rather than the concrete value like a float matrix.), 
+thanks to the separation of IRBuilder and parser. Features like loop
+unrolling can be implemented in the IRBuilder layer per target IR. This keeps
+the core parser as minimal as possible.
+
+
+#### Triton 
+[http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf)
+
+Triton uses the meta-parameters to generalize kernels (for example, 
+[tutorials/03-matrix-multiplication.html](https://triton-lang.org/master/getting-started/tutorials/03-matrix-multiplication.html)).
+Meta parameters are placed together with real parameters, but with type
+annotation `tl.constexpr` to differentiate. This method slightly deviates from
+regular Python semantics, as users will intuitively expect to pass them
+together with real parameters when calling the kernel.
+
+In TVMScript, one of the design principle is to narrow the gap between
+TVMScript and regular Python code. TVMScript should not surprise users with
+syntax or language features that deviate from Python. All features proposed in
+this RFC are designed to strictly follow the semantics of Python and aimed to
+be intuitive to Python users.
+
+# Unresolved questions
+[unresolved-questions]: #unresolved-questions
+
+N/A
+
+# Future possibilities
+[future-possibilities]: #future-possibilities
+
+N/A