You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/06 23:39:08 UTC

[tvm] 01/03: [TVMScript] IRModule TVMScript Parser.

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 11c13ace0b5cef71f50193248ecaac7e845ee25e
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Wed Feb 8 22:31:47 2023 +0800

    [TVMScript] IRModule TVMScript Parser.
    
    This PR adds the TVMScript parser/ir_builder support based on the
    blockbuilder.  This commit contains the non-relax portions from
    https://github.com/apache/tvm/pull/13932.
    
    Co-authored-by: Ruihang Lai <ru...@cs.cmu.edu>
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Tianqi Chen <ti...@gmail.com>
    Co-authored-by: Yuchen Jin <yu...@cs.washington.edu>
    Co-authored-by: Steven S. Lyubomirsky <sl...@gmail.com>
    Co-authored-by: Yong Wu <yo...@gmail.com>
---
 include/tvm/script/ir_builder/ir/frame.h       | 11 ++++--
 include/tvm/script/ir_builder/ir/ir.h          | 17 +++++++++
 python/tvm/script/ir_builder/base.py           |  6 ++--
 python/tvm/script/ir_builder/ir/__init__.py    |  2 +-
 python/tvm/script/ir_builder/ir/ir.py          | 45 +++++++++++++++++++++++
 python/tvm/script/parser/core/diagnostics.py   |  2 +-
 python/tvm/script/parser/core/evaluator.py     |  2 +-
 python/tvm/script/parser/core/parser.py        | 50 ++++++++++++++++++--------
 python/tvm/script/parser/ir/parser.py          |  4 +++
 python/tvm/script/parser/tir/entry.py          |  4 +--
 python/tvm/script/parser/tir/parser.py         | 26 ++++++++++++++
 src/script/ir_builder/ir/frame.cc              | 12 ++++---
 src/script/ir_builder/ir/ir.cc                 | 32 ++++++++++++++++-
 src/script/ir_builder/ir/{frame.cc => utils.h} | 30 +++++++++-------
 src/script/ir_builder/tir/frame.cc             | 15 ++++++--
 src/script/ir_builder/tir/utils.h              |  2 +-
 16 files changed, 213 insertions(+), 47 deletions(-)

diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h
index 887981ccff..dacfc361a6 100644
--- a/include/tvm/script/ir_builder/ir/frame.h
+++ b/include/tvm/script/ir_builder/ir/frame.h
@@ -38,12 +38,17 @@ namespace ir {
  */
 class IRModuleFrameNode : public IRBuilderFrameNode {
  public:
-  Array<GlobalVar> global_vars;
-  Array<BaseFunc> functions;
+  /*! \brief A map from string names to global variables that ensures global uniqueness. */
+  Map<String, GlobalVar> global_var_map;
+  /*!
+   * \brief A map from GlobalVar to all global functions.
+   * \note Only defined functions are in the map, while declared functions are not included.
+   */
+  Map<GlobalVar, BaseFunc> functions;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     IRBuilderFrameNode::VisitAttrs(v);
-    v->Visit("global_vars", &global_vars);
+    v->Visit("global_vars", &global_var_map);
     v->Visit("functions", &functions);
   }
 
diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h
index f0e7cc6f5c..49bdcf60e6 100644
--- a/include/tvm/script/ir_builder/ir/ir.h
+++ b/include/tvm/script/ir_builder/ir/ir.h
@@ -37,6 +37,23 @@ namespace ir {
  */
 TVM_DLL IRModuleFrame IRModule();
 
+/*!
+ * \brief Declare a Function without given the specific function implementation.
+ * \note It is usually used in cross-function call. And we can specify the function by `DefFunction`
+ * \param func_name The function unique name.
+ * \param func_signature A Function w/o body, which used to specify the function signature
+ *                       (i.e. func params and func return type/shape).
+ * \return The corresponding GlobalVar.
+ */
+TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature);
+
+/*!
+ * \brief Define the function which is declared before.
+ * \param func_name The function unique name.
+ * \param func The given function implementation
+ */
+TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func);
+
 }  // namespace ir
 }  // namespace ir_builder
 }  // namespace script
diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py
index 7aa33ee49c..b35bbd0a7d 100644
--- a/python/tvm/script/ir_builder/base.py
+++ b/python/tvm/script/ir_builder/base.py
@@ -64,8 +64,10 @@ class IRBuilderFrame(_Object):
         _ffi_api.IRBuilderFrameEnter(self)  # type: ignore[attr-defined] # pylint: disable=no-member
         return self
 
-    def __exit__(self, ptype, value, trace) -> None:  # pylint: disable=unused-argument
-        _ffi_api.IRBuilderFrameExit(self)  # type: ignore[attr-defined] # pylint: disable=no-member
+    def __exit__(self, exc_type, exc_value, trace) -> None:  # pylint: disable=unused-argument
+        if exc_type is None and exc_value is None:
+            # Do not execute `FrameExit` if the with scope exits because of exceptions
+            _ffi_api.IRBuilderFrameExit(self)  # type: ignore[attr-defined] # pylint: disable=no-member
 
     def add_callback(self, callback: Callable[[], None]) -> None:
         """Add a callback method invoked when exiting the with-scope.
diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py
index ebb9728737..946be263a7 100644
--- a/python/tvm/script/ir_builder/ir/__init__.py
+++ b/python/tvm/script/ir_builder/ir/__init__.py
@@ -16,4 +16,4 @@
 # under the License.
 """Package tvm.script.ir_builder.ir"""
 from .frame import IRModuleFrame
-from .ir import ir_module
+from .ir import decl_function, def_function, ir_module
diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py
index 213180463c..796d6f3aad 100644
--- a/python/tvm/script/ir_builder/ir/ir.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -16,9 +16,54 @@
 # under the License.
 """Package tvm.script.ir_builder.ir.ir"""
 
+from tvm.ir import BaseFunc, GlobalVar
+
 from . import _ffi_api
 from .frame import IRModuleFrame
 
 
 def ir_module() -> IRModuleFrame:
+    """Start a ir_module frame.
+    Returns
+    -------
+    frame: IRModuleFrame
+        The constructed frame.
+    """
     return _ffi_api.IRModule()  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar:
+    """Declare a Function without given the specific function implementation.
+    Parameters
+    ----------
+    func_name : str
+        The function unique name.
+
+    func_signature: Optional[BaseFunc]
+        A Function w/o body, which used to specify the function signature
+        (i.e. func params and func return type/shape).
+
+    Note
+    ----
+    It is usually used in cross-function call. And we can specify the function by `DefFunction`
+    Returns
+    -------
+    gv : GlobalVar
+        The corresponding GlobalVar.
+    """
+
+    return _ffi_api.DeclFunction(  # type: ignore[attr-defined] # pylint: disable=no-member
+        func_name, func_signature
+    )
+
+
+def def_function(func_name: str, func: BaseFunc) -> None:
+    """Define the function which is declared before.
+    Parameters
+    ----------
+    func_name : str
+        The function unique name.
+    func: BaseFunc
+        The given function implementation
+    """
+    return _ffi_api.DefFunction(func_name, func)  # type: ignore[attr-defined] # pylint: disable=no-member
diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py
index ad7ae50347..2767a97f60 100644
--- a/python/tvm/script/parser/core/diagnostics.py
+++ b/python/tvm/script/parser/core/diagnostics.py
@@ -220,7 +220,7 @@ class Diagnostics:
         level : diagnostics.DiagnosticLevel
             The diagnostic level.
         """
-        lineno = node.lineno or self.source.start_line
+        lineno = node.lineno or 1
         col_offset = node.col_offset or self.source.start_column
         end_lineno = node.end_lineno or lineno
         end_col_offset = node.end_col_offset or col_offset
diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py
index 3a72a3c331..075aedd891 100644
--- a/python/tvm/script/parser/core/evaluator.py
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -203,7 +203,7 @@ class ExprEvaluator:
             else:
                 value = self._eval_expr(node.__class__(**fields))
         except Exception as e:  # pylint: disable=broad-except,invalid-name
-            self.parser.report_error(node, str(e))
+            self.parser.report_error(node, e)
         return self._add_intermediate_result(value)
 
     def _eval_lambda(self, node: doc.Lambda) -> Any:
diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py
index fdccabcd23..837b7cce5d 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -60,6 +60,10 @@ def _deferred(exit_f: Callable[[], None]):
     return context()
 
 
+def _do_nothing(*args, **kwargs):  # pylint: disable=unused-argument
+    pass
+
+
 class VarTableFrame:
     """The variable table frame.
     A frame of variable table stores the variables created in one block or scope.
@@ -260,6 +264,17 @@ class Parser(doc.NodeVisitor):
             node = self.diag.source.as_ast()
             self.visit(node)
 
+    def get_dispatch_token(self, node: doc.FunctionDef) -> str:
+        if not isinstance(node, doc.FunctionDef):
+            self.report_error(node, "Only can get dispatch token for function.")
+        if not node.decorator_list:
+            self.report_error(node, "Function must be decorated")
+        # TODO: only the last decorator is parsed
+        decorator = self.eval_expr(node.decorator_list[-1])
+        if not hasattr(decorator, "dispatch_token"):
+            self.report_error(node, "The parser does not understand the decorator")
+        return decorator.dispatch_token
+
     def with_dispatch_token(self, token: str):
         """Add a new dispatching token as with statement.
 
@@ -389,6 +404,8 @@ class Parser(doc.NodeVisitor):
         # Only take the last line of the error message
         if isinstance(err, TVMError):
             msg = list(filter(None, str(err).split("\n")))[-1]
+        elif isinstance(err, KeyError):
+            msg = "KeyError: " + str(err)
         else:
             msg = str(err)
         self.diag.error(node, msg)
@@ -458,30 +475,33 @@ class Parser(doc.NodeVisitor):
         """
         return _dispatch(self, "tvm_annotation")(self, node)
 
-    def visit_FunctionDef(self, node: doc.FunctionDef) -> Any:  # pylint: disable=invalid-name
-        """The general function definition visiting method.
+    def visit_FunctionDef(self, node: doc.FunctionDef) -> None:  # pylint: disable=invalid-name
+        """The general function definition visit method.
 
         Parameters
         ----------
         node : doc.FunctionDef
-            The doc AST function definition node.
-
-        Returns
-        -------
-        res : Any
-            The visiting result.
+            The doc FunctionDef node.
         """
-        if not node.decorator_list:
-            self.report_error(node, "Function must be decorated")
-        # TODO: only the last decorator is parsed
-        decorator = self.eval_expr(node.decorator_list[-1])
-        if not hasattr(decorator, "dispatch_token"):
-            self.report_error(node, "The parser does not understand the decorator")
-        token = decorator.dispatch_token
+        token = self.get_dispatch_token(node)
+        current_token = self.dispatch_tokens[-1]
         func = dispatch.get(token=token, type_name="FunctionDef", default=None)
         if func is None:
             self.report_error(node, "The parser does not understand the decorator")
+        pre_func = dispatch.get(
+            token=current_token, type_name="pre_token_switch", default=_do_nothing
+        )
+        post_func = dispatch.get(
+            token=current_token, type_name="post_token_switch", default=_do_nothing
+        )
+        pre_func(self, node)
         _dispatch_wrapper(func)(self, node)
+        post_func(self, node)
+
+    def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None:
+        token = self.get_dispatch_token(node)
+        with self.with_dispatch_token(token):
+            _dispatch(self, "tvm_declare_function")(self, node)
 
     def visit_ClassDef(self, node: doc.ClassDef) -> Any:  # pylint: disable=invalid-name
         """The general class definition visiting method.
diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py
index e0268412d2..13b3e29859 100644
--- a/python/tvm/script/parser/ir/parser.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -32,8 +32,12 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
     node : doc.ClassDef
         The doc AST class definition node.
     """
+
     with self.var_table.with_frame():
         with I.ir_module():
+            for stmt in node.body:
+                if isinstance(stmt, doc.FunctionDef):
+                    self.visit_tvm_declare_function(stmt)
             with self.with_dispatch_token("ir"):
                 self.visit_body(node.body)
 
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
index 411a7f8f3c..649f817411 100644
--- a/python/tvm/script/parser/tir/entry.py
+++ b/python/tvm/script/parser/tir/entry.py
@@ -83,7 +83,7 @@ class BufferProxy:
             return self(keys)
         if len(keys) >= 2 and not isinstance(keys[1], str):
             return self(keys)
