You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/01/07 00:59:33 UTC

[tvm] branch main updated: [TIR][REFACTOR] Enforce allocate to use the correct var pointer hint. (#7216)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d777e7c  [TIR][REFACTOR] Enforce allocate to use the correct var pointer hint. (#7216)
d777e7c is described below

commit d777e7c612cf7a9aae4d8433c36f031c6b6f985c
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Wed Jan 6 19:59:12 2021 -0500

    [TIR][REFACTOR] Enforce allocate to use the correct var pointer hint. (#7216)
    
    * [TIR][REFACTOR] Enforce allocate to only accept buffer_var with correct PtrType.
    
    This is a refactoring step to cleanup legacy issue of opaque buffer
    var without ptr type information. Now all the allocation comes with the right
    pointer data type. Places touched:
    
    - TVMScript Parser: add the right info to get the correct pointer type.
    - Cross thread all reduce: set the right pointer type.
    - Storage rewrite: setup the right pointer type.
    - Custom dtype: remap the variables with new pointer type.
    
    x
    
    * Address comments
    
    Co-authored-by: Tristan Konolige <tr...@gmail.com>
    
    Co-authored-by: Tristan Konolige <tr...@gmail.com>
---
 include/tvm/tir/op.h                          |   2 +-
 python/tvm/script/parser.py                   |  25 +++--
 python/tvm/script/scope_handler.py            |  13 ++-
 python/tvm/tir/buffer.py                      |   5 +-
 src/driver/driver_api.cc                      |   3 +-
 src/target/source/codegen_cuda.cc             |   6 +-
 src/te/operation/cross_thread_reduction.cc    |   6 +-
 src/tir/ir/buffer.cc                          |  14 ++-
 src/tir/ir/stmt.cc                            |   9 +-
 src/tir/ir/stmt_functor.cc                    |  14 ++-
 src/tir/transforms/lower_custom_datatypes.cc  | 147 ++++++++++++++++++--------
 src/tir/transforms/lower_thread_allreduce.cc  |  16 +--
 src/tir/transforms/storage_rewrite.cc         |  34 +++---
 tests/cpp/ir_functor_test.cc                  |  10 +-
 tests/python/unittest/test_tir_constructor.py |   1 +
 15 files changed, 209 insertions(+), 96 deletions(-)

diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 61481d9..4a907fc 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -1241,7 +1241,7 @@ inline void DivAmbiguityError(const TA& a) {
                 "please call div, indexdiv/indexmod, "
                 "floordiv/floormod or truncdiv/truncmod directly "
                 "to avoid ambiguity in the code. "
-                "Checkout these functions in expr_operator.h.");
+                "Checkout these functions in tir/op.h.");
 }
 
 // The following code are not intended to be used in the codebase.
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index db976d0..33b0bab 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -230,6 +230,19 @@ class TVMScriptParser(Transformer):
         """Match the arguments of a function call in the AST to the required
         arguments of the function. This handles positional arguments,
         positional arguments specified by name, keyword arguments, and varargs.
+
+        Parameters
+        ----------
+        func : Function
+            The function that provides the signature
+
+        node_call: ast.Call
+            The AST call node that calls into the function.
+
+        Returns
+        -------
+        arg_list : list
+            The parsed positional argument.
         """
         assert isinstance(node_call, ast.Call)
         # collect arguments
@@ -435,8 +448,8 @@ class TVMScriptParser(Transformer):
                         node.rhs.span,
                     )
                 # Pattern 4
-                func.enter_scope(node, self.context)
                 arg_list = self.parse_arg_list(func, node.rhs)
+                func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
                 func.body = self.parse_body(node)
                 return func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
             elif isinstance(func, SpecialStmt):
@@ -532,9 +545,9 @@ class TVMScriptParser(Transformer):
         self.current_col_offset = node.span.start_column
         self.context.new_scope(nodes=node.body.stmts)
         # for scope handler process the scope
-        func.enter_scope(node, self.context)
-        func.body = self.parse_body(node)
         arg_list = self.parse_arg_list(func, node.rhs)
