You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/04 19:57:35 UTC

[GitHub] [tvm] lazycal opened a new pull request #10172: [TIR] Fix Ramp dtype mismatch in VectorizeLoop and NarrowDataType passes

lazycal opened a new pull request #10172:
URL: https://github.com/apache/tvm/pull/10172


   Thanks for contributing to TVM!   Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @ them in the pull request thread.
   
   The following model
   ```python
   import tvm
   from tvm import relay
   import numpy as np
   
   xshape = (1, 1, 1)
   inp = np.random.uniform(size=xshape).astype(np.int64)
   
   x = relay.var("x", shape=xshape, dtype='int64')
   x = relay.cast(x, 'int64')
   x = relay.broadcast_to(x, relay.const([1, 2, 2], dtype='int64'))
   func = relay.Function(relay.analysis.free_vars(x), -x)
   mod = tvm.IRModule.from_expr(func)
   
   with tvm.transform.PassContext(opt_level=0):
       relay.create_executor("debug", mod, tvm.cpu()).evaluate()(inp)
   ```
   triggers two issues regarding `base` and `stride` dtype mismatch in `Ramp`, one in VectorizeLoop Pass and the other in NarrowDataType Pass. To be specific:
   - During VectorizeLoop, a loop variable will be converted to a ramp but always of dtype `int32`. This PR changes it to use the loop variable's dtype. 
   - During NarrowDataType, it can happen that the `stride `is inferred with `int32`, but `base` is not (see the added test case for detail). This PR adds an upcasting when rewriting a `Ramp` node that has `base` and `stride` inferred with different number of bits.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ganler commented on a change in pull request #10172: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
ganler commented on a change in pull request #10172:
URL: https://github.com/apache/tvm/pull/10172#discussion_r800307383



##########
File path: src/tir/transforms/narrow_datatype.cc
##########
@@ -253,6 +253,23 @@ class DataTypeRewriter : public StmtExprMutator {
     return StmtExprMutator::VisitExpr_(op);
   }
 
