You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/15 17:45:05 UTC

[GitHub] [incubator-tvm] ANSHUMAN87 opened a new pull request #6066: [TIR][Transform] HoistIfThenElse added

ANSHUMAN87 opened a new pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066


   This is a follow up PR. Please refer #5559.
   
   cc @kevinthesun , @roastduck , @zhiics , @junrushao1994 , @tqchen .
   
   I have tried to cover all the possible scenarios. Please let me know in case i miss anything. TIA!
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459906470



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       May be the example i shares are not enough to cover all cases :)
   Let me put another example, may be that will clear-up all the queries.
   ```
   for (i: int32, 0, l: int32) {
     for (j: int32, 0, m: int32) {
       for (k: int32, 0, n: int32) {
         if @tir.likely(tvm.tir.any(i < 4, k >= 8)), dtype=bool) {
           m
         } else {
           n
         }
       }
     }
   }
   ```
   In the case above, `target_for_pos = 1`, but we should not do hoisting in that case. 
   I hope this will clear your queries.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460369122



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() { is_recorder_on = false; }
+
+  bool IsRecordingOn() { return is_recorder_on; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    is_recorder_on = true;
+    if (!var_for_map_.count(op->loop_var.get())) {
+      var_for_map_.insert({op->loop_var.get(), op});
+    }
+    ordered_for_list_.emplace_back(op);
+  }
+
+  void RemoveRecord(const ForNode* op) {
+    StopRecording();
+    var_for_map_.erase(op->loop_var.get());
+    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();
+  }
+
+  void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
+    hoist_for_if_recorder = std::make_tuple(true, for_node, if_node);
+    StopRecording();
+  }
+
+  void UpdateAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.insert(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.insert(iv);
+    }
+  }
+
+  void RemoveAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.erase(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.erase(iv);
+    }
+  }
+
+  bool CheckAttrVar() {
+    for (auto var : if_var_list_) {
+      if (attr_var_list_.count(var)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  std::vector<const ForNode*> ordered_for_list_;
+
+  std::vector<const VarNode*> if_var_list_;
+
+  std::unordered_set<const VarNode*> attr_var_list_;
+
+  VarForMap var_for_map_;
+
+  bool is_if_cond{false};
+  bool is_recorder_on{false};
+};
+
+class IfThenElseHoister : public StmtMutator {
+ public:
+  IfThenElseHoister() : hoist_selector(HoistCandidateSelector()) {}
+
+  Stmt VisitAndMutate(Stmt stmt) {
+    hoist_selector(stmt);
+    Stmt stmt_copy = std::move(stmt);
+
+    while (hoist_selector.RecordingComplete()) {
+      target_for = hoist_selector.GetTargetForNode();
+      target_if = hoist_selector.GetTargetIfNode();
+
+      stmt_copy = operator()(stmt_copy);
+
+      hoist_selector.ResetRecorder();
+      hoist_selector(stmt_copy);
+    }
+
+    // Support SSA Form
+    stmt_copy = ConvertSSA(stmt_copy);
+    return stmt_copy;
+  }
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    if ((!is_updating) && (target_for == op)) {
+      is_updating = true;
+      is_then_case = true;
+      Stmt then_case = StmtMutator::VisitStmt_(op);

Review comment:
       Sorry, i could not get your point clearly here. Here the op node is `ForNode`, it does not have `then_case`.
   Would you please help me understand better!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460096506



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       Yes, you are right, for some cases we dont have to check `IsParentForLoop` 
   if we traverse `(i < ordered_for_list_.size())`.
   But i am afraid the logic wont be a generalized one. Still there will be cases when we need to check it. 
   And above all it is a fail safe operation, which will never allow the hoisting to cross its parent for node. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459866041



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);

Review comment:
       Consider we split a loop `for (i = 0; i < n; i++)` into an inner loop and a outer loop, and bind the outer loop to `threadIdx.x`, there is common to be a condition like `if (threadIdx.x * tile_size + i.inner < n)`, where there are both `attr` variable `threadIdx.x` and loop variable `i.inner`. Can we handle this case?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460100741



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() {
+    if (is_recorder_on) is_recorder_on = false;

Review comment:
       ```suggestion
       is_recorder_on = false;
   ```

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() {
+    if (is_recorder_on) is_recorder_on = false;
+  }
+
+  bool IsRecordingOn() { return is_recorder_on; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    if (!is_recorder_on) is_recorder_on = true;
+    if (!var_for_map_.count(op->loop_var.get())) {
+      var_for_map_.insert({op->loop_var.get(), op});
+    }
+    ordered_for_list_.emplace_back(op);
+  }
+
+  void RemoveRecord(const ForNode* op) {
+    StopRecording();
+    var_for_map_.erase(op->loop_var.get());
+    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();

Review comment:
       why? imo this should be a CHECK() instead.

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() {
+    if (is_recorder_on) is_recorder_on = false;
+  }
+
+  bool IsRecordingOn() { return is_recorder_on; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    if (!is_recorder_on) is_recorder_on = true;

Review comment:
       ```suggestion
       is_recorder_on = true;
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-666365123


   > 0: I dont think a pass have to handle all cases for it to be merged. We can improve upon it incrementally.
   > 1: About loop unswitching, detecting invariant is very hard. But we can simply approximate by detecting expr that didnt have any vars change in the loop. There is rice theorem, so all solution will not be perfect, and we should do the least amount of work that get us good enough performance.
   > 2: you should just build a datastructure/helper function such as boundvar/freevar to help. I dont found dealing with scope too complex in the pass I had written. If you had more specific question/objection please bring it up.
   
   Thanks @MarisaKirisame for your enlightening response.
   
   I am in total agreement with all your points. :)
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459932460



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);

Review comment:
       I met those cases when I was working with sparse kernels. But I agree that we can merge this PR for now and left it for future improvements.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460367484



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;

Review comment:
       Thanks @kparzysz-quic for review!
   You are right. This type exist because my initial design was different. Later i just liked it :)
   Lets keep it!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-668224314


   @ANSHUMAN87 Thanks for clarification. Though we might not need to do so in this PR, it would be great if we can bring this in since AFAIK gpu is the major case for this pass in tvm.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459866041



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);

Review comment:
       Consider we split a loop `for (i = 0; i < n; i++)` into an inner loop and a outer loop, and bind the outer loop to `threadIdx.x`, there is common to be a condition like `if (threadIdx.x * tile_size + i.inner < n)`, where there are both `attr` variable `threadIdx.x` and loop variable `i.inner`.
   
   If this condition is inside another loop, say Loop `j`, it should be hoisted. Sometimes when `n` is constant, this condition can be optimized out in other passes. But if `n` is a variable, for example, defined with `tvm.te.var`, it can't be optimized out. Can we handle this case?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459959421



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);

Review comment:
       I will check on sparse kernels once then. I will keep this point open. So that we can cover this scenario as well.
   
   The solution is very simple, just move the position of pass to the end and disable the Attrs var list check feature. Unless there no further optimisation beyond lowering in the scope. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459867555



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       Let me describe my understanding.
   
   In case 1, `n` is 2, where Loop `i`, `j`, `k` is Loop 0, 1, 2, respectively. Both Loop 0 and 1 match, so `match_for_loop_pos == 1`, and `target_for_pos == match_for_loop_pas + 1 == 2`, which is equal to `n`.
   
   In case 2, `n` is 0, where Loop `k` is Loop 0, while Loop `i` or `j` is not recorded. There is no match, so `target_for_pos == 0`, which is still equal to `n`.
   
   What is wrong?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-662572200


   All the CI issues are resolved now. Also my internal tests shows good result. I think we can start review now. TIA!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459859824



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);

Review comment:
       I found that the Passes working on these Attr nodes are distributed. Also the if nodes are likely optimized out in those cases.
   But in reply to your concern, my current implementation does not stop hoisting for all Attr variables, it considers the required ones. It consider the if condition which has mixed case like `if(i + global_var ==  3)` .  So i think your actual concern is addressed.
   
   But if you have any specific case, which current logic does not handle, please let me know. I can check on it. Thanks!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen edited a comment on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-663669923


   cc @ZihengJiang @merrymercy @Hzfengsy @kevinthesun @junrushao1994 @spectrometerHBH  @wpan11nv @kparzysz-quic please help to take  a look


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun edited a comment on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
kevinthesun edited a comment on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-667357815


   One thing I think it is good to have in this PR is to get some benchmark data. IMO this pass is especially valuable when tackling dynamic kernels in GPU which introduces a lot of branchings. It's great to see how much performance improvement we can have. For CPU, we need to make sure this pass doesn't introduce regression for some common workloads, such as resnet.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-667549552


   > One thing I think it is good to have in this PR is to get some benchmark data, since we now enable this pass by default. IMO this pass is especially valuable when tackling dynamic kernels in GPU which introduces a lot of branchings. It's great to see how much performance improvement we can have. For CPU, we need to make sure this pass doesn't introduce regression for some common workloads, such as resnet.
   
   @kevinthesun : I have verified the inference time for Resnet50 on CPU. There is no performance impact. In fact i did not find anything as Hoisting candidate.
   
   ### Hoisting Disabled :
   ![image](https://user-images.githubusercontent.com/32511895/89104943-9dde3a80-d43a-11ea-994d-85e1468c8241.png)
   
   
   ### Hoisting Enabled:
   ![image](https://user-images.githubusercontent.com/32511895/89104952-b2223780-d43a-11ea-8f92-e7fb4e555115.png)
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460368833



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;

Review comment:
       Thanks for catching!
   All these if else breakdown, because i had to add logs while debugging initially :)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460019969



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       Let me make it explicit.
   
   ```
   if_var_list_ = [i, k]
   ordered_for_list_ = [for i, for j, for k]
   Iter 1: var == i, i == 0. (ordered_for_list_[i] == var_for_map_[var]) == true; match_for_loop_pos = 0; break;
   Iter 2: var == k, i == 0. (ordered_for_list_[i] == var_for_map_[var]) == false;
   Iter 3: var == k, i == 1. (ordered_for_list_[i] == var_for_map_[var]) == false;
   Iter 4: var == k, i == 2. (ordered_for_list_[i] == var_for_map_[var]) == true; match_for_loop_pos = 2; break;
   target_for_pos = match_for_loop_pos + 1 = 3
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-665439076


   Gentle ping @tqchen , @ZihengJiang @merrymercy @Hzfengsy @kevinthesun @junrushao1994 @spectrometerHBH @wpan11nv @kparzysz-quic !!!
   
   Let us discuss and bring a conclusion to the open points / challenges mentioned in my previous comment. TIA!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-660693140


   > Since nearly all programs will be touched by this pass, potential bugs in this pass would be critical. Could you (maybe temporarily) add this pass into the default building procedure and run all the tests? It would greatly reduce potential bugs.
   
   @roastduck : Thanks a lot for your input! 
   I will definitely ensure your point during internal test.
   Currently I am working on to resolve all the errors reported in CI test suites. Will update once all the issues are resolved.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459506386



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       I am sorry for the confusion!
   We can definitely add more comments to it. Let me put an explanation to the logic, then we can figure out the necessary comment for it together.
   
   Above case will hit in 2 cases:
   case 1(Partial Match found): 
   case 2 (No Match found) : 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460069525



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       Ah! I understand now. Why the confusion. Sorry if my explanation was poor earlier. 
   
   Actually in your assumption above :
   It won't iterate till step 3, execute till step 2 only. 
   
   Because the for loop cond is
   ( i < ordered_for_list_.size() - 1) 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame merged pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
MarisaKirisame merged pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459902322



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);

Review comment:
       Thanks for you detail reply! Now i understand your point clearly.
   Hypothetically if we have any such expression, yes it don't do hoisting. 
   
   But i think these cases wont be there, because the loop splitting and fusion generates related blocks in a bound scope.
   So in this scope we should not change the order of if and for nodes, where optimization technique solely depend on it.
   I am not sure when the condition `If this condition is inside another loop, say Loop j, it should be hoisted.`
   occurs. May be when we do out of order fusion i think, which is not a real-time use case.
   
   However in case we really want to support this, may be we can add this Hoisting logic, only after all passes when all optimizations are done with a special config, which will scan only for such cases or we can add hoisting always after all the passes. I am not too sure about the benefit, maybe we can ask other's opinion too. Just to make sure all are in sync. :)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460083344



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       Oh, got it, but why? Why not change it to `(i < ordered_for_list_.size())` so we don't need to check `IsParentForLoop`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-663669923


   cc @ZihengJiang @merrymercy @Hzfengsy @kevinthesun @junrushao1994 @spectrometerHBH  please help to take  a look


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-668234871


   Thanks @ANSHUMAN87 @roastduck @kevinthesun @kparzysz-quic @hzfan 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-663817727


   > This is a limited case of "loop unswitching". Please consider a more general solution, where
   > 
   > ```
   > for (i = 0..N) {
   >   // statement A
   >   if (invariant-condition) {
   >     // statement B
   >   } else {
   >     // statement C
   >   }
   >   // statement D
   > }
   > ```
   > 
   > is transformed into
   > 
   > ```
   > if (invariant-condition) {
   >   for (i = 0..N) {
   >     // statement A
   >     // statement B
   >     // statement D
   >   }
   > } else {
   >   for (i = 0..N) {
   >     // statement A
   >     // statement C
   >     // statement D
   >   }
   > }
   > ```
   > 
   > Using the same logic you could unswitch attribute statements, if needed.
   
   Thanks for bringing up this point!
   In fact this is on my top TODO list once current PR is merged.
   Here it is not the question of general solution, it is about covering more scenarios.
   Current changes covers most of the real-time scenarios (excluding hypothetical scenarios).
   This scenario is currently left uncovered intentionally because of following challenges.
   
   1. Invariant identification: What is the best way to identify the if condition variables are not updated inside the block ?
        For example in below case, how we can find out `var n` is invariant in the loop block:
   ```
       var n = 3
       for (i: int32, 0, n: int32) {
         if @tir.likely((n > 23), dtype=bool) {
           n = n + 2      
           data[i] = i
         } else {
           n = n + 3
           data[i] = i + 1
         }
       }
   
   ```
   
   2. Scope identification: Need to scan each and every statement to do this, where logic  becomes significantly more complex.
   
   May be we can discuss more about first Challenge. Please suggest if anyone has any optimum solution for it. 
   Let me know in case the scenario is not clear. TIA!
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459955697



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       The loop above runs only n-2 times, in this case it is 1 time, that's why the value for target_for_pos = 1. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen edited a comment on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-667180453






----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-665651723


   > Gentle ping @tqchen , @ZihengJiang @merrymercy @Hzfengsy @kevinthesun @junrushao1994 @spectrometerHBH @wpan11nv @kparzysz-quic !!!
   > 
   > Let us discuss and bring a conclusion to the open points / challenges mentioned in my previous comment. TIA!
   
   also cc @MarisaKirisame !!! 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460367771



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() { is_recorder_on = false; }
+
+  bool IsRecordingOn() { return is_recorder_on; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    is_recorder_on = true;
+    if (!var_for_map_.count(op->loop_var.get())) {
+      var_for_map_.insert({op->loop_var.get(), op});
+    }
+    ordered_for_list_.emplace_back(op);
+  }
+
+  void RemoveRecord(const ForNode* op) {
+    StopRecording();
+    var_for_map_.erase(op->loop_var.get());
+    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();
+  }
+
+  void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
+    hoist_for_if_recorder = std::make_tuple(true, for_node, if_node);
+    StopRecording();
+  }
+
+  void UpdateAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.insert(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.insert(iv);
+    }
+  }
+
+  void RemoveAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.erase(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.erase(iv);
+    }
+  }
+
+  bool CheckAttrVar() {
+    for (auto var : if_var_list_) {
+      if (attr_var_list_.count(var)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  std::vector<const ForNode*> ordered_for_list_;
+
+  std::vector<const VarNode*> if_var_list_;
+
+  std::unordered_set<const VarNode*> attr_var_list_;
+
+  VarForMap var_for_map_;
+
+  bool is_if_cond{false};
+  bool is_recorder_on{false};

Review comment:
       Thanks @Hzfengsy for review!
   I totally agree with your comment. Will definitely handle it.
   It is just that, through out TVM base code, these kind of naming convention are followed differently.
   Sometimes i get confused which one to follow :)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459506386



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       I am sorry for the confusion!
   We can definitely add more comments to it. Let me put an explanation to the logic, then we can figure out the necessary comment for it together.
   
   If you see the loop above it to find matching loop with if condition variables runs only (n - 1) times (when there are n no of candidate loops in the recorder). 
   
   Above case will hit in 2 cases:
   case 1(Partial Match found): Like below:
   ```
   for (i: int32, 0, l: int32) {
     for (j: int32, 0, m: int32) {
       if @tir.likely(((i + j) < 2), dtype=bool) {
         for (k: int32, 0, n: int32) {
           m
         }
       } else {
         for (k, 0, n) {
           n
         }
       }
     }
   }
   ```
   case 2 (No Match found) :  Like Below:
   ```
   for (i: int32, 0, l: int32) {
     for (j: int32, 0, m: int32) {
       data: handle[((i*3) + j)] = ((float32*)data[((i*3) + j)] + 0.5f32)
       for (k: int32, 0, n: int32) {
         if @tir.likely((i < 2), dtype=bool) {
           m
         } else {
           n
         }
       }
     }
   }
   ```
   
   I hope i am clear. Please let me know in case you have any query.
   Now will help me suggest, what comment should we put there, so that it will help understand others better. TIA!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-665762811


   0: I dont think a pass have to handle all cases for it to be merged. We can improve upon it incrementally.
   1: About loop unswitching, detecting invariant is very hard. But we can simply approximate by detecting expr that didnt have any vars change in the loop. There is rice theorem, so all solution will not be perfect, and we should do the least amount of work that get us good enough performance.
   2: you should just build a datastructure/helper function such as boundvar/freevar to help. I dont found dealing with scope too complex in the pass I had written. If you had more specific question/objection please bring it up.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460128717



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() {
+    if (is_recorder_on) is_recorder_on = false;
+  }
+
+  bool IsRecordingOn() { return is_recorder_on; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    if (!is_recorder_on) is_recorder_on = true;
+    if (!var_for_map_.count(op->loop_var.get())) {
+      var_for_map_.insert({op->loop_var.get(), op});
+    }
+    ordered_for_list_.emplace_back(op);
+  }
+
+  void RemoveRecord(const ForNode* op) {
+    StopRecording();
+    var_for_map_.erase(op->loop_var.get());
+    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();

Review comment:
       Yes, you are right. `AddRecord() & RemoveRecord()` should be in pair in normal scenario.
   But when we encounter a case, when we want to purge all recordings and start all over again(being in the call stack), in that case, the call sequence will be `AddRecord() -> ResetRecord() -> RemoveRecord()`, 
   for example in case of `ForNode -> SeqStmtNode -> ForNode`.
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kparzysz-quic commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
kparzysz-quic commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460310403



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {

Review comment:
       Put the short path under the if-statement, to avoid unnecessary indentation, i.e.
   ```
   if (RecordingComplete()) {
     StmtExprVisitor::VisitStmt_(op);
     return;
   }
   // rest of code
   ```

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {

Review comment:
       Especially here, move the `VisitStmt` from [line 173](https://github.com/apache/incubator-tvm/pull/6066/files#diff-1979d9c6d4050100d8c34e00b032c1c7R173) under the if, and unindent everything else.

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() { is_recorder_on = false; }
+
+  bool IsRecordingOn() { return is_recorder_on; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    is_recorder_on = true;
+    if (!var_for_map_.count(op->loop_var.get())) {
+      var_for_map_.insert({op->loop_var.get(), op});
+    }
+    ordered_for_list_.emplace_back(op);
+  }
+
+  void RemoveRecord(const ForNode* op) {
+    StopRecording();
+    var_for_map_.erase(op->loop_var.get());
+    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();
+  }
+
+  void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
+    hoist_for_if_recorder = std::make_tuple(true, for_node, if_node);
+    StopRecording();
+  }
+
+  void UpdateAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.insert(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.insert(iv);
+    }
+  }
+
+  void RemoveAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.erase(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.erase(iv);
+    }
+  }
+
+  bool CheckAttrVar() {
+    for (auto var : if_var_list_) {
+      if (attr_var_list_.count(var)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  std::vector<const ForNode*> ordered_for_list_;
+
+  std::vector<const VarNode*> if_var_list_;
+
+  std::unordered_set<const VarNode*> attr_var_list_;
+
+  VarForMap var_for_map_;
+
+  bool is_if_cond{false};
+  bool is_recorder_on{false};
+};
+
+class IfThenElseHoister : public StmtMutator {
+ public:
+  IfThenElseHoister() : hoist_selector(HoistCandidateSelector()) {}
+
+  Stmt VisitAndMutate(Stmt stmt) {
+    hoist_selector(stmt);
+    Stmt stmt_copy = std::move(stmt);
+
+    while (hoist_selector.RecordingComplete()) {
+      target_for = hoist_selector.GetTargetForNode();
+      target_if = hoist_selector.GetTargetIfNode();
+
+      stmt_copy = operator()(stmt_copy);
+
+      hoist_selector.ResetRecorder();
+      hoist_selector(stmt_copy);
+    }
+
+    // Support SSA Form
+    stmt_copy = ConvertSSA(stmt_copy);
+    return stmt_copy;
+  }
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    if ((!is_updating) && (target_for == op)) {
+      is_updating = true;
+      is_then_case = true;
+      Stmt then_case = StmtMutator::VisitStmt_(op);
+      is_then_case = false;
+      Stmt else_case = Stmt();
+      if (target_if->else_case.defined()) {
+        else_case = StmtMutator::VisitStmt_(op);

Review comment:
       `VisitStmt(op->else_case)`
   
   Then you can drop the `is_then_case` variable.

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() { is_recorder_on = false; }
+
+  bool IsRecordingOn() { return is_recorder_on; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    is_recorder_on = true;
+    if (!var_for_map_.count(op->loop_var.get())) {
+      var_for_map_.insert({op->loop_var.get(), op});
+    }
+    ordered_for_list_.emplace_back(op);
+  }
+
+  void RemoveRecord(const ForNode* op) {
+    StopRecording();
+    var_for_map_.erase(op->loop_var.get());
+    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();
+  }
+
+  void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
+    hoist_for_if_recorder = std::make_tuple(true, for_node, if_node);
+    StopRecording();
+  }
+
+  void UpdateAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.insert(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.insert(iv);
+    }
+  }
+
+  void RemoveAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.erase(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.erase(iv);
+    }
+  }
+
+  bool CheckAttrVar() {
+    for (auto var : if_var_list_) {
+      if (attr_var_list_.count(var)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  std::vector<const ForNode*> ordered_for_list_;
+
+  std::vector<const VarNode*> if_var_list_;
+
+  std::unordered_set<const VarNode*> attr_var_list_;
+
+  VarForMap var_for_map_;
+
+  bool is_if_cond{false};
+  bool is_recorder_on{false};
+};
+
+class IfThenElseHoister : public StmtMutator {
+ public:
+  IfThenElseHoister() : hoist_selector(HoistCandidateSelector()) {}
+
+  Stmt VisitAndMutate(Stmt stmt) {
+    hoist_selector(stmt);
+    Stmt stmt_copy = std::move(stmt);
+
+    while (hoist_selector.RecordingComplete()) {
+      target_for = hoist_selector.GetTargetForNode();
+      target_if = hoist_selector.GetTargetIfNode();
+
+      stmt_copy = operator()(stmt_copy);
+
+      hoist_selector.ResetRecorder();
+      hoist_selector(stmt_copy);
+    }
+
+    // Support SSA Form
+    stmt_copy = ConvertSSA(stmt_copy);
+    return stmt_copy;
+  }
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    if ((!is_updating) && (target_for == op)) {
+      is_updating = true;
+      is_then_case = true;
+      Stmt then_case = StmtMutator::VisitStmt_(op);

Review comment:
       `VisitStmt(op->then_case)`

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;

Review comment:
       You don't really need this type, since there is only one object for it.  You can just declare 3 separate variables inside the class.

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;

Review comment:
       You can replace this whole function with
   ```
   return !if_var_list_.empty() && !CheckVarAttr();
   ```

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() { is_recorder_on = false; }
+
+  bool IsRecordingOn() { return is_recorder_on; }

Review comment:
       Both `StopRecording` and `IsRecordingOn` aren't really adding any value, you can use the variable directly.

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;

Review comment:
       `return std::get<0>(hoist_for_if_recorder);`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kparzysz-quic commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
kparzysz-quic commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-663767437


   This is a limited case of "loop unswitching".  Please consider a more general solution, where
   ```
   for (i = 0..N) {
     // statement A
     if (invariant-condition) {
       // statement B
     } else {
       // statement C
     }
     // statement D
   }
   ```
   is transformed into
   ```
   if (invariant-condition) {
     for (i = 0..N) {
       // statement A
       // statement B
       // statement D
     }
   } else {
     for (i = 0..N) {
       // statement A
       // statement C
       // statement D
     }
   }
   ```
   
   Using the same logic you could unswitch attribute statements, if needed.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459511992



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);

Review comment:
       Thanks for bringing this up!
   In fact i also had the same idea initially, and it was working perfectly fine.
   But during CI failures i discovered, there are some dependency on the positioning of these If statements with global scope variables. So i have to put this logic to avoid hoisting for any such cases. 
   But as i mentioned in the comment as well, if you have any specific case to enable hoisting, we can add it, provided it does not violate other Pass logic.
   
   This is the [link](https://github.com/apache/incubator-tvm/blob/06d756563d805e8a12dac84c6372071f35457e4f/tests/python/unittest/test_te_schedule_ops.py#L531) for the test case failure, when this logic was absent.
   
   Please let me know your thought on this. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459853815



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       Sorry the `target_for_pos` values will be different in both cases. 
   For partial match or exact match, it can be [0, n-1] and if no match found it will be 0. To accommodate all the scenarios, hence the common logic :)
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-667356063


   I'll take a look in the next few days.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460367606



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() { is_recorder_on = false; }
+
+  bool IsRecordingOn() { return is_recorder_on; }

Review comment:
       Yes you are right! But i prefer the func as it is self explanatory, and easily upgradable. Lets keep it.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun edited a comment on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
kevinthesun edited a comment on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-667357815


   One thing I think it is good to have in this PR is to get some benchmark data, since we now enable this pass by default. IMO this pass is especially valuable when tackling dynamic kernels in GPU which introduces a lot of branchings. It's great to see how much performance improvement we can have. For CPU, we need to make sure this pass doesn't introduce regression for some common workloads, such as resnet.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-666461607


   I think we can go ahead and merge the PR, as most of the scenarios are handled.
   I will keep loop unswitching as open point, once it is handled will raise new PR for it.
   And we can keep adding new useful scenarios, as and when someone reports it.
   
   @tqchen : Would you please share your opinion on this. TIA!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460117369



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() {
+    if (is_recorder_on) is_recorder_on = false;
+  }
+
+  bool IsRecordingOn() { return is_recorder_on; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    if (!is_recorder_on) is_recorder_on = true;
+    if (!var_for_map_.count(op->loop_var.get())) {
+      var_for_map_.insert({op->loop_var.get(), op});
+    }
+    ordered_for_list_.emplace_back(op);
+  }
+
+  void RemoveRecord(const ForNode* op) {
+    StopRecording();
+    var_for_map_.erase(op->loop_var.get());
+    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();

Review comment:
       I understand that it is undefined.
   However, what I dont understand is, shouldnt you call RemoveRecord as much time as you call AddRecord()/ResetRecord()?
   Like, you call this to remove a single record, and if there isnt a single record, it should be the caller's fault, and the caller should fix the code.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459838669



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       In both cases above, `target_for_pos` ends up to be the inner-most loop, i.e., equals to `n`. I think we only have to check whether `target_for_pos < n`, instead of running `IsParentForLoop`. Am I right?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460106726



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() {
+    if (is_recorder_on) is_recorder_on = false;
+  }
+
+  bool IsRecordingOn() { return is_recorder_on; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    if (!is_recorder_on) is_recorder_on = true;
+    if (!var_for_map_.count(op->loop_var.get())) {
+      var_for_map_.insert({op->loop_var.get(), op});
+    }
+    ordered_for_list_.emplace_back(op);
+  }
+
+  void RemoveRecord(const ForNode* op) {
+    StopRecording();
+    var_for_map_.erase(op->loop_var.get());
+    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();

Review comment:
       Thanks @MarisaKirisame for review!
   
   This is a safe check added as pop_back() behavior is undefined when there is no entry in the list.
   This scenario can hit if `ResetRecorder()` is called in between.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-667513096


   Everything is green, but since @kevinthesun want to review it I will wait to merge for a few day.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-667826737


   > > One thing I think it is good to have in this PR is to get some benchmark data, since we now enable this pass by default. IMO this pass is especially valuable when tackling dynamic kernels in GPU which introduces a lot of branchings. It's great to see how much performance improvement we can have. For CPU, we need to make sure this pass doesn't introduce regression for some common workloads, such as resnet.
   > 
   > @kevinthesun : I have verified the inference time for Resnet50 on CPU. There is no performance impact. In fact i did not find anything as Hoisting candidate.
   > 
   > ### Hoisting Disabled :
   > ![image](https://user-images.githubusercontent.com/32511895/89104943-9dde3a80-d43a-11ea-994d-85e1468c8241.png)
   > 
   > ### Hoisting Enabled:
   > ![image](https://user-images.githubusercontent.com/32511895/89104952-b2223780-d43a-11ea-8f92-e7fb4e555115.png)
   > 
   > Hope it helps. Please let me know, if i have mistaken anything. TIA!
   
   Usually there are two cases which might involve this pass: 1) Loop tiling with non-factor split. 2) Dynamic shape op. If I remember correctly, a conv2d with symbolic batch size will generate an IR with a lot of hoist candidates. Due to the limitation of nvcc, performance for such as kernel is quite terrible and this pass is able to handle this. It would be nice if we can verify this case for GPU and add a unit test.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459839683



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);

Review comment:
       Maybe the problem with that specific test is that some optimizations for threads require an exact positioning of the `if`s. Maybe we can put our `HoistIfThenElse` pass after that particular optimization pass.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] kevinthesun commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
kevinthesun commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-667357815


   One thing I think it is good to have in this PR is to get some benchmark data. IMO this pass is especially valuable when tackling dynamic kernels in GPU which introduces a lot of branchings. It's great to see how much performance improvement we can have. For CPU, we need to make sure this pass doesn't introduce regression.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-660601199


   Since nearly all programs will be touched by this pass, potential bugs in this pass would be critical. Could you (maybe temporarily) add this pass into the default building procedure and run all the tests? It would greatly reduce potential bugs.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-668640265


   Thanks a lot @MarisaKirisame , @roastduck @kevinthesun @kparzysz-quic @Hzfengsy, @tqchen !
   
   This PR has 2 open points as per discussion with all the members participated.
   
   Summarizing as below:
   ```
   Open point 1:- Support hoisting for Attr Nodes with IterVar & GlobalVar(GPU kernel case).
   Open point 2:- Support for loop unswitching case.
   ```
   
   I have kept it on my TODO list. Will ensure support of these cases in my future PRs.
   However i request to all if anyone finds any more scenarios which need to be covered, please raise an issue and tag me, so that i can handle that too. TIA!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459934599



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       Why is `target_for_pos = 1`?
   
   ```c++
   if (match_for_loop_pos < i) {
     match_for_loop_pos = i;
   }
   ```
   
   This update hits twice, once for `i` and one for `k`, and `match_for_loop_pos` ends up to be 2, so `target_for_pos == 3`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460138057



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       That's fine. We can merge this PR first.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-668167439


   @kevinthesun : Thanks for your response! 
   I am familiar with the scenario you are referring for conv op. I agree to the possibilities of hoisting in that case.
   The same concern was raised by @roastduck too, i believe. 
   Unfortunately due to some intermediate pass dependency on the positioning of For and If node, i temporarily disabled this hoisting scenario with Attr nodes(IterVar & GlobalVar).
   
   I think it would be better, i enable this hoisting scenario too and move the position of Hoisting Pass to the end of list during lowering. May be after that we can take the performance data.
   
   Please let me know your opinion on this. TIA!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-667180453


   Need explicit approval https://tvm.apache.org/docs/contribute/code_review.html#approve-and-request-changes-explicitly from @kparzysz-quic @MarisaKirisame .
   
   The most important thing is the code clearity(others can understand the logic) and correctness.
   
   cc @ZihengJiang @junrushao1994 it would be great if you can also take a look 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-667488801


   @kevinthesun : Thanks a lot for your input! I believe all the cases which are covered now, does not degrade performance in any case either CPU or GPU :)
   Do you have any suggestion how to obtain the benchmark data(like some test case or some existing tools) ? 
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459188187



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;

Review comment:
       Can we use explicit types here? For example `std::unordered_map<const VarNode*, const ForNode*>`.

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);

Review comment:
       Does this function mean if there are any variables defined in an `attr` in the condition of an `if`, then that `if` will not be hoisted? Can we only stop hoisting if the `attr` is inside the out-most for loop? Since the `threadIdx` or `blockIdx` `attr` is very commonly used in `if`s, and these `attr`s are outside of all the loops, so they should not block the hoisting.

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {

Review comment:
       I didn't figure out in which case `target_for_pos` found above would not be a parent. Could you add some more comments here?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] Hzfengsy commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r460354269



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() { is_recorder_on = false; }
+
+  bool IsRecordingOn() { return is_recorder_on; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    is_recorder_on = true;
+    if (!var_for_map_.count(op->loop_var.get())) {
+      var_for_map_.insert({op->loop_var.get(), op});
+    }
+    ordered_for_list_.emplace_back(op);
+  }
+
+  void RemoveRecord(const ForNode* op) {
+    StopRecording();
+    var_for_map_.erase(op->loop_var.get());
+    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();
+  }
+
+  void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
+    hoist_for_if_recorder = std::make_tuple(true, for_node, if_node);
+    StopRecording();
+  }
+
+  void UpdateAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.insert(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.insert(iv);
+    }
+  }
+
+  void RemoveAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.erase(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.erase(iv);
+    }
+  }
+
+  bool CheckAttrVar() {
+    for (auto var : if_var_list_) {
+      if (attr_var_list_.count(var)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  std::vector<const ForNode*> ordered_for_list_;
+
+  std::vector<const VarNode*> if_var_list_;
+
+  std::unordered_set<const VarNode*> attr_var_list_;
+
+  VarForMap var_for_map_;
+
+  bool is_if_cond{false};
+  bool is_recorder_on{false};

Review comment:
       ```suggestion
     bool is_if_cond_{false};
     bool is_recorder_on_{false};
   ```

##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const VarNode*, const ForNode*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);
+  }
+
+  void VisitStmt_(const IfThenElseNode* op) final {
+    if (IsRecordingOn()) {
+      is_if_cond = true;
+      StmtExprVisitor::VisitExpr(op->condition);
+      is_if_cond = false;
+
+      if (CheckValidIf()) {
+        // Check corresponding for loop
+        bool match_found = false;
+        size_t match_for_loop_pos = 0;
+        for (auto var : if_var_list_) {
+          for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
+            if (ordered_for_list_[i] == var_for_map_[var]) {
+              if (match_for_loop_pos < i) {
+                match_for_loop_pos = i;
+              }
+              match_found = true;
+              break;
+            }
+          }
+        }
+        // If none of the for loop has the matching loop variable as if condition,
+        // then the if node need to be hoisted on top of all, provided no parent loop exists.
+        int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;
+
+        // Check if target for loop is not the parent of current if node
+        if (!IsParentForLoop(target_for_pos)) {
+          StopAndAddRecord(ordered_for_list_[target_for_pos], op);
+        }
+      }
+      if_var_list_.clear();
+      StmtExprVisitor::VisitStmt_(op);
+      StopRecording();
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const VarNode* op) final {
+    if (is_if_cond) {
+      if_var_list_.emplace_back(op);
+    }
+  }
+
+  HoistForIfTuple hoist_for_if_recorder;
+
+  void ResetRecorder() {
+    if (is_recorder_on) {
+      CHECK_GT(ordered_for_list_.size(), 0);
+      is_recorder_on = false;
+    }
+    ordered_for_list_.clear();
+    var_for_map_.clear();
+    hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
+  }
+
+  bool RecordingComplete() {
+    if (std::get<0>(hoist_for_if_recorder)) return true;
+    return false;
+  }
+
+  const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }
+
+  const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); }
+
+ private:
+  bool CheckValidIf() {
+    // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
+    // hoisting
+    if (if_var_list_.size() == 0) {
+      return false;
+    }
+    if (CheckAttrVar()) {
+      return false;
+    }
+    return true;
+  }
+
+  bool IsParentForLoop(int loop_pos) {
+    // Check if the loop position is higher than the parent loop position
+    for (auto var : if_var_list_) {
+      if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  int GetParentLoopPos(const Object* node) {
+    for (size_t i = 0; i < ordered_for_list_.size(); ++i) {
+      if (ordered_for_list_[i] == node) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
+  void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }
+
+  void StopRecording() { is_recorder_on = false; }
+
+  bool IsRecordingOn() { return is_recorder_on; }
+
+  void StartOrAddRecord(const ForNode* op) {
+    is_recorder_on = true;
+    if (!var_for_map_.count(op->loop_var.get())) {
+      var_for_map_.insert({op->loop_var.get(), op});
+    }
+    ordered_for_list_.emplace_back(op);
+  }
+
+  void RemoveRecord(const ForNode* op) {
+    StopRecording();
+    var_for_map_.erase(op->loop_var.get());
+    if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back();
+  }
+
+  void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) {
+    hoist_for_if_recorder = std::make_tuple(true, for_node, if_node);
+    StopRecording();
+  }
+
+  void UpdateAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.insert(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.insert(iv);
+    }
+  }
+
+  void RemoveAttrVarList(const AttrStmtNode* op) {
+    if (const auto* iv = op->node.as<IterVarNode>()) {
+      attr_var_list_.erase(iv->var.get());
+    } else if (const auto* iv = op->node.as<VarNode>()) {
+      attr_var_list_.erase(iv);
+    }
+  }
+
+  bool CheckAttrVar() {
+    for (auto var : if_var_list_) {
+      if (attr_var_list_.count(var)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  std::vector<const ForNode*> ordered_for_list_;
+
+  std::vector<const VarNode*> if_var_list_;
+
+  std::unordered_set<const VarNode*> attr_var_list_;
+
+  VarForMap var_for_map_;
+
+  bool is_if_cond{false};
+  bool is_recorder_on{false};
+};
+
+class IfThenElseHoister : public StmtMutator {
+ public:
+  IfThenElseHoister() : hoist_selector(HoistCandidateSelector()) {}
+
+  Stmt VisitAndMutate(Stmt stmt) {
+    hoist_selector(stmt);
+    Stmt stmt_copy = std::move(stmt);
+
+    while (hoist_selector.RecordingComplete()) {
+      target_for = hoist_selector.GetTargetForNode();
+      target_if = hoist_selector.GetTargetIfNode();
+
+      stmt_copy = operator()(stmt_copy);
+
+      hoist_selector.ResetRecorder();
+      hoist_selector(stmt_copy);
+    }
+
+    // Support SSA Form
+    stmt_copy = ConvertSSA(stmt_copy);
+    return stmt_copy;
+  }
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    if ((!is_updating) && (target_for == op)) {
+      is_updating = true;
+      is_then_case = true;
+      Stmt then_case = StmtMutator::VisitStmt_(op);
+      is_then_case = false;
+      Stmt else_case = Stmt();
+      if (target_if->else_case.defined()) {
+        else_case = StmtMutator::VisitStmt_(op);
+      }
+      is_updating = false;
+      return IfThenElse(target_if->condition, then_case, else_case);
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  Stmt VisitStmt_(const IfThenElseNode* op) final {
+    if (is_updating && (target_if == op)) {
+      if (is_then_case) {
+        return StmtMutator::VisitStmt(op->then_case);
+      } else if (op->else_case.defined()) {
+        return StmtMutator::VisitStmt(op->else_case);
+      }
+    }
+    return StmtMutator::VisitStmt_(op);
+  }
+
+  const ForNode* target_for;
+  const IfThenElseNode* target_if;
+
+ private:
+  bool is_updating{false};
+  bool is_then_case{false};
+  HoistCandidateSelector hoist_selector;

Review comment:
       ```suggestion
     bool is_updating_{false};
     bool is_then_case_{false};
     HoistCandidateSelector hoist_selector_;
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 commented on a change in pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 commented on a change in pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#discussion_r459511992



##########
File path: src/tir/transforms/hoist_if_then_else.cc
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file hoist_if_then_else.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../../arith/interval_set.h"
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_util.h"
+
+namespace tvm {
+namespace tir {
+
+using VarForMap = std::unordered_map<const Object*, const Object*>;
+using HoistForIfTuple = std::tuple<bool, const ForNode*, const IfThenElseNode*>;
+
+/*
+ * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
+ * For example, given the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
+ * Then we hoist IfThenElse stmt by one For stmt each step:
+ *
+ * Step 1:
+ * for (i = 0; i < 3; i++)
+ *     for (j = 0; j < 4; j++)
+ *         if (likely(i*2 < 4))
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Step 2:
+ * for (i = 0; i < 3; i++)
+ *     if (likely(i*2 < 4))
+ *         for (j = 0; j < 4; j++)
+ *             for (k = 0; k < 5; k++)
+ *                 A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * In this pass, we only continue detecting possible hoisting chance when visiting For,
+ * IfThenElse or AttrStmt Node. For example, for the following block:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * Only the For with k variable will be considered and the resulting stmt would be:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        A[i + j] = A[i + j] - 1
+ *        if (likely(i*2 < 4))
+ *            for (k = 0; k < 5; k++)
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *
+ * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
+ * block won't be optimized:
+ * for (i = 0; i < 3; i++)
+ *    for (j = 0; j < 4; j++)
+ *        for (k = 0; k < 5; k++)
+ *            if (likely(i*2 < 4))
+ *                A[3*i+2j+k] = B[7*i+3j+k]
+ *            if (likely(j > 2))
+ *                A[i+j+k] = B[i+j+k]
+ *
+ */
+
+// Select potential candidate IRs that can be hoisted.
+class HoistCandidateSelector final : public StmtExprVisitor {
+ public:
+  HoistCandidateSelector() { InitRecorder(); }
+
+  void VisitStmt_(const ForNode* op) final {
+    // Check if it is first for loop, then start the recorder
+    if (!RecordingComplete()) {
+      StartOrAddRecord(op);
+      StmtExprVisitor::VisitStmt_(op);
+      RemoveRecord(op);
+      return;
+    }
+
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const SeqStmtNode* op) final {
+    // If SeqStmt is encountered in the middle of recording
+    //  then need to purge all, as it can not be hoisted
+    if (IsRecordingOn()) {
+      ResetRecorder();
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Maintain list of all vars in AttrStmt
+    // To stop hoisting if any of the block variables are used.
+    //
+    // NOTE: If in future
+    // hoisting is required for any specific case,
+    // then add exception to only those case
+    // rather than allowing for all.
+    UpdateAttrVarList(op);
+    StmtExprVisitor::VisitStmt_(op);
+    RemoveAttrVarList(op);

Review comment:
       Thanks for bringing this up!
   In fact i also had the same idea initially, and it was working perfectly fine.
   But during CI failures i discovered, there are some dependency on the positioning of these If statements with global scope variables. So i have to put this logic to avoid hoisting for any such cases. 
   But as i mentioned in the comment as well, if you have any specific case to enable hoisting, we can add it, provided it does not violate other Pass logic.
   
   Please let me know your thought on this. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ANSHUMAN87 edited a comment on pull request #6066: [TIR][Transform] HoistIfThenElse added

Posted by GitBox <gi...@apache.org>.
ANSHUMAN87 edited a comment on pull request #6066:
URL: https://github.com/apache/incubator-tvm/pull/6066#issuecomment-667549552


   > One thing I think it is good to have in this PR is to get some benchmark data, since we now enable this pass by default. IMO this pass is especially valuable when tackling dynamic kernels in GPU which introduces a lot of branchings. It's great to see how much performance improvement we can have. For CPU, we need to make sure this pass doesn't introduce regression for some common workloads, such as resnet.
   
   @kevinthesun : I have verified the inference time for Resnet50 on CPU. There is no performance impact. In fact i did not find anything as Hoisting candidate.
   
   ### Hoisting Disabled :
   ![image](https://user-images.githubusercontent.com/32511895/89104943-9dde3a80-d43a-11ea-994d-85e1468c8241.png)
   
   
   ### Hoisting Enabled:
   ![image](https://user-images.githubusercontent.com/32511895/89104952-b2223780-d43a-11ea-8f92-e7fb4e555115.png)
   
   Hope it helps. Please let me know, if i have mistaken anything. TIA!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org