You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/06/13 00:23:18 UTC

[incubator-tvm] branch master updated: [TIR][REFACTOR] Cleanup unused classes (#5789)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 1c256f4  [TIR][REFACTOR] Cleanup unused classes (#5789)
1c256f4 is described below

commit 1c256f48a415e7c775cbf2a892a3d8ca29e3d25d
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Fri Jun 12 17:23:05 2020 -0700

    [TIR][REFACTOR] Cleanup unused classes (#5789)
---
 include/tvm/arith/bound.h             |  8 ++------
 include/tvm/te/operation.h            |  8 +++++---
 include/tvm/te/tensor.h               |  5 +++--
 include/tvm/tir/expr.h                | 34 ----------------------------------
 include/tvm/tir/var.h                 |  2 --
 src/arith/domain_touched.cc           |  6 +++---
 src/contrib/hybrid/codegen_hybrid.cc  |  2 +-
 src/te/schedule/graph.cc              | 12 ++++++------
 src/tir/transforms/inject_prefetch.cc |  2 +-
 9 files changed, 21 insertions(+), 58 deletions(-)

diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h
index df1a9e7..12b91cc 100644
--- a/include/tvm/arith/bound.h
+++ b/include/tvm/arith/bound.h
@@ -32,13 +32,9 @@
 #include <unordered_map>
 
 namespace tvm {
-// forward delcare Tensor
-namespace te {
-class Tensor;
-}
 namespace arith {
 
-using tir::Domain;
+using tir::Region;
 using tir::Stmt;
 using tir::Var;
 using tir::VarNode;
@@ -82,7 +78,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
  * \param consider_stores If stores are considered.
  * \return The domain that covers all the calls or provides within the given statement.
  */
-Domain DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads,
+Region DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads,
                      bool consider_stores);
 
 }  // namespace arith
diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h
index 4b7037a..dbd07fa 100644
--- a/include/tvm/te/operation.h
+++ b/include/tvm/te/operation.h
@@ -53,7 +53,7 @@ struct TensorDom {
 /*!
  * \brief Base class of all operation nodes
  */
-class OperationNode : public tir::FunctionBaseNode {
+class OperationNode : public Object {
  public:
   /*! \brief optional name of the operation */
   std::string name;
@@ -61,8 +61,10 @@ class OperationNode : public tir::FunctionBaseNode {
   std::string tag;
   /*! \brief additional attributes of the operation*/
   Map<String, ObjectRef> attrs;
-  /*! \return name of the operation */
-  const std::string& func_name() const final { return name; }
+  // virtual destructor.
+  virtual ~OperationNode() {}
+  /*! \return number of outputs */
+  virtual int num_outputs() const = 0;
   /*!
    * \return The list of iteration variable at root
    * \note root_iter_vars decides the shape of the outputs.
diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h
index 0c4af4b..2f9fa2f 100644
--- a/include/tvm/te/tensor.h
+++ b/include/tvm/te/tensor.h
@@ -42,13 +42,14 @@ using namespace tvm::tir;
 
 // internal node container for Operation
 class OperationNode;
+class Tensor;
 
 /*! \brief Operation that produces tensors */
-class Operation : public tir::FunctionRef {
+class Operation : public ObjectRef {
  public:
   /*! \brief default constructor  */
   Operation() {}
-  explicit Operation(ObjectPtr<Object> n) : FunctionRef(n) {}
+  explicit Operation(ObjectPtr<Object> n) : ObjectRef(n) {}
   /*!
    * \brief access the internal node container
    * \return the pointer to the internal node container
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 423f09e..4b6b28d 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -870,40 +870,6 @@ class Let : public PrimExpr {
   TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode);
 };
 
-// Call node, represent a function call or a multi-dimensional array load.
-//
-// TODO(tvm-team):
-// Refactor call with more explicit property registrations.
-// rather than calling a string symbol.
-// We should move most information into function itself and remove name.
-
-/*! \brief Base node of internal functions. */
-class FunctionBaseNode : public Object {
- public:
-  /*! \brief virtual destructor */
-  virtual ~FunctionBaseNode() {}
-  /*! \return the name of the function */
-  virtual const std::string& func_name() const = 0;
-  /*! \return the number of outputs of this function */
-  virtual int num_outputs() const = 0;
-
-  // fall back to pointer equality now before refactor.
-  bool SEqualReduce(const FunctionBaseNode* other, SEqualReducer equal) const {
-    return this == other;
-  }
-
-  void SHashReduce(SHashReducer hash_reduce) const {}
-
-  static constexpr const bool _type_has_method_sequal_reduce = true;
-  static constexpr const bool _type_has_method_shash_reduce = true;
-};
-
-/*! \brief reference to a function */
-class FunctionRef : public ObjectRef {
- public:
-  TVM_DEFINE_OBJECT_REF_METHODS(FunctionRef, ObjectRef, FunctionBaseNode);
-};
-
 /*!
  * \brief Call node.
  */
diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h
index 363bf6b..9f09824 100644
--- a/include/tvm/tir/var.h
+++ b/include/tvm/tir/var.h
@@ -226,8 +226,6 @@ enum IterVarType : int {
   kTensorized = 8
 };
 
-using Domain = Array<Range>;
-
 /*!
  * \brief An iteration variable representing an iteration
  *  over a one dimensional interval.
diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc
index 0ac4a89..b44d9f7 100644
--- a/src/arith/domain_touched.cc
+++ b/src/arith/domain_touched.cc
@@ -40,9 +40,9 @@ class BufferTouchedDomain final : public StmtExprVisitor {
   BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores)
       : buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {}
 
-  Domain Find(const Stmt& stmt) {
+  Region Find(const Stmt& stmt) {
     operator()(stmt);
-    Domain ret;
+    Region ret;
     Range none;
     for (size_t i = 0; i < bounds_.size(); ++i) {
       ret.push_back(arith::Union(bounds_[i]).cover_range(none));
@@ -107,7 +107,7 @@ class BufferTouchedDomain final : public StmtExprVisitor {
   std::unordered_map<const VarNode*, IntSet> dom_map_;
 };
 
-Domain DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
+Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
                      bool consider_stores) {
   return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt);
 }
diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc
index e9ec585..e08f39f 100644
--- a/src/contrib/hybrid/codegen_hybrid.cc
+++ b/src/contrib/hybrid/codegen_hybrid.cc
@@ -414,7 +414,7 @@ std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) {
   if (id_map_.count(key)) {
     return id_map_[key];
   }
-  std::string name_hint = tensor->op->func_name();
+  std::string name_hint = tensor->op->name;
   if (tensor->op->num_outputs() > 1) {
     name_hint += "_v" + std::to_string(tensor->value_index);
   }
diff --git a/src/te/schedule/graph.cc b/src/te/schedule/graph.cc
index 62557ed..09e8995 100644
--- a/src/te/schedule/graph.cc
+++ b/src/te/schedule/graph.cc
@@ -36,15 +36,15 @@ namespace tvm {
 namespace te {
 // key to specific tensor dimension.
 struct TensorDimKey {
-  tir::FunctionRef f;
+  Operation op;
   int value_index;
   int dim;
   TensorDimKey() {}
-  TensorDimKey(const Tensor& t, int dim) : f(t->op), value_index(t->value_index), dim(dim) {}
+  TensorDimKey(const Tensor& t, int dim) : op(t->op), value_index(t->value_index), dim(dim) {}
   TensorDimKey(const Tensor& t, size_t dim)
-      : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
+      : op(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {}
   inline bool operator==(const TensorDimKey& other) const {
-    return f == other.f && value_index == other.value_index && dim == other.dim;
+    return op == other.op && value_index == other.value_index && dim == other.dim;
   }
   inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); }
 };
@@ -55,7 +55,7 @@ namespace std {
 template <>
 struct hash<::tvm::te::TensorDimKey> {
   std::size_t operator()(const ::tvm::te::TensorDimKey& k) const {
-    size_t lhs = ::tvm::ObjectPtrHash()(k.f);
+    size_t lhs = ::tvm::ObjectPtrHash()(k.op);
     size_t rhs = static_cast<size_t>(k.value_index) << 16UL | static_cast<size_t>(k.dim);
     lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
     return lhs;
@@ -378,7 +378,7 @@ Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
           if (k != target && place_holder_ref.count(k)) break;
           stack.pop_back();
           if (!reach.count(k)) {
-            LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
+            LOG(FATAL) << "cannot find reach of " << k.op << "-" << k.dim;
           }
 
           for (TensorDimKey kk : reach.at(k)) {
diff --git a/src/tir/transforms/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc
index 3b626f0..9c27a71 100644
--- a/src/tir/transforms/inject_prefetch.cc
+++ b/src/tir/transforms/inject_prefetch.cc
@@ -45,7 +45,7 @@ class PrefetchInjector : public StmtMutator {
     if (op && op->attr_key == attr::prefetch_scope) {
       Buffer buffer = Downcast<Buffer>(op->node);
       CHECK_NE(loop_nest_.size(), 0U);
-      Domain domain = DomainTouched(op->body, buffer, true, false);
+      Region domain = DomainTouched(op->body, buffer, true, false);
       Region region;
 
       auto iter_var = loop_nest_.back().get();