You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/07 10:53:57 UTC

[GitHub] [tvm] hzfan opened a new pull request #7045: [Arith] Simplify cast

hzfan opened a new pull request #7045:
URL: https://github.com/apache/tvm/pull/7045


   Follow up of #6691 
   Simplify `cast(i32, c * 2 + 1) + 1 - cast(i32, c * 2)` to `2` by first transforming to `cast(i32, c * 2) + cast(i32, 1) + 1 - cast(i32, c * 2)`
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#discussion_r548436987



##########
File path: src/arith/canonical_simplify.cc
##########
@@ -77,6 +77,25 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   }
 }
 
+bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) {
+  if (!IsIndexType(dtype)) {
+    return false;
+  }
+  ConstIntBound bound = analyzer->const_int_bound(value);
+  int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
+  int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
+  if (value.dtype().bits() <= dtype.bits() ||  // upcast is safe
+      (bound->max_value <= ubound && bound->min_value >= lbound)) {
+    return true;
+  }
+  return false;
+}
+
+#define CHECK_CAST(DTYPE, VALUE)                \

Review comment:
       Always prefix macro with `TVM_` (except for glog-style checks). Also please use a tedious name to avoid conflicts and remember to #undef it when it is not used any more




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#issuecomment-756221819


   Thanks @hzfan @junrushao1994 ! This is now merged


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] hzfan commented on a change in pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
hzfan commented on a change in pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#discussion_r551789221



##########
File path: src/arith/canonical_simplify.cc
##########
@@ -77,6 +77,25 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   }
 }
 
+bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) {

Review comment:
       I guess CastIsSafe seems better, since it checks both upcast and downcast




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#discussion_r548436987



##########
File path: src/arith/canonical_simplify.cc
##########
@@ -77,6 +77,25 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   }
 }
 
+bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) {
+  if (!IsIndexType(dtype)) {
+    return false;
+  }
+  ConstIntBound bound = analyzer->const_int_bound(value);
+  int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
+  int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
+  if (value.dtype().bits() <= dtype.bits() ||  // upcast is safe
+      (bound->max_value <= ubound && bound->min_value >= lbound)) {
+    return true;
+  }
+  return false;
+}
+
+#define CHECK_CAST(DTYPE, VALUE)                \

Review comment:
       Always prefix macro with `TVM_` (except for glog-style checks)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#issuecomment-743355546


   cc @merrymercy @yzhliu @kazum @ZihengJiang can you please help to take a look?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#discussion_r551976566



##########
File path: src/arith/canonical_simplify.cc
##########
@@ -1071,6 +1208,33 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
   return ret;
 }
 
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
+  if (!IsIndexType(op->dtype)) {
+    return Rewriter::VisitExpr_(op);
+  }
+  // normalize
+  PrimExpr value = this->CanonicalMutate(op->value);
+  PrimExpr ret;
+  // PushCastToChildren
+  if (value.as<SumExprNode>()) {
+    SumExpr se = Downcast<SumExpr>(value);
+    if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+      se.CopyOnWrite()->PushCastToChildren(op->dtype);
+      ret = se;

Review comment:
       consider directly return here.

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -1071,6 +1208,33 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
   return ret;
 }
 
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
+  if (!IsIndexType(op->dtype)) {
+    return Rewriter::VisitExpr_(op);
+  }
+  // normalize
+  PrimExpr value = this->CanonicalMutate(op->value);
+  PrimExpr ret;
+  // PushCastToChildren
+  if (value.as<SumExprNode>()) {
+    SumExpr se = Downcast<SumExpr>(value);
+    if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+      se.CopyOnWrite()->PushCastToChildren(op->dtype);
+      ret = se;
+    }
+  } else if (value.as<SplitExprNode>()) {
+    SplitExpr se = Downcast<SplitExpr>(value);
+    if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+      se.CopyOnWrite()->PushCastToChildren(op->dtype);
+      ret = se;

Review comment:
       consider directly return here.

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -1071,6 +1208,33 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
   return ret;
 }
 
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
+  if (!IsIndexType(op->dtype)) {
+    return Rewriter::VisitExpr_(op);
+  }
+  // normalize
+  PrimExpr value = this->CanonicalMutate(op->value);
+  PrimExpr ret;
+  // PushCastToChildren
+  if (value.as<SumExprNode>()) {
+    SumExpr se = Downcast<SumExpr>(value);
+    if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+      se.CopyOnWrite()->PushCastToChildren(op->dtype);
+      ret = se;
+    }
+  } else if (value.as<SplitExprNode>()) {
+    SplitExpr se = Downcast<SplitExpr>(value);
+    if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+      se.CopyOnWrite()->PushCastToChildren(op->dtype);
+      ret = se;
+    }
+  }
+  if (!ret.defined()) {
+    ret = Rewriter::VisitExpr_(op);

Review comment:
       return Rewriter::VisitExpr_(op);




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#issuecomment-752985049


   Thanks @hzfan for keep polishing the code :) Arith analysis is quite core to most of our transformations so it is important to make sure code is clean and readable
   
   I have carefully read the PR and add a few more comments to improve readability.
   
   @junrushao1994 @yzhliu It would be great to also have you take another look as well.
   
   
   
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#issuecomment-751847448


   It looks good on my side. @yzhliu would you mind taking a look? Thanks!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen merged pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #7045:
