You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/06/02 06:57:39 UTC

[tvm] branch main updated: [TE] Fix `te.CreatePrimFunc` for 0-dim computation (#11518)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c6d7ecd0b5 [TE] Fix `te.CreatePrimFunc` for 0-dim computation (#11518)
c6d7ecd0b5 is described below

commit c6d7ecd0b5e71796c79b001f439322ae1d0ddbe0
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Jun 1 23:57:33 2022 -0700

    [TE] Fix `te.CreatePrimFunc` for 0-dim computation (#11518)
    
    For 0-dimensional computation, `te.CreatePrimFunc` creates an opaque block with 0 block iters,
    which is mistakenly passed into TVMScript auto-completion that failed to add the root block properly.
    As an example,
    
    ```python
    >> from tvm import te
    >> a = te.placeholder((), name="a", dtype="int32")
    >> b = te.placeholder((), name="b", dtype="int32")
    >> c = te.compute(a.shape, lambda *i: a(*i) + b(*i), name="c")
    >> f = te.create_prim_func([a, b, c])
    >> print(f.body.block.reads)
    [a[], b[]]
    >> print(f.body.block.writes)
    [c[]]
    ```
    
    This PR fixes this issue by enforcing the consistency that `te.CreatePrimFunc`
    always creates scheduleable blocks with at least 1 block iter:
    
    ```python
    @T.prim_func
    def func(a: T.Buffer[(), "int32"], b: T.Buffer[(), "int32"], c: T.Buffer[(), "int32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        with T.block("c"):
            vi = T.axis.spatial(1, 0)
            T.reads(a[()], b[()])
            T.writes(c[()])
            c[()] = a[()] + b[()]
    ```
---
 src/meta_schedule/task_scheduler/task_scheduler.cc |  2 ++
 src/te/operation/create_primfunc.cc                |  8 ++++++-
 tests/python/unittest/test_te_create_primfunc.py   | 27 ++++++++++++++++++++++
 3 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc
index 7485f4e076..fd1d95cd1f 100644
--- a/src/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/meta_schedule/task_scheduler/task_scheduler.cc
@@ -94,6 +94,8 @@ void SendToRunner(const Runner& runner, const TuneContext& context, PackedFunc l
 
 void TaskSchedulerNode::InitializeTask(int task_id) {
   TuneContext task = this->tasks[task_id];
+  TVM_PY_LOG(INFO, this->logging_func)
+      << "Initializing Task #" << task_id << ": " << task->task_name;
   TVM_PY_LOG(INFO, task->logging_func)
       << "Initializing Task #" << task_id << ": " << task->task_name;
   CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined";
diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc
index 03ad551c68..27cfdd605c 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -264,6 +264,12 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
   }
   // Set script_parsing_detect_access
   annotations.Set(tir::attr::script_parsing_detect_access, IntImm(DataType::Int(32), 3));
+  if (iter_vars.empty()) {
+    IterVar iter(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)), IterVarType::kDataPar);
+    PrimExpr binding(0);
+    iter_vars.push_back(iter);
+    bindings.push_back(binding);
+  }
 
   // Step 6. Create Block and BlockRealize.
   return BlockRealize(/*iter_values=*/std::move(bindings),
@@ -454,7 +460,7 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
                             {{"global_symbol", String("main")}, {"tir.noalias", Bool(true)}});
   const auto* complete = runtime::Registry::Get("script.Complete");
   ICHECK(complete);
-  func = (*complete)(func, info.root_alloc);
+  func = (*complete)(std::move(func), info.root_alloc);
   return LayoutFreePlaceholdersNormalizer().Process(std::move(func));
 }
 
diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py
index 014ca71a81..5d9ad003b4 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -524,6 +524,32 @@ def test_int64_indices():
     assert loop.extent.dtype == "int64"
 
 
+def test_zero_dim_add():
+    def te_func():
+        a = te.placeholder((), name="a", dtype="int32")
+        b = te.placeholder((), name="b", dtype="int32")
+        c = te.compute(a.shape, lambda *i: a(*i) + b(*i), name="c")
+        return [a, b, c]
+
+    @T.prim_func
+    def expected(
+        a: T.Buffer[(), "int32"],
+        b: T.Buffer[(), "int32"],
+        c: T.Buffer[(), "int32"],
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            with T.block("c"):
+                vi = T.axis.spatial(1, 0)
+                T.reads(a[()], b[()])
+                T.writes(c[()])
+                c[()] = a[()] + b[()]
+
+    _check_workload(te_func, expected)
+
+
 if __name__ == "__main__":
     test_unique_name_complete_block()
     test_unique_name_reduction_block()
@@ -541,3 +567,4 @@ if __name__ == "__main__":
     test_argmax_idx_val()
     test_argmax_val_idx()
     test_int64_indices()
+    test_zero_dim_add()