You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/02/24 21:26:08 UTC
[tvm] branch main updated: Support creating Bool constants in the
pattern_utils (#7507)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 88a4fdd Support creating Bool constants in the pattern_utils (#7507)
88a4fdd is described below
commit 88a4fdddc2bdd41a62baaaa55dbd4c524d25933d
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Wed Feb 24 14:25:50 2021 -0700
Support creating Bool constants in the pattern_utils (#7507)
---
src/relay/transforms/pattern_utils.h | 3 +++
tests/python/relay/test_pass_simplify_expr.py | 2 +-
2 files changed, 4 insertions(+), 1 deletion(-)
diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h
index bc0fcc9..c1eebde 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -86,6 +86,9 @@ namespace relay {
} else if (type == DataType::UInt(8)) { \
typedef uint8_t DType; \
{ __VA_ARGS__ } \
+ } else if (type == DataType::Bool()) { \
+ typedef bool DType; \
+ { __VA_ARGS__ } \
} else if ((*tvm::runtime::Registry::Get("runtime._datatype_get_type_registered"))( \
static_cast<uint8_t>(type.code()))) { \
typedef double DType; \
diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py
index 3d925bc..423f0a4 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -117,7 +117,7 @@ def test_simplify_full_elementwise():
assert tvm.ir.structural_equal(zz, after)
for shape in [[10], [10, 10], [10, 10, 10]]:
- for dtype in ["float32", "int32"]:
+ for dtype in ["float32", "int32", "bool"]:
for value in [0, 1, 2]:
validate(shape, value, dtype)