You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "jinhongyii (via GitHub)" <gi...@apache.org> on 2024/03/11 17:33:34 UTC
[PR] [Dlight] Add fallback for low batch gemv with outer reduction [tvm]
jinhongyii opened a new pull request, #16701:
URL: https://github.com/apache/tvm/pull/16701
Add fallback for low batch gemv with outer reduction
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
Re: [PR] [Dlight] Add fallback for low batch gemv with outer reduction [tvm]
Posted by "jinhongyii (via GitHub)" <gi...@apache.org>.
jinhongyii commented on PR #16701:
URL: https://github.com/apache/tvm/pull/16701#issuecomment-1992156384
I can't reproduce your error in my branch. I think the error you show has already been fixed in https://github.com/apache/tvm/commit/5bbe1aba6d0ca0f7422299a7b34c9e1a4181288d
Please check if you have this commit in your local branch
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
Re: [PR] [Dlight] Add fallback for low batch gemv with outer reduction [tvm]
Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen merged PR #16701:
URL: https://github.com/apache/tvm/pull/16701
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
Re: [PR] [Dlight] Add fallback for low batch gemv with outer reduction [tvm]
Posted by "kmn1024 (via GitHub)" <gi...@apache.org>.
kmn1024 commented on PR #16701:
URL: https://github.com/apache/tvm/pull/16701#issuecomment-1993155374
Yes you are right. Thanks so much!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
Re: [PR] [Dlight] Add fallback for low batch gemv with outer reduction [tvm]
Posted by "kmn1024 (via GitHub)" <gi...@apache.org>.
kmn1024 commented on PR #16701:
URL: https://github.com/apache/tvm/pull/16701#issuecomment-1989893523
Thanks for the fix! I tried it and got past the previous error, but now the same `compile` command gives a new error:
```
[2024-03-12 09:57:31] INFO pipeline.py:43: Running TVM Dlight low-level optimizations
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/__main__.py", line 47, in <module>
main()
File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/__main__.py", line 24, in main
cli.main(sys.argv[2:])
File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/cli/compile.py", line 131, in main
compile(
File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/interface/compile.py", line 229, in compile
_compile(args, model_config)
File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/interface/compile.py", line 176, in _compile
args.build_func(
File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/support/auto_target.py", line 235, in build
relax.build(
File "/home/ubuntu/new-mlc/relax/python/tvm/relax/vm_build.py", line 335, in build
mod = pipeline(mod)
File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/compiler_pass/pipeline.py", line 151, in _pipeline
mod = seq(mod)
File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 307, in _pass_func
return inst.transform_module(mod, ctx)
File "/home/ubuntu/new-mlc/mlc-llm/python/mlc_chat/compiler_pass/low_batch_specialization.py", line 28, in transform_module
low_batch_mod = dl.ApplyDefaultSchedule(
File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "/home/ubuntu/new-mlc/relax/python/tvm/ir/transform.py", line 307, in _pass_func
return inst.transform_module(mod, ctx)
File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/base/transform.py", line 64, in transform_module
sch = _apply_rules(func, target, self.rules, tunable=False)
File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/base/transform.py", line 80, in _apply_rules
space = rule.apply(func, target, tunable)
File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/gpu/low_batch_gemv.py", line 273, in apply
is_inner_reduction = normalize(sch, block_info)
File "/home/ubuntu/new-mlc/relax/python/tvm/dlight/gpu/low_batch_gemv.py", line 200, in normalize
sch.reorder(*dynamic_loops, *batch_loops, *s_loops, *r_loops, *c_loops)
File "/home/ubuntu/new-mlc/relax/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
return func(*args, **kwargs)
File "/home/ubuntu/new-mlc/relax/python/tvm/tir/schedule/schedule.py", line 982, in reorder
_ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member
File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/home/ubuntu/new-mlc/relax/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
1: tvm::tir::TracedScheduleNode::Reorder(tvm::runtime::Array<tvm::tir::LoopRV, void> const&)
at /home/ubuntu/new-mlc/relax/src/tir/schedule/traced_schedule.cc:269
0: tvm::tir::ConcreteScheduleNode::Reorder(tvm::runtime::Array<tvm::tir::LoopRV, void> const&)
at /home/ubuntu/new-mlc/relax/src/tir/schedule/concrete_schedule.cc:589
ScheduleError: An error occurred in the schedule primitive 'reorder'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func(private=True)
def main(var_A: T.handle, var_B: T.handle, var_matmul: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
total_seq_len = T.int64()
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), total_seq_len), "float16")
B = T.match_buffer(var_B, (T.int64(1), T.int64(32), total_seq_len, T.int64(64)), "float16")
matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), T.int64(1), T.int64(64)), "float16")
with T.block("root"):
T.reads()
T.writes()
A_pad = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), (total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)), "float16")
B_pad = T.alloc_buffer((T.int64(1), T.int64(32), (total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), T.int64(64)), "float16")
for ax0 in range(T.int64(32)):
for ax1 in range((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)):
with T.block("A_pad"):
v0 = T.axis.spatial(T.int64(32), ax0)
v1 = T.axis.spatial((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax1)
T.reads(A[T.int64(0), v0, T.int64(0), v1])
T.writes(A_pad[T.int64(0), v0, T.int64(0), v1])
A_pad[T.int64(0), v0, T.int64(0), v1] = T.if_then_else(v1 < total_seq_len, A[T.int64(0), v0, T.int64(0), v1], T.float16(0))
for ax0 in range(T.int64(32)):
for ax1 in range((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)):
for ax2 in range(T.int64(64)):
with T.block("B_pad"):
v0 = T.axis.spatial(T.int64(32), ax0)
v1 = T.axis.spatial((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax1)
v2 = T.axis.spatial(T.int64(64), ax2)
T.reads(B[T.int64(0), v0, v1, v2])
T.writes(B_pad[T.int64(0), v0, v1, v2])
B_pad[T.int64(0), v0, v1, v2] = T.if_then_else(v1 < total_seq_len, B[T.int64(0), v0, v1, v2], T.float16(0))
for ax0 in range(T.int64(32)):
for ax1 in range(T.int64(64)):
# tir.For#0
for ax2 in range((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for u in range(1):
^^^^^^^^^^^^^^^^^^
with T.block("matmul"):
^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(T.int64(32), ax0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(T.int64(64), ax1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v2 = T.axis.reduce((total_seq_len + T.int64(1)) // T.int64(2) * T.int64(2), ax2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(A_pad[T.int64(0), v0, T.int64(0), v2], B_pad[T.int64(0), v0, v2, v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(matmul[T.int64(0), v0, T.int64(0), v1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.init():
^^^^^^^^^^^^^^
matmul[T.int64(0), v0, T.int64(0), v1] = T.float16(0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
matmul[T.int64(0), v0, T.int64(0), v1] = matmul[T.int64(0), v0, T.int64(0), v1] + A_pad[T.int64(0), v0, T.int64(0), v2] * B_pad[T.int64(0), v0, v2, v1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: Loop tir.For#0 appears in the input array for multiple times.
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org