You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "cxx122 (via GitHub)" <gi...@apache.org> on 2023/07/02 12:48:31 UTC

[GitHub] [tvm] cxx122 opened a new issue, #15201: [Bug] Error reporting “Detected bound for max(2, ax0)conflicts with memorization” after schedule

cxx122 opened a new issue, #15201:
URL: https://github.com/apache/tvm/issues/15201

   The te programs before schedule optimizations didn't report any problems. But after the schedule optimization, it will report the error message `Check failed: ((val->second->min_value == res.min_value && val->second->max_value == res.max_value) || (val->second->min_value == everything.min_value && val->second->max_value == everything.max_value)) is false: Detected bound for max(2, ax0)conflicts with memorization`
   
   ```
   def te_test():
       A_1 = te.placeholder([1, 3, 224], name='A')
       data_dilate = te.compute([1, 3, 447], lambda i0, i1, i2 : tir.if_then_else((te.floormod(i2, 2) == 0), A_1[i0, i1, te.floordiv(i2, 2)], tir.const(0, dtype="float32")), name='data_dilate')
       data_pad = te.compute([1, 3, 454], lambda i0, i1, i2 : tir.if_then_else(tir.And((i2 >= 3) ,  (i2 < 450)), data_dilate[i0, i1, (i2 - 3)], tir.const(0, dtype="float32")), name='data_pad')
       W_2 = te.placeholder([3, 32, 5], name='W')
       kernel = te.compute([32, 3, 5], lambda o, i, w : W_2[i, o, (4 - w)], name='kernel')
       k = te.reduce_axis([0, 3], name='k')
       inline_tensor = te.compute([32, 3, 5], lambda ax0, ax1, ax2 : te.sum(tir.floor(kernel[k, ax1, ax2]) + kernel[ax0, ax1, ax2], axis=[k]), name='inline_tensor')
       k = te.reduce_axis([0, 3], name='k')
       fuse_tensor = te.compute([32, 3, 5], lambda ax0, ax1, ax2 : te.sum(tir.trunc(inline_tensor[k, ax1, ax2]) - inline_tensor[ax0, ax1, ax2], axis=[k]), name='fuse_tensor')
       dc = te.reduce_axis([0, 3], name='dc')
       dw = te.reduce_axis([0, 5], name='dw')
       compute = te.compute([1, 32, 450], lambda b, c, w : te.sum((data_pad[b, dc, (w + dw)]*fuse_tensor[c, dc, dw]), axis=[dc, dw]), name='compute')
       k = te.reduce_axis([0, 1], name='k')
       fuse_tensor = te.compute([1, 32, 450], lambda ax0, ax1, ax2 : te.max(tir.trunc((tir.const(0, dtype="float32") - tir.sqrt(compute[k, ax1, ax2]))), axis=[k]), name='fuse_tensor')
       return [A_1, W_2, fuse_tensor]
   ```
   
   ### Expected behavior
   
   There is no error message after the schedule optimization.
   
   ### Actual behavior
   
   `Check failed: ((val->second->min_value == res.min_value && val->second->max_value == res.max_value) || (val->second->min_value == everything.min_value && val->second->max_value == everything.max_value)) is false: Detected bound for max(2, ax0)conflicts with memorization`
   
   ```
   @I.ir_module
   class Module:
       @T.prim_func
       def main(A: T.Buffer((1, 3, 224), "float32"), W: T.Buffer((3, 32, 5), "float32"), fuse_tensor: T.Buffer((1, 32, 450), "float32")):
           T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
           data_dilate = T.allocate([14400], "float32", "global")
           data_pad = T.allocate([1362], "float32", "global")
           data_dilate_1 = T.Buffer((1341,), data=data_dilate)
           for i0_i1_fused_i2_fused in T.parallel(1341):
               cse_var_1: T.int32 = i0_i1_fused_i2_fused % 447
               A_1 = T.Buffer((672,), data=A.data)
               data_dilate_1[i0_i1_fused_i2_fused] = T.if_then_else(cse_var_1 % 2 == 0, A_1[i0_i1_fused_i2_fused // 447 * 224 + cse_var_1 // 2], T.float32(0))
           data_pad_1 = T.Buffer((1362,), data=data_pad)
           for i0_i1_fused_i2_fused in T.parallel(1362):
               cse_var_2: T.int32 = i0_i1_fused_i2_fused % 454
               data_pad_1[i0_i1_fused_i2_fused] = T.if_then_else(3 <= cse_var_2 and cse_var_2 < 450, data_dilate_1[i0_i1_fused_i2_fused // 454 * 447 + cse_var_2 - 3], T.float32(0))
           data_dilate_2 = T.Buffer((14400,), data=data_dilate)
           for b_outer_c_outer_fused_w_outer_fused in T.parallel(4):
               inline_tensor = T.allocate([b_outer_c_outer_fused_w_outer_fused * 120 + 120], "float32", "global")
               ax0 = T.int32()
               kernel = T.allocate([T.max(2, ax0) + 1], "float32", "global")
               compute_local = T.allocate([3600], "float32", "local")
               fuse_tensor_1 = T.allocate([8], "float32", "global")
               inline_tensor_1 = T.Buffer(((b_outer_c_outer_fused_w_outer_fused % 4 * 8 + 8) * 3 * 5,), data=inline_tensor)
               for ax0, ax1, ax2 in T.grid(b_outer_c_outer_fused_w_outer_fused * 8 + 8, 3, 5):
                   cse_var_3: T.int32 = ax0 * 15 + ax1 * 5 + ax2
                   kernel_1 = T.Buffer((T.max(2, ax0) + 1,), data=kernel)
                   for o in range(T.max(2, ax0) + 1):
                       W_1 = T.Buffer((480,), data=W.data)
                       kernel_1[o] = W_1[ax1 * 160 + o * 5 + 4 - ax2]
                   inline_tensor_1[cse_var_3] = T.float32(0)
                   inline_tensor_1[cse_var_3] = inline_tensor_1[cse_var_3] + (T.floor(kernel_1[0]) + kernel_1[ax0])
                   inline_tensor_1[cse_var_3] = inline_tensor_1[cse_var_3] + (T.floor(kernel_1[1]) + kernel_1[ax0])
                   inline_tensor_1[cse_var_3] = inline_tensor_1[cse_var_3] + (T.floor(kernel_1[2]) + kernel_1[ax0])
               compute_local_1 = T.Buffer((3600,), data=compute_local, scope="local")
               for w_c_outer_outer_inner in range(5):
                   for c_c_outer_inner_init, w_c_outer_inner_init in T.grid(8, 10):
                       compute_local_1[c_c_outer_inner_init * 450 + w_c_outer_outer_inner * 90 + w_c_outer_inner_init * 9:c_c_outer_inner_init * 450 + w_c_outer_outer_inner * 90 + w_c_outer_inner_init * 9 + 9] = T.Broadcast(T.float32(0), 9)
                   for dc_outer, dw_outer in T.grid(3, 5):
                       fuse_tensor_2 = T.Buffer((8,), data=fuse_tensor_1, align=32)
                       for ax0_1 in range(8):
                           fuse_tensor_2[ax0_1] = T.float32(0)
                           for k in range(3):
                               cse_var_4: T.int32 = dc_outer * 5
                               fuse_tensor_2[ax0_1] = fuse_tensor_2[ax0_1] + (T.trunc(inline_tensor_1[k * 15 + cse_var_4 + dw_outer]) - inline_tensor_1[b_outer_c_outer_fused_w_outer_fused * 120 + ax0_1 * 15 + cse_var_4 + dw_outer])
                       for c_c_outer_inner, w_c_outer_inner in T.grid(8, 10):
                           cse_var_7: T.int32 = w_c_outer_outer_inner * 90
                           cse_var_6: T.int32 = w_c_outer_inner * 9
                           cse_var_5: T.int32 = c_c_outer_inner * 450 + cse_var_7 + cse_var_6
                           compute_local_1[cse_var_5:cse_var_5 + 9] = compute_local_1[cse_var_5:cse_var_5 + 9] + data_pad_1[dc_outer * 454 + cse_var_7 + cse_var_6 + dw_outer:dc_outer * 454 + cse_var_7 + cse_var_6 + dw_outer + 9] * T.Broadcast(fuse_tensor_2[c_c_outer_inner], 9)
               for c_inner, w_inner in T.grid(8, 450):
                   cse_var_8: T.int32 = c_inner * 450
                   data_dilate_2[b_outer_c_outer_fused_w_outer_fused * 3600 + cse_var_8 + w_inner] = compute_local_1[cse_var_8 + w_inner]
           for ax0_ax1_fused_ax2_fused in T.parallel(14400):
               fuse_tensor_1 = T.Buffer((14400,), data=fuse_tensor.data)
               fuse_tensor_1[ax0_ax1_fused_ax2_fused] = T.float32(-3.4028234663852886e+38)
               fuse_tensor_1[ax0_ax1_fused_ax2_fused] = T.max(fuse_tensor_1[ax0_ax1_fused_ax2_fused], T.trunc(T.float32(0) - T.sqrt(data_dilate_2[ax0_ax1_fused_ax2_fused])))
   ```
   
   ### Environment
   
   Operating System: Ubuntu 18.04
   TVM version: v0.10.dev0
   
   ### Steps to reproduce
   
   ```
   import tvm
   import json
   import random
   import numpy as np
   from tvm import te
   from tvm import tir
   from tvm import testing
   from tvm import auto_scheduler
   from tvm.auto_scheduler.workload_registry import register_workload_tensors
   
   POLICY_PARAMS = {
       "eps_greedy": 0.05,
       "retry_search_one_round_on_empty": 1,
       "sample_init_min_population": 3,
       "sample_init_use_measured_ratio": 0.2,
       "evolutionary_search_population": 5,
       "evolutionary_search_num_iters": 4,
       "evolutionary_search_mutation_prob": 0.85,
       "cpu_multi_level_tiling_structure": "SSRSRS",
       "gpu_multi_level_tiling_structure": "SSSRRSRS",
       # Notice: the default thread bind policy of GPU assumes the tiling structure to have at
       # least 3 spatial tiling levels in outermost
       "max_innermost_split_factor": 64,
       "max_vectorize_size": 16,
       "disable_change_compute_location": 0,
   } 
   
   def te_test():
       A_1 = te.placeholder([1, 3, 224], name='A')
       data_dilate = te.compute([1, 3, 447], lambda i0, i1, i2 : tir.if_then_else((te.floormod(i2, 2) == 0), A_1[i0, i1, te.floordiv(i2, 2)], tir.const(0, dtype="float32")), name='data_dilate')
       data_pad = te.compute([1, 3, 454], lambda i0, i1, i2 : tir.if_then_else(tir.And((i2 >= 3) ,  (i2 < 450)), data_dilate[i0, i1, (i2 - 3)], tir.const(0, dtype="float32")), name='data_pad')
       W_2 = te.placeholder([3, 32, 5], name='W')
       kernel = te.compute([32, 3, 5], lambda o, i, w : W_2[i, o, (4 - w)], name='kernel')
       k = te.reduce_axis([0, 3], name='k')
       inline_tensor = te.compute([32, 3, 5], lambda ax0, ax1, ax2 : te.sum(tir.floor(kernel[k, ax1, ax2]) + kernel[ax0, ax1, ax2], axis=[k]), name='inline_tensor')
       k = te.reduce_axis([0, 3], name='k')
       fuse_tensor = te.compute([32, 3, 5], lambda ax0, ax1, ax2 : te.sum(tir.trunc(inline_tensor[k, ax1, ax2]) - inline_tensor[ax0, ax1, ax2], axis=[k]), name='fuse_tensor')
       dc = te.reduce_axis([0, 3], name='dc')
       dw = te.reduce_axis([0, 5], name='dw')
       compute = te.compute([1, 32, 450], lambda b, c, w : te.sum((data_pad[b, dc, (w + dw)]*fuse_tensor[c, dc, dw]), axis=[dc, dw]), name='compute')
       k = te.reduce_axis([0, 1], name='k')
       fuse_tensor = te.compute([1, 32, 450], lambda ax0, ax1, ax2 : te.max(tir.trunc((tir.const(0, dtype="float32") - tir.sqrt(compute[k, ax1, ax2]))), axis=[k]), name='fuse_tensor')
       return [A_1, W_2, fuse_tensor]
   
   # Get dag and print it.
   
   tensors = te_test()
   dag = auto_scheduler.ComputeDAG(tensors)
   dict = json.loads(tvm.ir.save_json(tensors))
   with open("./saved_json.txt", "w") as file:
       file.write(tvm.ir.save_json(tensors))
   print(dag)
   
   # Get inputs.
   
   inputs = []
   for tensor in dag.tensors:
       shape = [x.value if 'value' in dir(x) and isinstance(x.value, int) else 1 for x in tensor.shape]
       inputs.append((2 * np.random.uniform(size=shape)+1).astype(tensor.dtype))
   
   # Get program with no schedule.
   
   results = []
   mod_list = []
   pre_schedule_list = dag.apply_steps_from_state(dag.get_init_state())
   pre_mod = tvm.lower(pre_schedule_list[0], pre_schedule_list[1], simple_mode=True)
   mod_list.append(pre_mod)
   with tvm.transform.PassContext(opt_level=0):
       mod_exec = tvm.build(pre_mod)
       print(pre_mod)
   
   new_inputs = [tvm.nd.array(x, tvm.cpu()) for x in inputs.copy()]
   mod_exec(*new_inputs)
   result = []
   for x in new_inputs:
       try:
           result.append(x.numpy() if isinstance(
               x, tvm.runtime.NDArray) else x)
       except (ValueError, tvm.TVMError):
           result.append(None)
   results.append(result)
   
   # Get program with schedule.
   
   register_workload_tensors(dag.workload_key(), tensors)
   task = auto_scheduler.SearchTask(workload_key=dag.workload_key(), target=tvm.target.Target("llvm"))
   policy = auto_scheduler.SketchPolicy(task, verbose=0, params=POLICY_PARAMS)
   states = policy.sample_initial_population()
   
   for state in states:
       schedule_list = dag.apply_steps_from_state(state)
       mod = tvm.lower(schedule_list[0], schedule_list[1], simple_mode=True)
       mod_list.append(mod)
       print(mod)
       with tvm.transform.PassContext(opt_level=0):
           mod_exec = tvm.build(mod)
       
       new_inputs = [tvm.nd.array(x, tvm.cpu()) for x in inputs.copy()]
       mod_exec(*new_inputs)
       result = []
       for x in new_inputs:
           try:
               result.append(x.numpy() if isinstance(
                   x, tvm.runtime.NDArray) else x)
           except (ValueError, tvm.TVMError):
               result.append(None)
       results.append(result)
   
   for i in range(1, len(results)):
       result = results[i]
       for compare_idex in [-1]:
           try:
               tvm.testing.assert_allclose(results[0][compare_idex], result[compare_idex], rtol=1e-5)
           except AssertionError as e:
               print(e)
               print(mod_list[i])
               break
   
   ```
   
   ### Triage
   
   * tune:auto_scheduler
   * tir:schedule 
   * tir:transform	


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen closed issue #15201: [Bug] Error reporting “Detected bound for max(2, ax0)conflicts with memorization” after schedule

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen closed issue #15201: [Bug] Error reporting “Detected bound for max(2, ax0)conflicts with memorization” after schedule
URL: https://github.com/apache/tvm/issues/15201


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on issue #15201: [Bug] Error reporting “Detected bound for max(2, ax0)conflicts with memorization” after schedule

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen commented on issue #15201:
URL: https://github.com/apache/tvm/issues/15201#issuecomment-1616652964

   Thank you for reporting the issue, we are moving towards TensorIR and meta-schedule. and it is expected that some of the te and auto-scheduler may not cover all cases and require tuning to guard against possible failure case, so atm we don't have cycles to fix legacy issues in te.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org