You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/12/09 18:05:52 UTC
[tvm] branch main updated: [FIX] Simplify during create prim func (#9691)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6f623a9 [FIX] Simplify during create prim func (#9691)
6f623a9 is described below
commit 6f623a96f9d2f28ef98762ca6167988bb480f085
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Fri Dec 10 02:05:27 2021 +0800
[FIX] Simplify during create prim func (#9691)
---
src/te/operation/create_primfunc.cc | 34 ++++++++++++++++--------
tests/python/unittest/test_te_create_primfunc.py | 16 +++++++++--
2 files changed, 37 insertions(+), 13 deletions(-)
diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc
index d90681a..e43013c 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -17,6 +17,7 @@
* under the License.
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
@@ -83,19 +84,21 @@ struct CreateFuncInfo {
BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::Tensor& tensor,
Array<PrimExpr> bindings, PrimExpr expr_body,
- CreateFuncInfo* info) {
+ CreateFuncInfo* info, arith::Analyzer* analyzer) {
// Step 1. Push_back data_par axis and reduce_axis into block_vars.
Array<IterVar> iter_vars;
std::unordered_map<const VarNode*, PrimExpr> var_map;
iter_vars.reserve(compute_op->axis.size() + compute_op->reduce_axis.size());
- auto f_push_block_vars = [&iter_vars, &var_map](const Array<IterVar>& iters) {
+ auto f_push_block_vars = [&iter_vars, &var_map, &analyzer](const Array<IterVar>& iters) {
for (IterVar iter_var : iters) {
// Create new var
Var new_var(iter_var->var->name_hint, iter_var->var->dtype);
var_map[iter_var->var.get()] = new_var;
IterVarNode* iter_var_node = iter_var.CopyOnWrite();
- iter_var_node->dom = Range::FromMinExtent(iter_var->dom->min, iter_var->dom->extent);
+ const PrimExpr& dom_min = analyzer->Simplify(iter_var->dom->min);
+ const PrimExpr& dom_extent = analyzer->Simplify(iter_var->dom->extent);
+ iter_var_node->dom = Range::FromMinExtent(dom_min, dom_extent);
iter_var_node->var = new_var;
iter_vars.push_back(iter_var);
}
@@ -130,11 +133,14 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
const PrimExpr& lhs = BufferLoad(buffer, indices);
const PrimExpr& rhs = Substitute(info->transformer(reduce->source[0]), var_map);
ICHECK(lhs->dtype == rhs->dtype);
- body = BufferStore(buffer, reduce->combiner.get()->operator()({lhs}, {rhs})[0], indices);
- init = BufferStore(buffer, reduce->combiner->identity_element[0], indices);
+ const PrimExpr& reduce_body = reduce->combiner.get()->operator()({lhs}, {rhs})[0];
+ const PrimExpr& init_body = reduce->combiner->identity_element[0];
+ body = BufferStore(buffer, analyzer->Simplify(reduce_body), indices);
+ init = BufferStore(buffer, analyzer->Simplify(init_body), indices);
} else {
// Case 2. Data parallel compute
- body = BufferStore(buffer, Substitute(info->transformer(expr_body), var_map), indices);
+ const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map);
+ body = BufferStore(buffer, analyzer->Simplify(compute_body), indices);
}
// Step 6. Add script_parsing_detect_access attr for auto complete the whole IR.
@@ -156,7 +162,8 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te::
/*annotations=*/std::move(annotations)));
}
-Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info) {
+Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info,
+ arith::Analyzer* analyzer) {
// Step 1. Creating loop vars for block bindings.
Array<IterVar> axes = compute_op->axis;
axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end());
@@ -169,16 +176,18 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
for (int i = 0; i < compute_op->num_outputs(); ++i) {
const te::Tensor& tensor = compute_op.output(i);
PrimExpr expr_body = compute_op->body[i];
- seq_stmt.push_back(
- GenerateBlockFromTensor(compute_op, tensor, bindings, std::move(expr_body), info));
+ seq_stmt.push_back(GenerateBlockFromTensor(compute_op, tensor, bindings, std::move(expr_body),
+ info, analyzer));
}
Stmt body = SeqStmt::Flatten(seq_stmt);
// Step 3. Generate loop nesting.
for (size_t i = axes.size(); i > 0; --i) {
const IterVar& axis = axes[i - 1];
+ PrimExpr dom_min = analyzer->Simplify(axis->dom->min);
+ PrimExpr dom_extent = analyzer->Simplify(axis->dom->extent);
const Var& loop_var = Downcast<Var>(bindings[i - 1]);
- body = For(loop_var, axis->dom->min, axis->dom->extent, ForKind::kSerial, body);
+ body = For(loop_var, dom_min, dom_extent, ForKind::kSerial, body);
}
return body;
@@ -256,6 +265,8 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
CreateFuncInfo info(arg_list);
// Root body stmts.
Array<Stmt> root_stmts;
+ // Analyzer
+ arith::Analyzer analyzer;
// Step 3. Rewrite compute stages into blocks.
for (const te::Operation& op : order) {
@@ -270,7 +281,8 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
info.tensor2buffers[tensor] = buffer;
} else if (const auto* compute_op = op.as<te::ComputeOpNode>()) {
// Case 2. ComputeOp (te.compute)
- root_stmts.push_back(GenerateStmtFromCompute(GetRef<te::ComputeOp>(compute_op), &info));
+ root_stmts.push_back(
+ GenerateStmtFromCompute(GetRef<te::ComputeOp>(compute_op), &info, &analyzer));
} else if (const auto extern_op = op.as<te::ExternOpNode>()) {
// Case 3. ExternOp (te.extern)
root_stmts.push_back(GenerateStmtFromExternOp(GetRef<te::ExternOp>(extern_op), &info));
diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py
index 6b5c26d..dc72a5f 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -17,7 +17,7 @@
# pylint: disable=missing-function-docstring,missing-module-docstring
import tvm
from tvm.script import tir as T
-from tvm import te, tir
+from tvm import te, tir, topi
import numpy as np
import tvm.testing
@@ -135,7 +135,7 @@ def tir_conv2d(a: T.handle, w: T.handle, b: T.handle) -> None:
with T.block("Apad"):
nn, cc, yy, xx = T.axis.remap("SSSS", [n, c, y, x])
Apad[nn, cc, yy, xx] = T.if_then_else(
- yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14,
+ 1 <= yy and yy < 15 and 1 <= xx and xx < 15,
A[nn, cc, yy - 1, xx - 1],
0.0,
dtype="float32",
@@ -327,6 +327,17 @@ def test_data_dependent_access():
tvm.testing.assert_allclose(a_np[b_np], c.numpy())
+def test_select_simplify():
+ placeholder = te.placeholder([1, 128, 10, 10, 4], dtype="float32")
+ tensor = topi.nn.adaptive_pool(placeholder, [1, 1], "avg", "NCHW4c")
+ result = te.create_prim_func([placeholder, tensor])
+ script_func = result.script()
+ # There should be no Select
+ assert script_func.find("Select") == -1
+ # There should be no undefined vars
+ assert script_func.find("Var") == -1
+
+
if __name__ == "__main__":
test_unique_name()
test_matmul()
@@ -337,3 +348,4 @@ if __name__ == "__main__":
test_arg_order()
test_error_reporting()
test_constant()
+ test_select_simplify()