You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "JackWeiw (via GitHub)" <gi...@apache.org> on 2024/03/06 10:56:55 UTC

[PR] Add some case in rewrite_simplify [tvm]

JackWeiw opened a new pull request, #16679:
URL: https://github.com/apache/tvm/pull/16679

   When i try to read a col of softmax matrix into shared memory, it generate a complicate expr
   [script here](https://gist.github.com/JackWeiw/3e7bd0f1ed62225505916e87738d7751)
   after this PR 
   `T.max(T.int64(1), (m + T.int64(127)) // T.int64(128) * T.int64(128))` will be simplify to 
   `(m + T.int64(127)) // T.int64(128) * T.int64(128)`
   `T.min(T.int64(0), (m + T.int64(127)) // T.int64(128) * T.int64(128) - T.int64(1))` will be simplify to 
   `T.int64(0)`


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] Add some case in rewrite_simplify [tvm]

Posted by "JackWeiw (via GitHub)" <gi...@apache.org>.
JackWeiw commented on PR #16679:
URL: https://github.com/apache/tvm/pull/16679#issuecomment-1985883927

   > Hmm, that's the tricky part. This assertion is specific to your use case, but isn't true in general. For symbolic shapes in both TIR and Relax, it is legal for the axis to have an extent of zero. How are you currently generating the PrimFunc from your script?
   
   In my use case i would get IRModule generate from mlc-llm, and as i am concerned there are most two simbolic shape m and n, and i can extract symbolic shape infomation from params of func. Maybe i can simplify add a rewrite pass to add T.Assert(m>0) in every primfunc if the are dynamic args, i will try to do it. Please let me know if you have a better insight!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] Add some case in rewrite_simplify [tvm]

Posted by "JackWeiw (via GitHub)" <gi...@apache.org>.
JackWeiw commented on PR #16679:
URL: https://github.com/apache/tvm/pull/16679#issuecomment-1985884585

   > I think it's okay to close this PR, as that functionality should be covered by the test cases for `FloorDiv` simplification
   
   Thanks for your review, i am closing this PR


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] Add some case in rewrite_simplify [tvm]

Posted by "Lunderberg (via GitHub)" <gi...@apache.org>.
Lunderberg commented on code in PR #16679:
URL: https://github.com/apache/tvm/pull/16679#discussion_r1514838945


##########
src/arith/rewrite_simplify.cc:
##########
@@ -1177,6 +1177,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) {
     TVM_TRY_REWRITE_IF(matches_one_of(min(x, floordiv(x, c2) * c2), min(floordiv(x, c2) * c2, x)),
                        floordiv(x, c2) * c2, c2.Eval()->value > 0);
 
+    TVM_TRY_REWRITE_IF(matches_one_of(min(floordiv(x + c1, c2) * c2 + c3, c4),

Review Comment:
   This rewrite rule isn't correct, as it assumes that `x` is non-negative.
   
   The expression `min( (x + 15) // 16 * 16 + 15, 0)` would match this pattern with `{c1: 15, c2: 16, c3: 15, c4: 0}`, and would be rewritten to `0`.  However, if `x` were `-1024`, then this expression should evaluate to `min(-1009, 0)`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] Add some case in rewrite_simplify [tvm]

Posted by "JackWeiw (via GitHub)" <gi...@apache.org>.
JackWeiw commented on PR #16679:
URL: https://github.com/apache/tvm/pull/16679#issuecomment-1985227683

   > I agree, it sounds like this is an issue where there's an insufficient condition, rather than a missing rewrite rule. If the test case is run without this PR, but providing the knowledge that `x > 0`, (`TestCase(tvm.te.min(0, fld(x + 3, 4) * 4 - 1), 0, x>0)` in the unit tests), then it passes.
   > 
   > There's a couple different ways that assumptions can be added to the function.
   > 
   > 1. Using the variable as a dynamic shape.  This is already applied, but isn't a strong enough assumption to perform the simplifications you want.  In your script, this would provide `m >= 0`, while the simplification requires `m > 0`.  Since TIR allows a zero-dimension axis, this is the correct default behavior.
   > 2. Adding `T.assume(m > 0)` to the function.  Using `T.assume` provides additional information, but has undefined behavior if the assumption is violated.
   > 3. Adding `T.Assert(m > 0, "Some error message here")` to the function.  Using `T.Assert` performs a runtime validation, raising an exception if the assertion is violated.  All compile-time simplifications after the `T.Assert` may assume that the assumption holds.
   > 
   > I'd recommend using `T.Assert`, and only falling back to the `T.assume` if there are performance reasons to avoid the runtime validation.
   
   Thanks for your suggestion! Yes without this PR, but add preconfition x>0 both TestCases in Min and Max are passed.
   BTW, could please give me one more hint on how to add `T.Assert(m > 0, "Some error message here")` automatically to all the function in an IRModule if there are symbolic shapes ? I am not very familiar with writing a pass, your kindly suggestion would help me a lot ! 
   BTW, should i close this PR or should i keep the added testcase by adding precondition` x>0`  `TestCase(tvm.te.min(0, fld(x + 3, 4) * 4 - 1), 0, x > 0) ` and cancel the change in rewrite_simplify ?
     


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] Add some case in rewrite_simplify [tvm]

Posted by "JackWeiw (via GitHub)" <gi...@apache.org>.
JackWeiw closed pull request #16679: Add some case in rewrite_simplify
URL: https://github.com/apache/tvm/pull/16679


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] Add some case in rewrite_simplify [tvm]

