You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/29 02:08:25 UTC

[GitHub] [tvm] comaniac opened a new issue, #13508: [Bug] Long lowering time after #13217

comaniac opened a new issue, #13508:
URL: https://github.com/apache/tvm/issues/13508

   ### Expected behavior
   
   The lowering time of the given case should be around 10 seconds.
   
   ### Actual behavior
   
   The lowering time is more than 550 seconds.
   
   ### Environment
   
   Any environment with commit commit 101e3a4ade226a2b9cdef6437a285af18aef9cf8 (#13217) or later.
   
   ### Steps to reproduce
   
   The script:
   
   ```pyhton
   import time
   
   import tvm
   from tvm import topi
   
   class Timer:
       def __init__(self, msg):
           self.msg = msg
           print(f"{msg}...", flush=True)
   
       def __enter__(self):
           self.start = time.time()
   
       def __exit__(self, *args):
           print(f"{self.msg}...{time.time() - self.start:.2f}s", flush=True)
   
   def resize2d_dx_compute(inp, dy):
       """compute definition for resize2d_dx op"""
       size = (64, 32)
       layout = "NCHW"
       method = "cubic"
       coord_trans = "half_pixel"
       rounding_method = ""
       cubic_alpha = -0.75
       cubic_exclude = 0
       out_dtype = "float32"
   
       out = topi.image.resize2d(
           inp,
           (None, None, None, None),
           size,
           layout,
           method,
           coord_trans,
           rounding_method,
           bicubic_alpha=cubic_alpha,
           bicubic_exclude=cubic_exclude,
           out_dtype=out_dtype,
       )
       grads = tvm.te.gradient(out, [inp], head=dy)
       return grads
   
   inp = tvm.te.placeholder((32, 3, 32, 32), name="inp")
   dy = tvm.te.placeholder((32, 3, 64, 32), name="dy")
   with Timer("te.gradient"):
       grads = resize2d_dx_compute(inp, dy)
   
   # This problem is platform-independent.
   with Timer("schedule"):
       sch = topi.x86.injective.schedule_injective(grads)
   
   with Timer("lower"):
       print(tvm.lower(sch, [inp, dy, grads[0]], simple_mode=True))
   ```
   
   1. Switch to a commit before 101e3a4ade226a2b9cdef6437a285af18aef9cf8 (#13217) and run the script.
   2. Checkout the commit 101e3a4ade226a2b9cdef6437a285af18aef9cf8 (#13217) and run again.
   
   Here are also the lowered IR without and with this commit:
   
   Without this commit:
   ```
   @main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
     attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
     buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32], []),
                dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32], []),
                resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32), float32, [32, 3, 32, 32], [])}
     buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
     for (ax0.ax1.fused: int32, 0, 96) "parallel" {
       for (ax2: int32, 0, 32) {
         for (ax3.outer: int32, 0, 2) {
           resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304], [])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] = broadcast(0f32, 16)
           for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
             for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
               for (ax3.inner.s: int32, 0, 16) {
                 let cse_var_3: float32 = cast(float32, n1_n1_k3.shifted.shifted)
                 let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
                 let cse_var_1: float32 = (((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
                 if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((
 cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) {
                   let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)) + ax3.inner.s)
                   resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4] + (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) + (n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float3
 2)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((ca
 st(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int3
 2, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=fl
 oat32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, d
 type=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_
 var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) - (2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32) + select((((((a
 x2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(in
 t32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select(
 (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1,
  dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == ma
 x(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cs
 e_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - 
 @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floo
 r(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @t
 ir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2),
  31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))))), 0f32))))
                 }
               }
             }
           }
         }
       }
     }
   }
   ```
   
   With this commit:
   ```
   @main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
     attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
     buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32], []),
                dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32], []),
                resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32), float32, [32, 3, 32, 32], [])}
     buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
     for (ax0.ax1.fused: int32, 0, 96) "parallel" {
       for (ax2: int32, 0, 32) {
         for (ax3.outer: int32, 0, 2) {
           resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304], [])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] = broadcast(0f32, 16)
           for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
             for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
               for (ax3.inner.s: int32, 0, 16) {
                 let cse_var_3: float32 = cast(float32, n1_n1_k3.shifted.shifted)
                 let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
                 let cse_var_1: float32 = (((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
                 if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((
 cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) {
                   let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)) + ax3.inner.s)
                   resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4] + (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) + (n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float3
 2)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((ca
 st(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int3
 2, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=fl
 oat32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, d
 type=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_
 var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) - (2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32) + select((((((a
 x2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(in
 t32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select(
 (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1,
  dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == ma
 x(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cs
 e_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - 
 @tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floo
 r(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @t
 ir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2),
  31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))))), 0f32))))
                 }
               }
             }
           }
         }
       }
     }
   }
   ```
   
   cc @Lunderberg @masahi 
   
   ### Triage
   
   * needs-triage
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] comaniac commented on issue #13508: [Bug] Long lowering time after #13217

Posted by GitBox <gi...@apache.org>.
comaniac commented on issue #13508:
URL: https://github.com/apache/tvm/issues/13508#issuecomment-1342282317

   Yeah this particular issue has been resolved. Closed.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on issue #13508: [Bug] Long lowering time after #13217

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on issue #13508:
URL: https://github.com/apache/tvm/issues/13508#issuecomment-1332722603

   Testing on the GPU, with both cuda and vulkan backends (nvidia-driver-470 on ubuntu 21.04), it shows a pretty similar effect.  It isn't quite as dramatic, only 50x slower instead of 1000x, but it's still quite a large effect.  Both GPU tests were done with the same compute definition, but with `topi.cuda.injective.schedule_injective` instead of `topi.x86.injective.schedule_injective`
   
   ![image](https://user-images.githubusercontent.com/3888575/204904629-6bc069e5-4c95-4c92-b738-1528bcdb6e53.png)
   
   The specific fix here wasn't on the transformation side, but a change to the topi operator.  The nice thing is that it can be a lot more general, and can convert floating point numbers to integer ratios (e.g. the `-0.75` in the example into `Fraction(-3, 4)`) before they get too folded to be recognized.  The downside is that it isn't as general of a solution.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] comaniac commented on issue #13508: [Bug] Long lowering time after #13217

Posted by GitBox <gi...@apache.org>.
comaniac commented on issue #13508:
URL: https://github.com/apache/tvm/issues/13508#issuecomment-1331032516

   @Lunderberg your analysis makes a lot of sense. After removing `kApplyConstraintsToBooleanBranches`, the lowering time of the example became 13 seconds, which looks much more reasonable to me. It would be great if you could fix it by disabling the analyzer along with the flag.
   
   In addition, just my two cents about the simplification, correct me if I'm wrong, low-level compilers (e.g., nvcc, llvm) should be capable of simplifying this expression by themselves, so you might not see any performance improvement even you apply this simplification at TIR level.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on issue #13508: [Bug] Long lowering time after #13217

Posted by GitBox <gi...@apache.org>.
tqchen commented on issue #13508:
URL: https://github.com/apache/tvm/issues/13508#issuecomment-1331136601

   Ideally we should keep simplifier light weight. In this case, disabling `kApplyConstraintsToBooleanBranches` makes sense
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] comaniac commented on issue #13508: [Bug] Long lowering time after #13217

Posted by GitBox <gi...@apache.org>.
comaniac commented on issue #13508:
URL: https://github.com/apache/tvm/issues/13508#issuecomment-1332521275

   Thanks for the fix and investigation. Apparently the LLVM backend doesn't aware of this transform. If possible, could you also benchmark on GPU to test nvcc?
   
   Based on the benchmark results, I agree that we should include this transform once the we make it reasonably light weight.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] comaniac closed issue #13508: [Bug] Long lowering time after #13217

