You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/31 15:11:16 UTC

[GitHub] [tvm] tqchen commented on a change in pull request #7045: [Arith] Simplify cast

tqchen commented on a change in pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#discussion_r550500956



##########
File path: src/arith/canonical_simplify.cc
##########
@@ -77,6 +77,25 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   }
 }
 
+bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) {

Review comment:
       Perhaps rename to UpcastIsSafe?

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -128,6 +147,50 @@ class SplitExprNode : public CanonicalExprNode {
 
   void MulToSelf(int64_t scale) { this->scale *= scale; }
 
+  /*!
+   * \brief check if cast(dtype, self) is safe
+   * \param dtype The target datatype
+   * \param analyzer The analyzer
+   * \return whether the cast is safe or not
+   */
+  bool CheckCast(DataType dtype, Analyzer* analyzer) const {

Review comment:
       CanPushCastToChildren

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -77,6 +77,25 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   }
 }
 
+bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) {
+  if (!IsIndexType(dtype)) {
+    return false;
+  }
+  ConstIntBound bound = analyzer->const_int_bound(value);
+  int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
+  int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
+  if (value.dtype().bits() <= dtype.bits() ||  // upcast is safe
+      (bound->max_value <= ubound && bound->min_value >= lbound)) {
+    return true;
+  }
+  return false;
+}
+
+#define TVM_CHECK_CANONICAL_SIMPLIFY_CAST(DTYPE, VALUE) \
+  if (!CheckCastImpl(DTYPE, VALUE, analyzer)) {         \

Review comment:
       inline the function call, while macro makes the code a bit more concise, it makes the code itself harder to understand.

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -77,6 +77,25 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   }
 }
 
+bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) {

Review comment:
       document this function

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -255,6 +318,61 @@ class SumExprNode : public CanonicalExprNode {
 
   void AddToSelf(const SumExpr& other, int64_t scale);
 
+  /*!
+   * \brief check if cast(dtype, self) is safe
+   * \param dtype The target datatype
+   * \param analyzer The analyzer
+   * \return whether the cast is safe or not
+   */
+  bool CheckCast(DataType dtype, Analyzer* analyzer) const {
+    // cast(dtype, arg_1 + arg_2 + ... arg_n) ==
+    // cast(dtype, arg_1) + ... + cast(dtype, arg_n)
+    // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
+    // its intermediate results fit in the range of dtype
+    if (dtype.bits() >= this->dtype.bits()) {
+      return true;  // upcast is safe
+    }
+    PrimExpr res = make_const(dtype, 0);
+    for (size_t i = 0; i < args.size(); ++i) {
+      if (args[i]->scale > 0) {
+        res = res + args[i]->Normalize();
+        TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+      }
+    }
+    if (base > 0) {
+      res = res + make_const(dtype, base);
+      TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    }
+    // negative scales follows using sub.
+    for (size_t i = 0; i < args.size(); ++i) {
+      if (args[i]->scale < 0) {
+        res = res - args[i]->NormalizeWithScale(-1);
+        TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+      }
+    }
+    if (base < 0) {
+      res = res - make_const(dtype, -base);
+      TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    }
+    for (const auto& arg : args) {
+      if (!arg->CheckCast(dtype, analyzer)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /*!
+   * \brief self = cast(dtype, self)
+   * \param dtype The target datatype
+   */
+  void CastTo(DataType dtype) {

Review comment:
       PushCastToChildren

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -128,6 +147,50 @@ class SplitExprNode : public CanonicalExprNode {
 
   void MulToSelf(int64_t scale) { this->scale *= scale; }
 
+  /*!
+   * \brief check if cast(dtype, self) is safe
+   * \param dtype The target datatype
+   * \param analyzer The analyzer
+   * \return whether the cast is safe or not
+   */
+  bool CheckCast(DataType dtype, Analyzer* analyzer) const {
+    // cast(dtype, index % upper_factor / lower_factor * scale) ==
+    // cast(dtype, index) % upper_factor / lower_factor * scale
+    // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
+    // its intermediate results fit in the range of dtype
+    if (dtype.bits() >= this->dtype.bits()) {
+      return true;  // upcast is safe
+    }
+    PrimExpr res = this->index;
+    if (this->scale == 0) {
+      return true;
+    }
+    TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    if (this->upper_factor != SplitExprNode::kPosInf) {
+      res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode);
+      TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    }
+    if (this->lower_factor != 1) {
+      res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode);
+      TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    }
+    if (this->scale != 1) {
+      ICHECK(!this->dtype.is_uint() || this->scale > 0);
+      res = res * make_const(this->dtype, this->scale);
+      TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    }
+    return true;
+  }
+
+  /*!
+   * \brief self = cast(dtype, self)
+   * \param dtype The target datatype
+   */
+  void CastTo(DataType dtype) {

Review comment:
       Rename to `PushCastToChildren`

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -1071,6 +1192,32 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
   return ret;
 }
 
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
+  if (!IsIndexType(op->dtype)) {
+    return Rewriter::VisitExpr_(op);
+  }
+  // normalize
+  PrimExpr value = this->CanonicalMutate(op->value);
+  PrimExpr ret;

Review comment:
       comment: PushCastToChildren




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org