You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/01/15 13:13:35 UTC

[GitHub] [tvm] Hzfengsy opened a new pull request #9940: [MetaSchedule] random compute location

Hzfengsy opened a new pull request #9940:
URL: https://github.com/apache/tvm/pull/9940


   This PR is one of the search rule for MetaSchedule.
   
   Thanks to all co-authors for contributing!
   
   Co-authored-by: Junru Shao <ju...@gmail.com>
   Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
   Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
   Co-authored-by: Ruihang Lai <la...@qq.com>
   Co-authored-by: Hongyi Jin <32...@qq.com>
   Co-authored-by: Wuwei Lin <wu...@apache.org>
   
   cc @junrushao1994 @comaniac @jcf94 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9940: [MetaSchedule] random compute location

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9940:
URL: https://github.com/apache/tvm/pull/9940#issuecomment-1014222632


   @Hzfengsy let's address the last comment and get it merged ASAP


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on pull request #9940: [MetaSchedule] random compute location

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on pull request #9940:
URL: https://github.com/apache/tvm/pull/9940#issuecomment-1014292517


   Addressed. Please take another look!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 merged pull request #9940: [MetaSchedule] random compute location

Posted by GitBox <gi...@apache.org>.
junrushao1994 merged pull request #9940:
URL: https://github.com/apache/tvm/pull/9940


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9940: [MetaSchedule] random compute location

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9940:
URL: https://github.com/apache/tvm/pull/9940#discussion_r785481018



##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at location (specified by
+    // the annotation), we colect the producer first, and transform the producer block later.
+    // - The reason we collect the producer before transforming the input block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }

Review comment:
       Sounds great! Thanks for the explanation




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #9940: [MetaSchedule] random compute location

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #9940:
URL: https://github.com/apache/tvm/pull/9940#discussion_r785431614



##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at location (specified by
+    // the annotation), we colect the producer first, and transform the producer block later.
+    // - The reason we collect the producer before transforming the input block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }

Review comment:
       1. We have potential twice compute_at: the first is moving the current block to consumers and then compute_at the producer;
   2. `inline` is one of the choices of `RandomlyComputeAt`, so it is allowed. 
   3. One corner case would be multi-stage producers, i.e. `block_1` -> `block_2` -> `input_block`. In this case `block_1` will never be `compute_at`. But we have no such cases in realworld usage.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9940: [MetaSchedule] random compute location

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9940:
URL: https://github.com/apache/tvm/pull/9940#issuecomment-1013777982


   Looks like there is some flakiness in "CI / Windows". Retriggered


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9940: [MetaSchedule] random compute location

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9940:
URL: https://github.com/apache/tvm/pull/9940#discussion_r785480957



##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at location (specified by
+    // the annotation), we colect the producer first, and transform the producer block later.
+    // - The reason we collect the producer before transforming the input block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }
+
+    return {res};
+  }
+
+ private:
+  bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
+    const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
+    const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+
+    // Cond 1. The block is not the root block.
+    if (block_sref->parent == nullptr) {
+      return false;
+    }
+    // Cond 2. The block should be the direct child block of the root block.
+    if (GetScopeRoot(sch->state(), block_sref,          //
+                     /*require_stage_pipeline=*/false,  //
+                     /*require_subtree_compact_dataflow=*/false)
+            ->parent != nullptr) {
+      return false;
+    }
+    // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child
+    // block.
+    Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
+    if (loop_srefs.empty()) {
+      return false;
+    }
+    if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
+      return false;
+    }
+    // Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
+    if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined()) {
+      return false;
+    }
+    // Cond 6. The block has at lease one consumer.
+    if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
+      return false;
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Keep sampling a compute-at location for the input block until success.
+   * \param sch The TIR schedule
+   * \param block_rv The block whose compute-at location is to be sampled
+   * \return The TIR schedule after transformation
+   */
+  tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
+    for (;;) {
+      tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
+      try {
+        sch->ComputeAt(block_rv, compute_at_loc, true);
+      } catch (const dmlc::Error& e) {
+        // ComputeAt fails, cleanup the following before re-try:
+        // 1) trace: instruction & decisions
+        // 2) sym_tab
+        sch->trace().value()->Pop();
+        sch->RemoveRV(compute_at_loc);
+        continue;
+      }
+      break;
+    }

