You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/06/14 18:47:47 UTC

[GitHub] [tvm] kparzysz-quic opened a new issue, #11716: [Bug] Crash with tensorize + prefetch after PR#11589

kparzysz-quic opened a new issue, #11716:
URL: https://github.com/apache/tvm/issues/11716

   Testcase:
   ```python
   import tvm
   
   
   def compute_something(A, C):
       ib = tvm.tir.ir_builder.create()
       S = C.vstore(
           0,
           tvm.tir.call_intrin(
               "uint8x2", tvm.ir.Op.get("tir.reinterpret"), A.vload(0) + A.vload(1)
           ),
       )
       ib.emit(S)
       return ib.get()
   
   
   def intrin_compute_something(S):
       A = tvm.te.placeholder((128,), dtype="int16", name="A")
       C = tvm.te.compute((128,), lambda i: (A[i] * S).astype("uint8"), name="C")
   
       Ab = tvm.tir.decl_buffer(
           A.shape, A.dtype, name="Ab", elem_offset=tvm.te.var("b_offset", "int32")
       )
       Cb = tvm.tir.decl_buffer(
           C.shape, C.dtype, name="Cb", elem_offset=tvm.te.var("c_offset", "int32")
       )
   
       def intrin_func(ins, outs):
           M = compute_something(ins[0], outs[0])
           return M, None, None
   
       return tvm.te.decl_tensor_intrin(
           C.op,
           intrin_func,
           binds={A: Ab, C: Cb},
           default_buffer_params={"offset_factor": 128},
       )
   
   
   def some_op(target):
       D, H, W = tvm.te.var("D"), tvm.te.var("H"), tvm.te.var("W")
       S = tvm.te.var("S", dtype="uint16")
       A = tvm.te.placeholder((H, W, D * 128), name="A", dtype="int16")
   
       C = tvm.te.compute(
           A.shape, lambda yy, xx, cc: (A[yy, xx, cc] * S).astype("uint8"), name="C"
       )
   
       # Create schedule without prefetch
       s = tvm.te.create_schedule(C.op)
   
       cy, cx, cc = s[C].op.axis
       co, ci = s[C].split(cc, factor=128)
       s[C].tensorize(ci, intrin_compute_something(S))
       yo, yi = s[C].split(cy, factor=32)
       s[C].prefetch(A, yo, 1)
   
       module = tvm.build(s, [A, C, D, S], target)
       return module
   
   
   def test_some_op():
       module = some_op("llvm")
   
   
   test_some_op()
   ```
   
   Run: `python3 testcase.py`
   
   ```
   [...]
     7: tvm::arith::BufferTouchedDomain::VisitStmt_(tvm::tir::BufferStoreNode const*)
     6: void tvm::arith::BufferTouchedDomain::Touch<tvm::runtime::Array<tvm::PrimExpr, void> >(std::__1::vector<std::__1::vector<tvm::arith::IntSet, std::__1::allocator<tvm::arith::IntSet> >, std::__1::allocator<std::__1::vector<tvm::arith::IntSet, std::__1::allocator<tvm::arith::IntSet> > > >*, tvm::runtime::Array<tvm::PrimExpr, void> const&) const
     5: tvm::arith::EvalSet(tvm::PrimExpr, std::__1::unordered_map<tvm::tir::VarNode const*, tvm::arith::IntSet, std::__1::hash<tvm::tir::VarNode const*>, std::__1::equal_to<tvm::tir::VarNode const*>, std::__1::allocator<std::__1::pair<tvm::tir::VarNode const* const, tvm::arith::IntSet> > > const&)
     4: tvm::arith::EvalSet(tvm::PrimExpr, tvm::runtime::Map<tvm::tir::Var, tvm::arith::IntSet, void, void> const&)
     3: tvm::tir::ExprFunctor<tvm::arith::IntervalSet (tvm::PrimExpr const&)>::VisitExpr(tvm::PrimExpr const&)
     2: tvm::NodeFunctor<tvm::arith::IntervalSet (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::arith::IntervalSet (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::arith::IntervalSet (tvm::PrimExpr const&)>*) const
     1: _ZZN3tvm3tir11ExprFunctorIF
     0: tvm::arith::IntervalSetEvaluator::VisitExpr_(tvm::tir::RampNode const*)
     File "/w/src/aitools/tvm-upstream/src/arith/int_set.cc", line 453
   TVMError:
   ---------------------------------------------------------------
   An error occurred during the execution of TVM.
   For more information, please see: https://tvm.apache.org/docs/errors.html
   ---------------------------------------------------------------
     Check failed: (eval_vec_) is false:
   ```
   
   cc: @csullivan 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on issue #11716: [Bug] Crash with tensorize + prefetch after PR#11589

Posted by GitBox <gi...@apache.org>.
csullivan commented on issue #11716:
URL: https://github.com/apache/tvm/issues/11716#issuecomment-1155810243

   Thanks @kparzysz-quic, PTAL at the above PR. I noticed that on main this test always results in a null set returned from domain touched because the buffer annotated with prefetch scope doesn't exist within the IR. That's a separate issue, but pointing it out in case it's relevant for your use case. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] kparzysz-quic closed issue #11716: [Bug] Crash with tensorize + prefetch after PR#11589

Posted by GitBox <gi...@apache.org>.
kparzysz-quic closed issue #11716: [Bug] Crash with tensorize + prefetch after PR#11589
URL: https://github.com/apache/tvm/issues/11716


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] areusch commented on issue #11716: [Bug] Crash with tensorize + prefetch after PR#11589

Posted by GitBox <gi...@apache.org>.
areusch commented on issue #11716:
URL: https://github.com/apache/tvm/issues/11716#issuecomment-1155776195

   #11589


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org