Posted by GitBox <gi...@apache.org>.
comaniac closed issue #13508: [Bug] Long lowering time after #13217
URL: https://github.com/apache/tvm/issues/13508


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on issue #13508: [Bug] Long lowering time after #13217

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #13508:
URL: https://github.com/apache/tvm/issues/13508#issuecomment-1330543584

   hmm strange, the new flag `use_dataflow_analysis` in `RemoveNoOp` is set to false by default, so I thought it shouldn't affect the default lowering in any way. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on issue #13508: [Bug] Long lowering time after #13217

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on issue #13508:
URL: https://github.com/apache/tvm/issues/13508#issuecomment-1330735356

   Thank you, and I'm seeing the same behavior with this example.  Using f5a102c83, the lowering step runs in 0.18s, while using 101e3a4ad (just after PR#13217) the lowering step runs in 11.78s, the same 50x difference in performance that you're seeing.  This definitely shouldn't be the case, as @masahi pointed out, since the additional analysis is disabled by default.  I'm investigating into it.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on issue #13508: [Bug] Long lowering time after #13217

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on issue #13508:
URL: https://github.com/apache/tvm/issues/13508#issuecomment-1330823429

   Looks like the performance degredation is from `RemoveNoOp`.  Even though the data-flow is disabled by default, the analyzer of `IRMutatorWithAnalyzer` still collects scoped information.  Simplifications done by that analyzer don't show up in the output TIR, unless they are used to prove a statement to be a no-op (e.g. by having negative loop extent), but would impact the performance required.
   
   It looks like a quick fix may be to disable the `arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches`, which is currently enabled for the analyzer in `RemoveNoOp`, which restored the performance in this test case.  Can you check if it also improves the performance on your side by removing `kApplyConstraintsToBooleanBranches` from [this line](https://github.com/apache/tvm/blob/main/src/tir/transforms/remove_no_op.cc#L309)?
   
   I'm continuing to investigate, to see if this should be disabled, or if something else is wrong with simplifications.  The lowered TIR has a lot of expressions that I would expect to be simplified.  For example, that first `@tir.floor` in the if condition is `tir.floor((((cast(float32, n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32), dtype=float32))`.  which is equivalent to `floordiv(n0_n0_k2.shifted.shifted - 1, 2)`.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Lunderberg commented on issue #13508: [Bug] Long lowering time after #13217

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on issue #13508:
URL: https://github.com/apache/tvm/issues/13508#issuecomment-1332461163

   @comaniac @tqchen  I've submitted #13524, which disables the use of simplifier extensions by `RemoveNoOp`.  My main concern was that it would prevent some of the planned simplifications in #13299, but all the test cases can either by handled without extensions, or have data-flow analysis enabled which uses all the extensions.
   
   > In addition, just my two cents about the simplification, correct me if I'm wrong, low-level compilers (e.g., nvcc, llvm) should be capable of simplifying this expression by themselves, so you might not see any performance improvement even you apply this simplification at TIR level.
   
   I was curious on this, and did some benchmarks after modifying `resize.py` to only use integer fractions to compute the indices and linear/cubic interpolations weights, which ended up having about a 1000x improvement in execution speed on the LLVM backend.  Apart from the floats, the main difference in the TIR was that `tir.VectorizeLoop` could identify an opportunity to vectorize the innermost loop for integer indices, but couldn't do so for floating-point indices.
   
   ![image](https://user-images.githubusercontent.com/3888575/204859152-5e24bc5b-f556-4d08-a1e5-77bff8d2a179.png)
   
   
   Since there was such a benefit, I'm going to clean up and PR those changes to `topi.image.resize`.  (It also has a 10x improvement in the time required to lower the schedule, so that's also a plus.)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on issue #13508: [Bug] Long lowering time after #13217

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #13508:
URL: https://github.com/apache/tvm/issues/13508#issuecomment-1342262592

   Can we close this?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org