You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/12/09 18:05:52 UTC

[tvm] branch main updated: [FIX] Simplify during create prim func (#9691)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f623a9  [FIX] Simplify during create prim func (#9691)
6f623a9 is described below

commit 6f623a96f9d2f28ef98762ca6167988bb480f085
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Fri Dec 10 02:05:27 2021 +0800

    [FIX] Simplify during create prim func (#9691)
---
 src/te/operation/create_primfunc.cc              | 34 ++++++++++++++++--------
 tests/python/unittest/test_te_create_primfunc.py | 16 +++++++++--
 2 files changed, 37 insertions(+), 13 deletions(-)

diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc
index d90681a..e43013c 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -17,6 +17,7 @@
  * under the License.
  */
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/function.h>
 #include <tvm/tir/stmt_functor.h>
@@ -83,19 +84,21 @@ struct CreateFuncInfo {
 
 BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::Tensor& tensor,
                                      Array<PrimExpr> bindings, PrimExpr expr_body,
-                                     CreateFuncInfo* info) {
+                                     CreateFuncInfo* info, arith::Analyzer* analyzer) {
   // Step 1. Push_back data_par axis and reduce_axis into block_vars.
   Array<IterVar> iter_vars;
   std::unordered_map<const VarNode*, PrimExpr> var_map;
   iter_vars.reserve(compute_op->axis.size() + compute_op->reduce_axis.size());
-  auto f_push_block_vars = [&iter_vars, &var_map](const Array<IterVar>& iters) {
+  auto f_push_block_vars = [&iter_vars, &var_map, &analyzer](const Array<IterVar>& iters) {
     for (IterVar iter_var : iters) {
       // Create new var
       Var new_var(iter_var->var->name_hint, iter_var->var->dtype);
       var_map[iter_var->var.get()] = new_var;
 
       IterVarNode* iter_var_node = iter_var.CopyOnWrite();
-      iter_var_node->dom = Range::FromMinExtent(iter_var->dom->min, iter_var->dom->extent);
+      const PrimExpr& dom_min = analyzer->Simplify(iter_var->dom->min);
+      const PrimExpr& dom_extent = analyzer->Simplify(iter_var->dom->extent);
+      iter_var_node->dom = Range::FromMinExtent(dom_min, dom_extent);
       iter_var_node->var = new_var;
       iter_vars.push_back(iter_var);
     }
@@ -130,11 +133,14 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
     const PrimExpr& lhs = BufferLoad(buffer, indices);
     const PrimExpr& rhs = Substitute(info->transformer(reduce->source[0]), var_map);
     ICHECK(lhs->dtype == rhs->dtype);
-    body = BufferStore(buffer, reduce->combiner.get()->operator()({lhs}, {rhs})[0], indices);
-    init = BufferStore(buffer, reduce->combiner->identity_element[0], indices);
+    const PrimExpr& reduce_body = reduce->combiner.get()->operator()({lhs}, {rhs})[0];
+    const PrimExpr& init_body = reduce->combiner->identity_element[0];
+    body = BufferStore(buffer, analyzer->Simplify(reduce_body), indices);
+    init = BufferStore(buffer, analyzer->Simplify(init_body), indices);
   } else {
     // Case 2. Data parallel compute
-    body = BufferStore(buffer, Substitute(info->transformer(expr_body), var_map), indices);
+    const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map);
+    body = BufferStore(buffer, analyzer->Simplify(compute_body), indices);
   }
 
   // Step 6. Add script_parsing_detect_access attr for auto complete the whole IR.
@@ -156,7 +162,8 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
                             /*annotations=*/std::move(annotations)));
 }
 
-Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info) {
+Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info,
+                             arith::Analyzer* analyzer) {
   // Step 1. Creating loop vars for block bindings.
   Array<IterVar> axes = compute_op->axis;
   axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end());