+        func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
+        func.body = self.parse_body(node)
         res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
         # exit the scope
         self.context.pop_scope()
@@ -571,9 +584,9 @@ class TVMScriptParser(Transformer):
         self.current_col_offset = node.body.span.start_column
         self.context.new_scope(nodes=node.body.stmts)
         # with scope handler process the scope
-        func.enter_scope(node, self.context)
-        func.body = self.parse_body(node)
         arg_list = self.parse_arg_list(func, node.rhs)
+        func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
+        func.body = self.parse_body(node)
         res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
         # exit the scope
         self.context.pop_scope()
@@ -689,7 +702,7 @@ class TVMScriptParser(Transformer):
         if isinstance(func, Intrin) and func.stmt:
             return func.handle(arg_list, node.call.func_name.span)
         elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol:
-            func.enter_scope(node, self.context)
+            func.enter_scope(node, self.context, arg_list, node.call.func_name.span)
             func.body = self.parse_body(node)
             return func.exit_scope(node, self.context, arg_list, node.call.func_name.span)
         elif isinstance(func, SpecialStmt) and not func.def_symbol:
diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py
index 7f252e3..21ed7f6 100644
--- a/python/tvm/script/scope_handler.py
+++ b/python/tvm/script/scope_handler.py
@@ -35,7 +35,7 @@ class ScopeHandler:
     def signature(self):
         return "tir." + self.func.__name__, get_param_list(self.func)
 
-    def enter_scope(self, node, context):
+    def enter_scope(self, node, context, arg_list, span):
         pass
 
     def exit_scope(self, node, context, arg_list, span):
@@ -86,7 +86,7 @@ class Allocate(WithScopeHandler):
         super().__init__(allocate, concise_scope=True, def_symbol=True)
         self.buffer_var = None
 
-    def enter_scope(self, node, context):
+    def enter_scope(self, node, context, arg_list, span):
         # define buffer vars in symbol table
         if isinstance(node, ast.With):
             names = WithScopeHandler.get_optional_var_names(node, context)
@@ -98,7 +98,12 @@ class Allocate(WithScopeHandler):
         else:
             raise Exception("Internal Bug")
 
-        self.buffer_var = tvm.te.var(name, "handle", span=from_synr_span(node.lhs.id.span))
+        def setup_buffer_var(extents, dtype, scope, condition=True, span=None):
+            """Setup buffer var for a given type."""
+            buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
+            self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
+
+        setup_buffer_var(*arg_list, span=from_synr_span(node.lhs.id.span))
         context.update_symbol(name, self.buffer_var)
 
 
@@ -187,7 +192,7 @@ class ForScopeHandler(ScopeHandler):
         super().__init__(func)
         self.loop_vars = None
 
-    def enter_scope(self, node, context):
+    def enter_scope(self, node, context, arg_list, span):
         assert isinstance(node, ast.For)
 
         loop_var_names = list()
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index 2f50aa8..95966a5 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -247,7 +247,10 @@ def decl_buffer(
         shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
         elem_offset = Var("%s_elem_offset" % name, shape_dtype)
     if data is None:
-        data = Var(name, PointerType(PrimType(dtype)), span)
+        # Bool is represented as uint1 in the IR, but stored as int8
+        storage_type = PrimType(dtype)
+        storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type
+        data = Var(name, PointerType(storage_type), span)
     return _ffi_api.Buffer(
         data,
         dtype,
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index f88b621..bbbb7e3 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -69,7 +69,8 @@ Target DefaultTargetHost(Target target) {
 
 tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name,
                                       int data_alignment, int offset_factor, bool compact) {
-  auto data = tir::Var(name, PointerType(PrimType(dtype)));
+  DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
+  auto data = tir::Var(name, PointerType(PrimType(storage_dtype)));
   bool has_any = false;
   if (!compact) {
     for (const auto& it : shape) {
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index c0fb39f..6c73716 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -581,7 +581,11 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
   int32_t constant_size = op->constant_allocation_size();
   ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
   const VarNode* buffer = op->buffer_var.as<VarNode>();
-  std::string scope = alloc_storage_scope_.at(buffer);
+  auto it = alloc_storage_scope_.find(buffer);
+  ICHECK(it != alloc_storage_scope_.end())
+      << "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key";
+
+  std::string scope = it->second;
   if (scope.find("wmma.") == 0) {
     if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
       ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc
index b0fb9b6..da20dd8 100644
--- a/src/te/operation/cross_thread_reduction.cc
+++ b/src/te/operation/cross_thread_reduction.cc
@@ -145,7 +145,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
     Array<PrimExpr> lhs;
     for (size_t i = 0; i < size; ++i) {
       DataType t = reduces[i]->dtype;
-      normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), DataType::Handle());
+      normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i),
+                                      PointerType(PrimType(t)));
       lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes())));
     }
     Array<PrimExpr> init_value = combiner->identity_element;
