You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/07/25 06:50:40 UTC

[tvm] branch main updated: [TIR] Well-Formed Verifier (#12166)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d84620d07 [TIR] Well-Formed Verifier (#12166)
4d84620d07 is described below

commit 4d84620d07630f84c60b9386484ab2d8c102a0fe
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Mon Jul 25 14:50:33 2022 +0800

    [TIR] Well-Formed Verifier (#12166)
    
    * tir_well_formed_verifier
    
    * fix typo
    
    * lint
    
    * fix testcase
---
 include/tvm/tir/analysis.h                         |  11 +-
 python/tvm/tir/analysis/analysis.py                |  20 +++
 src/tir/analysis/verify_well_formed.cc             | 137 +++++++++++++++++++++
 src/tir/schedule/state.cc                          |   1 +
 .../test_tir_analysis_verify_well_formed.py        |  57 +++++++++
 .../test_tir_schedule_set_axis_separator.py        |   9 +-
 .../python/unittest/test_tir_schedule_set_scope.py |  15 +--
 7 files changed, 235 insertions(+), 15 deletions(-)

diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 8306cb173e..d60a222ac2 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -19,7 +19,7 @@
 
 /*!
  * \file tvm/tir/analysis.h
- * \brief Analysis utilitie and passes for TIR.
+ * \brief Analysis utilities and passes for TIR.
  */
 #ifndef TVM_TIR_ANALYSIS_H_
 #define TVM_TIR_ANALYSIS_H_
@@ -220,6 +220,15 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
  */
 TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func);
 
+/*!
+ * \brief Verify if the given TIR is well-formed. The verification includes:
+ *        - Check if expressions not contain vars that is defined outside the block.
+ * \param func The PrimFunc to be verified.
+ * \param assert_mode The indicator if it raises an error when the function is not well-formed.
+ * \return Whether it is a well-formed TIR function.
+ */
+TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);
+
 // Pass variants of verification analysis
 // directly throws RuntimeError when verification fails.
 namespace transform {
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index 7fc73ef4c4..13674daa24 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -300,3 +300,23 @@ def apply_prim_func_arg_and_result_memory_constraints(
     return _ffi_api.ApplyPrimFuncArgAndResultMemoryConstraints(  # type: ignore # pylint: disable=no-member
         func, relay_func_type, arg_and_result_memory_scopes
     )
+
+
+def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool:
+    """Verify if the given TIR is well-formed. The verification includes:
+        - Check if expressions not contain vars that is defined outside the block.
+
+    Parameters
+    ----------
+    func: tvm.tir.PrimFunc
+        The function to be verified.
+
+    assert_mode: bool
+        The indicator if it raises an error when the function is not well-formed.
+
+    Returns
+    -------
+    result: bool
+        Whether it is a well-formed TIR function.
+    """
+    return _ffi_api.VerifyWellFormed(func, assert_mode)  # type: ignore # pylint: disable=no-member
diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc
new file mode 100644
index 0000000000..878618fbe6
--- /dev/null
+++ b/src/tir/analysis/verify_well_formed.cc
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/verify_well_formed.cc
+ * \brief Check if schedulable tir is well-formed.
+ */
+
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../ir/functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+/*! \brief Verify all Expr inside the block does not contain:
+ *    1. loop vars outside the current block.
+ *    2. block vars of parent blocks.
+ */
+class BlockVarAccessVerifier : public StmtExprVisitor {
+ public:
+  static bool Verify(const PrimFunc& func, bool assert_mode) {
+    BlockVarAccessVerifier verifier(assert_mode);
+    verifier(func->body);
+    return !verifier.has_error_;
+  }
+
+ private:
+  explicit BlockVarAccessVerifier(bool assert_mode) : assert_mode_(assert_mode) {}
+
+  void VisitStmt(const Stmt& stmt) final {
+    if (!has_error_) {
+      StmtExprVisitor::VisitStmt(stmt);
+    }
+  }
+
+  void VisitExpr(const PrimExpr& expr) final {
+    if (!has_error_) {
+      StmtExprVisitor::VisitExpr(expr);
+    }
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    auto it = loop_vars_.find(op);
+    if (it != loop_vars_.end() && it->second < cur_block_level_) {
+      has_error_ = true;
+      if (assert_mode_) {
+        report_error(op);
+      }
+    }
+  }
+
+  void VisitStmt_(const ForNode* op) final {
+    ICHECK(loop_vars_.find(op->loop_var.get()) == loop_vars_.end());
+    loop_vars_[op->loop_var.get()] = cur_block_level_;
+    StmtExprVisitor::VisitStmt_(op);
+    loop_vars_.erase(op->loop_var.get());
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    // Do not check boundary if it's a opaque block.
+    cur_block_level_ += !op->iter_vars.empty();
+
+    // Step 0. Skip block iter var's domain
+
+    // Step 1. Visit read/write regions
+    auto fvisit_buffer_region = [this](const BufferRegion& s) {
+      for (const auto& range : s->region) {
+        this->VisitExpr(range->min);
+        this->VisitExpr(range->extent);
+      }
+    };
+    VisitArray(op->reads, fvisit_buffer_region);
+    VisitArray(op->writes, fvisit_buffer_region);
+
+    // Step 2. Visit match buffers
+    VisitArray(op->match_buffers,
+               [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) {
+                 fvisit_buffer_region(match_buffer_region->source);
+               });
+
+    // Step 3. Visit init and body
+    if (op->init.defined()) {
+      this->VisitStmt(op->init.value());
+    }
+    this->VisitStmt(op->body);
+
+    cur_block_level_ -= !op->iter_vars.empty();
+  }
+
+ private:
+  void report_error(const VarNode* var) {
+    // TODO(siyuan): use the error message from the parser.
+    LOG(FATAL) << "Well-formedness check failed: outside defined var " << var->name_hint
+               << " is used inside the current block.";
+  }
+
+  /*! \brief The map from outside loop vars to its corresponding block level. */
+  std::unordered_map<const VarNode*, size_t> loop_vars_;
+  /*! \brief Whether it's in assert mode. */
+  bool assert_mode_;
+  /*! \brief Current nested block stack level. */
+  size_t cur_block_level_{0};
+  /*! \brief Whether there is error. */
+  bool has_error_{false};
+};
+
+bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) {
+  if (!BlockVarAccessVerifier::Verify(func, assert_mode)) {
+    return false;
+  }
+  // TODO(Siyuan): add more checks here.
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed").set_body_typed(VerifyWellFormed);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 3c11d24853..dadabba485 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -413,6 +413,7 @@ class StateCreator : private StmtVisitor {
     for (const auto& kv : n->mod->functions) {
       const BaseFunc& base_func = kv.second;
       if (const auto* func = base_func.as<PrimFuncNode>()) {
+        VerifyWellFormed(GetRef<PrimFunc>(func));
         creator.VisitStmt(func->body);
         BlockInfoCollector::Collect(self, func->body);
       }
diff --git a/tests/python/unittest/test_tir_analysis_verify_well_formed.py b/tests/python/unittest/test_tir_analysis_verify_well_formed.py
new file mode 100644
index 0000000000..b3028a0148
--- /dev/null
+++ b/tests/python/unittest/test_tir_analysis_verify_well_formed.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm.script import tir as T
+
+
+def test_pass_simple():
+    @T.prim_func
+    def element_wise(
+        A: T.Buffer[(128, 128), "float32"],
+        C: T.Buffer[(128, 128), "float32"],
+    ):
+        B = T.alloc_buffer((128, 128), "float32")
+        for i, j in T.grid(128, 128):
+            with T.block("B"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                B[vi, vj] = A[vi, vj] * 2.0
+        for i, j in T.grid(128, 128):
+            with T.block("C"):
+                # It's a opaque block , so it can use outside variables
+                C[i, j] = B[i, j] * 2.0
+
+    assert tvm.tir.analysis.verify_well_formed(element_wise)
+
+
+def test_fail_use_out_loop_var():
+    @T.prim_func
+    def element_wise(
+        A: T.Buffer[(128, 128), "float32"],
+        B: T.Buffer[(128, 128), "float32"],
+    ):
+        for i, j in T.grid(128, 128):
+            with T.block("B"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                # we cannot use `i` since it's defined outside the block
+                B[vi, vj] = A[i, vj] * 2.0
+
+    assert not tvm.tir.analysis.verify_well_formed(element_wise, assert_mode=False)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
index 102b3d1cd7..9502da1829 100644
--- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py
+++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-function-docstring,missing-module-docstring
-import sys
 import pytest
 import tvm
 import tvm.testing
@@ -76,12 +75,12 @@ def element_wise_subregion_match(A: T.Buffer[(128, 128), "float32"], C: T.Buffer
     for i, j in T.grid(128, 128):
         with T.block("B"):
             vi, vj = T.axis.remap("SS", [i, j])
-            B_subregion0 = T.match_buffer(B[i, j], [], offset_factor=1)
+            B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1)
             B_subregion0[()] = A[vi, vj] * 2.0
     for i, j in T.grid(128, 128):
         with T.block("C"):
             vi, vj = T.axis.remap("SS", [i, j])
-            B_subregion1 = T.match_buffer(B[i, j], [], offset_factor=1)
+            B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1)
             C[vi, vj] = B_subregion1[()] + 1.0
 
 
@@ -92,12 +91,12 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer[(128, 128), "flo
     for i, j in T.grid(128, 128):
         with T.block("B"):
             vi, vj = T.axis.remap("SS", [i, j])
-            B_subregion0 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1])
+            B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
             B_subregion0[()] = A[vi, vj] * T.float32(2)
     for i, j in T.grid(128, 128):
         with T.block("C"):
             vi, vj = T.axis.remap("SS", [i, j])
-            B_subregion1 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1])
+            B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1])
             C[vi, vj] = B_subregion1[()] + T.float32(1)
 
 
