You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/08/05 15:43:52 UTC
[tvm] branch main updated: [TIR] Add tir::builtin::assume (#12267)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8a0911c429 [TIR] Add tir::builtin::assume (#12267)
8a0911c429 is described below
commit 8a0911c429c9987710d8d79c81b1433b33615989
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Aug 5 10:43:45 2022 -0500
[TIR] Add tir::builtin::assume (#12267)
* [RemoveAssume] Implemented T.assume in TVMScript, RemoveAssume
* [UnitTest] RemoveAssume, initial functionality tests
---
include/tvm/tir/builtin.h | 9 +++
python/tvm/script/tir/intrin.py | 11 ++++
python/tvm/tir/transform/transform.py | 11 ++++
src/printer/tvmscript_printer.cc | 8 +++
src/tir/op/builtin.cc | 4 ++
src/tir/transforms/remove_assume.cc | 69 ++++++++++++++++++++++
.../unittest/test_tir_transform_remove_assume.py | 57 ++++++++++++++++++
7 files changed, 169 insertions(+)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 5fc42392c3..fc326c1873 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -720,6 +720,15 @@ TVM_DLL const Op& texture2d_load();
*/
TVM_DLL const Op& mem_copy();
+/*!
+ * \brief Provide a true statement that can be used for simplifications
+ *
+ * Compile-time representation of known constraints about function
+ * inputs. This assumption is removed when lowering, and does not
+ * occur in codegen.
+ */
+TVM_DLL const Op& assume();
+
/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
index 440e1ca77d..627b89086a 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/tir/intrin.py
@@ -240,6 +240,17 @@ class StoreIntrin(Intrin):
super().__init__(store, stmt=True)
+@register
+class AssumeIntrin(Intrin):
+ def __init__(self):
+ def assume(constraint, span):
+ return tvm.tir.Evaluate(
+ tvm.tir.call_intrin("bool", "tir.assume", constraint, span=span)
+ )
+
+ super().__init__(assume, stmt=True)
+
+
@register
def comm_reducer(lambda_io, identities, span):
"""Create a CommReducer from lambda inputs/outputs and the identities"""
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index 6cc7b2e1f8..d63c65dfdd 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -253,6 +253,17 @@ def RemoveNoOp():
return _ffi_api.RemoveNoOp() # type: ignore
+def RemoveAssume():
+ """Remove all instances of builtin::assume
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RemoveAssume() # type: ignore
+
+
def BF16Legalize():
"""Legalize bf16 typed Ops.
Runs BF16Promote, BF16CastElimination and BF16TypeLowering
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index aaebc7409f..f2abf5c78d 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -1181,6 +1181,14 @@ Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) {
}
Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
+ if (auto* call = op->value.as<CallNode>()) {
+ if (call->op.same_as(builtin::assume())) {
+ Doc doc;
+ doc << tir_prefix_ << ".assume(" << Print(call->args[0]) << ")";
+ return doc;
+ }
+ }
+
Doc doc;
doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")";
return doc;
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 1871a3d7bf..860f98dd14 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -288,6 +288,10 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load)
TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(assume)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo))
+ .set_num_inputs(1);
+
} // namespace builtin
} // namespace tir
} // namespace tvm
diff --git a/src/tir/transforms/remove_assume.cc b/src/tir/transforms/remove_assume.cc
new file mode 100644
index 0000000000..928bcf02bc
--- /dev/null
+++ b/src/tir/transforms/remove_assume.cc
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file remove_store_undef.cc
+ * \brief Remove stores of tir::builtin::undef
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+// Remove any builtin::assume calls
+class AssumeRemover : public StmtExprMutator {
+ public:
+ using Parent = StmtExprMutator;
+
+ Stmt VisitStmt_(const EvaluateNode* op) final {
+ if (auto* call = op->value.as<CallNode>()) {
+ if (call->op.same_as(builtin::assume())) {
+ return Evaluate(0);
+ }
+ }
+ return StmtExprMutator::VisitStmt_(op);
+ }
+};
+
+namespace transform {
+Pass RemoveAssumeInternal() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = AssumeRemover()(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.RemoveAssumeInternal", {});
+}
+
+Pass RemoveAssume() {
+ return Sequential({RemoveAssumeInternal(), RemoveNoOp()}, "tir.RemoveAssume");
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RemoveAssume").set_body_typed(RemoveAssume);
+
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/unittest/test_tir_transform_remove_assume.py b/tests/python/unittest/test_tir_transform_remove_assume.py
new file mode 100644
index 0000000000..4223e40e3f
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_remove_assume.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T
+from tvm import TVMError
+
+
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+ @tvm.testing.fixture
+ def transform(self):
+ return tvm.tir.transform.RemoveAssume()
+
+
+class TestRemoveAssume(BaseBeforeAfter):
+ """Remove any instance of T.assume"""
+
+ def before(A: T.Buffer[1, "int32"]):
+ T.assume(A[0] == 5)
+ A[0] = 10
+
+ def expected(A: T.Buffer[1, "int32"]):
+ A[0] = 10
+
+
+class TestRemoveAssumeLoop(BaseBeforeAfter):
+ """Loops containing only T.assume should be removed"""
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ T.assume(A[i] == 0)
+
+ for i in T.serial(16):
+ A[i] = 10
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 10
+
+
+if __name__ == "__main__":
+ tvm.testing.main()