You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/07/25 06:50:40 UTC
[tvm] branch main updated: [TIR] Well-Formed Verifier (#12166)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4d84620d07 [TIR] Well-Formed Verifier (#12166)
4d84620d07 is described below
commit 4d84620d07630f84c60b9386484ab2d8c102a0fe
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Mon Jul 25 14:50:33 2022 +0800
[TIR] Well-Formed Verifier (#12166)
* tir_well_formed_verifier
* fix typo
* lint
* fix testcase
---
include/tvm/tir/analysis.h | 11 +-
python/tvm/tir/analysis/analysis.py | 20 +++
src/tir/analysis/verify_well_formed.cc | 137 +++++++++++++++++++++
src/tir/schedule/state.cc | 1 +
.../test_tir_analysis_verify_well_formed.py | 57 +++++++++
.../test_tir_schedule_set_axis_separator.py | 9 +-
.../python/unittest/test_tir_schedule_set_scope.py | 15 +--
7 files changed, 235 insertions(+), 15 deletions(-)
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 8306cb173e..d60a222ac2 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -19,7 +19,7 @@
/*!
* \file tvm/tir/analysis.h
- * \brief Analysis utilitie and passes for TIR.
+ * \brief Analysis utilities and passes for TIR.
*/
#ifndef TVM_TIR_ANALYSIS_H_
#define TVM_TIR_ANALYSIS_H_
@@ -220,6 +220,15 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
*/
TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func);
+/*!
+ * \brief Verify if the given TIR is well-formed. The verification includes:
+ * - Check if expressions not contain vars that is defined outside the block.
+ * \param func The PrimFunc to be verified.
+ * \param assert_mode The indicator if it raises an error when the function is not well-formed.
+ * \return Whether it is a well-formed TIR function.
+ */
+TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);
+
// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index 7fc73ef4c4..13674daa24 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -300,3 +300,23 @@ def apply_prim_func_arg_and_result_memory_constraints(
return _ffi_api.ApplyPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member
func, relay_func_type, arg_and_result_memory_scopes
)
+
+
+def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool:
+ """Verify if the given TIR is well-formed. The verification includes:
+ - Check if expressions not contain vars that is defined outside the block.
+
+ Parameters
+ ----------
+ func: tvm.tir.PrimFunc
+ The function to be verified.
+
+ assert_mode: bool
+ The indicator if it raises an error when the function is not well-formed.
+
+ Returns
+ -------
+ result: bool
+ Whether it is a well-formed TIR function.
+ """
+ return _ffi_api.VerifyWellFormed(func, assert_mode) # type: ignore # pylint: disable=no-member
diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc
new file mode 100644
index 0000000000..878618fbe6
--- /dev/null
+++ b/src/tir/analysis/verify_well_formed.cc
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/verify_well_formed.cc
+ * \brief Check if schedulable tir is well-formed.
+ */
+
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../ir/functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+/*! \brief Verify all Expr inside the block does not contain:
+ * 1. loop vars outside the current block.
+ * 2. block vars of parent blocks.
+ */
+class BlockVarAccessVerifier : public StmtExprVisitor {
+ public:
+ static bool Verify(const PrimFunc& func, bool assert_mode) {
+ BlockVarAccessVerifier verifier(assert_mode);
+ verifier(func->body);
+ return !verifier.has_error_;
+ }
+
+ private:
+ explicit BlockVarAccessVerifier(bool assert_mode) : assert_mode_(assert_mode) {}
+
+ void VisitStmt(const Stmt& stmt) final {
+ if (!has_error_) {
+ StmtExprVisitor::VisitStmt(stmt);
+ }
+ }
+
+ void VisitExpr(const PrimExpr& expr) final {
+ if (!has_error_) {
+ StmtExprVisitor::VisitExpr(expr);
+ }
+ }
+
+ void VisitExpr_(const VarNode* op) final {
+ auto it = loop_vars_.find(op);
+ if (it != loop_vars_.end() && it->second < cur_block_level_) {
+ has_error_ = true;
+ if (assert_mode_) {
+ report_error(op);
+ }
+ }
+ }
+
+ void VisitStmt_(const ForNode* op) final {
+ ICHECK(loop_vars_.find(op->loop_var.get()) == loop_vars_.end());
+ loop_vars_[op->loop_var.get()] = cur_block_level_;
+ StmtExprVisitor::VisitStmt_(op);
+ loop_vars_.erase(op->loop_var.get());
+ }
+
+ void VisitStmt_(const BlockNode* op) final {
+ // Do not check boundary if it's a opaque block.
+ cur_block_level_ += !op->iter_vars.empty();
+
+ // Step 0. Skip block iter var's domain
+
+ // Step 1. Visit read/write regions
+ auto fvisit_buffer_region = [this](const BufferRegion& s) {
+ for (const auto& range : s->region) {
+ this->VisitExpr(range->min);
+ this->VisitExpr(range->extent);
+ }
+ };
+ VisitArray(op->reads, fvisit_buffer_region);
+ VisitArray(op->writes, fvisit_buffer_region);
+
+ // Step 2. Visit match buffers
+ VisitArray(op->match_buffers,
+ [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) {
+ fvisit_buffer_region(match_buffer_region->source);
+ });
+
+ // Step 3. Visit init and body
+ if (op->init.defined()) {
+ this->VisitStmt(op->init.value());
+ }
+ this->VisitStmt(op->body);
+
+ cur_block_level_ -= !op->iter_vars.empty();
+ }
+
+ private:
+ void report_error(const VarNode* var) {
+ // TODO(siyuan): use the error message from the parser.
+ LOG(FATAL) << "Well-formedness check failed: outside defined var " << var->name_hint
+ << " is used inside the current block.";
+ }
+
+ /*! \brief The map from outside loop vars to its corresponding block level. */
+ std::unordered_map<const VarNode*, size_t> loop_vars_;
+ /*! \brief Whether it's in assert mode. */
+ bool assert_mode_;
+ /*! \brief Current nested block stack level. */
+ size_t cur_block_level_{0};
+ /*! \brief Whether there is error. */
+ bool has_error_{false};
+};
+
+bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) {
+ if (!BlockVarAccessVerifier::Verify(func, assert_mode)) {
+ return false;
+ }
+ // TODO(Siyuan): add more checks here.
+ return true;
+}
+
+TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed").set_body_typed(VerifyWellFormed);
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 3c11d24853..dadabba485 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -413,6 +413,7 @@ class StateCreator : private StmtVisitor {
for (const auto& kv : n->mod->functions) {
const BaseFunc& base_func = kv.second;
if (const auto* func = base_func.as<PrimFuncNode>()) {
+ VerifyWellFormed(GetRef<PrimFunc>(func));
creator.VisitStmt(func->body);
BlockInfoCollector::Collect(self, func->body);
}
diff --git a/tests/python/unittest/test_tir_analysis_verify_well_formed.py b/tests/python/unittest/test_tir_analysis_verify_well_formed.py
new file mode 100644
index 0000000000..b3028a0148
--- /dev/null
+++ b/tests/python/unittest/test_tir_analysis_verify_well_formed.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm.script import tir as T
+
+
+def test_pass_simple():
+ @T.prim_func
+ def element_wise(
+ A: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+ ):
+ B = T.alloc_buffer((128, 128), "float32")
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ # It's a opaque block , so it can use outside variables
+ C[i, j] = B[i, j] * 2.0
+
+ assert tvm.tir.analysis.verify_well_formed(element_wise)
+
+
+def test_fail_use_out_loop_var():
+ @T.prim_func
+ def element_wise(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"],
+ ):
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ # we cannot use `i` since it's defined outside the block
+ B[vi, vj] = A[i, vj] * 2.0
+
+ assert not tvm.tir.analysis.verify_well_formed(element_wise, assert_mode=False)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
index 102b3d1cd7..9502da1829 100644
--- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py
+++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
-import sys
import pytest
import tvm
import tvm.testing
@@ -76,12 +75,12 @@ def element_wise_subregion_match(A: T.Buffer[(128, 128), "float32"], C: T.Buffer
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
- B_subregion0 = T.match_buffer(B[i, j], [], offset_factor=1)
+ B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1)
B_subregion0[()] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
- B_subregion1 = T.match_buffer(B[i, j], [], offset_factor=1)
+ B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1)
C[vi, vj] = B_subregion1[()] + 1.0
@@ -92,12 +91,12 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer[(128, 128), "flo
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
- B_subregion0 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1])
+ B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
B_subregion0[()] = A[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
- B_subregion1 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1])
+ B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
C[vi, vj] = B_subregion1[()] + T.float32(1)
diff --git a/tests/python/unittest/test_tir_schedule_set_scope.py b/tests/python/unittest/test_tir_schedule_set_scope.py
index b2e8479462..adac81e629 100644
--- a/tests/python/unittest/test_tir_schedule_set_scope.py
+++ b/tests/python/unittest/test_tir_schedule_set_scope.py
@@ -17,6 +17,7 @@
# pylint: disable=missing-function-docstring,missing-module-docstring
import pytest
import tvm
+import tvm.testing
from tvm import tir
from tvm.script import tir as T
from tvm.tir.schedule.testing import verify_trace_roundtrip
@@ -59,12 +60,12 @@ def element_wise_subregion_match(A: T.Buffer[(128, 128), "float32"], C: T.Buffer
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
- B_subregion0 = T.match_buffer(B[i, j], [], offset_factor=1)
+ B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1)
B_subregion0[()] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
- B_subregion1 = T.match_buffer(B[i, j], [], offset_factor=1)
+ B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1)
C[vi, vj] = B_subregion1[()] + 1.0
@@ -75,12 +76,12 @@ def element_wise_subregion_match_set_scope(A: T.Buffer[(128, 128), "float32"], C
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
- B_subregion0_shared = T.match_buffer(B_shared[i, j], [], dtype="float32", scope="shared", offset_factor=1)
+ B_subregion0_shared = T.match_buffer(B_shared[vi, vj], [], dtype="float32", scope="shared", offset_factor=1)
B_subregion0_shared[()] = A[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
- B_subregion1_shared = T.match_buffer(B_shared[i, j], [], dtype="float32", scope="shared", offset_factor=1)
+ B_subregion1_shared = T.match_buffer(B_shared[vi, vj], [], dtype="float32", scope="shared", offset_factor=1)
C[vi, vj] = B_subregion1_shared[()] + T.float32(1)
@@ -128,8 +129,4 @@ def test_set_scope_subregion():
if __name__ == "__main__":
- test_set_scope()
- test_set_scope_fail_on_output_buffer()
- test_set_scope_fail_on_index_out_of_bound()
- test_set_scope_fail_on_invalid_scope()
- test_set_scope_subregion()
+ tvm.testing.main()