You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/04/08 12:50:18 UTC

[tvm] branch main updated: introduce pass lower_init_block (#7806)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 24e62ca  introduce pass lower_init_block (#7806)
24e62ca is described below

commit 24e62ca6cb428de9e2e412a38dbe901f99d9dcce
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Thu Apr 8 20:49:49 2021 +0800

    introduce pass lower_init_block (#7806)
    
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Ruihang Lai <la...@qq.com>
---
 include/tvm/tir/transform.h                        |  6 ++
 python/tvm/tir/transform/transform.py              | 11 +++
 src/tir/transforms/lower_init_block.cc             | 85 ++++++++++++++++++++++
 .../test_tir_transform_lower_init_block.py         | 53 ++++++++++++++
 4 files changed, 155 insertions(+)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index f31e515..2397caf 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -346,6 +346,12 @@ TVM_DLL Pass PointerValueTypeRewrite();
  */
 TVM_DLL Pass HoistIfThenElse();
 
+/*!
+ * \brief Lower block init stmt into IfThenElse stmts
+ * \return The pass.
+ */
+TVM_DLL Pass LowerInitBlock();
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index 40dd170..8bd63bd 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -536,3 +536,14 @@ def HoistIfThenElse(variant=None):
         return _ffi_api.HoistIfThenElseBasic()
     elif variant is None:
         return _ffi_api.HoistIfThenElse()
+
+
+def LowerInitBlock():
+    """Lower block init stmt into IfThenElse stmts
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LowerInitBlock()
diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc
new file mode 100644
index 0000000..c8aca51
--- /dev/null
+++ b/src/tir/transforms/lower_init_block.cc
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Lower block init stmt into branch stmt
+ * \file lower_reduction.cc
+ */
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+class InitBlockLower : public StmtMutator {
+ private:
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (!block->init.defined()) {
+      return StmtMutator::VisitStmt_(block);
+    }
+    Stmt init = DoLowering(block->init.value(), block->iter_vars);
+    Stmt body = VisitStmt(block->body);
+    auto n = CopyOnWrite(block);
+    n->init = NullOpt;
+    n->body = SeqStmt::Flatten(init, body);
+    return Block(n);
+  }
+
+  static Stmt DoLowering(const Stmt& init, const Array<IterVar>& iter_vars) {
+    std::vector<PrimExpr> conditions;
+    for (const IterVar& var : iter_vars) {
+      if (var->iter_type == IterVarType::kCommReduce) {
+        conditions.push_back(equal(var->var, var->dom->min));
+      }
+    }
+    // Handle the case where there is no condition
+    if (conditions.empty()) {
+      return init;
+    }
+    // Concat the conditions with logical and (&&)
+    PrimExpr cond = conditions[0];
+    for (size_t i = 1; i < conditions.size(); ++i) {
+      cond = logical_and(cond, conditions[i]);
+    }
+    return IfThenElse(cond, init);
+  }
+};
+
+PrimFunc LowerInitBlock(PrimFunc func) {
+  auto fptr = func.CopyOnWrite();
+  fptr->body = InitBlockLower()(std::move(fptr->body));
+  return func;
+}
+
+namespace transform {
+
+Pass LowerInitBlock() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    return LowerInitBlock(std::move(f));
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.LowerReduction", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock);
+
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py
new file mode 100644
index 0000000..3fb8331
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_lower_init_block.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
+@tvm.script.tir
+class WithInit:
+    def main(a: ty.handle, b: ty.handle) -> None:
+        A = tir.match_buffer(a, [64, 64, 64])
+        B = tir.match_buffer(b, [64])
+
+        with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]:
+            with tir.init():
+                B[i] = tir.float32(0)
+            B[i] += A[i, j, k]
+
+
+@tvm.script.tir
+class WithBranch:
+    def main(a: ty.handle, b: ty.handle) -> None:
+        A = tir.match_buffer(a, [64, 64, 64])
+        B = tir.match_buffer(b, [64])
+
+        with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]:
+            if (j == 0) and (k == 32):
+                B[i] = tir.float32(0)
+            B[i] += A[i, j, k]
+
+
+def test_lower_reduction():
+    origin_mod = WithInit()
+    mod = tvm.tir.transform.LowerInitBlock()(origin_mod)
+    tvm.ir.assert_structural_equal(mod, WithBranch(), True)
+
+
+if __name__ == "__main__":
+    test_lower_reduction()