You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/06/13 00:23:18 UTC
[incubator-tvm] branch master updated: [TIR][REFACTOR] Cleanup
unused classes (#5789)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 1c256f4 [TIR][REFACTOR] Cleanup unused classes (#5789)
1c256f4 is described below
commit 1c256f48a415e7c775cbf2a892a3d8ca29e3d25d
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Fri Jun 12 17:23:05 2020 -0700
[TIR][REFACTOR] Cleanup unused classes (#5789)
---
include/tvm/arith/bound.h | 8 ++------
include/tvm/te/operation.h | 8 +++++---
include/tvm/te/tensor.h | 5 +++--
include/tvm/tir/expr.h | 34 ----------------------------------
include/tvm/tir/var.h | 2 --
src/arith/domain_touched.cc | 6 +++---
src/contrib/hybrid/codegen_hybrid.cc | 2 +-
src/te/schedule/graph.cc | 12 ++++++------
src/tir/transforms/inject_prefetch.cc | 2 +-
9 files changed, 21 insertions(+), 58 deletions(-)
diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h
index df1a9e7..12b91cc 100644
--- a/include/tvm/arith/bound.h
+++ b/include/tvm/arith/bound.h
@@ -32,13 +32,9 @@
#include <unordered_map>
namespace tvm {
-// forward delcare Tensor
-namespace te {
-class Tensor;
-}
namespace arith {
-using tir::Domain;
+using tir::Region;
using tir::Stmt;
using tir::Var;
using tir::VarNode;
@@ -82,7 +78,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
* \param consider_stores If stores are considered.
* \return The domain that covers all the calls or provides within the given statement.
*/
-Domain DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads,
+Region DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads,
bool consider_stores);
} // namespace arith
diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h
index 4b7037a..dbd07fa 100644
--- a/include/tvm/te/operation.h
+++ b/include/tvm/te/operation.h
@@ -53,7 +53,7 @@ struct TensorDom {
/*!
* \brief Base class of all operation nodes
*/
-class OperationNode : public tir::FunctionBaseNode {
+class OperationNode : public Object {
public:
/*! \brief optional name of the operation */
std::string name;
@@ -61,8 +61,10 @@ class OperationNode : public tir::FunctionBaseNode {
std::string tag;
/*! \brief additional attributes of the operation*/
Map<String, ObjectRef> attrs;
- /*! \return name of the operation */
- const std::string& func_name() const final { return name; }
+ // virtual destructor.
+ virtual ~OperationNode() {}
+ /*! \return number of outputs */
+ virtual int num_outputs() const = 0;
/*!
* \return The list of iteration variable at root
* \note root_iter_vars decides the shape of the outputs.
diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h
index 0c4af4b..2f9fa2f 100644
--- a/include/tvm/te/tensor.h
+++ b/include/tvm/te/tensor.h
@@ -42,13 +42,14 @@ using namespace tvm::tir;
// internal node container for Operation
class OperationNode;
+class Tensor;
/*! \brief Operation that produces tensors */
-class Operation : public tir::FunctionRef {
+class Operation : public ObjectRef {
public:
/*! \brief default constructor */
Operation() {}
- explicit Operation(ObjectPtr<Object> n) : FunctionRef(n) {}
+ explicit Operation(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 423f09e..4b6b28d 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -870,40 +870,6 @@ class Let : public PrimExpr {
TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode);
};
-// Call node, represent a function call or a multi-dimensional array load.
-//
-// TODO(tvm-team):
-// Refactor call with more explicit property registrations.
-// rather than calling a string symbol.
-// We should move most information into function itself and remove name.
-
-/*! \brief Base node of internal functions. */
-class FunctionBaseNode : public Object {
- public:
- /*! \brief virtual destructor */
- virtual ~FunctionBaseNode() {}
- /*! \return the name of the function */
- virtual const std::string& func_name() const = 0;
- /*! \return the number of outputs of this function */
- virtual int num_outputs() const = 0;
-
- // fall back to pointer equality now before refactor.
- bool SEqualReduce(const FunctionBaseNode* other, SEqualReducer equal) const {
- return this == other;
- }
-
- void SHashReduce(SHashReducer hash_reduce) const {}
-
- static constexpr const bool _type_has_method_sequal_reduce = true;
- static constexpr const bool _type_has_method_shash_reduce = true;
-};
-
-/*! \brief reference to a function */
-class FunctionRef : public ObjectRef {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(FunctionRef, ObjectRef, FunctionBaseNode);
-};
-
/*!
* \brief Call node.
*/
diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h
index 363bf6b..9f09824 100644
--- a/include/tvm/tir/var.h
+++ b/include/tvm/tir/var.h
@@ -226,8 +226,6 @@ enum IterVarType : int {
kTensorized = 8
};
-using Domain = Array<Range>;
-
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional interval.
diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc
index 0ac4a89..b44d9f7 100644
--- a/src/arith/domain_touched.cc
+++ b/src/arith/domain_touched.cc
@@ -40,9 +40,9 @@ class BufferTouchedDomain final : public StmtExprVisitor {
BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores)
: buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {}
- Domain Find(const Stmt& stmt) {
+ Region Find(const Stmt& stmt) {
operator()(stmt);
- Domain ret;
+ Region ret;
Range none;
for (size_t i = 0; i < bounds_.size(); ++i) {
ret.push_back(arith::Union(bounds_[i]).cover_range(none));
@@ -107,7 +107,7 @@ class BufferTouchedDomain final : public StmtExprVisitor {
std::unordered_map<const VarNode*, IntSet> dom_map_;
};
-Domain DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
+Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
bool consider_stores) {
return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt);
}
diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc
index e9ec585..e08f39f 100644
--- a/src/contrib/hybrid/codegen_hybrid.cc
+++ b/src/contrib/hybrid/codegen_hybrid.cc
@@ -414,7 +414,7 @@ std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) {
if (id_map_.count(key)) {
return id_map_[key];
}
- std::string name_hint = tensor->op->func_name();
+ std::string name_hint = tensor->op->name;
if (tensor->op->num_outputs() > 1) {
name_hint += "_v" + std::to_string(tensor->value_index);
}
diff --git a/src/te/schedule/graph.cc b/src/te/schedule/graph.cc
index 62557ed..09e8995 100644
--- a/src/te/schedule/graph.cc
+++ b/src/te/schedule/graph.cc
@@ -36,15 +36,15 @@ namespace tvm {
namespace te {
// key to specific tensor dimension.
struct TensorDimKey {
- tir::FunctionRef f;
+ Operation op;
int value_index;
int dim;
TensorDimKey() {}
- TensorDimKey(const Tensor& t, int dim) : f(t->op), value_index(t->value_index), dim(dim) {}
+ TensorDimKey(const Tensor& t, int dim) : op(t->op), value_index(t->value_index), dim(dim) {}
TensorDimKey(const Tensor& t, size_t dim)
- : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
+ : op(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
inline bool operator==(const TensorDimKey& other) const {
- return f == other.f && value_index == other.value_index && dim == other.dim;
+ return op == other.op && value_index == other.value_index && dim == other.dim;
}
inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); }
};
@@ -55,7 +55,7 @@ namespace std {
template <>
struct hash<::tvm::te::TensorDimKey> {
std::size_t operator()(const ::tvm::te::TensorDimKey& k) const {
- size_t lhs = ::tvm::ObjectPtrHash()(k.f);
+ size_t lhs = ::tvm::ObjectPtrHash()(k.op);
size_t rhs = static_cast<size_t>(k.value_index) << 16UL | static_cast<size_t>(k.dim);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
@@ -378,7 +378,7 @@ Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
if (k != target && place_holder_ref.count(k)) break;
stack.pop_back();
if (!reach.count(k)) {
- LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
+ LOG(FATAL) << "cannot find reach of " << k.op << "-" << k.dim;
}
for (TensorDimKey kk : reach.at(k)) {
diff --git a/src/tir/transforms/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc
index 3b626f0..9c27a71 100644
--- a/src/tir/transforms/inject_prefetch.cc
+++ b/src/tir/transforms/inject_prefetch.cc
@@ -45,7 +45,7 @@ class PrefetchInjector : public StmtMutator {
if (op && op->attr_key == attr::prefetch_scope) {
Buffer buffer = Downcast<Buffer>(op->node);
CHECK_NE(loop_nest_.size(), 0U);
- Domain domain = DomainTouched(op->body, buffer, true, false);
+ Region domain = DomainTouched(op->body, buffer, true, false);
Region region;
auto iter_var = loop_nest_.back().get();