You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/06/04 16:09:19 UTC

[GitHub] [incubator-tvm] majiang31312 edited a comment on issue #5686: [vulkan] Assertion in tir/transforms/lower_thread_allreduce.cc", line 157 TVMError: Check failed: v:

majiang31312 edited a comment on issue #5686:
URL: https://github.com/apache/incubator-tvm/issues/5686#issuecomment-638953820


   The fix seems quite simple, but I'm not sure whether it's complete. 
   Please take a look at the Discussion section. Thanks! @tqchen @wpan11nv 
   
   Problem:
   when num_thread = 1 (that's the case for vulkan as CreateTarget in target.cc set thread_warp_size to 1),
   '
   ko, ki = s[B].split(B.op.reduce_axis[0], factor=num_thread)
   s[B].bind(ki, te.thread_axis("threadIdx.x"))
   '
   will triger "TVMError: Check failed: v:" in MakeAllreduce.
   when factor=1, simplify optimization replace the IterVar with a constant node, but MakeAllreduce want a var node.
   
   
   Reproduce:
   ```
   import tvm
   from tvm import te
   
   n, m = 32,32
   num_thread = 1
   A = te.placeholder((n, m), name='A' ,dtype = 'int8')
   k = te.reduce_axis((0, m), "k")
   B = te.compute((n, ), lambda i: te.sum(A[i, k], axis=[k]), name="B")
   
   s = te.create_schedule(B.op)
   ko, ki = s[B].split(B.op.reduce_axis[0], factor=num_thread)
   s[B].bind(ki, te.thread_axis("threadIdx.z"))
   
   #target = tvm.target.create("vulkan")
   target = tvm.target.create("cuda")
   s = tvm.lower(s, [A, B])
   s = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(s)
   s = tvm.tir.transform.Simplify()(s)
   print(s)
   s = tvm.tir.transform.LowerThreadAllreduce()(s)
   ```
   
   Fix:
   ```
   --- a/src/tir/transforms/lower_thread_allreduce.cc
   +++ b/src/tir/transforms/lower_thread_allreduce.cc
   @@ -154,9 +154,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
        std::unordered_set<const VarNode*> reduce_set;
        for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
          const VarNode* v = call->args[i].as<VarNode>();
   -      CHECK(v);
   -      reduce_set.insert(v);
   +      // The simply optimization replace a iteration variable with a constant
   +      // when extent of the iteration is 1. As threaded IterVar always started from 0, 
   +      // we can just ignore this variable in this case.
   +      if (v) {
   +        reduce_set.insert(v);
   +      } else {
   +        CHECK(call->args[i].as<IntImmNode>() && call->args[i].as<IntImmNode>()->value == 0) 
   +          << "arg" << i << "should be a VarNode or IntImmNode";
   +      }
        }
   +      
        size_t nmatch = 0;
        std::vector<ThreadEntry> vred, vpar;
        for (const AttrStmtNode* attr : thread_extents_) {
   @@ -170,6 +178,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
            const auto* ptr = attr->value.as<IntImmNode>();
            CHECK(ptr) << "Need constant extent for reduce set " << iv;
            e.extent = static_cast<int>(ptr->value);
   +        // ignore variables equal to 0
   +        if (e.extent == 1) {
   +          continue;
   +        }
   +
            if (reduce_set.count(iv->var.get())) {
              vred.push_back(e);
              ++nmatch;
   ```
   
   Discussion:
     At this moment threaded IterVar always started from 0, so we can safely ignore the const var node.
     Maybe we could keep a record somewhere after we replace a VarNode with a IntImmNode? I thinks that would help to deal with such kind of cases more clearly.
     By the way, the 'analyzer_.Simplify' in BufIndex can not work as expected. It looks like that the analyzer have not been initilized properly. I can provide test cases if someone want to take a look.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org