You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "jinhongyii (via GitHub)" <gi...@apache.org> on 2024/03/11 17:33:34 UTC

[PR] [Dlight] Add fallback for low batch gemv with outer reduction [tvm]

jinhongyii opened a new pull request, #16701:
URL: https://github.com/apache/tvm/pull/16701

   Add fallback for low batch gemv with outer reduction


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] [Dlight] Add fallback for low batch gemv with outer reduction [tvm]

Posted by "jinhongyii (via GitHub)" <gi...@apache.org>.
jinhongyii commented on PR #16701:
URL: https://github.com/apache/tvm/pull/16701#issuecomment-1992156384

   I can't reproduce your error in my branch. I think the error you show has already been fixed in https://github.com/apache/tvm/commit/5bbe1aba6d0ca0f7422299a7b34c9e1a4181288d 
   
   Please check if you have this commit in your local branch


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] [Dlight] Add fallback for low batch gemv with outer reduction [tvm]

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen merged PR #16701:
URL: https://github.com/apache/tvm/pull/16701


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] [Dlight] Add fallback for low batch gemv with outer reduction [tvm]

Posted by "kmn1024 (via GitHub)" <gi...@apache.org>.
kmn1024 commented on PR #16701:
URL: https://github.com/apache/tvm/pull/16701#issuecomment-1993155374

   Yes you are right. Thanks so much!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] [Dlight] Add fallback for low batch gemv with outer reduction [tvm]