URL: https://github.com/apache/tvm/pull/7045


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] hzfan commented on pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
hzfan commented on pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#issuecomment-751660044


   @junrushao1994 I just fixed the macro thing and CI's green. Please take another look.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#discussion_r550500956



##########
File path: src/arith/canonical_simplify.cc
##########
@@ -77,6 +77,25 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   }
 }
 
+bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) {

Review comment:
       Perhaps rename to UpcastIsSafe?

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -128,6 +147,50 @@ class SplitExprNode : public CanonicalExprNode {
 
   void MulToSelf(int64_t scale) { this->scale *= scale; }
 
+  /*!
+   * \brief check if cast(dtype, self) is safe
+   * \param dtype The target datatype
+   * \param analyzer The analyzer
+   * \return whether the cast is safe or not
+   */
+  bool CheckCast(DataType dtype, Analyzer* analyzer) const {

Review comment:
       CanPushCastToChildren

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -77,6 +77,25 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   }
 }
 
+bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) {
+  if (!IsIndexType(dtype)) {
+    return false;
+  }
+  ConstIntBound bound = analyzer->const_int_bound(value);
+  int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
+  int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
+  if (value.dtype().bits() <= dtype.bits() ||  // upcast is safe
+      (bound->max_value <= ubound && bound->min_value >= lbound)) {
+    return true;
+  }
+  return false;
+}
+
+#define TVM_CHECK_CANONICAL_SIMPLIFY_CAST(DTYPE, VALUE) \
+  if (!CheckCastImpl(DTYPE, VALUE, analyzer)) {         \

Review comment:
       inline the function call, while macro makes the code a bit more concise, it makes the code itself harder to understand.

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -77,6 +77,25 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   }
 }
 
+bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) {

Review comment:
       document this function

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -255,6 +318,61 @@ class SumExprNode : public CanonicalExprNode {
 
   void AddToSelf(const SumExpr& other, int64_t scale);
 
+  /*!
+   * \brief check if cast(dtype, self) is safe
+   * \param dtype The target datatype
+   * \param analyzer The analyzer
+   * \return whether the cast is safe or not
+   */
+  bool CheckCast(DataType dtype, Analyzer* analyzer) const {
+    // cast(dtype, arg_1 + arg_2 + ... arg_n) ==
+    // cast(dtype, arg_1) + ... + cast(dtype, arg_n)
+    // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
+    // its intermediate results fit in the range of dtype
+    if (dtype.bits() >= this->dtype.bits()) {
+      return true;  // upcast is safe
+    }
+    PrimExpr res = make_const(dtype, 0);
+    for (size_t i = 0; i < args.size(); ++i) {
+      if (args[i]->scale > 0) {
+        res = res + args[i]->Normalize();
+        TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+      }
+    }
+    if (base > 0) {
+      res = res + make_const(dtype, base);
+      TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    }
+    // negative scales follows using sub.
+    for (size_t i = 0; i < args.size(); ++i) {
+      if (args[i]->scale < 0) {
+        res = res - args[i]->NormalizeWithScale(-1);
+        TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+      }
+    }
+    if (base < 0) {
+      res = res - make_const(dtype, -base);
+      TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    }
+    for (const auto& arg : args) {
+      if (!arg->CheckCast(dtype, analyzer)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /*!
+   * \brief self = cast(dtype, self)
+   * \param dtype The target datatype
+   */
+  void CastTo(DataType dtype) {

Review comment:
       PushCastToChildren

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -128,6 +147,50 @@ class SplitExprNode : public CanonicalExprNode {
 
   void MulToSelf(int64_t scale) { this->scale *= scale; }
 
+  /*!
+   * \brief check if cast(dtype, self) is safe
+   * \param dtype The target datatype
+   * \param analyzer The analyzer
+   * \return whether the cast is safe or not
+   */
+  bool CheckCast(DataType dtype, Analyzer* analyzer) const {
+    // cast(dtype, index % upper_factor / lower_factor * scale) ==
+    // cast(dtype, index) % upper_factor / lower_factor * scale
+    // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
+    // its intermediate results fit in the range of dtype
+    if (dtype.bits() >= this->dtype.bits()) {
+      return true;  // upcast is safe
+    }
+    PrimExpr res = this->index;
+    if (this->scale == 0) {
+      return true;
+    }
+    TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    if (this->upper_factor != SplitExprNode::kPosInf) {
+      res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode);
+      TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    }
+    if (this->lower_factor != 1) {
+      res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode);
+      TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    }
+    if (this->scale != 1) {
+      ICHECK(!this->dtype.is_uint() || this->scale > 0);
+      res = res * make_const(this->dtype, this->scale);
+      TVM_CHECK_CANONICAL_SIMPLIFY_CAST(dtype, res)
+    }
+    return true;
+  }
+
+  /*!
+   * \brief self = cast(dtype, self)
+   * \param dtype The target datatype
+   */
+  void CastTo(DataType dtype) {

Review comment:
       Rename to `PushCastToChildren`

##########
File path: src/arith/canonical_simplify.cc
##########
@@ -1071,6 +1192,32 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
   return ret;
 }
 
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
+  if (!IsIndexType(op->dtype)) {
+    return Rewriter::VisitExpr_(op);
+  }
+  // normalize
+  PrimExpr value = this->CanonicalMutate(op->value);
+  PrimExpr ret;

Review comment:
       comment: PushCastToChildren




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] hzfan commented on a change in pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
hzfan commented on a change in pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#discussion_r549225503



##########
File path: src/arith/canonical_simplify.cc
##########
@@ -77,6 +77,25 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   }
 }
 
+bool CheckCastImpl(DataType dtype, PrimExpr value, Analyzer* analyzer) {
+  if (!IsIndexType(dtype)) {
+    return false;
+  }
+  ConstIntBound bound = analyzer->const_int_bound(value);
+  int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
+  int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
+  if (value.dtype().bits() <= dtype.bits() ||  // upcast is safe
+      (bound->max_value <= ubound && bound->min_value >= lbound)) {
+    return true;
+  }
+  return false;
+}
+
+#define CHECK_CAST(DTYPE, VALUE)                \

Review comment:
       Fixed and renamed to `TVM_CHECK_CANONICAL_SIMPLIFY_CAST`. Thanks for the suggestion!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7045: [Arith] Simplify cast

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#issuecomment-751360524


   cc @hzfan please update per review comments


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org