You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/02 08:44:10 UTC
[tvm] branch main updated: [Bugfix][TIR] Handle bool tensor in FlattenBuffer (#11532)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4c513b9de3 [Bugfix][TIR] Handle bool tensor in FlattenBuffer (#11532)
4c513b9de3 is described below
commit 4c513b9de3ebfdf4a1356f0daf7350e74ca74005
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Thu Jun 2 01:44:05 2022 -0700
[Bugfix][TIR] Handle bool tensor in FlattenBuffer (#11532)
This PR fixes an existing bug in TIR lowering where the TIR below triggers an error:
```python
@T.prim_func
def func(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i in T.serial(10):
with T.block("b"):
vi = T.axis.spatial(10, i)
b[vi] = a[vi]
tvm.build(func, target="llvm")
```
The error message is:
```
File "/root/Projects/tvm-dev/src/tir/transforms/flatten_buffer.cc", line 173
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: store->buffer->dtype == DataType::Int(8) (bool vs. int8) : Expected int8 backing array
for boolean tensor
```
This PR fixes this behavior.
---
src/tir/transforms/flatten_buffer.cc | 18 +++++------
.../unittest/test_tir_transform_flatten_buffer.py | 37 +++++++++++++++++++++-
2 files changed, 45 insertions(+), 10 deletions(-)
diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc
index c7cc51d271..21de191db0 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -53,9 +53,7 @@ class BufferFlattener : public StmtExprMutator {
static PrimFunc Flatten(PrimFunc func) {
Map<Var, Buffer> preflattened_buffer_map =
Merge(func->buffer_map, func->preflattened_buffer_map);
-
auto pass = BufferFlattener(func->buffer_map);
-
auto writer = func.CopyOnWrite();
writer->body = pass.VisitStmt(func->body);
writer->preflattened_buffer_map = preflattened_buffer_map;
@@ -137,7 +135,7 @@ class BufferFlattener : public StmtExprMutator {
} else {
PrimExpr expr = it->second;
if (expr.dtype() != var.dtype()) {
- expr = Cast(var.dtype(), std::move(expr));
+ expr = tvm::cast(var.dtype(), std::move(expr));
}
return expr;
}
@@ -164,33 +162,35 @@ class BufferFlattener : public StmtExprMutator {
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+ bool store_returns_bool = (op->value.dtype() == DataType::Bool());
+ store = VisitBufferAccess(store);
// Handle casts from the value's dtype to the dtype of the
// backing array.
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
- if (store->value.dtype() == DataType::Bool()) {
+ if (store_returns_bool) {
ICHECK_EQ(store->buffer->dtype, DataType::Int(8))
<< "Expected int8 backing array for boolean tensor";
auto writer = store.CopyOnWrite();
- writer->value = tir::Cast(DataType::Int(8), store->value);
+ writer->value = tvm::cast(DataType::Int(8), store->value);
+ return store;
}
- auto flattened_indices = store->buffer->ElemOffset(store->indices);
- return VisitBufferAccess(std::move(store));
+ return store;
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
bool load_returns_bool = (op->dtype == DataType::Bool());
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
load = VisitBufferAccess(load);
-
// Handle casts from dtype of the backing array to value's dtype.
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (load_returns_bool) {
ICHECK_EQ(load->buffer->dtype, DataType::Int(8))
<< "Expected int8 backing array for boolean tensor";
- return tir::Cast(DataType::Bool(), load);
+ load.CopyOnWrite()->dtype = DataType::Int(8);
+ return tvm::cast(DataType::Bool(), load);
} else {
return std::move(load);
}
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py
index 65be43aba3..f1a33a4fb2 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import tir, te
+from tvm import te, tir
from tvm.script import tir as T
@@ -268,6 +268,33 @@ def annotated_loops(a: T.handle) -> None:
A[i] = 0.0
+@T.prim_func
+def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None:
+ for i0 in T.serial(10):
+ with T.block("b"):
+ T.reads(a[i0])
+ T.writes(b[i0])
+ b[i0] = a[i0]
+
+
+@T.prim_func
+def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None:
+ T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
+ T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
+ # body
+ for i0 in T.serial(10):
+ b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
+
+
+@T.prim_func
+def boolean_handle_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None:
+ T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
+ T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
+ # body
+ for i0 in T.serial(10):
+ b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
+
+
def test_elementwise():
_check(compacted_elementwise_func, flattened_elementwise_func)
@@ -319,6 +346,13 @@ def test_annotated_loops():
tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0))
+def test_boolean_handling():
+ _check(boolean_handling_before, boolean_handling_after)
+ # mod = tvm.IRModule.from_expr(boolean_handling_before)
+ # mod = tvm.tir.transform.FlattenBuffer()(mod)
+ # print(mod.script())
+
+
if __name__ == "__main__":
test_elementwise()
test_gpu_workload()
@@ -329,3 +363,4 @@ if __name__ == "__main__":
test_strided_buffer()
test_lower_te()
test_annotated_loops()
+ test_boolean_handling()