You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/08/06 00:46:16 UTC
[tvm] branch main updated: [TIR][Schedule] Support annotate dict typed value (#12288)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2a7af612f7 [TIR][Schedule] Support annotate dict typed value (#12288)
2a7af612f7 is described below
commit 2a7af612f71bb9177e73c0e1df72db22fa039bd3
Author: wrongtest <wr...@gmail.com>
AuthorDate: Sat Aug 6 08:46:11 2022 +0800
[TIR][Schedule] Support annotate dict typed value (#12288)
* tir schedule support annotate dict typed value
* fix lint
* fix comment issues
---
python/tvm/tir/schedule/_type_checker.py | 34 ++++++++++++++++++++++
python/tvm/tir/schedule/schedule.py | 11 +++++--
src/tir/schedule/concrete_schedule.cc | 15 ++++++++++
src/tir/schedule/instruction_traits.h | 24 +++++++++++++++
src/tir/schedule/trace.cc | 8 +++++
.../python/unittest/test_tir_schedule_utilities.py | 8 +++--
.../unittest/test_type_annotation_checker.py | 10 ++++++-
7 files changed, 105 insertions(+), 5 deletions(-)
diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py
index 564d23afad..d45b4fb84b 100644
--- a/python/tvm/tir/schedule/_type_checker.py
+++ b/python/tvm/tir/schedule/_type_checker.py
@@ -41,6 +41,13 @@ if hasattr(typing, "_GenericAlias"):
return [subtype]
return None
+ @staticmethod
+ def dict_(type_: Any) -> Any:
+ if _Subtype._origin(type_) is dict:
+ (ktype, vtype) = type_.__args__
+ return [ktype, vtype]
+ return None
+
@staticmethod
def tuple_(type_: Any) -> Optional[List[type]]:
if _Subtype._origin(type_) is tuple:
@@ -75,6 +82,14 @@ elif hasattr(typing, "_Union"):
return [subtype]
return None
+ @staticmethod
+ def dict_(type_: Any) -> Optional[List[type]]:
+ if isinstance(type_, typing.GenericMeta): # type: ignore # pylint: disable=no-member
+ if type_.__name__ == "Dict":
+ (ktype, vtype) = type_.__args__ # type: ignore # pylint: disable=no-member
+ return [ktype, vtype]
+ return None
+
@staticmethod
def tuple_(type_: Any) -> Optional[List[type]]:
if isinstance(type_, typing.GenericMeta): # type: ignore # pylint: disable=no-member
@@ -108,6 +123,10 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
if subtype is not None:
return "list", subtype
+ subtype = _Subtype.dict_(type_)
+ if subtype is not None:
+ return "dict", subtype
+
subtype = _Subtype.tuple_(type_)
if subtype is not None:
return "tuple", subtype
@@ -127,6 +146,7 @@ _TYPE2STR: Dict[Any, Callable] = {
"none": lambda: "None",
"atomic": lambda t: str(t.__name__),
"list": lambda t: f"List[{_type2str(t)}]",
+ "dict": lambda k, v: f"Dict[{_type2str(k)}, {_type2str(v)}]",
"tuple": lambda *t: f"Tuple[{', '.join([_type2str(x) for x in t])}]",
"optional": lambda t: f"Optional[{_type2str(t)}]",
"union": lambda *t: f"Union[{', '.join([_type2str(x) for x in t])}]",
@@ -177,6 +197,19 @@ def _type_check_vtable() -> Dict[str, Callable]:
return error_msg
return None
+ def _type_check_dict(dict_obj: Dict[Any, Any], name: str, *types: Any) -> Optional[str]:
+ ktype_, vtype_ = types
+ if not isinstance(dict_obj, dict):
+ return _type_check_err(dict_obj, name, dict)
+ for k, v in dict_obj.items():
+ error_msg = _type_check(k, f"{name}[{k}]", ktype_)
+ if error_msg is not None:
+ return error_msg
+ error_msg = _type_check(v, f"{name}[{k}]", vtype_)
+ if error_msg is not None:
+ return error_msg
+ return None
+
def _type_check_tuple(v: Any, name: str, *types: Any) -> Optional[str]:
if not isinstance(v, tuple):
return _type_check_err(v, name, Tuple[types])
@@ -202,6 +235,7 @@ def _type_check_vtable() -> Dict[str, Callable]:
"none": _type_check_none,
"atomic": _type_check_atomic,
"list": _type_check_list,
+ "dict": _type_check_dict,
"tuple": _type_check_tuple,
"optional": _type_check_optional,
"union": _type_check_union,
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index 7bec054b73..cf031c014c 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2186,12 +2186,19 @@ class Schedule(Object):
########## Schedule: Annotation ##########
+ PrimAnnotationValueT = Union[str, int, float, ExprRV]
+ AnnotationValueT = Union[
+ PrimAnnotationValueT,
+ List[PrimAnnotationValueT],
+ Dict[str, Union[PrimAnnotationValueT, List[PrimAnnotationValueT]]],
+ ]
+
@type_checked
def annotate(
self,
block_or_loop: Union[BlockRV, LoopRV],
ann_key: str,
- ann_val: Union[str, int, float, ExprRV, List[Union[str, int, float, ExprRV]]],
+ ann_val: AnnotationValueT,
) -> None:
"""Annotate a block/loop with a key value pair
@@ -2201,7 +2208,7 @@ class Schedule(Object):
The block/loop to be annotated
ann_key : str
The annotation key
- ann_val : Union[str, int, float, ExprRV, List[Union[str, int, float, ExprRV]]]
+ ann_val : AnnotationValueT
The annotation value
Examples
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 9d0bb885e2..a73d6165a2 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -699,6 +699,21 @@ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_
}
return std::move(result);
}
+ if (const auto* dict = ann_val.as<MapNode>()) {
+ Map<String, ObjectRef> result;
+ for (auto it = dict->begin(); it != dict->end(); ++it) {
+ const auto& key = it->first;
+ auto value = CheckAndGetAnnotationValue(it->second);
+ if (const StringImmNode* imm = key.as<StringImmNode>()) {
+ result.Set(imm->value, value);
+ } else if (key->IsInstance<StringObj>()) {
+ result.Set(Downcast<String>(key), value);
+ } else {
+ LOG(FATAL) << "TypeError: annotation dict key expect to be String or StringImm";
+ }
+ }
+ return std::move(result);
+ }
LOG(FATAL)
<< "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but "
<< "gets: " << ann_val->GetTypeKey();
diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h
index 14d05a4a34..56c69224fe 100644
--- a/src/tir/schedule/instruction_traits.h
+++ b/src/tir/schedule/instruction_traits.h
@@ -22,7 +22,9 @@
#include <tvm/tir/schedule/instruction.h>
#include <tvm/tir/schedule/schedule.h>
+#include <algorithm>
#include <sstream>
+#include <string>
#include <utility>
#include <vector>
@@ -447,6 +449,28 @@ inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os
AsPythonString(e, os);
}
os << ']';
+ } else if (const auto* dict = obj.as<MapNode>()) {
+ os << '{';
+ bool is_first = true;
+ std::vector<std::pair<std::string, std::string>> dict_items;
+ for (auto it = dict->begin(); it != dict->end(); ++it) {
+ std::ostringstream ks;
+ AsPythonString(it->first, ks);
+ std::ostringstream vs;
+ AsPythonString(it->second, vs);
+ dict_items.emplace_back(ks.str(), vs.str());
+ }
+ std::sort(dict_items.begin(), dict_items.end(),
+ [](const auto& p1, const auto& p2) { return p1.first < p2.first; });
+ for (const auto& kv : dict_items) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ os << ", ";
+ }
+ os << '\"' << kv.first << "\": " << kv.second;
+ }
+ os << '}';
} else {
LOG(FATAL) << "ValueError: Cannot translate type '" << obj->GetTypeKey()
<< "' to python. Its value is: " << obj;
diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc
index 9fa86917c5..395613bf4c 100644
--- a/src/tir/schedule/trace.cc
+++ b/src/tir/schedule/trace.cc
@@ -112,6 +112,9 @@ Array<ObjectRef> TranslateInputRVs(
} else if (input->IsInstance<ArrayNode>()) {
// Case 4: array
results.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), rv_names));
+ } else if (input->IsInstance<MapNode>()) {
+ // Case 5: dict
+ results.push_back(input);
} else if (input->IsInstance<BlockRVNode>() || inputs->IsInstance<LoopRVNode>() ||
inputs->IsInstance<VarNode>()) {
LOG(FATAL) << "IndexError: Random variable is not defined " << input;
@@ -139,6 +142,11 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
results.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), named_rvs));
continue;
}
+ // Case 5. dict
+ if (input->IsInstance<MapNode>()) {
+ results.push_back(input);
+ continue;
+ }
const auto* str = input.as<StringObj>();
CHECK(str) << "TypeError: Expect String, but gets: " << input->GetTypeKey();
CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names";
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py
index c479555590..41844a868e 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -67,7 +67,7 @@ def matmul_relu_ann1(a: T.handle, b: T.handle, d: T.handle) -> None:
B = T.match_buffer(b, (1024, 1024))
C = T.alloc_buffer((1024, 1024))
D = T.match_buffer(d, (1024, 1024))
- for i in T.serial(0, 1024, annotations={"test1": "aaa"}):
+ for i in T.serial(0, 1024, annotations={"test1": "aaa", "test4": {"arr": [0, 0], "key": 3}}):
for j in T.serial(0, 1024, annotations={"test2": 612, "test3": ["aa", 1]}):
for k in T.serial(0, 1024):
with T.block("matmul"):
@@ -92,7 +92,7 @@ def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None:
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
- T.block_attr({"test1": "aaa"})
+ T.block_attr({"test1": "aaa", "test4": {"arr": [0, 0], "key": 3}})
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(1024, 1024):
with T.block("relu"):
@@ -279,11 +279,13 @@ def test_annotate_unannotate_loop():
sch.annotate(sch.get_loops(matmul)[0], "test1", "aaa")
sch.annotate(sch.get_loops(matmul)[1], "test2", 612)
sch.annotate(sch.get_loops(matmul)[1], "test3", ["aa", 1])
+ sch.annotate(sch.get_loops(matmul)[0], "test4", {"arr": [0, 0], "key": 3})
tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann1)
verify_trace_roundtrip(sch=sch, mod=matmul_relu)
sch.unannotate(sch.get_loops(matmul)[0], "test1")
sch.unannotate(sch.get_loops(matmul)[1], "test2")
sch.unannotate(sch.get_loops(matmul)[1], "test3")
+ sch.unannotate(sch.get_loops(matmul)[0], "test4")
verify_trace_roundtrip(sch=sch, mod=matmul_relu)
@@ -294,11 +296,13 @@ def test_annotate_unannotate_block():
sch.annotate(matmul, "test1", "aaa")
sch.annotate(relu, "test2", 0.22)
sch.annotate(relu, "test3", ["aa", 1])
+ sch.annotate(matmul, "test4", {"arr": [0, 0], "key": 3})
tvm.ir.assert_structural_equal(sch.mod["main"], matmul_relu_ann2)
verify_trace_roundtrip(sch=sch, mod=matmul_relu)
sch.unannotate(matmul, "test1")
sch.unannotate(relu, "test2")
sch.unannotate(relu, "test3")
+ sch.unannotate(matmul, "test4")
verify_trace_roundtrip(sch=sch, mod=matmul_relu)
diff --git a/tests/python/unittest/test_type_annotation_checker.py b/tests/python/unittest/test_type_annotation_checker.py
index 9f6f29c7ff..e84ae043d3 100644
--- a/tests/python/unittest/test_type_annotation_checker.py
+++ b/tests/python/unittest/test_type_annotation_checker.py
@@ -16,7 +16,8 @@
# under the License.
"""Test type checker based on python's type annotations"""
-from typing import List, Tuple, Union
+import sys
+from typing import Dict, List, Tuple, Union
import pytest
@@ -44,6 +45,13 @@ test_cases = [
["5"],
],
},
+ {
+ "type_annotation": Dict[str, int],
+ "positive_cases": [
+ {"key1": 0, "key2": 1, "key3": -1},
+ ],
+ "negative_cases": [None, [1], {1: "1"}],
+ },
{
"type_annotation": Tuple[int],
"positive_cases": [