You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/04/08 12:50:18 UTC
[tvm] branch main updated: introduce pass lower_init_block (#7806)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 24e62ca introduce pass lower_init_block (#7806)
24e62ca is described below
commit 24e62ca6cb428de9e2e412a38dbe901f99d9dcce
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Thu Apr 8 20:49:49 2021 +0800
introduce pass lower_init_block (#7806)
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
---
include/tvm/tir/transform.h | 6 ++
python/tvm/tir/transform/transform.py | 11 +++
src/tir/transforms/lower_init_block.cc | 85 ++++++++++++++++++++++
.../test_tir_transform_lower_init_block.py | 53 ++++++++++++++
4 files changed, 155 insertions(+)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index f31e515..2397caf 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -346,6 +346,12 @@ TVM_DLL Pass PointerValueTypeRewrite();
*/
TVM_DLL Pass HoistIfThenElse();
+/*!
+ * \brief Lower block init stmt into IfThenElse stmts
+ * \return The pass.
+ */
+TVM_DLL Pass LowerInitBlock();
+
} // namespace transform
} // namespace tir
} // namespace tvm
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index 40dd170..8bd63bd 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -536,3 +536,14 @@ def HoistIfThenElse(variant=None):
return _ffi_api.HoistIfThenElseBasic()
elif variant is None:
return _ffi_api.HoistIfThenElse()
+
+
+def LowerInitBlock():
+ """Lower block init stmt into IfThenElse stmts
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerInitBlock()
diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc
new file mode 100644
index 0000000..c8aca51
--- /dev/null
+++ b/src/tir/transforms/lower_init_block.cc
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Lower block init stmt into branch stmt
+ * \file lower_reduction.cc
+ */
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+class InitBlockLower : public StmtMutator {
+ private:
+ Stmt VisitStmt_(const BlockNode* block) final {
+ if (!block->init.defined()) {
+ return StmtMutator::VisitStmt_(block);
+ }
+ Stmt init = DoLowering(block->init.value(), block->iter_vars);
+ Stmt body = VisitStmt(block->body);
+ auto n = CopyOnWrite(block);
+ n->init = NullOpt;
+ n->body = SeqStmt::Flatten(init, body);
+ return Block(n);
+ }
+
+ static Stmt DoLowering(const Stmt& init, const Array<IterVar>& iter_vars) {
+ std::vector<PrimExpr> conditions;
+ for (const IterVar& var : iter_vars) {
+ if (var->iter_type == IterVarType::kCommReduce) {
+ conditions.push_back(equal(var->var, var->dom->min));
+ }
+ }
+ // Handle the case where there is no condition
+ if (conditions.empty()) {
+ return init;
+ }
+ // Concat the conditions with logical and (&&)
+ PrimExpr cond = conditions[0];
+ for (size_t i = 1; i < conditions.size(); ++i) {
+ cond = logical_and(cond, conditions[i]);
+ }
+ return IfThenElse(cond, init);
+ }
+};
+
+PrimFunc LowerInitBlock(PrimFunc func) {
+ auto fptr = func.CopyOnWrite();
+ fptr->body = InitBlockLower()(std::move(fptr->body));
+ return func;
+}
+
+namespace transform {
+
+Pass LowerInitBlock() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ return LowerInitBlock(std::move(f));
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.LowerReduction", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock);
+
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py
new file mode 100644
index 0000000..3fb8331
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_lower_init_block.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
+@tvm.script.tir
+class WithInit:
+ def main(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, [64, 64, 64])
+ B = tir.match_buffer(b, [64])
+
+ with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]:
+ with tir.init():
+ B[i] = tir.float32(0)
+ B[i] += A[i, j, k]
+
+
+@tvm.script.tir
+class WithBranch:
+ def main(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, [64, 64, 64])
+ B = tir.match_buffer(b, [64])
+
+ with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]:
+ if (j == 0) and (k == 32):
+ B[i] = tir.float32(0)
+ B[i] += A[i, j, k]
+
+
+def test_lower_reduction():
+ origin_mod = WithInit()
+ mod = tvm.tir.transform.LowerInitBlock()(origin_mod)
+ tvm.ir.assert_structural_equal(mod, WithBranch(), True)
+
+
+if __name__ == "__main__":
+ test_lower_reduction()