Posted by "Lunderberg (via GitHub)" <gi...@apache.org>.
Lunderberg commented on PR #16679:
URL: https://github.com/apache/tvm/pull/16679#issuecomment-1983991434

   I agree, it sounds like this is an issue where there's an insufficient condition, rather than a missing rewrite rule.  If the test case is run without this PR, but providing the knowledge that `x > 0`, (`TestCase(tvm.te.min(0, fld(x + 3, 4) * 4 - 1), 0, x>0)` in the unit tests), then it passes.
   
   There's a couple different ways that assumptions can be added to the function.
   
   1. Using the variable as a dynamic shape.  This is already applied, but isn't a strong enough assumption to perform the simplifications you want.  In your script, this would provide `m >= 0`, while the simplification requires `m > 0`.  Since TIR allows a zero-dimension axis, this is the correct default behavior.
   2. Adding `T.assume(m > 0)` to the function.  Using `T.assume` provides additional information, but has undefined behavior if the assumption is violated.
   3. Adding `T.Assert(m > 0, "Some error message here")` to the function.  Using `T.Assert` performs a runtime validation, raising an exception if the assertion is violated.  All compile-time simplifications after the `T.Assert` may assume that the assumption holds.
   
   I'd recommend using `T.Assert`, and only falling back to the `T.assume` if there are performance reasons to avoid the runtime validation.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] Add some case in rewrite_simplify [tvm]

Posted by "Lunderberg (via GitHub)" <gi...@apache.org>.
Lunderberg commented on PR #16679:
URL: https://github.com/apache/tvm/pull/16679#issuecomment-1985866404

   > Thanks for your suggestion! Yes without this PR, but add preconfition x>0 both TestCases in Min and Max are passed.
   BTW, could please give me one more hint on how to add `T.Assert(m > 0, "Some error message here")` automatically to all the function in an IRModule if there are symbolic shapes ? 
   
   Hmm, that's the tricky part.  This assertion is specific to your use case, but isn't true in general.  For symbolic shapes in both TIR and Relax, it is legal for the axis to have an extent of zero.  How are you currently generating the PrimFunc from your script?
   
   > BTW, should i close this PR or should i keep the added testcase by adding precondition `x>0` `TestCase(tvm.te.min(0, fld(x + 3, 4) * 4 - 1), 0, x > 0)` and cancel the change in rewrite_simplify ?
   
   I think it's okay to close this PR, as that functionality should be covered by the test cases for `FloorDiv` simplification.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] Add some case in rewrite_simplify [tvm]

Posted by "JackWeiw (via GitHub)" <gi...@apache.org>.
JackWeiw commented on PR #16679:
URL: https://github.com/apache/tvm/pull/16679#issuecomment-1982165909

   Sir, x here refers to col size of a input matrix, it should never be zero, would it be make sense if i add condition that x > 0?发自我的手机-------- 原始邮件 --------发件人: Eric Lunderberg ***@***.***>日期: 2024年3月7日周四 04:23收件人: apache/tvm ***@***.***>抄送: Wei Tao ***@***.***>, Author ***@***.***>主    题: Re: [apache/tvm] Add some case in rewrite_simplify (PR #16679)
   @Lunderberg requested changes on this pull request.
   
   I did a bit of testing, and these simplifications should not occur.  For the test case TestCase(tvm.te.min(0, fld(x + 3, 4) * 4 - 1), 0), substituting 0 for x, the min(0, fld(x+3, 4)*4 - 1) would evaluate to min(0, -1).  However, this simplification would cause it to instead evaluate to 0.
   Similarly for TestCase(tvm.te.max(1, fld(x + 3, 4) * 4), fld(x + 3, 4) * 4), substituting zero for x results in max(1, 0), but the simplification rule would result in 0.
   
   —Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: ***@***.***>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] Add some case in rewrite_simplify [tvm]

Posted by "JackWeiw (via GitHub)" <gi...@apache.org>.
JackWeiw commented on PR #16679:
URL: https://github.com/apache/tvm/pull/16679#issuecomment-1982225958

   > I did a bit of testing, and these simplifications should not occur. For the test case `TestCase(tvm.te.min(0, fld(x + 3, 4) * 4 - 1), 0)`, substituting `0` for `x`, the `min(0, fld(x+3, 4)*4 - 1)` would evaluate to `min(0, -1)`. However, this simplification would cause it to instead evaluate to `0`.
   > 
   > Similarly for `TestCase(tvm.te.max(1, fld(x + 3, 4) * 4), fld(x + 3, 4) * 4)`, substituting zero for `x` results in `max(1, 0)`, but the simplification rule would result in `0`.
   
   Sir, x here refers to col size of a input matrix, it should never be zero, would it be make sense if i add condition that x > 0?Or would please give me some advise on how to make [script here](https://gist.github.com/JackWeiw/3e7bd0f1ed62225505916e87738d7751) the index (when read a col of matrix into memory and the len of col is dynamic) more simplify?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org