You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/08/06 00:46:16 UTC

[tvm] branch main updated: [TIR][Schedule] Support annotate dict typed value (#12288)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2a7af612f7 [TIR][Schedule] Support annotate dict typed value (#12288)
2a7af612f7 is described below

commit 2a7af612f71bb9177e73c0e1df72db22fa039bd3
Author: wrongtest <wr...@gmail.com>
AuthorDate: Sat Aug 6 08:46:11 2022 +0800

    [TIR][Schedule] Support annotate dict typed value (#12288)
    
    * tir schedule support annotate dict typed value
    
    * fix lint
    
    * fix comment issues
---
 python/tvm/tir/schedule/_type_checker.py           | 34 ++++++++++++++++++++++
 python/tvm/tir/schedule/schedule.py                | 11 +++++--
 src/tir/schedule/concrete_schedule.cc              | 15 ++++++++++
 src/tir/schedule/instruction_traits.h              | 24 +++++++++++++++
 src/tir/schedule/trace.cc                          |  8 +++++
 .../python/unittest/test_tir_schedule_utilities.py |  8 +++--
 .../unittest/test_type_annotation_checker.py       | 10 ++++++-
 7 files changed, 105 insertions(+), 5 deletions(-)

diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py
index 564d23afad..d45b4fb84b 100644
--- a/python/tvm/tir/schedule/_type_checker.py
+++ b/python/tvm/tir/schedule/_type_checker.py
@@ -41,6 +41,13 @@ if hasattr(typing, "_GenericAlias"):
                 return [subtype]
             return None
 
+        @staticmethod
+        def dict_(type_: Any) -> Any:
+            if _Subtype._origin(type_) is dict:
+                (ktype, vtype) = type_.__args__
+                return [ktype, vtype]
+            return None
+
         @staticmethod
         def tuple_(type_: Any) -> Optional[List[type]]:
             if _Subtype._origin(type_) is tuple:
@@ -75,6 +82,14 @@ elif hasattr(typing, "_Union"):
                     return [subtype]
             return None
 
+        @staticmethod
+        def dict_(type_: Any) -> Optional[List[type]]:
+            if isinstance(type_, typing.GenericMeta):  # type: ignore # pylint: disable=no-member
+                if type_.__name__ == "Dict":
+                    (ktype, vtype) = type_.__args__  # type: ignore # pylint: disable=no-member
+                    return [ktype, vtype]
+            return None
+
         @staticmethod
         def tuple_(type_: Any) -> Optional[List[type]]:
             if isinstance(type_, typing.GenericMeta):  # type: ignore # pylint: disable=no-member
@@ -108,6 +123,10 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
     if subtype is not None:
         return "list", subtype
 
+    subtype = _Subtype.dict_(type_)
+    if subtype is not None:
+        return "dict", subtype
+
     subtype = _Subtype.tuple_(type_)
     if subtype is not None:
         return "tuple", subtype
@@ -127,6 +146,7 @@ _TYPE2STR: Dict[Any, Callable] = {
     "none": lambda: "None",
     "atomic": lambda t: str(t.__name__),
     "list": lambda t: f"List[{_type2str(t)}]",
+    "dict": lambda k, v: f"Dict[{_type2str(k)}, {_type2str(v)}]",
     "tuple": lambda *t: f"Tuple[{', '.join([_type2str(x) for x in t])}]",
     "optional": lambda t: f"Optional[{_type2str(t)}]",
     "union": lambda *t: f"Union[{', '.join([_type2str(x) for x in t])}]",
@@ -177,6 +197,19 @@ def _type_check_vtable() -> Dict[str, Callable]:
                 return error_msg
         return None
 
+    def _type_check_dict(dict_obj: Dict[Any, Any], name: str, *types: Any) -> Optional[str]:
+        ktype_, vtype_ = types
+        if not isinstance(dict_obj, dict):
+            return _type_check_err(dict_obj, name, dict)
+        for k, v in dict_obj.items():
+            error_msg = _type_check(k, f"{name}[{k}]", ktype_)
+            if error_msg is not None:
+                return error_msg
+            error_msg = _type_check(v, f"{name}[{k}]", vtype_)
+            if error_msg is not None:
+                return error_msg
+        return None
+
     def _type_check_tuple(v: Any, name: str, *types: Any) -> Optional[str]:
         if not isinstance(v, tuple):
             return _type_check_err(v, name, Tuple[types])
@@ -202,6 +235,7 @@ def _type_check_vtable() -> Dict[str, Callable]:
         "none": _type_check_none,
         "atomic": _type_check_atomic,
         "list": _type_check_list,
+        "dict": _type_check_dict,
         "tuple": _type_check_tuple,
         "optional": _type_check_optional,
         "union": _type_check_union,
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index 7bec054b73..cf031c014c 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2186,12 +2186,19 @@ class Schedule(Object):
 
     ########## Schedule: Annotation ##########
 
+    PrimAnnotationValueT = Union[str, int, float, ExprRV]
+    AnnotationValueT = Union[
+        PrimAnnotationValueT,
+        List[PrimAnnotationValueT],
+        Dict[str, Union[PrimAnnotationValueT, List[PrimAnnotationValueT]]],
+    ]
+
     @type_checked
     def annotate(
         self,
         block_or_loop: Union[BlockRV, LoopRV],
         ann_key: str,
-        ann_val: Union[str, int, float, ExprRV, List[Union[str, int, float, ExprRV]]],
+        ann_val: AnnotationValueT,
     ) -> None:
         """Annotate a block/loop with a key value pair
 
@@ -2201,7 +2208,7 @@ class Schedule(Object):
             The block/loop to be annotated
         ann_key : str
             The annotation key
-        ann_val : Union[str, int, float, ExprRV, List[Union[str, int, float, ExprRV]]]
+        ann_val : AnnotationValueT
             The annotation value
 
         Examples
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 9d0bb885e2..a73d6165a2 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -699,6 +699,21 @@ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_
     }
     return std::move(result);
   }
+  if (const auto* dict = ann_val.as<MapNode>()) {
+    Map<String, ObjectRef> result;
+    for (auto it = dict->begin(); it != dict->end(); ++it) {
+      const auto& key = it->first;
+      auto value = CheckAndGetAnnotationValue(it->second);
+      if (const StringImmNode* imm = key.as<StringImmNode>()) {
+        result.Set(imm->value, value);
+      } else if (key->IsInstance<StringObj>()) {
+        result.Set(Downcast<String>(key), value);
+      } else {
+        LOG(FATAL) << "TypeError: annotation dict key expect to be String or StringImm";
+      }
+    }
+    return std::move(result);
+  }
   LOG(FATAL)
       << "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but "
       << "gets: " << ann_val->GetTypeKey();
diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h
index 14d05a4a34..56c69224fe 100644
--- a/src/tir/schedule/instruction_traits.h
+++ b/src/tir/schedule/instruction_traits.h
@@ -22,7 +22,9 @@
 #include <tvm/tir/schedule/instruction.h>
 #include <tvm/tir/schedule/schedule.h>
 
+#include <algorithm>
 #include <sstream>
+#include <string>
 #include <utility>
 #include <vector>
 
@@ -447,6 +449,28 @@ inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os
       AsPythonString(e, os);
     }
     os << ']';
+  } else if (const auto* dict = obj.as<MapNode>()) {
+    os << '{';
+    bool is_first = true;
+    std::vector<std::pair<std::string, std::string>> dict_items;
+    for (auto it = dict->begin(); it != dict->end(); ++it) {
+      std::ostringstream ks;
+      AsPythonString(it->first, ks);
+      std::ostringstream vs;
+      AsPythonString(it->second, vs);
+      dict_items.emplace_back(ks.str(), vs.str());
+    }
+    std::sort(dict_items.begin(), dict_items.end(),
+              [](const auto& p1, const auto& p2) { return p1.first < p2.first; });
+    for (const auto& kv : dict_items) {
+      if (is_first) {
+        is_first = false;
+      } else {
+        os << ", ";
+      }
+      os << '\"' << kv.first << "\": " << kv.second;
+    }
+    os << '}';
   } else {
     LOG(FATAL) << "ValueError: Cannot translate type '" << obj->GetTypeKey()
                << "' to python. Its value is: " << obj;
diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc
index 9fa86917c5..395613bf4c 100644
--- a/src/tir/schedule/trace.cc
+++ b/src/tir/schedule/trace.cc
@@ -112,6 +112,9 @@ Array<ObjectRef> TranslateInputRVs(
     } else if (input->IsInstance<ArrayNode>()) {
       // Case 4: array
       results.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), rv_names));
+    } else if (input->IsInstance<MapNode>()) {
+      // Case 5: dict
+      results.push_back(input);
     } else if (input->IsInstance<BlockRVNode>() || inputs->IsInstance<LoopRVNode>() ||
                inputs->IsInstance<VarNode>()) {
       LOG(FATAL) << "IndexError: Random variable is not defined " << input;
@@ -139,6 +142,11 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
       results.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), named_rvs));
       continue;
     }
+    // Case 5. dict
+    if (input->IsInstance<MapNode>()) {
+      results.push_back(input);
+      continue;
+    }
     const auto* str = input.as<StringObj>();
     CHECK(str) << "TypeError: Expect String, but gets: " << input->GetTypeKey();
     CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names";
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py
index c479555590..41844a868e 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -67,7 +67,7 @@ def matmul_relu_ann1(a: T.handle, b: T.handle, d: T.handle) -> None:
     B = T.match_buffer(b, (1024, 1024))
     C = T.alloc_buffer((1024, 1024))
     D = T.match_buffer(d, (1024, 1024))
-    for i in T.serial(0, 1024, annotations={"test1": "aaa"}):
+    for i in T.serial(0, 1024, annotations={"test1": "aaa", "test4": {"arr": [0, 0], "key": 3}}):
         for j in T.serial(0, 1024, annotations={"test2": 612, "test3": ["aa", 1]}):
             for k in T.serial(0, 1024):
                 with T.block("matmul"):
@@ -92,7 +92,7 @@ def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None:
             vi, vj, vk = T.axis.remap("SSR", [i, j, k])
             with T.init():
                 C[vi, vj] = 0.0
-            T.block_attr({"test1": "aaa"})
+            T.block_attr({"test1": "aaa", "test4": {"arr": [0, 0], "key": 3}})
             C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
     for i, j in T.grid(1024, 1024):
         with T.block("relu"):
@@ -279,11 +279,13 @@ def test_annotate_unannotate_loop():
     sch.annotate(sch.get_loops(matmul)[0], "test1", "aaa")
     sch.annotate(sch.get_loops(matmul)[1], "test2", 612)
     sch.annotate(sch.get_loops(matmul)[1], "test3", ["aa", 1])
+    sch.annotate(sch.get_loops(matmul)[0], "test4", {"arr": [0, 0], "key": 3})
     tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann1)
     verify_trace_roundtrip(sch=sch, mod=matmul_relu)
     sch.unannotate(sch.get_loops(matmul)[0], "test1")
     sch.unannotate(sch.get_loops(matmul)[1], "test2")
     sch.unannotate(sch.get_loops(matmul)[1], "test3")
+    sch.unannotate(sch.get_loops(matmul)[0], "test4")
     verify_trace_roundtrip(sch=sch, mod=matmul_relu)
 
 
@@ -294,11 +296,13 @@ def test_annotate_unannotate_block():
     sch.annotate(matmul, "test1", "aaa")
     sch.annotate(relu, "test2", 0.22)
     sch.annotate(relu, "test3", ["aa", 1])
+    sch.annotate(matmul, "test4", {"arr": [0, 0], "key": 3})
     tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann2)
     verify_trace_roundtrip(sch=sch, mod=matmul_relu)
     sch.unannotate(matmul, "test1")
     sch.unannotate(relu, "test2")
     sch.unannotate(relu, "test3")
+    sch.unannotate(matmul, "test4")
     verify_trace_roundtrip(sch=sch, mod=matmul_relu)
 
 
diff --git a/tests/python/unittest/test_type_annotation_checker.py b/tests/python/unittest/test_type_annotation_checker.py
index 9f6f29c7ff..e84ae043d3 100644
--- a/tests/python/unittest/test_type_annotation_checker.py
+++ b/tests/python/unittest/test_type_annotation_checker.py
@@ -16,7 +16,8 @@
 # under the License.
 """Test type checker based on python's type annotations"""
 
-from typing import List, Tuple, Union
+import sys
+from typing import Dict, List, Tuple, Union
 
 import pytest
 
@@ -44,6 +45,13 @@ test_cases = [
             ["5"],
         ],
     },
+    {
+        "type_annotation": Dict[str, int],
+        "positive_cases": [
+            {"key1": 0, "key2": 1, "key3": -1},
+        ],
+        "negative_cases": [None, [1], {1: "1"}],
+    },
     {
         "type_annotation": Tuple[int],
         "positive_cases": [