You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/08/18 12:43:07 UTC
[tvm] branch main updated: [TVMScript] IRBuilder, IRBuilderFrame base class (#12482)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 250b68e202 [TVMScript] IRBuilder, IRBuilderFrame base class (#12482)
250b68e202 is described below
commit 250b68e2028bd84d7386ffaf8a7c91feef8d45ec
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Thu Aug 18 05:43:00 2022 -0700
[TVMScript] IRBuilder, IRBuilderFrame base class (#12482)
* [TVMScript] IRBuilder, IRBuilderFrame base class
This PR introduces basic data structures of the generic IRBuilder
across the codebase.
IRBuilder is a general-purpose IRBuilder that can be used in TIR, Relax
and any other vendor-specific dialects; IRBuilderFrame is where contexual
information as stored in the IRBuilder.
* fix linter
* Update include/tvm/script/ir_builder/base.h
Co-authored-by: Junru Shao <ju...@gmail.com>
---
include/tvm/script/ir_builder/base.h | 302 +++++++++++++++++++++
python/tvm/script/ir_builder/__init__.py | 18 ++
python/tvm/script/ir_builder/_ffi_api.py | 20 ++
python/tvm/script/ir_builder/base.py | 181 ++++++++++++
src/script/ir_builder/base.cc | 115 ++++++++
.../unittest/test_tvmscript_ir_builder_base.py | 42 +++
6 files changed, 678 insertions(+)
diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h
new file mode 100644
index 0000000000..61ca3eb9f7
--- /dev/null
+++ b/include/tvm/script/ir_builder/base.h
@@ -0,0 +1,302 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_IR_BUILDER_BASE_H_
+#define TVM_SCRIPT_IR_BUILDER_BASE_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/node/node.h>
+
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+
+////////////////////////////// IRBuilderFrame //////////////////////////////
+
+/*!
+ * \brief A stack frame of the IRBuilder used to keep track of the current scope.
+ * Furthermore, the information stored in each stack frame can be useful for context-dependent
+ * IR construction.
+ *
+ * \example
+ *
+ * The `T::MatchBuffer` below adds an element in `PrimFuncNode::buffer_map`:
+ *
+ * \code {.cpp}
+ *
+ * using T = tvm::script::ir_builder::tir;
+ * With <PrimFuncFrame> _(...);
+ * Buffer buffer = T::MatchBuffer(...);
+ *
+ * \endcode
+ *
+ * The `T::MatchBuffer` below instead generates `MatchBufferRegion` in a TIR block:
+ *
+ * \code {.cpp}
+ *
+ * using T = tvm::script::ir_builder::tir;
+ * With <PrimFuncFrame> _(...);
+ * {
+ * With<BlockFrame> _2(...);
+ * Buffer buffer = T::MatchBuffer(...);
+ * }
+ *
+ * \endcode
+ */
+class IRBuilderFrameNode : public runtime::Object {
+ public:
+ /*! \brief A list of callbacks used when exiting the frame. */
+ std::vector<runtime::TypedPackedFunc<void()>> callbacks;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `callbacks` is not visited.
+ }
+
+ static constexpr const char* _type_key = "script.ir_builder.IRBuilderFrame";
+ TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object);
+
+ public:
+ /*! \brief Default destructor. */
+ virtual ~IRBuilderFrameNode() = default;
+ /*!
+ * \brief The method called when entering RAII scope.
+ * \sa tvm::support::With
+ */
+ virtual void EnterWithScope();
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ virtual void ExitWithScope();
+ /*!
+ * \brief Add a callback method invoked when exiting the RAII scope.
+ * \param callback The callback to be added.
+ */
+ void AddCallback(runtime::TypedPackedFunc<void()> callback);
+};
+
+/*!
+ * \brief Managed reference to an IRBuilderFrameNode.
+ * \sa IRBuilderFrameNode
+ */
+class IRBuilderFrame : public runtime::ObjectRef {
+ public:
+ /*! \brief Default destructor. */
+ virtual ~IRBuilderFrame() = default;
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilderFrame, ObjectRef, IRBuilderFrameNode);
+
+ protected:
+ /*! \brief Disallow direct construction of this object. */
+ IRBuilderFrame() = default;
+
+ public:
+ /*!
+ * \brief Redirected to `IRBuilderFrameNode::EnterWithScope`.
+ * \sa IRBuilderFrameNode::EnterWithScope
+ */
+ inline void EnterWithScope() {
+ ICHECK(data_ != nullptr);
+ static_cast<IRBuilderFrameNode*>(data_.get())->EnterWithScope();
+ }
+ /*!
+ * \brief Redirected to `IRBuilderFrameNode::ExitWithScope`.
+ * \sa IRBuilderFrameNode::ExitWithScope
+ */
+ inline void ExitWithScope() {
+ ICHECK(data_ != nullptr);
+ static_cast<IRBuilderFrameNode*>(data_.get())->ExitWithScope();
+ data_.reset();
+ }
+};
+
+////////////////////////////// IRBuilder //////////////////////////////
+
+/*!
+ * \brief A dialect-agnostic IRBuilder that constructs any IR of TVM.
+ * An idiomatic use of this class is to put this inside the RAII with-scope,
+ * call dialect-specific methods accordingly. Upon exiting the scope.
+ *
+ * \code
+ *
+ * PrimFunc ConstructPrimFunc() {
+ * using tvm::script::ir_builder::IRBuilder;
+ * using T = tvm::script::ir_builder::tir;
+ * IRBuilder builder;
+ * // Step 1. Place IRBuilder inside the with-scope.
+ * {
+ * With<IRBuilder> _(builder);
+ * // Step 2. Call dialect-specific methods.
+ * With<T::PrimFuncFrame> _2(...);
+ * T::MatchBuffer(...);
+ * }
+ * // Step 3. Return the constructed PrimFunc.
+ * return builder->Get<PrimFunc>();
+ * }
+ *
+ * \endcode
+ */
+class IRBuilderNode : public runtime::Object {
+ public:
+ /*! \brief A stack of context frames in the IRBuilder */
+ runtime::Array<IRBuilderFrame> frames;
+ /*! \brief The outcome of IR construction */
+ Optional<ObjectRef> result;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("frames", &frames);
+ v->Visit("result", &result);
+ }
+
+ static constexpr const char* _type_key = "script.ir_builder.IRBuilder";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object);
+
+ public:
+ /*!
+ * \brief Find a frame of the given type in the stack `this->frames` from top to bottom.
+ * \tparam T The type of the frame to find.
+ * \return The frame if found, otherwise NullOpt.
+ */
+ template <typename TFrame>
+ inline Optional<TFrame> FindFrame() const;
+ /*!
+ * \brief Get the frame on top of the stack `this->frames` if its type is `TFrame`.
+ * \tparam TFrame The assumed type of the last frame on stack.
+ * \return The frame if the stack is non-empty and the top of the stack is of type `TFrame`.
+ * Otherwise NullOpt.
+ */
+ template <typename TFrame>
+ inline Optional<TFrame> GetLastFrame() const;
+ /*!
+ * \brief Get the IR being constructed.
+ * \tparam TObjectRef The type of the IR being constructed.
+ * \return The resulting IR. Throw an exception if the IR is not constructed yet.
+ */
+ template <typename TObjectRef>
+ inline TObjectRef Get() const;
+};
+
+/*!
+ * \brief Managed reference to an IRBuilderNode.
+ * \sa IRBuilderNode
+ */
+class IRBuilder : public runtime::ObjectRef {
+ public:
+ /*! \brief Creates an IRBuilder. */
+ IRBuilder();
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
+
+ public:
+ /*!
+ * \brief Puts the current IRBuilder into a thread-local scope, which can be retrieved using
+ * `IRBuilder::Current()`.
+ *
+ * \code {.cpp}
+ * IRBuilder builder;
+ * {
+ * With<IRBuilder> _(builder);
+ * // IRBuilder::Current() == builder
+ * }
+ * // IRBuilder::Current() == nullptr
+ * \endcode
+ *
+ * \sa IRBuilder::Current
+ * \sa IRBuilder::ExitWithScope
+ * \sa tvm::support::With
+ */
+ void EnterWithScope();
+ /*!
+ * \brief Exit the RAII scope.
+ * \sa IRBuilder::EnterWithScope
+ * \sa IRBuilder::Current
+ * \sa tvm::support::With
+ */
+ void ExitWithScope();
+ /*!
+ * \brief Get the current IRBuilder in the current thread-local scope.
+ * \return The current IRBuilder.
+ * \sa IRBuilder::EnterWithScope
+ * \sa IRBuilder::ExitWithScope
+ * \sa tvm::support::With
+ */
+ static IRBuilder Current();
+ /*!
+ * \brief Give a string name to the `obj`
+ * \tparam TObjectRef The type of the object to name.
+ * \param name The name to give to the object.
+ * \param obj The object to name.
+ */
+ template <class TObjectRef>
+ inline static TObjectRef Name(String name, TObjectRef obj);
+};
+
+////////////////////////////// Details //////////////////////////////
+
+namespace details {
+
+class Namer {
+ public:
+ using FType = NodeFunctor<void(const ObjectRef&, String)>;
+ static FType& vtable();
+ static void Name(ObjectRef node, String name);
+};
+
+} // namespace details
+
+template <class TObjectRef>
+inline TObjectRef IRBuilder::Name(String name, TObjectRef obj) {
+ details::Namer::Name(obj, name);
+ return Downcast<TObjectRef>(obj);
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::FindFrame() const {
+ using TFrameNode = typename TFrame::ContainerType;
+ for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
+ if (const TFrameNode* p = (*it).template as<TFrameNode>()) {
+ return GetRef<TFrame>(p);
+ }
+ }
+ return NullOpt;
+}
+
+template <typename TFrame>
+inline Optional<TFrame> IRBuilderNode::GetLastFrame() const {
+ using TFrameNode = typename TFrame::ContainerType;
+ if (!frames.empty() && frames.back()->IsInstance<TFrameNode>()) {
+ return Downcast<TFrame>(frames.back());
+ }
+ return NullOpt;
+}
+
+template <typename TObjectRef>
+inline TObjectRef IRBuilderNode::Get() const {
+ using TObject = typename TObjectRef::ContainerType;
+ CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet";
+ const auto* n = result.as<TObject>();
+ CHECK(n != nullptr) << "TypeError: IRBuilder result is not of type: " << TObject::_type_key;
+ return GetRef<TObjectRef>(n);
+}
+
+} // namespace ir_builder
+} // namespace script
+} // namespace tvm
+
+#endif // TVM_SCRIPT_IR_BUILDER_BASE_H_
diff --git a/python/tvm/script/ir_builder/__init__.py b/python/tvm/script/ir_builder/__init__.py
new file mode 100644
index 0000000000..b325fadd86
--- /dev/null
+++ b/python/tvm/script/ir_builder/__init__.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""tvm.script.ir_builder is a generic IR builder for TVM."""
+from .base import IRBuilder
diff --git a/python/tvm/script/ir_builder/_ffi_api.py b/python/tvm/script/ir_builder/_ffi_api.py
new file mode 100644
index 0000000000..68811c9e01
--- /dev/null
+++ b/python/tvm/script/ir_builder/_ffi_api.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for tvm.script.ir_builder"""
+import tvm._ffi
+
+tvm._ffi._init_api("script.ir_builder", __name__) # pylint: disable=protected-access
diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py
new file mode 100644
index 0000000000..767fa8bf25
--- /dev/null
+++ b/python/tvm/script/ir_builder/base.py
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A generic IRBuilder across the TVM stack"""
+from typing import Any, Callable, List
+
+from tvm._ffi import register_object as _register_object
+from tvm.runtime import Object as _Object
+
+from . import _ffi_api
+
+
+@_register_object("script.ir_builder.IRBuilderFrame")
+class IRBuilderFrame(_Object):
+ """A stack frame of the IRBuilder used to keep track of the current scope.
+ Furthermore, the information stored in each stack frame can be useful for context-dependent
+ IR construction.
+
+ Examples
+ --------
+
+ The `T.match_buffer` below instead an element in the buffer map of `PrimFuncFrame`:
+
+ .. code-block:: python
+
+ from tvm.script.ir_builder import tir as T
+ from tvm.script.ir_builder import IRBuilder
+
+ with IRBuilder() as builder:
+ with T.prim_func(...): # pushes a PrimFuncFrame (subclass of IRBuilderFrame)
+ # to `builder`'s stack of frames
+ buffer = T.match_buffer(...)
+
+
+ The `T.match_buffer` below instead generates `MatchBufferRegion` in a TIR block:
+
+ .. code-block:: python
+
+ from tvm.script.ir_builder import tir as T
+ from tvm.script.ir_builder import IRBuilder
+
+ with IRBuilder() as builder:
+ with T.prim_func(...): # pushes a PrimFuncFrame (subclass of IRBuilderFrame)
+ # to `builder`'s stack of frames
+ with T.block(...): # pushes a BlockFrame (subclass of IRBuilderFrame)
+ # to `builder`'s stack of frames
+ buffer = T.match_buffer(...)
+ """
+
+ def __enter__(self) -> "IRBuilderFrame":
+ _ffi_api.IRBuilderFrameEnter(self) # pylint: disable=no-member # type: ignore
+ return self
+
+ def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
+ _ffi_api.IRBuilderFrameExit(self) # pylint: disable=no-member # type: ignore
+
+ def add_callback(self, callback: Callable[[], None]) -> None:
+ """Add a callback method invoked when exiting the with-scope.
+
+ Parameters
+ ----------
+ callback : Callable[[], None]
+ The callback method to be invoked.
+ """
+ _ffi_api.IRBuilderFrameAddCallback( # pylint: disable=no-member # type: ignore
+ self, callback
+ )
+
+
+@_register_object("script.ir_builder.IRBuilder")
+class IRBuilder(_Object):
+ """A dialect-agnostic IRBuilder that constructs any IR of TVM.
+
+ Examples
+ --------
+ An idiomatic use of this class is to put this inside the with-scope,
+ call dialect-specific methods accordingly. Upon exiting the scope.
+
+ .. code-block:: python
+ from tvm.script.ir_builder import tir as T
+ from tvm.script.ir_builder import IRBuilder
+
+ with IRBuilder() as builder:
+ with T.prim_func(...): # pushes a PrimFuncFrame (subclass of IRBuilderFrame)
+ # to `builder`'s stack of frames
+ buffer = T.match_buffer(...)
+
+ return builder.get() # returns the constructed IR, i.e. tir.PrimFunc
+ """
+
+ def __init__(self) -> None:
+ """Construct an IRBuilder."""
+ self.__init_handle_by_constructor__(
+ _ffi_api.IRBuilder # pylint: disable=no-member # type: ignore
+ )
+
+ def __enter__(self) -> "IRBuilder":
+ """Enter the with-scope for IRBuilder, which allows the IRBuilder to be discoverable
+ using `IRBuilder.current()`.
+
+ Examples
+ --------
+ .. code-block:: python
+ from tvm.script.ir_builder import IRBuilder
+
+ with IRBuilder() as builder:
+ assert IRBuilder.current() == builder
+ """
+ _ffi_api.IRBuilderEnter(self) # pylint: disable=no-member # type: ignore
+ return self
+
+ def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument
+ _ffi_api.IRBuilderExit(self) # pylint: disable=no-member # type: ignore
+
+ @staticmethod
+ def current() -> "IRBuilder":
+ """Get the current IRBuilder put in the with-scope.
+
+ Returns
+ -------
+ builder : IRBuilder
+ The current IRBuilder.
+ """
+ return _ffi_api.IRBuilderCurrent() # pylint: disable=no-member # type: ignore
+
+ def get(self) -> _Object:
+ """Get the constructed IR."""
+ return _ffi_api.IRBuilderGet(self) # pylint: disable=no-member # type: ignore
+
+ @staticmethod
+ def name(s: str, v: Any) -> Any:
+ """Set the name of an object.
+
+ Parameters
+ ----------
+ s : str
+ The name of the object.
+ v : Any
+ The object to name.
+
+ Returns
+ -------
+ v : Any
+ The same object with the name set.
+ """
+ return _ffi_api.IRBuilderName(s, v) # pylint: disable=no-member # type: ignore
+
+ @staticmethod
+ def name_many( # pylint: disable=invalid-name
+ s: List[str],
+ vs: List[Any],
+ ) -> List[Any]:
+ """Set the name of a list of objects.
+
+ Parameters
+ ----------
+ s : List[str]
+ The names of the objects.
+ vs : List[Any]
+ The objects to name.
+
+ Returns
+ -------
+ vs : List[Any]
+ The same objects with the names set.
+ """
+ assert len(s) == len(vs)
+ return [IRBuilder.name(i, v) for i, v in zip(s, vs)]
diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc
new file mode 100644
index 0000000000..8303efff4f
--- /dev/null
+++ b/src/script/ir_builder/base.cc
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/ir/module.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/script/ir_builder/base.h>
+
+namespace tvm {
+namespace script {
+namespace ir_builder {
+
+void IRBuilderFrameNode::EnterWithScope() {
+ IRBuilder::Current()->frames.push_back(GetRef<IRBuilderFrame>(this));
+}
+
+void IRBuilderFrameNode::ExitWithScope() {
+ for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
+ (*it)();
+ }
+ this->callbacks.clear();
+ IRBuilder::Current()->frames.pop_back();
+}
+
+void IRBuilderFrameNode::AddCallback(runtime::TypedPackedFunc<void()> callback) {
+ if (IRBuilder::Current()->frames.empty()) {
+ LOG(FATAL) << "ValueError: No frames in Builder to add callback";
+ }
+ IRBuilder::Current()->frames.back()->callbacks.push_back(callback);
+}
+
+IRBuilder::IRBuilder() {
+ ObjectPtr<IRBuilderNode> n = make_object<IRBuilderNode>();
+ n->frames.clear();
+ n->result = NullOpt;
+ data_ = n;
+}
+
+std::vector<IRBuilder>* ThreadLocalBuilderStack() {
+ thread_local std::vector<IRBuilder> stack;
+ return &stack;
+}
+
+void IRBuilder::EnterWithScope() {
+ IRBuilderNode* n = this->get();
+ CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: "
+ << n->frames.size()
+ << ". Please use a fresh new builder every time building IRs";
+ n->result = NullOpt;
+ std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+ stack->push_back(*this);
+}
+
+void IRBuilder::ExitWithScope() {
+ std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+ ICHECK(!stack->empty());
+ stack->pop_back();
+}
+
+IRBuilder IRBuilder::Current() {
+ std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+ CHECK(!stack->empty()) << "ValueError: No builder in current scope";
+ return stack->back();
+}
+
+namespace details {
+
+Namer::FType& Namer::vtable() {
+ static FType inst;
+ return inst;
+}
+
+void Namer::Name(ObjectRef node, String name) {
+ static const FType& f = vtable();
+ CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name;
+ CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \""
+ << node->GetTypeKey();
+ f(node, name);
+}
+
+} // namespace details
+
+TVM_REGISTER_NODE_TYPE(IRBuilderFrameNode);
+TVM_REGISTER_NODE_TYPE(IRBuilderNode);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameEnter")
+ .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::EnterWithScope);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameExit")
+ .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::ExitWithScope);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameAddCallback")
+ .set_body_method<IRBuilderFrame>(&IRBuilderFrameNode::AddCallback);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); });
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet")
+ .set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);
+
+} // namespace ir_builder
+} // namespace script
+} // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_base.py b/tests/python/unittest/test_tvmscript_ir_builder_base.py
new file mode 100644
index 0000000000..b41e8cdd92
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_ir_builder_base.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unittests for tvm.script.ir_builder.base"""
+import pytest
+from tvm.script.ir_builder import IRBuilder
+
+
+def test_ir_builder_scope():
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ assert IRBuilder.current() == ib
+
+
+def test_ir_builder_multi_scope():
+ with IRBuilder() as ib: # pylint: disable=invalid-name
+ with IRBuilder() as ib2: # pylint: disable=invalid-name
+ assert IRBuilder.current() == ib2
+ assert IRBuilder.current() == ib
+
+
+def test_ir_builder_no_scope():
+ with pytest.raises(ValueError):
+ IRBuilder.current()
+
+
+if __name__ == "__main__":
+ test_ir_builder_scope()
+ test_ir_builder_multi_scope()
+ test_ir_builder_no_scope()