+  PrimExpr VisitExpr_(const RampNode* op) final {
+    PrimExpr base = VisitExpr(op->base);
+    PrimExpr stride = VisitExpr(op->stride);
+    if (base.same_as(op->base) && stride.same_as(op->stride)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      if (base.dtype().is_int()) {
+        ICHECK(stride.dtype().is_int()) << "Ramp base is int but stride is " << stride.dtype();

Review comment:
       I think we can simply assume that base and stride should be of integer types. However, I also noticed that in 
   
   https://github.com/apache/tvm/blob/22c488e3a829ad700de6547be6096fb2d1f02e81/src/tir/ir/expr.cc#L705
   
   Such assumptions are not checked. I am a bit curious if there will be, say `base/stride` in float types?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ganler edited a comment on pull request #10172: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
ganler edited a comment on pull request #10172:
URL: https://github.com/apache/tvm/pull/10172#issuecomment-1031072485


   @lazycal  Using this impl (based on yours) in https://github.com/lazycal/tvm/blob/ffe6649855c4c247f4bb85c9d48c5ca157850a1d/src/tir/ir/expr.cc#L705 fixes the bug you mentioned and might be more general to overcome other hidden ones if we can assume that `base` and `stride` must be of integers.
   
   ```c++
   Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
     ICHECK(base.defined());
     ICHECK(stride.defined());
     ICHECK(base.dtype().is_scalar());
     ICHECK(stride.dtype().is_scalar());
     ICHECK_GT(lanes, 1);
     ICHECK(base.dtype().is_int());
     ICHECK(stride.dtype().is_int());
     
     if (base.dtype() != stride.dtype()) {
       size_t bits = std::max(base.dtype().bits(), stride.dtype().bits());
       DataType dtype = base.dtype().with_bits(bits);
       if (base.dtype() != dtype) base = cast(dtype, base);
       if (stride.dtype() != dtype) stride = cast(dtype, stride);
     }
   
     ObjectPtr<RampNode> node = make_object<RampNode>();
     node->dtype = base.dtype().with_lanes(lanes);
     node->base = base;
     node->stride = stride;
     node->lanes = lanes;
     node->span = std::move(span);
     data_ = std::move(node);
   }
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] lazycal commented on a change in pull request #10172: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
lazycal commented on a change in pull request #10172:
URL: https://github.com/apache/tvm/pull/10172#discussion_r800855224



##########
File path: src/tir/transforms/narrow_datatype.cc
##########
@@ -253,6 +253,23 @@ class DataTypeRewriter : public StmtExprMutator {
     return StmtExprMutator::VisitExpr_(op);
   }
 
+  PrimExpr VisitExpr_(const RampNode* op) final {
+    PrimExpr base = VisitExpr(op->base);
+    PrimExpr stride = VisitExpr(op->stride);
+    if (base.same_as(op->base) && stride.same_as(op->stride)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      if (base.dtype().is_int()) {
+        ICHECK(stride.dtype().is_int()) << "Ramp base is int but stride is " << stride.dtype();

Review comment:
       I don't know if they can be floats. So I added that conservative line of code that only acts on integers. If they can only be integer we should add that `ICHECK` you mention.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ganler commented on pull request #10172: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
ganler commented on pull request #10172:
URL: https://github.com/apache/tvm/pull/10172#issuecomment-1031747692


   @lazycal Fair consideration! 
   
   I also tried `./tests/scripts/task_python_unittest.sh` for the direct fix. It seems to pass all related tests (except a few ones due to my environment), which means at least the unit-tests use `Ramp` with `base` and `stride` in `int` and is good with casting for `int64` and `int32` during Ramp node construction. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] lazycal commented on pull request #10172: [TIR] Fix Ramp dtype mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
lazycal commented on pull request #10172:
URL: https://github.com/apache/tvm/pull/10172#issuecomment-1030311312


   Not sure whom I should request for reviews, but it seems simimlar to this PR #9582. So ccing the reviewers there @YuchenJin @junrushao1994 @Mousius 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #10172: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #10172:
URL: https://github.com/apache/tvm/pull/10172#issuecomment-1031074881


   This is definitely an interesting bug! CC: @vinx13 @yzhliu @hzfan would be great if you guys could take a look :-)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ganler edited a comment on pull request #10172: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
ganler edited a comment on pull request #10172:
URL: https://github.com/apache/tvm/pull/10172#issuecomment-1031072485


   @lazycal  Using this impl in https://github.com/lazycal/tvm/blob/ffe6649855c4c247f4bb85c9d48c5ca157850a1d/src/tir/ir/expr.cc#L705 fixes the bug you mentioned and might be more general to overcome other hidden ones if we can assume that `base` and `stride` must be of integers.
   
   ```c++
   Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
     ICHECK(base.defined());
     ICHECK(stride.defined());
     ICHECK(base.dtype().is_scalar());
     ICHECK(stride.dtype().is_scalar());
     ICHECK_GT(lanes, 1);
     ICHECK(base.dtype().is_int());
     ICHECK(stride.dtype().is_int());
     
     if (base.dtype() != stride.dtype()) {
       size_t bits = std::max(base.dtype().bits(), stride.dtype().bits());
       DataType dtype = base.dtype().with_bits(bits);
       if (base.dtype() != dtype) base = cast(dtype, base);
       if (stride.dtype() != dtype) stride = cast(dtype, stride);
     }
   
     ObjectPtr<RampNode> node = make_object<RampNode>();
     node->dtype = base.dtype().with_lanes(lanes);
     node->base = base;
     node->stride = stride;
     node->lanes = lanes;
     node->span = std::move(span);
     data_ = std::move(node);
   }
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] lazycal commented on pull request #10172: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
lazycal commented on pull request #10172:
URL: https://github.com/apache/tvm/pull/10172#issuecomment-1031694343


   > @lazycal Using this impl (based on yours) in https://github.com/lazycal/tvm/blob/ffe6649855c4c247f4bb85c9d48c5ca157850a1d/src/tir/ir/expr.cc#L705 fixes the bug you mentioned and might be more general to overcome other hidden ones if we can assume that `base` and `stride` must be of integers.
   > 
   > ```c++
   > Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
   >   ICHECK(base.defined());
   >   ICHECK(stride.defined());
   >   ICHECK(base.dtype().is_scalar());
   >   ICHECK(stride.dtype().is_scalar());
   >   ICHECK_GT(lanes, 1);
   >   ICHECK(base.dtype().is_int());
   >   ICHECK(stride.dtype().is_int());
   >   
   >   if (base.dtype() != stride.dtype()) {
   >     size_t bits = std::max(base.dtype().bits(), stride.dtype().bits());
   >     DataType dtype = base.dtype().with_bits(bits);
   >     if (base.dtype() != dtype) base = cast(dtype, base);
   >     if (stride.dtype() != dtype) stride = cast(dtype, stride);
   >   }
   > 
   >   ObjectPtr<RampNode> node = make_object<RampNode>();
   >   node->dtype = base.dtype().with_lanes(lanes);
   >   node->base = base;
   >   node->stride = stride;
   >   node->lanes = lanes;
   >   node->span = std::move(span);
   >   data_ = std::move(node);
   > }
   > ```
   
   I didn't do what you said because
   - I want to make as little change as possible. Changing a constructor of a fundamental class might be more likely to break things IMO.
   - Also I think the author wrote this class that way must for a reason. Indeed I guess one of them could be to catch such corner cases in pass rewrite algorithms. You are right that implicitly upcasting in the constructor fixes both the bugs, but I am not sure if this is always the desired behavior. For example, there might be some other pass that should have other special handling other than upcasting, and if it forgets to do so, it'd be good to catch that, but implicit upcasting will have it silently ignored or exposed at a much later time.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi merged pull request #10172: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
masahi merged pull request #10172:
URL: https://github.com/apache/tvm/pull/10172


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ganler commented on pull request #10172: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
ganler commented on pull request #10172:
URL: https://github.com/apache/tvm/pull/10172#issuecomment-1031072485


   @lazycal  Using this impl in fixes the bug you mentioned and might be more general to overcome other hidden ones if we can assume that `base` and `stride` must be of integers.
   
   ```c++
   Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
     ICHECK(base.defined());
     ICHECK(stride.defined());
     ICHECK(base.dtype().is_scalar());
     ICHECK(stride.dtype().is_scalar());
     ICHECK_GT(lanes, 1);
     ICHECK(base.dtype().is_int());
     ICHECK(stride.dtype().is_int());
     
     if (base.dtype() != stride.dtype()) {
       size_t bits = std::max(base.dtype().bits(), stride.dtype().bits());
       DataType dtype = base.dtype().with_bits(bits);
       if (base.dtype() != dtype) base = cast(dtype, base);
       if (stride.dtype() != dtype) stride = cast(dtype, stride);
     }
   
     ObjectPtr<RampNode> node = make_object<RampNode>();
     node->dtype = base.dtype().with_lanes(lanes);
     node->base = base;
     node->stride = stride;
     node->lanes = lanes;
     node->span = std::move(span);
     data_ = std::move(node);
   }
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #10172: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #10172:
URL: https://github.com/apache/tvm/pull/10172#issuecomment-1030644414


   also cc @vinx13 @hzfan @yzhliu pelease take a look when you have time


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #10172: [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #10172:
URL: https://github.com/apache/tvm/pull/10172#issuecomment-1047280891


   I've just hit a similar error, when compiling an int8 model with tensorized ops (VNNI):
   
   ```
     21: tvm::te::MakeTensorize(tvm::te::ComputeOpNode const*, tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, bool)
     20: tvm::te::VerifyTensorizeBody(tvm::te::ComputeOpNode const*, tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::PrimExpr, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::PrimExpr> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::te::Tensor, tvm::runtime::Array<tvm::Range, void>, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::runtime::Array<tvm::Range, void> > > > const&, tvm::te::TensorIntrin const&)
     19: tvm::te::MatchTensorizeBody(tvm::te::ComputeOpNode const*, tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::te::Tensor, tvm::runtime::Array<tvm::Range, void>, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::runtime::Array<tvm::Range, void> > > > const&, tvm::te::TensorIntrin const&, tvm::runtime::Map<tvm::tir::Var, tvm::Range, void, void>*)
     18: non-virtual thunk to tvm::tir::StmtExprMutator::VisitExpr(tvm::PrimExpr const&)
     17: _ZZN3tvm3tir11ExprFunctorIFNS_8PrimExprERKS2_EE10InitVTableEvENUlRKNS_7runtime
     16: tvm::te::TensorIntrinMatcher::VisitExpr_(tvm::tir::ReduceNode const*)
   
     ...
   
     0: tvm::tir::FloorDiv::FloorDiv(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
     File "/home/masa/projects/dev/tvm/src/tir/ir/expr.cc", line 322
   TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32
   ```
   
   I wonder if this is related.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org