You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/03/14 21:26:34 UTC

[incubator-tvm] branch master updated: [TIR] Introduce tir::PrimFunc (#5070)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new e031641  [TIR] Introduce tir::PrimFunc (#5070)
e031641 is described below

commit e03164159ce08f2739a26c10531b26713e72153e
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sat Mar 14 14:26:23 2020 -0700

    [TIR] Introduce tir::PrimFunc (#5070)
    
    This PR introduces tir::PrimFunc which will be used as the TIR function
    container in the unified IR.
    
    Also streamlined the function attributes a bit further.
    - All common attributes are under tvm::attr
    - TIR specific attributes are under tvm::tir::attr and comes with a tir prefix
    - Use stl_style for attributes for now
---
 include/tvm/ir/function.h                          |  99 ++++++++++++
 include/tvm/ir/type.h                              |   7 +
 include/tvm/relay/function.h                       |  27 ----
 include/tvm/tir/function.h                         | 177 +++++++++++++++++++++
 include/tvm/tir/op.h                               |  12 ++
 python/tvm/ir/__init__.py                          |   3 +-
 python/tvm/ir/expr.py                              |   9 --
 python/tvm/ir/{__init__.py => function.py}         |  24 ++-
 python/tvm/relay/expr.py                           |   3 +-
 python/tvm/tir/__init__.py                         |   2 +
 python/tvm/tir/function.py                         |  86 ++++++++++
 src/ir/function.cc                                 |   1 +
 src/printer/relay_text_printer.cc                  |  13 +-
 src/relay/ir/function.cc                           |  18 +--
 src/tir/ir/function.cc                             |  91 +++++++++++
 src/tir/ir/op.cc                                   |  12 ++
 .../{test_lang_basic.py => test_tir_nodes.py}      |  24 ++-
 17 files changed, 536 insertions(+), 72 deletions(-)

diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h
index 4cb5d70..db7f446 100644
--- a/include/tvm/ir/function.h
+++ b/include/tvm/ir/function.h
@@ -33,6 +33,36 @@
 namespace tvm {
 
 /*!
+ * \brief Possible Calling conventions.
+ *
+ *  NOTE: The calling convention also implies
+ *  the way we implement the function during lowering.
+ */
+enum class CallingConv : int {
+  /*!
+   * \brief Default calling convetion.
+   *
+   * - Uses the native calling convention of the target.
+   * - Implementation: specified by the native target.
+   */
+  kDefault = 0,
+  /*!
+   * \brief Device kernel launch
+   *
+   * - Call by PackedFunc calling convention.
+   * - Implementation: defined by device runtime(e.g. runtime/cuda)
+   */
+  kDeviceKernelLaunch = 2,
+  /*!
+   * \brief PackedFunc that exposes a CPackedFunc signature.
+   *
+   * - Calling by PackedFunc calling convention.
+   * - Implementation: Expose a function with the CPackedFunc signature.
+   */
+  kCPackedFunc = 3,
+};
+
+/*!
  * \brief Base node of all functions.
  *
  * We support several variants of functions throughout the stack.
@@ -115,5 +145,74 @@ class BaseFunc : public RelayExpr {
   TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
 };
 
+/*!
+ * \brief Create a new function that copies func, but overrides
+ *        the attribute value key with the value.
+ *
+ * \param func The input function.
+ * \param attr_key The attribute key.
+ * \param attr_value The value attribute value.
+ *
+ * \tparam TFunc The corresponding function type.
+ *
+ * \returns The new function with updated attributes.
+ *
+ * \note This function performs copy on write optimization for func.
+ *       If we move a uniquely referenced func into WithAttr,
+ *       then no additional copy will be performed.
+ *
+ *       This is also why we make it as a function instead of a member function
+ *       and why we pass by value in the first argument.
+ *
+ * \code
+ *
+ *  // Recommended way to trigger copy on write
+ *  func = WithAttr(std::move(func), "key1", value1);
+ *  func = WithAttr(std::move(func), "key2", value2);
+ *
+ * \endcode
+ */
+template<typename TFunc,
+         typename = typename std::enable_if<
+           std::is_base_of<BaseFunc, TFunc>::value>::type>
+inline TFunc WithAttr(TFunc func,
+                      const std::string& attr_key,
+                      ObjectRef attr_value) {
+  using TNode = typename TFunc::ContainerType;
+  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
+  TNode* node = func.CopyOnWrite();
+  if (node->attrs.defined()) {
+    node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
+  } else {
+    Map<std::string, ObjectRef> dict = {{attr_key, attr_value}};
+    node->attrs = DictAttrs(dict);
+  }
+  return func;
+}
+
+/*!
+ * \brief Generic attribute names that can be attached to any function.
+ *
+ * \sa tvm::tir::attr, tvm::relay::attr
+ */
+namespace attr {
+/*!
+ * \brief Indicates the special calling convention.
+ *
+ * Type: Integer
+ *
+ * \sa tvm::CallingConv
+ */
+constexpr const char* kCallingConv = "calling_conv";
+
+/*!
+ * \brief Compilation target of the function.
+ *
+ * Type: Target
+ *
+ * \sa tvm::Target
+ */
+constexpr const char* kTarget = "target";
+}  // namespace attr
 }  // namespace tvm
 #endif  // TVM_IR_FUNCTION_H_
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index 9e87731..7fd224b 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -277,6 +277,13 @@ class TupleType : public Type {
 };
 
 /*!
+ * \return a type that represents void.
+ */
+inline Type VoidType() {
+  return TupleType::Empty();
+}
+
+/*!
  * \brief Potential Constraints in a function.
  * \sa TypeConstraint
  */
diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h
index 27aa2e8..f7514c7 100644
--- a/include/tvm/relay/function.h
+++ b/include/tvm/relay/function.h
@@ -115,33 +115,6 @@ class Function : public BaseFunc {
 };
 
 /*!
- * \brief Create a new function that copies func, but overrides
- *        the attribute value key with the value.
- *
- * \param func The input function.
- * \param attr_key The attribute key.
- * \param attr_value The value attribute value.
- *
- * \returns The new function with updated attributes.
- *
- * \note This function performs copy on write optimization for func.
- *       If we move a uniquely referenced func into WithAttr,
- *       then no additional copy will be performed.
- *
- *       This is also why we make it as a function instead of a member function
- *       and why we pass by value in the first argument.
- *
- * \code
- *
- *  // Recommended way to trigger copy on write
- *  func = WithAttr(std::move(func), "key1", value1);
- *  func = WithAttr(std::move(func), "key2", value2);
- *
- * \endcode
- */
-TVM_DLL Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value);
-
-/*!
  * \brief namespace of the attributes that can be attached to a relay::Function.
  */
 namespace attr {
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
new file mode 100644
index 0000000..0680267
--- /dev/null
+++ b/include/tvm/tir/function.h
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/tir/function.h
+ * \brief TIR Function.
+ */
+#ifndef TVM_TIR_FUNCTION_H_
+#define TVM_TIR_FUNCTION_H_
+
+#include <tvm/ir/function.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/buffer.h>
+#include <tvm/tir/stmt.h>
+#include <string>
+
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Primitive functions that contains TIR statements.
+ *
+ * The PrimFunc provides low-level code representation does not
+ * automatically manage
+ *
+ * \sa PrimFunc
+ */
+class PrimFuncNode : public BaseFuncNode {
+ public:
+  /*! \brief Function parameters */
+  Array<tir::Var> params;
+  /*! \brief The body of the function */
+  tir::Stmt body;
+  /*! \brief The return type of the function. */
+  Type ret_type;
+  /*!
+   * \brief Maps some parameters to specific Buffer data structures.
+   *
+   *  buffer_map provides a way to express data structure's field and shape
+   *  constraints. The provided information is used in the program analysis
+   *  and the code generation.
+   *
+   *  - It defines the vars in the Buffer (m, n) in the cases below when
+   *    they appears in the buffer_map for the first time.
+   *  - When a var appears multiple times, they translate into runtime
+   *    assertion to check the field constraint.
+   *
+   *  \code
+   *
+   *   # The corresponding fields of f are as follows
+   *   #
+   *   # - f.params = [a, b]
+   *   # - f.buffer_map = {a: A, b: B}
+   *   # - A = decl_buffer(shape=[m, n])
+   *   # - B = decl_buffer(shape=[m, n])
+   *
+   *   def f(a, b):
+   *       m, n = var(), var()
+   *       A = bind_buffer(a, shape=[m, n])
+   *       B = bind_buffer(b, shape=[m, n])
+   *       # body
+   *
+   *  \endcode
+   *
+   *  buffer_map is a sugar to express:
+   *  - Parameter unpacking: e.g. I can load a.shape[0] to get value of m
+   *  - Constraint checking: a.shape[0] must equal b.shape[0] because they
+   *    both corresponds to m.
+
+   *  While we could have express parameter unpacking and constraint using
+   *  normal statements, making buffer_map as first class citizen of PrimFunc
+   *  will make program analysis much easier.
+   *
+   * \note This field can be nullptr
+   */
+  Map<tir::Var, Buffer> buffer_map;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("params", &params);
+    v->Visit("body", &body);
+    v->Visit("ret_type", &ret_type);
+    v->Visit("buffer_map", &buffer_map);
+    v->Visit("attrs", &attrs);
+    v->Visit("span", &span);
+    v->Visit("_checked_type_", &checked_type_);
+  }
+
+  /*!
+   * \brief Return the derived function annotation of this function.
+   *
+   * \return The function type annotation.
+   * \note The function type annotation of PrimExpr is
+   *       directly derived from the Vars without the need of type inference.
+   */
+  TVM_DLL FuncType func_type_annotation() const;
+
+  static constexpr const char* _type_key = "tir.PrimFunc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
+};
+
+/*!
+ * \brief Managed reference to PrimFuncNode.
+ * \sa PrimFuncNode
+ */
+class PrimFunc : public BaseFunc {
+ public:
+  /*!
+   * \brief Constructor
+   * \param params The parameters of the function.
+   * \param body The body of the function.
+   * \param ret_type The return type of the function.
+   * \param buffer_map The buffer map for parameter buffer unpacking.
+   * \param attrs Additional function attributes.
+   */
+  TVM_DLL PrimFunc(Array<tir::Var> params,
+                   Stmt body,
+                   Type ret_type = VoidType(),
+                   Map<tir::Var, Buffer> buffer_map = NullValue<Map<tir::Var, Buffer>>(),
+                   DictAttrs attrs = NullValue<DictAttrs>());
+
+  TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
+};
+
+/*!
+ * \brief PrimFunc specific attribute names.
+ *
+ * \sa tvm::attr
+ */
+namespace attr {
+/*!
+ * \brief List of thread IterVar that a DeviceLaunch function corresponds to.
+ *
+ * Type: Array<tir::IterVar>
+ *
+ * We call a device kernel launch function f using the following convention:
+ *
+ * Call(f,
+ *      [arg1, arg2, ..., arg_n,
+ *       work_size_1, work_size_2, ... work_size_m])
+ *
+ * Here n = len(arg), m = len(work_size) = len(device_thread_axis).
+ *
+ * The list of device_thread_axis indicates how can be bind the
+ * work_size arguments to the corresponding threads.
+ *
+ * \sa tvm::CallingConv::kDeviceKernelLaunch
+ */
+constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";
+
+/*!
+ * \brief Whether to set noalias rule on the function arguments.
+ *
+ * Type: Integer
+ */
+constexpr const char* kNoAlias = "tir.noalias";
+}  // namespace attr
+}  // namespace tir
+}  // namespace tvm
+#endif  // TVM_TIR_FUNCTION_H_
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index a30c3c9..6ee5063 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -28,6 +28,7 @@
 #ifndef TVM_TIR_OP_H_
 #define TVM_TIR_OP_H_
 
