This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 88a4fdd  Support creating Bool constants in the pattern_utils (#7507)
88a4fdd is described below

commit 88a4fdddc2bdd41a62baaaa55dbd4c524d25933d
Author: Matthew Brookhart <[email protected]>
AuthorDate: Wed Feb 24 14:25:50 2021 -0700

    Support creating Bool constants in the pattern_utils (#7507)
---
 src/relay/transforms/pattern_utils.h          | 3 +++
 tests/python/relay/test_pass_simplify_expr.py | 2 +-
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/src/relay/transforms/pattern_utils.h 
b/src/relay/transforms/pattern_utils.h
index bc0fcc9..c1eebde 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -86,6 +86,9 @@ namespace relay {
   } else if (type == DataType::UInt(8)) {                                      
       \
     typedef uint8_t DType;                                                     
       \
     { __VA_ARGS__ }                                                            
       \
+  } else if (type == DataType::Bool()) {                                       
       \
+    typedef bool DType;                                                        
       \
+    { __VA_ARGS__ }                                                            
       \
   } else if 
((*tvm::runtime::Registry::Get("runtime._datatype_get_type_registered"))( \
                  static_cast<uint8_t>(type.code()))) {                         
       \
     typedef double DType;                                                      
       \
diff --git a/tests/python/relay/test_pass_simplify_expr.py 
b/tests/python/relay/test_pass_simplify_expr.py
index 3d925bc..423f0a4 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -117,7 +117,7 @@ def test_simplify_full_elementwise():
                 assert tvm.ir.structural_equal(zz, after)
 
     for shape in [[10], [10, 10], [10, 10, 10]]:
-        for dtype in ["float32", "int32"]:
+        for dtype in ["float32", "int32", "bool"]:
             for value in [0, 1, 2]:
                 validate(shape, value, dtype)
 

Reply via email to