Review comment:
       Note that ScheduleRules are run only once before the search, which takes much less than 1 second - there isn't need for optimizing their performance




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9940: [MetaSchedule] random compute location

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9940:
URL: https://github.com/apache/tvm/pull/9940#discussion_r785375497



##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at location (specified by
+    // the annotation), we colect the producer first, and transform the producer block later.

Review comment:
       ```suggestion
       // the annotation), we collect the producer first, and transform the producer block later.
   ```

##########
File path: src/tir/schedule/primitive/sampling.cc
##########
@@ -354,6 +354,40 @@ std::vector<int64_t> SamplePerfectTile(
   return result;
 }
 
+tir::StmtSRef SampleComputeLocation(tir::ScheduleState self,
+                                    support::LinearCongruentialEngine::TRandState* rand_state,
+                                    const StmtSRef& block_sref, Optional<Integer>* decision) {
+  // Step 1. Collect all possible compute-at locations.
+  Array<tir::StmtSRef> location_srefs;
+  std::vector<int> location_indices;
+  std::tie(location_srefs, location_indices) = CollectComputeLocation(self, block_sref);
+  ICHECK_EQ(location_srefs.size(), location_indices.size());
+
+  // Step 2. If there was a previous decision, keep the decision unchanged if it exists in the
+  // location candidates. Otherwise, pick the location before the previous decision.
+  // Step 3. If there was not a previous decision, sample a decision from the collected locations.
+  if (decision->defined()) {
+    int64_t old_decision = Downcast<Integer>(*decision)->value;
+    auto it = std::lower_bound(location_indices.begin(), location_indices.end(), old_decision);
+    int idx = it - location_indices.begin();
+
+    if (it != location_indices.end() && *it == old_decision) {
+      *decision = Integer(old_decision);
+      return location_srefs[idx];
+    } else if (it != location_indices.begin()) {
+      *decision = Integer(*--it);

Review comment:
       nit: just to make it a bit clearer :-)
   
   ```suggestion
         *decision = Integer(location_indices[idx - 1]);
   ```

##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at location (specified by
+    // the annotation), we colect the producer first, and transform the producer block later.
+    // - The reason we collect the producer before transforming the input block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }
+
+    return {res};
+  }
+
+ private:
+  bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
+    const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
+    const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+
+    // Cond 1. The block is not the root block.
+    if (block_sref->parent == nullptr) {
+      return false;
+    }
+    // Cond 2. The block should be the direct child block of the root block.
+    if (GetScopeRoot(sch->state(), block_sref,          //
+                     /*require_stage_pipeline=*/false,  //
+                     /*require_subtree_compact_dataflow=*/false)
+            ->parent != nullptr) {
+      return false;
+    }
+    // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child
+    // block.
+    Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
+    if (loop_srefs.empty()) {
+      return false;
+    }
+    if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
+      return false;
+    }
+    // Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
+    if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined()) {
+      return false;
+    }
+    // Cond 6. The block has at lease one consumer.
+    if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
+      return false;
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Keep sampling a compute-at location for the input block until success.
+   * \param sch The TIR schedule
+   * \param block_rv The block whose compute-at location is to be sampled
+   * \return The TIR schedule after transformation
+   */
+  tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
+    for (;;) {
+      tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
+      try {
+        sch->ComputeAt(block_rv, compute_at_loc, true);
+      } catch (const dmlc::Error& e) {
+        // ComputeAt fails, cleanup the following before re-try:
+        // 1) trace: instruction & decisions
+        // 2) sym_tab
+        sch->trace().value()->Pop();
+        sch->RemoveRV(compute_at_loc);
+        continue;
+      }
+      break;
+    }
+    return sch;
+  }
+
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode);
+};
+
+ScheduleRule ScheduleRule::RandomComputeLocation() {
+  ObjectPtr<RandomComputeLocationNode> n = make_object<RandomComputeLocationNode>();
+  return ScheduleRule(n);

Review comment:
       nit: can be merged into a single line:
   
   ```suggestion
     return ScheduleRule(make_object<RandomComputeLocationNode>());
   ```

##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at location (specified by
+    // the annotation), we colect the producer first, and transform the producer block later.
+    // - The reason we collect the producer before transforming the input block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }

Review comment:
       To make sure I understand correctly, it means `Compute-At` could potentially happen twice (when it comes with annotation): first compute the producer onto this block, then move this block somewhere to one of its consumers.
   
   Is there any corner case we potentially want to check carefully and disallow? For example, do we allow inline the producer?

##########
File path: src/tir/schedule/analysis/analysis.cc
##########
@@ -646,6 +646,152 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
   }
 }
 
