You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "abhikran-quic (via GitHub)" <gi...@apache.org> on 2023/05/15 11:58:15 UTC

[GitHub] [tvm] abhikran-quic opened a new pull request, #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

abhikran-quic opened a new pull request, #14854:
URL: https://github.com/apache/tvm/pull/14854

   Right now, `compute_at` is not able to handle movement of blocks if the indices of tensors contain complex floordiv/floormod operations. For ex. if a layout of tensor contains `(width % 2) // 2`, `compute_at` fails with an error. This change is to handle complex expressions in `compute_at` scheduling directive.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] quic-sanirudh merged pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "quic-sanirudh (via GitHub)" <gi...@apache.org>.
quic-sanirudh merged PR #14854:
URL: https://github.com/apache/tvm/pull/14854


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] abhikran-quic commented on pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "abhikran-quic (via GitHub)" <gi...@apache.org>.
abhikran-quic commented on PR #14854:
URL: https://github.com/apache/tvm/pull/14854#issuecomment-1557295772

   Thank you @wrongtest-intellif for the review. Currently, I'm working on fixing your comments.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] wrongtest-intellif commented on a diff in pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "wrongtest-intellif (via GitHub)" <gi...@apache.org>.
wrongtest-intellif commented on code in PR #14854:
URL: https://github.com/apache/tvm/pull/14854#discussion_r1226217746


