You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/11 18:10:26 UTC

[GitHub] [tvm] kk2049 opened a new issue #9715: [Bug] Te.gradient not work with complex forward workload

kk2049 opened a new issue #9715:
URL: https://github.com/apache/tvm/issues/9715


   ### My problem
   I am trying to use autoscheduler to generate CUDA source code for backward stage for NCHW winograd_conv2d. Due to some bugs in topi.cuda.conv2d_winograd.winograd_cuda, I copied some code to build my workload.
   
   Luckily, this workload works without te.gradient and can successfully get source code for the forward stage. But when I add te.gradient, this workload no longer works and I get an error msg below: `Check failed: (!repl_op.same_as(s->op)) is false: Cannot find Tensor(shape=[4, 2], op.name=A) in the inputs of compute(extracted_tensor.d.shared, ......`
   
   I am really confued now. Forward stage codegen can work proves that my workload is correct in some way. So I think this bug may caused by a bug in TVM, but I am not sure.
   
   Maybe someone can help me find out whether it is a bug about TVM. 
   
   Thanks a lot!!!
   
   ### Expected behavior
   
   This code should find a valid schedule 
   ### Actual behavior
   
   I get a error below when I start tunning. 
   ```
   Get devices for measurement successfully!
   ----------------------------------------------------------------------
   ------------------------------  [ Search ]
   ----------------------------------------------------------------------
   Traceback (most recent call last):
     File "bug_scheduler.py", line 189, in <module>
       task.tune(tune_option)
     File "/data/anaconda3/envs/env3.7/lib/python3.7/site-packages/tvm-0.8.0-py3.7-linux-x86_64.egg/tvm/auto_scheduler/search_task.py", line 498, in tune
       _ffi_api.AutoSchedule(search_policy, tuning_options)
     File "/data/anaconda3/envs/env3.7/lib/python3.7/site-packages/tvm-0.8.0-py3.7-linux-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
       raise get_last_ffi_error()
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     13: TVMFuncCall
     12: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::runtime::Array<tvm::runtime::ObjectRef, void> (tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)>::AssignTypedLambda<tvm::auto_scheduler::{lambda(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)#3}>(tvm::auto_scheduler::{lambda(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
     11: tvm::auto_scheduler::AutoSchedule(tvm::auto_scheduler::SearchPolicy, tvm::auto_scheduler::TuningOptions)
     10: tvm::auto_scheduler::SketchPolicyNode::Search(int, int, int, tvm::auto_scheduler::ProgramMeasurer)
     9: tvm::auto_scheduler::SketchPolicyNode::SearchOneRound(int, tvm::runtime::Array<tvm::auto_scheduler::State, void>*)
     8: tvm::auto_scheduler::SketchPolicyNode::GenerateSketches()
     7: tvm::auto_scheduler::RuleAddCacheRead::Apply(tvm::auto_scheduler::SketchPolicyNode const&, tvm::auto_scheduler::State const&, int) const
     6: tvm::auto_scheduler::State::cache_read(int, tvm::runtime::String const&, tvm::runtime::Array<tvm::Integer, void> const&, tvm::auto_scheduler::ComputeDAG const&)
     5: tvm::auto_scheduler::CacheReadStepNode::ApplyToState(tvm::auto_scheduler::State*, tvm::auto_scheduler::ComputeDAG const&) const
     4: tvm::auto_scheduler::ComputeDAG::ReplayAndGetDAG(tvm::runtime::Array<tvm::auto_scheduler::Step, void> const&) const
     3: tvm::auto_scheduler::ComputeDAG::ApplySteps(tvm::runtime::Array<tvm::auto_scheduler::Step, void> const&, tvm::runtime::Array<tvm::te::Stage, void>*, tvm::runtime::Map<tvm::te::Stage, tvm::runtime::Array<tvm::tir::IterVar, void>, tvm::runtime::ObjectHash, tvm::runtime::ObjectEqual>*, tvm::auto_scheduler::LayoutRewriteOption) const
     2: tvm::auto_scheduler::StepApplyToSchedule(tvm::auto_scheduler::Step const&, tvm::runtime::Array<tvm::te::Stage, void>*, tvm::runtime::Map<tvm::te::Stage, tvm::runtime::Array<tvm::tir::IterVar, void>, tvm::runtime::ObjectHash, tvm::runtime::ObjectEqual>*, tvm::te::Schedule*, tvm::runtime::Array<tvm::auto_scheduler::Step, void> const&)
     1: tvm::auto_scheduler::CacheReadStepNode::ApplyToSchedule(tvm::runtime::Array<tvm::te::Stage, void>*, tvm::runtime::Map<tvm::te::Stage, tvm::runtime::Array<tvm::tir::IterVar, void>, tvm::runtime::ObjectHash, tvm::runtime::ObjectEqual>*, tvm::te::Schedule*) const
     0: tvm::te::Schedule::cache_read(tvm::te::Tensor const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::Array<tvm::te::Operation, void> const&)
     File "/data/apache-tvm-src-v0.8.0.rc0/src/te/schedule/schedule_dataflow_rewrite.cc", line 168
   TVMError: 
   ---------------------------------------------------------------
   An error occurred during the execution of TVM.
   For more information, please see: https://tvm.apache.org/docs/errors.html
   ---------------------------------------------------------------
     Check failed: (!repl_op.same_as(s->op)) is false: Cannot find Tensor(shape=[4, 2], op.name=A) in the inputs of compute(extracted_tensor.d.shared, body=[extracted_tensor[ax0, ax1, ax2, ax3]], axis=[iter_var(ax0, range(min=0, ext=2)), iter_var(ax1, range(min=0, ext=2)), iter_var(ax2, range(min=0, ext=4)), iter_var(ax3, range(min=0, ext=4))], reduce_axis=[], tag=, attrs={})
   
   ```
   ### Environment
   
   My system is Ubuntun16.04
   CUDA version is 10.2
   My tvm version is 0.8.0. I build it with the source code from Download Apache TVM Source Code web page.
   
   ### Steps to reproduce
   
   I am sorry about put such a long code here to make sure this bug can be reproduced. I have tried to cut out some part of my code to reproduce this error, but this bug can only be triggered by this long code.
   ``` python
   import os
   
   import numpy as np
   import tvm
   from tvm import auto_scheduler
   
   import logging
   from tvm import te, topi
   from tvm import autotvm
   
   from tvm.topi import nn
   from tvm.topi.utils import get_const_int, get_const_tuple, traverse_inline
   from tvm.topi.nn.winograd_util import winograd_transform_matrices
   from tvm.topi.nn.conv2d import conv2d_winograd_nhwc, _conv2d_winograd_nhwc_impl
   import sys
   import numpy as np
   from tvm.topi.testing import conv2d_nchw_python
   
   def _infer_tile_size(data, kernel, layout="NCHW"):
       if layout == "NCHW":
           N, CI, H, W = get_const_tuple(data.shape)
       else:
           assert layout == "NHWC"
           N, H, W, CI = get_const_tuple(data.shape)
   
       if H % 8 == 0:
           return 4
       return 2
   
   @auto_scheduler.register_workload
   def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
       data = te.placeholder((N, CI, H, W), name="data")
       kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
   
       stride = (1,1)
       padding = (1,1)
       dilation = (1,1)
       pre_computed = False
       out_dtype = "float32"
   
       tile_size = _infer_tile_size(data, kernel)
       N, CI, H, W = get_const_tuple(data.shape)
   
       if isinstance(N, tvm.tir.Any):
           N = tvm.te.size_var("n")
   
       if not isinstance(H, int) or not isinstance(W, int):
           raise RuntimeError(
               "cuda winograd conv2d doesn't support dynamic input\
                              height or width."
           )
   
       if isinstance(dilation, int):
           dilation_h = dilation_w = dilation
       else:
           dilation_h, dilation_w = dilation
       HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride
   
       if not pre_computed:  # kernel tensor is raw tensor, do strict check
           if dilation_h != 1 or dilation_w != 1:
               kernel = nn.dilate(kernel, (1, 1, dilation_h, dilation_w))
           CO, CI, KH, KW = get_const_tuple(kernel.shape)
           alpha = KW + tile_size - 1
           assert HSTR == 1 and WSTR == 1 and KH == KW
       else:
           # kernel tensor is pre-transfomred. this op is created by alter op layout.
           # dilation is not supported
           alpha, _, CI, CO = get_const_tuple(kernel.shape)
           KH = KW = alpha + 1 - tile_size
           assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
   
       pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
       data_pad = nn.pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
   
       r = KW
       m = tile_size
       A, B, G = winograd_transform_matrices(m, r, out_dtype)
   
       H = (H + pt + pb - KH) // HSTR + 1
       W = (W + pl + pr - KW) // WSTR + 1
       nH, nW = (H + m - 1) // m, (W + m - 1) // m
   
       P = N * nH * nW if isinstance(N, int) else nH * nW
   
       # transform kernel
       if not pre_computed:
           r_kh = te.reduce_axis((0, KH), name="r_kh")
           r_kw = te.reduce_axis((0, KW), name="r_kw")
           kernel_pack = te.compute(
               (alpha, alpha, CI, CO),
               lambda eps, nu, ci, co: te.sum(
                   kernel[co][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]
               ),
               name="my_kernel_pack",
           )
       else:
           kernel_pack = kernel    
       
       idxdiv = tvm.tir.indexdiv
       idxmod = tvm.tir.indexmod
       # pack input tile
       input_tile = te.compute(
           (CI, P, alpha, alpha),
           lambda c, p, eps_1, nu_1: data_pad[idxdiv(p, (nH * nW))][c][
               idxmod(idxdiv(p, nW), nH) * m + eps_1
           ][idxmod(p, nW) * m + nu_1],
           name="my_d",
       )
   
       # dy = tvm.te.placeholder(input_tile.shape, name="input2_dy")
       # [dw] = tvm.te.gradient(input_tile, [data], head=dy)
       # return [data, kernel, input_tile, dy, dw]
   
       # transform data
       r_a = te.reduce_axis((0, alpha), "r_a")
       r_b = te.reduce_axis((0, alpha), "r_b")
       data_pack = te.compute(
           (alpha, alpha, CI, P),
           lambda eps, nu, ci, p: te.sum(
               input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
           ),
           name="my_data_pack",
       )
   
       # dy = tvm.te.placeholder(data_pack.shape, name="input2_dy")
       # [dw] = tvm.te.gradient(data_pack, [data], head=dy)
       # return [data, kernel, data_pack, dy, dw]
   
       # do batch gemm
       ci = te.reduce_axis((0, CI), name="ci")
       bgemm = te.compute(
           (alpha, alpha, CO, P),
           lambda eps, nu, co, p: te.sum(
               kernel_pack[eps][nu][ci][co] * data_pack[eps][nu][ci][p], axis=[ci]
           ),
           name="my_bgemm",
       )
       # inverse transform
       r_a_2 = te.reduce_axis((0, alpha), "r_a_2")
       r_b_2 = te.reduce_axis((0, alpha), "r_b_2")
       inverse = te.compute(
           (CO, P, m, m),
           lambda co, p, vh, vw: te.sum(
               bgemm[r_a_2][r_b_2][co][p] * A[r_a_2][vh] * A[r_b_2][vw], axis=[r_a_2, r_b_2]
           ),
           name="my_inverse",
       )
   
       # output
       output = te.compute(
           (N, CO, H, W),
           lambda n, co, h, w: inverse[
               co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), idxmod(h, m), idxmod(w, m)
           ],
           name="my_output",
           tag="conv2d_nchw_winograd",
       )
       
       dy = tvm.te.placeholder(output.shape, name="input2_dy")
       [dw] = tvm.te.gradient(output, [data], head=dy)
       return [data, kernel, output,dy,dw]
       # return [data, kernel, output]
   
   target = tvm.target.Target("cuda")
   
   # Use the last layer in ResNet-50
   N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
   task = auto_scheduler.SearchTask(
       func=conv2d_layer, args=(N, H, W, CO, CI, KH, KW, strides, padding), target=target
   )
   
   # Inspect the computational graph
   print("Computational DAG:")
   print(task.compute_dag)
   
   log_file = "conv2d.json"
   if os.path.exists(log_file):
       os.remove(log_file)
   measure_ctx = auto_scheduler.LocalRPCMeasureContext(min_repeat_ms=300)
   tune_option = auto_scheduler.TuningOptions(
       num_measure_trials=10,  # change this to 1000 to achieve the best performance
       runner=measure_ctx.runner,
       measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
       verbose=2,
   )
   
   # Run auto-tuning (search)
   task.tune(tune_option)
   # Apply the best schedule
   sch, args = task.apply_best(log_file)
   
   
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] kk2049 commented on issue #9715: [Bug] Te.gradient not work with complex forward workload

Posted by GitBox <gi...@apache.org>.
kk2049 commented on issue #9715:
URL: https://github.com/apache/tvm/issues/9715#issuecomment-995867087


   @comaniac Sorry to bother you. (I really appreciate your help about `te.gradient` months ago #8991 ) I wonder if I can get your help again about this problem. I am confusing about this bug and have no idea how to fix it. Thanks a lot!!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on issue #9715: [Bug] Te.gradient not work with complex forward workload

Posted by GitBox <gi...@apache.org>.
comaniac commented on issue #9715:
URL: https://github.com/apache/tvm/issues/9715#issuecomment-996055035


   It looks like auto-scheduler has issues when generating the schedule sketch for this workload. You could first try to build and run this workload on CPU without tuning to see if we could identify the problem. If that doesn't work, then it must be something wrong with the workload or te.gradient. Otherwise, we could investigate the compute DAG to see why auto-scheduler failed to work on this workload generated by te.gradient.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] kk2049 commented on issue #9715: [Bug] Te.gradient not work with complex forward workload

Posted by GitBox <gi...@apache.org>.
kk2049 commented on issue #9715:
URL: https://github.com/apache/tvm/issues/9715#issuecomment-997443436


   @comaniac  Thanks for your reply! I have tried to run this workload with `tvm.target.Target("llvm")`. This workload can be successfully launched. So I select `target("cuda")` again and tried to print the compute DAG. It looks like this:
   ```
   Computational DAG:
   kernel = PLACEHOLDER [512, 512, 3, 3]
   G(i, j) = select(((floormod(i, 4) == 3) && (floormod(j, 3) == 2)), 1f, select(((floormod(i, 4) == 3) && (floormod(j, 3) == 1)),  ..(OMITTED).. (floormod(i, 4) == 0) && (floormod(j, 3) == 1)), 0f, select(((floormod(i, 4) == 0) && (floormod(j, 3) == 0)), 1f, 0f))))))))))))
   my_kernel_pack(eps, nu, ci, co) += ((kernel[co, ci, r_kh, r_kw]*G[eps, r_kh])*G[nu, r_kw])
   data = PLACEHOLDER [1, 512, 7, 7]
   data_pad(i0, i1, i2, i3) = tir.if_then_else(((((i2 >= 1) && (i2 < 8)) && (i3 >= 1)) && (i3 < 8)), data[i0, i1, (i2 - 1), (i3 - 1)], 0f)
   my_d(c, p, eps_1, nu_1) = data_pad[floordiv(p, 16), c, ((floormod(floordiv(p, 4), 4)*2) + eps_1), ((floormod(p, 4)*2) + nu_1)]
   B(i, j) = select(((floormod(i, 4) == 3) && (floormod(j, 4) == 3)), 1f, select(((floormod(i, 4) == 3) && (floormod(j, 4) == 2)),  ..(OMITTED).. ormod(i, 4) == 0) && (floormod(j, 4) == 1)), 0f, select(((floormod(i, 4) == 0) && (floormod(j, 4) == 0)), 1f, 0f))))))))))))))))
   my_data_pack(eps, nu, ci, p) += ((my_d[ci, p, r_a, r_b]*B[r_a, eps])*B[r_b, nu])
   my_bgemm(eps, nu, co, p) += (my_kernel_pack[eps, nu, ci, co]*my_data_pack[eps, nu, ci, p])
   A(i, j) = select(((floormod(i, 4) == 3) && (floormod(j, 2) == 1)), 1f, select(((floormod(i, 4) == 3) && (floormod(j, 2) == 0)),  ..(OMITTED).. ct(((floormod(i, 4) == 0) && (floormod(j, 2) == 1)), 0f, select(((floormod(i, 4) == 0) && (floormod(j, 2) == 0)), 1f, 0f))))))))
   my_inverse(co, p, vh, vw) += ((my_bgemm[r_a_2, r_b_2, co, p]*A[r_a_2, vh])*A[r_b_2, vw])
   my_output(n, co, h, w) = my_inverse[co, ((((n*4)*4) + (floordiv(h, 2)*4)) + floordiv(w, 2)), floormod(h, 2), floormod(w, 2)]
   input2_dy = PLACEHOLDER [1, 512, 7, 7]
   my_output.my_inverse.grad(ax0, ax1, ax2, ax3) = select((((((((ax2*4) + (floordiv((7 + (ax1*-2)), 8)*-8)) <= 24) && (((ax1*-2) +  ..(OMITTED).. ) <= 15)), input2_dy[0, ax0, (ax2 + (floordiv((7 + (ax1*-2)), 8)*-2)), (((floordiv((7 + (ax1*-2)), 8)*8) + (ax1*2)) + ax3)], 0f)
   extracted_tensor(n0_n0_vh.shifted.shifted, n1_n1_vw.shifted.shifted, n2_n2_jac_i0.shifted.shifted, n3_n3_jac_i1.shifted.shifted) = (A[n2_n2_jac_i0.shifted.shifted, n0_n0_vh.shifted.shifted]*A[n3_n3_jac_i1.shifted.shifted, n1_n1_vw.shifted.shifted])
   my_inverse.my_bgemm.grad(ax0, ax1, ax2, ax3) += (my_output.my_inverse.grad[ax2, ax3, n0_n0_k2.shifted.shifted, n1_n1_k3.shifted.shifted]*extracted_tensor[n0_n0_k2.shifted.shifted, n1_n1_k3.shifted.shifted, ax0, ax1])
   my_bgemm.my_data_pack.grad(ax0, ax1, ax2, ax3) += (my_inverse.my_bgemm.grad[ax0, ax1, n0_n0_k2.shifted.shifted, ax3]*my_kernel_pack[ax0, ax1, ax2, n0_n0_k2.shifted.shifted])
   extracted_tensor(n0_n0_eps.shifted.shifted, n1_n1_nu.shifted.shifted, n4_n4_jac_i2.shifted.shifted, n5_n5_jac_i3.shifted.shifted) = (B[n4_n4_jac_i2.shifted.shifted, n0_n0_eps.shifted.shifted]*B[n5_n5_jac_i3.shifted.shifted, n1_n1_nu.shifted.shifted])
   my_data_pack.my_d.grad(ax0, ax1, ax2, ax3) += (my_bgemm.my_data_pack.grad[n0_n0_k0.shifted.shifted, n1_n1_k1.shifted.shifted, ax0, ax1]*extracted_tensor[n0_n0_k0.shifted.shifted, n1_n1_k1.shifted.shifted, ax2, ax3])
   data_pad.data.grad(ax0, ax1, ax2, ax3) += my_data_pack.my_d.grad[ax1, (((((floordiv((ax2 + 1), 2) + n0_n0_fdiv1.shifted.shifted) ..(OMITTED).. ormod((ax2 + 1), 2) + (n0_n0_fdiv1.shifted.shifted*-2)) + 2), ((floormod((ax3 + 1), 2) + (n1_n1_fmod1.shifted.shifted*-2)) + 2)]
   ```
   I have tried to check this DAG info myself but failed to anything useful. Maybe you can find something in it?
   
   Thanks a lot for your help!!!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org