+IterVarType GetLoopIterType(const StmtSRef& loop_sref) {
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const Var& loop_var = loop->loop_var;
+  int n_spatial = 0;
+  int n_reduce = 0;
+  int n_other = 0;
+  auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool {
+    if (const auto* realize = obj.as<BlockRealizeNode>()) {
+      const BlockNode* block = realize->block.get();
+      // Number of block vars and their bindings
+      ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size());
+      size_t n = realize->iter_values.size();
+      for (size_t i = 0; i < n; ++i) {
+        const IterVar& iter_var = block->iter_vars[i];
+        const PrimExpr& binding = realize->iter_values[i];
+        // Categorize the current block var
+        int* ref = nullptr;
+        if (iter_var->iter_type == IterVarType::kDataPar) {
+          ref = &n_spatial;
+        } else if (iter_var->iter_type == IterVarType::kCommReduce) {
+          ref = &n_reduce;
+        } else {
+          ref = &n_other;
+        }
+        // Visit the binding to see if `loop_var` appears
+        PostOrderVisit(binding, [&ref, &loop_var](const ObjectRef& obj) -> void {
+          if (obj.same_as(loop_var)) {
+            (*ref) += 1;
+          }
+        });
+      }
+      return false;
+    }
+    return true;
+  };
+  PreOrderVisit(loop->body, f_visit);
+  if (n_other) {
+    return IterVarType::kOpaque;
+  } else if (n_spatial && n_reduce) {
+    return IterVarType::kOpaque;
+  } else if (n_reduce) {
+    return IterVarType::kCommReduce;
+  } else {
+    return IterVarType::kDataPar;
+  }
+}
+
+StmtSRef GetSRefLowestCommonAncestor(const Array<StmtSRef>& srefs) {
+  CHECK(!srefs.empty()) << "ValueError: The input array is required to have at least one sref";
+
+  std::unordered_map<const StmtSRefNode*, size_t> sref_visited_cnt;
+  for (const StmtSRef& sref : srefs) {
+    const StmtSRefNode* p = sref.get();
+    while (p != nullptr) {
+      ++sref_visited_cnt[p];
+      p = p->parent;
+    }
+  }
+  size_t n_sref = srefs.size();
+  const StmtSRefNode* p = srefs[0].get();
+  while (p != nullptr && sref_visited_cnt[p] != n_sref) {
+    p = p->parent;
+  }
+  ICHECK(p != nullptr);
+  return GetRef<StmtSRef>(p);
+}
+
+std::pair<Array<StmtSRef>, std::vector<int>> CollectComputeLocation(const ScheduleState& self,
+                                                                    const StmtSRef& block_sref) {
+  Array<StmtSRef> location_srefs;
+  std::vector<int> location_indices;
+
+  // Step 1. Add the "compute-root" candidate. Add the "compute-inline" candidate if the block can
+  // be inlined.
+  if (CanComputeInline(self, block_sref)) {
+    location_srefs.push_back(StmtSRef::InlineMark());
+    location_indices.push_back(-2);
+  }
+  location_srefs.push_back(StmtSRef::RootMark());
+  location_indices.push_back(-1);
+
+  // Step 2. If the block has no consumer, there is no more candidate.
+  Array<StmtSRef> consumers = GetConsumers(self, block_sref);
+  if (consumers.empty()) {
+    return std::make_pair(location_srefs, location_indices);
+  }
+
+  // Step 3. Get the deepest loop that the input block can be computed at (namely "boundary"). If
+  // such a loop cannot be found, there is no more candidate and we just return.
+  StmtSRef loop_boundary = consumers.size() > 1 ? GetSRefLowestCommonAncestor(consumers)
+                                                : GetRef<StmtSRef>(consumers[0]->parent);
+  if (loop_boundary->StmtAs<ForNode>() == nullptr) {
+    return std::make_pair(location_srefs, location_indices);
+  }
+
+  // Step 4. Collect the loops outside the first consumer and locate the boundary loop. The position
+  // of the boundary loop reveals the number of possible additional candidates.
+  Array<StmtSRef> loop_srefs = GetLoops(consumers[0]);
+  size_t lca_pos =
+      std::find(loop_srefs.begin(), loop_srefs.end(), loop_boundary) - loop_srefs.begin();
+  ICHECK_LT(lca_pos, loop_srefs.size());
+  size_t n_candidate = lca_pos + 1;
+
+  // Step 5. Find the position of the deepest data-parallel loop among the candidate loops. This
+  // position is used for removing the unwanted candidates from the perspective of performance.
+  std::vector<IterVarType> loop_iter_types;
+  loop_iter_types.reserve(n_candidate);
+  int i_last_datapar = -1;
+  for (size_t i = 0; i < n_candidate; ++i) {
+    IterVarType iter_type = GetLoopIterType(loop_srefs[i]);

Review comment:
       We might want to improve the performance of this snippet in the future, but it doesn't look like it's the bottleneck now :-)

##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at location (specified by
+    // the annotation), we colect the producer first, and transform the producer block later.
+    // - The reason we collect the producer before transforming the input block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }
+
+    return {res};
+  }
+
+ private:
+  bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
+    const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
+    const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+
+    // Cond 1. The block is not the root block.
+    if (block_sref->parent == nullptr) {
+      return false;
+    }
+    // Cond 2. The block should be the direct child block of the root block.
+    if (GetScopeRoot(sch->state(), block_sref,          //
+                     /*require_stage_pipeline=*/false,  //
+                     /*require_subtree_compact_dataflow=*/false)
+            ->parent != nullptr) {
+      return false;
+    }
+    // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child
+    // block.
+    Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
+    if (loop_srefs.empty()) {
+      return false;
+    }
+    if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
+      return false;
+    }
+    // Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
+    if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined()) {
+      return false;
+    }
+    // Cond 6. The block has at lease one consumer.
+    if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
+      return false;
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Keep sampling a compute-at location for the input block until success.
+   * \param sch The TIR schedule
+   * \param block_rv The block whose compute-at location is to be sampled
+   * \return The TIR schedule after transformation
+   */
+  tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
+    for (;;) {
+      tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
+      try {
+        sch->ComputeAt(block_rv, compute_at_loc, true);
+      } catch (const dmlc::Error& e) {
+        // ComputeAt fails, cleanup the following before re-try:
+        // 1) trace: instruction & decisions
+        // 2) sym_tab
+        sch->trace().value()->Pop();
+        sch->RemoveRV(compute_at_loc);
+        continue;
+      }
+      break;
+    }