@@ -175,7 +176,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
   freduce_args.push_back(const_true(1));
   std::vector<Var> res_handles(size);
   for (size_t idx = 0; idx < size; ++idx) {
-    res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle());
+    DataType dtype = reduces[idx]->dtype;
+    res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype)));
     freduce_args.push_back(res_handles[idx]);
   }
 
diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc
index 23a2b3a..1667eb7 100644
--- a/src/tir/ir/buffer.cc
+++ b/src/tir/ir/buffer.cc
@@ -46,8 +46,9 @@ Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, Array<PrimExpr> array) {
 }
 
 Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, Span span) {
-  return Buffer(Var(name, PointerType(PrimType(dtype)), span), dtype, shape, Array<PrimExpr>(),
-                PrimExpr(), name, "", 0, 0, kDefault, span);
+  DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
+  return Buffer(Var(name, PointerType(PrimType(storage_dtype)), span), dtype, shape,
+                Array<PrimExpr>(), PrimExpr(), name, "", 0, 0, kDefault, span);
 }
 
 // Split the given expression w.r.t the add operator
@@ -384,9 +385,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
 Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
                PrimExpr elem_offset, String name, String scope, int data_alignment,
                int offset_factor, BufferType buffer_type, Span span) {
-  ICHECK(IsPointerType(data->type_annotation, dtype))
+  DataType storage_dtype = dtype;
+  // specially handle bool
+  if (storage_dtype == DataType::Bool()) {
+    storage_dtype = DataType::Int(8);
+  }
+  ICHECK(IsPointerType(data->type_annotation, storage_dtype))
       << "Buffer data field expect to have the right pointer type annotation"
-      << " annotation=" << data->type_annotation << ", dtype=" << dtype;
+      << " annotation=" << data->type_annotation << ", storage_dtype=" << storage_dtype;
 
   auto n = make_object<BufferNode>();
   n->data = std::move(data);
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 86960d9..fd03046 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -274,9 +274,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 // Allocate
 Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
                    Stmt body, Span span) {
-  // TODO(tvm-team): Add invariant check to make sure
-  // IsPointerPType(buffer_var->type_annotation, dtype)
-  // once we fix the allocate tvm script printing.
+  CHECK(IsPointerType(buffer_var->type_annotation, dtype))
+      << "The allocated data type (" << dtype
+      << ") does not match the type annotation of the buffer " << buffer_var << " ("
+      << buffer_var->type_annotation
+      << "). The data type should be an element of the pointer type.";
+
   for (size_t i = 0; i < extents.size(); ++i) {
     ICHECK(extents[i].defined());
     ICHECK(extents[i].dtype().is_scalar());
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index 529380b..e0ccb49 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -480,7 +480,6 @@ class IRSubstitue : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
-    // NOTE: we do not explicit recursivly mutate op->buffer_var
     PrimExpr ret = StmtExprMutator::VisitExpr_(op);
     op = ret.as<LoadNode>();
     if (auto mapped_var = vmap_(op->buffer_var)) {
@@ -491,7 +490,6 @@ class IRSubstitue : public StmtExprMutator {
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
-    // NOTE: we do not explicit recursivly mutate op->buffer_var
     Stmt ret = StmtExprMutator::VisitStmt_(op);
     op = ret.as<StoreNode>();
     if (auto mapped_var = vmap_(op->buffer_var)) {
@@ -501,6 +499,18 @@ class IRSubstitue : public StmtExprMutator {
     }
   }
 
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    Stmt ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<AttrStmtNode>();
+    // remap var node in attr
+    if (const auto* var_node = op->node.as<VarNode>()) {
+      if (auto mapped_var = vmap_(GetRef<Var>(var_node))) {
+        return AttrStmt(mapped_var, op->attr_key, op->value, op->body);
+      }
+    }
+    return ret;
+  }
+
  private:
   std::function<Optional<PrimExpr>(const Var&)> vmap_;
 };
diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc
index a3e5a92..21f1b18 100644
--- a/src/tir/transforms/lower_custom_datatypes.cc
+++ b/src/tir/transforms/lower_custom_datatypes.cc
@@ -44,14 +44,14 @@ class CustomDatatypesLowerer : public StmtExprMutator {
  public:
   explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {}
 
-  inline PrimExpr VisitExpr_(const CastNode* op) final {
+  PrimExpr VisitExpr_(const CastNode* op) final {
     auto type_code = op->dtype.code();
     auto src_type_code = op->value.dtype().code();
     // If either datatype is a registered custom datatype, we must lower.
-    bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
-                       datatype::Registry::Global()->GetTypeRegistered(src_type_code);
+    bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
+                         datatype::Registry::Global()->GetTypeRegistered(src_type_code);
     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    if (toBeLowered) {
+    if (to_be_lowered) {
       auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code);
       ICHECK(lower) << "Cast lowering function for target " << target_ << " destination type "
                     << static_cast<unsigned>(type_code) << " source type "
@@ -61,7 +61,7 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     return expr;
   }
 
-  inline PrimExpr VisitExpr_(const FloatImmNode* imm) final {
+  PrimExpr VisitExpr_(const FloatImmNode* imm) final {
     auto type_code = imm->dtype.code();
     auto e = GetRef<PrimExpr>(imm);
     if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
@@ -73,35 +73,86 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     return e;
   }
 
-  inline Stmt VisitStmt_(const AllocateNode* allocate) final {
-    bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code());
-    Stmt stmt = StmtExprMutator::VisitStmt_(allocate);
-    allocate = stmt.as<AllocateNode>();
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    Var var = GetRef<Var>(op);
 
-    if (toBeLowered) {
+    auto itr = var_remap_.find(var);
+    if (itr != var_remap_.end()) {
+      return itr->second;
+    } else {
+      return std::move(var);
+    }
+  }
+
+  Stmt VisitStmt_(const AllocateNode* allocate) final {
+    bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code());
+
+    if (to_be_lowered) {
       auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes());
-      return Allocate(allocate->buffer_var, new_allocate_type, allocate->extents,
-                      allocate->condition, allocate->body);
+      auto new_buffer_var =
+          Var(allocate->buffer_var->name_hint, PointerType(PrimType(new_allocate_type)));
+      var_remap_[allocate->buffer_var] = new_buffer_var;
+
+      Stmt stmt = StmtExprMutator::VisitStmt_(allocate);
+      allocate = stmt.as<AllocateNode>();
+
+      return Allocate(new_buffer_var, new_allocate_type, allocate->extents, allocate->condition,
+                      allocate->body);
+    } else {
+      return StmtExprMutator::VisitStmt_(allocate);
     }
-    return stmt;
   }
 
-  inline PrimExpr VisitExpr_(const LoadNode* load) final {
-    bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
+  PrimExpr VisitExpr_(const LoadNode* load) final {
+    bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
     PrimExpr expr = StmtExprMutator::VisitExpr_(load);
     load = expr.as<LoadNode>();
-    if (toBeLowered) {
+    if (to_be_lowered) {
       auto new_load_type = DataType::UInt(load->dtype.bits());
-      return Load(new_load_type, load->buffer_var, load->index, load->predicate);
+      auto buffer_var = load->buffer_var;
+      auto it = var_remap_.find(buffer_var);
+      if (it != var_remap_.end()) {
+        buffer_var = it->second;
+      }
+      return Load(new_load_type, buffer_var, load->index, load->predicate);
     }
     return expr;
   }
 
-  inline PrimExpr VisitExpr_(const CallNode* call) final {
-    bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(call->dtype.code());
+  Stmt VisitStmt_(const StoreNode* op) final {
+    Stmt ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<StoreNode>();
+
+    auto it = var_remap_.find(op->buffer_var);
+    if (it != var_remap_.end()) {
+      return Store(it->second, op->value, op->index, op->predicate);
+    } else {
+      return ret;
+    }
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    Stmt ret = StmtExprMutator::VisitStmt_(op);
+    op = ret.as<AttrStmtNode>();
+    // Due to legacy reasons, some attr node can contain
+    // information(e.g. alignment) of buffer variables.
+    // remap these vars when needed
+    // TODO(tvm-team): remove the rewriting once the buffer var
+    // attrs are being refactored into the corresponding definition node
+    if (const auto* var_node = op->node.as<VarNode>()) {
+      auto it = var_remap_.find(GetRef<Var>(var_node));
+      if (it != var_remap_.end()) {
+        return AttrStmt(it->second, op->attr_key, op->value, op->body);
+      }
+    }
+    return ret;
+  }
+
+  PrimExpr VisitExpr_(const CallNode* call) final {
+    bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(call->dtype.code());
     PrimExpr expr = StmtExprMutator::VisitExpr_(call);
     call = expr.as<CallNode>();
-    if (toBeLowered) {
+    if (to_be_lowered) {
       auto op = call->op.as<OpNode>();
       ICHECK(op != nullptr) << "Lowering non-intrinsic Calls not implemented";
       auto lower = datatype::GetIntrinLowerFunc(target_, op->name, call->dtype.code());
@@ -113,38 +164,42 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     return expr;
   }
 
-#define DEFINE_MUTATE(OP, NodeName)                                                \
-  inline PrimExpr VisitExpr_(const NodeName* op) final {                           \
-    auto type_code = op->dtype.code();                                             \
-    bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);                               \
-    op = expr.as<NodeName>();                                                      \
-    if (toBeLowered) {                                                             \
-      auto lower = datatype::Get##OP##LowerFunc(target_, type_code);               \
-      ICHECK(lower) << #OP " lowering function for target " << target_ << " type " \
-                    << static_cast<unsigned>(type_code) << " not found";           \
-      return (*lower)(expr);                                                       \
-    }                                                                              \
-    return expr;                                                                   \
+#define TVM_DEFINE_MUTATE_CUSTOM_DTYPE(OP, NodeName)                                 \
+  PrimExpr VisitExpr_(const NodeName* op) final {                                    \
+    auto type_code = op->dtype.code();                                               \
+    bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);                                 \
+    op = expr.as<NodeName>();                                                        \
+    if (to_be_lowered) {                                                             \
+      auto lower = datatype::Get##OP##LowerFunc(target_, type_code);                 \
+      ICHECK(lower) << #OP " lowering function for target " << target_ << " type "   \
+                    << static_cast<unsigned>(type_code) << " not found";             \
+      return (*lower)(expr);                                                         \
+    }                                                                                \
+    return expr;                                                                     \
   }
 
-  DEFINE_MUTATE(Add, AddNode);
-  DEFINE_MUTATE(Sub, SubNode);
-  DEFINE_MUTATE(Mul, MulNode);
-  DEFINE_MUTATE(Div, DivNode);
-  DEFINE_MUTATE(Mod, ModNode);
-  DEFINE_MUTATE(Min, MinNode);
-  DEFINE_MUTATE(Max, MaxNode);
-  DEFINE_MUTATE(EQ, EQNode);
-  DEFINE_MUTATE(NE, NENode);
-  DEFINE_MUTATE(LT, LTNode);
-  DEFINE_MUTATE(LE, LENode);
-  DEFINE_MUTATE(GT, GTNode);
-  DEFINE_MUTATE(GE, GENode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Add, AddNode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Sub, SubNode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mul, MulNode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Div, DivNode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mod, ModNode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Min, MinNode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Max, MaxNode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(EQ, EQNode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(NE, NENode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LT, LTNode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LE, LENode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GT, GTNode);
+  TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GE, GENode);
   // Later changes may need to add more mutate functions as we support workloads with more ops.
 
+#undef TVM_DEFINE_MUTATE_CUSTOM_DTYPE
+
  private:
   std::string target_;
+  // remap buffer vars
+  std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
 };
 
 namespace transform {
diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc
index c24e26b..f6cb096 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -224,14 +224,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       PrimExpr index(0);
 
       for (size_t idx = 0; idx < size; ++idx) {
-        shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle());
+        Type ptr_type = PointerType(PrimType(types[idx]));
+        shared_bufs[idx] = Var("red_buf" + std::to_string(idx), ptr_type);
         PrimExpr pred = const_true(types[idx].lanes());
         seq.emplace_back(Store(shared_bufs[idx], values[idx], index, pred));
 
         // Uses a local variable to store the shuffled data.
         // Later on, this allocation will be properly attached to this statement.
-        Var var("t" + std::to_string(idx), types[idx]);
-        Stmt s = Allocate(var, var.dtype(), {PrimExpr(1)}, pred, Evaluate(0));
+        Var var("t" + std::to_string(idx), ptr_type);
+        Stmt s = Allocate(var, types[idx], {PrimExpr(1)}, pred, Evaluate(0));
         local_vars.push_back(s);
       }
 
@@ -239,14 +240,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       // a divergent control flow. Here it uses a variable to cache the current
       // active channels.
       //
-      Var mask_var("mask", DataType::UInt(32));
+      DataType mask_dtype = DataType::UInt(32);
+      Var mask_var("mask", PointerType(PrimType(mask_dtype)));
       {
         PrimExpr pred = const_true(1);
-        PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {});
+        PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
         seq.emplace_back(Store(mask_var, mask, index, pred));
         // Push allocation with an empty body. Later this will be fixed
         // when the entire body is ready.
-        auto stmt = Allocate(mask_var, mask_var->dtype, {PrimExpr(1)}, pred, Evaluate(0));
+        auto stmt = Allocate(mask_var, mask_dtype, {PrimExpr(1)}, pred, Evaluate(0));
         local_vars.push_back(stmt);
       }
 
@@ -338,7 +340,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       // previous iteration on the same buffer.
       seq.emplace_back(SyncThread("shared"));
       for (size_t idx = 0; idx < size; ++idx) {
-        shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle());
+        shared_bufs[idx] = Var("red_buf" + std::to_string(idx), PointerType(PrimType(types[idx])));
         PrimExpr pred = const_true(types[idx].lanes());
         seq.emplace_back(Store(shared_bufs[idx], values[idx],
                                BufIndex(reduce_index, group_index, reduce_extent), pred));
diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc
index 78c5ca7..d4c5ca0 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -23,6 +23,7 @@
  *  Re-write data access to enable memory sharing when possible.
  */
 #include <tvm/arith/analyzer.h>
+#include <tvm/ir/type.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/target/target_info.h>
 #include <tvm/tir/analysis.h>
@@ -934,7 +935,12 @@ class VectorAllocRewriter : public StmtExprMutator {
       if (me->base % factor == 0 && me->coeff % factor == 0) {
         extents.Set(extents.size() - 1,
                     extents[extents.size() - 1] / make_const(extents[0].dtype(), factor));
-        return Allocate(op->buffer_var, tvec[0], extents, op->condition, op->body);
+        // create a new buffer var
+        DataType new_dtype = tvec[0];
+        Var new_buffer_var(op->buffer_var->name_hint, PointerType(PrimType(new_dtype)));
+        // update the remap req.
+        var_remap_.Set(op->buffer_var, new_buffer_var);
+        return Allocate(new_buffer_var, new_dtype, extents, op->condition, op->body);
       }
     }
     return stmt;
@@ -949,23 +955,21 @@ class VectorAllocRewriter : public StmtExprMutator {
 
   // Internal access map
   std::unordered_map<const VarNode*, std::vector<DataType> > acc_map_;
+  // Variables to remap
+  Map<tir::Var, PrimExpr> var_remap_;
   // internal analyzer
   arith::Analyzer analyzer_;
 };
 
-Stmt StorageRewrite(Stmt stmt) {
-  stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true);
-  return VectorAllocRewriter()(std::move(stmt));
-}
-
 PrimFunc PointerValueTypeRewrite(PrimFunc f) {
   auto* n = f.CopyOnWrite();
   VectorAllocRewriter rewriter;
-  n->body = rewriter(n->body);
+  n->body = rewriter(std::move(n->body));
 
+  Map<tir::Var, PrimExpr> var_remap = std::move(rewriter.var_remap_);
   Array<tir::Var> args;
-  Map<tir::Var, PrimExpr> remap_vars;
 
+  // rewrite paramters if needed.
   for (Var var : f->params) {
     if (var.dtype().is_handle()) {
       const auto& tvec = rewriter.acc_map_[var.get()];
@@ -973,15 +977,14 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) {
       if (tvec.size() == 1) {
         tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0])));
         args.push_back(new_var);
-        remap_vars.Set(var, new_var);
-
+        var_remap.Set(var, new_var);
       } else {
         // always set data type to be non vectorized so
         // load/store can still work via scalarization
         if (tvec.size() != 0 && !var->type_annotation.defined()) {
           tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0].with_lanes(1))));
           args.push_back(new_var);
-          remap_vars.Set(var, new_var);
+          var_remap.Set(var, new_var);
         } else {
           args.push_back(var);
         }
@@ -991,9 +994,13 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) {
     }
   }
 