##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,38 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
-        // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
-        if (analyzer->CanProveGreaterEqual(fac, 1)) {
-          var = p_v.Eval();
-          var_dom = arith::IntSet::Interval(required_min * fac,
-                                            analyzer->Simplify(required_max * fac + fac - 1));
-          var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        PrimExpr var_expr = p_f1.Eval();
+        if (var_expr->IsInstance<VarNode>()) {
+          // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
+          PrimExpr fac = p_f2.Eval();
+          if (analyzer->CanProveGreaterEqual(fac, 1)) {
+            var = Downcast<Var>(var_expr);
+            var_dom = arith::IntSet::Interval(required_min * fac,
+                                              analyzer->Simplify(required_max * fac + fac - 1));
+            var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
+          }
+        } else {
+          const arith::IntSet new_provided = arith::IntSet::SinglePoint(p_f1.Eval());
+          const arith::IntSet new_required = arith::IntSet::SinglePoint(p_f2.Eval());

Review Comment:
   I think the `new_required`  would be `arith::IntSet::Interval(required_min * fac, required_max * fac + fac - 1)`?
   



##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,38 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
-        // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
-        if (analyzer->CanProveGreaterEqual(fac, 1)) {
-          var = p_v.Eval();
-          var_dom = arith::IntSet::Interval(required_min * fac,
-                                            analyzer->Simplify(required_max * fac + fac - 1));
-          var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        PrimExpr var_expr = p_f1.Eval();
+        if (var_expr->IsInstance<VarNode>()) {
+          // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
+          PrimExpr fac = p_f2.Eval();
+          if (analyzer->CanProveGreaterEqual(fac, 1)) {
+            var = Downcast<Var>(var_expr);
+            var_dom = arith::IntSet::Interval(required_min * fac,
+                                              analyzer->Simplify(required_max * fac + fac - 1));
+            var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
+          }
+        } else {
+          const arith::IntSet new_provided = arith::IntSet::SinglePoint(p_f1.Eval());
+          const arith::IntSet new_required = arith::IntSet::SinglePoint(p_f2.Eval());
+          return SolveBlockVarDomain(new_provided, new_required, dim_max, analyzer);
+        }
+      } else if ((floormod(p_f1, p_f2).Match(provided_min))) {
+        PrimExpr var_expr = p_f1.Eval();
+        if (var_expr->IsInstance<VarNode>()) {
+          // generally domain of (x % fac) enforce no constraints to domain of x
+          Var var_mod = Downcast<Var>(var_expr);
+          return {var_mod, BlockVarDomainInfo()};
+        } else {
+          PrimExpr mod_1 = p_f1.Eval();
+          PrimExpr mod_2 = p_f2.Eval();
+          if (analyzer->CanProveGreaterEqual(mod_1, 1) &&
+              analyzer->CanProveGreaterEqual(mod_2, 1)) {
+            const arith::IntSet new_provided = arith::IntSet::SinglePoint(p_f1.Eval());
+            return SolveBlockVarDomain(new_provided, required, dim_max, analyzer);

Review Comment:
   We are deducing required region of `lhs` from the  required region of `lhs % fac`.
   If `a <= lhs % fac <= b` and `lhs > 0 && fac > 0 && a >= 0`, we could know `a <= lhs` but could not know `lhs <= b`, thus the `new_required` should be `[origin_required_min, +inf)` generally instead of `required`.  Thus here it would be great to change
   - Add check to ensure `required_min >= 0` 
   - Construct the deduced new required set `[required_min, +inf]` instead of origin one
   



##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,38 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
-        // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
-        if (analyzer->CanProveGreaterEqual(fac, 1)) {
-          var = p_v.Eval();
-          var_dom = arith::IntSet::Interval(required_min * fac,
-                                            analyzer->Simplify(required_max * fac + fac - 1));
-          var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        PrimExpr var_expr = p_f1.Eval();
+        if (var_expr->IsInstance<VarNode>()) {
+          // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
+          PrimExpr fac = p_f2.Eval();
+          if (analyzer->CanProveGreaterEqual(fac, 1)) {

Review Comment:
   should we check `analyzer->CanProveGreaterEqual(fac, 1)` for both var and non-var case?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Hzfengsy commented on pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "Hzfengsy (via GitHub)" <gi...@apache.org>.
Hzfengsy commented on PR #14854:
URL: https://github.com/apache/tvm/pull/14854#issuecomment-1549446438

   cc @wrongtest-intellif 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "tvm-bot (via GitHub)" <gi...@apache.org>.
tvm-bot commented on PR #14854:
URL: https://github.com/apache/tvm/pull/14854#issuecomment-1547718239

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * cc @Hzfengsy, @junrushao, @quic-sanirudh, @shingjan <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] abhikran-quic commented on a diff in pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "abhikran-quic (via GitHub)" <gi...@apache.org>.
abhikran-quic commented on code in PR #14854:
URL: https://github.com/apache/tvm/pull/14854#discussion_r1209241946


##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_v, p_f1)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
+        PrimExpr fac = p_f1.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
           var = p_v.Eval();
           var_dom = arith::IntSet::Interval(required_min * fac,
                                             analyzer->Simplify(required_max * fac + fac - 1));
           var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
         }
-      } else if ((floormod(p_v, p_f).Match(provided_min))) {
+      } else if ((floormod(p_v, p_f1).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
         return {p_v.Eval(), BlockVarDomainInfo()};
+      } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorDivNode>();
+        const arith::IntSet new_provided = arith::IntSet::SinglePoint(div_f->a);
+        return SolveBlockVarDomain(new_provided, required, dim_max, analyzer);
+      } else if ((floormod(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorModNode>();

Review Comment:
   Agree with you. I have added condition to check for positive operands.



##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_v, p_f1)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
+        PrimExpr fac = p_f1.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
           var = p_v.Eval();
           var_dom = arith::IntSet::Interval(required_min * fac,
                                             analyzer->Simplify(required_max * fac + fac - 1));
           var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
         }
-      } else if ((floormod(p_v, p_f).Match(provided_min))) {
+      } else if ((floormod(p_v, p_f1).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
         return {p_v.Eval(), BlockVarDomainInfo()};
+      } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {

Review Comment:
   Thank you. I've added changes to pass `p_f2` as new_provided to floordiv.
   
   Regarding your comment on merging this with line 426, I'm seeing a compilation error at line 430 if I try to assign `p_f1.Eval()` to `var` . IMHO, this is a genuine error. Please let me know if you have ideas to get past the error. 



##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_v, p_f1)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
+        PrimExpr fac = p_f1.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
           var = p_v.Eval();
           var_dom = arith::IntSet::Interval(required_min * fac,
                                             analyzer->Simplify(required_max * fac + fac - 1));
           var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
         }
-      } else if ((floormod(p_v, p_f).Match(provided_min))) {
+      } else if ((floormod(p_v, p_f1).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
         return {p_v.Eval(), BlockVarDomainInfo()};
+      } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorDivNode>();
+        const arith::IntSet new_provided = arith::IntSet::SinglePoint(div_f->a);
+        return SolveBlockVarDomain(new_provided, required, dim_max, analyzer);
+      } else if ((floormod(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorModNode>();
+        const arith::IntSet new_provided = arith::IntSet::SinglePoint(div_f->a);

Review Comment:
   Sure. I've removed the use of `div_f`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] abhikran-quic commented on a diff in pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "abhikran-quic (via GitHub)" <gi...@apache.org>.
abhikran-quic commented on code in PR #14854:
URL: https://github.com/apache/tvm/pull/14854#discussion_r1226376298


##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,38 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
-        // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
-        if (analyzer->CanProveGreaterEqual(fac, 1)) {
-          var = p_v.Eval();
-          var_dom = arith::IntSet::Interval(required_min * fac,
-                                            analyzer->Simplify(required_max * fac + fac - 1));
-          var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        PrimExpr var_expr = p_f1.Eval();
+        if (var_expr->IsInstance<VarNode>()) {
+          // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
+          PrimExpr fac = p_f2.Eval();
+          if (analyzer->CanProveGreaterEqual(fac, 1)) {

Review Comment:
   Agree. I've made the check common for var and non-var case.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] abhikran-quic commented on pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "abhikran-quic (via GitHub)" <gi...@apache.org>.
abhikran-quic commented on PR #14854:
URL: https://github.com/apache/tvm/pull/14854#issuecomment-1588695677

   @tvm-bot rerun


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] wrongtest-intellif commented on a diff in pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "wrongtest-intellif (via GitHub)" <gi...@apache.org>.
wrongtest-intellif commented on code in PR #14854:
URL: https://github.com/apache/tvm/pull/14854#discussion_r1211707656


##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_v, p_f1)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
+        PrimExpr fac = p_f1.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
           var = p_v.Eval();
           var_dom = arith::IntSet::Interval(required_min * fac,
                                             analyzer->Simplify(required_max * fac + fac - 1));
           var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
         }
-      } else if ((floormod(p_v, p_f).Match(provided_min))) {
+      } else if ((floormod(p_v, p_f1).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
         return {p_v.Eval(), BlockVarDomainInfo()};
+      } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {

Review Comment:
   Yeal, `p_f1.Eval()` is `PrimExpr` but var is `Var`, it could write as `var = Downcast<Var>(p_f1.Eval())`.
   
   Of course, we should first check whether the `p_f1.Eval()` is an instance of VarNode, if it is, we convert it to `Var` safely and end the recursion.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] wrongtest-intellif commented on a diff in pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "wrongtest-intellif (via GitHub)" <gi...@apache.org>.
wrongtest-intellif commented on code in PR #14854:
URL: https://github.com/apache/tvm/pull/14854#discussion_r1196680316


##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_v, p_f1)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
+        PrimExpr fac = p_f1.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
           var = p_v.Eval();
           var_dom = arith::IntSet::Interval(required_min * fac,
                                             analyzer->Simplify(required_max * fac + fac - 1));
           var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
         }
-      } else if ((floormod(p_v, p_f).Match(provided_min))) {
+      } else if ((floormod(p_v, p_f1).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
         return {p_v.Eval(), BlockVarDomainInfo()};
+      } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorDivNode>();
+        const arith::IntSet new_provided = arith::IntSet::SinglePoint(div_f->a);
+        return SolveBlockVarDomain(new_provided, required, dim_max, analyzer);
+      } else if ((floormod(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorModNode>();

Review Comment:
   Suppose the provided range min `x // 2 % 8` and the required min is `y`, then `x // 2 % 8  >= y` induce `x // 2 >= y` only when `(x//2)` is positive. Could we add check to ensure `p_f1.Eval()` and `p_f2.Eval()` are all positive?



##########
tests/python/unittest/test_tir_schedule_compute_at.py:
##########
@@ -995,6 +995,41 @@ def floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.han
                 Y[v_i] = temp[v_i // 16, v_i % 16]
 
 
+@T.prim_func
+def recursive_floordiv_and_floormod_indices(a: T.handle, b: T.handle) -> None:
+    X = T.match_buffer(a, [16, 16])
+    Y = T.match_buffer(b, [256])
+    temp = T.alloc_buffer([16, 4, 2, 2])
+    for i, j in T.grid(16, 16):
+        with T.block("A"):
+            v_i, v_j = T.axis.remap("SS", [i, j])
+            temp[v_i, v_j // 4, (v_j % 4) //2, v_j % 2] = X[v_j, v_i] + 1.0
+    for i, j in T.grid(16, 16):
+        with T.block("B"):
+            v_i, v_j = T.axis.remap("SS", [i, j])
+            Y[v_i*16 + v_j] = temp[v_i, v_j // 4, (v_j % 4) // 2, (v_j %2)]
+
+
+@T.prim_func
+def recursive_floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.handle) -> None:
+    X = T.match_buffer(a, [16, 16])
+    Y = T.match_buffer(b, [256])
+    temp = T.alloc_buffer((16, 4, 2, 2))
+    for i in range(16):
+        for j in range(16):
+            with T.block("A"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(X[v_j, v_i])
+                T.writes(temp[v_i, v_j // 4, v_j % 4 // 2, v_j % 2])
+                temp[v_i, v_j // 4, v_j % 4 // 2, v_j % 2] = X[v_j, v_i] + T.float32(1)

Review Comment:
   The unittest seems always pass in the main stream. Could we change to non-bijective form like
   `temp[v_i, v_j // 5, v_j % 4 // 2, v_j % 3] = X[v_j, v_i]` ?



##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_v, p_f1)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
+        PrimExpr fac = p_f1.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
           var = p_v.Eval();
           var_dom = arith::IntSet::Interval(required_min * fac,
                                             analyzer->Simplify(required_max * fac + fac - 1));
           var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
         }
-      } else if ((floormod(p_v, p_f).Match(provided_min))) {
+      } else if ((floormod(p_v, p_f1).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
         return {p_v.Eval(), BlockVarDomainInfo()};
+      } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {

Review Comment:
   Suppose the provided range min `x % 8 // 2` and the required min is `y`, then `x % 8 // 2 >= y` induce `x % 8 >= y * 2`. Then we treat provided min as `x % 8` and required as `y * 2` and make recursion. IIUC, we require the multiplier `p_f2` on the new prodived for floordiv?
   
   Also could we consider merge this branch with branch of line 426, just when the `pf_1.Eval()` is a variable, the recursion ends.
   
   
   



##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_v, p_f1)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
+        PrimExpr fac = p_f1.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
           var = p_v.Eval();
           var_dom = arith::IntSet::Interval(required_min * fac,
                                             analyzer->Simplify(required_max * fac + fac - 1));
           var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
         }
-      } else if ((floormod(p_v, p_f).Match(provided_min))) {
+      } else if ((floormod(p_v, p_f1).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
         return {p_v.Eval(), BlockVarDomainInfo()};
+      } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorDivNode>();
+        const arith::IntSet new_provided = arith::IntSet::SinglePoint(div_f->a);
+        return SolveBlockVarDomain(new_provided, required, dim_max, analyzer);
+      } else if ((floormod(p_f1, p_f2).Match(provided_min))) {
+        auto* div_f = provided_min.as<FloorModNode>();
+        const arith::IntSet new_provided = arith::IntSet::SinglePoint(div_f->a);

Review Comment:
   We can just use `p_f1.Eval()` than explicit;y cast to `div_f`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] wrongtest-intellif commented on a diff in pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "wrongtest-intellif (via GitHub)" <gi...@apache.org>.
wrongtest-intellif commented on code in PR #14854:
URL: https://github.com/apache/tvm/pull/14854#discussion_r1211538782


##########
tests/python/unittest/test_tir_schedule_compute_at.py:
##########
@@ -995,6 +995,41 @@ def floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.han
                 Y[v_i] = temp[v_i // 16, v_i % 16]
 
 
+@T.prim_func
+def recursive_floordiv_and_floormod_indices(a: T.handle, b: T.handle) -> None:
+    X = T.match_buffer(a, [16, 16])
+    Y = T.match_buffer(b, [256])
+    temp = T.alloc_buffer([16, 4, 2, 2])
+    for i, j in T.grid(16, 16):
+        with T.block("A"):
+            v_i, v_j = T.axis.remap("SS", [i, j])
+            temp[v_i, v_j // 4, (v_j % 4) //2, v_j % 2] = X[v_j, v_i] + 1.0
+    for i, j in T.grid(16, 16):
+        with T.block("B"):
+            v_i, v_j = T.axis.remap("SS", [i, j])
+            Y[v_i*16 + v_j] = temp[v_i, v_j // 4, (v_j % 4) // 2, (v_j %2)]
+
+
+@T.prim_func
+def recursive_floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.handle) -> None:
+    X = T.match_buffer(a, [16, 16])
+    Y = T.match_buffer(b, [256])
+    temp = T.alloc_buffer((16, 4, 2, 2))
+    for i in range(16):
+        for j in range(16):
+            with T.block("A"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(X[v_j, v_i])
+                T.writes(temp[v_i, v_j // 4, v_j % 4 // 2, v_j % 2])
+                temp[v_i, v_j // 4, v_j % 4 // 2, v_j % 2] = X[v_j, v_i] + T.float32(1)

Review Comment:
   https://github.com/apache/tvm/blob/880126a7e2f1cd044e1b61a520b2c144a90f0c62/src/tir/schedule/primitive/compute_at.cc#L604
   We have dedicated support on bijective mapping form so there would be concern the testcase not go into current changing part.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] abhikran-quic commented on a diff in pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "abhikran-quic (via GitHub)" <gi...@apache.org>.
abhikran-quic commented on code in PR #14854:
URL: https://github.com/apache/tvm/pull/14854#discussion_r1209785817


##########
tests/python/unittest/test_tir_schedule_compute_at.py:
##########
@@ -995,6 +995,41 @@ def floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.han
                 Y[v_i] = temp[v_i // 16, v_i % 16]
 
 
+@T.prim_func
+def recursive_floordiv_and_floormod_indices(a: T.handle, b: T.handle) -> None:
+    X = T.match_buffer(a, [16, 16])
+    Y = T.match_buffer(b, [256])
+    temp = T.alloc_buffer([16, 4, 2, 2])
+    for i, j in T.grid(16, 16):
+        with T.block("A"):
+            v_i, v_j = T.axis.remap("SS", [i, j])
+            temp[v_i, v_j // 4, (v_j % 4) //2, v_j % 2] = X[v_j, v_i] + 1.0
+    for i, j in T.grid(16, 16):
+        with T.block("B"):
+            v_i, v_j = T.axis.remap("SS", [i, j])
+            Y[v_i*16 + v_j] = temp[v_i, v_j // 4, (v_j % 4) // 2, (v_j %2)]
+
+
+@T.prim_func
+def recursive_floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.handle) -> None:
+    X = T.match_buffer(a, [16, 16])
+    Y = T.match_buffer(b, [256])
+    temp = T.alloc_buffer((16, 4, 2, 2))
+    for i in range(16):
+        for j in range(16):
+            with T.block("A"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(X[v_j, v_i])
+                T.writes(temp[v_i, v_j // 4, v_j % 4 // 2, v_j % 2])
+                temp[v_i, v_j // 4, v_j % 4 // 2, v_j % 2] = X[v_j, v_i] + T.float32(1)

Review Comment:
   I'm seeing an error while applying `reverse_compute_at` using the index_map shared above.
   Could you please tell if the idea is to check for indices access for non-bijective form ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] abhikran-quic commented on pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "abhikran-quic (via GitHub)" <gi...@apache.org>.
abhikran-quic commented on PR #14854:
URL: https://github.com/apache/tvm/pull/14854#issuecomment-1548756895

   cc @Hzfengsy  , @junrushao , @quic-sanirudh @shingjan 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] abhikran-quic commented on a diff in pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "abhikran-quic (via GitHub)" <gi...@apache.org>.
abhikran-quic commented on code in PR #14854:
URL: https://github.com/apache/tvm/pull/14854#discussion_r1220062219


##########
tests/python/unittest/test_tir_schedule_compute_at.py:
##########
@@ -995,6 +995,41 @@ def floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.han
                 Y[v_i] = temp[v_i // 16, v_i % 16]
 
 
+@T.prim_func
+def recursive_floordiv_and_floormod_indices(a: T.handle, b: T.handle) -> None:
+    X = T.match_buffer(a, [16, 16])
+    Y = T.match_buffer(b, [256])
+    temp = T.alloc_buffer([16, 4, 2, 2])
+    for i, j in T.grid(16, 16):
+        with T.block("A"):
+            v_i, v_j = T.axis.remap("SS", [i, j])
+            temp[v_i, v_j // 4, (v_j % 4) //2, v_j % 2] = X[v_j, v_i] + 1.0
+    for i, j in T.grid(16, 16):
+        with T.block("B"):
+            v_i, v_j = T.axis.remap("SS", [i, j])
+            Y[v_i*16 + v_j] = temp[v_i, v_j // 4, (v_j % 4) // 2, (v_j %2)]
+
+
+@T.prim_func
+def recursive_floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.handle) -> None:
+    X = T.match_buffer(a, [16, 16])
+    Y = T.match_buffer(b, [256])
+    temp = T.alloc_buffer((16, 4, 2, 2))
+    for i in range(16):
+        for j in range(16):
+            with T.block("A"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(X[v_j, v_i])
+                T.writes(temp[v_i, v_j // 4, v_j % 4 // 2, v_j % 2])
+                temp[v_i, v_j // 4, v_j % 4 // 2, v_j % 2] = X[v_j, v_i] + T.float32(1)

Review Comment:
   Sure. I've added a test case to validate the LOC changed in this PR. It's reduced from the issue that I faced with a model whose schedule I was trying to optimize using `compute_at`.



##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,27 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_v, p_f1)).Match(provided_min)) {
         // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
+        PrimExpr fac = p_f1.Eval();
         if (analyzer->CanProveGreaterEqual(fac, 1)) {
           var = p_v.Eval();
           var_dom = arith::IntSet::Interval(required_min * fac,
                                             analyzer->Simplify(required_max * fac + fac - 1));
           var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
         }
-      } else if ((floormod(p_v, p_f).Match(provided_min))) {
+      } else if ((floormod(p_v, p_f1).Match(provided_min))) {
         // generally domain of (x % fac) enforce no constraints to domain of x
         return {p_v.Eval(), BlockVarDomainInfo()};
+      } else if ((floordiv(p_f1, p_f2).Match(provided_min))) {

Review Comment:
   Thank you for the suggestion. I've merged the separate conditions for `floordiv` and `floormod` respectively. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] abhikran-quic commented on pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "abhikran-quic (via GitHub)" <gi...@apache.org>.
abhikran-quic commented on PR #14854:
URL: https://github.com/apache/tvm/pull/14854#issuecomment-1584147755

   Hi @wrongtest-intellif , Could you please review this PR ?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] abhikran-quic commented on a diff in pull request #14854: [TIR][Compute-at] Enable complex floordiv/floormod expressions in compute_at

Posted by "abhikran-quic (via GitHub)" <gi...@apache.org>.
abhikran-quic commented on code in PR #14854:
URL: https://github.com/apache/tvm/pull/14854#discussion_r1226409513


##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,38 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
-        // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
-        if (analyzer->CanProveGreaterEqual(fac, 1)) {
-          var = p_v.Eval();
-          var_dom = arith::IntSet::Interval(required_min * fac,
-                                            analyzer->Simplify(required_max * fac + fac - 1));
-          var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        PrimExpr var_expr = p_f1.Eval();
+        if (var_expr->IsInstance<VarNode>()) {
+          // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
+          PrimExpr fac = p_f2.Eval();
+          if (analyzer->CanProveGreaterEqual(fac, 1)) {
+            var = Downcast<Var>(var_expr);
+            var_dom = arith::IntSet::Interval(required_min * fac,
+                                              analyzer->Simplify(required_max * fac + fac - 1));
+            var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
+          }
+        } else {
+          const arith::IntSet new_provided = arith::IntSet::SinglePoint(p_f1.Eval());
+          const arith::IntSet new_required = arith::IntSet::SinglePoint(p_f2.Eval());

Review Comment:
   Sure. I've made the change with one modification. I've added  `analyzer->Simplify` to second argument of `arith::IntSet::Interval`. Please review it if it looks good to you. The tests that I ran passed with this change.



##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,38 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
-        // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
-        if (analyzer->CanProveGreaterEqual(fac, 1)) {
-          var = p_v.Eval();
-          var_dom = arith::IntSet::Interval(required_min * fac,
-                                            analyzer->Simplify(required_max * fac + fac - 1));
-          var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        PrimExpr var_expr = p_f1.Eval();
+        if (var_expr->IsInstance<VarNode>()) {
+          // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
+          PrimExpr fac = p_f2.Eval();
+          if (analyzer->CanProveGreaterEqual(fac, 1)) {

Review Comment:
   Agree. I've made the check common for var ad non-var case.



##########
src/tir/schedule/primitive/compute_at.cc:
##########
@@ -422,19 +422,38 @@ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& prov
       var_dom = arith::IntSet::Interval(required_min, required_max);
       var_bound = arith::IntSet::Interval(0, dim_max);
     } else {
-      arith::PVar<PrimExpr> p_f;
-      if ((floordiv(p_v, p_f)).Match(provided_min)) {
-        // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
-        PrimExpr fac = p_f.Eval();
-        if (analyzer->CanProveGreaterEqual(fac, 1)) {
-          var = p_v.Eval();
-          var_dom = arith::IntSet::Interval(required_min * fac,
-                                            analyzer->Simplify(required_max * fac + fac - 1));
-          var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
+      arith::PVar<PrimExpr> p_f1, p_f2;
+      if ((floordiv(p_f1, p_f2).Match(provided_min))) {
+        PrimExpr var_expr = p_f1.Eval();
+        if (var_expr->IsInstance<VarNode>()) {
+          // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
+          PrimExpr fac = p_f2.Eval();
+          if (analyzer->CanProveGreaterEqual(fac, 1)) {
+            var = Downcast<Var>(var_expr);
+            var_dom = arith::IntSet::Interval(required_min * fac,
+                                              analyzer->Simplify(required_max * fac + fac - 1));
+            var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
+          }
+        } else {
+          const arith::IntSet new_provided = arith::IntSet::SinglePoint(p_f1.Eval());
+          const arith::IntSet new_required = arith::IntSet::SinglePoint(p_f2.Eval());
+          return SolveBlockVarDomain(new_provided, new_required, dim_max, analyzer);
+        }
+      } else if ((floormod(p_f1, p_f2).Match(provided_min))) {
+        PrimExpr var_expr = p_f1.Eval();
+        if (var_expr->IsInstance<VarNode>()) {
+          // generally domain of (x % fac) enforce no constraints to domain of x
+          Var var_mod = Downcast<Var>(var_expr);
+          return {var_mod, BlockVarDomainInfo()};
+        } else {
+          PrimExpr mod_1 = p_f1.Eval();
+          PrimExpr mod_2 = p_f2.Eval();
+          if (analyzer->CanProveGreaterEqual(mod_1, 1) &&
+              analyzer->CanProveGreaterEqual(mod_2, 1)) {
+            const arith::IntSet new_provided = arith::IntSet::SinglePoint(p_f1.Eval());
+            return SolveBlockVarDomain(new_provided, required, dim_max, analyzer);

Review Comment:
   Sure. Agree with you. I have updated the PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org