diff --git a/tests/python/unittest/test_tir_schedule_set_scope.py b/tests/python/unittest/test_tir_schedule_set_scope.py
index b2e8479462..adac81e629 100644
--- a/tests/python/unittest/test_tir_schedule_set_scope.py
+++ b/tests/python/unittest/test_tir_schedule_set_scope.py
@@ -17,6 +17,7 @@
 # pylint: disable=missing-function-docstring,missing-module-docstring
 import pytest
 import tvm
+import tvm.testing
 from tvm import tir
 from tvm.script import tir as T
 from tvm.tir.schedule.testing import verify_trace_roundtrip
@@ -59,12 +60,12 @@ def element_wise_subregion_match(A: T.Buffer[(128, 128), "float32"], C: T.Buffer
     for i, j in T.grid(128, 128):
         with T.block("B"):
             vi, vj = T.axis.remap("SS", [i, j])
-            B_subregion0 = T.match_buffer(B[i, j], [], offset_factor=1)
+            B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1)
             B_subregion0[()] = A[vi, vj] * 2.0
     for i, j in T.grid(128, 128):
         with T.block("C"):
             vi, vj = T.axis.remap("SS", [i, j])
-            B_subregion1 = T.match_buffer(B[i, j], [], offset_factor=1)
+            B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1)
             C[vi, vj] = B_subregion1[()] + 1.0
 
 