+  // no variable remap is needed.
+  if (var_remap.size() == 0) return f;
+
+  // remap the variables.
   ICHECK_EQ(args.size(), n->params.size());
   n->params = args;
-  n->body = Substitute(n->body, remap_vars);
+  n->body = Substitute(n->body, var_remap);
   return f;
 }
 
@@ -1003,8 +1010,7 @@ Pass StorageRewrite() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
     n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true);
-    n->body = VectorAllocRewriter()(std::move(n->body));
-    return f;
+    return PointerValueTypeRewrite(std::move(f));
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
 }
diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc
index 683caaa..9be8398 100644
--- a/tests/cpp/ir_functor_test.cc
+++ b/tests/cpp/ir_functor_test.cc
@@ -114,8 +114,9 @@ TEST(IRF, StmtVisitor) {
   auto fmaketest = [&]() {
     auto z = x + 1;
     Stmt body = Evaluate(z);
-    Var buffer("b", DataType::Handle());
-    return Allocate(buffer, DataType::Float(32), {z, z}, const_true(), body);
+    DataType dtype = DataType::Float(32);
+    Var buffer("b", PointerType(PrimType(dtype)));
+    return Allocate(buffer, dtype, {z, z}, const_true(), body);
   };
   v(fmaketest());
   ICHECK_EQ(v.count, 3);
@@ -140,8 +141,9 @@ TEST(IRF, StmtMutator) {
   auto fmakealloc = [&]() {
     auto z = x + 1;
     Stmt body = Evaluate(z);
-    Var buffer("b", DataType::Handle());
-    return Allocate(buffer, DataType::Float(32), {1, z}, const_true(), body);
+    DataType dtype = DataType::Float(32);
+    Var buffer("b", PointerType(PrimType(dtype)));
+    return Allocate(buffer, dtype, {1, z}, const_true(), body);
   };
 
   auto fmakeif = [&]() {
diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py
index 3cde5d7..2bf4ba5 100644
--- a/tests/python/unittest/test_tir_constructor.py
+++ b/tests/python/unittest/test_tir_constructor.py
@@ -154,6 +154,7 @@ def test_stmt_constructor():
     assert x.index.value == 10
     assert x.value.value == 1
 
+    buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32")))
     x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop)
     assert isinstance(x, tvm.tir.Allocate)
     assert x.dtype == "float32"