You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2019/12/22 02:26:27 UTC
[incubator-tvm] branch master updated: [REFACTOR][DTYPE] Isolate
dtype to runtime (#4560)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 7fa8aab [REFACTOR][DTYPE] Isolate dtype to runtime (#4560)
7fa8aab is described below
commit 7fa8aab563cca45797f4a694c1dfc06186549630
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sat Dec 21 18:26:21 2019 -0800
[REFACTOR][DTYPE] Isolate dtype to runtime (#4560)
dtype.h -> runtime/data_type.h
Changes:
- Rename all old reference of tvm::Type to DataType
- ExprNode.type -> ExprNode.dtype
- Expr.type() -> Expr.dtype()
- Change Expr related functions to expr_operator.
- DataType::min() -> min_value(DataType)
- DataType::max() -> max_value(DataType)
- Move type constructor Int, UInt, Float, Handle, Bool into DataType.
- Int(bits) -> DataType::Int(bits)
- UInt(bits) -> DataType::UInt(bits)
---
include/tvm/attrs.h | 8 +-
include/tvm/buffer.h | 18 +-
include/tvm/channel.h | 4 +-
include/tvm/expr.h | 18 +-
include/tvm/expr_operator.h | 46 +++--
include/tvm/ir.h | 66 +++----
include/tvm/node/reflection.h | 3 +-
include/tvm/operation.h | 20 +-
include/tvm/packed_func_ext.h | 17 --
include/tvm/relay/attrs/memory.h | 2 +-
include/tvm/relay/base.h | 2 +-
include/tvm/{dtype.h => runtime/data_type.h} | 206 ++++++++++----------
include/tvm/runtime/packed_func.h | 41 +++-
include/tvm/tensor.h | 4 +-
nnvm/include/nnvm/compiler/util.h | 2 +-
nnvm/src/compiler/alter_op_layout.cc | 2 +-
nnvm/src/compiler/compile_engine.cc | 56 +++---
nnvm/src/compiler/compile_engine.h | 4 +-
nnvm/src/compiler/graph_fuse.cc | 6 +-
nnvm/src/compiler/graph_fuse.h | 2 +-
nnvm/src/top/nn/nn.cc | 4 +-
nnvm/src/top/tensor/elemwise.cc | 14 +-
nnvm/src/top/tensor/transform.cc | 6 +-
src/api/api_ir.cc | 14 +-
src/api/api_lang.cc | 8 +-
src/arithmetic/bound_deducer.cc | 2 +-
src/arithmetic/canonical_simplify.cc | 42 ++---
src/arithmetic/compute_expr.h | 2 +-
src/arithmetic/const_fold.h | 50 ++---
src/arithmetic/const_int_bound.cc | 18 +-
src/arithmetic/detect_linear_equation.cc | 18 +-
src/arithmetic/domain_touched.cc | 2 +-
src/arithmetic/int_set.cc | 24 +--
src/arithmetic/ir_mutator_with_analyzer.cc | 2 +-
src/arithmetic/pattern_match.h | 10 +-
src/arithmetic/rewrite_simplify.cc | 79 ++++----
src/autotvm/touch_extractor.cc | 30 +--
src/autotvm/touch_extractor.h | 10 +-
src/codegen/build_common.h | 2 +-
src/codegen/build_module.cc | 6 +-
src/codegen/codegen_c.cc | 118 ++++++------
src/codegen/codegen_c.h | 26 +--
src/codegen/codegen_c_host.cc | 16 +-
src/codegen/codegen_c_host.h | 2 +-
src/codegen/codegen_cuda.cc | 60 +++---
src/codegen/codegen_cuda.h | 18 +-
src/codegen/codegen_metal.cc | 36 ++--
src/codegen/codegen_metal.h | 6 +-
src/codegen/codegen_opencl.cc | 26 +--
src/codegen/codegen_opencl.h | 10 +-
src/codegen/codegen_opengl.cc | 18 +-
src/codegen/codegen_opengl.h | 4 +-
src/codegen/codegen_source_base.cc | 2 +-
src/codegen/codegen_source_base.h | 4 +-
src/codegen/codegen_vhls.cc | 12 +-
src/codegen/codegen_vhls.h | 2 +-
src/codegen/intrin_rule.cc | 4 +-
src/codegen/intrin_rule.h | 16 +-
src/codegen/intrin_rule_cuda.cc | 8 +-
src/codegen/intrin_rule_opencl.cc | 2 +-
src/codegen/llvm/codegen_amdgpu.cc | 8 +-
src/codegen/llvm/codegen_arm.cc | 45 ++---
src/codegen/llvm/codegen_cpu.cc | 30 +--
src/codegen/llvm/codegen_cpu.h | 4 +-
src/codegen/llvm/codegen_llvm.cc | 106 +++++------
src/codegen/llvm/codegen_llvm.h | 26 +--
src/codegen/llvm/codegen_nvptx.cc | 8 +-
src/codegen/llvm/codegen_x86_64.cc | 30 +--
src/codegen/llvm/intrin_rule_llvm.cc | 12 +-
src/codegen/llvm/intrin_rule_llvm.h | 16 +-
src/codegen/llvm/intrin_rule_nvptx.cc | 6 +-
src/codegen/llvm/intrin_rule_rocm.cc | 4 +-
src/codegen/spirv/codegen_spirv.cc | 72 +++----
src/codegen/spirv/codegen_spirv.h | 8 +-
src/codegen/spirv/intrin_rule_spirv.cc | 4 +-
src/codegen/spirv/ir_builder.cc | 56 +++---
src/codegen/spirv/ir_builder.h | 6 +-
src/codegen/stackvm/codegen_stackvm.cc | 16 +-
src/codegen/stackvm/codegen_stackvm.h | 2 +-
src/contrib/hybrid/codegen_hybrid.cc | 20 +-
src/contrib/hybrid/codegen_hybrid.h | 2 +-
src/lang/attrs.cc | 8 +-
src/lang/buffer.cc | 40 ++--
src/lang/channel.cc | 2 +-
src/lang/expr.cc | 75 +-------
src/lang/expr_operator.cc | 271 ++++++++++++++++-----------
src/lang/ir.cc | 132 ++++++-------
src/lang/tensor.cc | 2 +-
src/node/reflection.cc | 4 +-
src/node/serialization.cc | 6 +-
src/op/compute_op.cc | 14 +-
src/op/cross_thread_reduction.cc | 16 +-
src/op/extern_op.cc | 12 +-
src/op/hybrid_op.cc | 8 +-
src/op/op_util.cc | 8 +-
src/op/placeholder_op.cc | 6 +-
src/op/scan_op.cc | 2 +-
src/op/tensor_compute_op.cc | 6 +-
src/op/tensorize.cc | 14 +-
src/pass/arg_binder.cc | 76 ++++----
src/pass/bound_checker.cc | 30 +--
src/pass/combine_context_call.cc | 4 +-
src/pass/coproc_sync.cc | 12 +-
src/pass/detect_device.cc | 2 +-
src/pass/inject_copy_intrin.cc | 14 +-
src/pass/inject_double_buffer.cc | 26 +--
src/pass/inject_virtual_thread.cc | 24 +--
src/pass/ir_deep_compare.cc | 10 +-
src/pass/ir_mutator.cc | 12 +-
src/pass/ir_util.h | 38 ++--
src/pass/lift_attr_scope.cc | 4 +-
src/pass/loop_partition.cc | 10 +-
src/pass/lower_custom_datatypes.cc | 16 +-
src/pass/lower_intrin.cc | 22 +--
src/pass/lower_thread_allreduce.cc | 16 +-
src/pass/lower_tvm_builtin.cc | 80 ++++----
src/pass/lower_warp_memory.cc | 26 +--
src/pass/make_api.cc | 35 ++--
src/pass/narrow_channel_access.cc | 2 +-
src/pass/rewrite_unsafe_select.cc | 4 +-
src/pass/split_host_device.cc | 8 +-
src/pass/split_pipeline.cc | 14 +-
src/pass/ssa.cc | 12 +-
src/pass/storage_access.cc | 20 +-
src/pass/storage_access.h | 2 +-
src/pass/storage_flatten.cc | 52 ++---
src/pass/storage_rewrite.cc | 78 ++++----
src/pass/storage_sync.cc | 10 +-
src/pass/tensor_core.cc | 110 +++++------
src/pass/unroll_loop.cc | 2 +-
src/pass/vectorize_loop.cc | 82 ++++----
src/pass/verify_gpu_code.cc | 4 +-
src/relay/backend/build_module.cc | 2 +-
src/relay/backend/compile_engine.cc | 36 ++--
src/relay/backend/interpreter.cc | 10 +-
src/relay/backend/utils.h | 11 +-
src/relay/backend/vm/compiler.cc | 2 +-
src/relay/ir/doc.cc | 2 +-
src/relay/ir/doc.h | 6 +-
src/relay/ir/expr.cc | 4 +-
src/relay/ir/pretty_printer.cc | 20 +-
src/relay/ir/type.cc | 2 +-
src/relay/op/memory/memory.cc | 8 +-
src/relay/op/nn/nn.cc | 2 +-
src/relay/op/nn/pad.cc | 4 +-
src/relay/op/nn/upsampling.cc | 4 +-
src/relay/op/tensor/reduce.cc | 7 +-
src/relay/op/tensor/transform.cc | 19 +-
src/relay/op/type_relations.cc | 2 +-
src/relay/op/vision/multibox_op.cc | 2 +-
src/relay/op/vision/nms.cc | 4 +-
src/relay/pass/combine_parallel_conv2d.cc | 2 +-
src/relay/pass/fold_constant.cc | 2 +-
src/relay/pass/fuse_ops.cc | 12 +-
src/relay/pass/partial_eval.cc | 8 +-
src/relay/pass/pattern_util.h | 26 +--
src/relay/pass/quantize/calibrate.cc | 2 +-
src/relay/pass/quantize/quantize.cc | 6 +-
src/relay/pass/quantize/quantize.h | 6 +-
src/relay/pass/quantize/realize.cc | 13 +-
src/relay/pass/simplify_inference.cc | 6 +-
src/relay/pass/type_infer.cc | 2 +-
src/relay/qnn/op/add.cc | 10 +-
src/relay/qnn/op/convolution.cc | 49 ++---
src/relay/qnn/op/dense.cc | 16 +-
src/relay/qnn/op/dequantize.cc | 14 +-
src/relay/qnn/op/mul.cc | 8 +-
src/relay/qnn/op/quantize.cc | 12 +-
src/relay/qnn/op/requantize.cc | 10 +-
src/relay/qnn/util.cc | 2 +-
src/relay/qnn/util.h | 8 +-
src/runtime/contrib/tflite/tflite_runtime.cc | 38 ++--
src/schedule/graph.cc | 6 +-
src/schedule/message_passing.cc | 8 +-
src/schedule/schedule_dataflow_rewrite.cc | 4 +-
src/schedule/schedule_lang.cc | 4 +-
src/schedule/schedule_ops.cc | 4 +-
tests/cpp/attrs_test.cc | 2 +-
tests/cpp/build_module_test.cc | 12 +-
tests/cpp/ir_mutator_test.cc | 2 +-
tests/cpp/packed_func_test.cc | 8 +-
tests/cpp/pattern_match_test.cc | 13 +-
tests/cpp/relay_build_module_test.cc | 2 +-
tests/cpp/relay_pass_type_infer_test.cc | 2 +-
tests/cpp/relay_transform_sequential.cc | 2 +-
tests/cpp/simple_passes_test.cc | 2 +-
tests/cpp/tensor_test.cc | 8 +-
tests/cpp/topi_ewise_test.cc | 2 +-
topi/include/topi/broadcast.h | 4 +-
topi/include/topi/cuda/dense.h | 2 +-
topi/include/topi/detail/broadcast.h | 2 +-
topi/include/topi/detail/extern.h | 16 +-
topi/include/topi/elemwise.h | 14 +-
topi/include/topi/image/resize.h | 56 +++---
topi/include/topi/nn.h | 6 +-
topi/include/topi/nn/bnn.h | 8 +-
topi/include/topi/nn/dense.h | 2 +-
topi/include/topi/nn/dilate.h | 2 +-
topi/include/topi/nn/pooling.h | 88 +++++----
topi/include/topi/reduction.h | 14 +-
topi/include/topi/rocm/dense.h | 2 +-
topi/include/topi/transform.h | 24 +--
topi/src/topi.cc | 6 +-
203 files changed, 2003 insertions(+), 1947 deletions(-)
diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h
index 2fbb9e6..8810c4e 100644
--- a/include/tvm/attrs.h
+++ b/include/tvm/attrs.h
@@ -159,7 +159,7 @@ class AttrsEqual {
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
- bool operator()(const Type& lhs, const Type& rhs) const {
+ bool operator()(const DataType& lhs, const DataType& rhs) const {
return lhs == rhs;
}
// node comparator
@@ -506,8 +506,8 @@ inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
}
}
template<>
-inline void SetValue(Type* ptr, const TVMArgValue& val) {
- *ptr = val.operator Type();
+inline void SetValue(DataType* ptr, const TVMArgValue& val) {
+ *ptr = val.operator DataType();
}
template<>
inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
@@ -611,7 +611,7 @@ struct TypeName<uint64_t> {
};
template<>
-struct TypeName<Type> {
+struct TypeName<DataType> {
static constexpr const char* value = "Type";
};
diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h
index d2c2b40..fac18a9 100644
--- a/include/tvm/buffer.h
+++ b/include/tvm/buffer.h
@@ -74,14 +74,16 @@ class Buffer : public NodeRef {
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
*/
- TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
- int content_lanes = 1, Expr offset = make_const(Int(32), 0)) const;
+ TVM_DLL Expr access_ptr(int access_mask,
+ DataType ptr_type = DataType::Handle(),
+ int content_lanes = 1,
+ Expr offset = make_const(DataType::Int(32), 0)) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
*/
- TVM_DLL Expr vload(Array<Expr> begin, Type dtype) const;
+ TVM_DLL Expr vload(Array<Expr> begin, DataType dtype) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
@@ -108,7 +110,7 @@ class BufferNode : public Node {
*/
Var data;
/*! \brief data type in the content of the tensor */
- Type dtype;
+ DataType dtype;
/*! \brief The shape of the buffer */
Array<Expr> shape;
/*!
@@ -149,14 +151,14 @@ class BufferNode : public Node {
}
/*! \return preferred index type for this buffer node */
- Type DefaultIndexType() const {
- return shape.size() != 0 ? shape[0].type() : Int(32);
+ DataType DefaultIndexType() const {
+ return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
}
// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
TVM_DLL static Buffer make(Var ptr,
- Type dtype,
+ DataType dtype,
Array<Expr> shape,
Array<Expr> strides,
Expr elem_offset,
@@ -183,7 +185,7 @@ inline const BufferNode* Buffer::operator->() const {
* \sa BufferNode::make for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<Expr> shape,
- Type dtype = Float(32),
+ DataType dtype = DataType::Float(32),
std::string name = "buffer");
} // namespace tvm
#endif // TVM_BUFFER_H_
diff --git a/include/tvm/channel.h b/include/tvm/channel.h
index 3a40a78..25ee7a9 100644
--- a/include/tvm/channel.h
+++ b/include/tvm/channel.h
@@ -52,14 +52,14 @@ struct ChannelNode : public Node {
/*! \brief Variable to channel handle */
Var handle_var;
/*! \brief default data type in read/write */
- Type dtype;
+ DataType dtype;
// visit all attributes
void VisitAttrs(AttrVisitor* v) {
v->Visit("handle_var", &handle_var);
v->Visit("dtype", &dtype);
}
- static Channel make(Var handle_var, Type dtype);
+ static Channel make(Var handle_var, DataType dtype);
static constexpr const char* _type_key = "Channel";
TVM_DECLARE_NODE_TYPE_INFO(ChannelNode, Node);
diff --git a/include/tvm/expr.h b/include/tvm/expr.h
index fc52421..f27cb98 100644
--- a/include/tvm/expr.h
+++ b/include/tvm/expr.h
@@ -29,11 +29,11 @@
#include <unordered_map>
#include <iostream>
#include "base.h"
-#include "dtype.h"
#include "node/node.h"
#include "node/container.h"
#include "node/functor.h"
#include "runtime/c_runtime_api.h"
+#include "runtime/data_type.h"
namespace tvm {
@@ -41,7 +41,7 @@ namespace tvm {
class ExprNode : public Node {
public:
/*! \brief The data type of the expression. */
- DataType type;
+ DataType dtype;
static constexpr const char* _type_key = "Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node);
@@ -69,8 +69,8 @@ class Expr : public NodeRef {
TVM_DLL Expr(std::string str); // NOLINT(*)
/*! \return the data type of this expression. */
- DataType type() const {
- return static_cast<const ExprNode*>(get())->type;
+ DataType dtype() const {
+ return static_cast<const ExprNode*>(get())->dtype;
}
/*! \brief type indicate the container type */
@@ -113,7 +113,7 @@ class Variable : public ExprNode {
static Var make(DataType dtype, std::string name_hint);
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("name", &name_hint);
}
@@ -126,14 +126,14 @@ class Var : public Expr {
public:
explicit Var(ObjectPtr<Object> n) : Expr(n) {}
TVM_DLL explicit Var(std::string name_hint = "v",
- Type t = Int(32));
+ DataType t = DataType::Int(32));
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
* \return the new Var copy
*/
Var copy_with_suffix(const std::string& suffix) const {
- return Var((*this)->name_hint + suffix, (*this)->type);
+ return Var((*this)->name_hint + suffix, (*this)->dtype);
}
/*!
* \brief Get pointer to the internal value.
@@ -167,7 +167,7 @@ class IntImm : public ExprNode {
int64_t value;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
@@ -452,7 +452,7 @@ inline const char* IterVarType2String(IterVarType t) {
* \param name_hint The name hint for the expression
* \param t The type of the expression
*/
-TVM_DLL Var var(std::string name_hint, Type t = Int(32));
+TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32));
/*
* \brief Template function to convert Map to unordered_map
diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h
index 625ee8e..41e7aa5 100644
--- a/include/tvm/expr_operator.h
+++ b/include/tvm/expr_operator.h
@@ -44,20 +44,20 @@ namespace tvm {
*/
template<typename ValueType,
typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
-inline Expr make_const(Type t, ValueType value);
+inline Expr make_const(DataType t, ValueType value);
/*!
* \brief Make a const zero expr.
* \param t The target type.
* \return the result expression.
*/
-inline Expr make_zero(Type t);
+inline Expr make_zero(DataType t);
/*!
* \brief Make a constant true expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline Expr const_true(int lanes = 1) {
- return make_const(UInt(1, lanes), 1);
+ return make_const(DataType::UInt(1, lanes), 1);
}
/*!
* \brief Make a constant false expression.
@@ -65,7 +65,7 @@ inline Expr const_true(int lanes = 1) {
* \return The result expression.
*/
inline Expr const_false(int lanes = 1) {
- return make_const(UInt(1, lanes), 0);
+ return make_const(DataType::UInt(1, lanes), 0);
}
/*!
* \brief Get x as constant int expression.
@@ -140,6 +140,20 @@ inline bool is_zero(const Expr& x) {
inline bool is_const(const Expr& x);
/*!
+ * Query the maximum possible value of dtype.
+ * \param dtype The data type.
+ * \return the maximum possible value in this format.
+ */
+TVM_DLL Expr max_value(const DataType& dtype);
+
+/*!
+ * Query the minimum possible value of dtype.
+ * \param dtype The data type.
+ * \return the minimum possible value in this format.
+ */
+TVM_DLL Expr min_value(const DataType& dtype);
+
+/*!
* \brief Check whether x is a constant power of two
* If x is power of two, write the power to the shift.
*
@@ -157,7 +171,7 @@ TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift);
* \return The result expression.
* \note This function may return value if the type is the same.
*/
-TVM_DLL Expr cast(const Type& t, Expr value);
+TVM_DLL Expr cast(const DataType& t, Expr value);
/*!
* \brief perform reinterpret cast value to type.
*
@@ -166,7 +180,7 @@ TVM_DLL Expr cast(const Type& t, Expr value);
* \return The result expression.
* \note This function may return value if the type is the same.
*/
-TVM_DLL Expr reinterpret(const Type& t, Expr value);
+TVM_DLL Expr reinterpret(const DataType& t, Expr value);
/*!
* \brief add operator
*
@@ -586,7 +600,7 @@ TVM_DLL Expr trunc(Expr x);
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
- return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
+ return ir::Call::make(x.dtype(), #OpName, {x}, ir::Call::PureIntrinsic); \
} \
TVM_DECLARE_INTRIN_UNARY(exp);
@@ -657,7 +671,7 @@ inline bool is_no_op(const Stmt& stmt) {
}
template<typename ValueType>
-inline Expr MakeConstScalar(Type t, ValueType value) {
+inline Expr MakeConstScalar(DataType t, ValueType value) {
if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value));
@@ -672,7 +686,7 @@ inline Expr MakeConstScalar(Type t, ValueType value) {
}
template<typename ValueType, typename>
-inline Expr make_const(Type t, ValueType value) {
+inline Expr make_const(DataType t, ValueType value) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
@@ -681,9 +695,9 @@ inline Expr make_const(Type t, ValueType value) {
}
}
-inline Expr make_zero(Type t) {
+inline Expr make_zero(DataType t) {
if (t.is_handle()) {
- return reinterpret(t, make_const(UInt(64), 0));
+ return reinterpret(t, make_const(DataType::UInt(64), 0));
}
return make_const(t, 0);
}
@@ -703,13 +717,13 @@ inline Expr make_zero(Type t) {
return Name(Expr(a), b); \
} \
inline Expr Name(int a, const Expr& b) { \
- return Name(make_const(b.type(), a), b); \
+ return Name(make_const(b.dtype(), a), b); \
} \
inline Expr Name(const Expr& a, int b) { \
- return Name(a, make_const(a.type(), b)); \
+ return Name(a, make_const(a.dtype(), b)); \
} \
inline Expr Name(const Expr& a, double b) { \
- return Name(a, make_const(Float(64), b)); \
+ return Name(a, make_const(DataType::Float(64), b)); \
}
#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
@@ -722,10 +736,10 @@ inline Expr make_zero(Type t) {
#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, int b) { \
- return Name(a, make_const(a.type(), b)); \
+ return Name(a, make_const(a.dtype(), b)); \
} \
inline Expr Name(int a, const Expr& b) { \
- return Name(make_const(b.type(), a), b); \
+ return Name(make_const(b.dtype(), a), b); \
}
diff --git a/include/tvm/ir.h b/include/tvm/ir.h
index 53eb94e..33aa72b 100644
--- a/include/tvm/ir.h
+++ b/include/tvm/ir.h
@@ -46,7 +46,7 @@ class UIntImm : public ExprNode {
uint64_t value;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
@@ -63,7 +63,7 @@ class FloatImm : public ExprNode {
double value;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
@@ -80,7 +80,7 @@ class StringImm : public ExprNode {
std::string value;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
@@ -100,7 +100,7 @@ class Cast : public ExprNode {
Expr value;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
@@ -123,7 +123,7 @@ class BinaryOpNode : public ExprNode {
Expr b;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &(this->type));
+ v->Visit("dtype", &(this->dtype));
v->Visit("a", &a);
v->Visit("b", &b);
}
@@ -131,9 +131,9 @@ class BinaryOpNode : public ExprNode {
static Expr make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
- CHECK(a.type() == b.type()) << "TypeError: mismatched types\n";
+ CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n";
NodePtr<T> node = make_node<T>();
- node->type = a.type();
+ node->dtype = a.dtype();
node->a = std::move(a);
node->b = std::move(b);
return Expr(node);
@@ -215,7 +215,7 @@ class CmpOpNode : public ExprNode {
Expr b;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &(this->type));
+ v->Visit("dtype", &(this->dtype));
v->Visit("a", &a);
v->Visit("b", &b);
}
@@ -223,9 +223,9 @@ class CmpOpNode : public ExprNode {
static Expr make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
- CHECK(a.type() == b.type()) << "TypeError: mismatched types\n";
+ CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n";
NodePtr<T> node = make_node<T>();
- node->type = Bool(a.type().lanes());
+ node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
node->b = std::move(b);
return Expr(node);
@@ -279,7 +279,7 @@ class And : public ExprNode {
Expr b;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &(this->type));
+ v->Visit("dtype", &(this->dtype));
v->Visit("a", &a);
v->Visit("b", &b);
}
@@ -299,7 +299,7 @@ class Or : public ExprNode {
Expr b;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("a", &a);
v->Visit("b", &b);
}
@@ -317,7 +317,7 @@ class Not : public ExprNode {
Expr a;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("a", &a);
}
@@ -344,7 +344,7 @@ class Select : public ExprNode {
Expr false_value;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("condition", &condition);
v->Visit("true_value", &true_value);
v->Visit("false_value", &false_value);
@@ -381,13 +381,13 @@ class Load : public ExprNode {
Expr predicate;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("buffer_var", &buffer_var);
v->Visit("index", &index);
v->Visit("predicate", &predicate);
}
- TVM_DLL static Expr make(DataType type, Var buffer_var, Expr index, Expr predicate);
+ TVM_DLL static Expr make(DataType dtype, Var buffer_var, Expr index, Expr predicate);
static constexpr const char* _type_key = "Load";
TVM_DECLARE_NODE_TYPE_INFO(Load, ExprNode);
@@ -412,7 +412,7 @@ class Ramp : public ExprNode {
int lanes;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("base", &base);
v->Visit("stride", &stride);
v->Visit("lanes", &lanes);
@@ -433,7 +433,7 @@ class Broadcast : public ExprNode {
int lanes;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("value", &value);
v->Visit("lanes", &lanes);
}
@@ -457,7 +457,7 @@ class Let : public ExprNode {
Expr body;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
@@ -523,7 +523,7 @@ class Call : public ExprNode {
int value_index{0};
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("call_type", &call_type);
@@ -531,7 +531,7 @@ class Call : public ExprNode {
v->Visit("value_index", &value_index);
}
- TVM_DLL static Expr make(DataType type,
+ TVM_DLL static Expr make(DataType dtype,
std::string name,
Array<Expr> args,
CallType call_type,
@@ -695,7 +695,7 @@ class Reduce : public ExprNode {
int value_index);
void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("combiner", &combiner);
v->Visit("source", &source);
v->Visit("axis", &axis);
@@ -713,7 +713,7 @@ class Any : public ExprNode {
void VisitAttrs(AttrVisitor* v) {}
/*! \brief Convert to var. */
Var ToVar() const {
- return Variable::make(Int(32), "any_dim");
+ return Variable::make(DataType::Int(32), "any_dim");
}
TVM_DLL static Expr make();
@@ -917,7 +917,7 @@ class Allocate : public StmtNode {
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The type of the buffer. */
- DataType type;
+ DataType dtype;
/*! \brief The extents of the buffer. */
Array<Expr> extents;
/*! \brief Only allocate buffer when condition is satisfied. */
@@ -931,14 +931,14 @@ class Allocate : public StmtNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("condition", &condition);
v->Visit("body", &body);
}
TVM_DLL static Stmt make(Var buffer_var,
- DataType type,
+ DataType dtype,
Array<Expr> extents,
Expr condition,
Stmt body,
@@ -993,7 +993,7 @@ class Realize : public StmtNode {
/*! \brief The output value index if func's value is a tuple. */
int value_index;
/*! \brief The data type of the array. */
- DataType type;
+ DataType dtype;
/*! \brief Bounds to be realized. */
Region bounds;
/*! \brief Only realize if condition holds. */
@@ -1004,7 +1004,7 @@ class Realize : public StmtNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
- v->Visit("dtype", &type);
+ v->Visit("dtype", &dtype);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
@@ -1012,7 +1012,7 @@ class Realize : public StmtNode {
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
- DataType type,
+ DataType dtype,
Region bounds,
Expr condition,
Stmt body);
@@ -1165,20 +1165,20 @@ class Prefetch : public StmtNode {
/*! \brief The output value index if func's value is a tuple. */
int value_index;
/*! \brief The data type of the array. */
- DataType type;
+ DataType dtype;
/*! \brief Bounds to be prefetched. */
Region bounds;
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("value_index", &value_index);
- v->Visit("type", &type);
+ v->Visit("dtype", &dtype);
v->Visit("bounds", &bounds);
}
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
- DataType type,
+ DataType dtype,
Region bounds);
static constexpr const char* _type_key = "Prefetch";
@@ -1620,7 +1620,7 @@ constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
* \param dtype The data type
* \return Expr a expression with dtype.
*/
-inline Expr TypeAnnotation(Type dtype) {
+inline Expr TypeAnnotation(DataType dtype) {
return ir::Call::make(dtype,
"type_annotation", {},
ir::Call::PureIntrinsic);
diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h
index 35a8e1d..daffeb8 100644
--- a/include/tvm/node/reflection.h
+++ b/include/tvm/node/reflection.h
@@ -28,6 +28,7 @@
#include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/data_type.h>
#include <vector>
#include <string>
@@ -35,8 +36,6 @@
namespace tvm {
// forward declaration
-class DataType;
-
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
diff --git a/include/tvm/operation.h b/include/tvm/operation.h
index f53c1ce..34f584b 100644
--- a/include/tvm/operation.h
+++ b/include/tvm/operation.h
@@ -75,7 +75,7 @@ class OperationNode : public ir::FunctionBaseNode {
* \param i The output index.
* \return type of i-th output.
*/
- virtual Type output_dtype(size_t i) const = 0;
+ virtual DataType output_dtype(size_t i) const = 0;
/*!
* \brief Get shape of i-th output tensor.
* \param i The output index.
@@ -160,11 +160,11 @@ class PlaceholderOpNode : public OperationNode {
/*! \brief The shape of the input */
Array<Expr> shape;
/*! \brief The data type of the input. */
- Type dtype;
+ DataType dtype;
// override behavior.
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
- Type output_dtype(size_t i) const final;
+ DataType output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
@@ -197,7 +197,7 @@ class PlaceholderOpNode : public OperationNode {
}
static Operation make(std::string name,
Array<Expr> shape,
- Type dtype);
+ DataType dtype);
static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode);
@@ -243,7 +243,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
ComputeOpNode() {}
// override functions
int num_outputs() const final;
- Type output_dtype(size_t i) const final;
+ DataType output_dtype(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
@@ -296,7 +296,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
TensorComputeOpNode() {}
// override functions
int num_outputs() const final;
- Type output_dtype(size_t i) const final;
+ DataType output_dtype(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
@@ -370,7 +370,7 @@ class ScanOpNode : public OperationNode {
// override behavior.
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
- Type output_dtype(size_t i) const final;
+ DataType output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
@@ -437,7 +437,7 @@ class ExternOpNode : public OperationNode {
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
- Type output_dtype(size_t i) const final;
+ DataType output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
@@ -505,7 +505,7 @@ class HybridOpNode : public OperationNode {
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
- Type output_dtype(size_t i) const final;
+ DataType output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
@@ -562,7 +562,7 @@ using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;
* \param name The name of the Tensor.
*/
TVM_DLL Tensor placeholder(Array<Expr> shape,
- Type dtype = Float(32),
+ DataType dtype = DataType::Float(32),
std::string name = "placeholder");
/*!
diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h
index 71f8f55..93b7ac3 100644
--- a/include/tvm/packed_func_ext.h
+++ b/include/tvm/packed_func_ext.h
@@ -208,23 +208,6 @@ inline TObjectRef TVMRetValue::AsObjectRef() const {
return TObjectRef(ObjectPtr<Object>(ptr));
}
-// type related stuffs
-inline TVMRetValue& TVMRetValue::operator=(const DataType& t) {
- return this->operator=(t.operator DLDataType());
-}
-
-inline TVMRetValue::operator tvm::DataType() const {
- return DataType(operator DLDataType());
-}
-
-inline TVMArgValue::operator tvm::DataType() const {
- return DataType(operator DLDataType());
-}
-
-inline void TVMArgsSetter::operator()(
- size_t i, const DataType& t) const {
- this->operator()(i, t.operator DLDataType());
-}
} // namespace runtime
} // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_
diff --git a/include/tvm/relay/attrs/memory.h b/include/tvm/relay/attrs/memory.h
index 2e279a5..c74b648 100644
--- a/include/tvm/relay/attrs/memory.h
+++ b/include/tvm/relay/attrs/memory.h
@@ -43,7 +43,7 @@ struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
TVM_ATTR_FIELD(dtype)
.describe(
"The dtype of the tensor to allocate.")
- .set_default(Float(32, 1));
+ .set_default(DataType::Float(32, 1));
TVM_ATTR_FIELD(const_shape)
.describe(
"The shape of constant used to aid in type inference.");
diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h
index 42a01f0..32f9c32 100644
--- a/include/tvm/relay/base.h
+++ b/include/tvm/relay/base.h
@@ -63,7 +63,7 @@ using NodeRef = tvm::NodeRef;
/*!
* \brief Content data type.
*/
-using DataType = ::tvm::Type;
+using DataType = ::tvm::DataType;
/*!
* \brief Symbolic expression for tensor shape.
diff --git a/include/tvm/dtype.h b/include/tvm/runtime/data_type.h
similarity index 57%
rename from include/tvm/dtype.h
rename to include/tvm/runtime/data_type.h
index 9f7902d..5b222ac 100644
--- a/include/tvm/dtype.h
+++ b/include/tvm/runtime/data_type.h
@@ -17,23 +17,35 @@
* under the License.
*/
/*
- * \file tvm/dtype.h
- * \brief Data type used in IR.
+ * \file tvm/runtime/data_type.h
+ * \brief Primitive runtime data type.
*/
// Acknowledgement: DataType structure design originates from Halide.
-#ifndef TVM_DTYPE_H_
-#define TVM_DTYPE_H_
+#ifndef TVM_RUNTIME_DATA_TYPE_H_
+#define TVM_RUNTIME_DATA_TYPE_H_
-#include "runtime/packed_func.h"
+#include <tvm/runtime/c_runtime_api.h>
+#include <dmlc/logging.h>
+#include <type_traits>
-namespace tvm {
-class Expr;
+namespace tvm {
+namespace runtime {
/*!
- * \brief Primitive data types in tvm.
+ * \brief Runtime primitive data type.
+ *
+ * This class is a thin wrapper of DLDataType.
+ * We also make use of DataType in compiler to store quick hint
*/
class DataType {
public:
+ /*! \brief Type code for the DataType. */
+ enum TypeCode {
+ kInt = kDLInt,
+ kUInt = kDLUInt,
+ kFloat = kDLFloat,
+ kHandle = TVMTypeCode::kHandle,
+ };
/*! \brief default constructor */
DataType() {}
/*!
@@ -75,23 +87,23 @@ class DataType {
}
/*! \return whether type is a scalar type. */
bool is_bool() const {
- return code() == kDLUInt && bits() == 1;
+ return code() == DataType::kUInt && bits() == 1;
}
/*! \return whether type is a float type. */
bool is_float() const {
- return code() == kDLFloat;
+ return code() == DataType::kFloat;
}
/*! \return whether type is an int type. */
bool is_int() const {
- return code() == kDLInt;
+ return code() == DataType::kInt;
}
/*! \return whether type is an uint type. */
bool is_uint() const {
- return code() == kDLUInt;
+ return code() == DataType::kUInt;
}
/*! \return whether type is a handle type. */
bool is_handle() const {
- return code() == kHandle;
+ return code() == DataType::kHandle;
}
/*! \return whether type is a vector type. */
bool is_vector() const {
@@ -120,108 +132,94 @@ class DataType {
DataType element_of() const {
return with_lanes(1);
}
- // operator overloadings
+ /*!
+ * \brief Equal comparator.
+ * \param other The data type to compre against.
+ * \return The comparison resilt.
+ */
bool operator==(const DataType& other) const {
return
data_.code == other.data_.code &&
data_.bits == other.data_.bits &&
data_.lanes == other.data_.lanes;
}
+ /*!
+ * \brief NotEqual comparator.
+ * \param other The data type to compre against.
+ * \return The comparison resilt.
+ */
bool operator!=(const DataType& other) const {
return !operator==(other);
}
+ /*!
+ * \brief Converter to DLDataType
+ * \return the result.
+ */
operator DLDataType () const {
return data_;
}
- /*! \return the maximum possible value in this format. */
- TVM_DLL Expr max() const;
- /*! \return the minimum possible value in this format. */
- TVM_DLL Expr min() const;
+
+ /*!
+ * \brief Construct an int type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes.
+ * \return The constructed data type.
+ */
+ static DataType Int(int bits, int lanes = 1) {
+ return DataType(kDLInt, bits, lanes);
+ }
+ /*!
+ * \brief Construct an uint type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+ static DataType UInt(int bits, int lanes = 1) {
+ return DataType(kDLUInt, bits, lanes);
+ }
+ /*!
+ * \brief Construct an uint type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+ static DataType Float(int bits, int lanes = 1) {
+ return DataType(kDLFloat, bits, lanes);
+ }
+ /*!
+ * \brief Construct a bool type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+ static DataType Bool(int lanes = 1) {
+ return DataType::UInt(1, lanes);
+ }
+ /*!
+ * \brief Construct a handle type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+ static DataType Handle(int bits = 64, int lanes = 1) {
+ return DataType(kHandle, bits, lanes);
+ }
+ /*!
+ * \brief Get the corresponding type of TVMShapeIndex.
+ * \return The type of TVM shape index.
+ */
+ static DataType ShapeIndex() {
+ if (std::is_signed<tvm_index_t>::value) {
+ return DataType::Int(sizeof(tvm_index_t) * 8);
+ } else {
+ return DataType::UInt(sizeof(tvm_index_t) * 8);
+ }
+ }
private:
DLDataType data_;
};
/*!
- * \brief Construct an int type.
- * \param bits The number of bits in the type.
- * \param lanes The number of lanes.
- * \return The constructed data type.
- */
-inline DataType Int(int bits, int lanes = 1) {
- return DataType(kDLInt, bits, lanes);
-}
-
-/*!
- * \brief Construct an uint type.
- * \param bits The number of bits in the type.
- * \param lanes The number of lanes
- * \return The constructed data type.
- */
-inline DataType UInt(int bits, int lanes = 1) {
- return DataType(kDLUInt, bits, lanes);
-}
-
-/*!
- * \brief Construct a bool type.
- * \param lanes The number of lanes
- * \return The constructed data type.
- */
-inline DataType Bool(int lanes = 1) {
- return UInt(1, lanes);
-}
-
-/*!
- * \brief Construct an uint type.
- * \param bits The number of bits in the type.
- * \param lanes The number of lanes
- * \return The constructed data type.
- */
-inline DataType Float(int bits, int lanes = 1) {
- return DataType(kDLFloat, bits, lanes);
-}
-
-/*!
- * \brief Construct a handle type.
- * \param bits The number of bits in the type.
- * \param lanes The number of lanes
- * \return The constructed data type.
- */
-inline DataType Handle(int bits = 64, int lanes = 1) {
- return DataType(kHandle, bits, lanes);
-}
-
-/*!
- * \brief Get the corresponding type of TVMShapeIndex.
- * \return The type of TVM shape index.
- */
-inline DataType TVMShapeIndexType() {
- if (std::is_signed<tvm_index_t>::value) {
- return Int(sizeof(tvm_index_t) * 8);
- } else {
- return UInt(sizeof(tvm_index_t) * 8);
- }
-}
-
-/*!
- * \brief Convert DLDataType to DataType.
- * \param t The original type.
- * \return The conversion result.
- */
-inline DataType TVMType2Type(DLDataType t) {
- return DataType(t.code, t.bits, t.lanes);
-}
-
-/*!
- * \brief Convert DataType to DataType.
- * \param t The original type.
- * \return The conversion result.
- */
-inline DLDataType Type2TVMType(DataType t) {
- return t.operator DLDataType();
-}
-
-/*!
* \brief Get the number of bytes needed in a vector.
* \param dtype The data type.
* \return Number of bytes needed.
@@ -229,19 +227,15 @@ inline DLDataType Type2TVMType(DataType t) {
inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
- if (dtype == Bool()) return 1;
+ if (dtype == DataType::Bool()) return 1;
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
}
-// Overload print function.
-inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*)
- using namespace tvm::runtime;
- return os << dtype.operator DLDataType();
-}
+} // namespace runtime
+
+using DataType = runtime::DataType;
-// Backward compatibility
-using Type = DataType;
} // namespace tvm
-#endif // TVM_DTYPE_H_
+#endif // TVM_RUNTIME_DATA_TYPE_H_
diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h
index 57c4291..1d7db66 100644
--- a/include/tvm/runtime/packed_func.h
+++ b/include/tvm/runtime/packed_func.h
@@ -28,6 +28,11 @@
#include <sstream>
#endif
#include <dmlc/logging.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/object.h>
#include <functional>
#include <tuple>
#include <vector>
@@ -36,10 +41,7 @@
#include <memory>
#include <utility>
#include <type_traits>
-#include "c_runtime_api.h"
-#include "module.h"
-#include "ndarray.h"
-#include "object.h"
+
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
@@ -49,7 +51,6 @@
namespace tvm {
// forward declarations
class Integer;
-class DataType;
class Expr;
namespace runtime {
@@ -629,7 +630,7 @@ class TVMArgValue : public TVMPODValue_ {
typename = typename std::enable_if<
std::is_class<T>::value>::type>
inline operator T() const;
- inline operator tvm::DataType() const;
+ inline operator DataType() const;
inline operator tvm::Expr() const;
inline operator tvm::Integer() const;
};
@@ -834,8 +835,8 @@ class TVMRetValue : public TVMPODValue_ {
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;
// type related
- inline operator tvm::DataType() const;
- inline TVMRetValue& operator=(const tvm::DataType& other);
+ inline operator DataType() const;
+ inline TVMRetValue& operator=(const DataType& other);
private:
template<typename T>
@@ -1048,6 +1049,10 @@ inline TVMType String2TVMType(std::string s) {
return t;
}
+inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
+ return os << dtype.operator DLDataType();
+}
+
inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)
<< "not enough argument passed, "
@@ -1198,7 +1203,7 @@ class TVMArgsSetter {
typename = typename std::enable_if<
extension_type_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const;
- inline void operator()(size_t i, const tvm::DataType& t) const;
+ inline void operator()(size_t i, const DataType& t) const;
private:
/*! \brief The values fields */
@@ -1362,6 +1367,24 @@ inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
values_[i].v_handle = const_cast<T*>(&value);
}
+// PackedFunc support
+inline TVMRetValue& TVMRetValue::operator=(const DataType& t) {
+ return this->operator=(t.operator DLDataType());
+}
+
+inline TVMRetValue::operator DataType() const {
+ return DataType(operator DLDataType());
+}
+
+inline TVMArgValue::operator DataType() const {
+ return DataType(operator DLDataType());
+}
+
+inline void TVMArgsSetter::operator()(
+ size_t i, const DataType& t) const {
+ this->operator()(i, t.operator DLDataType());
+}
+
// extension type handling
template<typename T>
struct ExtTypeInfo {
diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h
index 599d6ff..f44498a 100644
--- a/include/tvm/tensor.h
+++ b/include/tvm/tensor.h
@@ -163,7 +163,7 @@ class TensorNode : public Node {
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief data type in the content of the tensor */
- Type dtype;
+ DataType dtype;
/*! \brief the source operation, can be None */
Operation op;
/*! \brief the output index from source operation */
@@ -178,7 +178,7 @@ class TensorNode : public Node {
v->Visit("value_index", &value_index);
}
TVM_DLL static Tensor make(Array<Expr> shape,
- Type dtype,
+ DataType dtype,
Operation op,
int value_index);
diff --git a/nnvm/include/nnvm/compiler/util.h b/nnvm/include/nnvm/compiler/util.h
index f108ff1..63d0655 100644
--- a/nnvm/include/nnvm/compiler/util.h
+++ b/nnvm/include/nnvm/compiler/util.h
@@ -41,7 +41,7 @@ namespace compiler {
inline tvm::Array<tvm::Expr> ShapeToArray(TShape shape) {
tvm::Array<tvm::Expr> result;
for (auto i : shape) {
- result.push_back(tvm::make_const(tvm::Int(32), i));
+ result.push_back(tvm::make_const(tvm::DataType::Int(32), i));
}
return result;
}
diff --git a/nnvm/src/compiler/alter_op_layout.cc b/nnvm/src/compiler/alter_op_layout.cc
index abc0022..8a6694f 100644
--- a/nnvm/src/compiler/alter_op_layout.cc
+++ b/nnvm/src/compiler/alter_op_layout.cc
@@ -46,7 +46,7 @@ tvm::Array<tvm::Tensor> GetTensorInfo(const IndexedGraph& idx_graph,
tvm::Array<tvm::Expr> shape;
for (int64_t x : shape_vec[idx_graph.entry_id(nid, i)]) {
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
- shape.push_back(tvm::make_const(tvm::Int(32), x));
+ shape.push_back(tvm::make_const(tvm::DataType::Int(32), x));
}
vec.push_back(tvm::placeholder(
shape, GetTVMType(dtype_vec[idx_graph.entry_id(nid, i)])));
diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc
index 67af852..82d8ff3 100644
--- a/nnvm/src/compiler/compile_engine.cc
+++ b/nnvm/src/compiler/compile_engine.cc
@@ -47,52 +47,52 @@ using namespace tvm;
* \param type the tvm type.
* \return corresponding DLDataType
*/
-int GetTypeFlag(tvm::Type type) {
- if (type == tvm::Float(32)) return 0;
- if (type == tvm::Float(64)) return 1;
- if (type == tvm::Float(16)) return 2;
- if (type == tvm::UInt(8)) return 3;
- if (type == tvm::Int(32)) return 4;
- if (type == tvm::Int(8)) return 5;
- if (type == tvm::Int(64)) return 6;
- if (type == tvm::Int(16)) return 7;
- if (type == tvm::UInt(16)) return 8;
- if (type == tvm::UInt(32)) return 9;
- if (type == tvm::UInt(64)) return 10;
- if (type == tvm::UInt(1)) return 11;
+int GetTypeFlag(tvm::DataType type) {
+ if (type == tvm::DataType::Float(32)) return 0;
+ if (type == tvm::DataType::Float(64)) return 1;
+ if (type == tvm::DataType::Float(16)) return 2;
+ if (type == tvm::DataType::UInt(8)) return 3;
+ if (type == tvm::DataType::Int(32)) return 4;
+ if (type == tvm::DataType::Int(8)) return 5;
+ if (type == tvm::DataType::Int(64)) return 6;
+ if (type == tvm::DataType::Int(16)) return 7;
+ if (type == tvm::DataType::UInt(16)) return 8;
+ if (type == tvm::DataType::UInt(32)) return 9;
+ if (type == tvm::DataType::UInt(64)) return 10;
+ if (type == tvm::DataType::UInt(1)) return 11;
LOG(FATAL) << "cannot convert " << type;
return 0;
}
// convert from type flag to tvm type.
-Type GetTVMType(int type_flag) {
+DataType GetTVMType(int type_flag) {
switch (type_flag) {
case 0:
- return tvm::Float(32);
+ return tvm::DataType::Float(32);
case 1:
- return tvm::Float(64);
+ return tvm::DataType::Float(64);
case 2:
- return tvm::Float(16);
+ return tvm::DataType::Float(16);
case 3:
- return tvm::UInt(8);
+ return tvm::DataType::UInt(8);
case 4:
- return tvm::Int(32);
+ return tvm::DataType::Int(32);
case 5:
- return tvm::Int(8);
+ return tvm::DataType::Int(8);
case 6:
- return tvm::Int(64);
+ return tvm::DataType::Int(64);
case 7:
- return tvm::Int(16);
+ return tvm::DataType::Int(16);
case 8:
- return tvm::UInt(16);
+ return tvm::DataType::UInt(16);
case 9:
- return tvm::UInt(32);
+ return tvm::DataType::UInt(32);
case 10:
- return tvm::UInt(64);
+ return tvm::DataType::UInt(64);
case 11:
- return tvm::UInt(1);
+ return tvm::DataType::UInt(1);
default:
LOG(FATAL) << "unknown type_flag=" << type_flag;
- return Float(32);
+ return DataType::Float(32);
}
}
@@ -216,7 +216,7 @@ class CompileEngine {
Array<Expr> shape;
for (int64_t x : shape_vec[idx.entry_id(nid, i)]) {
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
- shape.push_back(make_const(Int(32), x));
+ shape.push_back(make_const(DataType::Int(32), x));
}
out_info.push_back(
placeholder(shape,
diff --git a/nnvm/src/compiler/compile_engine.h b/nnvm/src/compiler/compile_engine.h
index 8151f6c..b4fec10 100644
--- a/nnvm/src/compiler/compile_engine.h
+++ b/nnvm/src/compiler/compile_engine.h
@@ -117,7 +117,7 @@ GraphFunc GraphLower(Graph graph,
* \param type the tvm type
* \return corresponding DLDataType
*/
-int GetTypeFlag(tvm::Type type);
+int GetTypeFlag(tvm::DataType type);
/*!
* \brief Get TVM Type from type flag
@@ -125,7 +125,7 @@ int GetTypeFlag(tvm::Type type);
* \param type_flag the type flag
* \return corresponding TVM type
*/
-tvm::Type GetTVMType(int type_flag);
+tvm::DataType GetTVMType(int type_flag);
} // namespace compiler
} // namespace nnvm
diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc
index 1b4a8e1..f6c1332 100644
--- a/nnvm/src/compiler/graph_fuse.cc
+++ b/nnvm/src/compiler/graph_fuse.cc
@@ -352,17 +352,17 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
prod *= x;
}
CHECK_LE(prod, static_cast<int64_t>(std::numeric_limits<int>::max()));
- shape.push_back(make_const(Int(32), prod));
+ shape.push_back(make_const(DataType::Int(32), prod));
} else {
for (int64_t x : shape_vec[idx.entry_id(e)]) {
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
- shape.push_back(make_const(Int(32), x));
+ shape.push_back(make_const(DataType::Int(32), x));
}
}
std::ostringstream os_name;
os_name << "input" << fe.imap.size();
Tensor data = placeholder(
- shape, TVMType2Type(GetDLType(dtype_vec[idx.entry_id(e)])),
+ shape, DataType(GetDLType(dtype_vec[idx.entry_id(e)])),
os_name.str());
NodeEntry garg = Symbol::CreateVariable(os_name.str()).outputs[0];
fe.imap[e] = garg;
diff --git a/nnvm/src/compiler/graph_fuse.h b/nnvm/src/compiler/graph_fuse.h
index ce7da82..dd8d5d5 100644
--- a/nnvm/src/compiler/graph_fuse.h
+++ b/nnvm/src/compiler/graph_fuse.h
@@ -47,7 +47,7 @@ enum class FuseRule {
* \return corresponding DLDataType
*/
inline DLDataType GetDLType(int type_flag) {
- return tvm::Type2TVMType(GetTVMType(type_flag));
+ return GetTVMType(type_flag);
}
struct INodeEntryHash {
diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc
index 63b2d45..1864ccd 100644
--- a/nnvm/src/top/nn/nn.cc
+++ b/nnvm/src/top/nn/nn.cc
@@ -631,11 +631,11 @@ NNVM_REGISTER_OP(pad)
<< "Illegal pad_width";
Array<tvm::Expr> pad_before;
for (size_t i = 0; i < pad_width.ndim(); ++i) {
- pad_before.push_back(tvm::make_const(tvm::Int(32), pad_width[i][0]));
+ pad_before.push_back(tvm::make_const(tvm::DataType::Int(32), pad_width[i][0]));
}
Array<tvm::Expr> pad_after;
for (size_t i = 0; i < pad_width.ndim(); ++i) {
- pad_after.push_back(tvm::make_const(tvm::Int(32), pad_width[i][1]));
+ pad_after.push_back(tvm::make_const(tvm::DataType::Int(32), pad_width[i][1]));
}
return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after,
tvm::make_const(inputs[0]->dtype, param.pad_value)) };
diff --git a/nnvm/src/top/tensor/elemwise.cc b/nnvm/src/top/tensor/elemwise.cc
index 7a79db0..5ac6d91 100644
--- a/nnvm/src/top/tensor/elemwise.cc
+++ b/nnvm/src/top/tensor/elemwise.cc
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -482,7 +482,7 @@ NNVM_REGISTER_INIT_OP(full)
const Array<Tensor>& out_info) {
const InitOpWithScalarParam& param = nnvm::get<InitOpWithScalarParam>(attrs.parsed);
Array<Expr> shape = ShapeToArray(param.shape);
- Type dtype = GetTVMType(param.dtype);
+ DataType dtype = GetTVMType(param.dtype);
Expr fill_value = tvm::make_const(dtype, param.fill_value);
return Array<Tensor>{ topi::full(shape, dtype, fill_value) };
})
@@ -505,7 +505,7 @@ NNVM_REGISTER_INIT_OP(zeros)
const Array<Tensor>& out_info) {
const InitOpParam& param = nnvm::get<InitOpParam>(attrs.parsed);
Array<Expr> shape = ShapeToArray(param.shape);
- Type dtype = GetTVMType(param.dtype);
+ DataType dtype = GetTVMType(param.dtype);
Expr fill_value = tvm::make_const(dtype, 0);
return Array<Tensor>{ topi::full(shape, dtype, fill_value) };
})
@@ -528,7 +528,7 @@ NNVM_REGISTER_INIT_OP(ones)
const Array<Tensor>& out_info) {
const InitOpParam& param = nnvm::get<InitOpParam>(attrs.parsed);
Array<Expr> shape = ShapeToArray(param.shape);
- Type dtype = GetTVMType(param.dtype);
+ DataType dtype = GetTVMType(param.dtype);
Expr fill_value = tvm::make_const(dtype, 1);
return Array<Tensor>{ topi::full(shape, dtype, fill_value) };
})
@@ -950,8 +950,8 @@ Example::
const Array<Tensor>& out_info) {
const ClipParam params = get<ClipParam>(attrs.parsed);
return Array<Tensor>{
- topi::clip(inputs[0], tvm::make_const(tvm::Float(32), params.a_min),
- tvm::make_const(tvm::Float(32), params.a_max)) };
+ topi::clip(inputs[0], tvm::make_const(tvm::DataType::Float(32), params.a_min),
+ tvm::make_const(tvm::DataType::Float(32), params.a_max)) };
})
.add_argument("data", "NDArray-or-Symbol", "Input array.")
.add_arguments(ClipParam::__FIELDS__())
diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc
index 8b85c4e..a83f447 100644
--- a/nnvm/src/top/tensor/transform.cc
+++ b/nnvm/src/top/tensor/transform.cc
@@ -477,7 +477,7 @@ NNVM_REGISTER_OP(cast)
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const CastParam& param = nnvm::get<CastParam>(attrs.parsed);
- Type dtype = GetTVMType(param.dtype);
+ DataType dtype = GetTVMType(param.dtype);
return Array<Tensor>{ topi::cast(inputs[0], dtype) };
})
.set_support_level(1);
@@ -1261,8 +1261,8 @@ NNVM_REGISTER_OP(slice_like)
Array<Expr> target_shape = inputs[1]->shape;
Array<Expr> begin_idx, end_idx, strides;
for (size_t i = 0; i < src_shape.size(); ++i) {
- begin_idx.push_back(make_const(tvm::Int(32), 0));
- strides.push_back(make_const(tvm::Int(32), 1));
+ begin_idx.push_back(make_const(tvm::DataType::Int(32), 0));
+ strides.push_back(make_const(tvm::DataType::Int(32), 1));
}
end_idx = Array<Expr>(src_shape);
if (param.axis.ndim() == 0) {
diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc
index 9312c55..03f37b1 100644
--- a/src/api/api_ir.cc
+++ b/src/api/api_ir.cc
@@ -30,7 +30,7 @@ namespace tvm {
namespace ir {
TVM_REGISTER_API("_Var")
-.set_body_typed<VarExpr(std::string, Type)>([](std::string s, Type t) {
+.set_body_typed<VarExpr(std::string, DataType)>([](std::string s, DataType t) {
return Variable::make(t, s);
});
@@ -75,7 +75,7 @@ TVM_REGISTER_API("make.For")
TVM_REGISTER_API("make.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) {
- Type t = args[0];
+ DataType t = args[0];
if (args.size() == 3) {
*ret = Load::make(t, args[1], args[2], const_true(t.lanes()));
} else {
@@ -87,7 +87,7 @@ TVM_REGISTER_API("make.Store")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Expr value = args[1];
if (args.size() == 3) {
- *ret = Store::make(args[0], value, args[2], const_true(value.type().lanes()));
+ *ret = Store::make(args[0], value, args[2], const_true(value.dtype().lanes()));
} else {
*ret = Store::make(args[0], value, args[2], args[3]);
}
@@ -97,8 +97,8 @@ TVM_REGISTER_API("make.Realize")
.set_body_typed(Realize::make);
TVM_REGISTER_API("make.Call")
-.set_body_typed<Expr(Type, std::string, Array<Expr>, int, FunctionRef, int)>([](
- Type type, std::string name,
+.set_body_typed<Expr(DataType, std::string, Array<Expr>, int, FunctionRef, int)>([](
+ DataType type, std::string name,
Array<Expr> args, int call_type,
FunctionRef func, int value_index
) {
@@ -166,8 +166,8 @@ TVM_REGISTER_API("make.Block")
// has default args
TVM_REGISTER_API("make.Allocate")
- .set_body_typed<Stmt(VarExpr, Type, Array<Expr>, Expr, Stmt)>([](
- VarExpr buffer_var, Type type, Array<Expr> extents, Expr condition, Stmt body
+ .set_body_typed<Stmt(VarExpr, DataType, Array<Expr>, Expr, Stmt)>([](
+ VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body
){
return Allocate::make(buffer_var, type, extents, condition, body);
});
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
index f3d6c5f..9cb797f 100644
--- a/src/api/api_lang.cc
+++ b/src/api/api_lang.cc
@@ -35,10 +35,10 @@
namespace tvm {
TVM_REGISTER_API("_min_value")
-.set_body_method(&DataType::min);
+.set_body_typed(min_value);
TVM_REGISTER_API("_max_value")
-.set_body_method(&DataType::max);
+.set_body_typed(max_value);
TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -287,8 +287,8 @@ TVM_REGISTER_API("_TensorHash")
});
TVM_REGISTER_API("_Placeholder")
-.set_body_typed<Tensor(Array<Expr>, Type, std::string)>([](
- Array<Expr> shape, Type dtype, std::string name
+.set_body_typed<Tensor(Array<Expr>, DataType, std::string)>([](
+ Array<Expr> shape, DataType dtype, std::string name
) {
return placeholder(shape, dtype, name);
});
diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc
index 31fedcc..19f0452 100644
--- a/src/arithmetic/bound_deducer.cc
+++ b/src/arithmetic/bound_deducer.cc
@@ -132,7 +132,7 @@ class BoundDeducer: public IRVisitor {
Expr target_var = left ? op->a : op->b;
SignType sign_operand;
- if (operand.type().is_uint()) {
+ if (operand.dtype().is_uint()) {
sign_operand = kPositive;
} else {
sign_operand = expr_map_[operand].sign_type();
diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc
index 1b576a6..022dd8e 100644
--- a/src/arithmetic/canonical_simplify.cc
+++ b/src/arithmetic/canonical_simplify.cc
@@ -115,7 +115,7 @@ class SplitExprNode : public CanonicalExprNode {
Expr NormalizeWithScale(int64_t sscale) const {
Expr res = this->index;
- Type dtype = this->type;
+ DataType dtype = this->dtype;
if (this->scale == 0) {
return make_const(dtype, 0);
}
@@ -190,9 +190,9 @@ class SumExprNode : public CanonicalExprNode {
Expr Normalize() const final {
// quick path 1.
if (this->args.size() == 0) {
- return make_const(this->type, this->base);
+ return make_const(this->dtype, this->base);
}
- return Normalize_(this->type,
+ return Normalize_(this->dtype,
SimplifySplitExprs(args),
base);
}
@@ -379,7 +379,7 @@ class SumExprNode : public CanonicalExprNode {
std::stable_sort(args.begin(), args.end(), fcompare);
return args;
}
- static Expr Normalize_(Type dtype,
+ static Expr Normalize_(DataType dtype,
const std::vector<SplitExpr>& args,
int64_t base) {
// Positive scales first
@@ -508,7 +508,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
expr = op->Normalize();
}
NodePtr<SplitExprNode> n = make_node<SplitExprNode>();
- n->type = expr.type();
+ n->dtype = expr.dtype();
n->index = std::move(expr);
n->div_mode = kTruncDiv;
return SplitExpr(n);
@@ -545,7 +545,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
return GetRef<SumExpr>(op);
}
NodePtr<SumExprNode> n = make_node<SumExprNode>();
- n->type = expr.type();
+ n->dtype = expr.dtype();
if (const auto* op = expr.as<IntImm>()) {
n->base = op->value;
return SumExpr(n);
@@ -560,7 +560,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
Expr CanonicalSimplifier::Impl::
Mutate_(const Add* op, const Expr& self) {
- if (!IsIndexType(op->type)) {
+ if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
}
// normalize
@@ -586,7 +586,7 @@ Mutate_(const Add* op, const Expr& self) {
Expr CanonicalSimplifier::Impl::
Mutate_(const Sub* op, const Expr& self) {
- if (!IsIndexType(op->type)) {
+ if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
}
// normalize
@@ -613,7 +613,7 @@ Mutate_(const Sub* op, const Expr& self) {
Expr CanonicalSimplifier::Impl::
Mutate_(const Mul* op, const Expr& self) {
- if (!IsIndexType(op->type)) {
+ if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
}
// normalize
@@ -657,8 +657,8 @@ SeparateDivisibleParts(const SumExprNode* psum,
SumExpr* out_non_divisible) {
auto divisible = make_node<SumExprNode>();
auto non_divisible = make_node<SumExprNode>();
- divisible->type = psum->type;
- non_divisible->type = psum->type;
+ divisible->dtype = psum->dtype;
+ non_divisible->dtype = psum->dtype;
if (psum->base % coeff == 0) {
divisible->base = psum->base;
@@ -698,11 +698,11 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
return lhs;
} else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) {
// (x % c1) / c2 => 0 when c2 >= c1
- return ToSplitExpr(make_zero(lhs.type()));
+ return ToSplitExpr(make_zero(lhs.dtype()));
} else {
// move the upper_factor modular into index.
lhs.CopyOnWrite()->index =
- ModImpl(lhs->index, make_const(lhs.type(), lhs->upper_factor), div_mode);
+ ModImpl(lhs->index, make_const(lhs.dtype(), lhs->upper_factor), div_mode);
lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf;
lhs.CopyOnWrite()->scale = 1;
lhs.CopyOnWrite()->lower_factor *= scaled_cval;
@@ -720,7 +720,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
Expr CanonicalSimplifier::Impl::
Mutate_(const Div* op, const Expr& self) {
- if (!IsIndexType(op->type)) {
+ if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
}
@@ -764,7 +764,7 @@ Mutate_(const Div* op, const Expr& self) {
// if a >= 0 && a < cval, then result == 0
auto cbound = analyzer_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) {
- return make_zero(a.type());
+ return make_zero(a.dtype());
}
}
return SplitDivConst(ToSplitExpr(std::move(a)), cval, kTruncDiv);
@@ -781,7 +781,7 @@ Mutate_(const Div* op, const Expr& self) {
Expr CanonicalSimplifier::Impl::
Mutate_(const FloorDiv* op, const Expr& self) {
- if (!IsIndexType(op->type)) {
+ if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
}
Expr a = this->CanonicalMutate(op->a);
@@ -820,7 +820,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
// if a >= 0 && a < cval, then result == 0
auto cbound = analyzer_->const_int_bound(Normalize(a));
if (cbound->min_value >= 0 && cbound->max_value < cval) {
- return make_zero(a.type());
+ return make_zero(a.dtype());
}
}
return SplitDivConst(ToSplitExpr(std::move(a)), cval, kFloorDiv);
@@ -859,7 +859,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
if (new_upper_factor < lhs->upper_factor &&
lhs->upper_factor != SplitExprNode::kPosInf) {
auto updated = ToSplitExpr(Mutate(ModImpl(
- lhs->index, make_const(lhs.type(), new_upper_factor), div_mode)));
+ lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode)));
// re-apply the lower_factor
if (lhs->lower_factor != 1) {
return SplitDivConst(updated, lhs->lower_factor, div_mode);
@@ -887,7 +887,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
Expr CanonicalSimplifier::Impl::
Mutate_(const Mod* op, const Expr& self) {
- if (!IsIndexType(op->type)) {
+ if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
}
// normalize
@@ -906,7 +906,7 @@ Mutate_(const Mod* op, const Expr& self) {
SumExpr lhs, extra;
SeparateDivisibleParts(psum, cval, &lhs, &extra);
if (extra->IsZero()) {
- return make_zero(a.type());
+ return make_zero(a.dtype());
}
// both lhs and extra are non-negative
if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
@@ -957,7 +957,7 @@ Mutate_(const Mod* op, const Expr& self) {
Expr CanonicalSimplifier::Impl::
Mutate_(const FloorMod* op, const Expr& self) {
- if (!IsIndexType(op->type)) {
+ if (!IsIndexType(op->dtype)) {
return Rewriter::Mutate_(op, self);
}
// normalize
diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h
index 4b001cf..806587a 100644
--- a/src/arithmetic/compute_expr.h
+++ b/src/arithmetic/compute_expr.h
@@ -56,7 +56,7 @@ inline Expr ComputeReduce(
const Array<Expr>& values, Expr empty_value);
inline bool GetConst(Expr e, int64_t* out) {
- if (e.type().is_vector()) return false;
+ if (e.dtype().is_vector()) return false;
const int64_t* v = as_const_int(e);
if (v) {
*out = *v; return true;
diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h
index 86f1927..93bf708 100644
--- a/src/arithmetic/const_fold.h
+++ b/src/arithmetic/const_fold.h
@@ -70,7 +70,7 @@ inline Expr TryConstFold(Expr a);
* \param type The type to represent index.
* \return the checked result.
*/
-inline bool IsIndexType(const Type& type) {
+inline bool IsIndexType(const DataType& type) {
return type.is_int() && type.lanes() == 1 &&
(type.bits() == 32 || type.bits() == 64);
}
@@ -92,8 +92,8 @@ inline bool IsIndexType(const Type& type) {
using ir::UIntImm; \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
- const Type& ta = a.type(); \
- const Type& tb = b.type(); \
+ const DataType& ta = a.dtype(); \
+ const DataType& tb = b.dtype(); \
if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \
BODY; \
} \
@@ -103,7 +103,7 @@ inline bool IsIndexType(const Type& type) {
template<>
inline Expr TryConstFold<ir::Add>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a;
@@ -117,7 +117,7 @@ inline Expr TryConstFold<ir::Add>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::Sub>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value);
@@ -129,7 +129,7 @@ inline Expr TryConstFold<ir::Sub>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
if (pa) {
if (pa->value == 1) return b;
@@ -155,7 +155,7 @@ inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) {
// due to division and mod can have different modes
// NOTE: this will assumes truc div.
@@ -184,7 +184,7 @@ inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) {
return IntImm::make(rtype, pa->value % pb->value);
}
@@ -202,7 +202,7 @@ inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm::make(rtype, arith::floordiv(pa->value, pb->value));
@@ -229,7 +229,7 @@ inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::FloorMod>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) {
return IntImm::make(rtype, arith::floormod(pa->value, pb->value));
}
@@ -247,7 +247,7 @@ inline Expr TryConstFold<ir::FloorMod>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
});
@@ -258,7 +258,7 @@ inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
});
@@ -269,8 +269,8 @@ inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::GT>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value);
- if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value);
+ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value > pb->value);
+ if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value > fb->value);
});
return Expr();
}
@@ -278,8 +278,8 @@ inline Expr TryConstFold<ir::GT>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::GE>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value);
- if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value);
+ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value >= pb->value);
+ if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value >= fb->value);
});
return Expr();
}
@@ -287,8 +287,8 @@ inline Expr TryConstFold<ir::GE>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::LT>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value);
- if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value);
+ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value < pb->value);
+ if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value < fb->value);
});
return Expr();
}
@@ -296,8 +296,8 @@ inline Expr TryConstFold<ir::LT>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::LE>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value);
- if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value);
+ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value <= pb->value);
+ if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value <= fb->value);
});
return Expr();
}
@@ -305,8 +305,8 @@ inline Expr TryConstFold<ir::LE>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::EQ>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value);
- if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value);
+ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value == pb->value);
+ if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value == fb->value);
});
return Expr();
}
@@ -314,8 +314,8 @@ inline Expr TryConstFold<ir::EQ>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::NE>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value);
- if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value);
+ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value != pb->value);
+ if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value != fb->value);
});
return Expr();
}
@@ -349,7 +349,7 @@ inline Expr TryConstFold<ir::Not>(Expr a) {
using ir::UIntImm;
const UIntImm* pa = a.as<UIntImm>();
if (pa) {
- return UIntImm::make(UInt(1), !(pa->value));
+ return UIntImm::make(DataType::UInt(1), !(pa->value));
}
return Expr();
}
diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc
index 6e11969..c051910 100644
--- a/src/arithmetic/const_int_bound.cc
+++ b/src/arithmetic/const_int_bound.cc
@@ -125,7 +125,7 @@ class ConstIntBoundAnalyzer::Impl :
// Override visitor behaviors
Entry VisitExprDefault_(const Node* op) final {
return Everything(
- static_cast<const ExprNode*>(op)->type);
+ static_cast<const ExprNode*>(op)->dtype);
}
Entry VisitExpr(const Expr& expr) final {
@@ -142,7 +142,7 @@ class ConstIntBoundAnalyzer::Impl :
Entry VisitExpr_(const Cast* op) final {
Entry a = VisitExpr(op->value);
- Entry b = Everything(op->type);
+ Entry b = Everything(op->dtype);
return Intersect(a, b);
}
@@ -154,7 +154,7 @@ class ConstIntBoundAnalyzer::Impl :
if (op->value <= static_cast<uint64_t>(kPosInf)) {
return MakeBound(op->value, op->value);
} else {
- return Everything(op->type);
+ return Everything(op->dtype);
}
}
@@ -211,7 +211,7 @@ class ConstIntBoundAnalyzer::Impl :
CHECK(!b.is_const(0)) << "mod by zero";
// mod by negative value is rare,
// and we just use the simpliest rule.
- return Everything(op->type);
+ return Everything(op->dtype);
}
}
@@ -242,7 +242,7 @@ class ConstIntBoundAnalyzer::Impl :
CHECK(!b.is_const(0)) << "floormod by zero";
// mod by negative value is rare,
// and we just use the simpliest rule.
- return Everything(op->type);
+ return Everything(op->dtype);
}
}
@@ -278,7 +278,7 @@ class ConstIntBoundAnalyzer::Impl :
} else if (op->is_intrinsic(Call::bitwise_and)) {
return VisitBitwiseAnd(op);
} else {
- return Everything(op->type);
+ return Everything(op->dtype);
}
}
@@ -288,7 +288,7 @@ class ConstIntBoundAnalyzer::Impl :
if (it != var_map_.end()) {
return it->second;
} else {
- return Everything(op->type);
+ return Everything(op->dtype);
}
}
@@ -311,7 +311,7 @@ class ConstIntBoundAnalyzer::Impl :
if (a.min_value >= 0) {
return MakeBound(0, a.max_value);
}
- return Everything(op->type);
+ return Everything(op->dtype);
}
}
@@ -466,7 +466,7 @@ class ConstIntBoundAnalyzer::Impl :
* \param dtype The data type.
* \return Bound that represent everything dtype can represent.
*/
- static Entry Everything(Type dtype) {
+ static Entry Everything(DataType dtype) {
if (!dtype.is_int() && !dtype.is_uint()) {
return MakeBound(kNegInf, kPosInf);
}
diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc
index 8c7f4f2..cf37545 100644
--- a/src/arithmetic/detect_linear_equation.cc
+++ b/src/arithmetic/detect_linear_equation.cc
@@ -53,10 +53,10 @@ class LinearEqDetector
*ret = VisitExpr(e, e);
if (fail_) return false;
if (!ret->base.defined()) {
- ret->base = make_zero(var_.type());
+ ret->base = make_zero(var_.dtype());
}
if (!ret->coeff.defined()) {
- ret->coeff = make_zero(var_.type());
+ ret->coeff = make_zero(var_.dtype());
}
return true;
}
@@ -100,7 +100,7 @@ class LinearEqDetector
LinearEqEntry VisitExpr_(const Variable* op, const Expr& e) final {
LinearEqEntry ret;
if (op == var_.get()) {
- ret.coeff = make_const(op->type, 1);
+ ret.coeff = make_const(op->dtype, 1);
} else {
ret.base = e;
}
@@ -190,16 +190,16 @@ bool DetectClipBound(
// canonical form: exp >= 0
Expr canonical;
if (const LT* op = cond.as<LT>()) {
- if (!op->a.type().is_int()) return false;
- canonical = op->b - op->a - make_const(op->a.type(), 1);
+ if (!op->a.dtype().is_int()) return false;
+ canonical = op->b - op->a - make_const(op->a.dtype(), 1);
} else if (const LE* op = cond.as<LE>()) {
- if (!op->a.type().is_int()) return false;
+ if (!op->a.dtype().is_int()) return false;
canonical = op->b - op->a;
} else if (const GT* op = cond.as<GT>()) {
- if (!op->a.type().is_int()) return false;
- canonical = op->a - op->b - make_const(op->a.type(), 1);
+ if (!op->a.dtype().is_int()) return false;
+ canonical = op->a - op->b - make_const(op->a.dtype(), 1);
} else if (const GE* op = cond.as<GE>()) {
- if (!op->a.type().is_int()) return false;
+ if (!op->a.dtype().is_int()) return false;
canonical = op->a - op->b;
} else {
return false;
diff --git a/src/arithmetic/domain_touched.cc b/src/arithmetic/domain_touched.cc
index c28346e..947f005 100644
--- a/src/arithmetic/domain_touched.cc
+++ b/src/arithmetic/domain_touched.cc
@@ -72,7 +72,7 @@ class FuncTouchedDomain final : public IRVisitor {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
const Variable* var = thread_axis->var.get();
- dom_map_[var] = IntSet::range(Range(make_zero(op->value.type()), op->value));
+ dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
IRVisitor::Visit_(op);
dom_map_.erase(var);
} else {
diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc
index 9f8effb..e4f2042 100644
--- a/src/arithmetic/int_set.cc
+++ b/src/arithmetic/int_set.cc
@@ -33,8 +33,8 @@
namespace tvm {
namespace arith {
-Expr SymbolicLimits::pos_inf_ = Var("pos_inf", Handle());
-Expr SymbolicLimits::neg_inf_ = Var("neg_inf", Handle());
+Expr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle());
+Expr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle());
IntervalSet::IntervalSet(Expr min_value, Expr max_value) {
auto node = make_node<IntervalSetNode>();
@@ -54,8 +54,8 @@ TVM_REGISTER_API("arith._make_IntervalSet")
IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
Expr max_value = min(a->max_value, b->max_value);
Expr min_value = max(a->min_value, b->min_value);
- if ((max_value.type().is_int() || max_value.type().is_uint()) &&
- (min_value.type().is_int() || min_value.type().is_uint()) &&
+ if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) &&
+ (min_value.dtype().is_int() || min_value.dtype().is_uint()) &&
analyzer->CanProveGreaterEqual(min_value - max_value, 1)) {
return IntervalSet::Empty();
} else {
@@ -105,8 +105,8 @@ inline IntervalSet Combine(Analyzer* analyzer,
return IntervalSet::SinglePoint(res);
}
if (is_logical_op<Op>::value) {
- return IntervalSet(make_const(a->min_value.type(), 0),
- make_const(a->min_value.type(), 1));
+ return IntervalSet(make_const(a->min_value.dtype(), 0),
+ make_const(a->min_value.dtype(), 1));
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
@@ -177,7 +177,7 @@ inline IntervalSet Combine<ir::Mul>(Analyzer* analyzer,
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::Select;
- Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
+ Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
Expr e1 = a->min_value * b->min_value;
Expr e2 = a->max_value * b->min_value;
return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
@@ -212,7 +212,7 @@ inline IntervalSet Combine<ir::Div>(Analyzer* analyzer,
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::Select;
- Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
+ Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
Expr e1 = a->min_value / b->min_value;
Expr e2 = a->max_value / b->min_value;
return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
@@ -242,7 +242,7 @@ inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
// is the case of our application.
// TODO(tqchen): add bound constraints for a.
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
- return IntervalSet(make_zero(divisor.type()), divisor - 1);
+ return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
} else {
Expr bound = abs(divisor) - 1;
return IntervalSet(-bound, bound);
@@ -278,7 +278,7 @@ inline IntervalSet Combine<ir::FloorDiv>(Analyzer* analyzer,
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::Select;
- Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
+ Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
Expr e1 = floordiv(a->min_value, b->min_value);
Expr e2 = floordiv(a->max_value, b->min_value);
return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
@@ -304,7 +304,7 @@ inline IntervalSet Combine<ir::FloorMod>(Analyzer* analyzer,
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
- return IntervalSet(make_zero(divisor.type()), divisor - 1);
+ return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
} else {
Expr bound = abs(divisor) - 1;
return IntervalSet(-bound, bound);
@@ -476,7 +476,7 @@ class IntervalSetEvaluator :
IntervalSet base = Eval(op->base);
PVar<Integer> stride;
if (stride.Match(op->stride)) {
- Type t = op->base.type();
+ DataType t = op->base.dtype();
int64_t vstride = stride.Eval()->value;
if (vstride> 0) {
return Combine<Add>(
diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc
index cda9d58..0d4b8f2 100644
--- a/src/arithmetic/ir_mutator_with_analyzer.cc
+++ b/src/arithmetic/ir_mutator_with_analyzer.cc
@@ -140,7 +140,7 @@ Mutate_(const Call* op, const Expr& self) {
false_value.same_as(op->args[2])) {
return self;
} else {
- return Call::make(op->type, op->name,
+ return Call::make(op->dtype, op->name,
{cond, true_value, false_value},
op->call_type);
}
diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h
index f7d5483..fd07a37 100644
--- a/src/arithmetic/pattern_match.h
+++ b/src/arithmetic/pattern_match.h
@@ -291,7 +291,7 @@ class PConstWithTypeLike :
}
Expr Eval() const {
- return make_const(ref_.Eval().type(), value_);
+ return make_const(ref_.Eval().dtype(), value_);
}
private:
@@ -474,7 +474,7 @@ class PCastExpr :
bool Match_(const NodeRef& node) const {
if (const ir::Cast* ptr = node.as<ir::Cast>()) {
- if (!dtype_.Match_(ptr->type)) return false;
+ if (!dtype_.Match_(ptr->dtype)) return false;
if (!value_.Match_(ptr->value)) return false;
return true;
} else {
@@ -730,7 +730,7 @@ class PCallExpr :
#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \
struct OpName { \
static Expr Eval(Array<Expr> args) { \
- return ir::Call::make(args[0].type(), kName, args, \
+ return ir::Call::make(args[0].dtype(), kName, args, \
ir::Call::PureIntrinsic); \
} \
static constexpr const char* kName = IntrinStr; \
@@ -751,7 +751,7 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");
#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \
struct OpName { \
static Expr Eval(Array<Expr> args) { \
- return ir::Call::make(args[0].type(), kName, args, \
+ return ir::Call::make(args[0].dtype(), kName, args, \
ir::Call::PureIntrinsic); \
} \
static constexpr const char* kName = IntrinStr; \
@@ -768,7 +768,7 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
struct PIfThenElseOp {
static Expr Eval(Array<Expr> args) {
return ir::Call::make(
- args[1].type(), kName, args,
+ args[1].dtype(), kName, args,
ir::Call::PureIntrinsic);
}
static constexpr const char* kName = "tvm_if_then_else";
diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc
index b26f833..235306c 100644
--- a/src/arithmetic/rewrite_simplify.cc
+++ b/src/arithmetic/rewrite_simplify.cc
@@ -129,7 +129,7 @@ Mutate_(const Add* op, const Expr& self) {
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes),
ramp(b1 + b2, s1 + s2, lanes));
TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes),
@@ -140,7 +140,7 @@ Mutate_(const Add* op, const Expr& self) {
broadcast(x + y, lanes));
}
- if (IsIndexType(op->type)) {
+ if (IsIndexType(op->dtype)) {
// Index rules
// cancelation rules
TVM_TRY_REWRITE((x - y) + y, x);
@@ -244,7 +244,7 @@ Mutate_(const Sub* op, const Expr& self) {
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes),
ramp(b1 - b2, s1 - s2, lanes));
TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes),
@@ -255,7 +255,7 @@ Mutate_(const Sub* op, const Expr& self) {
broadcast(x - y, lanes));
}
- if (IsIndexType(op->type)) {
+ if (IsIndexType(op->dtype)) {
// Index rules
// cancelation rules
TVM_TRY_REWRITE((x + y) - y, x);
@@ -443,7 +443,7 @@ Mutate_(const Mul* op, const Expr& self) {
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes),
broadcast(x * y, lanes));
TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes),
@@ -452,7 +452,7 @@ Mutate_(const Mul* op, const Expr& self) {
ramp(b1 * x, s1 * x, lanes));
}
- if (IsIndexType(op->type)) {
+ if (IsIndexType(op->dtype)) {
// constant simplification rule
TVM_TRY_REWRITE((x + c1) * c2, x * c2 + c1 * c2);
TVM_TRY_REWRITE((x * c1) * c2, x * (c1 * c2));
@@ -484,12 +484,12 @@ Mutate_(const Div* op, const Expr& self) {
// x / 2.0 = x * 0.5
if (const FloatImm* ptr = op->b.as<FloatImm>()) {
- CHECK(op->type.is_float());
- return op->a * make_const(op->b.type(), 1.0 / ptr->value);
+ CHECK(op->dtype.is_float());
+ return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
}
// Vector rules
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
// NOTE: use div as the pattern also works for float.
TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(div(x, y), lanes));
@@ -512,7 +512,7 @@ Mutate_(const Div* op, const Expr& self) {
}
}
- if (IsIndexType(op->type)) {
+ if (IsIndexType(op->dtype)) {
// Be-aware of the division rules:
// We adopt the default C division uses truncation instead of floordiv.
// This means most rules need to check non-negativeness of the operands.
@@ -524,7 +524,7 @@ Mutate_(const Div* op, const Expr& self) {
if (truncdiv(c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
- return make_const(op->type, truncdiv(c1val, c2val));
+ return make_const(op->dtype, truncdiv(c1val, c2val));
}
// while it is always true for trunc div
@@ -706,7 +706,7 @@ Mutate_(const Mod* op, const Expr& self) {
PVar<int> lanes;
// Vector rules
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(truncmod(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(truncmod(x, y), lanes));
@@ -734,7 +734,7 @@ Mutate_(const Mod* op, const Expr& self) {
}
}
- if (IsIndexType(op->type)) {
+ if (IsIndexType(op->dtype)) {
// Be-aware of the division rules:
// We adopt the default C division uses truncation instead of floordiv.
// This means most rules need to check non-negativeness of the operands.
@@ -762,9 +762,10 @@ Mutate_(const Mod* op, const Expr& self) {
// canonicalization: x % c == x % (-c) for truncated division
// NOTE: trunc div required
- TVM_TRY_RECURSIVE_REWRITE_IF(truncmod(x, c1),
- truncmod(x, PConst<Expr>(make_const(op->type, -c1.Eval()->value))),
- c1.Eval()->value < 0);
+ TVM_TRY_RECURSIVE_REWRITE_IF(
+ truncmod(x, c1),
+ truncmod(x, PConst<Expr>(make_const(op->dtype, -c1.Eval()->value))),
+ c1.Eval()->value < 0);
// try modular analysis
if (truncmod(x, c1).Match(ret)) {
@@ -794,7 +795,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
PVar<int> lanes;
// Vector rules
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(floordiv(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(floordiv(x, y), lanes));
// ramp // bcast
@@ -814,7 +815,7 @@ Mutate_(const FloorDiv* op, const Expr& self) {
}
}
- if (IsIndexType(op->type)) {
+ if (IsIndexType(op->dtype)) {
// Be-aware of the division rules: this is floor division.
TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1), c2), floordiv(x, c1 * c2),
c1.Eval()->value > 0 && c2.Eval()->value > 0);
@@ -939,7 +940,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
PVar<int> lanes;
// Vector rules
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(floormod(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(floormod(x, y), lanes));
@@ -964,7 +965,7 @@ Mutate_(const FloorMod* op, const Expr& self) {
}
}
- if (IsIndexType(op->type)) {
+ if (IsIndexType(op->dtype)) {
// Be-aware of the division rules: we use floordiv/floormod here
TVM_TRY_REWRITE_IF(floormod(x * c1, c2), ZeroWithTypeLike(x),
c2.Eval()->value != 0 &&
@@ -1008,13 +1009,13 @@ Mutate_(const Min* op, const Expr& self) {
PVar<int> lanes;
// vector rule
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(min(x, y), lanes));
TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)),
min(x, broadcast(min(y, z), lanes)));
}
- if (IsIndexType(op->type)) {
+ if (IsIndexType(op->dtype)) {
TVM_TRY_REWRITE(min(x, x), x);
// constant int bound
@@ -1193,13 +1194,13 @@ Mutate_(const Max* op, const Expr& self) {
PVar<int> lanes;
// vector rule
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(max(x, y), lanes));
TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)),
max(x, broadcast(max(y, z), lanes)));
}
- if (IsIndexType(op->type)) {
+ if (IsIndexType(op->dtype)) {
TVM_TRY_REWRITE(max(x, x), x);
// constant int bound
@@ -1366,17 +1367,17 @@ Mutate_(const EQ* op, const Expr& self) {
PVar<int> lanes;
// vector rule
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes),
broadcast(x == y, lanes));
}
- if (IsIndexType(op->a.type())) {
+ if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a - op->b, 0);
if (result == kEQ) {
- return make_const(op->type, true);
+ return make_const(op->dtype, true);
} else if (result == kNE || result == kGT || result == kLT) {
- return make_const(op->type, false);
+ return make_const(op->dtype, false);
}
TVM_TRY_REWRITE(x - c1 == 0, x == c1);
TVM_TRY_REWRITE(c1 - x == 0, x == c1);
@@ -1420,20 +1421,20 @@ Mutate_(const LT* op, const Expr& self) {
PVar<int> lanes;
// vector rule
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes),
broadcast(x < y, lanes));
TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes),
broadcast(x < y, lanes));
}
- if (IsIndexType(op->a.type())) {
+ if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a - op->b, 0);
if (result == kLT) {
- return make_const(op->type, true);
+ return make_const(op->dtype, true);
}
if (result == kEQ || result == kGT || result == kGE) {
- return make_const(op->type, false);
+ return make_const(op->dtype, false);
}
TVM_TRY_REWRITE(x + y < x + z, y < z);
@@ -1571,7 +1572,7 @@ Mutate_(const Not* op, const Expr& self) {
// Pattern var to match any expression
PVar<Expr> x, y;
PVar<int> lanes;
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes));
}
@@ -1600,12 +1601,12 @@ Mutate_(const And* op, const Expr& self) {
PVar<Integer> c1, c2;
PVar<int> lanes;
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes),
broadcast(x && y, lanes));
}
- auto cfalse = PConst<Expr>(make_const(op->type, false));
+ auto cfalse = PConst<Expr>(make_const(op->dtype, false));
TVM_TRY_REWRITE(x == y && x != y, cfalse);
TVM_TRY_REWRITE(x != y && x == y, cfalse);
TVM_TRY_REWRITE(x && !x, cfalse);
@@ -1649,12 +1650,12 @@ Mutate_(const Or* op, const Expr& self) {
PVar<Integer> c1, c2;
PVar<int> lanes;
- if (op->type.lanes() != 1) {
+ if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes),
broadcast(x || y, lanes));
}
- auto ctrue = PConst<Expr>(make_const(op->type, true));
+ auto ctrue = PConst<Expr>(make_const(op->dtype, true));
TVM_TRY_REWRITE(x == y || x != y, ctrue);
TVM_TRY_REWRITE(x != y || x == y, ctrue);
@@ -1720,7 +1721,7 @@ Mutate_(const Call* op, const Expr& self) {
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (Equal(constraint, op->args[0])) {
- return make_const(op->type, true);
+ return make_const(op->dtype, true);
}
}
}
@@ -1741,7 +1742,7 @@ Expr RewriteSimplifier::Impl::
Mutate_(const Cast* op, const Expr& self) {
Expr ret = IRMutator::Mutate_(op, self);
op = ret.as<Cast>();
- return cast(op->type, op->value);
+ return cast(op->dtype, op->value);
}
Expr RewriteSimplifier::Impl::
diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc
index 101d8f1..f66a724 100644
--- a/src/autotvm/touch_extractor.cc
+++ b/src/autotvm/touch_extractor.cc
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -255,10 +255,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
feature_row.push_back(Array<Expr>{std::string("_itervar_"), var});
Array<Expr> attr{std::string("_attr_"),
- FloatImm::make(Float(32), trans(fea.length)),
- IntImm::make(Int(32), fea.nest_level),
- FloatImm::make(Float(32), trans(fea.topdown_product)),
- FloatImm::make(Float(32), trans(fea.bottomup_product)),
+ FloatImm::make(DataType::Float(32), trans(fea.length)),
+ IntImm::make(DataType::Int(32), fea.nest_level),
+ FloatImm::make(DataType::Float(32), trans(fea.topdown_product)),
+ FloatImm::make(DataType::Float(32), trans(fea.bottomup_product)),
};
// one hot annotation
for (int i = 0; i < kNum; i++) {
@@ -268,9 +268,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
// arithmetic
feature_row.push_back(Array<Expr>{std::string("_arith_"),
- FloatImm::make(Float(32), trans(fea.add_ct)),
- FloatImm::make(Float(32), trans(fea.mul_ct)),
- FloatImm::make(Float(32), trans(fea.div_ct)),
+ FloatImm::make(DataType::Float(32), trans(fea.add_ct)),
+ FloatImm::make(DataType::Float(32), trans(fea.mul_ct)),
+ FloatImm::make(DataType::Float(32), trans(fea.div_ct)),
});
// touch map
@@ -282,12 +282,12 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k];
feature_row.push_back(Array<Expr>{k,
- FloatImm::make(Float(32), trans(v.stride)),
- FloatImm::make(Float(32), trans(v.mod)),
- FloatImm::make(Float(32), trans(v.count)),
- FloatImm::make(Float(32), trans(v.reuse)),
- FloatImm::make(Float(32), trans(v.thread_count)),
- FloatImm::make(Float(32), trans(v.thread_reuse)),
+ FloatImm::make(DataType::Float(32), trans(v.stride)),
+ FloatImm::make(DataType::Float(32), trans(v.mod)),
+ FloatImm::make(DataType::Float(32), trans(v.count)),
+ FloatImm::make(DataType::Float(32), trans(v.reuse)),
+ FloatImm::make(DataType::Float(32), trans(v.thread_count)),
+ FloatImm::make(DataType::Float(32), trans(v.thread_reuse)),
});
}
diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h
index e669064..1028b01 100644
--- a/src/autotvm/touch_extractor.h
+++ b/src/autotvm/touch_extractor.h
@@ -91,31 +91,31 @@ class TouchExtractor : public FeatureVisitor {
// arithmetic stats
void Visit_(const Add *op) {
- if (op->type.is_float())
+ if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Sub *op) {
- if (op->type.is_float())
+ if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Mul *op) {
- if (op->type.is_float())
+ if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].mul_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Div *op) {
- if (op->type.is_float())
+ if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op);
}
void Visit_(const Mod *op) {
- if (op->type.is_float())
+ if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
IRVisitor::Visit_(op);
}
diff --git a/src/codegen/build_common.h b/src/codegen/build_common.h
index 8a21aee..b2c8953 100644
--- a/src/codegen/build_common.h
+++ b/src/codegen/build_common.h
@@ -39,7 +39,7 @@ ExtractFuncInfo(const Array<LoweredFunc>& funcs) {
for (LoweredFunc f : funcs) {
runtime::FunctionInfo info;
for (size_t i = 0; i < f->args.size(); ++i) {
- info.arg_types.push_back(Type2TVMType(f->args[i].type()));
+ info.arg_types.push_back(f->args[i].dtype());
}
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc
index a7325a9..ca25731 100644
--- a/src/codegen/build_module.cc
+++ b/src/codegen/build_module.cc
@@ -334,12 +334,12 @@ Target DefaultTargetHost(Target target) {
}
Buffer BufferWithOffsetAlignment(Array<Expr> shape,
- Type dtype,
+ DataType dtype,
std::string name,
int data_alignment,
int offset_factor,
bool compact) {
- auto data = Var(name, Handle());
+ auto data = Var(name, DataType::Handle());
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
@@ -353,7 +353,7 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
Expr elem_offset;
if (offset_factor != 0) {
- elem_offset = Var(name + "_elem_offset", shape[0].type());
+ elem_offset = Var(name + "_elem_offset", shape[0].dtype());
} else {
elem_offset = Expr();
}
diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc
index eab542d..4b95e2c 100644
--- a/src/codegen/codegen_c.cc
+++ b/src/codegen/codegen_c.cc
@@ -79,7 +79,7 @@ void CodeGenC::AddFunction(LoweredFunc f) {
ReserveKeywordsAsUnique();
// add to alloc buffer type.
for (const auto & kv : f->handle_data_type) {
- RegisterHandleType(kv.first.get(), kv.second.type());
+ RegisterHandleType(kv.first.get(), kv.second.dtype());
}
this->stream << "void " << f->name << "(";
@@ -87,7 +87,7 @@ void CodeGenC::AddFunction(LoweredFunc f) {
Var v = f->args[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
- if (v.type().is_handle()) {
+ if (v.dtype().is_handle()) {
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end())
PrintStorageScope(it->second, stream);
@@ -104,7 +104,7 @@ void CodeGenC::AddFunction(LoweredFunc f) {
stream << ' ' << restrict_keyword_;
}
} else {
- PrintType(v.type(), stream);
+ PrintType(v.dtype(), stream);
}
stream << ' ' << vid;
}
@@ -125,14 +125,14 @@ void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*)
if (print_ssa_form_) {
std::ostringstream temp;
VisitExpr(n, temp);
- os << SSAGetID(temp.str(), n.type());
+ os << SSAGetID(temp.str(), n.dtype());
} else {
VisitExpr(n, os);
}
}
void CodeGenC::PrintSSAAssign(
- const std::string& target, const std::string& src, Type t) {
+ const std::string& target, const std::string& src, DataType t) {
PrintType(t, stream);
stream << ' ' << target << " = ";
if (src.length() > 3 &&
@@ -146,7 +146,7 @@ void CodeGenC::PrintSSAAssign(
// Print a reference expression to a buffer.
std::string CodeGenC::GetBufferRef(
- Type t, const Variable* buffer, Expr index) {
+ DataType t, const Variable* buffer, Expr index) {
std::ostringstream os;
std::string vid = GetVarID(buffer);
std::string scope;
@@ -213,7 +213,7 @@ std::string CodeGenC::GetBufferRef(
// Print a reference expression to a buffer.
std::string CodeGenC::GetStructRef(
- Type t, const Expr& buffer, const Expr& index, int kind) {
+ DataType t, const Expr& buffer, const Expr& index, int kind) {
if (kind < intrinsic::kArrKindBound_) {
std::ostringstream os;
os << "(((TVMArray*)";
@@ -265,13 +265,13 @@ std::string CodeGenC::GetStructRef(
}
-bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
+bool CodeGenC::HandleTypeMatch(const Variable* buf_var, DataType t) const {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) return false;
return it->second == t;
}
-void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) {
+void CodeGenC::RegisterHandleType(const Variable* buf_var, DataType t) {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) {
handle_data_type_[buf_var] = t;
@@ -282,13 +282,13 @@ void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) {
}
void CodeGenC::PrintVecElemLoad(const std::string& vec,
- Type t, int i,
+ DataType t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << ".s" << std::hex << i << std::dec;
}
void CodeGenC::PrintVecElemStore(const std::string& vec,
- Type t, int i,
+ DataType t, int i,
const std::string& value) {
this->PrintIndent();
stream << vec << ".s" << std::hex << i
@@ -296,19 +296,19 @@ void CodeGenC::PrintVecElemStore(const std::string& vec,
}
std::string CodeGenC::GetVecLoad(
- Type t, const Variable* buffer, Expr base) {
+ DataType t, const Variable* buffer, Expr base) {
return GetBufferRef(t, buffer, base);
}
void CodeGenC::PrintVecStore(const Variable* buffer,
- Type t, Expr base,
+ DataType t, Expr base,
const std::string& value) {
std::string ref = GetBufferRef(t, buffer, base);
this->PrintIndent();
stream << ref << " = " << value << ";\n";
}
-std::string CodeGenC::CastFromTo(std::string value, Type from, Type target) {
+std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType target) {
if (from == target) return value;
std::ostringstream os;
os << "((";
@@ -328,7 +328,7 @@ void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { /
CHECK_EQ(scope, "global");
}
-void CodeGenC::PrintType(Type t, std::ostream& os) { // NOLINT(*)
+void CodeGenC::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
CHECK_EQ(t.lanes(), 1)
<< "do not yet support vector types";
if (t.is_handle()) {
@@ -360,48 +360,48 @@ void CodeGenC::PrintType(Type t, std::ostream& os) { // NOLINT(*)
inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
- if (op->type == Int(32)) {
+ if (op->dtype == DataType::Int(32)) {
std::ostringstream temp;
temp << op->value;
p->MarkConst(temp.str());
os << temp.str();
} else {
os << "(";
- p->PrintType(op->type, os);
+ p->PrintType(op->dtype, os);
os << ")" << op->value;
}
}
inline void PrintConst(const UIntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
- if (op->type == UInt(32)) {
+ if (op->dtype == DataType::UInt(32)) {
std::ostringstream temp;
temp << op->value << "U";
p->MarkConst(temp.str());
os << temp.str();
} else {
os << "(";
- p->PrintType(op->type, os);
+ p->PrintType(op->dtype, os);
os << ")" << op->value;
}
}
inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
- switch (op->type.bits()) {
+ switch (op->dtype.bits()) {
case 64: case 32: {
std::ostringstream temp;
temp << std::scientific << op->value;
- if (op->type.bits() == 32) temp << 'f';
+ if (op->dtype.bits() == 32) temp << 'f';
p->MarkConst(temp.str());
os << temp.str();
break;
}
case 16: {
os << '(';
- p->PrintType(op->type, os);
+ p->PrintType(op->dtype, os);
os << ')' << std::scientific <<op->value << 'f';
break;
}
- default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n";
+ default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
}
@@ -423,7 +423,7 @@ inline void PrintBinaryExpr(const T* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenC* p) {
- if (op->type.lanes() == 1) {
+ if (op->dtype.lanes() == 1) {
if (isalpha(opstr[0])) {
os << opstr << '(';
p->PrintExpr(op->a, os);
@@ -438,7 +438,7 @@ inline void PrintBinaryExpr(const T* op,
os << ')';
}
} else {
- p->PrintVecBinaryOp(opstr, op->type, op->a, op->b, os);
+ p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os);
}
}
@@ -446,7 +446,7 @@ inline void PrintBinaryIntrinsic(const Call* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenC* p) {
- if (op->type.lanes() == 1) {
+ if (op->dtype.lanes() == 1) {
CHECK_EQ(op->args.size(), 2U);
os << '(';
p->PrintExpr(op->args[0], os);
@@ -454,13 +454,13 @@ inline void PrintBinaryIntrinsic(const Call* op,
p->PrintExpr(op->args[1], os);
os << ')';
} else {
- p->PrintVecBinaryOp(opstr, op->type, op->args[0], op->args[1], os);
+ p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os);
}
}
void CodeGenC::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*)
std::stringstream value;
this->PrintExpr(op->value, value);
- os << CastFromTo(value.str(), op->value.type(), op->type);
+ os << CastFromTo(value.str(), op->value.dtype(), op->dtype);
}
void CodeGenC::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*)
os << GetVarID(op);
@@ -553,7 +553,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
os << "((";
- this->PrintType(l->type.element_of(), os);
+ this->PrintType(l->dtype.element_of(), os);
os << " *)" << this->GetVarID(l->buffer_var.get())
<< " + ";
this->PrintExpr(l->index, os);
@@ -561,7 +561,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
os << GetStructRef(
- op->type, op->args[0], op->args[1],
+ op->dtype, op->args[0], op->args[1],
op->args[2].as<IntImm>()->value);
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U);
@@ -571,7 +571,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
} else if (op->is_intrinsic(Call::reinterpret)) {
// generate (*( TYPE *)(&(ARG)))
os << "(*(";
- this->PrintType(op->type, os);
+ this->PrintType(op->dtype, os);
os << " *)(&(";
this->PrintExpr(op->args[0], os);
os << ")))";
@@ -585,7 +585,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
LOG(FATAL) << "Unresolved intrinsic " << op->name
- << " with return type " << op->type;
+ << " with return type " << op->dtype;
} else {
LOG(FATAL) << "Unresolved call type " << op->call_type;
}
@@ -593,7 +593,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
}
void CodeGenC::PrintVecBinaryOp(
- const std::string& op, Type t,
+ const std::string& op, DataType t,
Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*)
if (isalpha(op[0])) {
os << op << "(";
@@ -611,17 +611,17 @@ void CodeGenC::PrintVecBinaryOp(
}
void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
- int lanes = op->type.lanes();
+ int lanes = op->dtype.lanes();
// delcare type.
- if (op->type.lanes() == 1) {
- std::string ref = GetBufferRef(op->type, op->buffer_var.get(), op->index);
+ if (op->dtype.lanes() == 1) {
+ std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index);
os << ref;
} else {
CHECK(is_one(op->predicate))
<< "predicated load is not supported";
Expr base;
- if (GetRamp1Base(op->index, op->type.lanes(), &base)) {
- std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base);
+ if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) {
+ std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);
os << ref;
} else {
// The assignment below introduces side-effect, and the resulting value cannot
@@ -631,16 +631,16 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
// load seperately.
std::string svalue = GetUniqueName("_");
this->PrintIndent();
- this->PrintType(op->type, stream);
+ this->PrintType(op->dtype, stream);
stream << ' ' << svalue << ";\n";
- std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type());
+ std::string sindex = SSAGetID(PrintExpr(op->index), op->index.dtype());
std::string vid = GetVarID(op->buffer_var.get());
- Type elem_type = op->type.element_of();
+ DataType elem_type = op->dtype.element_of();
for (int i = 0; i < lanes; ++i) {
std::ostringstream value_temp;
if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
value_temp << "((";
- if (op->buffer_var.get()->type.is_handle()) {
+ if (op->buffer_var.get()->dtype.is_handle()) {
auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, value_temp);
@@ -653,9 +653,9 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
value_temp << vid;
}
value_temp << '[';
- PrintVecElemLoad(sindex, op->index.type(), i, value_temp);
+ PrintVecElemLoad(sindex, op->index.dtype(), i, value_temp);
value_temp << ']';
- PrintVecElemStore(svalue, op->type, i, value_temp.str());
+ PrintVecElemStore(svalue, op->dtype, i, value_temp.str());
}
os << svalue;
EndScope(vec_scope);
@@ -664,7 +664,7 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
}
void CodeGenC::VisitStmt_(const Store* op) {
- Type t = op->value.type();
+ DataType t = op->value.dtype();
if (t.lanes() == 1) {
std::string value = this->PrintExpr(op->value);
std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index);
@@ -683,15 +683,15 @@ void CodeGenC::VisitStmt_(const Store* op) {
int vec_scope = BeginScope();
// store elements seperately
- std::string index = SSAGetID(PrintExpr(op->index), op->index.type());
- std::string value = SSAGetID(PrintExpr(op->value), op->value.type());
+ std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype());
+ std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
std::string vid = GetVarID(op->buffer_var.get());
for (int i = 0; i < t.lanes(); ++i) {
this->PrintIndent();
- Type elem_type = t.element_of();
+ DataType elem_type = t.element_of();
if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
stream << "((";
- if (op->buffer_var.get()->type.is_handle()) {
+ if (op->buffer_var.get()->dtype.is_handle()) {
auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
@@ -704,9 +704,9 @@ void CodeGenC::VisitStmt_(const Store* op) {
stream << vid;
}
stream << '[';
- PrintVecElemLoad(index, op->index.type(), i, stream);
+ PrintVecElemLoad(index, op->index.dtype(), i, stream);
stream << "] = ";
- PrintVecElemLoad(value, op->value.type(), i, stream);
+ PrintVecElemLoad(value, op->value.dtype(), i, stream);
stream << ";\n";
}
EndScope(vec_scope);
@@ -723,7 +723,7 @@ void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
// constraint of current logic
- CHECK_EQ(op->base.type(), Int(32));
+ CHECK_EQ(op->base.dtype(), DataType::Int(32));
os << "((int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
@@ -758,7 +758,7 @@ void CodeGenC::VisitStmt_(const LetStmt* op) {
var_idmap_[op->var.get()] = value;
} else {
PrintIndent();
- if (op->var.type() == Handle() &&
+ if (op->var.dtype() == DataType::Handle() &&
handle_data_type_.count(op->var.get())) {
PrintType(handle_data_type_.at(op->var.get()), stream);
stream << "* "
@@ -767,7 +767,7 @@ void CodeGenC::VisitStmt_(const LetStmt* op) {
PrintType(handle_data_type_.at(op->var.get()), stream);
stream << "*)" << value << ";\n";
} else {
- PrintType(op->var.type(), this->stream);
+ PrintType(op->var.dtype(), this->stream);
this->stream << ' '
<< AllocVarID(op->var.get())
<< " = " << value << ";\n";
@@ -784,7 +784,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) {
CHECK_EQ(op->free_function, "nop");
std::string new_data = PrintExpr(op->new_expr);
this->PrintIndent();
- PrintType(op->type, stream);
+ PrintType(op->dtype, stream);
stream << "* "<< vid << '=' << new_data << ";\n";
} else {
this->PrintIndent();
@@ -795,11 +795,11 @@ void CodeGenC::VisitStmt_(const Allocate* op) {
std::string scope = alloc_storage_scope_.at(buffer);
PrintStorageScope(scope, stream);
stream << ' ';
- PrintType(op->type, stream);
+ PrintType(op->dtype, stream);
stream << ' '<< vid << '['
<< constant_size << "];\n";
}
- RegisterHandleType(op->buffer_var.get(), op->type);
+ RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
}
@@ -841,7 +841,7 @@ void CodeGenC::VisitStmt_(const For* op) {
std::string vid = AllocVarID(op->loop_var.get());
CHECK(is_zero(op->min));
stream << "for (";
- PrintType(op->loop_var.type(), stream);
+ PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = 0; "
<< vid << " < " << extent
<< "; ++" << vid << ") {\n";
@@ -890,7 +890,7 @@ void CodeGenC::VisitStmt_(const Evaluate *op) {
CHECK_EQ(call->args.size(), 4);
std::string value = PrintExpr(call->args[3]);
std::string ref = GetStructRef(
- call->args[3].type(),
+ call->args[3].dtype(),
call->args[0],
call->args[1],
call->args[2].as<IntImm>()->value);
diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h
index 8701cda..b8d3570 100644
--- a/src/codegen/codegen_c.h
+++ b/src/codegen/codegen_c.h
@@ -147,7 +147,7 @@ class CodeGenC :
* \param t The type representation.
* \param os The stream to print the ctype into
*/
- virtual void PrintType(Type t, std::ostream& os); // NOLINT(*)
+ virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*)
/*!
* \brief Print expr representing the thread tag
* \param IterVar iv The thread index to be binded;
@@ -157,51 +157,51 @@ class CodeGenC :
virtual void PrintStorageSync(const Call* op); // NOLINT(*)
// Binary vector op.
virtual void PrintVecBinaryOp(
- const std::string&op, Type op_type,
+ const std::string&op, DataType op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load
- virtual std::string GetVecLoad(Type t, const Variable* buffer, Expr base);
+ virtual std::string GetVecLoad(DataType t, const Variable* buffer, Expr base);
// print vector store
virtual void PrintVecStore(const Variable* buffer,
- Type t, Expr base,
+ DataType t, Expr base,
const std::string& value); // NOLINT(*)
// print load of single element
virtual void PrintVecElemLoad(
- const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*)
+ const std::string& vec, DataType t, int i, std::ostream& os); // NOLINT(*)
// print store of single element.
virtual void PrintVecElemStore(
- const std::string& vec, Type t, int i, const std::string& value);
+ const std::string& vec, DataType t, int i, const std::string& value);
// Get a cast type from to
- virtual std::string CastFromTo(std::string value, Type from, Type target);
+ virtual std::string CastFromTo(std::string value, DataType from, DataType target);
protected:
// Print reference to struct location
std::string GetStructRef(
- Type t, const Expr& buffer, const Expr& index, int kind);
+ DataType t, const Expr& buffer, const Expr& index, int kind);
// print reference to a buffer as type t in index.
virtual std::string GetBufferRef(
- Type t, const Variable* buffer, Expr index);
+ DataType t, const Variable* buffer, Expr index);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
- bool HandleTypeMatch(const Variable* buf_var, Type t) const;
+ bool HandleTypeMatch(const Variable* buf_var, DataType t) const;
/*!
* \brief Register the data type of buf_var
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
- void RegisterHandleType(const Variable* buf_var, Type t);
+ void RegisterHandleType(const Variable* buf_var, DataType t);
// override
void PrintSSAAssign(
- const std::string& target, const std::string& src, Type t) final;
+ const std::string& target, const std::string& src, DataType t) final;
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
- std::unordered_map<const Variable*, Type> handle_data_type_;
+ std::unordered_map<const Variable*, DataType> handle_data_type_;
/*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique();
diff --git a/src/codegen/codegen_c_host.cc b/src/codegen/codegen_c_host.cc
index 9c099a4..f2c54c2 100644
--- a/src/codegen/codegen_c_host.cc
+++ b/src/codegen/codegen_c_host.cc
@@ -48,7 +48,7 @@ void CodeGenCHost::AddFunction(LoweredFunc f) {
ReserveKeywordsAsUnique();
// add to alloc buffer type.
for (const auto & kv : f->handle_data_type) {
- RegisterHandleType(kv.first.get(), kv.second.type());
+ RegisterHandleType(kv.first.get(), kv.second.dtype());
}
this->stream << "#ifdef __cplusplus\n";
@@ -59,7 +59,7 @@ void CodeGenCHost::AddFunction(LoweredFunc f) {
Var v = f->args[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
- if (v.type().is_handle()) {
+ if (v.dtype().is_handle()) {
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
@@ -77,7 +77,7 @@ void CodeGenCHost::AddFunction(LoweredFunc f) {
stream << ' ' << restrict_keyword_;
}
} else {
- PrintType(v.type(), stream);
+ PrintType(v.dtype(), stream);
}
stream << ' ' << vid;
}
@@ -96,14 +96,14 @@ std::string CodeGenCHost::Finish() {
return CodeGenC::Finish();
}
-void CodeGenCHost::PrintType(Type t, std::ostream& os) { // NOLINT(*)
+void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
<< "does not support vector types";
os << "void*"; return;
}
- if (t == Bool()) {
+ if (t == DataType::Bool()) {
os << "bool"; return;
}
bool fail = false;
@@ -145,7 +145,7 @@ void CodeGenCHost::PrintType(Type t, std::ostream& os) { // NOLINT(*)
void CodeGenCHost::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
os << "((";
- PrintType(op->type, os);
+ PrintType(op->dtype, os);
os << ")(";
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
@@ -268,10 +268,10 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op,
std::ostream& os) { // NOLINT(*)
std::ostringstream temp_a;
VisitExpr(op->a, temp_a);
- std::string a_id = SSAGetID(temp_a.str(), op->a.type());
+ std::string a_id = SSAGetID(temp_a.str(), op->a.dtype());
std::ostringstream temp_b;
VisitExpr(op->b, temp_b);
- std::string b_id = SSAGetID(temp_b.str(), op->b.type());
+ std::string b_id = SSAGetID(temp_b.str(), op->b.dtype());
os << "((" << a_id << ") " << compare << " (" << b_id << ") "
<< "? (" << a_id << ") : (" << b_id << "))";
diff --git a/src/codegen/codegen_c_host.h b/src/codegen/codegen_c_host.h
index 80e359c..44f8385 100644
--- a/src/codegen/codegen_c_host.h
+++ b/src/codegen/codegen_c_host.h
@@ -39,7 +39,7 @@ class CodeGenCHost final : public CodeGenC {
void AddFunction(LoweredFunc f);
std::string Finish();
- void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
+ void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
// overload visitor functions
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc
index 6656fa0..06b542a 100644
--- a/src/codegen/codegen_cuda.cc
+++ b/src/codegen/codegen_cuda.cc
@@ -105,10 +105,10 @@ void CodeGenCUDA::VisitStmt_(const ir::For* op) {
void CodeGenCUDA::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] =
- CastFromTo(iv->thread_tag, UInt(32), iv->var.type());
+ CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
}
-void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
+void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
@@ -137,7 +137,7 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return;
}
- } else if (t == Bool()) {
+ } else if (t == DataType::Bool()) {
os << "bool"; return;
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
@@ -199,7 +199,7 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
}
void CodeGenCUDA::PrintVecBinaryOp(
- const std::string&op, Type t,
+ const std::string&op, DataType t,
Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*)
// unpacking operations.
int lanes = t.lanes();
@@ -210,8 +210,8 @@ void CodeGenCUDA::PrintVecBinaryOp(
int vec_scope = BeginScope();
// default: unpack into individual ops.
- std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.type());
- std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.type());
+ std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
+ std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());
std::string sret = GetUniqueName("_");
{
// delcare type.
@@ -223,15 +223,15 @@ void CodeGenCUDA::PrintVecBinaryOp(
std::ostringstream value_temp;
if (isalpha(op[0])) {
value_temp << op << "(";
- PrintVecElemLoad(vlhs, lhs.type(), i, value_temp);
+ PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
value_temp << ", ";
- PrintVecElemLoad(vrhs, rhs.type(), i, value_temp);
+ PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
} else {
value_temp << "(";
- PrintVecElemLoad(vlhs, lhs.type(), i, value_temp);
+ PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
value_temp << op;
- PrintVecElemLoad(vrhs, rhs.type(), i, value_temp);
+ PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
}
PrintVecElemStore(sret, t, i, value_temp.str());
@@ -242,7 +242,7 @@ void CodeGenCUDA::PrintVecBinaryOp(
}
void CodeGenCUDA::PrintVecElemLoad(
- const std::string& vec, Type t, int i, std::ostream& os) { // NOLINT(*)
+ const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*)
static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < 4);
if (t.is_int() && t.bits() == 8) {
@@ -253,7 +253,7 @@ void CodeGenCUDA::PrintVecElemLoad(
}
void CodeGenCUDA::PrintVecElemStore(
- const std::string& vec, Type t, int i, const std::string& value) {
+ const std::string& vec, DataType t, int i, const std::string& value) {
this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < 4);
@@ -390,7 +390,7 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) {
CHECK_EQ(op->free_function, "nop");
std::string new_data = PrintExpr(op->new_expr);
this->PrintIndent();
- PrintType(op->type, stream);
+ PrintType(op->dtype, stream);
stream << "* "<< vid << '=' << new_data << ";\n";
} else {
this->PrintIndent();
@@ -401,23 +401,27 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) {
std::string scope = alloc_storage_scope_.at(buffer);
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
- CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8))
+ CHECK(op->dtype == DataType::Float(16) ||
+ op->dtype == DataType::Int(8) ||
+ op->dtype == DataType::UInt(8))
<< "Matrix_a and matrix_b only support half or char or unsigned char type for now";
} else {
- CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32))
+ CHECK(op->dtype == DataType::Float(16) ||
+ op->dtype == DataType::Float(32) ||
+ op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now";
}
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
- PrintWmmaScope(scope, op->type, buffer, stream);
+ PrintWmmaScope(scope, op->dtype, buffer, stream);
} else {
PrintStorageScope(scope, stream);
stream << ' ';
- PrintType(op->type, stream);
+ PrintType(op->dtype, stream);
}
stream << ' '<< vid << '['
<< constant_size << "];\n";
}
- RegisterHandleType(op->buffer_var.get(), op->type);
+ RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
}
@@ -449,7 +453,7 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
}
void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
- if (op->type.is_int() && op->type.bits() == 8 && op->lanes == 4) {
+ if (op->dtype.is_int() && op->dtype.bits() == 8 && op->lanes == 4) {
// make_int8x4
const int64_t *p = as_const_int(op->value);
CHECK(p);
@@ -461,7 +465,7 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN
std::string v = PrintExpr(op->value);
os << "make_";
- PrintType(op->type, os);
+ PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
@@ -473,11 +477,11 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLIN
void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) {
std::vector<std::string> to_shuffle(op->vectors.size());
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
- CHECK(op->vectors[i].type().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
+ CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
to_shuffle[i] = PrintExpr(op->vectors[i]);
}
os << "make_";
- PrintType(op->type, os);
+ PrintType(op->dtype, os);
os << '(';
for (int i = 0, e = op->indices.size(); i < e; ++i) {
const int64_t *val = as_const_int(op->indices[i]);
@@ -489,21 +493,21 @@ void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) {
}
inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
- switch (op->type.bits()) {
+ switch (op->dtype.bits()) {
case 64: case 32: {
std::ostringstream temp;
if (std::isinf(op->value)) {
if (op->value < 0) {
temp << "-";
}
- temp << ((op->type.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
+ temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
p->need_math_constants_h_ = true;
} else if (std::isnan(op->value)) {
- temp << ((op->type.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
+ temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
p->need_math_constants_h_ = true;
} else {
temp << std::scientific << op->value;
- if (op->type.bits() == 32) temp << 'f';
+ if (op->dtype.bits() == 32) temp << 'f';
}
p->MarkConst(temp.str());
os << temp.str();
@@ -514,7 +518,7 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { /
os << '(' << std::scientific << op->value << 'f' << ')';
break;
}
- default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n";
+ default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
}
@@ -523,7 +527,7 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*
PrintConst(op, os, this);
}
-void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t,
+void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
const Variable* variable, std::ostream &os) {
std::stringstream type;
PrintType(t, type);
diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h
index efb3004..74d6fba 100644
--- a/src/codegen/codegen_cuda.h
+++ b/src/codegen/codegen_cuda.h
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -47,13 +47,13 @@ class CodeGenCUDA final : public CodeGenC {
void PrintStorageSync(const Call* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
- const std::string&op, Type t,
+ const std::string&op, DataType t,
Expr lhs, Expr rhs, std::ostream& os) final; // NOLINT(*)
- void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
+ void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad(
- const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
+ const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(
- const std::string& vec, Type t, int i, const std::string& value) final;
+ const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*)
@@ -84,8 +84,10 @@ class CodeGenCUDA final : public CodeGenC {
std::unordered_map<const Variable*, std::string> fragment_shapes;
std::unordered_map<const Variable*, std::string> fragment_layouts;
friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p);
- void PrintWmmaScope(const std::string& scope, Type t, const Variable* variable, std::ostream& os);
- int32_t GetWmmaFragmentSize(const std::string &scope, const Variable* variable, int32_t size);
+ void PrintWmmaScope(
+ const std::string& scope, DataType t, const Variable* variable, std::ostream& os);
+ int32_t GetWmmaFragmentSize(
+ const std::string &scope, const Variable* variable, int32_t size);
};
} // namespace codegen
diff --git a/src/codegen/codegen_metal.cc b/src/codegen/codegen_metal.cc
index 311bdcb..f4ff014 100644
--- a/src/codegen/codegen_metal.cc
+++ b/src/codegen/codegen_metal.cc
@@ -36,7 +36,7 @@ void CodeGenMetal::InitFuncState(LoweredFunc f) {
CodeGenC::InitFuncState(f);
// analyze the data;
for (Var arg : f->args) {
- if (arg.type().is_handle()) {
+ if (arg.dtype().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
}
@@ -57,7 +57,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
GetUniqueName("_");
// add to alloc buffer type.
for (const auto & kv : f->handle_data_type) {
- RegisterHandleType(kv.first.get(), kv.second.type());
+ RegisterHandleType(kv.first.get(), kv.second.dtype());
}
// Function header.
this->stream << "kernel void " << f->name << "(\n";
@@ -65,7 +65,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
size_t num_buffer = 0;
for (size_t i = 0; i < f->args.size(); ++i, ++num_buffer) {
Var v = f->args[i];
- if (!v.type().is_handle()) break;
+ if (!v.dtype().is_handle()) break;
stream << " ";
std::string vid = AllocVarID(v.get());
auto it = alloc_storage_scope_.find(v.get());
@@ -76,7 +76,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
PrintType(handle_data_type_.at(v.get()), stream);
stream << "*";
} else {
- PrintType(v.type(), stream);
+ PrintType(v.dtype(), stream);
}
stream << ' ' << vid
<< " [[ buffer(" << i << ") ]],\n";
@@ -92,19 +92,19 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
decl_stream << "struct " << arg_buf_type << " {\n";
for (size_t i = num_buffer; i < f->args.size(); ++i) {
Var v = f->args[i];
- CHECK(!v.type().is_handle());
+ CHECK(!v.dtype().is_handle());
std::string vid = AllocVarID(v.get());
std::ostringstream vref;
- if (v.type().bits() == 32) {
+ if (v.dtype().bits() == 32) {
decl_stream << " ";
- PrintType(v.type(), decl_stream);
+ PrintType(v.dtype(), decl_stream);
decl_stream << " " << vid << ";\n";
vref << varg << "." << vid;
} else {
// For non 32bit type, ref through arg union.
decl_stream << " __TVMArgUnion " << vid << ";\n";
vref << varg << "." << vid << ".v_";
- PrintType(v.type(), vref);
+ PrintType(v.dtype(), vref);
}
var_idmap_[v.get()] = vref.str();
}
@@ -121,10 +121,10 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
if (work_dim != 0) {
// use ushort by default for now
stream << " ";
- PrintType(UInt(thread_index_bits_, work_dim), stream);
+ PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
stream << " blockIdx [[threadgroup_position_in_grid]],\n";
stream << " ";
- PrintType(UInt(thread_index_bits_, work_dim), stream);
+ PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
stream << " threadIdx [[thread_position_in_threadgroup]]\n";
}
// bind thread axis
@@ -135,7 +135,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
vname = vname.substr(0, iv->thread_tag.length() - 2);
}
var_idmap_[iv->var.get()] =
- CastFromTo(vname, UInt(thread_index_bits_), iv->var.type());
+ CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype());
}
// the function scope.
stream << ") {\n";
@@ -149,17 +149,17 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] =
- CastFromTo(iv->thread_tag, UInt(thread_index_bits_), iv->var.type());
+ CastFromTo(iv->thread_tag, DataType::UInt(thread_index_bits_), iv->var.dtype());
}
-void CodeGenMetal::PrintType(Type t, std::ostream& os) { // NOLINT(*)
+void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
<< "do not yet support vector types";
os << "void*"; return;
}
- if (t == Bool()) {
+ if (t == DataType::Bool()) {
os << "bool"; return;
}
bool fail = false;
@@ -210,13 +210,13 @@ void CodeGenMetal::PrintStorageSync(const Call* op) {
}
void CodeGenMetal::PrintVecElemLoad(const std::string& vec,
- Type t, int i,
+ DataType t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << "[" << i << "]";
}
void CodeGenMetal::PrintVecElemStore(const std::string& vec,
- Type t, int i,
+ DataType t, int i,
const std::string& value) {
this->PrintIndent();
stream << vec << "[" << i << "]"
@@ -236,7 +236,7 @@ void CodeGenMetal::PrintStorageScope(
void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
- PrintType(op->type, os);
+ PrintType(op->dtype, os);
os << "(";
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
@@ -249,7 +249,7 @@ void CodeGenMetal::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(Call::reinterpret)) {
// generate as_type<TYPE>(ARG)
os << "(as_type<";
- this->PrintType(op->type, os);
+ this->PrintType(op->dtype, os);
os << ">(";
this->PrintExpr(op->args[0], os);
os << "))";
diff --git a/src/codegen/codegen_metal.h b/src/codegen/codegen_metal.h
index c009cd1..728e3e0 100644
--- a/src/codegen/codegen_metal.h
+++ b/src/codegen/codegen_metal.h
@@ -41,14 +41,14 @@ class CodeGenMetal final : public CodeGenC {
void InitFuncState(LoweredFunc f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
- void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
+ void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// print load of single element
void PrintVecElemLoad(
- const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
+ const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*)
// print store of single element.
void PrintVecElemStore(
- const std::string& vec, Type t, int i, const std::string& value) final;
+ const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc
index 49dccb1..ae43419 100644
--- a/src/codegen/codegen_opencl.cc
+++ b/src/codegen/codegen_opencl.cc
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -39,7 +39,7 @@ CodeGenOpenCL::CodeGenOpenCL() {
void CodeGenOpenCL::InitFuncState(LoweredFunc f) {
CodeGenC::InitFuncState(f);
for (Var arg : f->args) {
- if (arg.type().is_handle()) {
+ if (arg.dtype().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
}
@@ -89,17 +89,17 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
os << "get_group_id(" << ts.dim_index << ")";
}
var_idmap_[iv->var.get()] =
- CastFromTo(os.str(), UInt(64), iv->var.type());
+ CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype());
}
-void CodeGenOpenCL::PrintType(Type t, std::ostream& os) { // NOLINT(*)
+void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
<< "do not yet support vector types";
os << "void*"; return;
}
- if (t == Bool()) {
+ if (t == DataType::Bool()) {
os << "bool"; return;
}
bool fail = false;
@@ -144,7 +144,7 @@ void CodeGenOpenCL::PrintType(Type t, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type";
}
-void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t,
+void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t,
Expr base, std::ostream& os) { // NOLINT(*)
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
@@ -160,7 +160,7 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t,
PrintExpr(base, os);
}
std::string CodeGenOpenCL::GetVecLoad(
- Type t, const Variable* buffer, Expr base) {
+ DataType t, const Variable* buffer, Expr base) {
std::ostringstream os;
os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os);
@@ -169,7 +169,7 @@ std::string CodeGenOpenCL::GetVecLoad(
}
void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
- Type t, Expr base,
+ DataType t, Expr base,
const std::string& value) {
this->PrintIndent();
stream << "vstore" << t.lanes() << "(" << value << ", 0, ";
@@ -199,7 +199,7 @@ void CodeGenOpenCL::PrintStorageScope(
}
}
-std::string CodeGenOpenCL::CastFromTo(std::string value, Type from, Type target) {
+std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType target) {
if (from == target) return value;
std::ostringstream os;
if (target.lanes() == 1) {
@@ -218,7 +218,7 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, Type from, Type target)
void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
os << "((";
- PrintType(op->type, os);
+ PrintType(op->dtype, os);
os << ")(";
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
@@ -232,7 +232,7 @@ void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
* add a cast */
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
os << "(";
- PrintType(op->args[2].type(), os);
+ PrintType(op->args[2].dtype(), os);
os << ")";
}
CodeGenC::VisitExpr_(op, os);
@@ -242,7 +242,7 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
os << "(";
- PrintType(op->true_value.type(), os);
+ PrintType(op->true_value.dtype(), os);
os << ")";
CodeGenC::VisitExpr_(op, os);
}
diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h
index 32f4501..36324eb 100644
--- a/src/codegen/codegen_opencl.h
+++ b/src/codegen/codegen_opencl.h
@@ -43,16 +43,16 @@ class CodeGenOpenCL final : public CodeGenC {
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*)
- void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
- std::string GetVecLoad(Type t, const Variable* buffer,
+ void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
+ std::string GetVecLoad(DataType t, const Variable* buffer,
Expr base) final;
void PrintVecStore(const Variable* buffer,
- Type t, Expr base,
+ DataType t, Expr base,
const std::string& value) final; // NOLINT(*)
// the address of load/store
- void PrintVecAddr(const Variable* buffer, Type t,
+ void PrintVecAddr(const Variable* buffer, DataType t,
Expr base, std::ostream& os); // NOLINT(*)
- std::string CastFromTo(std::string value, Type from, Type target); // NOLINT(*)
+ std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc
index 52e04db..db14be3 100644
--- a/src/codegen/codegen_opengl.cc
+++ b/src/codegen/codegen_opengl.cc
@@ -59,7 +59,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) {
GetUniqueName("_");
// add to alloc buffer type.
for (const auto& kv : f->handle_data_type) {
- RegisterHandleType(kv.first.get(), kv.second.type());
+ RegisterHandleType(kv.first.get(), kv.second.dtype());
}
// Allocate argument names. Store in `var_idmap_`.
@@ -93,7 +93,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) {
auto type_it = this->handle_data_type_.find(arg.get());
CHECK(type_it != this->handle_data_type_.cend()) << "Cannot find type.";
- auto type = Type2TVMType(type_it->second);
+ DLDataType type = type_it->second;
CHECK_EQ(type.lanes, 1) << "Vector type not supported.";
switch (type.code) {
@@ -129,7 +129,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) {
// Format: "uniform {type} {name};"
auto arg_name = GetVarID(arg.get());
- auto type = arg.get()->type;
+ auto type = arg.get()->dtype;
this->decl_stream << "uniform ";
PrintType(type, this->decl_stream);
@@ -207,7 +207,7 @@ std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) {
// Print a reference expression to a buffer.
// Format: texelFetch(buffer, index, 0).r
std::string CodeGenOpenGL::GetBufferRef(
- Type t, const Variable* buffer, Expr index) {
+ DataType t, const Variable* buffer, Expr index) {
CHECK_EQ(t.lanes(), 1) << "Vector type not supported.";
CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported.";
@@ -221,7 +221,7 @@ std::string CodeGenOpenGL::GetBufferRef(
}
}
-void CodeGenOpenGL::PrintType(Type t, std::ostream& os) {
+void CodeGenOpenGL::PrintType(DataType t, std::ostream& os) {
switch (t.code()) {
case kDLInt:
CHECK_EQ(t.bits(), 32) << "Only support 32-bit int.";
@@ -243,17 +243,17 @@ void CodeGenOpenGL::PrintType(Type t, std::ostream& os) {
// Codegen for immediate values
void CodeGenOpenGL::VisitExpr_(const IntImm* op, std::ostream& os) {
- CHECK_EQ(op->type, Int(32)) << "GLSL 3.0 only supports 32-bit ints.";
+ CHECK_EQ(op->dtype, DataType::Int(32)) << "GLSL 3.0 only supports 32-bit ints.";
CodeGenC::VisitExpr_(op, os);
}
void CodeGenOpenGL::VisitExpr_(const UIntImm* op, std::ostream& os) {
- CHECK_EQ(op->type, UInt(32)) << "GLSL 3.0 only supports 32-bit uints.";
+ CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints.";
CodeGenC::VisitExpr_(op, os);
}
void CodeGenOpenGL::VisitExpr_(const FloatImm* op, std::ostream& os) {
- CHECK_EQ(op->type, Float(32)) << "GLSL 3.0 only supports 32-bit floats.";
+ CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats.";
CodeGenC::VisitExpr_(op, os);
}
@@ -273,7 +273,7 @@ void CodeGenOpenGL::VisitStmt_(const Evaluate* op) {
auto value = call->args[1];
// Doesn't support store to vector.
- auto type = value.type();
+ auto type = value.dtype();
CHECK_EQ(type.lanes(), 1)
<< "Vectorized store not implemented, type = " << type;
diff --git a/src/codegen/codegen_opengl.h b/src/codegen/codegen_opengl.h
index d18052f..46e87a8 100644
--- a/src/codegen/codegen_opengl.h
+++ b/src/codegen/codegen_opengl.h
@@ -45,8 +45,8 @@ class CodeGenOpenGL final : public CodeGenC {
void BindThreadIndex(const IterVar& iv) final;
void VisitStmt_(const Store* op) final;
std::string TexelFetch(const Variable* buffer, Expr index);
- std::string GetBufferRef(Type t, const Variable* buffer, Expr index) final;
- void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
+ std::string GetBufferRef(DataType t, const Variable* buffer, Expr index) final;
+ void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
// Codegen for immediate values
void VisitExpr_(const IntImm* op, std::ostream& os) final; // NOLINT(*)
diff --git a/src/codegen/codegen_source_base.cc b/src/codegen/codegen_source_base.cc
index 9a9f525..7c4ed5b 100644
--- a/src/codegen/codegen_source_base.cc
+++ b/src/codegen/codegen_source_base.cc
@@ -52,7 +52,7 @@ std::string CodeGenSourceBase::GetUniqueName(std::string prefix) {
return prefix;
}
-std::string CodeGenSourceBase::SSAGetID(std::string src, Type t) {
+std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) {
if (name_alloc_map_.count(src)) return src;
auto it = ssa_assign_map_.find(src);
if (it != ssa_assign_map_.end()) {
diff --git a/src/codegen/codegen_source_base.h b/src/codegen/codegen_source_base.h
index e0608c6..7fd0eef 100644
--- a/src/codegen/codegen_source_base.h
+++ b/src/codegen/codegen_source_base.h
@@ -79,7 +79,7 @@ class CodeGenSourceBase {
* \param src The source expression
* \param t The type of the expression.
*/
- std::string SSAGetID(std::string src, Type t);
+ std::string SSAGetID(std::string src, DataType t);
/*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
@@ -103,7 +103,7 @@ class CodeGenSourceBase {
* \param t The type of target.
*/
virtual void PrintSSAAssign(
- const std::string& target, const std::string& src, Type t) = 0;
+ const std::string& target, const std::string& src, DataType t) = 0;
/*! \brief the declaration stream */
std::ostringstream decl_stream;
diff --git a/src/codegen/codegen_vhls.cc b/src/codegen/codegen_vhls.cc
index 84329f9..40550d9 100644
--- a/src/codegen/codegen_vhls.cc
+++ b/src/codegen/codegen_vhls.cc
@@ -37,7 +37,7 @@ void CodeGenVivadoHLS::Init(bool output_ssa) {
this->stream << "#include <algorithm>\n\n";
}
-void CodeGenVivadoHLS::PrintType(Type t, std::ostream& os) {
+void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) {
if (t.is_uint()) {
switch (t.bits()) {
case 8:
@@ -78,7 +78,7 @@ void CodeGenVivadoHLS::PreFunctionBody(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
std::string vid = GetVarID(v.get());
- if (v.type().is_handle()) {
+ if (v.dtype().is_handle()) {
this->stream << "#pragma HLS INTERFACE m_axi port=" << vid << " offset=slave bundle=gmem\n";
}
this->stream << "#pragma HLS INTERFACE s_axilite port=" << vid << " bundle=control\n";
@@ -100,8 +100,8 @@ inline void PrintBinaryExpr(const T* op,
void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::min";
- if (op->type.is_float()) {
- switch (op->type.bits()) {
+ if (op->dtype.is_float()) {
+ switch (op->dtype.bits()) {
case 32:
opstr = "fminf"; break;
case 64:
@@ -114,8 +114,8 @@ void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(
void CodeGenVivadoHLS::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::max";
- if (op->type.is_float()) {
- switch (op->type.bits()) {
+ if (op->dtype.is_float()) {
+ switch (op->dtype.bits()) {
case 32:
opstr = "fmaxf"; break;
case 64:
diff --git a/src/codegen/codegen_vhls.h b/src/codegen/codegen_vhls.h
index 4ec7b10..e678edb 100644
--- a/src/codegen/codegen_vhls.h
+++ b/src/codegen/codegen_vhls.h
@@ -35,7 +35,7 @@ namespace codegen {
class CodeGenVivadoHLS final : public CodeGenC {
public:
void Init(bool output_ssa);
- void PrintType(Type t, std::ostream& os);
+ void PrintType(DataType t, std::ostream& os);
void AddFunction(LoweredFunc f);
void PreFunctionBody(LoweredFunc f);
void VisitExpr_(const Min *op, std::ostream& os);
diff --git a/src/codegen/intrin_rule.cc b/src/codegen/intrin_rule.cc
index f765c00..219b485 100644
--- a/src/codegen/intrin_rule.cc
+++ b/src/codegen/intrin_rule.cc
@@ -57,7 +57,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
const Call* call = e.as<Call>();
CHECK(call != nullptr);
- auto one = make_const(call->args[0].type(), 1);
+ auto one = make_const(call->args[0].dtype(), 1);
*rv = one / sqrt(call->args[0]);
});
@@ -70,7 +70,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
const Call* call = e.as<Call>();
CHECK(call != nullptr);
- auto one = make_const(call->args[0].type(), 1);
+ auto one = make_const(call->args[0].dtype(), 1);
*rv = one / (one + exp(-call->args[0]));
});
diff --git a/src/codegen/intrin_rule.h b/src/codegen/intrin_rule.h
index 9f3bd79..581387d 100644
--- a/src/codegen/intrin_rule.h
+++ b/src/codegen/intrin_rule.h
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -37,10 +37,10 @@ using namespace ir;
// Add float suffix to the intrinsics
struct FloatSuffix {
- std::string operator()(Type t, std::string name) const {
- if (t == Float(32)) {
+ std::string operator()(DataType t, std::string name) const {
+ if (t == DataType::Float(32)) {
return name + 'f';
- } else if (t == Float(64)) {
+ } else if (t == DataType::Float(64)) {
return name;
} else {
return "";
@@ -50,7 +50,7 @@ struct FloatSuffix {
// Return the intrinsic name
struct Direct {
- std::string operator()(Type t, std::string name) const {
+ std::string operator()(DataType t, std::string name) const {
return name;
}
};
@@ -61,10 +61,10 @@ inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);
- std::string name = T()(call->type, call->name);
+ std::string name = T()(call->dtype, call->name);
if (name.length() != 0) {
*rv = Call::make(
- call->type, name, call->args, Call::PureExtern);
+ call->dtype, name, call->args, Call::PureExtern);
} else {
*rv = e;
}
diff --git a/src/codegen/intrin_rule_cuda.cc b/src/codegen/intrin_rule_cuda.cc
index 4fed20f..3f6bc7b 100644
--- a/src/codegen/intrin_rule_cuda.cc
+++ b/src/codegen/intrin_rule_cuda.cc
@@ -28,7 +28,7 @@ namespace codegen {
namespace intrin {
// Add float suffix to the intrinsics, CUDA fast math.
struct CUDAMath {
- std::string operator()(Type t, std::string name) const {
+ std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1) {
if (t.is_float()) {
switch (t.bits()) {
@@ -44,7 +44,7 @@ struct CUDAMath {
};
struct CUDAFastMath : public CUDAMath {
- std::string operator()(Type t, std::string name) const {
+ std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1 && t.is_float() && t.bits() == 32) {
return "__" + name + 'f';
} else {
@@ -55,7 +55,7 @@ struct CUDAFastMath : public CUDAMath {
};
struct CUDAPopcount {
- std::string operator()(Type t, std::string name) const {
+ std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1 && t.is_uint()) {
switch (t.bits()) {
case 32: return "__popc";
@@ -68,7 +68,7 @@ struct CUDAPopcount {
};
struct CUDAShuffle {
- std::string operator()(Type t, std::string name) const {
+ std::string operator()(DataType t, std::string name) const {
return "__shfl";
}
};
diff --git a/src/codegen/intrin_rule_opencl.cc b/src/codegen/intrin_rule_opencl.cc
index 246747c..4b1d403 100644
--- a/src/codegen/intrin_rule_opencl.cc
+++ b/src/codegen/intrin_rule_opencl.cc
@@ -66,7 +66,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod")
// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
struct IntelShuffle {
- std::string operator()(Type t, std::string name) const {
+ std::string operator()(DataType t, std::string name) const {
return "intel_sub_group_shuffle";
}
};
diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc
index 491a304..f57a3ca 100644
--- a/src/codegen/llvm/codegen_amdgpu.cc
+++ b/src/codegen/llvm/codegen_amdgpu.cc
@@ -82,7 +82,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
<< "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
- info.alignment = GetTempAllocaAlignment(op->type, constant_size);
+ info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
}
// maximum necessary alignment in the AMD devices
if (info.alignment > 16) {
@@ -93,7 +93,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
- LLVMType(op->type), ConstInt32(constant_size));
+ LLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
@@ -108,7 +108,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
- llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size);
+ llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
@@ -122,7 +122,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
}
buf = builder_->CreatePointerCast(
- buf, LLVMType(op->type)->getPointerTo(
+ buf, LLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
diff --git a/src/codegen/llvm/codegen_arm.cc b/src/codegen/llvm/codegen_arm.cc
index 9b21455..4c092df 100644
--- a/src/codegen/llvm/codegen_arm.cc
+++ b/src/codegen/llvm/codegen_arm.cc
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -61,14 +61,14 @@ Expr CodeGenARM::ARMPopcount(const Call *call) {
::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;
// Fallback to default llvm lowering rule if input type not a full vector or half vector length
- int total_size = call->type.bits() * call->type.lanes();
- if (!call->type.is_vector() || call->type.bits() == 8 ||
+ int total_size = call->dtype.bits() * call->dtype.lanes();
+ if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
(total_size != 128 && total_size != 64)) {
Array<Expr> vcnt_args;
- vcnt_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
- vcnt_args.push_back(ir::UIntImm::make(UInt(32), 1));
+ vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id));
+ vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
vcnt_args.push_back(e);
- return ir::Call::make(call->type, "llvm_intrin", vcnt_args, Call::PureIntrinsic);
+ return ir::Call::make(call->dtype, "llvm_intrin", vcnt_args, Call::PureIntrinsic);
}
// Popcount lowering rule:
@@ -77,9 +77,12 @@ Expr CodeGenARM::ARMPopcount(const Call *call) {
// to return back to original input type
// Dvisions are always divisible (number of bits = 64 or 128)
- Type uint8_type = Type(e.type().code(), 8, e.type().bits() * e.type().lanes() / 8);
- Type uint16_type = Type(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16);
- Type uint32_type = Type(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);
+ DataType uint8_type = DataType(
+ e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8);
+ DataType uint16_type = DataType(
+ uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16);
+ DataType uint32_type = DataType(
+ uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);
// Interpret input as vector of 8bit values
Expr input8 = reinterpret(uint8_type, e);
@@ -87,37 +90,37 @@ Expr CodeGenARM::ARMPopcount(const Call *call) {
const Call* c0 = input8.as<Call>();
CHECK(c0 != nullptr);
Array<Expr> vcnt8_args;
- vcnt8_args.push_back(ir::UIntImm::make(UInt(32), ctpop_id));
- vcnt8_args.push_back(ir::UIntImm::make(UInt(32), 1));
+ vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id));
+ vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
vcnt8_args.push_back(input8);
Expr vcnt8 = ir::Call::make(uint8_type, "llvm_intrin", vcnt8_args, Call::PureIntrinsic);
// Accumulation 8->16bit
Array<Expr> vcnt16_args;
- vcnt16_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
- vcnt16_args.push_back(ir::UIntImm::make(UInt(32), 1));
+ vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
+ vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
vcnt16_args.push_back(vcnt8);
Expr vcnt16 = ir::Call::make(uint16_type, "llvm_intrin", vcnt16_args, Call::PureIntrinsic);
- if (call->type.bits() == 16) {
+ if (call->dtype.bits() == 16) {
return vcnt16;
}
// Accumulation 16->32bit
Array<Expr> vcnt32_args;
- vcnt32_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
- vcnt32_args.push_back(ir::UIntImm::make(UInt(32), 1));
+ vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
+ vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
vcnt32_args.push_back(vcnt16);
Expr vcnt32 = ir::Call::make(uint32_type, "llvm_intrin", vcnt32_args, Call::PureIntrinsic);
- if (call->type.bits() == 32) {
+ if (call->dtype.bits() == 32) {
return vcnt32;
}
// Accumulation 32->64bit
Array<Expr> vcnt64_args;
- vcnt64_args.push_back(ir::UIntImm::make(UInt(32), vpaddlu_id));
- vcnt64_args.push_back(ir::UIntImm::make(UInt(32), 1));
+ vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
+ vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
vcnt64_args.push_back(vcnt32);
- return ir::Call::make(call->type, "llvm_intrin", vcnt64_args, Call::PureIntrinsic);
+ return ir::Call::make(call->dtype, "llvm_intrin", vcnt64_args, Call::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
diff --git a/src/codegen/llvm/codegen_cpu.cc b/src/codegen/llvm/codegen_cpu.cc
index 0ba0c58..9f1a292 100644
--- a/src/codegen/llvm/codegen_cpu.cc
+++ b/src/codegen/llvm/codegen_cpu.cc
@@ -43,7 +43,7 @@ void CodeGenCPU::Init(const std::string& module_name,
func_handle_map_.clear();
export_system_symbols_.clear();
// TVM runtime types
- t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, TVMShapeIndexType().bits());
+ t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, DataType::ShapeIndex().bits());
t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_});
t_tvm_func_handle_ = t_void_p_;
@@ -252,7 +252,7 @@ std::unique_ptr<llvm::Module> CodeGenCPU::Finish() {
return CodeGenLLVM::Finish();
}
llvm::Value* CodeGenCPU::CreateStructRefPtr(
- Type t, llvm::Value* buf, llvm::Value* index, int kind) {
+ DataType t, llvm::Value* buf, llvm::Value* index, int kind) {
if (kind < intrinsic::kArrKindBound_) {
if (buf->getType() == t_void_p_) {
buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
@@ -329,7 +329,7 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const Call* op) {
arg_types.push_back(v->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(
- LLVMType(op->type), arg_types, false);
+ LLVMType(op->dtype), arg_types, false);
// Check if it is available in global function table as injected function.
auto it = gv_func_map_.find(op->name);
if (it != gv_func_map_.end()) {
@@ -448,7 +448,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
llvm::Argument* v = &(*it);
const Var& var = vargs[idx];
new_vmap[var.get()] = v;
- if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
+ if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
// set non alias.
#if TVM_LLVM_VERSION >= 50
fcompute->addParamAttr(idx, llvm::Attribute::NoAlias);
@@ -532,8 +532,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
UnpackClosureData(cdata, vfields, &new_vmap);
// setup parallel env
ParallelEnv par_env;
- par_env.task_id = Var("task_id", Int(32));
- par_env.num_task = Var("num_task", Int(32));
+ par_env.task_id = Var("task_id", DataType::Int(32));
+ par_env.num_task = Var("num_task", DataType::Int(32));
new_vmap[par_env.task_id.get()] = task_id;
new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
builder_->CreateInBoundsGEP(
@@ -670,7 +670,7 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
llvm::BasicBlock *
CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue,
- llvm::Value **ret_tcode, const Type &r_type,
+ llvm::Value **ret_tcode, const DataType &r_type,
const int64_t begin, const int64_t end) {
using llvm::BasicBlock;
std::string func_name = args[0].as<StringImm>()->value;
@@ -684,15 +684,15 @@ CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue,
builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(begin));
llvm::Value *arg_tcode =
- CreateBufferPtr(Int(32), stack_tcode, ConstInt32(begin));
+ CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin));
llvm::Value *ret_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(end));
- *ret_tcode = CreateBufferPtr(Int(32), stack_tcode, ConstInt32(end));
+ *ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));
BasicBlock *end_block = CheckCallSuccess(builder_->CreateCall(
RuntimeTVMFuncCall(), {handle, arg_value, arg_tcode, ConstInt32(nargs),
ret_value, *ret_tcode}));
- Type r_api_type = ir::APIType(r_type);
+ DataType r_api_type = ir::APIType(r_type);
*rvalue = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ret_value,
LLVMType(r_api_type)->getPointerTo()),
@@ -705,7 +705,7 @@ llvm::Value *CodeGenCPU::CreateCallPacked(const Call *op) {
CHECK_EQ(op->args.size(), 5U);
llvm::Value *rvalue = nullptr;
llvm::Value *ret_tcode = nullptr;
- MakeCallPacked(op->args, &rvalue, &ret_tcode, op->type,
+ MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype,
op->args[3].as<IntImm>()->value,
op->args[4].as<IntImm>()->value);
return rvalue;
@@ -717,7 +717,7 @@ llvm::Value *CodeGenCPU::CreateCallTracePacked(const Call *op) {
llvm::Value *rvalue = nullptr;
llvm::Value *ret_tcode = nullptr;
BasicBlock *end_block = MakeCallPacked(
- op->args, &rvalue, &ret_tcode, op->type, op->args[3].as<IntImm>()->value,
+ op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImm>()->value,
op->args[4].as<IntImm>()->value);
// Get traced value.
llvm::Value *traced_value = MakeValue(op->args[5]);
@@ -800,7 +800,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImm>()->value;
llvm::Value* ref = this->CreateStructRefPtr(
- op->type, MakeValue(op->args[0]),
+ op->dtype, MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
if (kind == intrinsic::kArrAddr) {
return builder_->CreatePointerCast(ref, t_void_p_);
@@ -812,7 +812,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
int kind = op->args[2].as<IntImm>()->value;
llvm::Value* value = MakeValue(op->args[3]);
llvm::Value* ref = this->CreateStructRefPtr(
- op->args[3].type(), MakeValue(op->args[0]),
+ op->args[3].dtype(), MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
CHECK(kind != intrinsic::kArrAddr);
if (value->getType()->isPointerTy()) {
@@ -922,7 +922,7 @@ void CodeGenCPU::VisitStmt_(const For* op) {
CHECK(parallel_env_.task_id.defined());
CHECK(parallel_env_.num_task.defined());
CHECK(parallel_env_.penv != nullptr);
- Type t = op->extent.type();
+ DataType t = op->extent.dtype();
Expr num_task = cast(t, parallel_env_.num_task);
Expr task_id = cast(t, parallel_env_.task_id);
CHECK(!parallel_env_.in_parallel_loop)
diff --git a/src/codegen/llvm/codegen_cpu.h b/src/codegen/llvm/codegen_cpu.h
index 52e6f6c..b9e1275 100644
--- a/src/codegen/llvm/codegen_cpu.h
+++ b/src/codegen/llvm/codegen_cpu.h
@@ -96,14 +96,14 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::Value* CreateStaticHandle();
llvm::Value* GetPackedFuncHandle(const std::string& str);
llvm::Value* PackClosureData(const Array<Var>& fields, uint64_t *num_bytes);
- llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind);
+ llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind);
void UnpackClosureData(llvm::Value*cdata,
const Array<Var>& fields,
std::unordered_map<const Variable*, llvm::Value*>* vmap);
// Make packed call.
llvm::BasicBlock *MakeCallPacked(const Array<Expr> &args,
llvm::Value **rvalue,
- llvm::Value **ret_tcode, const Type &r_type,
+ llvm::Value **ret_tcode, const DataType &r_type,
const int64_t begin, const int64_t end);
// create call into tvm packed function.
llvm::Value* CreateCallPacked(const Call* op);
diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc
index 2cff88b..94ad8b7 100644
--- a/src/codegen/llvm/codegen_llvm.cc
+++ b/src/codegen/llvm/codegen_llvm.cc
@@ -115,11 +115,11 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
std::vector<llvm::Type*> arg_types;
is_restricted_ = f->is_restricted;
for (Var arg : f->args) {
- Type t = arg.type();
+ DataType t = arg.dtype();
if (t.is_handle()) {
auto it = f->handle_data_type.find(arg);
if (it != f->handle_data_type.end()) {
- arg_types.push_back(LLVMType((*it).second.type())
+ arg_types.push_back(LLVMType((*it).second.dtype())
->getPointerTo(GetGlobalAddressSpace()));
} else {
arg_types.push_back(t_int8_->getPointerTo(GetGlobalAddressSpace()));
@@ -128,7 +128,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
alias_var_set_.insert(arg.get());
}
} else {
- arg_types.push_back(LLVMType(arg.type()));
+ arg_types.push_back(LLVMType(arg.dtype()));
}
}
llvm::FunctionType* ftype = llvm::FunctionType::get(
@@ -147,7 +147,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
const Var& var = f->args[i];
var_map_[var.get()] = v;
if (is_restricted_) {
- if (var.type().is_handle() && !alias_var_set_.count(var.get())) {
+ if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
// set non alias.
#if TVM_LLVM_VERSION >= 50
function_->addParamAttr(i, llvm::Attribute::NoAlias);
@@ -302,7 +302,7 @@ unsigned CodeGenLLVM::GetGlobalAddressSpace() {
return 0;
}
-llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
+llvm::Type* CodeGenLLVM::LLVMType(const DataType& t) const {
if (t.is_handle()) {
CHECK_EQ(t.lanes(), 1);
return t_void_p_;
@@ -335,7 +335,7 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
const Variable* buffer,
Expr index,
- Type type) {
+ DataType type) {
if (alias_var_set_.count(buffer) != 0) {
// Mark all possibly aliased pointer as same type.
llvm::MDNode* meta = md_tbaa_alias_set_;
@@ -387,7 +387,7 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
md_builder_->createTBAAStructTagNode(meta, meta, 0));
}
-void CodeGenLLVM::GetAlignment(Type t,
+void CodeGenLLVM::GetAlignment(DataType t,
const Variable* buf_var,
const Expr& index,
int* p_alignment,
@@ -474,7 +474,7 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
}
llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
- llvm::Value* mask = llvm::UndefValue::get(LLVMType(Int(32, target_lanes)));
+ llvm::Value* mask = llvm::UndefValue::get(LLVMType(DataType::Int(32, target_lanes)));
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
if (num_elems == target_lanes) return vec;
CHECK_LT(num_elems, target_lanes);
@@ -542,19 +542,19 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
loop_value->addIncoming(begin, pre_block);
CHECK(!var_map_.count(loop_var.get()));
var_map_[loop_var.get()] = loop_value;
- builder_->CreateCondBr(CreateLT(loop_var.type(), loop_value, end),
+ builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end),
for_body, for_end, md_very_likely_branch_);
builder_->SetInsertPoint(for_body);
this->VisitStmt(body);
var_map_.erase(loop_var.get());
- llvm::Value* loop_next = CreateAdd(loop_var.type(), loop_value, stride);
+ llvm::Value* loop_next = CreateAdd(loop_var.dtype(), loop_value, stride);
loop_value->addIncoming(loop_next, builder_->GetInsertBlock());
builder_->CreateBr(for_begin);
builder_->SetInsertPoint(for_end);
}
// cast operatpr
-llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
+llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
llvm::Type * target = LLVMType(to);
if (value->getType() == target) return value;
if (to.is_handle()) {
@@ -609,7 +609,7 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
}
llvm::Value* CodeGenLLVM::CreateBufferPtr(
- Type t, llvm::Value* buffer, llvm::Value* index) {
+ DataType t, llvm::Value* buffer, llvm::Value* index) {
CHECK_EQ(t.lanes(), 1);
llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
CHECK(btype != nullptr);
@@ -622,7 +622,7 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr(
}
llvm::Value* CodeGenLLVM::CreateBufferVecPtr(
- Type t, llvm::Value* buffer, llvm::Value* index) {
+ DataType t, llvm::Value* buffer, llvm::Value* index) {
CHECK_GT(t.lanes(), 1);
llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
CHECK(btype != nullptr);
@@ -647,7 +647,7 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
arg_type.push_back(arg_value.back()->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(
- LLVMType(op->type), arg_type, false);
+ LLVMType(op->dtype), arg_type, false);
llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) {
f = llvm::Function::Create(
@@ -674,7 +674,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
sig_type.push_back(arg_value.back()->getType());
}
}
- llvm::Type *return_type = LLVMType(op->type);
+ llvm::Type *return_type = LLVMType(op->dtype);
if (sig_type.size() > 0 && return_type != sig_type[0]) {
sig_type.insert(sig_type.begin(), return_type);
}
@@ -692,7 +692,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
} else if (op->is_intrinsic(Call::shift_left)) {
return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::shift_right)) {
- if (op->args[0].type().is_int()) {
+ if (op->args[0].dtype().is_int()) {
return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else {
return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
@@ -707,13 +707,13 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
unsigned addrspace;
if (!r) {
ptr = CreateBufferPtr(
- l->type, MakeValue(l->buffer_var), MakeValue(l->index));
+ l->dtype, MakeValue(l->buffer_var), MakeValue(l->index));
addrspace = llvm::dyn_cast<llvm::PointerType>(
ptr->getType())->getAddressSpace();
} else {
- Expr index = r->base / make_const(Int(32), r->lanes);
+ Expr index = r->base / make_const(DataType::Int(32), r->lanes);
ptr = CreateBufferVecPtr(
- l->type, MakeValue(l->buffer_var), MakeValue(index));
+ l->dtype, MakeValue(l->buffer_var), MakeValue(index));
addrspace = llvm::dyn_cast<llvm::PointerType>(
ptr->getType())->getAddressSpace();
}
@@ -723,7 +723,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
return builder_->CreateIsNull(MakeValue(op->args[0]));
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
- CHECK_EQ(op->args[0].type().lanes(), 1)
+ CHECK_EQ(op->args[0].dtype().lanes(), 1)
<< "if_then_else can only take scalar condition";
using llvm::BasicBlock;
BasicBlock* then_block = BasicBlock::Create(
@@ -747,7 +747,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
value->addIncoming(else_value, else_value_block);
return value;
} else if (op->is_intrinsic(Call::reinterpret)) {
- llvm::Type * target = LLVMType(op->type);
+ llvm::Type * target = LLVMType(op->dtype);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
} else if (op->is_intrinsic(Call::isnan)) {
// TODO(hgt312): set fast math flag
@@ -779,13 +779,13 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
void CodeGenLLVM::Scalarize(const Expr& e,
std::function<void(int i, llvm::Value* v)> f) {
if (const Ramp* ramp = e.as<Ramp>()) {
- for (int i = 0; i < ramp->type.lanes(); ++i) {
+ for (int i = 0; i < ramp->dtype.lanes(); ++i) {
Expr offset = ramp->base + (ramp->stride * i);
f(i, MakeValue(offset));
}
} else {
llvm::Value* value = MakeValue(e);
- for (int i = 0; i < e.type().lanes(); ++i) {
+ for (int i = 0; i < e.dtype().lanes(); ++i) {
f(i, builder_->CreateExtractElement(value, i));
}
}
@@ -798,18 +798,18 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) {
- return CreateCast(op->value.type(), op->type, MakeValue(op->value));
+ return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) {
- return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value);
+ return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImm* op) {
- return llvm::ConstantInt::get(LLVMType(op->type), op->value);
+ return llvm::ConstantInt::get(LLVMType(op->dtype), op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImm* op) {
- return llvm::ConstantFP::get(LLVMType(op->type), op->value);
+ return llvm::ConstantFP::get(LLVMType(op->dtype), op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
@@ -818,7 +818,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
#define DEFINE_CODEGEN_BINARY_OP(Op) \
llvm::Value* CodeGenLLVM::Create ## Op( \
- Type t, llvm::Value* a, llvm::Value *b) { \
+ DataType t, llvm::Value* a, llvm::Value *b) { \
if (t.is_int()) { \
if (t.bits() >= 32) { \
return builder_->CreateNSW ## Op (a, b); \
@@ -837,7 +837,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
} \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \
- return Create ## Op(op->type, MakeValue(op->a), MakeValue(op->b)); \
+ return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
}
DEFINE_CODEGEN_BINARY_OP(Add);
@@ -846,7 +846,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul);
#define DEFINE_CODEGEN_CMP_OP(Op) \
llvm::Value* CodeGenLLVM::Create ## Op( \
- Type t, llvm::Value* a, llvm::Value* b) { \
+ DataType t, llvm::Value* a, llvm::Value* b) { \
if (t.is_int()) { \
return builder_->CreateICmpS ## Op (a, b); \
} else if (t.is_uint()) { \
@@ -857,7 +857,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul);
} \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \
- return Create ## Op(op->a.type(), MakeValue(op->a), MakeValue(op->b)); \
+ return Create ## Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \
}
DEFINE_CODEGEN_CMP_OP(LT);
@@ -868,12 +868,12 @@ DEFINE_CODEGEN_CMP_OP(GE);
llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
- if (op->type.is_int()) {
+ if (op->dtype.is_int()) {
return builder_->CreateSDiv(a, b);
- } else if (op->type.is_uint()) {
+ } else if (op->dtype.is_uint()) {
return builder_->CreateUDiv(a, b);
} else {
- CHECK(op->type.is_float());
+ CHECK(op->dtype.is_float());
return builder_->CreateFDiv(a, b);
}
}
@@ -881,12 +881,12 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
- if (op->type.is_int()) {
+ if (op->dtype.is_int()) {
return builder_->CreateSRem(a, b);
- } else if (op->type.is_uint()) {
+ } else if (op->dtype.is_uint()) {
return builder_->CreateURem(a, b);
} else {
- CHECK(op->type.is_float());
+ CHECK(op->dtype.is_float());
return builder_->CreateFRem(a, b);
}
}
@@ -894,19 +894,19 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
- return builder_->CreateSelect(CreateLT(op->a.type(), a, b), a, b);
+ return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
- return builder_->CreateSelect(CreateGT(op->a.type(), a, b), a, b);
+ return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
- if (op->a.type().is_int() || op->a.type().is_uint()) {
+ if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
return builder_->CreateICmpEQ(a, b);
} else {
return builder_->CreateFCmpOEQ(a, b);
@@ -916,7 +916,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
- if (op->a.type().is_int() || op->a.type().is_uint()) {
+ if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
return builder_->CreateICmpNE(a, b);
} else {
return builder_->CreateFCmpONE(a, b);
@@ -950,7 +950,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
- Type t = op->type;
+ DataType t = op->dtype;
bool is_volatile = volatile_buf_.count(op->buffer_var.get());
llvm::Value* buffer = MakeValue(op->buffer_var);
llvm::Value* index = MakeValue(op->index);
@@ -1010,10 +1010,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
- llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->type));
+ llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->dtype));
for (int i = 0; i < op->lanes; ++i) {
vec = builder_->CreateInsertElement(
- vec, MakeValue(op->base + op->stride * make_const(op->stride.type(), i)),
+ vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)),
ConstInt32(i));
}
return vec;
@@ -1024,7 +1024,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) {
int total_lanes = 0;
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
vecs[i] = VisitExpr(op->vectors[i]);
- total_lanes += op->vectors[i].type().lanes();
+ total_lanes += op->vectors[i].dtype().lanes();
}
llvm::Value* v0 = CreateVecConcat(vecs);
std::vector<uint32_t> idx(op->indices.size());
@@ -1045,7 +1045,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
void CodeGenLLVM::VisitStmt_(const Store* op) {
CHECK(is_one(op->predicate));
- Type t = op->value.type();
+ DataType t = op->value.dtype();
bool is_volatile = volatile_buf_.count(op->buffer_var.get());
llvm::Value* buffer = MakeValue(op->buffer_var);
llvm::Value* index = MakeValue(op->index);
@@ -1056,7 +1056,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits);
llvm::Value* ptr = CreateBufferPtr(t, buffer, index);
llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
- AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
+ AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.dtype());
return;
} else {
// vector store
@@ -1071,7 +1071,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
t.element_of(), buffer, MakeValue(ramp->base));
ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
- AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.type());
+ AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.dtype());
return;
}
}
@@ -1084,7 +1084,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
llvm::StoreInst* store = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, i),
ptr, basic_align, is_volatile);
- AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.type());
+ AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.dtype());
};
this->Scalarize(op->index, f);
}
@@ -1142,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
<< "Can only handle constant size stack allocation";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
- info.alignment = GetTempAllocaAlignment(op->type, constant_size);
+ info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
}
// maximum necessary alignment in the NV devices
if (info.alignment > 16) {
@@ -1150,7 +1150,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
}
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
- LLVMType(op->type), ConstInt32(constant_size));
+ LLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
@@ -1163,7 +1163,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
buf = alloca;
}
buf = builder_->CreatePointerCast(
- buf, LLVMType(op->type)->getPointerTo(
+ buf, LLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
@@ -1204,7 +1204,7 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
CHECK(!var_map_.count(op->var.get()));
- if (op->var.type().is_handle()) {
+ if (op->var.dtype().is_handle()) {
if (!is_restricted_) {
alias_var_set_.insert(op->var.get());
}
diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h
index b7d091b..08c836a 100644
--- a/src/codegen/llvm/codegen_llvm.h
+++ b/src/codegen/llvm/codegen_llvm.h
@@ -206,12 +206,12 @@ class CodeGenLLVM :
* \param t The original type.
* \return LLVM type of t
*/
- llvm::Type* LLVMType(const Type& t) const;
+ llvm::Type* LLVMType(const DataType& t) const;
// initialize the function state.
void InitFuncState();
// Get alignment given index.
void GetAlignment(
- Type t, const Variable* buf_var, const Expr& index,
+ DataType t, const Variable* buf_var, const Expr& index,
int* p_alignment, int* p_native_bits);
// Get constant string
llvm::Value* GetConstString(const std::string& str);
@@ -221,19 +221,19 @@ class CodeGenLLVM :
// handle module import
void HandleImport(const std::string& code);
// cast operatpr
- llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
+ llvm::Value* CreateCast(DataType from, DataType to, llvm::Value* value);
// comparison op
llvm::Value* GetVarValue(const Variable* v) const;
- llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
- llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
- llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
- llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
- llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
- llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
- llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
+ llvm::Value* CreateLT(DataType t, llvm::Value* a, llvm::Value* b);
+ llvm::Value* CreateLE(DataType t, llvm::Value* a, llvm::Value* b);
+ llvm::Value* CreateGT(DataType t, llvm::Value* a, llvm::Value* b);
+ llvm::Value* CreateGE(DataType t, llvm::Value* a, llvm::Value* b);
+ llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b);
+ llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b);
+ llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
- llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
- llvm::Value* CreateBufferVecPtr(Type t, llvm::Value* buffer, llvm::Value* index);
+ llvm::Value* CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index);
+ llvm::Value* CreateBufferVecPtr(DataType t, llvm::Value* buffer, llvm::Value* index);
// Vector concatenation.
llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
llvm::Value* CreateVecFlip(llvm::Value* vec);
@@ -245,7 +245,7 @@ class CodeGenLLVM :
llvm::Value* stride,
const VarExpr& loop_var, const Stmt& body);
// add alias information.
- void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
+ void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, DataType type);
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function
diff --git a/src/codegen/llvm/codegen_nvptx.cc b/src/codegen/llvm/codegen_nvptx.cc
index b6bc6ef..372408c 100644
--- a/src/codegen/llvm/codegen_nvptx.cc
+++ b/src/codegen/llvm/codegen_nvptx.cc
@@ -58,7 +58,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
<< "Can only handle constant size stack allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
if (constant_size % 4 == 0 && info.alignment == 0) {
- info.alignment = GetTempAllocaAlignment(op->type, constant_size);
+ info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
}
// maximum necessary alignment in the NV devices
if (info.alignment > 16) {
@@ -69,7 +69,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
- LLVMType(op->type), ConstInt32(constant_size));
+ LLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
@@ -84,7 +84,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
- llvm::Type* type = llvm::ArrayType::get(LLVMType(op->type), constant_size);
+ llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
@@ -98,7 +98,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
}
}
buf = builder_->CreatePointerCast(
- buf, LLVMType(op->type)->getPointerTo(
+ buf, LLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
diff --git a/src/codegen/llvm/codegen_x86_64.cc b/src/codegen/llvm/codegen_x86_64.cc
index 804d9b2..5d72b56 100644
--- a/src/codegen/llvm/codegen_x86_64.cc
+++ b/src/codegen/llvm/codegen_x86_64.cc
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -74,8 +74,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
// LLVM does not automatically generate the correct instruction sequences for
// half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of
// vcvtph2ps), so we explicitly generate them ourselves.
- const auto from = op->value.type();
- const auto to = op->type;
+ const auto from = op->value.dtype();
+ const auto to = op->dtype;
if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) {
CHECK_EQ(from.lanes(), to.lanes());
CHECK_NOTNULL(target_machine_);
@@ -85,21 +85,25 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
if (from.lanes() >= 16 && has_avx512) {
return CallVectorIntrin(
- ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, LLVMType(Float(32, from.lanes())),
+ ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16,
+ LLVMType(DataType::Float(32, from.lanes())),
{
- MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
- ir::Call::PureIntrinsic)),
- MakeValue(ir::Broadcast::make(ir::FloatImm::make(Float(32), 0), from.lanes())),
- /*mask=*/MakeValue(ir::IntImm::make(Int(16), -1)),
- /*rounding-mode=*/MakeValue(ir::IntImm::make(Int(32), 4)),
+ MakeValue(ir::Call::make(
+ DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
+ ir::Call::PureIntrinsic)),
+ MakeValue(
+ ir::Broadcast::make(ir::FloatImm::make(DataType::Float(32), 0), from.lanes())),
+ /*mask=*/MakeValue(ir::IntImm::make(DataType::Int(16), -1)),
+ /*rounding-mode=*/MakeValue(ir::IntImm::make(DataType::Int(32), 4)),
});
}
if (from.lanes() >= 8 && has_f16c) {
return CallVectorIntrin(
- ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(Float(32, from.lanes())),
- {MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
- ir::Call::PureIntrinsic))});
+ ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())),
+ {MakeValue(ir::Call::make(
+ DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
+ ir::Call::PureIntrinsic))});
}
}
diff --git a/src/codegen/llvm/intrin_rule_llvm.cc b/src/codegen/llvm/intrin_rule_llvm.cc
index fd28d7e..da07ff3 100644
--- a/src/codegen/llvm/intrin_rule_llvm.cc
+++ b/src/codegen/llvm/intrin_rule_llvm.cc
@@ -67,19 +67,19 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
const ir::Call* call = e.as<ir::Call>();
CHECK(call != nullptr);
const Expr& x = call->args[0];
- Expr one = make_const(x.type(), 1);
- Expr two = make_const(x.type(), 2);
- Expr neg_two = make_const(x.type(), -2);
+ Expr one = make_const(x.dtype(), 1);
+ Expr two = make_const(x.dtype(), 2);
+ Expr neg_two = make_const(x.dtype(), -2);
Expr exp_neg2x = ir::Call::make(
- x.type(), "exp", {neg_two * x}, ir::Call::PureIntrinsic);
+ x.dtype(), "exp", {neg_two * x}, ir::Call::PureIntrinsic);
Expr exp_pos2x = ir::Call::make(
- x.type(), "exp", {two * x}, ir::Call::PureIntrinsic);
+ x.dtype(), "exp", {two * x}, ir::Call::PureIntrinsic);
Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
*rv = ir::Select::make(
- x >= make_zero(x.type()), tanh_pos, tanh_neg);
+ x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
});
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
diff --git a/src/codegen/llvm/intrin_rule_llvm.h b/src/codegen/llvm/intrin_rule_llvm.h
index c0b5241..7863a3d 100644
--- a/src/codegen/llvm/intrin_rule_llvm.h
+++ b/src/codegen/llvm/intrin_rule_llvm.h
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -41,14 +41,14 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
- cargs.push_back(ir::UIntImm::make(UInt(32), id));
- cargs.push_back(ir::UIntImm::make(UInt(32), num_signature));
+ cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
+ cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
- call->type, "llvm_intrin", cargs, ir::Call::PureIntrinsic);
+ call->dtype, "llvm_intrin", cargs, ir::Call::PureIntrinsic);
}
template<unsigned id, int num_signature>
@@ -58,13 +58,13 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
- cargs.push_back(ir::UIntImm::make(UInt(32), id));
- cargs.push_back(ir::UIntImm::make(UInt(32), num_signature));
+ cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
+ cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
- call->type, "llvm_intrin", cargs, ir::Call::Intrinsic);
+ call->dtype, "llvm_intrin", cargs, ir::Call::Intrinsic);
}
} // namespace codegen
diff --git a/src/codegen/llvm/intrin_rule_nvptx.cc b/src/codegen/llvm/intrin_rule_nvptx.cc
index 4718cf7..862d06b 100644
--- a/src/codegen/llvm/intrin_rule_nvptx.cc
+++ b/src/codegen/llvm/intrin_rule_nvptx.cc
@@ -35,11 +35,11 @@ inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
using namespace ir;
const Call* call = e.as<Call>();
CHECK(call != nullptr);
- CHECK(call->type.bits() == 32 || call->type.bits() == 64) << "Only support float32 or float64.";
+ CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64.";
std::ostringstream intrinsic_name;
intrinsic_name << "__nv_" << call->name;
- if (call->type.bits() == 32) intrinsic_name << "f";
- *rv = Call::make(call->type, intrinsic_name.str(), call->args,
+ if (call->dtype.bits() == 32) intrinsic_name << "f";
+ *rv = Call::make(call->dtype, intrinsic_name.str(), call->args,
Call::PureExtern);
}
diff --git a/src/codegen/llvm/intrin_rule_rocm.cc b/src/codegen/llvm/intrin_rule_rocm.cc
index 5ad5261..22b3245 100644
--- a/src/codegen/llvm/intrin_rule_rocm.cc
+++ b/src/codegen/llvm/intrin_rule_rocm.cc
@@ -36,8 +36,8 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
const Call* call = e.as<Call>();
CHECK(call != nullptr);
std::ostringstream intrinsic_name;
- intrinsic_name << "__ocml_" << call->name << "_f" << call->type.bits();
- *rv = Call::make(call->type, intrinsic_name.str(), call->args,
+ intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits();
+ *rv = Call::make(call->dtype, intrinsic_name.str(), call->args,
Call::PureExtern);
}
diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc
index be2b6cc..7800e47 100644
--- a/src/codegen/spirv/codegen_spirv.cc
+++ b/src/codegen/spirv/codegen_spirv.cc
@@ -37,11 +37,11 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const LoweredFunc& f) {
std::vector<Var> pod_args;
uint32_t num_buffer = 0;
for (Var arg : f->args) {
- Type t = arg.type();
+ DataType t = arg.dtype();
if (t.is_handle()) {
auto it = f->handle_data_type.find(arg);
if (it != f->handle_data_type.end()) {
- Type value_type = (*it).second.type();
+ DataType value_type = (*it).second.dtype();
spirv::Value arg_value = builder_->BufferArgument(
builder_->GetSType(value_type), 0, num_buffer);
storage_info_[arg.get()].UpdateContentType(value_type);
@@ -61,7 +61,7 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const LoweredFunc& f) {
if (pod_args.size() != 0) {
std::vector<spirv::SType> value_types;
for (size_t i = 0; i < pod_args.size(); ++i) {
- value_types.push_back(builder_->GetSType(pod_args[i].type()));
+ value_types.push_back(builder_->GetSType(pod_args[i].dtype()));
}
spirv::Value ptr = builder_->DeclarePushConstant(value_types);
for (size_t i = 0; i < pod_args.size(); ++i) {
@@ -103,7 +103,7 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(
} else {
v = builder_->GetWorkgroupID(ts.dim_index);
}
- return builder_->Cast(builder_->GetSType(iv->var.type()), v);
+ return builder_->Cast(builder_->GetSType(iv->var.dtype()), v);
}
spirv::Value CodeGenSPIRV::CreateStorageSync(const Call* op) {
@@ -112,7 +112,7 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const Call* op) {
if (sync == "warp") {
return value;
} else if (sync == "shared") {
- auto type_int = builder_->GetSType(Int(32));
+ auto type_int = builder_->GetSType(DataType::Int(32));
builder_->MakeInst(
spv::OpControlBarrier,
builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
@@ -133,15 +133,15 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Variable* op) {
}
spirv::Value CodeGenSPIRV::VisitExpr_(const IntImm* op) {
- return builder_->IntImm(builder_->GetSType(op->type), op->value);
+ return builder_->IntImm(builder_->GetSType(op->dtype), op->value);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImm* op) {
- return builder_->UIntImm(builder_->GetSType(op->type), op->value);
+ return builder_->UIntImm(builder_->GetSType(op->dtype), op->value);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImm* op) {
- return builder_->FloatImm(builder_->GetSType(op->type), op->value);
+ return builder_->FloatImm(builder_->GetSType(op->dtype), op->value);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const StringImm* op) {
@@ -150,7 +150,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const StringImm* op) {
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Cast* op) {
- return builder_->Cast(builder_->GetSType(op->type), MakeValue(op->value));
+ return builder_->Cast(builder_->GetSType(op->dtype), MakeValue(op->value));
}
spirv::Value CodeGenSPIRV::VisitExpr_(const Add* op) {
@@ -248,7 +248,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
values.push_back(MakeValue(op->args[i]));
}
return builder_->CallGLSL450(
- builder_->GetSType(op->type), inst_id, values);
+ builder_->GetSType(op->dtype), inst_id, values);
} else if (op->is_intrinsic(Call::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
@@ -277,13 +277,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
- if (op->args[0].type().is_int()) {
+ if (op->args[0].dtype().is_int()) {
return builder_->MakeValue(spv::OpShiftRightArithmetic, a.stype, a, b);
} else {
return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
}
} else if (op->is_intrinsic(Call::reinterpret)) {
- return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->type),
+ return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype),
MakeValue(op->args[0]));
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return this->CreateStorageSync(op);
@@ -316,17 +316,17 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
} else if (op->is_intrinsic("popcount")) {
return builder_->MakeValue(
spv::OpBitCount,
- builder_->GetSType(op->type),
+ builder_->GetSType(op->dtype),
MakeValue(op->args[0]));
} else {
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
LOG(FATAL) << "Unresolved intrinsic " << op->name
- << " with return type " << op->type;
+ << " with return type " << op->dtype;
} else if (op->call_type == Call::Extern ||
op->call_type == Call::PureExtern) {
LOG(FATAL) << "Unresolved extern " << op->name
- << " with return type " << op->type;
+ << " with return type " << op->dtype;
} else {
LOG(FATAL) << "Unresolved call type " << op->call_type;
}
@@ -341,7 +341,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) {
spirv::Value v = base;
if (i != 0) {
spirv::Value offset = MakeValue(
- make_const(op->stride.type(), i) * op->stride);
+ make_const(op->stride.dtype(), i) * op->stride);
v = builder_->Add(v, offset);
}
values.push_back(v);
@@ -364,7 +364,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) {
CHECK(it != storage_info_.end());
StorageInfo& info = it->second;
if (!info.content_fixed) {
- info.UpdateContentType(op->type);
+ info.UpdateContentType(op->dtype);
}
spirv::SType content_type = builder_->GetSType(info.content_type);
@@ -376,15 +376,15 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) {
if (info.is_volatile) {
mask |= spv::MemoryAccessVolatileMask;
}
- if (op->type.lanes() == 1) {
- CHECK_EQ(info.content_type, op->type)
+ if (op->dtype.lanes() == 1) {
+ CHECK_EQ(info.content_type, op->dtype)
<< "Vulkan only allow one type access to the same buffer";
spirv::Value index = MakeValue(op->index);
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, index);
return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
} else {
- if (op->type.element_of() == info.content_type) {
+ if (op->dtype.element_of() == info.content_type) {
// because content type is element type, we can only do scalarize load.
std::vector<spirv::Value> values;
auto f = [&](int i, spirv::Value index) {
@@ -398,13 +398,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) {
} else {
if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) {
- CHECK_EQ(ramp->lanes, op->type.lanes());
+ CHECK_EQ(ramp->lanes, op->dtype.lanes());
arith::ModularSet me = analyzer_->modular_set(ramp->base);
CHECK((me->coeff % ramp->lanes) == 0 &&
(me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
Expr vec_index = ir::Simplify(
- ramp->base / make_const(ramp->base.type(), ramp->lanes));
+ ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, MakeValue(vec_index));
return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
@@ -420,14 +420,14 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) {
void CodeGenSPIRV::Scalarize(const Expr& e,
std::function<void(int i, spirv::Value v)> f) {
if (const Ramp* ramp = e.as<Ramp>()) {
- for (int i = 0; i < ramp->type.lanes(); ++i) {
+ for (int i = 0; i < ramp->dtype.lanes(); ++i) {
Expr offset = ramp->base + ramp->stride * i;
f(i, MakeValue(offset));
}
} else {
- spirv::SType etype = builder_->GetSType(e.type().element_of());
+ spirv::SType etype = builder_->GetSType(e.dtype().element_of());
spirv::Value value = MakeValue(e);
- for (int i = 0; i < e.type().lanes(); ++i) {
+ for (int i = 0; i < e.dtype().lanes(); ++i) {
f(i, builder_->MakeValue(
spv::OpCompositeExtract, etype, value, i));
}
@@ -441,7 +441,7 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) {
StorageInfo& info = it->second;
if (!info.content_fixed) {
- info.UpdateContentType(op->value.type());
+ info.UpdateContentType(op->value.dtype());
}
spirv::SType content_type = builder_->GetSType(info.content_type);
@@ -455,15 +455,15 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) {
mask |= spv::MemoryAccessVolatileMask;
}
- if (op->value.type().lanes() == 1) {
- CHECK_EQ(info.content_type, op->value.type())
+ if (op->value.dtype().lanes() == 1) {
+ CHECK_EQ(info.content_type, op->value.dtype())
<< "Vulkan only allow one type access to the same buffer";
spirv::Value index = MakeValue(op->index);
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, index);
builder_->MakeInst(spv::OpStore, ptr, value, mask);
} else {
- if (op->value.type().element_of() == info.content_type) {
+ if (op->value.dtype().element_of() == info.content_type) {
// because content type is element type, we can only do scalarize load.
auto f = [&](int i, spirv::Value index) {
spirv::Value elem = builder_->MakeValue(
@@ -476,13 +476,13 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) {
} else {
if (const Ramp* ramp = op->index.as<Ramp>()) {
if (is_one(ramp->stride)) {
- CHECK_EQ(ramp->lanes, op->value.type().lanes());
+ CHECK_EQ(ramp->lanes, op->value.dtype().lanes());
arith::ModularSet me = analyzer_->modular_set(ramp->base);
CHECK((me->coeff % ramp->lanes) == 0 &&
(me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
Expr vec_index = ir::Simplify(
- ramp->base / make_const(ramp->base.type(), ramp->lanes));
+ ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, MakeValue(vec_index));
builder_->MakeInst(spv::OpStore, ptr, value, mask);
@@ -530,7 +530,7 @@ void CodeGenSPIRV::VisitStmt_(const For* op) {
// loop continue
builder_->StartLabel(continue_label);
spirv::Value one =
- op->loop_var.type().is_int() ?
+ op->loop_var.dtype().is_int() ?
builder_->IntImm(loop_var.stype, 1) :
builder_->UIntImm(loop_var.stype, 1);
spirv::Value next_value = builder_->Add(loop_var, one);
@@ -576,13 +576,13 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElse* op) {
void CodeGenSPIRV::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition));
CHECK(!op->new_expr.defined());
- CHECK(!op->type.is_handle());
+ CHECK(!op->dtype.is_handle());
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation in GPU";
spirv::Value buf;
StorageInfo& info = storage_info_[op->buffer_var.get()];
- spirv::SType etype = builder_->GetSType(op->type);
+ spirv::SType etype = builder_->GetSType(op->dtype);
if (info.scope.rank == runtime::StorageRank::kLocal) {
buf = builder_->Allocate(
etype, static_cast<uint32_t>(constant_size),
@@ -597,7 +597,7 @@ void CodeGenSPIRV::VisitStmt_(const Allocate* op) {
spv::StorageClassWorkgroup);
}
CHECK(!info.content_fixed);
- info.UpdateContentType(op->type);
+ info.UpdateContentType(op->dtype);
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
this->VisitStmt(op->body);
@@ -632,7 +632,7 @@ void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) {
void CodeGenSPIRV::VisitStmt_(const LetStmt* op) {
CHECK(!var_map_.count(op->var.get()));
- CHECK(!op->var.type().is_handle());
+ CHECK(!op->var.dtype().is_handle());
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
this->VisitStmt(op->body);
diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h
index eca3614..3d16377 100644
--- a/src/codegen/spirv/codegen_spirv.h
+++ b/src/codegen/spirv/codegen_spirv.h
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -112,10 +112,10 @@ class CodeGenSPIRV:
/*! \brief Whether it is volatile */
bool content_fixed{false};
/*! \brief Current content type */
- Type content_type{Handle()};
+ DataType content_type{DataType::Handle()};
// Update content type if it hasn't beenupdated.
- void UpdateContentType(Type type) {
+ void UpdateContentType(DataType type) {
if (content_fixed) {
CHECK_EQ(type, content_type)
<< "Cannot use two different content type in GLSL model";
diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc
index fca9aa2..7a347e5 100644
--- a/src/codegen/spirv/intrin_rule_spirv.cc
+++ b/src/codegen/spirv/intrin_rule_spirv.cc
@@ -39,13 +39,13 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
- cargs.push_back(ir::UIntImm::make(UInt(32), id));
+ cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::Call::make(
- call->type, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
+ call->dtype, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc
index 35d57d7..6f8d96e 100644
--- a/src/codegen/spirv/ir_builder.cc
+++ b/src/codegen/spirv/ir_builder.cc
@@ -53,10 +53,10 @@ void IRBuilder::InitHeader() {
void IRBuilder::InitPreDefs() {
ext_glsl450_ = ExtInstImport("GLSL.std.450");
- t_int32_ = DeclareType(Int(32));
- t_uint32_ = DeclareType(UInt(32));
- t_bool_ = DeclareType(UInt(1));
- t_fp32_ = DeclareType(Float(32));
+ t_int32_ = DeclareType(DataType::Int(32));
+ t_uint32_ = DeclareType(DataType::UInt(32));
+ t_bool_ = DeclareType(DataType::UInt(1));
+ t_fp32_ = DeclareType(DataType::Float(32));
const_i32_zero_ = IntImm(t_int32_, 0);
// declare void, and void functions
t_void_.id = id_counter_++;
@@ -66,14 +66,14 @@ void IRBuilder::InitPreDefs() {
.AddSeq(t_void_func_, t_void_).Commit(&global_);
}
-SType IRBuilder::GetSType(const Type& dtype) {
- if (dtype == Int(32)) {
+SType IRBuilder::GetSType(const DataType& dtype) {
+ if (dtype == DataType::Int(32)) {
return t_int32_;
- } else if (dtype == UInt(1)) {
+ } else if (dtype == DataType::UInt(1)) {
return t_bool_;
- } else if (dtype == Float(32)) {
+ } else if (dtype == DataType::Float(32)) {
return t_fp32_;
- } else if (dtype == UInt(32)) {
+ } else if (dtype == DataType::UInt(32)) {
return t_uint32_;
}
uint32_t type_key;
@@ -99,7 +99,7 @@ SType IRBuilder::GetPointerType(const SType& value_type,
}
SType t;
t.id = id_counter_++;
- t.type = Handle();
+ t.type = DataType::Handle();
t.element_type_id = value_type.id;
t.storage_class = storage_class;
ib_.Begin(spv::OpTypePointer)
@@ -118,11 +118,11 @@ SType IRBuilder::GetStructArrayType(const SType& value_type,
SType arr_type;
arr_type.id = id_counter_++;
- arr_type.type = Handle();
+ arr_type.type = DataType::Handle();
arr_type.element_type_id = value_type.id;
if (num_elems != 0) {
- Value length = UIntImm(GetSType(UInt(32)), num_elems);
+ Value length = UIntImm(GetSType(DataType::UInt(32)), num_elems);
ib_.Begin(spv::OpTypeArray)
.AddSeq(arr_type, value_type, length).Commit(&global_);
} else {
@@ -138,7 +138,7 @@ SType IRBuilder::GetStructArrayType(const SType& value_type,
// declare struct of array
SType struct_type;
struct_type.id = id_counter_++;
- struct_type.type = Handle();
+ struct_type.type = DataType::Handle();
struct_type.element_type_id = value_type.id;
ib_.Begin(spv::OpTypeStruct)
.AddSeq(struct_type, arr_type).Commit(&global_);
@@ -183,7 +183,7 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) {
} else {
CHECK_EQ(dtype.type.bits(), 16);
return Cast(dtype,
- FloatImm(GetSType(Float(32)), value));
+ FloatImm(GetSType(DataType::Float(32)), value));
}
}
@@ -206,7 +206,7 @@ Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
CHECK_EQ(push_const_.id, 0);
SType struct_type;
struct_type.id = id_counter_++;
- struct_type.type = Handle();
+ struct_type.type = DataType::Handle();
ib_.Begin(spv::OpTypeStruct).Add(struct_type);
for (const SType& vtype : value_types) {
ib_.Add(vtype);
@@ -218,7 +218,7 @@ Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
ib_.Begin(spv::OpMemberDecorate)
.AddSeq(struct_type, i, spv::DecorationOffset, offset)
.Commit(&decorate_);
- Type t = value_types[i].type;
+ DataType t = value_types[i].type;
uint32_t nbits = t.bits() * t.lanes();
CHECK_EQ(nbits % 8 , 0);
offset += nbits / 8;
@@ -296,7 +296,7 @@ Value IRBuilder::Allocate(const SType& value_type,
Value IRBuilder::GetWorkgroupID(uint32_t dim_index) {
if (workgroup_id_.id == 0) {
- SType vec3_type = this->GetSType(Int(32).with_lanes(3));
+ SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3));
SType ptr_type = this->GetPointerType(
vec3_type, spv::StorageClassInput);
workgroup_id_ = NewValue(ptr_type, kVectorPtr);
@@ -315,7 +315,7 @@ Value IRBuilder::GetWorkgroupID(uint32_t dim_index) {
Value IRBuilder::GetLocalID(uint32_t dim_index) {
if (local_id_.id == 0) {
- SType vec3_type = this->GetSType(Int(32).with_lanes(3));
+ SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3));
SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput);
local_id_ = NewValue(ptr_type, kVectorPtr);
ib_.Begin(spv::OpVariable)
@@ -339,7 +339,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) {
}
CHECK_LE(dtype.type.bits(), 64);
Value ret = NewValue(dtype, kConstant);
- if (dtype.type == UInt(1)) {
+ if (dtype.type == DataType::UInt(1)) {
// bool types.
if (*pvalue) {
ib_.Begin(spv::OpConstantTrue).AddSeq(ret);
@@ -367,7 +367,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) {
return ret;
}
-SType IRBuilder::DeclareType(const Type& dtype) {
+SType IRBuilder::DeclareType(const DataType& dtype) {
if (dtype.lanes() == 1) {
SType t;
t.id = id_counter_++;
@@ -426,7 +426,7 @@ Value IRBuilder::CallGLSL450(const SType& ret_type,
Value IRBuilder::Concat(const std::vector<Value>& vec) {
bool is_const = vec[0].flag == kConstant;
- Type etype = vec[0].stype.type;
+ DataType etype = vec[0].stype.type;
int lanes = etype.lanes();
for (size_t i = 1; i < vec.size(); ++i) {
CHECK_EQ(etype, vec[i].stype.type.element_of())
@@ -456,10 +456,10 @@ Value IRBuilder::Concat(const std::vector<Value>& vec) {
Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
CHECK_NE(value.stype.id, 0U);
if (value.stype.id == dst_type.id) return value;
- const tvm::Type& from = value.stype.type;
- const tvm::Type& to = dst_type.type;
+ const tvm::DataType& from = value.stype.type;
+ const tvm::DataType& to = dst_type.type;
CHECK_EQ(from.lanes(), to.lanes());
- if (from == Bool()) {
+ if (from == DataType::Bool()) {
if (to.is_int()) {
return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0));
} else if (to.is_uint()) {
@@ -471,7 +471,7 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) {
LOG(FATAL) << "cannot cast from " << from << " to " << to;
return Value();
}
- } else if (to == Bool()) {
+ } else if (to == DataType::Bool()) {
if (from.is_int()) {
return NE(value, IntImm(value.stype, 0));
} else if (to.is_uint()) {
@@ -558,7 +558,7 @@ Value IRBuilder::Mod(Value a, Value b) {
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
- const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
+ const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
if (a.stype.type.is_int()) { \
return MakeValue(spv::OpS##_Op, bool_type, a, b); \
} else if (a.stype.type.is_uint()) { \
@@ -578,7 +578,7 @@ DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
Value IRBuilder::_OpName(Value a, Value b) { \
CHECK_EQ(a.stype.id, b.stype.id); \
CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \
- const auto& bool_type = this->GetSType(UInt(1).with_lanes(a.stype.type.lanes())); \
+ const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
if (a.stype.type.is_int() || a.stype.type.is_uint()) { \
return MakeValue(spv::OpI##_Op, bool_type, a, b); \
} else { \
@@ -592,7 +592,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual);
Value IRBuilder::Select(Value cond, Value a, Value b) {
CHECK_EQ(a.stype.id, b.stype.id);
- CHECK_EQ(cond.stype.type.element_of(), UInt(1));
+ CHECK_EQ(cond.stype.type.element_of(), DataType::UInt(1));
return MakeValue(spv::OpSelect, a.stype, cond, a, b);
}
diff --git a/src/codegen/spirv/ir_builder.h b/src/codegen/spirv/ir_builder.h
index c04af74..3843cbb 100644
--- a/src/codegen/spirv/ir_builder.h
+++ b/src/codegen/spirv/ir_builder.h
@@ -45,7 +45,7 @@ struct SType {
/*! \brief The Id to represent type */
uint32_t id{0};
/*! \brief corresponding TVM type */
- tvm::Type type;
+ tvm::DataType type;
/*! \brief content type id if it is a pointer/struct-array class */
uint32_t element_type_id{0};
/*! \brief The storage class, if it is a pointer */
@@ -424,7 +424,7 @@ class IRBuilder {
* \param dtype The data type.
* \return The corresponding spirv type.
*/
- SType GetSType(const tvm::Type& dtype);
+ SType GetSType(const tvm::DataType& dtype);
/*!
* \brief Get the pointer type that points to value_type
* \param value_type.
@@ -575,7 +575,7 @@ class IRBuilder {
// get constant given value encoded in uint64_t
Value GetConst_(const SType& dtype, const uint64_t* pvalue);
// declare type
- SType DeclareType(const Type& dtype);
+ SType DeclareType(const DataType& dtype);
/*! \brief internal instruction builder */
InstrBuilder ib_;
/*! \brief Current label */
diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc
index fd2a5f7..52cabaf 100644
--- a/src/codegen/stackvm/codegen_stackvm.cc
+++ b/src/codegen/stackvm/codegen_stackvm.cc
@@ -100,12 +100,12 @@ int CodeGenStackVM::GetVarID(const Variable* v) const {
void CodeGenStackVM::VisitExpr_(const Load* op) {
this->Push(op->buffer_var);
- StackVM::OpCode code = StackVM::GetLoad(Type2TVMType(op->type));
+ StackVM::OpCode code = StackVM::GetLoad(op->dtype);
if (const IntImm* index = op->index.as<IntImm>()) {
this->PushOp(code, index->value);
} else {
this->Push(op->index);
- this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes());
+ this->PushOp(StackVM::PUSH_I64, op->dtype.element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
this->PushOp(code, 0);
@@ -114,13 +114,13 @@ void CodeGenStackVM::VisitExpr_(const Load* op) {
void CodeGenStackVM::VisitStmt_(const Store* op) {
this->Push(op->buffer_var);
- StackVM::OpCode code = StackVM::GetStore(Type2TVMType(op->value.type()));
+ StackVM::OpCode code = StackVM::GetStore(op->value.dtype());
if (const IntImm* index = op->index.as<IntImm>()) {
this->Push(op->value);
this->PushOp(code, index->value);
} else {
this->Push(op->index);
- this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes());
+ this->PushOp(StackVM::PUSH_I64, op->value.dtype().element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
this->Push(op->value);
@@ -147,7 +147,7 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
CHECK(op->args.size() == 1 && l);
this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
this->Push(l->index);
- this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes());
+ this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
} else if (op->is_intrinsic(Call::reinterpret)) {
@@ -248,7 +248,7 @@ void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
const Expr& b) {
this->Push(a);
this->Push(b);
- Type t = a.type();
+ DataType t = a.dtype();
if (t.is_int()) {
this->PushOp(op_int64);
} else if (t.is_uint()) {
@@ -258,7 +258,7 @@ void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
}
}
-void CodeGenStackVM::PushCast(Type dst, Type src) {
+void CodeGenStackVM::PushCast(DataType dst, DataType src) {
if (dst.is_int()) {
if (src.is_int() || src.is_uint()) return;
} else if (dst.is_uint()) {
@@ -297,7 +297,7 @@ void CodeGenStackVM::VisitExpr_(const Variable *op) {
void CodeGenStackVM::VisitExpr_(const Cast *op) {
this->Push(op->value);
- PushCast(op->type, op->value.type());
+ PushCast(op->dtype, op->value.dtype());
}
void CodeGenStackVM::VisitExpr_(const Add *op) {
diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h
index 1e6dd64..dcae072 100644
--- a/src/codegen/stackvm/codegen_stackvm.h
+++ b/src/codegen/stackvm/codegen_stackvm.h
@@ -108,7 +108,7 @@ class CodeGenStackVM
const Expr& a,
const Expr& b);
// push cast;
- void PushCast(Type dst, Type src);
+ void PushCast(DataType dst, DataType src);
// overloadable functions
// expression
void VisitExpr_(const Variable* op) final;
diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc
index 9e55d9b..2bb8609 100644
--- a/src/contrib/hybrid/codegen_hybrid.cc
+++ b/src/contrib/hybrid/codegen_hybrid.cc
@@ -57,7 +57,7 @@ std::string CodeGenHybrid::Finish() {
return stream.str();
}
-void CodeGenHybrid::PrintType(Type t, std::ostream &os) {
+void CodeGenHybrid::PrintType(DataType t, std::ostream &os) {
if (t.is_float()) {
os << "float";
CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
@@ -76,11 +76,11 @@ void CodeGenHybrid::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(
os << op->value;
}
void CodeGenHybrid::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*)
- PrintType(op->type, os);
+ PrintType(op->dtype, os);
os << "(" << op->value << ")";
}
void CodeGenHybrid::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
- PrintType(op->type, os);
+ PrintType(op->dtype, os);
os << "(" << std::setprecision(20) << op->value << ")";
}
void CodeGenHybrid::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*)
@@ -92,7 +92,7 @@ inline void PrintBinaryExpr(const T* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenHybrid* p) {
- CHECK(op->type.lanes() == 1) << "vec bin op not implemented";
+ CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented";
if (isalpha(opstr[0])) {
os << opstr << '(';
p->PrintExpr(op->a, os);
@@ -114,7 +114,7 @@ inline void PrintBinaryIntrinsitc(const Call* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenHybrid* p) {
- CHECK(op->type.lanes() == 1) << "vec bin intrin not implemented";
+ CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented";
CHECK_EQ(op->args.size(), 2U);
os << '(';
p->PrintExpr(op->args[0], os);
@@ -124,10 +124,10 @@ inline void PrintBinaryIntrinsitc(const Call* op,
}
void CodeGenHybrid::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*)
- if (op->type == op->value.type()) {
+ if (op->dtype == op->value.dtype()) {
PrintExpr(op->value, stream);
} else {
- PrintType(op->type, os);
+ PrintType(op->dtype, os);
os << "(";
PrintExpr(op->value, os);
os << ")";
@@ -148,14 +148,14 @@ void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*)
}
void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*)
- if (op->type.is_int())
+ if (op->dtype.is_int())
PrintBinaryExpr(op, "//", os, this);
else
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenHybrid::VisitExpr_(const FloorDiv *op, std::ostream& os) { // NOLINT(*)
- if (op->type.is_int())
+ if (op->dtype.is_int())
PrintBinaryExpr(op, "//", os, this);
else
PrintBinaryExpr(op, "/", os, this);
@@ -320,7 +320,7 @@ void CodeGenHybrid::VisitStmt_(const Realize *op) {
}
if (op->bounds.size() == 1) stream << ", ";
stream << "), '";
- PrintType(op->type, stream);
+ PrintType(op->dtype, stream);
stream << "', '";
stream << alloc_storage_scope_[op->func] << "')\n";
}
diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h
index 8667569..2c719b0 100644
--- a/src/contrib/hybrid/codegen_hybrid.h
+++ b/src/contrib/hybrid/codegen_hybrid.h
@@ -138,7 +138,7 @@ class CodeGenHybrid :
* \param t The type representation.
* \param os The stream to print the ctype into
*/
- virtual void PrintType(Type t, std::ostream& os); // NOLINT(*)
+ virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*)
private:
/*! \brief The current indent of the code dump. */
diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc
index 007a68b..b83734b 100644
--- a/src/lang/attrs.cc
+++ b/src/lang/attrs.cc
@@ -177,7 +177,7 @@ bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const ObjectRef& other) {
bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<Cast>()) {
- if (lhs->type != rhs->type) return false;
+ if (lhs->dtype != rhs->dtype) return false;
return Equal(lhs->value, rhs->value);
} else {
return false;
@@ -188,7 +188,7 @@ bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<Call>()) {
return
lhs->name == rhs->name &&
- lhs->type == rhs->type &&
+ lhs->dtype == rhs->dtype &&
lhs->call_type == rhs->call_type &&
Equal(lhs->args, rhs->args);
} else {
@@ -290,7 +290,7 @@ size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
static size_t key = std::hash<std::string>()(Cast::_type_key);
AttrsHash hasher;
size_t res = key;
- res = Combine(res, hasher(op->type));
+ res = Combine(res, hasher(op->dtype));
res = Combine(res, Hash(op->value));
return res;
}
@@ -300,7 +300,7 @@ size_t AttrsHashHandler::VisitAttr_(const Call* op) {
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->name));
- res = Combine(res, hasher(op->type));
+ res = Combine(res, hasher(op->dtype));
res = Combine(res, Hash(op->args));
return res;
}
diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc
index 77e7410..eb5d87e 100644
--- a/src/lang/buffer.cc
+++ b/src/lang/buffer.cc
@@ -42,10 +42,10 @@ Array<Expr> SimplifyArray(Array<Expr> array) {
}
Buffer decl_buffer(Array<Expr> shape,
- Type dtype,
+ DataType dtype,
std::string name) {
return BufferNode::make(
- Var(name, Handle()),
+ Var(name, DataType::Handle()),
dtype,
shape,
Array<Expr>(),
@@ -279,30 +279,30 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
return base;
}
-inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, Type dtype) {
+inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, DataType dtype) {
Expr offset = ElemOffset(n, index);
if (n->dtype.lanes() != 1) {
- offset = offset * make_const(offset.type(), dtype.lanes());
+ offset = offset * make_const(offset.dtype(), dtype.lanes());
}
if (dtype.lanes() != 1) {
- return ir::Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes());
+ return ir::Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
} else {
return offset;
}
}
-Expr Buffer::vload(Array<Expr> begin, Type dtype) const {
- // specially handle bool, stored as Int(8)
+Expr Buffer::vload(Array<Expr> begin, DataType dtype) const {
+ // specially handle bool, stored asDataType::Int(8)
const BufferNode* n = operator->();
CHECK(dtype.element_of() == n->dtype.element_of() &&
dtype.lanes() % n->dtype.lanes() == 0)
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
- if (dtype == Bool()) {
+ if (dtype == DataType::Bool()) {
return ir::Cast::make(
- Bool(),
+ DataType::Bool(),
ir::Load::make(
- Int(8), n->data, BufferOffset(n, begin, Int(8)),
+ DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)),
const_true()));
} else {
return ir::Load::make(
@@ -312,17 +312,17 @@ Expr Buffer::vload(Array<Expr> begin, Type dtype) const {
}
Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
- // specially handle bool, stored as Int(8)
+ // specially handle bool, stored asDataType::Int(8)
const BufferNode* n = operator->();
- Type dtype = value.type();
+ DataType dtype = value.dtype();
CHECK(dtype.element_of() == n->dtype.element_of() &&
dtype.lanes() % n->dtype.lanes() == 0)
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
- if (value.type() == Bool()) {
+ if (value.dtype() == DataType::Bool()) {
return ir::Store::make(n->data,
- ir::Cast::make(Int(8), value),
- BufferOffset(n, begin, Int(8)),
+ ir::Cast::make(DataType::Int(8), value),
+ BufferOffset(n, begin, DataType::Int(8)),
const_true());
} else {
return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
@@ -381,7 +381,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
n->buffer_type);
}
-Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const {
+Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, Expr offset) const {
const BufferNode* self = operator->();
Expr e_dtype;
Expr extent;
@@ -396,21 +396,21 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr
Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes));
- extent = extent / make_const(self->elem_offset.type(), content_lanes);
- elem_offset = self->elem_offset / make_const(self->elem_offset.type(),
+ extent = extent / make_const(self->elem_offset.dtype(), content_lanes);
+ elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(),
content_lanes);
} else {
e_dtype = ir::TypeAnnotation(self->dtype);
}
Array<Expr> acc_args{
e_dtype, self->data, elem_offset,
- extent, make_const(Int(32), access_mask)};
+ extent, make_const(DataType::Int(32), access_mask)};
return ir::Call::make(
ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic);
}
Buffer BufferNode::make(Var data,
- Type dtype,
+ DataType dtype,
Array<Expr> shape,
Array<Expr> strides,
Expr elem_offset,
diff --git a/src/lang/channel.cc b/src/lang/channel.cc
index cb3e2f5..555562a 100644
--- a/src/lang/channel.cc
+++ b/src/lang/channel.cc
@@ -24,7 +24,7 @@
namespace tvm {
-Channel ChannelNode::make(Var handle_var, Type dtype) {
+Channel ChannelNode::make(Var handle_var, DataType dtype) {
auto n = make_node<ChannelNode>();
n->handle_var = handle_var;
n->dtype = dtype;
diff --git a/src/lang/expr.cc b/src/lang/expr.cc
index 6a69fda..997c151 100644
--- a/src/lang/expr.cc
+++ b/src/lang/expr.cc
@@ -29,70 +29,11 @@
namespace tvm {
-// maximum and min values
-Expr DataType::max() const {
- using namespace ir;
- CHECK_EQ(lanes(), 1);
- if (is_int()) {
- if (bits() == 64) {
- return IntImm::make(*this, std::numeric_limits<int64_t>::max());
- } else if (bits() < 64) {
- int64_t val = 1;
- val = (val << (bits() - 1)) - 1;
- return IntImm::make(*this, val);
- }
- } else if (is_uint()) {
- if (bits() == 64) {
- return UIntImm::make(*this, std::numeric_limits<uint64_t>::max());
- } else if (bits() < 64) {
- uint64_t val = 1;
- val = (val << static_cast<uint64_t>(bits())) - 1;
- return UIntImm::make(*this, val);
- }
- } else if (is_float()) {
- if (bits() == 64) {
- return FloatImm::make(*this, std::numeric_limits<double>::max());
- } else if (bits() == 32) {
- return FloatImm::make(*this, std::numeric_limits<float>::max());
- } else if (bits() == 16) {
- return FloatImm::make(*this, 65504.0);
- }
- }
- LOG(FATAL) << "Cannot decide max_value for type" << *this;
- return Expr();
-}
-
-Expr DataType::min() const {
- using namespace ir;
- CHECK_EQ(lanes(), 1);
- if (is_int()) {
- if (bits() == 64) {
- return IntImm::make(*this, std::numeric_limits<int64_t>::lowest());
- } else if (bits() < 64) {
- int64_t val = 1;
- val = -(val << (bits() - 1));
- return IntImm::make(*this, val);
- }
- } else if (is_uint()) {
- return UIntImm::make(*this, 0);
- } else if (is_float()) {
- if (bits() == 64) {
- return FloatImm::make(*this, std::numeric_limits<double>::lowest());
- } else if (bits() == 32) {
- return FloatImm::make(*this, std::numeric_limits<float>::lowest());
- } else if (bits() == 16) {
- return FloatImm::make(*this, -65504.0);
- }
- }
- LOG(FATAL) << "Cannot decide min_value for type" << *this;
- return Expr();
-}
-
Expr::Expr(int32_t value)
- : Expr(IntImm::make(Int(32), value)) {}
+ : Expr(IntImm::make(DataType::Int(32), value)) {}
Expr::Expr(float value)
- : Expr(ir::FloatImm::make(Float(32), value)) {}
+ : Expr(ir::FloatImm::make(DataType::Float(32), value)) {}
Expr::Expr(std::string str)
: Expr(ir::StringImm::make(str)) {}
@@ -102,7 +43,7 @@ Var::Var(std::string name_hint, DataType t)
Var Variable::make(DataType t, std::string name_hint) {
NodePtr<Variable> node = make_node<Variable>();
- node->type = t;
+ node->dtype = t;
node->name_hint = std::move(name_hint);
return Var(node);
}
@@ -113,11 +54,11 @@ Range::Range(Expr begin, Expr end)
is_zero(begin) ? end : (end - begin))) {
}
-Integer IntImm::make(Type t, int64_t value) {
+Integer IntImm::make(DataType t, int64_t value) {
CHECK(t.is_int() && t.is_scalar())
<< "ValueError: IntImm can only take scalar.";
NodePtr<IntImm> node = make_node<IntImm>();
- node->type = t;
+ node->dtype = t;
node->value = value;
return Integer(node);
}
@@ -152,7 +93,7 @@ void Dump(const NodeRef& n) {
std::cerr << n << "\n";
}
-Var var(std::string name_hint, Type t) {
+Var var(std::string name_hint, DataType t) {
return Var(name_hint, t);
}
@@ -184,10 +125,10 @@ IRPrinter::FType& IRPrinter::vtable() {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntImm>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const IntImm*>(node.get());
- if (op->type == Int(32)) {
+ if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
} else {
- p->stream << "(" << op->type << ")" << op->value;
+ p->stream << "(" << op->dtype << ")" << op->value;
}
});
diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc
index 220d437..1166e7e 100644
--- a/src/lang/expr_operator.cc
+++ b/src/lang/expr_operator.cc
@@ -30,16 +30,16 @@
namespace tvm {
// simple cast that only checks if type matches and cast
-inline Expr SimpleCast(const Type& t, Expr value) {
- if (value.type() == t) return value;
+inline Expr SimpleCast(const DataType& t, Expr value) {
+ if (value.dtype() == t) return value;
return ir::Cast::make(t, value);
}
// The public function with a quick checking path.
void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*)
- if (lhs.type() == rhs.type()) return;
- Type ltype = lhs.type();
- Type rtype = rhs.type();
+ if (lhs.dtype() == rhs.dtype()) return;
+ DataType ltype = lhs.dtype();
+ DataType rtype = rhs.dtype();
if (ltype.lanes() == 1 && rtype.lanes() != 1) {
lhs = ir::Broadcast::make(lhs, rtype.lanes());
} else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
@@ -48,37 +48,96 @@ void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*)
CHECK(ltype.lanes() == rtype.lanes())
<< "Cannot match type " << ltype << " vs " << rtype;
}
- if (lhs.type() == rhs.type()) return;
+ if (lhs.dtype() == rhs.dtype()) return;
// Only do very simple type coversion
- // int->float, int(32)->int(64)
+ // int->float, DataType::Int(32)->int(64)
// require the types to be relatively consistent
// This will the reduce amount code generated by operators
// and also help user to find potential type conversion problems.
- if (!lhs.type().is_float() && rhs.type().is_float()) {
+ if (!lhs.dtype().is_float() && rhs.dtype().is_float()) {
// int->float
- lhs = cast(rhs.type(), lhs);
- } else if (lhs.type().is_float() && !rhs.type().is_float()) {
+ lhs = cast(rhs.dtype(), lhs);
+ } else if (lhs.dtype().is_float() && !rhs.dtype().is_float()) {
// int->float
- rhs = cast(lhs.type(), rhs);
- } else if ((lhs.type().is_int() && rhs.type().is_int()) ||
- (lhs.type().is_uint() && rhs.type().is_uint())) {
+ rhs = cast(lhs.dtype(), rhs);
+ } else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) ||
+ (lhs.dtype().is_uint() && rhs.dtype().is_uint())) {
// promote int to higher bits
- if (lhs.type().bits() < rhs.type().bits()) {
- lhs = cast(rhs.type(), lhs);
+ if (lhs.dtype().bits() < rhs.dtype().bits()) {
+ lhs = cast(rhs.dtype(), lhs);
} else {
- rhs = cast(lhs.type(), rhs);
+ rhs = cast(lhs.dtype(), rhs);
}
- } else if ((lhs.type().is_int() && rhs.type().is_uint()) ||
- (lhs.type().is_uint() && rhs.type().is_int())) {
- int bits = std::max(lhs.type().bits(), rhs.type().bits());
- lhs = SimpleCast(Int(bits, lhs.type().lanes()), lhs);
- rhs = SimpleCast(Int(bits, rhs.type().lanes()), rhs);
+ } else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) ||
+ (lhs.dtype().is_uint() && rhs.dtype().is_int())) {
+ int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits());
+ lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs);
+ rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs);
} else {
LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
}
}
+// maximum and min limits
+Expr max_value(const DataType& dtype) {
+ using namespace ir;
+ CHECK_EQ(dtype.lanes(), 1);
+ if (dtype.is_int()) {
+ if (dtype.bits() == 64) {
+ return IntImm::make(dtype, std::numeric_limits<int64_t>::max());
+ } else if (dtype.bits() < 64) {
+ int64_t val = 1;
+ val = (val << (dtype.bits() - 1)) - 1;
+ return IntImm::make(dtype, val);
+ }
+ } else if (dtype.is_uint()) {
+ if (dtype.bits() == 64) {
+ return UIntImm::make(dtype, std::numeric_limits<uint64_t>::max());
+ } else if (dtype.bits() < 64) {
+ uint64_t val = 1;
+ val = (val << static_cast<uint64_t>(dtype.bits())) - 1;
+ return UIntImm::make(dtype, val);
+ }
+ } else if (dtype.is_float()) {
+ if (dtype.bits() == 64) {
+ return FloatImm::make(dtype, std::numeric_limits<double>::max());
+ } else if (dtype.bits() == 32) {
+ return FloatImm::make(dtype, std::numeric_limits<float>::max());
+ } else if (dtype.bits() == 16) {
+ return FloatImm::make(dtype, 65504.0);
+ }
+ }
+ LOG(FATAL) << "Cannot decide max_value for type" << dtype;
+ return Expr();
+}
+
+Expr min_value(const DataType& dtype) {
+ using namespace ir;
+ CHECK_EQ(dtype.lanes(), 1);
+ if (dtype.is_int()) {
+ if (dtype.bits() == 64) {
+ return IntImm::make(dtype, std::numeric_limits<int64_t>::lowest());
+ } else if (dtype.bits() < 64) {
+ int64_t val = 1;
+ val = -(val << (dtype.bits() - 1));
+ return IntImm::make(dtype, val);
+ }
+ } else if (dtype.is_uint()) {
+ return UIntImm::make(dtype, 0);
+ } else if (dtype.is_float()) {
+ if (dtype.bits() == 64) {
+ return FloatImm::make(dtype, std::numeric_limits<double>::lowest());
+ } else if (dtype.bits() == 32) {
+ return FloatImm::make(dtype, std::numeric_limits<float>::lowest());
+ } else if (dtype.bits() == 16) {
+ return FloatImm::make(dtype, -65504.0);
+ }
+ }
+ LOG(FATAL) << "Cannot decide min_value for type" << dtype;
+ return Expr();
+}
+
template<typename ValueType>
inline bool ConstPowerHelper(ValueType val, int *shift) {
if (val <= 0) return false;
@@ -103,11 +162,11 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) {
}
}
-Expr cast(const Type& t, Expr value) {
+Expr cast(const DataType& t, Expr value) {
using ir::IntImm;
using ir::UIntImm;
using ir::FloatImm;
- if (value.type() == t) return value;
+ if (value.dtype() == t) return value;
// const fold IntImm as they are used in index computations
if (t.lanes() == 1) {
if (const IntImm* op = value.as<IntImm>()) {
@@ -119,10 +178,10 @@ Expr cast(const Type& t, Expr value) {
}
return ir::Cast::make(t, value);
} else {
- if (value.type().lanes() == 1) {
+ if (value.dtype().lanes() == 1) {
// manually unroll cast
- Type vtype = t.element_of();
- if (value.type() != vtype) {
+ DataType vtype = t.element_of();
+ if (value.dtype() != vtype) {
if (const IntImm* op = value.as<IntImm>()) {
value = make_const(vtype, op->value);
} else if (const UIntImm* op = value.as<UIntImm>()) {
@@ -135,14 +194,14 @@ Expr cast(const Type& t, Expr value) {
}
return ir::Broadcast::make(value, t.lanes());
} else {
- CHECK(value.type().lanes() == t.lanes());
+ CHECK(value.dtype().lanes() == t.lanes());
return ir::Cast::make(t, value);
}
}
}
-Expr reinterpret(const Type& t, Expr value) {
- if (value.type() == t) return value;
+Expr reinterpret(const DataType& t, Expr value) {
+ if (value.dtype() == t) return value;
return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic);
}
@@ -159,9 +218,9 @@ Expr operator-(Expr a) {
using ir::FloatImm;
const IntImm* pa = a.as<IntImm>();
const FloatImm* fa = a.as<FloatImm>();
- if (pa) return ir::IntImm::make(a.type(), -pa->value);
- if (fa) return ir::FloatImm::make(a.type(), -fa->value);
- return make_zero(a.type()) - a;
+ if (pa) return ir::IntImm::make(a.dtype(), -pa->value);
+ if (fa) return ir::FloatImm::make(a.dtype(), -fa->value);
+ return make_zero(a.dtype()) - a;
}
Expr operator-(Expr a, Expr b) {
@@ -186,8 +245,8 @@ Expr div(Expr a, Expr b) {
}
Expr truncdiv(Expr a, Expr b) {
- CHECK(a.type().is_int() || a.type().is_uint());
- CHECK(b.type().is_int() || b.type().is_uint());
+ CHECK(a.dtype().is_int() || a.dtype().is_uint());
+ CHECK(b.dtype().is_int() || b.dtype().is_uint());
return div(a, b);
}
@@ -216,8 +275,8 @@ Expr indexmod(Expr a, Expr b) {
}
Expr floordiv(Expr a, Expr b) {
- CHECK(a.type().is_int() || a.type().is_uint());
- CHECK(b.type().is_int() || b.type().is_uint());
+ CHECK(a.dtype().is_int() || a.dtype().is_uint());
+ CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
if (ret.defined()) return ret;
@@ -225,8 +284,8 @@ Expr floordiv(Expr a, Expr b) {
}
Expr floormod(Expr a, Expr b) {
- CHECK(a.type().is_int() || a.type().is_uint());
- CHECK(b.type().is_int() || b.type().is_uint());
+ CHECK(a.dtype().is_int() || a.dtype().is_uint());
+ CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::FloorMod>(a, b);
if (ret.defined()) return ret;
@@ -264,7 +323,7 @@ Expr max(Expr a, Expr b) {
Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
using ir::IntImm;
using ir::UIntImm;
- CHECK(cond.type() == Bool(1))
+ CHECK(cond.dtype() == DataType::Bool(1))
<< "if_then_else only accept the condition to be boolean type.";
BinaryOpMatchTypes(true_value, false_value);
if (const UIntImm* op = cond.as<UIntImm>()) {
@@ -281,7 +340,7 @@ Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
}
}
return ir::Call::make(
- true_value.type(),
+ true_value.dtype(),
ir::intrinsic::tvm_if_then_else,
{cond, true_value, false_value},
ir::Call::PureIntrinsic);
@@ -289,7 +348,7 @@ Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
Expr likely(Expr cond) {
if (is_const(cond)) return cond;
- return ir::Call::make(cond.type(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic);
+ return ir::Call::make(cond.dtype(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic);
}
Expr operator>(Expr a, Expr b) {
@@ -335,23 +394,23 @@ Expr operator!=(Expr a, Expr b) {
}
Expr operator&&(Expr a, Expr b) {
- CHECK(a.type().is_bool());
- CHECK(b.type().is_bool());
+ CHECK(a.dtype().is_bool());
+ CHECK(b.dtype().is_bool());
Expr ret = arith::TryConstFold<ir::And>(a, b);
if (ret.defined()) return ret;
return ir::And::make(a, b);
}
Expr operator||(Expr a, Expr b) {
- CHECK(a.type().is_bool());
- CHECK(b.type().is_bool());
+ CHECK(a.dtype().is_bool());
+ CHECK(b.dtype().is_bool());
Expr ret = arith::TryConstFold<ir::Or>(a, b);
if (ret.defined()) return ret;
return ir::Or::make(a, b);
}
Expr operator!(Expr a) {
- CHECK(a.type().is_bool());
+ CHECK(a.dtype().is_bool());
Expr ret = arith::TryConstFold<ir::Not>(a);
if (ret.defined()) return ret;
return ir::Not::make(a);
@@ -360,211 +419,211 @@ Expr operator!(Expr a) {
Expr operator>>(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value));
if (pb) {
if (pb->value == 0) return a;
}
});
- return ir::Call::make(a.type(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic);
+ return ir::Call::make(a.dtype(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator<<(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value));
if (pb) {
if (pb->value == 0) return a;
}
});
- return ir::Call::make(a.type(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic);
+ return ir::Call::make(a.dtype(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator&(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value));
});
- return ir::Call::make(a.type(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic);
+ return ir::Call::make(a.dtype(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator|(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value));
});
- return ir::Call::make(a.type(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic);
+ return ir::Call::make(a.dtype(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator^(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
- const Type& rtype = a.type();
+ const DataType& rtype = a.dtype();
if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value));
});
- return ir::Call::make(a.type(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic);
+ return ir::Call::make(a.dtype(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic);
}
Expr operator~(Expr a) {
- CHECK(a.type().is_int() || a.type().is_uint());
- return ir::Call::make(a.type(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic);
+ CHECK(a.dtype().is_int() || a.dtype().is_uint());
+ return ir::Call::make(a.dtype(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic);
}
Expr pow(Expr x, Expr y) {
BinaryOpMatchTypes(x, y);
- CHECK(x.type().is_float()) << "power only applies to float";
- return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
+ CHECK(x.dtype().is_float()) << "power only applies to float";
+ return ir::Call::make(x.dtype(), "pow", { x, y }, ir::Call::PureIntrinsic);
}
Expr abs(Expr x) {
- if (x.type().is_int()) {
+ if (x.dtype().is_int()) {
using ir::IntImm;
const IntImm* px = x.as<IntImm>();
if (px) {
- return ir::IntImm::make(x.type(), std::abs(px->value));
+ return ir::IntImm::make(x.dtype(), std::abs(px->value));
}
- return ir::Select::make(x >= make_zero(x.type()), x, -x);
- } else if (x.type().is_float()) {
+ return ir::Select::make(x >= make_zero(x.dtype()), x, -x);
+ } else if (x.dtype().is_float()) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) {
- return ir::FloatImm::make(x.type(), std::fabs(fx->value));
+ return ir::FloatImm::make(x.dtype(), std::fabs(fx->value));
}
- return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
- } else if (x.type().is_uint()) {
+ return ir::Call::make(x.dtype(), "fabs", {x}, ir::Call::PureIntrinsic);
+ } else if (x.dtype().is_uint()) {
return x;
} else {
- LOG(FATAL) << "Data type " << x.type()
+ LOG(FATAL) << "Data type " << x.dtype()
<<" not supported for absolute op. Skipping absolute op...";
return x;
}
}
Expr isnan(Expr x) {
- Type t = Bool(x.type().lanes());
- if (x.type().is_int() || x.type().is_uint()) {
+ DataType t = DataType::Bool(x.dtype().lanes());
+ if (x.dtype().is_int() || x.dtype().is_uint()) {
return make_const(t, false);
- } else if (x.type().is_float()) {
+ } else if (x.dtype().is_float()) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) {
return make_const(t, std::isnan(fx->value));
}
- if (x.type().bits() == 16) {
+ if (x.dtype().bits() == 16) {
return ir::Call::make(t, ir::Call::isnan,
- {cast(Float(32, t.lanes()), std::move(x))},
+ {cast(DataType::Float(32, t.lanes()), std::move(x))},
ir::Call::PureIntrinsic);
} else {
return ir::Call::make(t, ir::Call::isnan, {x}, ir::Call::PureIntrinsic);
}
} else {
- LOG(FATAL) << "Data type " << x.type()
+ LOG(FATAL) << "Data type " << x.dtype()
<<" not supported for isnan op. Skipping isnan op...";
return x;
}
}
Expr sum(Expr source, Array<IterVar> rdom) {
- Var x("x", source.type()), y("y", source.type());
+ Var x("x", source.dtype()), y("y", source.dtype());
Expr result = ir::Add::make(x, y);
- Expr identity_element = make_zero(source.type());
+ Expr identity_element = make_zero(source.dtype());
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
+ return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr all(Expr source, Array<IterVar> rdom) {
- CHECK(source.type().is_bool());
- Var x("x", source.type()), y("y", source.type());
+ CHECK(source.dtype().is_bool());
+ Var x("x", source.dtype()), y("y", source.dtype());
Expr result = ir::And::make(x, y);
- Expr identity_element = make_const(source.type(), true);
+ Expr identity_element = make_const(source.dtype(), true);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
+ return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr any(Expr source, Array<IterVar> rdom) {
- CHECK(source.type().is_bool());
- Var x("x", source.type()), y("y", source.type());
+ CHECK(source.dtype().is_bool());
+ Var x("x", source.dtype()), y("y", source.dtype());
Expr result = ir::Or::make(x, y);
- Expr identity_element = make_const(source.type(), false);
+ Expr identity_element = make_const(source.dtype(), false);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
+ return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr max(Expr source, Array<IterVar> rdom) {
- Var x("x", source.type()), y("y", source.type());
+ Var x("x", source.dtype()), y("y", source.dtype());
Expr result = ir::Max::make(x, y);
- Expr identity_element = source.type().min();
+ Expr identity_element = min_value(source.dtype());
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
+ return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr min(Expr source, Array<IterVar> rdom) {
- Var x("x", source.type()), y("y", source.type());
+ Var x("x", source.dtype()), y("y", source.dtype());
Expr result = ir::Min::make(x, y);
- Expr identity_element = source.type().max();
+ Expr identity_element = max_value(source.dtype());
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
+ return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr prod(Expr source, Array<IterVar> rdom) {
- Var x("x", source.type()), y("y", source.type());
+ Var x("x", source.dtype()), y("y", source.dtype());
Expr result = ir::Mul::make(x, y);
- Expr identity_element = make_const(source.type(), 1);
+ Expr identity_element = make_const(source.dtype(), 1);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
+ return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr fmod(Expr x, Expr y) {
BinaryOpMatchTypes(x, y);
- CHECK(x.type().is_float()) << "fmod only applies to float";
- return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic);
+ CHECK(x.dtype().is_float()) << "fmod only applies to float";
+ return ir::Call::make(x.dtype(), "fmod", { x, y }, ir::Call::PureIntrinsic);
}
Expr floor(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
- if (fx) return FloatImm::make(x.type(), std::floor(fx->value));
- return ir::Call::make(x.type(), "floor", {x}, ir::Call::PureIntrinsic);
+ if (fx) return FloatImm::make(x.dtype(), std::floor(fx->value));
+ return ir::Call::make(x.dtype(), "floor", {x}, ir::Call::PureIntrinsic);
}
Expr ceil(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
- if (fx) return FloatImm::make(x.type(), std::ceil(fx->value));
- return ir::Call::make(x.type(), "ceil", {x}, ir::Call::PureIntrinsic);
+ if (fx) return FloatImm::make(x.dtype(), std::ceil(fx->value));
+ return ir::Call::make(x.dtype(), "ceil", {x}, ir::Call::PureIntrinsic);
}
Expr round(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
- if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value));
- return ir::Call::make(x.type(), "round", {x}, ir::Call::PureIntrinsic);
+ if (fx) return FloatImm::make(x.dtype(), std::nearbyint(fx->value));
+ return ir::Call::make(x.dtype(), "round", {x}, ir::Call::PureIntrinsic);
}
Expr nearbyint(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
- if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value));
- return ir::Call::make(x.type(), "nearbyint", {x}, ir::Call::PureIntrinsic);
+ if (fx) return FloatImm::make(x.dtype(), std::nearbyint(fx->value));
+ return ir::Call::make(x.dtype(), "nearbyint", {x}, ir::Call::PureIntrinsic);
}
Expr trunc(Expr x) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) {
- return FloatImm::make(x.type(), (fx->value < 0 ? std::ceil(fx->value) :
+ return FloatImm::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) :
std::floor(fx->value)));
}
- return ir::Call::make(x.type(), "trunc", {x}, ir::Call::PureIntrinsic);
+ return ir::Call::make(x.dtype(), "trunc", {x}, ir::Call::PureIntrinsic);
}
} // namespace tvm
diff --git a/src/lang/ir.cc b/src/lang/ir.cc
index bb8401d..427e026 100644
--- a/src/lang/ir.cc
+++ b/src/lang/ir.cc
@@ -35,7 +35,7 @@ Expr UIntImm::make(DataType t, uint64_t value) {
CHECK(t.is_uint() && t.lanes() == 1)
<< "ValueError: UIntImm can only take scalar";
NodePtr<UIntImm> node = make_node<UIntImm>();
- node->type = t;
+ node->dtype = t;
node->value = value;
return Expr(node);
}
@@ -44,23 +44,23 @@ Expr FloatImm::make(DataType t, double value) {
CHECK_EQ(t.lanes(), 1)
<< "ValueError: FloatImm can only take scalar";
NodePtr<FloatImm> node = make_node<FloatImm>();
- node->type = t;
+ node->dtype = t;
node->value = value;
return Expr(node);
}
Expr StringImm::make(std::string value) {
NodePtr<StringImm> node = make_node<StringImm>();
- node->type = Handle();
+ node->dtype = DataType::Handle();
node->value = std::move(value);
return Expr(node);
}
Expr Cast::make(DataType t, Expr value) {
CHECK(value.defined());
- CHECK_EQ(t.lanes(), value.type().lanes());
+ CHECK_EQ(t.lanes(), value.dtype().lanes());
NodePtr<Cast> node = make_node<Cast>();
- node->type = t;
+ node->dtype = t;
node->value = std::move(value);
return Expr(node);
}
@@ -68,12 +68,12 @@ Expr Cast::make(DataType t, Expr value) {
Expr And::make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
- CHECK(a.type().is_bool());
- CHECK(b.type().is_bool());
- CHECK(a.type() == b.type()) << "TypeError: mismatched types";
+ CHECK(a.dtype().is_bool());
+ CHECK(b.dtype().is_bool());
+ CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
NodePtr<And> node = make_node<And>();
- node->type = Bool(a.type().lanes());
+ node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
node->b = std::move(b);
return Expr(node);
@@ -82,12 +82,12 @@ Expr And::make(Expr a, Expr b) {
Expr Or::make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
- CHECK(a.type().is_bool());
- CHECK(b.type().is_bool());
- CHECK(a.type() == b.type()) << "TypeError: mismatched types";
+ CHECK(a.dtype().is_bool());
+ CHECK(b.dtype().is_bool());
+ CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
NodePtr<Or> node = make_node<Or>();
- node->type = Bool(a.type().lanes());
+ node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
node->b = std::move(b);
return Expr(node);
@@ -95,10 +95,10 @@ Expr Or::make(Expr a, Expr b) {
Expr Not::make(Expr a) {
CHECK(a.defined()) << "ValueError: a is undefined";
- CHECK(a.type().is_bool());
+ CHECK(a.dtype().is_bool());
NodePtr<Not> node = make_node<Not>();
- node->type = Bool(a.type().lanes());
+ node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
return Expr(node);
}
@@ -107,27 +107,27 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) {
CHECK(condition.defined()) << "ValueError: condition is undefined";
CHECK(true_value.defined()) << "ValueError: true_value is undefined";
CHECK(false_value.defined()) << "ValueError: true_value is undefined";
- CHECK(condition.type().is_bool());
- CHECK_EQ(condition.type().lanes(), true_value.type().lanes());
- CHECK(false_value.type() == true_value.type()) << "TypeError: mismatched types";
+ CHECK(condition.dtype().is_bool());
+ CHECK_EQ(condition.dtype().lanes(), true_value.dtype().lanes());
+ CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types";
NodePtr<Select> node = make_node<Select>();
- node->type = true_value.type();
+ node->dtype = true_value.dtype();
node->condition = std::move(condition);
node->true_value = std::move(true_value);
node->false_value = std::move(false_value);
return Expr(node);
}
-Expr Load::make(DataType type, Var buffer_var, Expr index, Expr predicate) {
+Expr Load::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) {
CHECK(buffer_var.defined());
CHECK(predicate.defined());
CHECK(index.defined());
- CHECK_EQ(type.lanes(), index.type().lanes());
- CHECK_EQ(type.lanes(), predicate.type().lanes());
+ CHECK_EQ(dtype.lanes(), index.dtype().lanes());
+ CHECK_EQ(dtype.lanes(), predicate.dtype().lanes());
NodePtr<Load> node = make_node<Load>();
- node->type = type;
+ node->dtype = dtype;
node->buffer_var = std::move(buffer_var);
node->index = std::move(index);
node->predicate = std::move(predicate);
@@ -138,13 +138,13 @@ Expr Load::make(DataType type, Var buffer_var, Expr index, Expr predicate) {
Expr Ramp::make(Expr base, Expr stride, int lanes) {
CHECK(base.defined());
CHECK(stride.defined());
- CHECK(base.type().is_scalar());
- CHECK(stride.type().is_scalar());
+ CHECK(base.dtype().is_scalar());
+ CHECK(stride.dtype().is_scalar());
CHECK_GT(lanes, 1);
- CHECK_EQ(stride.type(), base.type());
+ CHECK_EQ(stride.dtype(), base.dtype());
NodePtr<Ramp> node = make_node<Ramp>();
- node->type = base.type().with_lanes(lanes);
+ node->dtype = base.dtype().with_lanes(lanes);
node->base = base;
node->stride = stride;
node->lanes = lanes;
@@ -153,11 +153,11 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) {
Expr Broadcast::make(Expr value, int lanes) {
CHECK(value.defined());
- CHECK(value.type().is_scalar());
+ CHECK(value.dtype().is_scalar());
CHECK_GT(lanes, 1);
NodePtr<Broadcast> node = make_node<Broadcast>();
- node->type = value.type().with_lanes(lanes);
+ node->dtype = value.dtype().with_lanes(lanes);
node->value = std::move(value);
node->lanes = lanes;
return Expr(node);
@@ -166,10 +166,10 @@ Expr Broadcast::make(Expr value, int lanes) {
Expr Let::make(Var var, Expr value, Expr body) {
CHECK(value.defined());
CHECK(body.defined());
- CHECK_EQ(value.type(), var.type());
+ CHECK_EQ(value.dtype(), var.dtype());
NodePtr<Let> node = make_node<Let>();
- node->type = body.type();
+ node->dtype = body.dtype();
node->var = std::move(var);
node->value = std::move(value);
node->body = std::move(body);
@@ -192,7 +192,7 @@ bool Call::is_vectorizable() const {
return false;
}
-Expr Call::make(DataType type,
+Expr Call::make(DataType dtype,
std::string name,
Array<Expr> args,
CallType call_type,
@@ -204,12 +204,12 @@ Expr Call::make(DataType type,
if (call_type == Halide) {
for (size_t i = 0; i < args.size(); ++i) {
- CHECK(args[i].type().is_int());
+ CHECK(args[i].dtype().is_int());
}
}
NodePtr<Call> node = make_node<Call>();
- node->type = type;
+ node->dtype = dtype;
node->name = std::move(name);
node->args = std::move(args);
node->call_type = call_type;
@@ -223,17 +223,17 @@ Expr Shuffle::make(Array<Expr> vectors,
CHECK_NE(vectors.size(), 0U);
CHECK_NE(indices.size(), 0U);
- Type base_type = vectors[0].type().element_of();
+ DataType base_type = vectors[0].dtype().element_of();
int total_lanes = 0;
for (Expr val : vectors) {
- CHECK(val.type().element_of() == base_type);
- total_lanes += val.type().lanes();
+ CHECK(val.dtype().element_of() == base_type);
+ total_lanes += val.dtype().lanes();
}
CHECK_LE(indices.size(), static_cast<size_t>(total_lanes));
NodePtr<Shuffle> node = make_node<Shuffle>();
- node->type = base_type.with_lanes(static_cast<int>(indices.size()));
+ node->dtype = base_type.with_lanes(static_cast<int>(indices.size()));
node->vectors = std::move(vectors);
node->indices = std::move(indices);
return Expr(node);
@@ -247,8 +247,8 @@ Expr Shuffle::make_concat(Array<Expr> vectors) {
Array<Expr> indices;
int index = 0;
for (const Expr& e : vectors) {
- for (int i = 0; i < e.type().lanes(); ++i) {
- indices.push_back(IntImm::make(Int(32), index++));
+ for (int i = 0; i < e.dtype().lanes(); ++i) {
+ indices.push_back(IntImm::make(DataType::Int(32), index++));
}
}
return make(vectors, indices);
@@ -298,7 +298,7 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source,
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
}
- n->type = source[value_index].type();
+ n->dtype = source[value_index].dtype();
n->combiner = std::move(combiner);
n->source = std::move(source);
n->axis = std::move(axis);
@@ -315,7 +315,7 @@ Expr Any::make() {
Stmt LetStmt::make(Var var, Expr value, Stmt body) {
CHECK(value.defined());
CHECK(body.defined());
- CHECK_EQ(value.type(), var.type());
+ CHECK_EQ(value.dtype(), var.dtype());
NodePtr<LetStmt> node = make_node<LetStmt>();
node->var = std::move(var);
@@ -338,7 +338,7 @@ Stmt AttrStmt::make(NodeRef node,
Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) {
CHECK(condition.defined());
- CHECK(message.type() == Int(32) ||
+ CHECK(message.dtype() == DataType::Int(32) ||
message.as<StringImm>())
<< "TypeError: AssertStmt message must be an int or string:"
<< message << "\n";
@@ -368,9 +368,9 @@ Stmt For::make(Var loop_var,
Stmt body) {
CHECK(min.defined());
CHECK(extent.defined());
- CHECK(min.type().is_scalar());
- CHECK(extent.type().is_scalar());
- CHECK(loop_var.type().is_scalar());
+ CHECK(min.dtype().is_scalar());
+ CHECK(extent.dtype().is_scalar());
+ CHECK(loop_var.dtype().is_scalar());
CHECK(body.defined());
NodePtr<For> node = make_node<For>();
@@ -387,8 +387,8 @@ Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
CHECK(value.defined());
CHECK(index.defined());
CHECK(predicate.defined());
- CHECK_EQ(value.type().lanes(), index.type().lanes());
- CHECK_EQ(value.type().lanes(), predicate.type().lanes());
+ CHECK_EQ(value.dtype().lanes(), index.dtype().lanes());
+ CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes());
NodePtr<Store> node = make_node<Store>();
node->buffer_var = std::move(buffer_var);
@@ -416,7 +416,7 @@ Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array<Expr> ar
}
Stmt Allocate::make(Var buffer_var,
- DataType type,
+ DataType dtype,
Array<Expr> extents,
Expr condition,
Stmt body,
@@ -424,15 +424,15 @@ Stmt Allocate::make(Var buffer_var,
std::string free_function) {
for (size_t i = 0; i < extents.size(); ++i) {
CHECK(extents[i].defined());
- CHECK(extents[i].type().is_scalar());
+ CHECK(extents[i].dtype().is_scalar());
}
CHECK(body.defined());
CHECK(condition.defined());
- CHECK(condition.type().is_bool());
+ CHECK(condition.dtype().is_bool());
NodePtr<Allocate> node = make_node<Allocate>();
node->buffer_var = std::move(buffer_var);
- node->type = type;
+ node->dtype = dtype;
node->extents = std::move(extents);
node->condition = std::move(condition);
node->body = std::move(body);
@@ -464,42 +464,42 @@ Stmt Free::make(Var buffer_var) {
Stmt Realize::make(FunctionRef func,
int value_index,
- DataType type,
+ DataType dtype,
Region bounds,
Expr condition,
Stmt body) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
CHECK(bounds[i]->extent.defined());
- CHECK(bounds[i]->min.type().is_scalar());
- CHECK(bounds[i]->extent.type().is_scalar());
+ CHECK(bounds[i]->min.dtype().is_scalar());
+ CHECK(bounds[i]->extent.dtype().is_scalar());
}
CHECK(body.defined());
CHECK(condition.defined());
- CHECK(condition.type().is_bool());
+ CHECK(condition.dtype().is_bool());
NodePtr<Realize> node = make_node<Realize>();
node->func = std::move(func);
node->value_index = value_index;
- node->type = type;
+ node->dtype = dtype;
node->bounds = std::move(bounds);
node->condition = std::move(condition);
node->body = std::move(body);
return Stmt(node);
}
-Stmt Prefetch::make(FunctionRef func, int value_index, DataType type, Region bounds) {
+Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bounds) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
CHECK(bounds[i]->extent.defined());
- CHECK(bounds[i]->min.type().is_scalar());
- CHECK(bounds[i]->extent.type().is_scalar());
+ CHECK(bounds[i]->min.dtype().is_scalar());
+ CHECK(bounds[i]->extent.dtype().is_scalar());
}
NodePtr<Prefetch> node = make_node<Prefetch>();
node->func = std::move(func);
node->value_index = value_index;
- node->type = type;
+ node->dtype = dtype;
node->bounds = std::move(bounds);
return Stmt(node);
}
@@ -555,14 +555,14 @@ Stmt Evaluate::make(Expr value) {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<UIntImm>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const UIntImm*>(node.get());
- p->stream << "(" << op->type << ")" << op->value;
+ p->stream << "(" << op->dtype << ")" << op->value;
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FloatImm>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const FloatImm*>(node.get());
auto& stream = p->stream;
- switch (op->type.bits()) {
+ switch (op->dtype.bits()) {
case 64:
stream << op->value;
break;
@@ -573,7 +573,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
stream << op->value << 'h';
break;
default:
- LOG(FATAL) << "Unknown float type bits=" << op->type.bits();
+ LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits();
}
});
@@ -616,7 +616,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Cast>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Cast*>(node.get());
- p->stream << op->type << '(';
+ p->stream << op->dtype << '(';
p->Print(op->value);
p->stream << ')';
})
@@ -959,7 +959,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Allocate>([](const ObjectRef& node, IRPrinter* p) {
auto* op = static_cast<const Allocate*>(node.get());
p->PrintIndent();
- p->stream << "allocate " << op->buffer_var << "[" << op->type;
+ p->stream << "allocate " << op->buffer_var << "[" << op->dtype;
for (size_t i = 0; i < op->extents.size(); ++i) {
p->stream << " * ";
p->Print(op->extents[i]);
diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc
index 05ba6f7..1c11093 100644
--- a/src/lang/tensor.cc
+++ b/src/lang/tensor.cc
@@ -56,7 +56,7 @@ Tensor Operation::output(size_t i) const {
}
Tensor TensorNode::make(Array<Expr> shape,
- Type dtype,
+ DataType dtype,
Operation op,
int value_index) {
auto n = make_node<TensorNode>();
diff --git a/src/node/reflection.cc b/src/node/reflection.cc
index e92ca92..f535837 100644
--- a/src/node/reflection.cc
+++ b/src/node/reflection.cc
@@ -61,7 +61,7 @@ class AttrGetter : public AttrVisitor {
void Visit(const char* key, void** value) final {
if (skey == key) *ret = static_cast<void*>(value[0]);
}
- void Visit(const char* key, Type* value) final {
+ void Visit(const char* key, DataType* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, std::string* value) final {
@@ -135,7 +135,7 @@ class AttrDir : public AttrVisitor {
void Visit(const char* key, void** value) final {
names->push_back(key);
}
- void Visit(const char* key, Type* value) final {
+ void Visit(const char* key, DataType* value) final {
names->push_back(key);
}
void Visit(const char* key, std::string* value) final {
diff --git a/src/node/serialization.cc b/src/node/serialization.cc
index cb310eb..5a991aa 100644
--- a/src/node/serialization.cc
+++ b/src/node/serialization.cc
@@ -39,11 +39,11 @@
namespace tvm {
inline std::string Type2String(const DataType& t) {
- return runtime::TVMType2String(Type2TVMType(t));
+ return runtime::TVMType2String(t);
}
-inline Type String2Type(std::string s) {
- return TVMType2Type(runtime::String2TVMType(s));
+inline DataType String2Type(std::string s) {
+ return DataType(runtime::String2TVMType(s));
}
// indexer to index all the nodes
diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
index 5f5d2d4..bd129ac 100644
--- a/src/op/compute_op.cc
+++ b/src/op/compute_op.cc
@@ -70,9 +70,9 @@ Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
return ret;
}
-Type ComputeOpNode::output_dtype(size_t idx) const {
+DataType ComputeOpNode::output_dtype(size_t idx) const {
CHECK_LT(idx, num_outputs());
- return body[idx].type();
+ return body[idx].dtype();
}
Array<Expr> BaseComputeOpNode::output_shape(size_t idx) const {
@@ -100,7 +100,7 @@ Tensor compute(Array<Expr> shape,
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(IterVarNode::make(
- Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
+ Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
args.push_back(axis.back()->var);
}
@@ -122,7 +122,7 @@ Array<Tensor> compute(Array<Expr> shape,
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(IterVarNode::make(
- Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
+ Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar));
args.push_back(axis.back()->var);
}
@@ -190,7 +190,7 @@ Operation ComputeOpNode::ReplaceInputs(
for (size_t k = 0; k < this->body.size(); ++k) {
auto n = make_node<ir::Reduce>(*r);
n->value_index = static_cast<int>(k);
- n->type = r->source[k].type();
+ n->dtype = r->source[k].dtype();
arr.push_back(Expr(n));
}
} else {
@@ -229,7 +229,7 @@ void ComputeOpNode::PropBoundToInputs(
IntSet arg_intset = EvalSet(call->args[i], dom_map);
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) {
- Expr shape_i_min_value = make_zero(t->shape[i].type());
+ Expr shape_i_min_value = make_zero(t->shape[i].dtype());
Expr shape_i_max_value = t->shape[i] - 1;
Expr min_value = arg_interval->min_value;
Expr max_value = arg_interval->max_value;
@@ -295,7 +295,7 @@ Stmt BaseComputeOpNode::BuildRealize(
attr->dim_align_offset};
realize = ir::AttrStmt::make(
t, ir::attr::buffer_dim_align,
- Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
+ Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
realize);
}
}
diff --git a/src/op/cross_thread_reduction.cc b/src/op/cross_thread_reduction.cc
index 818acb9..4a3aa54 100644
--- a/src/op/cross_thread_reduction.cc
+++ b/src/op/cross_thread_reduction.cc
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -57,14 +57,14 @@ Stmt MakeCrossThreadReduction(
cond = cond && v;
}
Array<Expr> freduce_args;
- freduce_args.push_back(make_const(UInt(32), static_cast<uint32_t>(size)));
+ freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
for (size_t i = 0; i < size; ++i) {
freduce_args.push_back(reduces[0]->source[i]);
}
freduce_args.push_back(cond);
std::vector<Var> res_handles(size);
for (size_t idx = 0; idx < size; ++idx) {
- res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle());
+ res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle());
freduce_args.push_back(res_handles[idx]);
}
@@ -85,17 +85,17 @@ Stmt MakeCrossThreadReduction(
}
Stmt reduce_body = Evaluate::make(Call::make(
- Handle(),
+ DataType::Handle(),
ir::intrinsic::tvm_thread_allreduce,
freduce_args, Call::Intrinsic));
reduce_body = AttrStmt::make(
reduces[0]->combiner,
attr::reduce_scope,
- make_zero(Handle()),
+ make_zero(DataType::Handle()),
reduce_body);
std::vector<Stmt> assigns(size);
for (size_t idx = 0; idx < size; ++idx) {
- Type t = reduces[idx]->type;
+ DataType t = reduces[idx]->dtype;
assigns[idx] = Provide::make(
stage->op, idx,
Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
@@ -106,7 +106,7 @@ Stmt MakeCrossThreadReduction(
Stmt body = Block::make(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
body = Allocate::make(
- res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body);
+ res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
body = AttrStmt::make(
res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
}
diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc
index 35fe469..883ebdc 100644
--- a/src/op/extern_op.cc
+++ b/src/op/extern_op.cc
@@ -46,7 +46,7 @@ Array<IterVar> ExternOpNode::root_iter_vars() const {
return {};
}
-Type ExternOpNode::output_dtype(size_t i) const {
+DataType ExternOpNode::output_dtype(size_t i) const {
return output_placeholders[i]->dtype;
}
@@ -122,7 +122,7 @@ void ExternOpNode::PropBoundToInputs(
for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range(
Range::make_by_min_extent(
- make_const(t->shape[i].type(), 0), t->shape[i])));
+ make_const(t->shape[i].dtype(), 0), t->shape[i])));
}
}
}
@@ -145,7 +145,7 @@ Stmt ExternOpNode::BuildRealize(
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
- make_const(t->shape[i].type(), 0), t->shape[i]));
+ make_const(t->shape[i].dtype(), 0), t->shape[i]));
}
realize_body = ir::Realize::make(
t->op, t->value_index, t->dtype,
@@ -159,19 +159,19 @@ Stmt ExternOpNode::BuildProvide(
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
- Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
+ Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec;
Array<Expr> tuple;
bind_spec.push_back(buffer);
bind_spec.push_back(tensor);
for (size_t k = 0; k < buffer->shape.size(); ++k) {
- tuple.push_back(make_const(buffer->shape[k].type(), 0));
+ tuple.push_back(make_const(buffer->shape[k].dtype(), 0));
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmt::make(
bind_spec, attr::buffer_bind_scope,
- Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
+ Call::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
};
for (size_t i = output_placeholders.size(); i != 0; --i) {
f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc
index 7a99ea1..1e1a814 100644
--- a/src/op/hybrid_op.cc
+++ b/src/op/hybrid_op.cc
@@ -52,7 +52,7 @@ Array<IterVar> HybridOpNode::root_iter_vars() const {
return this->axis;
}
-Type HybridOpNode::output_dtype(size_t i) const {
+DataType HybridOpNode::output_dtype(size_t i) const {
return outputs[i]->dtype;
}
@@ -138,7 +138,7 @@ void HybridOpNode::PropBoundToInputs(
for (size_t i = 0; i < t->shape.size(); ++i) {
dom.data[i].emplace_back(IntSet::range(
Range::make_by_min_extent(
- make_const(t->shape[i].type(), 0), t->shape[i])));
+ make_const(t->shape[i].dtype(), 0), t->shape[i])));
}
}
}
@@ -166,7 +166,7 @@ Stmt HybridOpNode::BuildRealize(
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
- make_const(t->shape[i].type(), 0), t->shape[i]));
+ make_const(t->shape[i].dtype(), 0), t->shape[i]));
}
realize_body = ir::Realize::make(
t->op, t->value_index, t->dtype,
@@ -180,7 +180,7 @@ Stmt HybridOpNode::BuildProvide(
const std::unordered_map<IterVar, Range> &dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
- Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
+ Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
diff --git a/src/op/op_util.cc b/src/op/op_util.cc
index 6916031..cd3b168 100644
--- a/src/op/op_util.cc
+++ b/src/op/op_util.cc
@@ -74,7 +74,7 @@ MakeLoopNest(const Stage& stage,
if (bind_iv->thread_tag.length() == 0) {
// Only generate new loop if we're not bound to a thread.
if (new_loop_var) {
- var = Var(iv->var->name_hint + ".init", bind_iv->var.type());
+ var = Var(iv->var->name_hint + ".init", bind_iv->var.dtype());
}
ForType for_type = ForType::Serial;
@@ -98,7 +98,7 @@ MakeLoopNest(const Stage& stage,
const std::string& pkey = it_attr->pragma_keys[k].as<StringImm>()->value;
Expr pvalue = it_attr->pragma_values[k];
if (!pvalue.defined()) {
- pvalue = make_const(Int(32), 1);
+ pvalue = make_const(DataType::Int(32), 1);
}
nest[i + 1].emplace_back(
AttrStmt::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
@@ -114,7 +114,7 @@ MakeLoopNest(const Stage& stage,
for_type, DeviceAPI::None, no_op));
value_map[iv] = var;
} else {
- Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.type());
+ Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
nest[i + 1].emplace_back(
For::make(idx, 0, dom->extent,
for_type, DeviceAPI::None, no_op));
@@ -197,7 +197,7 @@ class TensorReplacer : public ir::IRMutator {
auto it = vmap_.find(t);
if (it != vmap_.end()) {
Expr ret = ir::Call::make(
- op->type, it->second->op->name, op->args,
+ op->dtype, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc
index 91b0589..6910f63 100644
--- a/src/op/placeholder_op.cc
+++ b/src/op/placeholder_op.cc
@@ -42,7 +42,7 @@ Array<IterVar> PlaceholderOpNode::root_iter_vars() const {
return {};
}
-Type PlaceholderOpNode::output_dtype(size_t i) const {
+DataType PlaceholderOpNode::output_dtype(size_t i) const {
CHECK_EQ(i, 0U);
return dtype;
}
@@ -54,7 +54,7 @@ Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
Operation PlaceholderOpNode::make(std::string name,
Array<Expr> shape,
- Type dtype) {
+ DataType dtype) {
auto n = make_node<PlaceholderOpNode>();
n->name = name;
n->shape = shape;
@@ -62,7 +62,7 @@ Operation PlaceholderOpNode::make(std::string name,
return Operation(n);
}
-Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
+Tensor placeholder(Array<Expr> shape, DataType dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc
index b02073b..e83a231 100644
--- a/src/op/scan_op.cc
+++ b/src/op/scan_op.cc
@@ -53,7 +53,7 @@ Array<IterVar> ScanOpNode::root_iter_vars() const {
return ret;
}
-Type ScanOpNode::output_dtype(size_t i) const {
+DataType ScanOpNode::output_dtype(size_t i) const {
return update[i]->dtype;
}
diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc
index 83cdd76..e59f90f 100644
--- a/src/op/tensor_compute_op.cc
+++ b/src/op/tensor_compute_op.cc
@@ -46,7 +46,7 @@ int TensorComputeOpNode::num_outputs() const {
return static_cast<int>(this->intrin->buffers.size() - this->inputs.size());
}
-Type TensorComputeOpNode::output_dtype(size_t i) const {
+DataType TensorComputeOpNode::output_dtype(size_t i) const {
return this->intrin->buffers[this->inputs.size() + i]->dtype;
}
@@ -155,7 +155,7 @@ Stmt TensorComputeOpNode::BuildProvide(
}
input_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
- Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+ Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
// output binding
@@ -179,7 +179,7 @@ Stmt TensorComputeOpNode::BuildProvide(
output_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
- Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+ Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
// Check variable remap
diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc
index c4abf0b..b7f32de 100644
--- a/src/op/tensorize.cc
+++ b/src/op/tensorize.cc
@@ -173,7 +173,7 @@ class TensorIntrinMatcher final : public IRMutator {
args.push_back(op->args[i] - e.region[i]->min);
}
return Call::make(
- op->type, e.tensor->op->name, args,
+ op->dtype, e.tensor->op->name, args,
op->call_type, e.tensor->op, e.tensor->value_index);
}
}
@@ -341,12 +341,12 @@ void VerifyTensorizeBody(
lhs = CanonicalSimplify(lhs, compute_intrin_iter_space);
Expr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
rhs = CanonicalSimplify(rhs, compute_intrin_iter_space);
- if (lhs.type() != rhs.type()) {
+ if (lhs.dtype() != rhs.dtype()) {
LOG(FATAL)
<< "Failed to match the data type with TensorIntrin "
<< intrin->name << "'s declaration "
- << " provided=" << lhs.type()
- << ", intrin=" << rhs.type();
+ << " provided=" << lhs.dtype()
+ << ", intrin=" << rhs.dtype();
}
CHECK(Equal(lhs, rhs))
<< "Failed to match the compute with TensorIntrin "
@@ -390,7 +390,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
}
input_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
- Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+ Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
// output binding
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
@@ -410,7 +410,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
Array<NodeRef> bind_spec{buffer, tensor};
output_bind_nest.emplace_back(AttrStmt::make(
bind_spec, ir::attr::buffer_bind_scope,
- Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+ Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
}
// Check variable remap
std::unordered_map<const Variable*, Expr> vmap;
@@ -430,7 +430,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
IterVar target = intrin_compute->reduce_axis[i - start];
auto it = out_dom.find(iv);
CHECK(it != out_dom.end());
- binder.Bind(target->dom->min, make_const(iv->dom->min.type(), 0),
+ binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0),
"tensir_intrin.reduction.min");
binder.Bind(target->dom->extent, it->second->extent,
"tensir_intrin.reduction.extent");
diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc
index f892b6b..e4ff9cb 100644
--- a/src/pass/arg_binder.cc
+++ b/src/pass/arg_binder.cc
@@ -50,7 +50,7 @@ bool ArgBinder::Bind_(const Expr& arg,
const Expr& value,
const std::string& arg_name,
bool with_lets) {
- CHECK_EQ(arg.type(), value.type());
+ CHECK_EQ(arg.dtype(), value.dtype());
if (const Variable* v = arg.as<Variable>()) {
auto it = def_map_->find(v);
if (it == def_map_->end()) {
@@ -118,8 +118,8 @@ void ArgBinder::BindBuffer(const Buffer& arg,
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
Expr offset = value->elem_offset;
- Expr factor = make_const(offset.type(), arg->offset_factor);
- Expr zero = make_zero(offset.type());
+ Expr factor = make_const(offset.dtype(), arg->offset_factor);
+ Expr zero = make_zero(offset.dtype());
BinderAddAssert(truncmod(offset, factor) == zero,
arg_name + ".elem_offset", &asserts_);
}
@@ -153,7 +153,7 @@ void ArgBinder::BindBuffer(const Buffer& arg,
}
}
-inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
+inline Expr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) {
return TVMStructGet(t, arr, 0, kind);
}
@@ -162,8 +162,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
const Expr& device_id,
const Var& handle,
const std::string& arg_name) {
- const Type tvm_shape_type = TVMShapeIndexType();
- const Type tvm_ndim_type = Int(32);
+ const DataType tvm_shape_type = DataType::ShapeIndex();
+ const DataType tvm_ndim_type = DataType::Int(32);
const Stmt nop = Evaluate::make(0);
// dimension checks
Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
@@ -175,52 +175,52 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
<< buffer->shape.size();
asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
// type checks
- Type dtype = buffer->dtype;
+ DataType dtype = buffer->dtype;
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << dtype;
- Expr cond = (TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeCode) ==
- UIntImm::make(UInt(8), dtype.code()) &&
- TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeBits) ==
- UIntImm::make(UInt(8), dtype.bits()) &&
- TVMArrayGet(UInt(16), handle, intrinsic::kArrTypeLanes) ==
- UIntImm::make(UInt(16), dtype.lanes()));
+ Expr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
+ UIntImm::make(DataType::UInt(8), dtype.code()) &&
+ TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) ==
+ UIntImm::make(DataType::UInt(8), dtype.bits()) &&
+ TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
+ UIntImm::make(DataType::UInt(16), dtype.lanes()));
asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str(), nop));
// data field
- if (Bind_(buffer->data, TVMArrayGet(Handle(), handle, intrinsic::kArrData),
+ if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
arg_name + ".data", true)) {
Var vptr(buffer->data);
def_handle_dtype_.Set(vptr, ir::TypeAnnotation(buffer->dtype));
// mark alignment of external bufs
init_nest_.emplace_back(AttrStmt::make(
vptr, ir::attr::storage_alignment,
- IntImm::make(Int(32), buffer->data_alignment), nop));
+ IntImm::make(DataType::Int(32), buffer->data_alignment), nop));
}
- Var v_shape(arg_name + ".shape", Handle());
+ Var v_shape(arg_name + ".shape", DataType::Handle());
def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
init_nest_.emplace_back(LetStmt::make(
- v_shape, TVMArrayGet(Handle(), handle, intrinsic::kArrShape), nop));
+ v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
Bind_(buffer->shape[k],
- cast(buffer->shape[k].type(),
+ cast(buffer->shape[k].dtype(),
Load::make(tvm_shape_type, v_shape,
- IntImm::make(Int(32), k), const_true(1))),
+ IntImm::make(DataType::Int(32), k), const_true(1))),
field_name.str(), true);
}
// strides field
- Var v_strides(arg_name + ".strides", Handle());
+ Var v_strides(arg_name + ".strides", DataType::Handle());
def_handle_dtype_.Set(v_strides, ir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(LetStmt::make(
- v_strides, TVMArrayGet(Handle(), handle, intrinsic::kArrStrides),
+ v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides),
nop));
Expr is_null = Call::make(
- Bool(1), intrinsic::tvm_handle_is_null,
+ DataType::Bool(1), intrinsic::tvm_handle_is_null,
{v_strides}, Call::PureIntrinsic);
if (buffer->strides.size() == 0) {
// Assert the buffer is compact
- Type stype = buffer->DefaultIndexType();
+ DataType stype = buffer->DefaultIndexType();
Expr expect_stride = make_const(stype, 1);
Array<Expr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
@@ -228,7 +228,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
Expr svalue = cast(
stype,
Load::make(tvm_shape_type, v_strides,
- IntImm::make(Int(32), k), const_true(1)));
+ IntImm::make(DataType::Int(32), k), const_true(1)));
conds.push_back(expect_stride == svalue);
expect_stride = expect_stride * buffer->shape[k];
}
@@ -243,15 +243,15 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
asserts_.emplace_back(Block::make(check, Evaluate::make(0)));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
- Type stype = buffer->DefaultIndexType();
+ DataType stype = buffer->DefaultIndexType();
Expr stride = make_const(stype, 1);
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
- Expr value = cast(buffer->shape[k].type(),
+ Expr value = cast(buffer->shape[k].dtype(),
Load::make(tvm_shape_type, v_strides,
- IntImm::make(Int(32), k), const_true(1)));
+ IntImm::make(DataType::Int(32), k), const_true(1)));
value = tvm::if_then_else(is_null, stride, value);
value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
Bind_(buffer->strides[k], value, field_name.str(), true);
@@ -266,9 +266,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
Bind_(buffer->strides[k],
- cast(buffer->shape[k].type(),
+ cast(buffer->shape[k].dtype(),
Load::make(tvm_shape_type, v_strides,
- IntImm::make(Int(32), k), const_true(1))),
+ IntImm::make(DataType::Int(32), k), const_true(1))),
field_name.str(), true);
}
}
@@ -276,29 +276,29 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
int data_bytes = GetVectorBytes(buffer->dtype);
int64_t const_offset;
if (arith::GetConst(buffer->elem_offset, &const_offset)) {
- Bind_(make_const(UInt(64), const_offset * data_bytes),
- TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset),
+ Bind_(make_const(DataType::UInt(64), const_offset * data_bytes),
+ TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset),
arg_name + ".byte_offset", true);
} else {
if (Bind_(buffer->elem_offset,
- cast(buffer->elem_offset.type(),
- (TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset) /
- make_const(UInt(64), data_bytes))),
+ cast(buffer->elem_offset.dtype(),
+ (TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset) /
+ make_const(DataType::UInt(64), data_bytes))),
arg_name + ".elem_offset", true)) {
if (buffer->offset_factor > 1) {
Expr offset = buffer->elem_offset;
- Expr factor = make_const(offset.type(), buffer->offset_factor);
- Expr zero = make_zero(offset.type());
+ Expr factor = make_const(offset.dtype(), buffer->offset_factor);
+ Expr zero = make_zero(offset.dtype());
BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_);
}
}
}
// device info.
Bind_(device_type,
- TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceType),
+ TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType),
arg_name + ".device_type", true);
Bind_(device_id,
- TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceId),
+ TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId),
arg_name + ".device_id", true);
}
diff --git a/src/pass/bound_checker.cc b/src/pass/bound_checker.cc
index 55f9847..648302e 100644
--- a/src/pass/bound_checker.cc
+++ b/src/pass/bound_checker.cc
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -58,7 +58,7 @@ class BoundChecker : public IRMutator {
Stmt Mutate_(const Allocate *op, const Stmt &s) final {
// If the shape was updated we should update the hashtable.
if (UpdateIsNeeded(op->buffer_var)) {
- Update(op->buffer_var, op->extents, op->type);
+ Update(op->buffer_var, op->extents, op->dtype);
}
return IRMutator::Mutate_(op, s);
}
@@ -108,26 +108,26 @@ class BoundChecker : public IRMutator {
}
void Update(const VarExpr &buffer_var, const Array<Expr> &new_shape,
- const Type &type) {
+ const DataType &type) {
// Sanity check at first.
if (!new_shape.size()) {
return;
}
for (size_t i = 0; i < new_shape.size(); ++i) {
- if (!new_shape[0].defined() || !new_shape[i].type().is_scalar() ||
+ if (!new_shape[0].defined() || !new_shape[i].dtype().is_scalar() ||
is_negative_const(new_shape[i])) {
return;
}
}
// Scalarize the shape.
- Expr shape = Mul::make(make_const(UInt(64), type.lanes()),
- Cast::make(UInt(64), new_shape[0]));
+ Expr shape = Mul::make(make_const(DataType::UInt(64), type.lanes()),
+ Cast::make(DataType::UInt(64), new_shape[0]));
for (size_t i = 1; i < new_shape.size(); ++i) {
// Cast to unsigned to avoid integer overlow at frist.
- shape = Mul::make(shape, Mul::make(make_const(UInt(64), type.lanes()),
- Cast::make(UInt(64), new_shape[i])));
+ shape = Mul::make(shape, Mul::make(make_const(DataType::UInt(64), type.lanes()),
+ Cast::make(DataType::UInt(64), new_shape[i])));
}
mem_to_shape_[buffer_var.get()] = shape;
}
@@ -139,9 +139,9 @@ class BoundChecker : public IRMutator {
if (const Ramp *ramp_index = index.as<Ramp>()) {
return ramp_index->base.defined() &&
- ramp_index->base.type().is_scalar() &&
+ ramp_index->base.dtype().is_scalar() &&
ramp_index->stride.defined() &&
- ramp_index->stride.type().is_scalar() && (ramp_index->lanes > 0);
+ ramp_index->stride.dtype().is_scalar() && (ramp_index->lanes > 0);
}
return true;
}
@@ -168,7 +168,7 @@ class BoundChecker : public IRMutator {
// Non inclusive range.
index = Add::make(
ramp_index->base,
- Mul::make(ramp_index->stride, make_const(ramp_index->stride.type(),
+ Mul::make(ramp_index->stride, make_const(ramp_index->stride.dtype(),
ramp_index->lanes - 1)));
}
@@ -177,11 +177,11 @@ class BoundChecker : public IRMutator {
upper_bound = ir::Simplify(upper_bound);
// Cast to the same type - signed, to be able to check lower bound.
- index = Cast::make(Int(64), index);
- upper_bound = Cast::make(Int(64), upper_bound);
+ index = Cast::make(DataType::Int(64), index);
+ upper_bound = Cast::make(DataType::Int(64), upper_bound);
// Looks like a lower bound should always be zero after normalization.
- Expr lower_bound = make_zero(Int(64));
+ Expr lower_bound = make_zero(DataType::Int(64));
Expr current_condition =
And::make(GE::make(index, lower_bound), LT::make(index, upper_bound));
diff --git a/src/pass/combine_context_call.cc b/src/pass/combine_context_call.cc
index d7fb779..f1cb8fe 100644
--- a/src/pass/combine_context_call.cc
+++ b/src/pass/combine_context_call.cc
@@ -48,14 +48,14 @@ class ContextCallCombiner final : public IRMutator {
if (it != ctx_map_.end()) {
return it->second;
} else {
- CHECK(ctx.type().is_handle());
+ CHECK(ctx.dtype().is_handle());
std::string name;
if (const Call* call = ctx.as<Call>()) {
name = call->name + "_cache";
} else {
name = "ctx_cache_";
}
- Var ctx_var(name, ctx.type());
+ Var ctx_var(name, ctx.dtype());
ctx_map_[ctx] = ctx_var;
return std::move(ctx_var);
}
diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc
index 3dacb6d..4aa8879 100644
--- a/src/pass/coproc_sync.cc
+++ b/src/pass/coproc_sync.cc
@@ -198,7 +198,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
std::vector<Stmt> GetSync(std::string sync_name) {
return {Evaluate::make(Call::make(
- Int(32),
+ DataType::Int(32),
sync_name,
{}, Call::Intrinsic))};
}
@@ -345,7 +345,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
Expr min = r->min;
Expr extent = r->extent;
return Evaluate::make(Call::make(
- Int(32), func,
+ DataType::Int(32), func,
{wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, Call::Intrinsic));
}
// Write barrier name
@@ -588,14 +588,14 @@ class CoProcInstDepDetector : public IRVisitor {
Stmt MakePush(int from, int to) {
return Evaluate::make(Call::make(
- Int(32), sync_push_name_,
- {make_const(Int(32), from), make_const(Int(32), to)},
+ DataType::Int(32), sync_push_name_,
+ {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
Call::Intrinsic));
}
Stmt MakePop(int from, int to) {
return Evaluate::make(Call::make(
- Int(32), sync_pop_name_,
- {make_const(Int(32), from), make_const(Int(32), to)},
+ DataType::Int(32), sync_pop_name_,
+ {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
Call::Intrinsic));
}
// sync states.
diff --git a/src/pass/detect_device.cc b/src/pass/detect_device.cc
index 92e368b..cd7c979 100644
--- a/src/pass/detect_device.cc
+++ b/src/pass/detect_device.cc
@@ -28,7 +28,7 @@
namespace tvm {
namespace ir {
Stmt DecorateDeviceScope(Stmt stmt) {
- Stmt body = AttrStmt::make(make_zero(Int(32)),
+ Stmt body = AttrStmt::make(make_zero(DataType::Int(32)),
ir::attr::device_scope,
0,
stmt);
diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc
index 3b14836..7b7c5df 100644
--- a/src/pass/inject_copy_intrin.cc
+++ b/src/pass/inject_copy_intrin.cc
@@ -88,7 +88,7 @@ class CopyIntrinInjector : public IRMutator {
load = cast->value.as<Load>();
}
if (load == nullptr) return false;
- if (load->type.lanes() != 1) return false;
+ if (load->dtype.lanes() != 1) return false;
Array<Var> loop_vars;
for (const For* op : loops) {
loop_vars.push_back(op->loop_var);
@@ -101,7 +101,7 @@ class CopyIntrinInjector : public IRMutator {
Array<Expr> dst_shape;
const size_t loop_var_size = loop_vars.size();
if (loop_var_size == 0) {
- dst_shape.push_back(make_const(Int(32), 1));
+ dst_shape.push_back(make_const(DataType::Int(32), 1));
} else {
for (const For* op : loops) {
dst_shape.push_back(op->extent);
@@ -121,7 +121,7 @@ class CopyIntrinInjector : public IRMutator {
for (size_t i = 0; i < src_shape.size(); ++i) {
Expr min_value = clip_bound[2 * i];
Expr max_value = clip_bound[2 * i + 1];
- Type t = loop_vars[i].type();
+ DataType t = loop_vars[i].dtype();
Expr svalue = src_shape[i];
if (min_value.defined()) {
Expr pbefore = Simplify(Max::make(min_value, make_zero(t)));
@@ -148,12 +148,12 @@ class CopyIntrinInjector : public IRMutator {
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
if (loop_var_size == 0) {
- src_strides.push_back(make_const(Int(32), 1));
- dst_strides.push_back(make_const(Int(32), 1));
+ src_strides.push_back(make_const(DataType::Int(32), 1));
+ dst_strides.push_back(make_const(DataType::Int(32), 1));
}
Buffer dst = BufferNode::make(
store->buffer_var,
- store->value.type(),
+ store->value.dtype(),
dst_shape,
dst_strides,
store_strides[loop_var_size],
@@ -162,7 +162,7 @@ class CopyIntrinInjector : public IRMutator {
0, 0, kDefault);
Buffer src = BufferNode::make(
load->buffer_var,
- load->type,
+ load->dtype,
src_shape,
src_strides,
src_elem_offset,
diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc
index 065bbd4..78d3305 100644
--- a/src/pass/inject_double_buffer.cc
+++ b/src/pass/inject_double_buffer.cc
@@ -100,10 +100,10 @@ class DoubleBufferInjector : public IRMutator {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
it->second.stride = arith::ComputeReduce<Mul>(
- op->extents, Expr()) * op->type.lanes();
+ op->extents, Expr()) * op->dtype.lanes();
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Allocate>();
- Array<Expr> new_extents{make_const(op->extents[0].type(), 2)};
+ Array<Expr> new_extents{make_const(op->extents[0].dtype(), 2)};
for (Expr e : op->extents) {
new_extents.push_back(e);
}
@@ -114,7 +114,7 @@ class DoubleBufferInjector : public IRMutator {
StringImm::make(it->second.scope),
Evaluate::make(0)));
alloc_nest.emplace_back(Allocate::make(
- op->buffer_var, op->type, new_extents, op->condition,
+ op->buffer_var, op->dtype, new_extents, op->condition,
Evaluate::make(0)));
return op->body;
} else {
@@ -135,15 +135,15 @@ class DoubleBufferInjector : public IRMutator {
CHECK(is_zero(old_loop->min));
Expr zero = old_loop->min;
Expr new_ext =
- old_loop->extent - make_const(old_loop->loop_var.type(), 1);
- Expr factor = make_const(new_ext.type(), split_loop_);
+ old_loop->extent - make_const(old_loop->loop_var.dtype(), 1);
+ Expr factor = make_const(new_ext.dtype(), split_loop_);
Expr outer_ext = new_ext / factor;
Expr tail_base = outer_ext * factor;
- Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type());
+ Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.dtype());
std::unordered_map<const Variable*, Expr> vmap;
std::vector<Stmt> loop_seq;
for (int32_t i = 0; i < split_loop_; ++i) {
- vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.type(), i);
+ vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i);
loop_seq.emplace_back(Substitute(old_loop->body, vmap));
}
Stmt loop = For::make(
@@ -153,7 +153,7 @@ class DoubleBufferInjector : public IRMutator {
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body);
for (int32_t i = 0; i < split_loop_; ++i) {
- Expr idx = tail_base + make_const(tail_base.type(), i);
+ Expr idx = tail_base + make_const(tail_base.dtype(), i);
vmap[old_loop->loop_var.get()] = idx;
tail_seq.emplace_back(
IfThenElse::make(idx < old_loop->extent,
@@ -196,7 +196,7 @@ class DoubleBufferInjector : public IRMutator {
const StorageEntry& e = it->second;
CHECK(e.stride.defined());
CHECK(e.switch_read_var.defined());
- return Load::make(op->type,
+ return Load::make(op->dtype,
op->buffer_var,
e.switch_read_var * e.stride + op->index,
op->predicate);
@@ -222,12 +222,12 @@ class DoubleBufferInjector : public IRMutator {
}
StorageEntry& e = it->second;
e.loop = loop_nest_.back();
- Expr zero = make_const(e.loop->loop_var.type(), 0);
- Expr one = make_const(e.loop->loop_var.type(), 1);
- Expr two = make_const(e.loop->loop_var.type(), 2);
+ Expr zero = make_const(e.loop->loop_var.dtype(), 0);
+ Expr one = make_const(e.loop->loop_var.dtype(), 1);
+ Expr two = make_const(e.loop->loop_var.dtype(), 2);
Expr loop_shift = e.loop->loop_var + one;
e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
- e.loop->loop_var.type());
+ e.loop->loop_var.dtype());
e.switch_read_var = indexmod(e.loop->loop_var, two);
in_double_buffer_scope_ = true;
Stmt body = Mutate(op->body);
diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc
index eafe5a9..c80c7fc 100644
--- a/src/pass/inject_virtual_thread.cc
+++ b/src/pass/inject_virtual_thread.cc
@@ -222,7 +222,7 @@ class VTInjector : public IRMutator {
}
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
- return Load::make(op->type, op->buffer_var,
+ return Load::make(op->dtype, op->buffer_var,
RewriteIndex(op->index, it->second),
op->predicate);
} else {
@@ -233,7 +233,7 @@ class VTInjector : public IRMutator {
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
- Type dtype = op->args[0].type();
+ DataType dtype = op->args[0].dtype();
const Variable* buffer = op->args[1].as<Variable>();
auto it = alloc_remap_.find(buffer);
if (it == alloc_remap_.end()) return IRMutator::Mutate_(op, e);
@@ -241,10 +241,10 @@ class VTInjector : public IRMutator {
Expr offset = Mutate(op->args[2]);
Expr extent = Mutate(op->args[3]);
Expr stride =
- it->second / make_const(offset.type(), dtype.lanes());
+ it->second / make_const(offset.dtype(), dtype.lanes());
offset = stride * var_ + offset;
return Call::make(
- op->type, op->name,
+ op->dtype, op->name,
{op->args[0], op->args[1], offset, extent, op->args[4]},
op->call_type);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
@@ -395,9 +395,9 @@ class VTInjector : public IRMutator {
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension.
Expr stride = arith::ComputeReduce<Mul>(
- op->extents, Expr()) * op->type.lanes();
+ op->extents, Expr()) * op->dtype.lanes();
Array<Expr> other;
- other.push_back(make_const(op->extents[0].type(), num_threads_));
+ other.push_back(make_const(op->extents[0].dtype(), num_threads_));
for (Expr e : extents) {
other.push_back(e);
}
@@ -417,7 +417,7 @@ class VTInjector : public IRMutator {
return s;
} else {
return Allocate::make(
- op->buffer_var, op->type,
+ op->buffer_var, op->dtype,
extents, condition, body,
op->new_expr, op->free_function);
}
@@ -439,19 +439,19 @@ class VTInjector : public IRMutator {
// only unroll if number of vthreads are small
if (max_loop_depth_ == 0 && num_threads_ < 16) {
// do unrolling if it is inside innermost content.
- Stmt blk = Substitute(stmt, {{var_, make_zero(var_.type())}});
+ Stmt blk = Substitute(stmt, {{var_, make_zero(var_.dtype())}});
for (int i = 1; i < num_threads_; ++i) {
blk = Block::make(
- blk, Substitute(stmt, {{var_, make_const(var_.type(), i)}}));
+ blk, Substitute(stmt, {{var_, make_const(var_.dtype(), i)}}));
}
return blk;
} else {
// insert a for loop
- Var idx(var_->name_hint + ".s", var_->type);
+ Var idx(var_->name_hint + ".s", var_->dtype);
Map<Var, Expr> values{{var_, idx}};
stmt = Substitute(stmt, values);
- return For::make(idx, make_zero(idx.type()),
- make_const(idx.type(), num_threads_),
+ return For::make(idx, make_zero(idx.dtype()),
+ make_const(idx.dtype(), num_threads_),
ForType::Serial, DeviceAPI::None, stmt);
}
}
diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc
index cb859d0..e399e7f 100644
--- a/src/pass/ir_deep_compare.cc
+++ b/src/pass/ir_deep_compare.cc
@@ -63,7 +63,7 @@ class IRDeepCompare :
if (order_ != 0) return;
if (n.same_as(other)) return;
if (CompareValue(n->type_index(), other->type_index()) != 0) return;
- if (CompareType(n.type(), other.type()) != 0) return;
+ if (CompareType(n.dtype(), other.dtype()) != 0) return;
ExprComparator::VisitExpr(n, other);
}
@@ -119,7 +119,7 @@ class IRDeepCompare :
} else {
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
}
- if (CompareType(op->type, rhs->type) != 0) return;
+ if (CompareType(op->dtype, rhs->dtype) != 0) return;
if (CompareArray(op->extents, rhs->extents) != 0) return;
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
@@ -166,7 +166,7 @@ class IRDeepCompare :
const Realize* rhs = other.as<Realize>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
- if (CompareType(op->type, rhs->type) != 0) return;
+ if (CompareType(op->dtype, rhs->dtype) != 0) return;
if (CompareRegion(op->bounds, rhs->bounds) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
@@ -175,7 +175,7 @@ class IRDeepCompare :
const Prefetch* rhs = other.as<Prefetch>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
- if (CompareType(op->type, rhs->type) != 0) return;
+ if (CompareType(op->dtype, rhs->dtype) != 0) return;
if (CompareRegion(op->bounds, rhs->bounds) != 0) return;
}
@@ -369,7 +369,7 @@ class IRDeepCompare :
return order_;
}
- int CompareType(const Type& lhs, const Type& rhs) {
+ int CompareType(const DataType& lhs, const DataType& rhs) {
if (order_ != 0) return order_;
if (lhs == rhs) return order_;
if (CompareValue(lhs.code(), rhs.code()) != 0) return order_;
diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc
index 6022267..b300989 100644
--- a/src/pass/ir_mutator.cc
+++ b/src/pass/ir_mutator.cc
@@ -179,7 +179,7 @@ Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) {
return s;
} else {
return Allocate::make(
- op->buffer_var, op->type,
+ op->buffer_var, op->dtype,
new_extents, condition, body,
new_expr, op->free_function);
}
@@ -247,7 +247,7 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
return s;
} else {
return Realize::make(op->func, op->value_index,
- op->type, new_bounds,
+ op->dtype, new_bounds,
condition, body);
}
}
@@ -273,7 +273,7 @@ Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
return s;
} else {
return Prefetch::make(op->func, op->value_index,
- op->type, new_bounds);
+ op->dtype, new_bounds);
}
}
@@ -358,7 +358,7 @@ Expr IRMutator::Mutate_(const Load* op, const Expr& e) {
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
return e;
} else {
- return Load::make(op->type, op->buffer_var, index, pred);
+ return Load::make(op->dtype, op->buffer_var, index, pred);
}
}
@@ -378,7 +378,7 @@ Expr IRMutator::Mutate_(const Call* op, const Expr& e) {
if (op->args.same_as(new_args)) {
return e;
} else {
- return Call::make(op->type, op->name, new_args, op->call_type,
+ return Call::make(op->dtype, op->name, new_args, op->call_type,
op->func, op->value_index);
}
}
@@ -432,7 +432,7 @@ Expr IRMutator::Mutate_(const Cast* op, const Expr& e) {
if (value.same_as(op->value)) {
return e;
} else {
- return Cast::make(op->type, value);
+ return Cast::make(op->dtype, value);
}
}
diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h
index 690feca..0f8bb99 100644
--- a/src/pass/ir_util.h
+++ b/src/pass/ir_util.h
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -89,12 +89,12 @@ inline Array<T> UpdateArray(Array<T> arr, F fupdate) {
* \return the get expression.
*/
inline Expr TVMStructGet(
- Type dtype, Var handle, int index,
+ DataType dtype, Var handle, int index,
intrinsic::TVMStructFieldKind kind) {
Array<Expr> args ={
handle,
- make_const(Int(32), index),
- make_const(Int(32), static_cast<int>(kind))};
+ make_const(DataType::Int(32), index),
+ make_const(DataType::Int(32), static_cast<int>(kind))};
return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic);
}
@@ -104,10 +104,10 @@ inline Expr TVMStructGet(
* \param dtype The data type.
* \param offset the offset index.
*/
-inline Expr AddressOffset(Var handle, Type dtype, int offset) {
+inline Expr AddressOffset(Var handle, DataType dtype, int offset) {
return Call::make(
- Handle(), intrinsic::tvm_address_of,
- {Load::make(dtype, handle, make_const(Int(32), offset * dtype.lanes()),
+ DataType::Handle(), intrinsic::tvm_address_of,
+ {Load::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
const_true(dtype.lanes()))},
Call::PureIntrinsic);
}
@@ -118,13 +118,13 @@ inline Expr AddressOffset(Var handle, Type dtype, int offset) {
* \param dtype The data type.
* \param offset the offset index.
*/
-inline Expr AddressOffset(Var handle, Type dtype, Expr offset) {
+inline Expr AddressOffset(Var handle, DataType dtype, Expr offset) {
if (dtype.lanes() != 1) {
- offset = offset * make_const(offset.type(), dtype.lanes());
- offset = Ramp::make(offset, make_const(offset.type(), 1), dtype.lanes());
+ offset = offset * make_const(offset.dtype(), dtype.lanes());
+ offset = Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
}
return Call::make(
- Handle(), intrinsic::tvm_address_of,
+ DataType::Handle(), intrinsic::tvm_address_of,
{Load::make(dtype, handle, offset,
const_true(dtype.lanes()))},
Call::PureIntrinsic);
@@ -143,11 +143,11 @@ inline Stmt TVMStructSet(
intrinsic::TVMStructFieldKind kind, Expr value) {
Array<Expr> args ={
handle,
- make_const(Int(32), index),
- make_const(Int(32), static_cast<int>(kind)),
+ make_const(DataType::Int(32), index),
+ make_const(DataType::Int(32), static_cast<int>(kind)),
value};
return Evaluate::make(
- Call::make(Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic));
+ Call::make(DataType::Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic));
}
/*!
@@ -155,13 +155,13 @@ inline Stmt TVMStructSet(
* \param t The original type.
* \return The corresponding API type.
*/
-inline Type APIType(Type t) {
+inline DataType APIType(DataType t) {
if (t.is_handle()) return t;
CHECK_EQ(t.lanes(), 1)
<< "Cannot pass vector type through packed API.";
- if (t.is_uint() || t.is_int()) return Int(64);
+ if (t.is_uint() || t.is_int()) return DataType::Int(64);
CHECK(t.is_float());
- return Float(64);
+ return DataType::Float(64);
}
/*!
@@ -170,7 +170,7 @@ inline Type APIType(Type t) {
* \param const_size The constant size of the array.
* \return the alignment
*/
-inline int GetTempAllocaAlignment(Type type, int32_t const_size) {
+inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
int align = runtime::kTempAllocaAlignment;
if (const_size > 0) {
int64_t const_s = static_cast<int64_t>(const_size) * type.bits() * type.lanes() / 8;
diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc
index adcaaeb..cfc6e5a 100644
--- a/src/pass/lift_attr_scope.cc
+++ b/src/pass/lift_attr_scope.cc
@@ -57,7 +57,7 @@ class AttrScopeLifter : public IRMutator {
attr_node_ = NodeRef();
attr_value_ = Expr();
return Allocate::make(
- op->buffer_var, op->type,
+ op->buffer_var, op->dtype,
op->extents, op->condition, body,
op->new_expr, op->free_function);
} else {
@@ -198,7 +198,7 @@ class AttrScopeLifter : public IRMutator {
static bool ValueSame(const Expr& a, const Expr& b) {
if (a.same_as(b)) return true;
if (a->type_index() != b->type_index()) return false;
- if (a.type() != b.type()) return false;
+ if (a.dtype() != b.dtype()) return false;
if (const IntImm* op = a.as<IntImm>()) {
return op->value == b.as<IntImm>()->value;
}
diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc
index ef5cc9c..1ac3867 100644
--- a/src/pass/loop_partition.cc
+++ b/src/pass/loop_partition.cc
@@ -181,7 +181,7 @@ class PartitionFinder : public IRVisitor {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
const Variable* var = thread_axis->var.get();
- IntSet dom = IntSet::range(Range(make_zero(op->value.type()), op->value));
+ IntSet dom = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
hint_map_.insert({var, dom});
relax_map_.insert({var, dom});
IRVisitor::Visit_(op);
@@ -351,12 +351,12 @@ class LoopPartitioner : public IRMutator {
if (scope.rank == 1) {
// threadIdx should be put into relax map, in case of divergence.
relax_map_.insert({var.get(),
- IntSet::interval(make_zero(var.type()), op->value - 1)});
+ IntSet::interval(make_zero(var.dtype()), op->value - 1)});
res = IRMutator::Mutate_(op, stmt);
relax_map_.erase(var.get());
} else {
hint_map_.insert({var.get(),
- IntSet::interval(make_zero(var.type()), op->value - 1)});
+ IntSet::interval(make_zero(var.dtype()), op->value - 1)});
res = IRMutator::Mutate_(op, stmt);
hint_map_.erase(var.get());
}
@@ -595,9 +595,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) {
const For *for_node = static_cast<const For*>(node);
CHECK(for_node);
- if (analyzer_.CanProve(extent == make_const(Int(32), 1))) {
+ if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) {
// If the loop extent is 1, do not create the loop anymore
- return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}});
+ return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
return For::make(for_node->loop_var, 0, extent,
for_node->for_type, for_node->device_api, body);
diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc
index 3e71868..e24cddd 100644
--- a/src/pass/lower_custom_datatypes.cc
+++ b/src/pass/lower_custom_datatypes.cc
@@ -42,8 +42,8 @@ class CustomDatatypesLowerer : public IRMutator {
explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {}
inline Expr Mutate_(const Cast* op, const Expr& e) final {
- auto type_code = op->type.code();
- auto src_type_code = op->value.type().code();
+ auto type_code = op->dtype.code();
+ auto src_type_code = op->value.dtype().code();
// If either datatype is a registered custom datatype, we must lower.
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
datatype::Registry::Global()->GetTypeRegistered(src_type_code);
@@ -60,7 +60,7 @@ class CustomDatatypesLowerer : public IRMutator {
}
inline Expr Mutate_(const FloatImm* imm, const Expr& e) final {
- auto type_code = imm->type.code();
+ auto type_code = imm->dtype.code();
if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
auto lower = datatype::GetFloatImmLowerFunc(target_, type_code);
CHECK(lower) << "FloatImm lowering function for target " << target_ << " type "
@@ -71,12 +71,12 @@ class CustomDatatypesLowerer : public IRMutator {
}
inline Stmt Mutate_(const Allocate* allocate, const Stmt& s) final {
- bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->type.code());
+ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code());
Stmt stmt = IRMutator::Mutate_(allocate, s);
allocate = stmt.as<Allocate>();
if (toBeLowered) {
- auto new_allocate_type = UInt(allocate->type.bits(), allocate->type.lanes());
+ auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes());
return Allocate::make(allocate->buffer_var, new_allocate_type, allocate->extents,
allocate->condition, allocate->body, allocate->new_expr,
allocate->free_function);
@@ -85,11 +85,11 @@ class CustomDatatypesLowerer : public IRMutator {
}
inline Expr Mutate_(const Load* load, const Expr& e) final {
- bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->type.code());
+ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
Expr expr = IRMutator::Mutate_(load, e);
load = expr.as<Load>();
if (toBeLowered) {
- auto new_load_type = UInt(load->type.bits());
+ auto new_load_type = DataType::UInt(load->dtype.bits());
return Load::make(new_load_type, load->buffer_var, load->index, load->predicate);
}
return expr;
@@ -97,7 +97,7 @@ class CustomDatatypesLowerer : public IRMutator {
#define DEFINE_MUTATE__(OP) \
inline Expr Mutate_(const OP* op, const Expr& e) final { \
- auto type_code = op->type.code(); \
+ auto type_code = op->dtype.code(); \
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
Expr expr = IRMutator::Mutate_(op, e); \
op = expr.as<OP>(); \
diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc
index c2a2fe6..f0b0b3c 100644
--- a/src/pass/lower_intrin.cc
+++ b/src/pass/lower_intrin.cc
@@ -76,7 +76,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
op = ret.as<FloorDiv>();
if (op == nullptr) return ret;
int shift;
- const DataType& dtype = op->type;
+ const DataType& dtype = op->dtype;
CHECK(dtype.is_int() || dtype.is_uint());
if (support_bitwise_op_ &&
@@ -97,7 +97,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
- if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) {
+ if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
@@ -123,7 +123,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if (op == nullptr) return ret;
// Lower floordiv to native truncdiv.
int shift;
- const DataType& dtype = op->type;
+ const DataType& dtype = op->dtype;
CHECK(dtype.is_int() || dtype.is_uint());
if (support_bitwise_op_ &&
@@ -144,7 +144,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
Expr rmod = truncmod(op->a, op->b);
- if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) {
+ if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
// -> rmod >= 0 ? 0 : b
@@ -207,23 +207,23 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
if (const Cast* cast = bcast->value.as<Cast>()) {
auto should_swap = [&]() {
// Maintain behaviour (int8 -> int16, fp16 -> fp32).
- if (cast->type.bits() == cast->value.type().bits() * 2) {
+ if (cast->dtype.bits() == cast->value.dtype().bits() * 2) {
return true;
}
// Check both operands are integer-like.
- if (!cast->type.is_uint() && !cast->type.is_int()) {
+ if (!cast->dtype.is_uint() && !cast->dtype.is_int()) {
return false;
}
- if (!cast->value.type().is_uint() && !cast->value.type().is_int()) {
+ if (!cast->value.dtype().is_uint() && !cast->value.dtype().is_int()) {
return false;
}
// If both are integer-like, swap if we have a widening cast.
- return cast->type.bits() > cast->value.type().bits();
+ return cast->dtype.bits() > cast->value.dtype().bits();
};
if (should_swap()) {
Expr new_bcast = Broadcast::make(cast->value, bcast->lanes);
- return Cast::make(bcast->type, new_bcast);
+ return Cast::make(bcast->dtype, new_bcast);
}
}
}
@@ -236,9 +236,9 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer {
Expr lhs = SwapBroadcastCast(a);
Expr rhs = SwapBroadcastCast(b);
- if (fma_ != nullptr && op->type.is_float()) {
+ if (fma_ != nullptr && op->dtype.is_float()) {
Expr r = (*fma_)(Call::make(
- op->type, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
+ op->dtype, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
if (r.defined()) return this->Mutate(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc
index e8ea52e..2a12118 100644
--- a/src/pass/lower_thread_allreduce.cc
+++ b/src/pass/lower_thread_allreduce.cc
@@ -83,7 +83,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
stmt = AttrStmt::make(
repl->buffer_var, attr::volatile_scope, 1, op->body);
stmt = Allocate::make(
- repl->buffer_var, repl->type,
+ repl->buffer_var, repl->dtype,
repl->extents, repl->condition, stmt);
stmt = AttrStmt::make(
repl->buffer_var, attr::storage_scope,
@@ -125,14 +125,14 @@ class ThreadAllreduceBuilder final : public IRMutator {
CHECK_EQ(size, size_of_args->value);
Array<Expr> inits = combiner->identity_element;
std::vector<Expr> values(size);
- std::vector<Type> types(size);
+ std::vector<DataType> types(size);
Expr cond = call->args[size+1];
for (size_t idx = 0; idx < size; ++idx) {
values[idx] = call->args[1+idx];
if (!is_one(cond)) {
values[idx] = Select::make(cond, values[idx], inits[idx]);
}
- types[idx] = values[idx].type();
+ types[idx] = values[idx].dtype();
}
std::vector<const Variable*> buffers(size);
for (size_t idx = 0; idx < size; ++idx) {
@@ -197,7 +197,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
// previous iteration on the same buffer.
seq.emplace_back(SyncThread("shared"));
for (size_t idx = 0; idx < size; ++idx) {
- shared_bufs[idx] = Var("red_buf"+std::to_string(idx), Handle());
+ shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle());
Expr pred = const_true(types[idx].lanes());
seq.emplace_back(Store::make(
shared_bufs[idx], values[idx],
@@ -212,7 +212,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
Expr pred = const_true(types[idx].lanes());
load_remap_[buffers[idx]] = Load::make(
types[idx], shared_bufs[idx],
- BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), pred);
+ BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
alloc_remap_[buffers[idx]] = Allocate::make(
shared_bufs[idx], types[idx],
{Expr(group_extent), Expr(reduce_extent)},
@@ -222,7 +222,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
}
// make allreduce.
Stmt MakeBufAllreduce(const CommReducerNode *combiner,
- const std::vector<Type>& types,
+ const std::vector<DataType>& types,
const Array<Var>& shared_bufs,
Expr reduce_index,
Expr group_index,
@@ -293,7 +293,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
int& total_extent = *out_total_extent;
total_extent = 1;
if (tvec.size() == 0) {
- return make_zero(Int(32));
+ return make_zero(DataType::Int(32));
}
Expr ret;
@@ -311,7 +311,7 @@ class ThreadAllreduceBuilder final : public IRMutator {
// sync thread op.
static Stmt SyncThread(const std::string& sync) {
return Evaluate::make(
- Call::make(Int(32), intrinsic::tvm_storage_sync,
+ Call::make(DataType::Int(32), intrinsic::tvm_storage_sync,
{StringImm::make(sync)},
Call::Intrinsic));
}
diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc
index e73956c..c8c8fa9 100644
--- a/src/pass/lower_tvm_builtin.cc
+++ b/src/pass/lower_tvm_builtin.cc
@@ -33,12 +33,12 @@ namespace ir {
inline Expr ConstInt32(size_t index) {
CHECK_LE(index, std::numeric_limits<int>::max());
- return make_const(Int(32), static_cast<int>(index));
+ return make_const(DataType::Int(32), static_cast<int>(index));
}
inline Expr StackAlloca(std::string type, size_t num) {
Array<Expr> args = {StringImm::make(type), ConstInt32(num)};
- return Call::make(Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic);
+ return Call::make(DataType::Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic);
}
// Calculate the statistics of packed function.
@@ -46,10 +46,10 @@ inline Expr StackAlloca(std::string type, size_t num) {
class BuiltinLower : public IRMutator {
public:
Stmt Build(Stmt stmt) {
- stack_shape_ = Var("stack_shape", Handle());
- stack_array_ = Var("stack_array", Handle());
- stack_value_ = Var("stack_value", Handle());
- stack_tcode_ = Var("stack_tcode", Handle());
+ stack_shape_ = Var("stack_shape", DataType::Handle());
+ stack_array_ = Var("stack_array", DataType::Handle());
+ stack_value_ = Var("stack_value", DataType::Handle());
+ stack_tcode_ = Var("stack_tcode", DataType::Handle());
stmt = this->Mutate(stmt);
if (max_shape_stack_ != 0) {
stmt = LetStmt::make(
@@ -86,7 +86,7 @@ class BuiltinLower : public IRMutator {
if (op->new_expr.defined()) return stmt;
// Get constant allocation bound.
int64_t dev_type;
- int64_t nbytes = GetVectorBytes(op->type);
+ int64_t nbytes = GetVectorBytes(op->dtype);
if (device_type_.defined()) {
if (arith::GetConst(device_type_, &dev_type)) {
if (dev_type == kDLCPU) {
@@ -97,18 +97,18 @@ class BuiltinLower : public IRMutator {
}
}
}
- Expr total_bytes = make_const(op->extents[0].type(), nbytes);
+ Expr total_bytes = make_const(op->extents[0].dtype(), nbytes);
for (size_t i = 0; i < op->extents.size(); ++i) {
total_bytes = total_bytes * op->extents[i];
}
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
- Stmt throw_last_error = Evaluate::make(Call::make(Int(32),
+ Stmt throw_last_error = Evaluate::make(Call::make(DataType::Int(32),
intrinsic::tvm_throw_last_error, {},
Call::Intrinsic));
Stmt body = Block::make(
- IfThenElse::make(Call::make(Bool(1),
+ IfThenElse::make(Call::make(DataType::Bool(1),
intrinsic::tvm_handle_is_null,
{op->buffer_var}, Call::PureIntrinsic),
throw_last_error),
@@ -116,27 +116,27 @@ class BuiltinLower : public IRMutator {
Stmt alloca = LetStmt::make(
op->buffer_var,
- Call::make(op->buffer_var.type(),
+ Call::make(op->buffer_var.dtype(),
"TVMBackendAllocWorkspace",
- {cast(Int(32), device_type_),
- cast(Int(32), device_id_),
- cast(UInt(64), total_bytes),
- IntImm::make(Int(32), op->type.code()),
- IntImm::make(Int(32), op->type.bits())},
+ {cast(DataType::Int(32), device_type_),
+ cast(DataType::Int(32), device_id_),
+ cast(DataType::UInt(64), total_bytes),
+ IntImm::make(DataType::Int(32), op->dtype.code()),
+ IntImm::make(DataType::Int(32), op->dtype.bits())},
Call::Extern),
body);
- Expr free_op = Call::make(Int(32),
+ Expr free_op = Call::make(DataType::Int(32),
"TVMBackendFreeWorkspace",
- {cast(Int(32), device_type_),
- cast(Int(32), device_id_),
+ {cast(DataType::Int(32), device_type_),
+ cast(DataType::Int(32), device_id_),
op->buffer_var},
Call::Extern);
- Stmt free_stmt = IfThenElse::make(free_op != make_zero(Int(32)), throw_last_error);
+ Stmt free_stmt = IfThenElse::make(free_op != make_zero(DataType::Int(32)), throw_last_error);
body = Block::make(alloca, free_stmt);
body = AttrStmt::make(
op->buffer_var, attr::storage_alignment,
- make_const(Int(32), runtime::kTempAllocaAlignment),
+ make_const(DataType::Int(32), runtime::kTempAllocaAlignment),
body);
return body;
}
@@ -164,7 +164,7 @@ class BuiltinLower : public IRMutator {
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
return MakeArray(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
- return make_zero(op->type);
+ return make_zero(op->dtype);
} else {
return IRMutator::Mutate_(op, e);
}
@@ -177,10 +177,10 @@ class BuiltinLower : public IRMutator {
op = expr.as<Call>();
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(
- Store::make(stack_shape_, cast(Int(64), op->args[i]),
+ Store::make(stack_shape_, cast(DataType::Int(64), op->args[i]),
ConstInt32(stack_begin +i), const_true(1)));
}
- return AddressOffset(stack_shape_, Int(64), stack_begin);
+ return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
}
// make array
Expr MakeArray(const Call* op, const Expr& e) {
@@ -194,40 +194,40 @@ class BuiltinLower : public IRMutator {
TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
Expr strides = op->args[2];
if (!strides.defined() || is_zero(strides)) {
- strides = make_zero(Handle());
+ strides = make_zero(DataType::Handle());
}
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3]));
- Type dtype = op->args[4].type();
+ DataType dtype = op->args[4].dtype();
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode,
- make_const(UInt(8), static_cast<int>(dtype.code()))));
+ make_const(DataType::UInt(8), static_cast<int>(dtype.code()))));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits,
- make_const(UInt(8), dtype.bits())));
+ make_const(DataType::UInt(8), dtype.bits())));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
- make_const(UInt(16), dtype.lanes())));
+ make_const(DataType::UInt(16), dtype.lanes())));
// set byte offset
int data_bytes = GetVectorBytes(dtype);
Expr byte_offset = op->args[5];
if (!is_zero(byte_offset)) {
- byte_offset = byte_offset * make_const(byte_offset.type(), data_bytes);
+ byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes);
}
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
- cast(UInt(64), byte_offset)));
+ cast(DataType::UInt(64), byte_offset)));
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
- cast(Int(32), device_id_)));
+ cast(DataType::Int(32), device_id_)));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
- cast(Int(32), device_type_)));
- return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr);
+ cast(DataType::Int(32), device_type_)));
+ return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr);
}
// call packed.
Expr MakeCallPacked(const Call* op, const Expr& e) {
@@ -241,8 +241,8 @@ class BuiltinLower : public IRMutator {
for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
Expr arg = op->args[i];
- Type t = arg.type();
- Type api_type = APIType(t);
+ DataType t = arg.dtype();
+ DataType api_type = APIType(t);
if (t != api_type) {
arg = Cast::make(api_type, arg);
}
@@ -274,7 +274,7 @@ class BuiltinLower : public IRMutator {
ConstInt32(arg_stack_begin + op->args.size() - 1)
};
return Call::make(
- Int(32), intrinsic::tvm_call_packed_lowered,
+ DataType::Int(32), intrinsic::tvm_call_packed_lowered,
packed_args, Call::Intrinsic);
}
@@ -290,8 +290,8 @@ class BuiltinLower : public IRMutator {
for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
Expr arg = op->args[i];
- Type t = arg.type();
- Type api_type = APIType(t);
+ DataType t = arg.dtype();
+ DataType api_type = APIType(t);
if (t != api_type) {
arg = Cast::make(api_type, arg);
}
@@ -324,7 +324,7 @@ class BuiltinLower : public IRMutator {
op->args[args_size - 1]
};
return Call::make(
- op->type, intrinsic::tvm_call_trace_packed_lowered,
+ op->dtype, intrinsic::tvm_call_trace_packed_lowered,
packed_args, Call::Intrinsic);
}
diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc
index 393605e..0ed2b62 100644
--- a/src/pass/lower_warp_memory.cc
+++ b/src/pass/lower_warp_memory.cc
@@ -94,11 +94,11 @@ class WarpStoreCoeffFinder : private IRVisitor {
/// Visitor implementation
void Visit_(const Store *op) final {
if (op->buffer_var.get() == buffer_) {
- if (op->value.type().lanes() == 1) {
+ if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index);
} else {
Expr base;
- CHECK(GetRamp1Base(op->index, op->value.type().lanes(), &base))
+ CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base))
<< "LowerWarpMemory failed due to store index=" << op->index
<< ", can only handle continuous store";
UpdatePattern(base);
@@ -196,7 +196,7 @@ class WarpAccessRewriter : protected IRMutator {
int alloc_size = op->constant_allocation_size();
CHECK_GT(alloc_size, 0)
<< "warp memory only support constant alloc size";
- alloc_size *= op->type.lanes();
+ alloc_size *= op->dtype.lanes();
warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var;
warp_coeff_ = WarpStoreCoeffFinder(
buffer_, warp_index_, analyzer_).Find(op->body);
@@ -205,8 +205,8 @@ class WarpAccessRewriter : protected IRMutator {
warp_group_ = alloc_size / (warp_size_ * warp_coeff_);
return Allocate::make(
op->buffer_var,
- op->type,
- {make_const(Int(32), alloc_size / warp_size_)},
+ op->dtype,
+ {make_const(DataType::Int(32), alloc_size / warp_size_)},
op->condition,
this->Mutate(op->body));
}
@@ -237,8 +237,8 @@ class WarpAccessRewriter : protected IRMutator {
<< "LowerWarpMemory failed to rewrite load to shuffle for index "
<< op->index << " local_index=" << local_index;
Expr load_value = Load::make(
- op->type, op->buffer_var, local_index, op->predicate);
- return Call::make(load_value.type(),
+ op->dtype, op->buffer_var, local_index, op->predicate);
+ return Call::make(load_value.dtype(),
intrinsic::tvm_warp_shuffle,
{load_value, group},
Call::Intrinsic);
@@ -252,15 +252,15 @@ class WarpAccessRewriter : protected IRMutator {
// source index is the corresponding source index
// in this access pattern.
std::pair<Expr, Expr> SplitIndexByGroup(const Expr& index) {
- if (index.type().lanes() != 1) {
+ if (index.dtype().lanes() != 1) {
Expr base, local_index, group;
- CHECK(GetRamp1Base(index, index.type().lanes(), &base));
+ CHECK(GetRamp1Base(index, index.dtype().lanes(), &base));
std::tie(local_index, group) = SplitIndexByGroup(base);
local_index =
- Ramp::make(local_index, make_const(local_index.type(), 1), index.type().lanes());
+ Ramp::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes());
return std::make_pair(local_index, group);
}
- Expr m = make_const(index.type(), warp_coeff_);
+ Expr m = make_const(index.dtype(), warp_coeff_);
// simple case, warp index is on the highest.
if (warp_group_ == 1) {
@@ -269,9 +269,9 @@ class WarpAccessRewriter : protected IRMutator {
return std::make_pair(x, z);
} else {
Expr x = analyzer_->canonical_simplify(indexmod(index, m));
- Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_);
+ Expr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_);
y = y * m + x;
- Expr z = indexdiv(indexmod(index, make_const(index.type(), warp_coeff_ * warp_size_)),
+ Expr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)),
m);
return std::make_pair(analyzer_->canonical_simplify(y),
analyzer_->canonical_simplify(z));
diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc
index 4d9c92b..74b8f89 100644
--- a/src/pass/make_api.cc
+++ b/src/pass/make_api.cc
@@ -51,9 +51,9 @@ LoweredFunc MakeAPI(Stmt body,
int num_packed_args = num_args - num_unpacked_args;
// Data field definitions
// The packed fields
- Var v_packed_args("args", Handle());
- Var v_packed_arg_type_ids("arg_type_ids", Handle());
- Var v_num_packed_args("num_args", Int(32));
+ Var v_packed_args("args", DataType::Handle());
+ Var v_packed_arg_type_ids("arg_type_ids", DataType::Handle());
+ Var v_num_packed_args("num_args", DataType::Int(32));
// The arguments of the function.
Array<Var> args;
// The device context
@@ -66,12 +66,12 @@ LoweredFunc MakeAPI(Stmt body,
// ---------------------------
// local function definitions
// load i-th argument as type t
- auto f_arg_value = [&](Type t, int i) {
+ auto f_arg_value = [&](DataType t, int i) {
Array<Expr> call_args{v_packed_args,
- IntImm::make(Int(32), i),
- IntImm::make(Int(32), intrinsic::kTVMValueContent)};
+ IntImm::make(DataType::Int(32), i),
+ IntImm::make(DataType::Int(32), intrinsic::kTVMValueContent)};
// load 64 bit version
- Type api_type = APIType(t);
+ DataType api_type = APIType(t);
Expr res = Call::make(
api_type, intrinsic::tvm_struct_get, call_args,
Call::PureIntrinsic);
@@ -86,7 +86,7 @@ LoweredFunc MakeAPI(Stmt body,
std::ostringstream os;
os << "arg" << i;
const Variable* v = api_args[i].as<Variable>();
- return Var(os.str(), v ? v->type: Handle());
+ return Var(os.str(), v ? v->dtype: DataType::Handle());
};
// ---------------------------
// start of logics
@@ -110,14 +110,15 @@ LoweredFunc MakeAPI(Stmt body,
if (i < num_packed_args) {
// Value loads
seq_init.emplace_back(LetStmt::make(
- v_arg, f_arg_value(v_arg.type(), i), nop));
+ v_arg, f_arg_value(v_arg.dtype(), i), nop));
// type code checks
- Var tcode(v_arg->name_hint + ".code", Int(32));
+ Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(LetStmt::make(
tcode, Load::make(
- Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i), const_true(1)),
+ DataType::Int(32), v_packed_arg_type_ids,
+ IntImm::make(DataType::Int(32), i), const_true(1)),
nop));
- Type t = v_arg.type();
+ DataType t = v_arg.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be pointer";
@@ -174,7 +175,7 @@ LoweredFunc MakeAPI(Stmt body,
n->is_packed_func = num_unpacked_args == 0;
n->is_restricted = is_restricted;
body = AttrStmt::make(
- make_zero(Int(32)), attr::compute_scope,
+ make_zero(DataType::Int(32)), attr::compute_scope,
StringImm::make(name + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
@@ -186,7 +187,7 @@ LoweredFunc MakeAPI(Stmt body,
node, attr::device_context_type, device_type, nop));
Stmt set_device = IfThenElse::make(
device_type != kDLCPU, Evaluate::make(Call::make(
- Int(32), intrinsic::tvm_call_packed,
+ DataType::Int(32), intrinsic::tvm_call_packed,
{StringImm::make(runtime::symbol::tvm_set_device),
device_type, device_id}, Call::Intrinsic)));
body = Block::make(set_device, body);
@@ -215,7 +216,7 @@ class DeviceTypeBinder: public IRMutator {
if (op->attr_key == attr::device_context_type) {
if (const Variable* var = op->value.as<Variable>()) {
var_ = var;
- Expr value = make_const(op->value.type(), device_type_);
+ Expr value = make_const(op->value.dtype(), device_type_);
Stmt body = IRMutator::Mutate_(op, s);
var_ = nullptr;
std::ostringstream os;
@@ -245,14 +246,14 @@ class DeviceTypeBinder: public IRMutator {
Expr res = IRMutator::Mutate_(op, e);
op = res.as<NE>();
if (ir::Equal(op->a, op->b)) {
- return make_const(op->type, false);
+ return make_const(op->dtype, false);
}
return res;
}
Expr Mutate_(const Variable* op, const Expr& e) final {
if (op == var_) {
- return make_const(op->type, device_type_);
+ return make_const(op->dtype, device_type_);
} else {
return e;
}
diff --git a/src/pass/narrow_channel_access.cc b/src/pass/narrow_channel_access.cc
index 13c4e51..6687512 100644
--- a/src/pass/narrow_channel_access.cc
+++ b/src/pass/narrow_channel_access.cc
@@ -93,7 +93,7 @@ class ChannelAccessIndexRewriter : public IRMutator {
op = expr.as<Load>();
if (read_access_ && buf_var_ == op->buffer_var.get()) {
return Load::make(
- op->type, op->buffer_var, ir::Simplify(op->index - min_),
+ op->dtype, op->buffer_var, ir::Simplify(op->index - min_),
op->predicate);
} else {
return expr;
diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc
index 25ed039..43e3005 100644
--- a/src/pass/rewrite_unsafe_select.cc
+++ b/src/pass/rewrite_unsafe_select.cc
@@ -115,12 +115,12 @@ class UnsafeSelectRewriter : public IRMutator {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Select>();
UnsafeExprDetector unsafe;
- bool cond_is_scalar_bool = op->condition.type().is_bool() && op->condition.type().is_scalar();
+ bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
if ((unsafe.VisitExpr(op->true_value) ||
unsafe.VisitExpr(op->false_value)) &&
cond_is_scalar_bool) {
return Call::make(
- op->type,
+ op->dtype,
intrinsic::tvm_if_then_else,
{op->condition, op->true_value, op->false_value},
Call::Intrinsic);
diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc
index 2239e5a..4cbc258 100644
--- a/src/pass/split_host_device.cc
+++ b/src/pass/split_host_device.cc
@@ -165,7 +165,7 @@ class IRUseDefAnalysis : public IRMutator {
class HostDeviceSplitter : public IRMutator {
public:
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
- handle_data_type_[op->buffer_var.get()] = make_const(op->type, 0);
+ handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
return IRMutator::Mutate_(op, s);
}
@@ -209,7 +209,7 @@ class HostDeviceSplitter : public IRMutator {
n->thread_axis = m.thread_axis_;
// Strictly order the arguments: Var pointers, positional arguments.
for (Var v : m.undefined_) {
- if (v.type().is_handle()) {
+ if (v.dtype().is_handle()) {
n->args.push_back(v);
// mark handle data type.
auto it = handle_data_type_.find(v.get());
@@ -219,7 +219,7 @@ class HostDeviceSplitter : public IRMutator {
}
}
for (Var v : m.undefined_) {
- if (!v.type().is_handle()) {
+ if (!v.dtype().is_handle()) {
n->args.push_back(v);
}
}
@@ -234,7 +234,7 @@ class HostDeviceSplitter : public IRMutator {
}
device_funcs_.emplace_back(f_device);
return Evaluate::make(Call::make(
- Int(32), intrinsic::tvm_call_packed,
+ DataType::Int(32), intrinsic::tvm_call_packed,
call_args, Call::Intrinsic));
}
diff --git a/src/pass/split_pipeline.cc b/src/pass/split_pipeline.cc
index 7aefb1b..549b4c6 100644
--- a/src/pass/split_pipeline.cc
+++ b/src/pass/split_pipeline.cc
@@ -116,7 +116,7 @@ class MarkChannelAccess : public IRMutator {
int32_t csize = op->constant_allocation_size();
Expr alloc_size;
if (csize > 0) {
- alloc_size = IntImm::make(Int(32), csize);
+ alloc_size = IntImm::make(DataType::Int(32), csize);
} else {
alloc_size = op->extents[0];
for (size_t i = 1; i < op->extents.size(); ++i) {
@@ -183,17 +183,17 @@ class StageSplitter : public IRMutator {
std::ostringstream cname;
cname << "fifo." << temp_fifo_count_++;
// Create FIFO channel for load.
- Channel ch = ChannelNode::make(Var(cname.str(), Handle()), op->type);
+ Channel ch = ChannelNode::make(Var(cname.str(), DataType::Handle()), op->dtype);
Expr index = Mutate(op->index);
Stmt provide = Store::make(
ch->handle_var,
- Load::make(op->type, op->buffer_var, index, op->predicate),
+ Load::make(op->dtype, op->buffer_var, index, op->predicate),
0, op->predicate);
Stmt temp = nest_.back(); nest_.pop_back();
stages_.emplace_back(BuildStage(provide, ch));
nest_.push_back(temp);
fifo_map_[ch->handle_var.get()] = ch;
- return Load::make(op->type, ch->handle_var, 0, op->predicate);
+ return Load::make(op->dtype, ch->handle_var, 0, op->predicate);
}
Stmt Split(Stmt stmt, const ProducerConsumer* env) {
@@ -246,7 +246,7 @@ class StageSplitter : public IRMutator {
} else if (s.as<Block>()) {
} else if (const Allocate* op = s.as<Allocate>()) {
nest.emplace_back(Allocate::make(
- op->buffer_var, op->type, op->extents,
+ op->buffer_var, op->dtype, op->extents,
op->condition, no_op, op->new_expr, op->free_function));
MarkChannel(op);
} else {
@@ -256,11 +256,11 @@ class StageSplitter : public IRMutator {
body = Substitute(MergeNest(nest, body), subst);
return AttrStmt::make(
target, ir::attr::pipeline_stage_scope,
- make_const(Int(32), stage_index), body);
+ make_const(DataType::Int(32), stage_index), body);
}
void MarkChannel(const Allocate* op) {
if (!cmap_.count(op->buffer_var.get())) {
- Channel ch = ChannelNode::make(Var(op->buffer_var), op->type);
+ Channel ch = ChannelNode::make(Var(op->buffer_var), op->dtype);
cmap_[op->buffer_var.get()] = ch;
}
}
diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc
index 83fc032..0fff1e6 100644
--- a/src/pass/ssa.cc
+++ b/src/pass/ssa.cc
@@ -83,7 +83,7 @@ class IRConvertSSA final : public IRMutator {
const VarExpr& v = op->var;
if (defined_.count(v.get())) {
Expr value = IRMutator::Mutate(op->value);
- VarExpr new_var = Variable::make(v.type(), v->name_hint);
+ VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Expr body = IRMutator::Mutate(op->body);
scope_[v.get()].pop_back();
@@ -98,7 +98,7 @@ class IRConvertSSA final : public IRMutator {
op = expr.as<Load>();
if (scope_.count(op->buffer_var.get())) {
return Load::make(
- op->type, scope_[op->buffer_var.get()].back(),
+ op->dtype, scope_[op->buffer_var.get()].back(),
op->index, op->predicate);
} else {
return expr;
@@ -119,7 +119,7 @@ class IRConvertSSA final : public IRMutator {
const VarExpr& v = op->var;
if (defined_.count(v.get())) {
Expr value = IRMutator::Mutate(op->value);
- VarExpr new_var = Variable::make(v.type(), v->name_hint);
+ VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt body = IRMutator::Mutate(op->body);
scope_[v.get()].pop_back();
@@ -132,7 +132,7 @@ class IRConvertSSA final : public IRMutator {
Stmt Mutate_(const For* op, const Stmt& s) final {
const VarExpr& v = op->loop_var;
if (defined_.count(v.get())) {
- VarExpr new_var = Variable::make(v.type(), v->name_hint);
+ VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt stmt = IRMutator::Mutate_(op, s);
scope_[v.get()].pop_back();
@@ -147,13 +147,13 @@ class IRConvertSSA final : public IRMutator {
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
const VarExpr& v = op->buffer_var;
if (defined_.count(v.get())) {
- VarExpr new_var = Variable::make(v.type(), v->name_hint);
+ VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt stmt = IRMutator::Mutate_(op, s);
scope_[v.get()].pop_back();
op = stmt.as<Allocate>();
return Allocate::make(
- new_var, op->type, op->extents, op->condition,
+ new_var, op->dtype, op->extents, op->condition,
op->body, op->new_expr, op->free_function);
} else {
defined_.insert(v.get());
diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc
index 6f9e18f..c146a87 100644
--- a/src/pass/storage_access.cc
+++ b/src/pass/storage_access.cc
@@ -40,7 +40,7 @@ void StorageAccessVisitor::Visit_(const Load* op) {
AccessEntry e;
e.threads = env_threads();
e.buffer = op->buffer_var;
- e.dtype = op->type.element_of();
+ e.dtype = op->dtype.element_of();
e.touched = arith::IntSet::vector(op->index);
e.type = kRead;
e.scope = scope;
@@ -60,7 +60,7 @@ void StorageAccessVisitor::Visit_(const Store* op) {
AccessEntry e;
e.threads = env_threads();
e.buffer = op->buffer_var;
- e.dtype = op->value.type().element_of();
+ e.dtype = op->value.dtype().element_of();
e.touched = arith::IntSet::vector(op->index);
e.type = kWrite;
e.scope = scope;
@@ -186,7 +186,7 @@ void StorageAccessVisitor::Visit_(const Call* op) {
IRVisitor::Visit_(l);
} else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
- Type dtype = op->args[0].type();
+ DataType dtype = op->args[0].dtype();
const Variable* buffer = op->args[1].as<Variable>();
Expr offset = op->args[2];
Expr extent = op->args[3];
@@ -251,7 +251,7 @@ class StorageAccessInfoLower : public IRMutator {
<< "Double allocation of " << it->second.scope.to_string();
if (info->head_address.defined()) {
return Allocate::make(
- op->buffer_var, op->type, op->extents, op->condition,
+ op->buffer_var, op->dtype, op->extents, op->condition,
op->body, info->head_address, "nop");
}
return op->body;
@@ -292,24 +292,24 @@ class StorageAccessInfoLower : public IRMutator {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
CHECK_EQ(op->args.size(), 5U);
- Type dtype = op->args[0].type();
+ DataType dtype = op->args[0].dtype();
const Variable* buffer = op->args[1].as<Variable>();
Var buffer_var = Downcast<Var>(op->args[1]);
Expr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.info.defined()) {
return MakeTaggedAccessPtr(
- op->type, buffer_var, dtype, offset,
+ op->dtype, buffer_var, dtype, offset,
it->second.info);
}
- CHECK(op->type.is_handle());
+ CHECK(op->dtype.is_handle());
// Change to address_of
return AddressOffset(buffer_var, dtype, offset);
}
- Expr MakeTaggedAccessPtr(Type ptr_type,
+ Expr MakeTaggedAccessPtr(DataType ptr_type,
Var buffer_var,
- Type dtype,
+ DataType dtype,
Expr offset,
const MemoryInfo& info) {
if (ptr_type.is_handle()) {
@@ -321,7 +321,7 @@ class StorageAccessInfoLower : public IRMutator {
CHECK_EQ(info->unit_bits % dtype_bits, 0);
return cast(ptr_type,
ir::Simplify(offset / make_const(
- offset.type(), info->unit_bits / dtype_bits)));
+ offset.dtype(), info->unit_bits / dtype_bits)));
}
// The storage entry.
struct StorageEntry {
diff --git a/src/pass/storage_access.h b/src/pass/storage_access.h
index 8832b52..028645b 100644
--- a/src/pass/storage_access.h
+++ b/src/pass/storage_access.h
@@ -58,7 +58,7 @@ class StorageAccessVisitor : public IRVisitor {
/*! \brief The buffer variable, if any */
Var buffer = NullValue<Var>();
/*! \brief The access data type */
- Type dtype;
+ DataType dtype;
/*! \brief The touched access range */
arith::IntSet touched;
/*! \brief The type of access */
diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc
index 3851e1c..d6dde29 100644
--- a/src/pass/storage_flatten.cc
+++ b/src/pass/storage_flatten.cc
@@ -137,7 +137,7 @@ class StorageFlattener : public IRMutator {
<< "Read a buffer that is already out of scope";
if (is_opengl_) {
return Evaluate::make(Call::make(
- Type(),
+ DataType(),
Call::glsl_texture_store,
{e.buffer->data, op->value},
Call::Intrinsic));
@@ -190,12 +190,12 @@ class StorageFlattener : public IRMutator {
// use small alignment for small arrays
int32_t const_size = Allocate::constant_allocation_size(shape);
- int align = GetTempAllocaAlignment(op->type, const_size);
+ int align = GetTempAllocaAlignment(op->dtype, const_size);
if (skey.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(skey.to_string());
if (info.defined()) {
- align = (info->max_simd_bits + op->type.bits() - 1) / op->type.bits();
- CHECK_LE(const_size * op->type.bits(), info->max_num_bits)
+ align = (info->max_simd_bits + op->dtype.bits() - 1) / op->dtype.bits();
+ CHECK_LE(const_size * op->dtype.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag " << skey.to_string();
}
}
@@ -204,12 +204,12 @@ class StorageFlattener : public IRMutator {
std::vector<Expr> rstrides;
const std::vector<DimAlignInfo>& avec = dim_align_[key];
int first_dim = 0;
- Expr stride = make_const(shape[first_dim].type(), 1);
+ Expr stride = make_const(shape[first_dim].dtype(), 1);
for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1;
if (dim < avec.size() && avec[dim].align_factor != 0) {
- Expr factor = make_const(stride.type(), avec[dim].align_factor);
- Expr offset = make_const(stride.type(), avec[dim].align_offset);
+ Expr factor = make_const(stride.dtype(), avec[dim].align_factor);
+ Expr offset = make_const(stride.dtype(), avec[dim].align_offset);
stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
stride = ir::Simplify(stride);
}
@@ -220,8 +220,8 @@ class StorageFlattener : public IRMutator {
}
e.buffer = BufferNode::make(
- Var(key.GetName(), Handle()),
- op->type, shape, strides, Expr(),
+ Var(key.GetName(), DataType::Handle()),
+ op->dtype, shape, strides, Expr(),
key.GetName(), skey.to_string(),
align, 0, kDefault);
@@ -230,26 +230,26 @@ class StorageFlattener : public IRMutator {
buf_map_[key].released = true;
Stmt ret;
- Type storage_type = e.buffer->dtype;
+ DataType storage_type = e.buffer->dtype;
// specially handle bool, lower its storage
- // type to be Int(8)(byte)
- if (storage_type == Bool()) {
- storage_type = Int(8);
+ // type to beDataType::Int(8)(byte)
+ if (storage_type == DataType::Bool()) {
+ storage_type = DataType::Int(8);
}
if (strides.size() != 0) {
int first_dim = 0;
ret = Allocate::make(
e.buffer->data, storage_type,
{e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
- make_const(Bool(e.buffer->dtype.lanes()), true), body);
+ make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
} else {
shape = e.buffer->shape;
if (shape.size() == 0) {
- shape.push_back(make_const(Int(32), 1));
+ shape.push_back(make_const(DataType::Int(32), 1));
}
ret = Allocate::make(
e.buffer->data, storage_type, shape,
- make_const(Bool(e.buffer->dtype.lanes()), true), body);
+ make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
}
ret = AttrStmt::make(
e.buffer->data, attr::storage_scope,
@@ -271,7 +271,7 @@ class StorageFlattener : public IRMutator {
!it->second.same_as(op->buffer_var)) {
CHECK(it->second.as<Variable>());
VarExpr buf_var = Downcast<VarExpr>(it->second);
- return Load::make(op->type, buf_var, op->index, op->predicate);
+ return Load::make(op->dtype, buf_var, op->index, op->predicate);
} else {
return expr;
}
@@ -342,10 +342,12 @@ class StorageFlattener : public IRMutator {
args.push_back(op->bounds[i]->min);
}
auto &func_name = op->func->func_name();
- vars.push_back(VarExpr("prefetch." + func_name + "." + std::to_string(starts), Int(32)));
+ vars.push_back(VarExpr(
+ "prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32)));
args.push_back(op->bounds[starts]->min + stride * vars.back());
for (int i = starts - 1; i >= 0; --i) {
- vars.push_back(VarExpr("prefetch." + func_name + "." + std::to_string(i), Int(32)));
+ vars.push_back(VarExpr(
+ "prefetch." + func_name + "." + std::to_string(i), DataType::Int(32)));
args.push_back(vars.back() + op->bounds[i]->min);
}
for (int i = starts; i >= 0; --i) {
@@ -354,8 +356,8 @@ class StorageFlattener : public IRMutator {
vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
} else {
Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
- Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
- Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
+ Expr address = Call::make(DataType::Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
+ Expr prefetch = Call::make(op->dtype, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
stmt = Evaluate::make(prefetch);
Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
@@ -484,7 +486,7 @@ class StorageFlattener : public IRMutator {
return false;
for (size_t i = 0; i < shape.size(); ++i) {
- if (!shape[i].defined() || !shape[i].type().is_scalar() ||
+ if (!shape[i].defined() || !shape[i].dtype().is_scalar() ||
is_negative_const(shape[i])) {
return false;
}
@@ -492,12 +494,12 @@ class StorageFlattener : public IRMutator {
return true;
}
- Expr MakeBound(const Type &type, const Array<Expr> &shape) {
+ Expr MakeBound(const DataType &type, const Array<Expr> &shape) {
// We have already checked the shape size to be greater then 0.
- Expr bound = Mul::make(make_const(shape[0].type(), type.lanes()), shape[0]);
+ Expr bound = Mul::make(make_const(shape[0].dtype(), type.lanes()), shape[0]);
for (size_t i = 1; i < shape.size(); ++i) {
bound = Mul::make(
- bound, Mul::make(make_const(bound.type(), type.lanes()), shape[i]));
+ bound, Mul::make(make_const(bound.dtype(), type.lanes()), shape[i]));
}
return bound;
}
diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc
index 18b6634..12a06da 100644
--- a/src/pass/storage_rewrite.cc
+++ b/src/pass/storage_rewrite.cc
@@ -306,7 +306,7 @@ class InplaceOpVerifier : public IRVisitor {
}
if (src_ == buf) {
if (store_ == nullptr ||
- store_->value.type() != op->type ||
+ store_->value.dtype() != op->dtype ||
!ir::Equal(store_->index, op->index)) {
result_ = false; return;
}
@@ -370,7 +370,7 @@ class StoragePlanRewriter : public IRMutator {
if (it == alloc_map_.end()) return stmt;
return Store::make(it->second->alloc_var,
op->value,
- RemapIndex(op->value.type(), op->index, it->second),
+ RemapIndex(op->value.dtype(), op->index, it->second),
op->predicate);
}
Expr Mutate_(const Load* op, const Expr& e) final {
@@ -378,9 +378,9 @@ class StoragePlanRewriter : public IRMutator {
op = expr.as<Load>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return expr;
- return Load::make(op->type,
+ return Load::make(op->dtype,
it->second->alloc_var,
- RemapIndex(op->type, op->index, it->second),
+ RemapIndex(op->dtype, op->index, it->second),
op->predicate);
}
Expr Mutate_(const Variable* op, const Expr& e) final {
@@ -397,7 +397,7 @@ class StoragePlanRewriter : public IRMutator {
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
- Type dtype = op->args[0].type();
+ DataType dtype = op->args[0].dtype();
const Variable* buffer = op->args[1].as<Variable>();
auto it = alloc_map_.find(buffer);
if (it == alloc_map_.end()) return IRMutator::Mutate_(op, e);
@@ -407,10 +407,10 @@ class StoragePlanRewriter : public IRMutator {
uint64_t elem_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(se->bits_offset % elem_bits, 0U);
if (se->bits_offset != 0) {
- offset = make_const(offset.type(), se->bits_offset / elem_bits) + offset;
+ offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset;
}
return Call::make(
- op->type, op->name,
+ op->dtype, op->name,
{op->args[0], se->alloc_var, offset, extent, op->args[4]},
op->call_type);
} else {
@@ -485,7 +485,7 @@ class StoragePlanRewriter : public IRMutator {
// The var expr of new allocation.
VarExpr alloc_var;
// The allocation element type.
- Type elem_type;
+ DataType elem_type;
// This is non-zero if this allocate is folded into another one
// the address(in bits) becomes alloc_var + bits_offset;
// can be effectively converted to the element type.
@@ -524,11 +524,11 @@ class StoragePlanRewriter : public IRMutator {
return MergeNest(nest, body);
}
// Remap the index
- Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) {
+ Expr RemapIndex(DataType dtype, Expr index, StorageEntry* e) {
if (e->bits_offset == 0) return index;
uint64_t elem_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(e->bits_offset % elem_bits, 0U);
- return make_const(index.type(), e->bits_offset / elem_bits) + index;
+ return make_const(index.dtype(), e->bits_offset / elem_bits) + index;
}
// Prepare the new allocations
void PrepareNewAlloc() {
@@ -564,16 +564,16 @@ class StoragePlanRewriter : public IRMutator {
}
// Get the allocation size;
e->alloc_var = e->allocs[0]->buffer_var;
- Type alloc_type = e->allocs[0]->type;
+ DataType alloc_type = e->allocs[0]->dtype;
for (const Allocate* op : e->allocs) {
- if (op->type.lanes() > alloc_type.lanes()) {
- alloc_type = op->type;
+ if (op->dtype.lanes() > alloc_type.lanes()) {
+ alloc_type = op->dtype;
}
}
if (e->allocs.size() == 1) {
// simply use the original allocation.
Expr sz = arith::ComputeReduce<Mul>(e->allocs[0]->extents,
- make_const(Int(32), 1));
+ make_const(DataType::Int(32), 1));
e->new_alloc = Allocate::make(
e->alloc_var, alloc_type, {sz},
e->allocs[0]->condition, Evaluate::make(0));
@@ -587,8 +587,8 @@ class StoragePlanRewriter : public IRMutator {
// Build a merged allocation
Expr combo_size;
for (const Allocate* op : e->allocs) {
- Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(Int(32), 1));
- auto nbits = op->type.bits() * op->type.lanes();
+ Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(DataType::Int(32), 1));
+ auto nbits = op->dtype.bits() * op->dtype.lanes();
if (const auto* imm = sz.as<IntImm>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
LOG(WARNING) << "The allocation requires : " << imm->value
@@ -596,7 +596,7 @@ class StoragePlanRewriter : public IRMutator {
<< " bits, which is greater than the maximum of"
" int32. The size is cast to int64."
<< "\n";
- sz = make_const(Int(64), imm->value);
+ sz = make_const(DataType::Int(64), imm->value);
}
}
// transform to bits
@@ -613,7 +613,7 @@ class StoragePlanRewriter : public IRMutator {
combo_size = indexdiv(combo_size, type_bits);
// round up for can not divided
if (!divided) {
- combo_size = combo_size + make_const(Int(32), 1);
+ combo_size = combo_size + make_const(DataType::Int(32), 1);
}
combo_size = ir::Simplify(combo_size);
e->new_alloc = Allocate::make(
@@ -658,7 +658,7 @@ class StoragePlanRewriter : public IRMutator {
}
}
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
- Expr alloc_size = make_const(e->allocs[0]->extents[0].type(),
+ Expr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
(total_bits + type_bits - 1) / type_bits);
e->new_alloc = Allocate::make(
e->alloc_var, e->elem_type, {alloc_size}, const_true(),
@@ -751,12 +751,12 @@ class StoragePlanRewriter : public IRMutator {
StorageEntry* src_entry = alloc_map_.at(src);
if (src_entry->scope == ae.storage_scope &&
src_entry->attach_scope_ == thread_scope_ &&
- src_entry->elem_type == ae.alloc->type.element_of() &&
+ src_entry->elem_type == ae.alloc->dtype.element_of() &&
visitor.Check(s.stmt, var, src)) {
uint64_t const_nbits =
static_cast<uint64_t>(ae.alloc->constant_allocation_size()) *
- ae.alloc->type.bits() *
- ae.alloc->type.lanes();
+ ae.alloc->dtype.bits() *
+ ae.alloc->dtype.lanes();
if (src_entry->const_nbits == const_nbits && !inplace_found) {
// successfully inplace
dst_entry = src_entry;
@@ -816,7 +816,7 @@ class StoragePlanRewriter : public IRMutator {
std::unique_ptr<StorageEntry> entry(new StorageEntry());
entry->attach_scope_ = attach_scope;
entry->scope = scope;
- entry->elem_type = op->type.element_of();
+ entry->elem_type = op->dtype.element_of();
entry->const_nbits = const_nbits;
StorageEntry* e = entry.get();
alloc_vec_.emplace_back(std::move(entry));
@@ -830,13 +830,13 @@ class StoragePlanRewriter : public IRMutator {
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
- uint64_t op_elem_bits = op->type.bits() * op->type.lanes();
+ uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
uint64_t const_nbits = static_cast<uint64_t>(
op->constant_allocation_size() * op_elem_bits);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (scope.tag.length() == 0) {
- if (scope.rank >= StorageRank::kWarp || op->type.is_handle()) {
+ if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
if (const_nbits > 0 && const_nbits <= 32) {
@@ -865,7 +865,7 @@ class StoragePlanRewriter : public IRMutator {
StorageEntry *e = it->second;
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
- if (e->elem_type != op->type.element_of()) continue;
+ if (e->elem_type != op->dtype.element_of()) continue;
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
@@ -877,7 +877,7 @@ class StoragePlanRewriter : public IRMutator {
StorageEntry* e = *it;
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
- if (e->elem_type != op->type.element_of()) continue;
+ if (e->elem_type != op->dtype.element_of()) continue;
sym_free_list_.erase(it);
return e;
}
@@ -896,7 +896,7 @@ class StoragePlanRewriter : public IRMutator {
if (e->scope.tag.length() == 0) {
// Disable sharing of local memory.
if (e->scope.rank >= StorageRank::kWarp ||
- e->allocs[0]->type.is_handle()) return;
+ e->allocs[0]->dtype.is_handle()) return;
// disable reuse of small arrays
if (e->const_nbits > 0 && e->const_nbits <= 32) return;
}
@@ -932,17 +932,17 @@ class StoragePlanRewriter : public IRMutator {
class VectorAllocRewriter : public IRMutator {
public:
Expr Mutate_(const Load* op, const Expr& e) final {
- UpdateTypeMap(op->buffer_var.get(), op->type);
+ UpdateTypeMap(op->buffer_var.get(), op->dtype);
return IRMutator::Mutate_(op, e);
}
Stmt Mutate_(const Store* op, const Stmt& s) final {
- UpdateTypeMap(op->buffer_var.get(), op->value.type());
+ UpdateTypeMap(op->buffer_var.get(), op->value.dtype());
return IRMutator::Mutate_(op, s);
}
Expr Mutate_(const Call* op, const Expr& e) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
- Type dtype = op->args[0].type();
+ DataType dtype = op->args[0].dtype();
const Variable* buffer = op->args[1].as<Variable>();
UpdateTypeMap(buffer, dtype);
}
@@ -955,15 +955,15 @@ class VectorAllocRewriter : public IRMutator {
const auto& tvec = acc_map_[op->buffer_var.get()];
if (tvec.size() == 1 &&
- tvec[0].element_of() == op->type.element_of() &&
- tvec[0].lanes() % op->type.lanes() == 0 &&
- tvec[0].lanes() != op->type.lanes()) {
- int factor = tvec[0].lanes() / op->type.lanes();
+ tvec[0].element_of() == op->dtype.element_of() &&
+ tvec[0].lanes() % op->dtype.lanes() == 0 &&
+ tvec[0].lanes() != op->dtype.lanes()) {
+ int factor = tvec[0].lanes() / op->dtype.lanes();
Array<Expr> extents = op->extents;
arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]);
if (me->base % factor == 0 && me->coeff % factor == 0) {
extents.Set(extents.size() - 1,
- extents[extents.size() - 1] / make_const(extents[0].type(), factor));
+ extents[extents.size() - 1] / make_const(extents[0].dtype(), factor));
return Allocate::make(
op->buffer_var, tvec[0], extents,
op->condition, op->body);
@@ -972,7 +972,7 @@ class VectorAllocRewriter : public IRMutator {
return stmt;
}
- void UpdateTypeMap(const Variable* buffer, Type t) {
+ void UpdateTypeMap(const Variable* buffer, DataType t) {
auto& tvec = acc_map_[buffer];
if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
tvec.push_back(t);
@@ -980,7 +980,7 @@ class VectorAllocRewriter : public IRMutator {
}
// Internal access map
- std::unordered_map<const Variable*, std::vector<Type> > acc_map_;
+ std::unordered_map<const Variable*, std::vector<DataType> > acc_map_;
// internal analyzer
arith::Analyzer analyzer_;
};
@@ -991,7 +991,7 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
VectorAllocRewriter rewriter;
n->body = rewriter.Mutate(n->body);
for (Var arg : f->args) {
- if (arg.type().is_handle()) {
+ if (arg.dtype().is_handle()) {
const auto& tvec = rewriter.acc_map_[arg.get()];
if (tvec.size() == 1) {
Expr dtype = make_const(tvec[0], 0);
diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc
index 69b1a31..018a6bb 100644
--- a/src/pass/storage_sync.cc
+++ b/src/pass/storage_sync.cc
@@ -211,7 +211,7 @@ class ThreadSyncInserter : public IRMutator {
barrier = MakeGlobalBarrier();
} else {
barrier = Evaluate::make(
- Call::make(Int(32), intrinsic::tvm_storage_sync,
+ Call::make(DataType::Int(32), intrinsic::tvm_storage_sync,
{StringImm::make(sync_scope_.to_string())},
Call::Intrinsic));
}
@@ -303,7 +303,7 @@ class ThreadSyncInserter : public IRMutator {
CHECK(op != nullptr);
Array<Expr> pargs = {StringImm::make(runtime::symbol::tvm_prepare_global_barrier)};
Stmt prep = Evaluate::make(
- Call::make(Int(32), intrinsic::tvm_call_packed, pargs, Call::Intrinsic));
+ Call::make(DataType::Int(32), intrinsic::tvm_call_packed, pargs, Call::Intrinsic));
Stmt body = op->body;
for (const auto& kv : rw_stats_) {
const auto& e = kv.second;
@@ -313,7 +313,7 @@ class ThreadSyncInserter : public IRMutator {
}
rw_stats_.clear();
Stmt kinit = Evaluate::make(
- Call::make(Int(32), intrinsic::tvm_global_barrier_kinit, {}, Call::Intrinsic));
+ Call::make(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, Call::Intrinsic));
body = Block::make(kinit, body);
body = AttrStmt::make(
op->node, op->attr_key, op->value, body);
@@ -331,7 +331,7 @@ class ThreadSyncInserter : public IRMutator {
num_blocks_ = (num_blocks_.defined() ?
attr->value * num_blocks_ : attr->value);
} else if (s.rank == 1) {
- Expr cond = iv->var == make_zero(iv->var.type());
+ Expr cond = iv->var == make_zero(iv->var.dtype());
is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond;
}
}
@@ -339,7 +339,7 @@ class ThreadSyncInserter : public IRMutator {
CHECK_EQ(num_work_dim_, thread_extents_.size());
}
return Evaluate::make(
- Call::make(Int(32), intrinsic::tvm_storage_sync,
+ Call::make(DataType::Int(32), intrinsic::tvm_storage_sync,
{StringImm::make(sync_scope_.to_string()),
is_lead_, num_blocks_},
Call::Intrinsic));
diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc
index b855427..2ead2b9 100644
--- a/src/pass/tensor_core.cc
+++ b/src/pass/tensor_core.cc
@@ -60,11 +60,11 @@ std::string simplify_name(std::string input) {
}
}
-Expr unpack_type_cast(const Expr &input, const Type &target_type) {
... 3110 lines suppressed ...