You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/02 08:44:10 UTC

[tvm] branch main updated: [Bugfix][TIR] Handle bool tensor in FlattenBuffer (#11532)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4c513b9de3 [Bugfix][TIR] Handle bool tensor in FlattenBuffer (#11532)
4c513b9de3 is described below

commit 4c513b9de3ebfdf4a1356f0daf7350e74ca74005
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Thu Jun 2 01:44:05 2022 -0700

    [Bugfix][TIR] Handle bool tensor in FlattenBuffer (#11532)
    
    This PR fixes an existing bug in TIR lowering where the TIR below triggers an error:
    
    ```python
    @T.prim_func
    def func(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i in T.serial(10):
            with T.block("b"):
                vi = T.axis.spatial(10, i)
                b[vi] = a[vi]
    
    tvm.build(func, target="llvm")
    ```
    
    The error message is:
    
    ```
      File "/root/Projects/tvm-dev/src/tir/transforms/flatten_buffer.cc", line 173
    TVMError:
    ---------------------------------------------------------------
    An error occurred during the execution of TVM.
    For more information, please see: https://tvm.apache.org/docs/errors.html
    ---------------------------------------------------------------
    
    Check failed: store->buffer->dtype == DataType::Int(8) (bool vs. int8) : Expected int8 backing array
    for boolean tensor
    ```
    
    This PR fixes this behavior.
---
 src/tir/transforms/flatten_buffer.cc               | 18 +++++------
 .../unittest/test_tir_transform_flatten_buffer.py  | 37 +++++++++++++++++++++-
 2 files changed, 45 insertions(+), 10 deletions(-)

diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc
index c7cc51d271..21de191db0 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -53,9 +53,7 @@ class BufferFlattener : public StmtExprMutator {
   static PrimFunc Flatten(PrimFunc func) {
     Map<Var, Buffer> preflattened_buffer_map =
         Merge(func->buffer_map, func->preflattened_buffer_map);
-
     auto pass = BufferFlattener(func->buffer_map);
-
     auto writer = func.CopyOnWrite();
     writer->body = pass.VisitStmt(func->body);
     writer->preflattened_buffer_map = preflattened_buffer_map;
@@ -137,7 +135,7 @@ class BufferFlattener : public StmtExprMutator {
     } else {
       PrimExpr expr = it->second;
       if (expr.dtype() != var.dtype()) {
-        expr = Cast(var.dtype(), std::move(expr));
+        expr = tvm::cast(var.dtype(), std::move(expr));
       }
       return expr;
     }
@@ -164,33 +162,35 @@ class BufferFlattener : public StmtExprMutator {
 
   Stmt VisitStmt_(const BufferStoreNode* op) final {
     BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    bool store_returns_bool = (op->value.dtype() == DataType::Bool());
+    store = VisitBufferAccess(store);
 
     // Handle casts from the value's dtype to the dtype of the
     // backing array.
     // TODO(Lunderberg): Move the handling of boolean into a
     // dedicated pass.
-    if (store->value.dtype() == DataType::Bool()) {
+    if (store_returns_bool) {
       ICHECK_EQ(store->buffer->dtype, DataType::Int(8))
           << "Expected int8 backing array for boolean tensor";
       auto writer = store.CopyOnWrite();
-      writer->value = tir::Cast(DataType::Int(8), store->value);
+      writer->value = tvm::cast(DataType::Int(8), store->value);
+      return store;
     }
-    auto flattened_indices = store->buffer->ElemOffset(store->indices);
-    return VisitBufferAccess(std::move(store));
+    return store;
   }
 
   PrimExpr VisitExpr_(const BufferLoadNode* op) final {
     bool load_returns_bool = (op->dtype == DataType::Bool());
     BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
     load = VisitBufferAccess(load);
-
     // Handle casts from dtype of the backing array to value's dtype.
     // TODO(Lunderberg): Move the handling of boolean into a
     // dedicated pass.
     if (load_returns_bool) {
       ICHECK_EQ(load->buffer->dtype, DataType::Int(8))
           << "Expected int8 backing array for boolean tensor";
-      return tir::Cast(DataType::Bool(), load);
+      load.CopyOnWrite()->dtype = DataType::Int(8);
+      return tvm::cast(DataType::Bool(), load);
     } else {
       return std::move(load);
     }
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py
index 65be43aba3..f1a33a4fb2 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import tir, te
+from tvm import te, tir
 from tvm.script import tir as T
 
 
@@ -268,6 +268,33 @@ def annotated_loops(a: T.handle) -> None:
         A[i] = 0.0
 
 
+@T.prim_func
+def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None:
+    for i0 in T.serial(10):
+        with T.block("b"):
+            T.reads(a[i0])
+            T.writes(b[i0])
+            b[i0] = a[i0]
+
+
+@T.prim_func
+def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None:
+    T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
+    T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
+    # body
+    for i0 in T.serial(10):
+        b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
+
+
+@T.prim_func
+def boolean_handle_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None:
+    T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
+    T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
+    # body
+    for i0 in T.serial(10):
+        b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
+
+
 def test_elementwise():
     _check(compacted_elementwise_func, flattened_elementwise_func)
 
@@ -319,6 +346,13 @@ def test_annotated_loops():
     tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0))
 
 
+def test_boolean_handling():
+    _check(boolean_handling_before, boolean_handling_after)
+    # mod = tvm.IRModule.from_expr(boolean_handling_before)
+    # mod = tvm.tir.transform.FlattenBuffer()(mod)
+    # print(mod.script())
+
+
 if __name__ == "__main__":
     test_elementwise()
     test_gpu_workload()
@@ -329,3 +363,4 @@ if __name__ == "__main__":
     test_strided_buffer()
     test_lower_te()
     test_annotated_loops()
+    test_boolean_handling()