@@ -75,12 +76,12 @@ def element_wise_subregion_match_set_scope(A: T.Buffer[(128, 128), "float32"], C
     for i, j in T.grid(128, 128):
         with T.block("B"):
             vi, vj = T.axis.remap("SS", [i, j])
-            B_subregion0_shared = T.match_buffer(B_shared[i, j], [], dtype="float32", scope="shared", offset_factor=1)
+            B_subregion0_shared = T.match_buffer(B_shared[vi, vj], [], dtype="float32", scope="shared", offset_factor=1)
             B_subregion0_shared[()] = A[vi, vj] * T.float32(2)
     for i, j in T.grid(128, 128):
         with T.block("C"):
             vi, vj = T.axis.remap("SS", [i, j])
-            B_subregion1_shared = T.match_buffer(B_shared[i, j], [], dtype="float32", scope="shared", offset_factor=1)
+            B_subregion1_shared = T.match_buffer(B_shared[vi, vj], [], dtype="float32", scope="shared", offset_factor=1)
             C[vi, vj] = B_subregion1_shared[()] + T.float32(1)
 
 
@@ -128,8 +129,4 @@ def test_set_scope_subregion():
 
 
 if __name__ == "__main__":
-    test_set_scope()
-    test_set_scope_fail_on_output_buffer()
-    test_set_scope_fail_on_index_out_of_bound()
-    test_set_scope_fail_on_invalid_scope()
-    test_set_scope_subregion()
+    tvm.testing.main()