You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/01/07 00:59:33 UTC
[tvm] branch main updated: [TIR][REFACTOR] Enforce allocate to use
the correct var pointer hint. (#7216)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d777e7c [TIR][REFACTOR] Enforce allocate to use the correct var pointer hint. (#7216)
d777e7c is described below
commit d777e7c612cf7a9aae4d8433c36f031c6b6f985c
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Wed Jan 6 19:59:12 2021 -0500
[TIR][REFACTOR] Enforce allocate to use the correct var pointer hint. (#7216)
* [TIR][REFACTOR] Enforce allocate to only accept buffer_var with correct PtrType.
This is a refactoring step to cleanup legacy issue of opaque buffer
var without ptr type information. Now all the allocation comes with the right
pointer data type. Places touched:
- TVMScript Parser: add the right info to get the correct pointer type.
- Cross thread all reduce: set the right pointer type.
- Storage rewrite: setup the right pointer type.
- Custom dtype: remap the variables with new pointer type.
x
* Address comments
Co-authored-by: Tristan Konolige <tr...@gmail.com>
Co-authored-by: Tristan Konolige <tr...@gmail.com>
---
include/tvm/tir/op.h | 2 +-
python/tvm/script/parser.py | 25 +++--
python/tvm/script/scope_handler.py | 13 ++-
python/tvm/tir/buffer.py | 5 +-
src/driver/driver_api.cc | 3 +-
src/target/source/codegen_cuda.cc | 6 +-
src/te/operation/cross_thread_reduction.cc | 6 +-
src/tir/ir/buffer.cc | 14 ++-
src/tir/ir/stmt.cc | 9 +-
src/tir/ir/stmt_functor.cc | 14 ++-
src/tir/transforms/lower_custom_datatypes.cc | 147 ++++++++++++++++++--------
src/tir/transforms/lower_thread_allreduce.cc | 16 +--
src/tir/transforms/storage_rewrite.cc | 34 +++---
tests/cpp/ir_functor_test.cc | 10 +-
tests/python/unittest/test_tir_constructor.py | 1 +
15 files changed, 209 insertions(+), 96 deletions(-)
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 61481d9..4a907fc 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -1241,7 +1241,7 @@ inline void DivAmbiguityError(const TA& a) {
"please call div, indexdiv/indexmod, "
"floordiv/floormod or truncdiv/truncmod directly "
"to avoid ambiguity in the code. "
- "Checkout these functions in expr_operator.h.");
+ "Checkout these functions in tir/op.h.");
}
// The following code are not intended to be used in the codebase.
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index db976d0..33b0bab 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -230,6 +230,19 @@ class TVMScriptParser(Transformer):
"""Match the arguments of a function call in the AST to the required
arguments of the function. This handles positional arguments,
positional arguments specified by name, keyword arguments, and varargs.
+
+ Parameters
+ ----------
+ func : Function
+ The function that provides the signature
+
+ node_call: ast.Call
+ The AST call node that calls into the function.
+
+ Returns
+ -------
+ arg_list : list
+ The parsed positional argument.
"""
assert isinstance(node_call, ast.Call)
# collect arguments
@@ -435,8 +448,8 @@ class TVMScriptParser(Transformer):
node.rhs.span,
)
# Pattern 4
- func.enter_scope(node, self.context)
arg_list = self.parse_arg_list(func, node.rhs)
+ func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
func.body = self.parse_body(node)
return func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
elif isinstance(func, SpecialStmt):
@@ -532,9 +545,9 @@ class TVMScriptParser(Transformer):
self.current_col_offset = node.span.start_column
self.context.new_scope(nodes=node.body.stmts)
# for scope handler process the scope
- func.enter_scope(node, self.context)
- func.body = self.parse_body(node)
arg_list = self.parse_arg_list(func, node.rhs)
+ func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
+ func.body = self.parse_body(node)
res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
# exit the scope
self.context.pop_scope()
@@ -571,9 +584,9 @@ class TVMScriptParser(Transformer):
self.current_col_offset = node.body.span.start_column
self.context.new_scope(nodes=node.body.stmts)
# with scope handler process the scope
- func.enter_scope(node, self.context)
- func.body = self.parse_body(node)
arg_list = self.parse_arg_list(func, node.rhs)
+ func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
+ func.body = self.parse_body(node)
res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
# exit the scope
self.context.pop_scope()
@@ -689,7 +702,7 @@ class TVMScriptParser(Transformer):
if isinstance(func, Intrin) and func.stmt:
return func.handle(arg_list, node.call.func_name.span)
elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol:
- func.enter_scope(node, self.context)
+ func.enter_scope(node, self.context, arg_list, node.call.func_name.span)
func.body = self.parse_body(node)
return func.exit_scope(node, self.context, arg_list, node.call.func_name.span)
elif isinstance(func, SpecialStmt) and not func.def_symbol:
diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py
index 7f252e3..21ed7f6 100644
--- a/python/tvm/script/scope_handler.py
+++ b/python/tvm/script/scope_handler.py
@@ -35,7 +35,7 @@ class ScopeHandler:
def signature(self):
return "tir." + self.func.__name__, get_param_list(self.func)
- def enter_scope(self, node, context):
+ def enter_scope(self, node, context, arg_list, span):
pass
def exit_scope(self, node, context, arg_list, span):
@@ -86,7 +86,7 @@ class Allocate(WithScopeHandler):
super().__init__(allocate, concise_scope=True, def_symbol=True)
self.buffer_var = None
- def enter_scope(self, node, context):
+ def enter_scope(self, node, context, arg_list, span):
# define buffer vars in symbol table
if isinstance(node, ast.With):
names = WithScopeHandler.get_optional_var_names(node, context)
@@ -98,7 +98,12 @@ class Allocate(WithScopeHandler):
else:
raise Exception("Internal Bug")
- self.buffer_var = tvm.te.var(name, "handle", span=from_synr_span(node.lhs.id.span))
+ def setup_buffer_var(extents, dtype, scope, condition=True, span=None):
+ """Setup buffer var for a given type."""
+ buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
+ self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
+
+ setup_buffer_var(*arg_list, span=from_synr_span(node.lhs.id.span))
context.update_symbol(name, self.buffer_var)
@@ -187,7 +192,7 @@ class ForScopeHandler(ScopeHandler):
super().__init__(func)
self.loop_vars = None
- def enter_scope(self, node, context):
+ def enter_scope(self, node, context, arg_list, span):
assert isinstance(node, ast.For)
loop_var_names = list()
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index 2f50aa8..95966a5 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -247,7 +247,10 @@ def decl_buffer(
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
elem_offset = Var("%s_elem_offset" % name, shape_dtype)
if data is None:
- data = Var(name, PointerType(PrimType(dtype)), span)
+ # Bool is represented as uint1 in the IR, but stored as int8
+ storage_type = PrimType(dtype)
+ storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type
+ data = Var(name, PointerType(storage_type), span)
return _ffi_api.Buffer(
data,
dtype,
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index f88b621..bbbb7e3 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -69,7 +69,8 @@ Target DefaultTargetHost(Target target) {
tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name,
int data_alignment, int offset_factor, bool compact) {
- auto data = tir::Var(name, PointerType(PrimType(dtype)));
+ DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
+ auto data = tir::Var(name, PointerType(PrimType(storage_dtype)));
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index c0fb39f..6c73716 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -581,7 +581,11 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
int32_t constant_size = op->constant_allocation_size();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
const VarNode* buffer = op->buffer_var.as<VarNode>();
- std::string scope = alloc_storage_scope_.at(buffer);
+ auto it = alloc_storage_scope_.find(buffer);
+ ICHECK(it != alloc_storage_scope_.end())
+ << "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key";
+
+ std::string scope = it->second;
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc
index b0fb9b6..da20dd8 100644
--- a/src/te/operation/cross_thread_reduction.cc
+++ b/src/te/operation/cross_thread_reduction.cc
@@ -145,7 +145,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
Array<PrimExpr> lhs;
for (size_t i = 0; i < size; ++i) {
DataType t = reduces[i]->dtype;
- normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), DataType::Handle());
+ normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i),
+ PointerType(PrimType(t)));
lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes())));
}
Array<PrimExpr> init_value = combiner->identity_element;
@@ -175,7 +176,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
freduce_args.push_back(const_true(1));
std::vector<Var> res_handles(size);
for (size_t idx = 0; idx < size; ++idx) {
- res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle());
+ DataType dtype = reduces[idx]->dtype;
+ res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype)));
freduce_args.push_back(res_handles[idx]);
}
diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc
index 23a2b3a..1667eb7 100644
--- a/src/tir/ir/buffer.cc
+++ b/src/tir/ir/buffer.cc
@@ -46,8 +46,9 @@ Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, Array<PrimExpr> array) {
}
Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, Span span) {
- return Buffer(Var(name, PointerType(PrimType(dtype)), span), dtype, shape, Array<PrimExpr>(),
- PrimExpr(), name, "", 0, 0, kDefault, span);
+ DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
+ return Buffer(Var(name, PointerType(PrimType(storage_dtype)), span), dtype, shape,
+ Array<PrimExpr>(), PrimExpr(), name, "", 0, 0, kDefault, span);
}
// Split the given expression w.r.t the add operator
@@ -384,9 +385,14 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
PrimExpr elem_offset, String name, String scope, int data_alignment,
int offset_factor, BufferType buffer_type, Span span) {
- ICHECK(IsPointerType(data->type_annotation, dtype))
+ DataType storage_dtype = dtype;
+ // specially handle bool
+ if (storage_dtype == DataType::Bool()) {
+ storage_dtype = DataType::Int(8);
+ }
+ ICHECK(IsPointerType(data->type_annotation, storage_dtype))
<< "Buffer data field expect to have the right pointer type annotation"
- << " annotation=" << data->type_annotation << ", dtype=" << dtype;
+ << " annotation=" << data->type_annotation << ", storage_dtype=" << storage_dtype;
auto n = make_object<BufferNode>();
n->data = std::move(data);
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 86960d9..fd03046 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -274,9 +274,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
// Allocate
Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Span span) {
- // TODO(tvm-team): Add invariant check to make sure
- // IsPointerPType(buffer_var->type_annotation, dtype)
- // once we fix the allocate tvm script printing.
+ CHECK(IsPointerType(buffer_var->type_annotation, dtype))
+ << "The allocated data type (" << dtype
+ << ") does not match the type annotation of the buffer " << buffer_var << " ("
+ << buffer_var->type_annotation
+ << "). The data type should be an element of the pointer type.";
+
for (size_t i = 0; i < extents.size(); ++i) {
ICHECK(extents[i].defined());
ICHECK(extents[i].dtype().is_scalar());
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index 529380b..e0ccb49 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -480,7 +480,6 @@ class IRSubstitue : public StmtExprMutator {
}
PrimExpr VisitExpr_(const LoadNode* op) final {
- // NOTE: we do not explicit recursivly mutate op->buffer_var
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<LoadNode>();
if (auto mapped_var = vmap_(op->buffer_var)) {
@@ -491,7 +490,6 @@ class IRSubstitue : public StmtExprMutator {
}
Stmt VisitStmt_(const StoreNode* op) final {
- // NOTE: we do not explicit recursivly mutate op->buffer_var
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<StoreNode>();
if (auto mapped_var = vmap_(op->buffer_var)) {
@@ -501,6 +499,18 @@ class IRSubstitue : public StmtExprMutator {
}
}
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ Stmt ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<AttrStmtNode>();
+ // remap var node in attr
+ if (const auto* var_node = op->node.as<VarNode>()) {
+ if (auto mapped_var = vmap_(GetRef<Var>(var_node))) {
+ return AttrStmt(mapped_var, op->attr_key, op->value, op->body);
+ }
+ }
+ return ret;
+ }
+
private:
std::function<Optional<PrimExpr>(const Var&)> vmap_;
};
diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc
index a3e5a92..21f1b18 100644
--- a/src/tir/transforms/lower_custom_datatypes.cc
+++ b/src/tir/transforms/lower_custom_datatypes.cc
@@ -44,14 +44,14 @@ class CustomDatatypesLowerer : public StmtExprMutator {
public:
explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {}
- inline PrimExpr VisitExpr_(const CastNode* op) final {
+ PrimExpr VisitExpr_(const CastNode* op) final {
auto type_code = op->dtype.code();
auto src_type_code = op->value.dtype().code();
// If either datatype is a registered custom datatype, we must lower.
- bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
- datatype::Registry::Global()->GetTypeRegistered(src_type_code);
+ bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
+ datatype::Registry::Global()->GetTypeRegistered(src_type_code);
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
- if (toBeLowered) {
+ if (to_be_lowered) {
auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code);
ICHECK(lower) << "Cast lowering function for target " << target_ << " destination type "
<< static_cast<unsigned>(type_code) << " source type "
@@ -61,7 +61,7 @@ class CustomDatatypesLowerer : public StmtExprMutator {
return expr;
}
- inline PrimExpr VisitExpr_(const FloatImmNode* imm) final {
+ PrimExpr VisitExpr_(const FloatImmNode* imm) final {
auto type_code = imm->dtype.code();
auto e = GetRef<PrimExpr>(imm);
if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
@@ -73,35 +73,86 @@ class CustomDatatypesLowerer : public StmtExprMutator {
return e;
}
- inline Stmt VisitStmt_(const AllocateNode* allocate) final {
- bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code());
- Stmt stmt = StmtExprMutator::VisitStmt_(allocate);
- allocate = stmt.as<AllocateNode>();
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ Var var = GetRef<Var>(op);
- if (toBeLowered) {
+ auto itr = var_remap_.find(var);
+ if (itr != var_remap_.end()) {
+ return itr->second;
+ } else {
+ return std::move(var);
+ }
+ }
+
+ Stmt VisitStmt_(const AllocateNode* allocate) final {
+ bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code());
+
+ if (to_be_lowered) {
auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes());
- return Allocate(allocate->buffer_var, new_allocate_type, allocate->extents,
- allocate->condition, allocate->body);
+ auto new_buffer_var =
+ Var(allocate->buffer_var->name_hint, PointerType(PrimType(new_allocate_type)));
+ var_remap_[allocate->buffer_var] = new_buffer_var;
+
+ Stmt stmt = StmtExprMutator::VisitStmt_(allocate);
+ allocate = stmt.as<AllocateNode>();
+
+ return Allocate(new_buffer_var, new_allocate_type, allocate->extents, allocate->condition,
+ allocate->body);
+ } else {
+ return StmtExprMutator::VisitStmt_(allocate);
}
- return stmt;
}
- inline PrimExpr VisitExpr_(const LoadNode* load) final {
- bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
+ PrimExpr VisitExpr_(const LoadNode* load) final {
+ bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
PrimExpr expr = StmtExprMutator::VisitExpr_(load);
load = expr.as<LoadNode>();
- if (toBeLowered) {
+ if (to_be_lowered) {
auto new_load_type = DataType::UInt(load->dtype.bits());
- return Load(new_load_type, load->buffer_var, load->index, load->predicate);
+ auto buffer_var = load->buffer_var;
+ auto it = var_remap_.find(buffer_var);
+ if (it != var_remap_.end()) {
+ buffer_var = it->second;
+ }
+ return Load(new_load_type, buffer_var, load->index, load->predicate);
}
return expr;
}
- inline PrimExpr VisitExpr_(const CallNode* call) final {
- bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(call->dtype.code());
+ Stmt VisitStmt_(const StoreNode* op) final {
+ Stmt ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<StoreNode>();
+
+ auto it = var_remap_.find(op->buffer_var);
+ if (it != var_remap_.end()) {
+ return Store(it->second, op->value, op->index, op->predicate);
+ } else {
+ return ret;
+ }
+ }
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ Stmt ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<AttrStmtNode>();
+ // Due to legacy reasons, some attr node can contain
+ // information(e.g. alignment) of buffer variables.
+ // remap these vars when needed
+ // TODO(tvm-team): remove the rewriting once the buffer var
+ // attrs are being refactored into the corresponding definition node
+ if (const auto* var_node = op->node.as<VarNode>()) {
+ auto it = var_remap_.find(GetRef<Var>(var_node));
+ if (it != var_remap_.end()) {
+ return AttrStmt(it->second, op->attr_key, op->value, op->body);
+ }
+ }
+ return ret;
+ }
+
+ PrimExpr VisitExpr_(const CallNode* call) final {
+ bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(call->dtype.code());
PrimExpr expr = StmtExprMutator::VisitExpr_(call);
call = expr.as<CallNode>();
- if (toBeLowered) {
+ if (to_be_lowered) {
auto op = call->op.as<OpNode>();
ICHECK(op != nullptr) << "Lowering non-intrinsic Calls not implemented";
auto lower = datatype::GetIntrinLowerFunc(target_, op->name, call->dtype.code());
@@ -113,38 +164,42 @@ class CustomDatatypesLowerer : public StmtExprMutator {
return expr;
}
-#define DEFINE_MUTATE(OP, NodeName) \
- inline PrimExpr VisitExpr_(const NodeName* op) final { \
- auto type_code = op->dtype.code(); \
- bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
- PrimExpr expr = StmtExprMutator::VisitExpr_(op); \
- op = expr.as<NodeName>(); \
- if (toBeLowered) { \
- auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \
- ICHECK(lower) << #OP " lowering function for target " << target_ << " type " \
- << static_cast<unsigned>(type_code) << " not found"; \
- return (*lower)(expr); \
- } \
- return expr; \
+#define TVM_DEFINE_MUTATE_CUSTOM_DTYPE(OP, NodeName) \
+ PrimExpr VisitExpr_(const NodeName* op) final { \
+ auto type_code = op->dtype.code(); \
+ bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op); \
+ op = expr.as<NodeName>(); \
+ if (to_be_lowered) { \
+ auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \
+ ICHECK(lower) << #OP " lowering function for target " << target_ << " type " \
+ << static_cast<unsigned>(type_code) << " not found"; \
+ return (*lower)(expr); \
+ } \
+ return expr; \
}
- DEFINE_MUTATE(Add, AddNode);
- DEFINE_MUTATE(Sub, SubNode);
- DEFINE_MUTATE(Mul, MulNode);
- DEFINE_MUTATE(Div, DivNode);
- DEFINE_MUTATE(Mod, ModNode);
- DEFINE_MUTATE(Min, MinNode);
- DEFINE_MUTATE(Max, MaxNode);
- DEFINE_MUTATE(EQ, EQNode);
- DEFINE_MUTATE(NE, NENode);
- DEFINE_MUTATE(LT, LTNode);
- DEFINE_MUTATE(LE, LENode);
- DEFINE_MUTATE(GT, GTNode);
- DEFINE_MUTATE(GE, GENode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Add, AddNode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Sub, SubNode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mul, MulNode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Div, DivNode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mod, ModNode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Min, MinNode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Max, MaxNode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(EQ, EQNode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(NE, NENode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LT, LTNode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LE, LENode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GT, GTNode);
+ TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GE, GENode);
// Later changes may need to add more mutate functions as we support workloads with more ops.
+#undef TVM_DEFINE_MUTATE_CUSTOM_DTYPE
+
private:
std::string target_;
+ // remap buffer vars
+ std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
};
namespace transform {
diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc
index c24e26b..f6cb096 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -224,14 +224,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
PrimExpr index(0);
for (size_t idx = 0; idx < size; ++idx) {
- shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle());
+ Type ptr_type = PointerType(PrimType(types[idx]));
+ shared_bufs[idx] = Var("red_buf" + std::to_string(idx), ptr_type);
PrimExpr pred = const_true(types[idx].lanes());
seq.emplace_back(Store(shared_bufs[idx], values[idx], index, pred));
// Uses a local variable to store the shuffled data.
// Later on, this allocation will be properly attached to this statement.
- Var var("t" + std::to_string(idx), types[idx]);
- Stmt s = Allocate(var, var.dtype(), {PrimExpr(1)}, pred, Evaluate(0));
+ Var var("t" + std::to_string(idx), ptr_type);
+ Stmt s = Allocate(var, types[idx], {PrimExpr(1)}, pred, Evaluate(0));
local_vars.push_back(s);
}
@@ -239,14 +240,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// a divergent control flow. Here it uses a variable to cache the current
// active channels.
//
- Var mask_var("mask", DataType::UInt(32));
+ DataType mask_dtype = DataType::UInt(32);
+ Var mask_var("mask", PointerType(PrimType(mask_dtype)));
{
PrimExpr pred = const_true(1);
- PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {});
+ PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
seq.emplace_back(Store(mask_var, mask, index, pred));
// Push allocation with an empty body. Later this will be fixed
// when the entire body is ready.
- auto stmt = Allocate(mask_var, mask_var->dtype, {PrimExpr(1)}, pred, Evaluate(0));
+ auto stmt = Allocate(mask_var, mask_dtype, {PrimExpr(1)}, pred, Evaluate(0));
local_vars.push_back(stmt);
}
@@ -338,7 +340,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// previous iteration on the same buffer.
seq.emplace_back(SyncThread("shared"));
for (size_t idx = 0; idx < size; ++idx) {
- shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle());
+ shared_bufs[idx] = Var("red_buf" + std::to_string(idx), PointerType(PrimType(types[idx])));
PrimExpr pred = const_true(types[idx].lanes());
seq.emplace_back(Store(shared_bufs[idx], values[idx],
BufIndex(reduce_index, group_index, reduce_extent), pred));
diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc
index 78c5ca7..d4c5ca0 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -23,6 +23,7 @@
* Re-write data access to enable memory sharing when possible.
*/
#include <tvm/arith/analyzer.h>
+#include <tvm/ir/type.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target_info.h>
#include <tvm/tir/analysis.h>
@@ -934,7 +935,12 @@ class VectorAllocRewriter : public StmtExprMutator {
if (me->base % factor == 0 && me->coeff % factor == 0) {
extents.Set(extents.size() - 1,
extents[extents.size() - 1] / make_const(extents[0].dtype(), factor));
- return Allocate(op->buffer_var, tvec[0], extents, op->condition, op->body);
+ // create a new buffer var
+ DataType new_dtype = tvec[0];
+ Var new_buffer_var(op->buffer_var->name_hint, PointerType(PrimType(new_dtype)));
+ // update the remap req.
+ var_remap_.Set(op->buffer_var, new_buffer_var);
+ return Allocate(new_buffer_var, new_dtype, extents, op->condition, op->body);
}
}
return stmt;
@@ -949,23 +955,21 @@ class VectorAllocRewriter : public StmtExprMutator {
// Internal access map
std::unordered_map<const VarNode*, std::vector<DataType> > acc_map_;
+ // Variables to remap
+ Map<tir::Var, PrimExpr> var_remap_;
// internal analyzer
arith::Analyzer analyzer_;
};
-Stmt StorageRewrite(Stmt stmt) {
- stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true);
- return VectorAllocRewriter()(std::move(stmt));
-}
-
PrimFunc PointerValueTypeRewrite(PrimFunc f) {
auto* n = f.CopyOnWrite();
VectorAllocRewriter rewriter;
- n->body = rewriter(n->body);
+ n->body = rewriter(std::move(n->body));
+ Map<tir::Var, PrimExpr> var_remap = std::move(rewriter.var_remap_);
Array<tir::Var> args;
- Map<tir::Var, PrimExpr> remap_vars;
+ // rewrite paramters if needed.
for (Var var : f->params) {
if (var.dtype().is_handle()) {
const auto& tvec = rewriter.acc_map_[var.get()];
@@ -973,15 +977,14 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) {
if (tvec.size() == 1) {
tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0])));
args.push_back(new_var);
- remap_vars.Set(var, new_var);
-
+ var_remap.Set(var, new_var);
} else {
// always set data type to be non vectorized so
// load/store can still work via scalarization
if (tvec.size() != 0 && !var->type_annotation.defined()) {
tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0].with_lanes(1))));
args.push_back(new_var);
- remap_vars.Set(var, new_var);
+ var_remap.Set(var, new_var);
} else {
args.push_back(var);
}
@@ -991,9 +994,13 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) {
}
}
+ // no variable remap is needed.
+ if (var_remap.size() == 0) return f;
+
+ // remap the variables.
ICHECK_EQ(args.size(), n->params.size());
n->params = args;
- n->body = Substitute(n->body, remap_vars);
+ n->body = Substitute(n->body, var_remap);
return f;
}
@@ -1003,8 +1010,7 @@ Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true);
- n->body = VectorAllocRewriter()(std::move(n->body));
- return f;
+ return PointerValueTypeRewrite(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
}
diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc
index 683caaa..9be8398 100644
--- a/tests/cpp/ir_functor_test.cc
+++ b/tests/cpp/ir_functor_test.cc
@@ -114,8 +114,9 @@ TEST(IRF, StmtVisitor) {
auto fmaketest = [&]() {
auto z = x + 1;
Stmt body = Evaluate(z);
- Var buffer("b", DataType::Handle());
- return Allocate(buffer, DataType::Float(32), {z, z}, const_true(), body);
+ DataType dtype = DataType::Float(32);
+ Var buffer("b", PointerType(PrimType(dtype)));
+ return Allocate(buffer, dtype, {z, z}, const_true(), body);
};
v(fmaketest());
ICHECK_EQ(v.count, 3);
@@ -140,8 +141,9 @@ TEST(IRF, StmtMutator) {
auto fmakealloc = [&]() {
auto z = x + 1;
Stmt body = Evaluate(z);
- Var buffer("b", DataType::Handle());
- return Allocate(buffer, DataType::Float(32), {1, z}, const_true(), body);
+ DataType dtype = DataType::Float(32);
+ Var buffer("b", PointerType(PrimType(dtype)));
+ return Allocate(buffer, dtype, {1, z}, const_true(), body);
};
auto fmakeif = [&]() {
diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py
index 3cde5d7..2bf4ba5 100644
--- a/tests/python/unittest/test_tir_constructor.py
+++ b/tests/python/unittest/test_tir_constructor.py
@@ -154,6 +154,7 @@ def test_stmt_constructor():
assert x.index.value == 10
assert x.value.value == 1
+ buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32")))
x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop)
assert isinstance(x, tvm.tir.Allocate)
assert x.dtype == "float32"