Review comment:
       The try-catch loop here is not desirable. Shall we use `CanComputeAt` in `Sample-Compute-Location` to make sure every outcome of it works?

##########
File path: python/tvm/meta_schedule/schedule_rule/__init__.py
##########
@@ -16,4 +16,6 @@
 Meta Schedule schedule rules are used for modification of
 blocks in a schedule. See also PostOrderApply.
 """
+

Review comment:
       no need for this blank line, i suppose

##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at location (specified by
+    // the annotation), we colect the producer first, and transform the producer block later.
+    // - The reason we collect the producer before transforming the input block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }
+
+    return {res};
+  }
+
+ private:
+  bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
+    const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);

Review comment:
       No need to use reference
   
   ```suggestion
       tir::StmtSRef block_sref = sch->GetSRef(block_rv);
   ```

##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {

Review comment:
       nit: add a blank line above




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9940: [MetaSchedule] random compute location

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9940:
URL: https://github.com/apache/tvm/pull/9940#discussion_r785383712



##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at location (specified by
+    // the annotation), we colect the producer first, and transform the producer block later.
+    // - The reason we collect the producer before transforming the input block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }
+
+    return {res};
+  }
+
+ private:
+  bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
+    const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
+    const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+
+    // Cond 1. The block is not the root block.
+    if (block_sref->parent == nullptr) {
+      return false;
+    }
+    // Cond 2. The block should be the direct child block of the root block.
+    if (GetScopeRoot(sch->state(), block_sref,          //
+                     /*require_stage_pipeline=*/false,  //
+                     /*require_subtree_compact_dataflow=*/false)
+            ->parent != nullptr) {
+      return false;
+    }
+    // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child
+    // block.
+    Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
+    if (loop_srefs.empty()) {
+      return false;
+    }
+    if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
+      return false;
+    }
+    // Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
+    if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined()) {
+      return false;
+    }

Review comment:
       You can add an analysis function named "HasBeenMultiLevelTiled" like this and use the function here, since the rule CrossThreadReduction also need this analysis function.
   
   https://github.com/junrushao1994/tvm/blob/meta-schedule/src/tir/schedule/analysis/analysis.cc#L1930-L1932
   
   https://github.com/junrushao1994/tvm/blob/meta-schedule/src/meta_schedule/schedule_rule/random_compute_location.cc#L50-L53




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #9940: [MetaSchedule] random compute location

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #9940:
URL: https://github.com/apache/tvm/pull/9940#discussion_r785432527



##########
File path: src/meta_schedule/schedule_rule/random_compute_location.cc
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+class RandomComputeLocationNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
+    if (!CheckConditions(sch, block_rv)) {
+      return {sch};
+    }
+
+    // Step 1. If the producer of the input block needs a random compute-at location (specified by
+    // the annotation), we colect the producer first, and transform the producer block later.
+    // - The reason we collect the producer before transforming the input block is that, if the
+    // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
+    // access the input block. Hence we collect its producer ahead of time.
+    // - Note that only single producer is allowed in this case.
+    Array<tir::BlockRV> producers{nullptr};
+    if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
+                    true)) {
+      producers = sch->GetProducers(block_rv);
+      sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
+      ICHECK_EQ(producers.size(), 1);
+    }
+
+    // Step 2. Transform the input block.
+    tir::Schedule res = RandomlyComputeAt(sch, block_rv);
+
+    // Step 3. Transform the producer block if compute-location sampling is needed.
+    if (producers.defined()) {
+      res = RandomlyComputeAt(res, producers[0]);
+    }
+
+    return {res};
+  }
+
+ private:
+  bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
+    const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
+    const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+
+    // Cond 1. The block is not the root block.
+    if (block_sref->parent == nullptr) {
+      return false;
+    }
+    // Cond 2. The block should be the direct child block of the root block.
+    if (GetScopeRoot(sch->state(), block_sref,          //
+                     /*require_stage_pipeline=*/false,  //
+                     /*require_subtree_compact_dataflow=*/false)
+            ->parent != nullptr) {
+      return false;
+    }
+    // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child
+    // block.
+    Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
+    if (loop_srefs.empty()) {
+      return false;
+    }
+    if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
+      return false;
+    }
+    // Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
+    if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined()) {
+      return false;
+    }
+    // Cond 6. The block has at lease one consumer.
+    if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
+      return false;
+    }
+    return true;
+  }
+
+  /*!
+   * \brief Keep sampling a compute-at location for the input block until success.
+   * \param sch The TIR schedule
+   * \param block_rv The block whose compute-at location is to be sampled
+   * \return The TIR schedule after transformation
+   */
+  tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
+    for (;;) {
+      tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
+      try {
+        sch->ComputeAt(block_rv, compute_at_loc, true);
+      } catch (const dmlc::Error& e) {
+        // ComputeAt fails, cleanup the following before re-try:
+        // 1) trace: instruction & decisions
+        // 2) sym_tab
+        sch->trace().value()->Pop();
+        sch->RemoveRV(compute_at_loc);
+        continue;
+      }
+      break;
+    }

Review comment:
       Thanks for pointing this. It is a good solution for clear coding logic; however, not a good answer for performance. 
   
   Because `CanComputeAt` is just implemented by try-catch, which means it contains both the checking stage and the mutation stage.
   https://github.com/apache/tvm/blob/ecc2e563df1a0b1d7e9d712bce90ee94948c3848/src/tir/schedule/primitive/compute_at.cc#L537-L547
   
   Calling `CanComputeAt` to get validated samples will make each entry run twice.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org