You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/08/05 15:43:52 UTC

[tvm] branch main updated: [TIR] Add tir::builtin::assume (#12267)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8a0911c429 [TIR] Add tir::builtin::assume (#12267)
8a0911c429 is described below

commit 8a0911c429c9987710d8d79c81b1433b33615989
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Aug 5 10:43:45 2022 -0500

    [TIR] Add tir::builtin::assume (#12267)
    
    * [RemoveAssume] Implemented T.assume in TVMScript, RemoveAssume
    
    * [UnitTest] RemoveAssume, initial functionality tests
---
 include/tvm/tir/builtin.h                          |  9 +++
 python/tvm/script/tir/intrin.py                    | 11 ++++
 python/tvm/tir/transform/transform.py              | 11 ++++
 src/printer/tvmscript_printer.cc                   |  8 +++
 src/tir/op/builtin.cc                              |  4 ++
 src/tir/transforms/remove_assume.cc                | 69 ++++++++++++++++++++++
 .../unittest/test_tir_transform_remove_assume.py   | 57 ++++++++++++++++++
 7 files changed, 169 insertions(+)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 5fc42392c3..fc326c1873 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -720,6 +720,15 @@ TVM_DLL const Op& texture2d_load();
  */
 TVM_DLL const Op& mem_copy();
 
+/*!
+ * \brief Provide a true statement that can be used for simplifications
+ *
+ * Compile-time representation of known constraints about function
+ * inputs.  This assumption is removed when lowering, and does not
+ * occur in codegen.
+ */
+TVM_DLL const Op& assume();
+
 /*! \brief The kind of structure field info used in intrinsic */
 enum TVMStructFieldKind : int {
   // array head address
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
index 440e1ca77d..627b89086a 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/tir/intrin.py
@@ -240,6 +240,17 @@ class StoreIntrin(Intrin):
         super().__init__(store, stmt=True)
 
 
+@register
+class AssumeIntrin(Intrin):
+    def __init__(self):
+        def assume(constraint, span):
+            return tvm.tir.Evaluate(
+                tvm.tir.call_intrin("bool", "tir.assume", constraint, span=span)
+            )
+
+        super().__init__(assume, stmt=True)
+
+
 @register
 def comm_reducer(lambda_io, identities, span):
     """Create a CommReducer from lambda inputs/outputs and the identities"""
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index 6cc7b2e1f8..d63c65dfdd 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -253,6 +253,17 @@ def RemoveNoOp():
     return _ffi_api.RemoveNoOp()  # type: ignore
 
 
+def RemoveAssume():
+    """Remove all instances of builtin::assume
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.RemoveAssume()  # type: ignore
+
+
 def BF16Legalize():
     """Legalize bf16 typed Ops.
     Runs BF16Promote, BF16CastElimination and BF16TypeLowering
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index aaebc7409f..f2abf5c78d 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -1181,6 +1181,14 @@ Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) {
 }
 
 Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
+  if (auto* call = op->value.as<CallNode>()) {
+    if (call->op.same_as(builtin::assume())) {
+      Doc doc;
+      doc << tir_prefix_ << ".assume(" << Print(call->args[0]) << ")";
+      return doc;
+    }
+  }
+
   Doc doc;
   doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")";
   return doc;
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 1871a3d7bf..860f98dd14 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -288,6 +288,10 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load)
 TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr<TCallEffectKind>("TCallEffectKind",
                                                             Integer(CallEffectKind::kOpaque));
 
+TIR_DEFINE_BUILTIN_FUNC(assume)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo))
+    .set_num_inputs(1);
+
 }  // namespace builtin
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/transforms/remove_assume.cc b/src/tir/transforms/remove_assume.cc
new file mode 100644
index 0000000000..928bcf02bc
--- /dev/null
+++ b/src/tir/transforms/remove_assume.cc
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file remove_store_undef.cc
+ * \brief Remove stores of tir::builtin::undef
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+// Remove any builtin::assume calls
+class AssumeRemover : public StmtExprMutator {
+ public:
+  using Parent = StmtExprMutator;
+
+  Stmt VisitStmt_(const EvaluateNode* op) final {
+    if (auto* call = op->value.as<CallNode>()) {
+      if (call->op.same_as(builtin::assume())) {
+        return Evaluate(0);
+      }
+    }
+    return StmtExprMutator::VisitStmt_(op);
+  }
+};
+
+namespace transform {
+Pass RemoveAssumeInternal() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = AssumeRemover()(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.RemoveAssumeInternal", {});
+}
+
+Pass RemoveAssume() {
+  return Sequential({RemoveAssumeInternal(), RemoveNoOp()}, "tir.RemoveAssume");
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RemoveAssume").set_body_typed(RemoveAssume);
+
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_tir_transform_remove_assume.py b/tests/python/unittest/test_tir_transform_remove_assume.py
new file mode 100644
index 0000000000..4223e40e3f
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_remove_assume.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T
+from tvm import TVMError
+
+
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+    @tvm.testing.fixture
+    def transform(self):
+        return tvm.tir.transform.RemoveAssume()
+
+
+class TestRemoveAssume(BaseBeforeAfter):
+    """Remove any instance of T.assume"""
+
+    def before(A: T.Buffer[1, "int32"]):
+        T.assume(A[0] == 5)
+        A[0] = 10
+
+    def expected(A: T.Buffer[1, "int32"]):
+        A[0] = 10
+
+
+class TestRemoveAssumeLoop(BaseBeforeAfter):
+    """Loops containing only T.assume should be removed"""
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            T.assume(A[i] == 0)
+
+        for i in T.serial(16):
+            A[i] = 10
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 10
+
+
+if __name__ == "__main__":
+    tvm.testing.main()