You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/06/22 17:36:15 UTC

[GitHub] [tvm] AndrewZhaoLuo edited a comment on issue #8294: CUDA support for mixed precision pass

AndrewZhaoLuo edited a comment on issue #8294:
URL: https://github.com/apache/tvm/issues/8294#issuecomment-866190408


   I'm just going to turn off accumulating to fp32 for now. I don't want to manually look at every single schedule ever written to check for correctness. 
   
   Turning things off all the unit tests except one pass. The one that doesn't pass is the problem described by @Lunderberg. This ones seems trickier since I don't understand how cuda codegen works at all:
   
   ```
   E               rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
   E             File "/home/aluo/tvm/python/tvm/target/codegen.py", line 39, in build_module
   E               return _ffi_api.Build(mod, target)
   E             File "/home/aluo/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
   E               raise get_last_ffi_error()
   E             19: TVMFuncCall
   E             18: void tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::Target), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<c
   har> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
   E             17: tvm::codegen::Build(tvm::IRModule, tvm::Target)
   E             16: void tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::Target), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<c
   har> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
   E             15: tvm::codegen::BuildCUDA(tvm::IRModule, tvm::Target)
   E             14: tvm::codegen::CodeGenC::AddFunction(tvm::tir::PrimFunc const&)
   E             13: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const
   E             12: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
   E             11: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
   E             10: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const
   E             9: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
   E             8: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
   E             7: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const
   E             6: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::IfThenElseNode const*)
   E             5: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const
   E             4: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::StoreNode const*)
   E             3: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
   E             2: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
   E             1: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::BroadcastNode const*, std::ostream&)
   E             0: tvm::codegen::CodeGenCUDA::PrintType(tvm::runtime::DataType, std::ostream&)
   E             File "/home/aluo/tvm/src/target/source/codegen_cuda.cc", line 149
   E           TVMError: 
   E           ---------------------------------------------------------------
   E           An error occurred during the execution of TVM.
   E           For more information, please see: https://tvm.apache.org/docs/errors.html
   E           ---------------------------------------------------------------
   E           
   E             Check failed: lanes % 2 == 0 (1 vs. 0) : only support even lane for half type
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org