@@ -169,16 +176,18 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
   for (int i = 0; i < compute_op->num_outputs(); ++i) {
     const te::Tensor& tensor = compute_op.output(i);
     PrimExpr expr_body = compute_op->body[i];
-    seq_stmt.push_back(
-        GenerateBlockFromTensor(compute_op, tensor, bindings, std::move(expr_body), info));
+    seq_stmt.push_back(GenerateBlockFromTensor(compute_op, tensor, bindings, std::move(expr_body),
+                                               info, analyzer));
   }
   Stmt body = SeqStmt::Flatten(seq_stmt);
 
   // Step 3. Generate loop nesting.
   for (size_t i = axes.size(); i > 0; --i) {
     const IterVar& axis = axes[i - 1];
+    PrimExpr dom_min = analyzer->Simplify(axis->dom->min);
+    PrimExpr dom_extent = analyzer->Simplify(axis->dom->extent);
     const Var& loop_var = Downcast<Var>(bindings[i - 1]);
-    body = For(loop_var, axis->dom->min, axis->dom->extent, ForKind::kSerial, body);
+    body = For(loop_var, dom_min, dom_extent, ForKind::kSerial, body);
   }
 
   return body;
@@ -256,6 +265,8 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
   CreateFuncInfo info(arg_list);
   // Root body stmts.
   Array<Stmt> root_stmts;
+  // Analyzer
+  arith::Analyzer analyzer;
 
   // Step 3. Rewrite compute stages into blocks.
   for (const te::Operation& op : order) {
@@ -270,7 +281,8 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
       info.tensor2buffers[tensor] = buffer;
     } else if (const auto* compute_op = op.as<te::ComputeOpNode>()) {
       // Case 2. ComputeOp (te.compute)
-      root_stmts.push_back(GenerateStmtFromCompute(GetRef<te::ComputeOp>(compute_op), &info));
+      root_stmts.push_back(
+          GenerateStmtFromCompute(GetRef<te::ComputeOp>(compute_op), &info, &analyzer));
     } else if (const auto extern_op = op.as<te::ExternOpNode>()) {
       // Case 3. ExternOp (te.extern)
       root_stmts.push_back(GenerateStmtFromExternOp(GetRef<te::ExternOp>(extern_op), &info));
diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py
index 6b5c26d..dc72a5f 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -17,7 +17,7 @@
 # pylint: disable=missing-function-docstring,missing-module-docstring
 import tvm
 from tvm.script import tir as T
-from tvm import te, tir
+from tvm import te, tir, topi
 import numpy as np
 import tvm.testing
 
@@ -135,7 +135,7 @@ def tir_conv2d(a: T.handle, w: T.handle, b: T.handle) -> None:
         with T.block("Apad"):
             nn, cc, yy, xx = T.axis.remap("SSSS", [n, c, y, x])
             Apad[nn, cc, yy, xx] = T.if_then_else(
-                yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14,
+                1 <= yy and yy < 15 and 1 <= xx and xx < 15,
                 A[nn, cc, yy - 1, xx - 1],
                 0.0,
                 dtype="float32",
@@ -327,6 +327,17 @@ def test_data_dependent_access():
     tvm.testing.assert_allclose(a_np[b_np], c.numpy())
 
 
+def test_select_simplify():
+    placeholder = te.placeholder([1, 128, 10, 10, 4], dtype="float32")
+    tensor = topi.nn.adaptive_pool(placeholder, [1, 1], "avg", "NCHW4c")
+    result = te.create_prim_func([placeholder, tensor])
+    script_func = result.script()
+    # There should be no Select
+    assert script_func.find("Select") == -1
+    # There should be no undefined vars
+    assert script_func.find("Var") == -1
+
+
 if __name__ == "__main__":
     test_unique_name()
     test_matmul()
@@ -337,3 +348,4 @@ if __name__ == "__main__":
     test_arg_order()
     test_error_reporting()
     test_constant()
+    test_select_simplify()