Posted by "kmn1024 (via GitHub)" <gi...@apache.org>.
kmn1024 commented on PR #16701:
URL: https://github.com/apache/tvm/pull/16701#issuecomment-1989893523

   Thanks for the fix! I tried it and got past the previous error, but now the same `compile` command gives a new error:
   ```
   [2024-03-12 09:57:31] INFO pipeline.py:43: Running TVM Dlight low-level optimizations
   Traceback (most recent call last):
     File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
       return _run_code(code, main_globals, None,
     File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
       exec(code, run_globals)
     File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/__main__.py", line 47, in <module>
       main()
     File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/__main__.py", line 24, in main
       cli.main(sys.argv[2:])
     File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/cli/compile.py", line 131, in main
       compile(
     File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/interface/compile.py", line 229, in compile
       _compile(args, model_config)
     File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/interface/compile.py", line 176, in _compile
       args.build_func(
     File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/support/auto_target.py", line 235, in build
       relax.build(
     File "/home/ubuntu/new-mlc/relax/python/tvm/relax/vm_build.py", line 335, in build
       mod = pipeline(mod)
     File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 238, in __call__
       return _ffi_transform_api.RunPass(self, mod)
     File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
       raise_last_ffi_error()
     File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
       raise py_err
     File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/compiler_pass/pipeline.py", line 151, in _pipeline
       mod = seq(mod)
     File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 238, in __call__
       return _ffi_transform_api.RunPass(self, mod)
     File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 307, in _pass_func
       return inst.transform_module(mod, ctx)
     File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/compiler_pass/low_batch_specialization.py", line 28, in transform_module
       low_batch_mod = dl.ApplyDefaultSchedule(
     File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 238, in __call__
       return _ffi_transform_api.RunPass(self, mod)
     File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 307, in _pass_func
       return inst.transform_module(mod, ctx)
     File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/base/transform.py", line 64, in transform_module
       sch = _apply_rules(func, target, self.rules, tunable=False)
     File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/base/transform.py", line 80, in _apply_rules
       space = rule.apply(func, target, tunable)
     File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/gpu/low_batch_gemv.py", line 273, in apply
       is_inner_reduction = normalize(sch, block_info)
     File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/gpu/low_batch_gemv.py", line 200, in normalize
       sch.reorder(*dynamic_loops, *batch_loops, *s_loops, *r_loops, *c_loops)
     File "/home/ubuntu/new-mlc/relax/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
       return func(*args, **kwargs)
     File "/home/ubuntu/new-mlc/relax/python/tvm/tir/schedule/schedule.py", line 982, in reorder
       _ffi_api.ScheduleReorder(self, ordered_loops)  # type: ignore # pylint: disable=no-member
     File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
       raise_last_ffi_error()
     File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
       raise py_err
   tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
     1: tvm::tir::TracedScheduleNode::Reorder(tvm::runtime::Array<tvm::tir::LoopRV, void> const&)
           at /home/ubuntu/new-mlc/relax/src/tir/schedule/traced_schedule.cc:269
     0: tvm::tir::ConcreteScheduleNode::Reorder(tvm::runtime::Array<tvm::tir::LoopRV, void> const&)
           at /home/ubuntu/new-mlc/relax/src/tir/schedule/concrete_schedule.cc:589
   ScheduleError: An error occurred in the schedule primitive 'reorder'.
   The IR with diagnostic is:
   # from tvm.script import ir as I
   # from tvm.script import tir as T
   
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def main(var_A: T.handle, var_B: T.handle, var_matmul: T.handle):
           T.func_attr({"tir.noalias": T.bool(True)})
           total_seq_len = T.int64()
           A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), total_seq_len), "float16")
           B = T.match_buffer(var_B, (T.int64(1), T.int64(32), total_seq_len, T.int64(64)), "float16")
           matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), T.int64(1), T.int64(64)), "float16")
           with T.block("root"):
               T.reads()
               T.writes()
               A_pad = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), (total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)), "float16")
               B_pad = T.alloc_buffer((T.int64(1), T.int64(32), (total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(64)), "float16")
               for ax0 in range(T.int64(32)):
                   for ax1 in range((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)):
                       with T.block("A_pad"):
                           v0 = T.axis.spatial(T.int64(32), ax0)
                           v1 = T.axis.spatial((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax1)
                           T.reads(A[T.int64(0), v0, T.int64(0), v1])
                           T.writes(A_pad[T.int64(0), v0, T.int64(0), v1])
                           A_pad[T.int64(0), v0, T.int64(0), v1] = T.if_then_else(v1 < total_seq_len, A[T.int64(0), v0, T.int64(0), v1], T.float16(0))
               for ax0 in range(T.int64(32)):
                   for ax1 in range((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)):
                       for ax2 in range(T.int64(64)):
                           with T.block("B_pad"):
                               v0 = T.axis.spatial(T.int64(32), ax0)
                               v1 = T.axis.spatial((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax1)
                               v2 = T.axis.spatial(T.int64(64), ax2)
                               T.reads(B[T.int64(0), v0, v1, v2])
                               T.writes(B_pad[T.int64(0), v0, v1, v2])
                               B_pad[T.int64(0), v0, v1, v2] = T.if_then_else(v1 < total_seq_len, B[T.int64(0), v0, v1, v2], T.float16(0))
               for ax0 in range(T.int64(32)):
                   for ax1 in range(T.int64(64)):
                       # tir.For#0
                       for ax2 in range((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)):
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                           for u in range(1):
                           ^^^^^^^^^^^^^^^^^^
                               with T.block("matmul"):
                               ^^^^^^^^^^^^^^^^^^^^^^^
                                   v0 = T.axis.spatial(T.int64(32), ax0)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   v1 = T.axis.spatial(T.int64(64), ax1)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   v2 = T.axis.reduce((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax2)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   T.reads(A_pad[T.int64(0), v0, T.int64(0), v2], B_pad[T.int64(0), v0, v2, v1])
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   T.writes(matmul[T.int64(0), v0, T.int64(0), v1])
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   with T.init():
                                   ^^^^^^^^^^^^^^
                                       matmul[T.int64(0), v0, T.int64(0), v1] = T.float16(0)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                   matmul[T.int64(0), v0, T.int64(0), v1] = matmul[T.int64(0), v0, T.int64(0), v1] + A_pad[T.int64(0), v0, T.int64(0), v2] * B_pad[T.int64(0), v0, v2, v1]
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   Error message: Loop tir.For#0 appears in the input array for multiple times.
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org