-        return self(*keys)  # pylint: disable=no-member # type: ignore
+        return self(*keys)  # type: ignore[attr-defined] # pylint: disable=no-member
 
 
 class PtrProxy:
@@ -93,7 +93,7 @@ class PtrProxy:
     def __call__(self, dtype, storage_scope="global"):
         if callable(dtype):
             dtype = dtype().dtype
-        return ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+        return ptr(dtype, storage_scope)  # type: ignore[attr-defined] # pylint: disable=no-member
 
     @deprecated("T.Ptr[...]", "T.handle(...)")
     def __getitem__(self, keys):
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
index 8a067267a3..63171f6722 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -24,6 +24,7 @@ import tvm
 from tvm.ir import PrimType
 from tvm.tir import Buffer, IterVar, PrimExpr, Var
 
+from ...ir_builder import ir as I
 from ...ir_builder import tir as T
 from ...ir_builder.base import IRBuilder
 from ...ir_builder.base import IRBuilderFrame as Frame
@@ -473,3 +474,28 @@ def visit_return(self: Parser, node: doc.Return) -> None:
         The doc AST return node.
     """
     self.report_error(node, "Return is not allowed.")
+
+
+@dispatch.register(token="tir", type_name="tvm_declare_function")
+def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None:
+    """The function declaration step for tir
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.Return
+        The doc AST return node.
+    """
+
+    ret_type = None
+    if node.returns is not None:
+        ret_type = self.eval_expr(node.returns)
+        if callable(ret_type):
+            ret_type = PrimType(ret_type().dtype)
+
+    # Only ret_type is needed for func_signature.
+    func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
+    global_var = I.decl_function(node.name, func_signature)
+    self.var_table.add(node.name, global_var)
diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc
index a81c56922d..addf129284 100644
--- a/src/script/ir_builder/ir/frame.cc
+++ b/src/script/ir_builder/ir/frame.cc
@@ -26,11 +26,15 @@ namespace ir_builder {
 namespace ir {
 
 void IRModuleFrameNode::ExitWithScope() {
-  ICHECK_EQ(functions.size(), global_vars.size());
-  int n = functions.size();
   Map<GlobalVar, BaseFunc> func_map;
-  for (int i = 0; i < n; ++i) {
-    func_map.Set(global_vars[i], functions[i]);
+  CHECK_EQ(functions.size(), global_var_map.size())
+      << "All functions must be defined in the IRModule. Got " << global_var_map.size()
+      << "declared function(s), but only " << functions.size() << "defined function(s).";
+  for (const auto& kv : functions) {
+    const GlobalVar& gv = kv.first;
+    const BaseFunc& func = kv.second;
+    CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined";
+    func_map.Set(gv, func);
   }
   IRBuilder builder = IRBuilder::Current();
   ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index a8cc452e4f..5764e90c8d 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -20,6 +20,8 @@
 #include <tvm/runtime/registry.h>
 #include <tvm/script/ir_builder/ir/ir.h>
 
+#include "./utils.h"
+
 namespace tvm {
 namespace script {
 namespace ir_builder {
@@ -27,12 +29,40 @@ namespace ir {
 
 IRModuleFrame IRModule() {
   ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
-  n->global_vars.clear();
+  n->global_var_map.clear();
   n->functions.clear();
   return IRModuleFrame(n);
 }
 
+GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) {
+  IRModuleFrame frame = FindModuleFrame("I.DeclFunction");
+  CHECK(!frame->global_var_map.count(func_name))
+      << "ValueError: function " << func_name << " already exists";
+  GlobalVar gv = GlobalVar(func_name);
+  CHECK(frame->functions.find(gv) == frame->functions.end())
+      << "ValueError: function " << func_name << " has already been defined.";
+  frame->global_var_map.Set(func_name, gv);
+  if (func_signature.defined()) {
+    frame->functions.Set(gv, func_signature);
+  }
+  return gv;
+}
+
+void DefFunction(const String& func_name, const BaseFunc& func) {
+  IRModuleFrame frame = FindModuleFrame("I.DefFunction");
+  auto it = frame->global_var_map.find(func_name);
+  CHECK(it != frame->global_var_map.end())
+      << "ValueError: function " << func_name << " does not exist, please declare it first.";
+  const GlobalVar& gv = (*it).second;
+  frame->functions.Set(gv, func);
+  if (func->checked_type_.defined()) {
+    gv->checked_type_ = func->checked_type_;
+  }
+}
+
 TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction);
 
 }  // namespace ir
 }  // namespace ir_builder
diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/utils.h
similarity index 59%
copy from src/script/ir_builder/ir/frame.cc
copy to src/script/ir_builder/ir/utils.h
index a81c56922d..58d5e53f70 100644
--- a/src/script/ir_builder/ir/frame.cc
+++ b/src/script/ir_builder/ir/utils.h
@@ -16,8 +16,9 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-#include <tvm/ir/module.h>
-#include <tvm/runtime/registry.h>
+#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_
+#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_
+
 #include <tvm/script/ir_builder/ir/frame.h>
 
 namespace tvm {
@@ -25,21 +26,24 @@ namespace script {
 namespace ir_builder {
 namespace ir {
 
-void IRModuleFrameNode::ExitWithScope() {
-  ICHECK_EQ(functions.size(), global_vars.size());
-  int n = functions.size();
-  Map<GlobalVar, BaseFunc> func_map;
-  for (int i = 0; i < n; ++i) {
-    func_map.Set(global_vars[i], functions[i]);
-  }
+inline IRModuleFrame FindModuleFrame(const String& method) {
   IRBuilder builder = IRBuilder::Current();
-  ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
-  builder->result = tvm::IRModule(func_map);
+  if (Optional<IRModuleFrame> frame = builder->FindFrame<IRModuleFrame>()) {
+    const Optional<IRModuleFrame>& last_module_frame = builder->GetLastFrame<IRModuleFrame>();
+    if (last_module_frame.defined() && last_module_frame.value() == frame) {
+      return frame.value();
+    }
+  } else {
+    LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method
+               << "' is called under I.ir_module()";
+  }
+  LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()";
+  throw;
 }
 
-TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
-
 }  // namespace ir
 }  // namespace ir_builder
 }  // namespace script
 }  // namespace tvm
+
+#endif  // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_
diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc
index 1e63201a40..dd8d3c2ed3 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -16,6 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include <tvm/script/ir_builder/ir/ir.h>
 #include <tvm/script/ir_builder/tir/frame.h>
 #include <tvm/tir/function.h>
 
@@ -41,9 +42,17 @@ void PrimFuncFrameNode::ExitWithScope() {
     ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
     builder->result = func;
   } else if (Optional<ir::IRModuleFrame> opt_frame = builder->FindFrame<ir::IRModuleFrame>()) {
-    ir::IRModuleFrame frame = opt_frame.value();
-    frame->global_vars.push_back(GlobalVar(name.value_or("")));
-    frame->functions.push_back(func);
+    CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the "
+                             "function scope, if it's defined in a Module";
+    const ir::IRModuleFrame& frame = opt_frame.value();
+    const String& func_name = name.value_or("");
+    if (!frame->global_var_map.count(func_name)) {
+      // Case. First time visiting the function.
+      ir::DeclFunction(func_name, func);
+    }
+    // Define the function.
+    // Note we do checks to disallow redefinition of functions inside the `DefFunction`.
+    ir::DefFunction(func_name, func);
   } else {
     LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc";
   }
diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h
index 7ccc132fa1..f3b547532c 100644
--- a/src/script/ir_builder/tir/utils.h
+++ b/src/script/ir_builder/tir/utils.h
@@ -87,7 +87,7 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) {
  * \return The top frame of BlockFrame.
  */
 inline BlockFrame FindBlockFrame(const String& method) {
-  if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
+  if (Optional<BlockFrame> frame = IRBuilder::Current()->FindFrame<BlockFrame>()) {
     return frame.value();
   } else if (Optional<BlockFrame> frame = IRBuilder::Current()->FindFrame<BlockFrame>()) {
     LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block().  "