+#include <tvm/ir/type.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt.h>
 
@@ -37,6 +38,7 @@
 
 
 namespace tvm {
+
 // Most common operators can be overloaded by argument type(PrimExpr).
 // So we put them under the root namespace.
 // It is also necessary to overload operators for PrimExpr.
@@ -45,6 +47,16 @@ namespace tvm {
 // as they are more specific to the tir namespace.
 
 /*!
+ * \brief Get the type of the expression under the unified type system.
+ *
+ * This function could return a more refined type than
+ * the runtime type provided by expr->dtype
+ *
+ * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
+ */
+TVM_DLL Type GetType(const PrimExpr& expr);
+
+/*!
  * Query the maximum possible value of dtype.
  * \param dtype The data type.
  * \return the maximum possible value in this format.
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index a718124..4160326 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -21,7 +21,8 @@ from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
 from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
 from .tensor_type import TensorType
 from .type_relation import TypeCall, TypeRelation
-from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
+from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
+from .function import BaseFunc
 from .adt import Constructor, TypeData
 from .module import IRModule
 from .attrs import Attrs, DictAttrs, make_node
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index 00ceb5b..4e6bf16 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -51,15 +51,6 @@ class RelayExpr(BaseExpr):
         return ret
 
 
-class BaseFunc(RelayExpr):
-    """Base class of all functions."""
-    @property
-    def attrs(self):
-        """Return the attrs member of the function.
-        """
-        return _ffi_api.BaseFunc_Attrs(self)
-
-
 @tvm._ffi.register_object("relay.GlobalVar")
 class GlobalVar(RelayExpr):
     """A global variable in the IR.
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/function.py
similarity index 55%
copy from python/tvm/ir/__init__.py
copy to python/tvm/ir/function.py
index a718124..70eb51a 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/function.py
@@ -14,17 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=unused-import
-"""Common data structures across all IR variants."""
-from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
-from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
-from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
-from .tensor_type import TensorType
-from .type_relation import TypeCall, TypeRelation
-from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
-from .adt import Constructor, TypeData
-from .module import IRModule
-from .attrs import Attrs, DictAttrs, make_node
-from .container import Array, Map
+"""Function defintiions."""
+from .expr import RelayExpr
+from . import _ffi_api
 
-from . import transform
+
+class BaseFunc(RelayExpr):
+    """Base class of all functions."""
+    @property
+    def attrs(self):
+        """Return the attrs member of the function.
+        """
+        return _ffi_api.BaseFunc_Attrs(self)
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index a3c6251..61a5fb7 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -282,7 +282,8 @@ class Function(BaseFunc):
         func : Function
             A new copy of the function
         """
-        return _expr.FunctionWithAttr(self, attr_key, attr_value)
+        return _expr.FunctionWithAttr(
+            self, attr_key, convert(attr_value))
 
 
 
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index fa244ac..b8a56f8 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -31,6 +31,8 @@ from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
 from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
 from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
 
+from .function import PrimFunc
+
 from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
 from .op import call_llvm_intrin, all, any, min_value, max_value, trace
 from .op import exp, exp2, exp10, log, log2, log10
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
new file mode 100644
index 0000000..37946f6
--- /dev/null
+++ b/python/tvm/tir/function.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Function data types."""
+
+import tvm._ffi
+import tvm.runtime
+from tvm.ir import BaseFunc
+from .buffer import Buffer
+from .expr import Var
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("tir.PrimFunc")
+class PrimFunc(BaseFunc):
+    """A function declaration expression.
+
+    Parameters
+    ----------
+    params: List[Union[tvm.tir.Var, tvm.tir.Buffer]]
+        List of input parameters to the function.
+
+    body: tvm.tir.Stmt
+        The body of the function.
+
+    ret_type: tvm.ir.Type
+        The return type annotation of the function.
+
+    buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer]
+        The buffer binding map.
+
+    attrs: Optional[tvm.Attrs]
+        Attributes of the function, can be None
+    """
+    def __init__(self,
+                 params,
+                 body,
+                 ret_type=None,
+                 buffer_map=None,
+                 attrs=None):
+        param_list = []
+        buffer_map = {} if buffer_map is None else buffer_map
+        for x in params:
+            if isinstance(x, Buffer):
+                var = Var(x.name, dtype="handle")
+                param_list.append(var)
+                buffer_map[var] = x
+            elif isinstance(x, Var):
+                param_list.append(x)
+            else:
+                raise TypeError("params can only contain Var or Buffer")
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
+
+    def with_attr(self, attr_key, attr_value):
+        """Create a new copy of the function and update the attribute
+
+        Parameters
+        ----------
+        attr_key : str
+            The attribute key to use.
+
+        attr_value : Object
+            The new attribute value.
+
+        Returns
+        -------
+        func : Function
+            A new copy of the function
+        """
+        return _ffi_api.PrimFuncWithAttr(
+            self, attr_key, tvm.runtime.convert(attr_value))
diff --git a/src/ir/function.cc b/src/ir/function.cc
index d3753d8..e7ccbbe 100644
--- a/src/ir/function.cc
+++ b/src/ir/function.cc
@@ -30,4 +30,5 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs")
 .set_body_typed([](BaseFunc func) {
   return func->attrs;
 });
+
 }  // namespace tvm
diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc
index 2799be0..56e77b7 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/printer/relay_text_printer.cc
@@ -99,7 +99,11 @@ class RelayTextPrinter :
   }
 
   Doc PrintFinal(const ObjectRef& node) {
-    if (node.as<ExprNode>()) {
+    if (node->IsInstance<BaseFuncNode>() &&
+        !node->IsInstance<relay::FunctionNode>()) {
+      // Temporarily skip non-relay functions.
+      // TODO(tvm-team) enhance the code to work for all functions
+    } else if (node.as<ExprNode>()) {
       Expr expr = Downcast<Expr>(node);
       dg_ = DependencyGraph::Create(&arena_, expr);
     }
@@ -122,7 +126,10 @@ class RelayTextPrinter :
   std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
 
   Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false) {
-    if (node.as<ExprNode>()) {
+    bool is_non_relay_func =
+        node->IsInstance<BaseFuncNode>() &&
+        !node->IsInstance<relay::FunctionNode>();
+    if (node.as<ExprNode>() && !is_non_relay_func) {
       return PrintExpr(Downcast<Expr>(node), meta, try_inline);
     } else if (node.as<TypeNode>()) {
       return PrintType(Downcast<Type>(node), meta);
@@ -134,7 +141,7 @@ class RelayTextPrinter :
       // default module.
       std::ostringstream os;
       os << node;
-      return Doc() << os.str();
+      return Doc::RawText(os.str());
     }
   }
 
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index c1bd710..63ad4dd 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -60,18 +60,6 @@ bool FunctionNode::UseDefaultCompiler() const {
   return !val.defined() || val->value == "default";
 }
 
-Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value) {
-  FunctionNode* node = func.CopyOnWrite();
-  if (node->attrs.defined()) {
-    node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
-  } else {
-    Map<std::string, ObjectRef> dict = {{attr_key, attr_value}};
-    node->attrs = DictAttrs(dict);
-  }
-  return func;
-}
-
-
 TVM_REGISTER_NODE_TYPE(FunctionNode);
 
 TVM_REGISTER_GLOBAL("relay._make.Function")
@@ -94,9 +82,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 TVM_REGISTER_GLOBAL("relay._expr.FunctionWithAttr")
 .set_body_typed(
-  [](Function func, std::string name, ObjectRef ref) {
-    return WithAttr(std::move(func), name, ref);
-});
+    [](Function func, std::string name, ObjectRef ref) {
+      return WithAttr(std::move(func), name, ref);
+    });
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
new file mode 100644
index 0000000..7464e3a
--- /dev/null
+++ b/src/tir/ir/function.cc
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/function.cc
+ * \brief The function data structure.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/function.h>
+
+namespace tvm {
+namespace tir {
+
+PrimFunc::PrimFunc(Array<tir::Var> params,
+                   Stmt body,
+                   Type ret_type,
+                   Map<tir::Var, Buffer> buffer_map,
+                   DictAttrs attrs) {
+  // Assume void-return type for now
+  // TODO(tvm-team) consider type deduction from body.
+  if (!ret_type.defined()) {
+    ret_type = VoidType();
+  }
+  auto n = make_object<PrimFuncNode>();
+  n->params = std::move(params);
+  n->body = std::move(body);
+  n->ret_type = std::move(ret_type);
+  n->buffer_map = std::move(buffer_map);
+  n->attrs = std::move(attrs);
+  data_ = std::move(n);
+}
+
+FuncType PrimFuncNode::func_type_annotation() const {
+  Array<Type> param_types;
+  for (auto param : this->params) {
+    param_types.push_back(GetType(param));
+  }
+  return FuncType(param_types, ret_type, {}, {});
+}
+
+TVM_REGISTER_NODE_TYPE(PrimFuncNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+.set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprPrinter* p) {
+  // TODO(tvm-team) redirect to Text printer once we have a good text format.
+  auto* node = static_cast<const PrimFuncNode*>(ref.get());
+  p->stream << "PrimFunc(" << node->params << ") ";
+  if (node->attrs.defined()) {
+    p->stream << "attrs=" << node->attrs;
+  }
+  p->stream << " {\n";
+  p->indent += 2;
+  p->Print(node->body);
+  p->indent -= 2;
+  p->stream << "}\n";
+});
+
+
+TVM_REGISTER_GLOBAL("tir.PrimFunc")
+.set_body_typed([](Array<tir::Var> params,
+                   Stmt body,
+                   Type ret_type,
+                   Map<tir::Var, Buffer> buffer_map,
+                   DictAttrs attrs) {
+  return PrimFunc(params, body, ret_type, buffer_map, attrs);
+});
+
+
+TVM_REGISTER_GLOBAL("tir.PrimFuncWithAttr")
+.set_body_typed([](PrimFunc func, std::string name, ObjectRef ref) {
+  return WithAttr(std::move(func), name, ref);
+});
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc
index 2882fea..b073643 100644
--- a/src/tir/ir/op.cc
+++ b/src/tir/ir/op.cc
@@ -32,6 +32,18 @@ namespace tvm {
 
 using namespace tir;
 
+
+Type GetType(const PrimExpr& expr) {
+  runtime::DataType dtype = expr.dtype();
+  // These types already implies the specific type.
+  if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) {
+    return PrimType(dtype);
+  }
+  // TODO(tqchen): add recursive type inference for Var and Call here
+  // once we introduced the corresponding fields to the IR.
+  return PrimType(dtype);
+}
+
 // simple cast that only checks if type matches and cast
 inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) {
   if (value.dtype() == t) return value;
diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_tir_nodes.py
similarity index 92%
rename from tests/python/unittest/test_lang_basic.py
rename to tests/python/unittest/test_tir_nodes.py
index c279194..3a7985d 100644
--- a/tests/python/unittest/test_lang_basic.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -19,6 +19,7 @@ from tvm import te
 import numpy as np
 
 
+
 def test_const():
     x = tvm.tir.const(1, "int32")
     print(x.dtype)
@@ -46,8 +47,8 @@ def test_make():
     x = tvm.tir.const(1, "int32")
     y = te.var("x")
     z = x + y
-    assert isinstance(tvm.te.max(x, y), tvm.tir.Max)
-    assert isinstance(tvm.te.min(x, y), tvm.tir.Min)
+    assert isinstance(tvm.tir.max(x, y), tvm.tir.Max)
+    assert isinstance(tvm.tir.min(x, y), tvm.tir.Min)
 
 
 def test_ir():
@@ -111,7 +112,6 @@ def test_stmt():
                  tvm.tir.For.Serial, 0,
                  x)
 
-
 def test_dir():
     x = te.var('x')
     dir(x)
@@ -247,8 +247,26 @@ def test_equality_string_imm():
     x == y.value
     x == y
 
+def test_prim_func():
+    x = te.var('x')
+    y = te.var('y')
+    b = tvm.tir.decl_buffer((x,), "float32")
+    stmt = tvm.tir.LetStmt(
+        x, 10, tvm.tir.Evaluate(x + 1));
+
+    func = tvm.tir.PrimFunc(
+        [x, y, b], stmt)
+
+    assert func.buffer_map[func.params[2]].same_as(b)
+
+    assert len(func.buffer_map) == 1
+    f2 = func.with_attr("calling_conv", 1)
+    assert f2.attrs["calling_conv"].value == 1
+    assert func.attrs is None
+
 
 if __name__ == "__main__":
+    test_prim_func()
     test_cast()
     test_attr()
     test_const()