You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/06/02 06:57:39 UTC
[tvm] branch main updated: [TE] Fix `te.CreatePrimFunc` for 0-dim computation (#11518)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c6d7ecd0b5 [TE] Fix `te.CreatePrimFunc` for 0-dim computation (#11518)
c6d7ecd0b5 is described below
commit c6d7ecd0b5e71796c79b001f439322ae1d0ddbe0
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Jun 1 23:57:33 2022 -0700
[TE] Fix `te.CreatePrimFunc` for 0-dim computation (#11518)
For 0-dimensional computation, `te.CreatePrimFunc` creates an opaque block with 0 block iters,
which is mistakenly passed into TVMScript auto-completion that failed to add the root block properly.
As an example,
```python
>> from tvm import te
>> a = te.placeholder((), name="a", dtype="int32")
>> b = te.placeholder((), name="b", dtype="int32")
>> c = te.compute(a.shape, lambda *i: a(*i) + b(*i), name="c")
>> f = te.create_prim_func([a, b, c])
>> print(f.body.block.reads)
[a[], b[]]
>> print(f.body.block.writes)
[c[]]
```
This PR fixes this issue by enforcing the consistency that `te.CreatePrimFunc`
always creates scheduleable blocks with at least 1 block iter:
```python
@T.prim_func
def func(a: T.Buffer[(), "int32"], b: T.Buffer[(), "int32"], c: T.Buffer[(), "int32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
with T.block("c"):
vi = T.axis.spatial(1, 0)
T.reads(a[()], b[()])
T.writes(c[()])
c[()] = a[()] + b[()]
```
---
src/meta_schedule/task_scheduler/task_scheduler.cc | 2 ++
src/te/operation/create_primfunc.cc | 8 ++++++-
tests/python/unittest/test_te_create_primfunc.py | 27 ++++++++++++++++++++++
3 files changed, 36 insertions(+), 1 deletion(-)
diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc
index 7485f4e076..fd1d95cd1f 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -94,6 +94,8 @@ void SendToRunner(const Runner& runner, const TuneContext& context, PackedFunc l
void TaskSchedulerNode::InitializeTask(int task_id) {
TuneContext task = this->tasks[task_id];
+ TVM_PY_LOG(INFO, this->logging_func)
+ << "Initializing Task #" << task_id << ": " << task->task_name;
TVM_PY_LOG(INFO, task->logging_func)
<< "Initializing Task #" << task_id << ": " << task->task_name;
CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined";
diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc
index 03ad551c68..27cfdd605c 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -264,6 +264,12 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
}
// Set script_parsing_detect_access
annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3));
+ if (iter_vars.empty()) {
+ IterVar iter(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)), IterVarType::kDataPar);
+ PrimExpr binding(0);
+ iter_vars.push_back(iter);
+ bindings.push_back(binding);
+ }
// Step 6. Create Block and BlockRealize.
return BlockRealize(/*iter_values=*/std::move(bindings),
@@ -454,7 +460,7 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
{{"global_symbol", String("main")}, {"tir.noalias", Bool(true)}});
const auto* complete = runtime::Registry::Get("script.Complete");
ICHECK(complete);
- func = (*complete)(func, info.root_alloc);
+ func = (*complete)(std::move(func), info.root_alloc);
return LayoutFreePlaceholdersNormalizer().Process(std::move(func));
}
diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py
index 014ca71a81..5d9ad003b4 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -524,6 +524,32 @@ def test_int64_indices():
assert loop.extent.dtype == "int64"
+def test_zero_dim_add():
+ def te_func():
+ a = te.placeholder((), name="a", dtype="int32")
+ b = te.placeholder((), name="b", dtype="int32")
+ c = te.compute(a.shape, lambda *i: a(*i) + b(*i), name="c")
+ return [a, b, c]
+
+ @T.prim_func
+ def expected(
+ a: T.Buffer[(), "int32"],
+ b: T.Buffer[(), "int32"],
+ c: T.Buffer[(), "int32"],
+ ) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ with T.block("c"):
+ vi = T.axis.spatial(1, 0)
+ T.reads(a[()], b[()])
+ T.writes(c[()])
+ c[()] = a[()] + b[()]
+
+ _check_workload(te_func, expected)
+
+
if __name__ == "__main__":
test_unique_name_complete_block()
test_unique_name_reduction_block()
@@ -541,3 +567,4 @@ if __name__ == "__main__":
test_argmax_idx_val()
test_argmax_val_idx()
test_int64_indices()
+ test_zero_dim_add()