You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/02/24 21:26:08 UTC

[tvm] branch main updated: Support creating Bool constants in the pattern_utils (#7507)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 88a4fdd  Support creating Bool constants in the pattern_utils (#7507)
88a4fdd is described below

commit 88a4fdddc2bdd41a62baaaa55dbd4c524d25933d
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Wed Feb 24 14:25:50 2021 -0700

    Support creating Bool constants in the pattern_utils (#7507)
---
 src/relay/transforms/pattern_utils.h          | 3 +++
 tests/python/relay/test_pass_simplify_expr.py | 2 +-
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h
index bc0fcc9..c1eebde 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -86,6 +86,9 @@ namespace relay {
   } else if (type == DataType::UInt(8)) {                                             \
     typedef uint8_t DType;                                                            \
     { __VA_ARGS__ }                                                                   \
+  } else if (type == DataType::Bool()) {                                              \
+    typedef bool DType;                                                               \
+    { __VA_ARGS__ }                                                                   \
   } else if ((*tvm::runtime::Registry::Get("runtime._datatype_get_type_registered"))( \
                  static_cast<uint8_t>(type.code()))) {                                \
     typedef double DType;                                                             \
diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py
index 3d925bc..423f0a4 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -117,7 +117,7 @@ def test_simplify_full_elementwise():
                 assert tvm.ir.structural_equal(zz, after)
 
     for shape in [[10], [10, 10], [10, 10, 10]]:
-        for dtype in ["float32", "int32"]:
+        for dtype in ["float32", "int32", "bool"]:
             for value in [0, 1, 2]:
                 validate(shape, value, dtype)