You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "zxybazh (via GitHub)" <gi...@apache.org> on 2023/03/02 23:50:16 UTC

[GitHub] [tvm] zxybazh opened a new pull request, #14182: [Unity] Introduce Default Schedule Pass

zxybazh opened a new pull request, #14182:
URL: https://github.com/apache/tvm/pull/14182

   This PR introduces a default schedule pass that creates default thread binding for PrimFuncs, including symbolic shape functions so that these PrimFuncs could be built and run on Cuda device. The pass picks out all the blocks inside the PrimFunc and does loop fusion, split and reorder based on the loop extent and target information (max thread block number and max thread per block).
   
   Co-authored-by: Josh Fromm  <<j...@octoml.ai>>
   CC: @sunggg @vinx13 @tqchen 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] jwfromm commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "jwfromm (via GitHub)" <gi...@apache.org>.
jwfromm commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1125113839


##########
include/tvm/tir/transform.h:
##########
@@ -709,6 +709,12 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage();
  */
 TVM_DLL Pass InstrumentProfileIntrinsics();
 
+/*!
+ * \brief Create default schedule for PrimFuncs to run on cuda device.
+ * \return The Pass.
+ */
+TVM_DLL Pass DefaultSchedule();

Review Comment:
   I do think that would be a better name, `DefaultSchedule~ makes it sound like it would be useful for cpu too.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1125840522


##########
include/tvm/tir/transform.h:
##########
@@ -709,6 +709,12 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage();
  */
 TVM_DLL Pass InstrumentProfileIntrinsics();
 
+/*!
+ * \brief Create default schedule for PrimFuncs to run on cuda device.
+ * \return The Pass.
+ */
+TVM_DLL Pass DefaultSchedule();

Review Comment:
   I agree DefaultGPUSchedule seems to be a good name here



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zxybazh commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "zxybazh (via GitHub)" <gi...@apache.org>.
zxybazh commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1125870096


##########
include/tvm/tir/transform.h:
##########
@@ -709,6 +709,12 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage();
  */
 TVM_DLL Pass InstrumentProfileIntrinsics();
 
+/*!
+ * \brief Create default schedule for PrimFuncs to run on cuda device.
+ * \return The Pass.
+ */
+TVM_DLL Pass DefaultSchedule();

Review Comment:
   It appears that there is consensus among the contributors that the current name DefaultSchedule is not clear enough, and many prefer the new name DefaultGPUSchedule as it allows for potential expansion to other GPU targets in the future. As a result, I have renamed it to DefaultGPUSchedule for the time being, but I am still open to other suggestions for naming.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] jwfromm commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "jwfromm (via GitHub)" <gi...@apache.org>.
jwfromm commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1125108824


##########
include/tvm/tir/transform.h:
##########
@@ -709,6 +709,12 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage();
  */
 TVM_DLL Pass InstrumentProfileIntrinsics();
 
+/*!
+ * \brief Create default schedule for PrimFuncs to run on cuda device.
+ * \return The Pass.
+ */
+TVM_DLL Pass DefaultSchedule();

Review Comment:
   Should we call it `DefaultCudaSchedule` or `DefaultGPUSchedule` to indicate its specifically for thread bindings?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] jwfromm commented on pull request #14182: [Unity] Introduce Default GPU Schedule Pass

Posted by "jwfromm (via GitHub)" <gi...@apache.org>.
jwfromm commented on PR #14182:
URL: https://github.com/apache/tvm/pull/14182#issuecomment-1456706737

   Thanks for this excellent PR @zxybazh! This is now merged.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zxybazh commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "zxybazh (via GitHub)" <gi...@apache.org>.
zxybazh commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1125110479


##########
include/tvm/tir/transform.h:
##########
@@ -709,6 +709,12 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage();
  */
 TVM_DLL Pass InstrumentProfileIntrinsics();
 
+/*!
+ * \brief Create default schedule for PrimFuncs to run on cuda device.
+ * \return The Pass.
+ */
+TVM_DLL Pass DefaultSchedule();

Review Comment:
   I'm open to naming options, `DefaultCudaSchedule` sounds like a better version given it's only targeting cuda for now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] psrivas2 commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "psrivas2 (via GitHub)" <gi...@apache.org>.
psrivas2 commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1124542691


##########
src/relax/transform/default_schedule.cc:
##########
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief The helper class to schedule functions and build a new module which calls the new TIR
+ * function.
+ */
+class ThreadBindMutator : public ExprMutator {
+ public:
+  static IRModule Transform(const IRModule& mod, int64_t max_thread_per_block) {
+    ThreadBindMutator mutator(mod);
+
+    for (const auto& kv : mod->functions) {

Review Comment:
   nit: we can use `const auto& [gv, func]: ...` to avoid`kv.first` and `kv.second` statements



##########
src/relax/transform/default_schedule.cc:
##########
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief The helper class to schedule functions and build a new module which calls the new TIR
+ * function.
+ */
+class ThreadBindMutator : public ExprMutator {
+ public:
+  static IRModule Transform(const IRModule& mod, int64_t max_thread_per_block) {
+    ThreadBindMutator mutator(mod);
+
+    for (const auto& kv : mod->functions) {
+      const GlobalVar& gv = kv.first;
+      const BaseFunc& func = kv.second;
+
+      if (func->IsInstance<tir::PrimFuncNode>()) {
+        IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({kv}));
+        tir::Schedule sch = tir::Schedule::Traced(mod, /*seed=*/-1, /*debug_mask=*/0,
+                                                  tir::ScheduleErrorRenderLevel::kDetail);
+        Array<tir::BlockRV> blocks = meta_schedule::BlockCollector::Collect(sch);
+        for (const tir::BlockRV& block : blocks) {
+          // fetch the loops
+          Array<tir::LoopRV> loops = sch->GetLoops(block);
+          bool scheduled = false;
+          for (const tir::LoopRV& loop : loops) {
+            if (sch->Get(loop)->thread_binding.defined()) {
+              scheduled = true;
+              break;
+            }
+          }
+          // skip if already scheduled
+          if (scheduled) {
+            continue;
+          }
+          Array<tir::IterVar> iters = sch->Get(block)->iter_vars;
+          ICHECK_EQ(loops.size(), iters.size());
+          Array<tir::LoopRV> data_parallel_loops;
+          // only fuse data parallel loops
+          for (size_t i = 0; i < loops.size(); ++i) {
+            if (iters[i]->iter_type == tir::IterVarType::kDataPar) {
+              data_parallel_loops.push_back(loops[i]);
+            }
+          }
+          if (data_parallel_loops.size() == 0) {
+            continue;
+          }
+          // fuse all data parallel loops
+          tir::LoopRV fused = sch->Fuse(data_parallel_loops, /*preserve_unit_iters=*/false);
+          int64_t product = std::numeric_limits<int64_t>::max();
+          if (sch->Get(fused)->extent->IsInstance<tir::IntImmNode>()) {
+            product = sch->Get(fused)->extent.as<tir::IntImmNode>()->value;
+          }
+          static const int64_t max_threadblocks = 256;
+          // schedule the fused loop
+          if (product > max_thread_per_block * max_threadblocks) {
+            Array<tir::LoopRV> splits = sch->Split(
+                fused,
+                /*factors=*/{NullOpt, Integer(max_threadblocks), Integer(max_thread_per_block)});
+            sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]});
+            sch->Bind(splits[1], "blockIdx.x");
+            sch->Bind(splits[2], "threadIdx.x");
+          } else {
+            Array<tir::LoopRV> splits = sch->Split(
+                fused, /*factors=*/{NullOpt, Integer(std::min(product, max_thread_per_block))});
+            sch->Bind(splits[0], "blockIdx.x");
+            sch->Bind(splits[1], "threadIdx.x");
+          }
+        }
+        mutator.builder_->AddFunction(sch->mod()->Lookup(gv->name_hint), gv->name_hint);
+      } else {
+        mutator.builder_->AddFunction(func, gv->name_hint);
+      }
+    }
+    return mutator.builder_->GetContextIRModule();
+  }
+
+ private:
+  explicit ThreadBindMutator(const IRModule& mod) : mod_(mod) {}

Review Comment:
   Calling `ExprMutator` constructor with module i.e.,`ExprMutator(mod)` here would populate the builder context IRModule with old IRModule. So we won't have to call `builder_->AddFunction` (line 102) for functions that are not modified at all, and for others we can use `builder_->UpdateFunction` (line 100).



##########
src/relax/transform/default_schedule.cc:
##########
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief The helper class to schedule functions and build a new module which calls the new TIR
+ * function.
+ */
+class ThreadBindMutator : public ExprMutator {

Review Comment:
   +1 we can skip using ExprMutator for this and directly use Schedule APIs to modify the PrimFunc, and then return the module with updated PrimFunc.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "MasterJH5574 (via GitHub)" <gi...@apache.org>.
MasterJH5574 commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1125350328


##########
python/tvm/tir/transform/transform.py:
##########
@@ -1040,3 +1040,13 @@ def InstallDebugSpans():
         The result pass
     """
     return _ffi_api.InstallDebugSpans()  # type: ignore
+
+
+def DefaultSchedule():
+    """Default schedule for PrimFuncs.

Review Comment:
   I’m thinking about the performance. If we don’t intend to guarantee a decent performance, we probably should state in the documents that this pass only intends to make the IRModule able to be built, and has no guarantee on the performance.
   
   BTW could we give the pass more documents? The PR description looks very good to serve as document.



##########
src/tir/transforms/default_schedule.cc:
##########
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"

Review Comment:
   These headers are redundant. They seem to come from the FuseOps pass.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen commented on PR #14182:
URL: https://github.com/apache/tvm/pull/14182#issuecomment-1455100927

   
   
   please run the following commend to update to latest change
   ```
   git rebase --onto upstream/unity upstream/unity-rebase-backup-2023-03-05 
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yongwww commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "yongwww (via GitHub)" <gi...@apache.org>.
yongwww commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1124882665


##########
src/relax/transform/default_schedule.cc:
##########
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief The helper class to schedule functions and build a new module which calls the new TIR
+ * function.
+ */
+class ThreadBindMutator : public ExprMutator {
+ public:
+  static IRModule Transform(const IRModule& mod, int64_t max_thread_per_block) {
+    ThreadBindMutator mutator(mod);
+
+    for (const auto& kv : mod->functions) {
+      const GlobalVar& gv = kv.first;
+      const BaseFunc& func = kv.second;
+
+      if (func->IsInstance<tir::PrimFuncNode>()) {
+        IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({kv}));
+        tir::Schedule sch = tir::Schedule::Traced(mod, /*seed=*/-1, /*debug_mask=*/0,
+                                                  tir::ScheduleErrorRenderLevel::kDetail);
+        Array<tir::BlockRV> blocks = meta_schedule::BlockCollector::Collect(sch);
+        for (const tir::BlockRV& block : blocks) {
+          // fetch the loops
+          Array<tir::LoopRV> loops = sch->GetLoops(block);
+          bool scheduled = false;
+          for (const tir::LoopRV& loop : loops) {
+            if (sch->Get(loop)->thread_binding.defined()) {
+              scheduled = true;
+              break;
+            }
+          }
+          // skip if already scheduled
+          if (scheduled) {
+            continue;
+          }
+          Array<tir::IterVar> iters = sch->Get(block)->iter_vars;
+          ICHECK_EQ(loops.size(), iters.size());
+          Array<tir::LoopRV> data_parallel_loops;
+          // only fuse data parallel loops
+          for (size_t i = 0; i < loops.size(); ++i) {
+            if (iters[i]->iter_type == tir::IterVarType::kDataPar) {
+              data_parallel_loops.push_back(loops[i]);
+            }
+          }
+          if (data_parallel_loops.size() == 0) {
+            continue;
+          }
+          // fuse all data parallel loops
+          tir::LoopRV fused = sch->Fuse(data_parallel_loops, /*preserve_unit_iters=*/false);
+          int64_t product = std::numeric_limits<int64_t>::max();
+          if (sch->Get(fused)->extent->IsInstance<tir::IntImmNode>()) {
+            product = sch->Get(fused)->extent.as<tir::IntImmNode>()->value;
+          }
+          static const int64_t max_threadblocks = 256;
+          // schedule the fused loop
+          if (product > max_thread_per_block * max_threadblocks) {
+            Array<tir::LoopRV> splits = sch->Split(
+                fused,
+                /*factors=*/{NullOpt, Integer(max_threadblocks), Integer(max_thread_per_block)});
+            sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]});
+            sch->Bind(splits[1], "blockIdx.x");
+            sch->Bind(splits[2], "threadIdx.x");
+          } else {
+            Array<tir::LoopRV> splits = sch->Split(
+                fused, /*factors=*/{NullOpt, Integer(std::min(product, max_thread_per_block))});
+            sch->Bind(splits[0], "blockIdx.x");
+            sch->Bind(splits[1], "threadIdx.x");
+          }
+        }
+        mutator.builder_->AddFunction(sch->mod()->Lookup(gv->name_hint), gv->name_hint);
+      } else {
+        mutator.builder_->AddFunction(func, gv->name_hint);
+      }
+    }
+    return mutator.builder_->GetContextIRModule();
+  }
+
+ private:
+  explicit ThreadBindMutator(const IRModule& mod) : mod_(mod) {}
+
+ private:
+  /*! \brief The IRModule */
+  const IRModule& mod_;
+};
+
+IRModule DefaultSchedule(IRModule mod, int64_t max_thread_per_block) {
+  mod = ThreadBindMutator::Transform(mod, max_thread_per_block);
+  return mod;
+}
+
+namespace transform {
+
+Pass DefaultSchedule() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
+      [=](IRModule m, PassContext pc) {
+        tvm::Target target = tvm::Target::Current();
+        Integer max_thread_per_block = target->GetAttr<Integer>("max_num_threads").value_or(-1);
+        if (target->kind->name != "cuda") {
+          ICHECK_NE(max_thread_per_block, -1) << "max_num_threads is not set for target " << target;
+          return m;
+        }
+        return relax::DefaultSchedule(m, max_thread_per_block.IntValue());
+      };
+  return CreateModulePass(/*pass_function=*/pass_func,      //
+                          /*opt_level=*/0,                  //
+                          /*pass_name=*/"DefaultSchedule",  //
+                          /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.DefaultSchedule").set_body_typed(DefaultSchedule);

Review Comment:
   I feel it would be good to move it into `tir.transform.DefaultSchedule` (src/tir/transform), looks only PrimFunc is updated



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zxybazh commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "zxybazh (via GitHub)" <gi...@apache.org>.
zxybazh commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1125859329


##########
python/tvm/tir/transform/transform.py:
##########
@@ -1040,3 +1040,13 @@ def InstallDebugSpans():
         The result pass
     """
     return _ffi_api.InstallDebugSpans()  # type: ignore
+
+
+def DefaultSchedule():
+    """Default schedule for PrimFuncs.

Review Comment:
   Thanks for the advice! Added more document according to PR description and the intention of this PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "vinx13 (via GitHub)" <gi...@apache.org>.
vinx13 commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1123927489


##########
include/tvm/relax/transform.h:
##########
@@ -272,6 +273,12 @@ TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
 TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, ObjectRef>>> target_options,
                         Array<runtime::String> entry_functions);
 
+/*!
+ * \brief Create default schedule for PrimFuncs.
+ * \return The Pass.
+ */
+TVM_DLL Pass DefaultSchedule(tvm::Target target);

Review Comment:
   shall we get target using `Target::Current()` instead



##########
src/relax/transform/default_schedule.cc:
##########
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief The helper class to schedule functions and build a new module which calls the new TIR
+ * function.
+ */
+class ThreadBindMutator : public ExprMutator {
+ public:
+  static IRModule Transform(const IRModule& mod, int64_t max_thread_per_block) {
+    ThreadBindMutator mutator(mod);
+
+    for (const auto& kv : mod->functions) {
+      const GlobalVar& gv = kv.first;
+      const BaseFunc& func = kv.second;
+
+      if (func->IsInstance<tir::PrimFuncNode>()) {
+        IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({kv}));
+        tir::Schedule sch = tir::Schedule::Traced(mod, /*seed=*/-1, /*debug_mask=*/0,
+                                                  tir::ScheduleErrorRenderLevel::kDetail);
+        Array<tir::BlockRV> blocks = meta_schedule::BlockCollector::Collect(sch);
+        for (const tir::BlockRV& block : blocks) {
+          // fetch the loops
+          Array<tir::LoopRV> loops = sch->GetLoops(block);
+          bool scheduled = false;
+          for (const tir::LoopRV& loop : loops) {
+            if (sch->Get(loop)->thread_binding.defined()) {
+              scheduled = true;
+              break;
+            }
+          }
+          // skip if already scheduled
+          if (scheduled) {
+            continue;
+          }
+          Array<tir::IterVar> iters = sch->Get(block)->iter_vars;
+          ICHECK_EQ(loops.size(), iters.size());
+          Array<tir::LoopRV> data_parallel_loops;
+          // only fuse data parallel loops
+          for (size_t i = 0; i < loops.size(); ++i) {
+            if (iters[i]->iter_type == tir::IterVarType::kDataPar) {
+              data_parallel_loops.push_back(loops[i]);
+            }
+          }
+          if (data_parallel_loops.size() == 0) {
+            continue;
+          }
+          // fuse all data parallel loops
+          tir::LoopRV fused = sch->Fuse(data_parallel_loops, /*preserve_unit_iters=*/false);
+          int64_t product = std::numeric_limits<int64_t>::max();
+          if (sch->Get(fused)->extent->IsInstance<tir::IntImmNode>()) {
+            product = sch->Get(fused)->extent.as<tir::IntImmNode>()->value;
+          }
+          static const int64_t max_threadblocks = 256;
+          // schedule the fused loop
+          if (product > max_thread_per_block * max_threadblocks) {
+            Array<tir::LoopRV> splits = sch->Split(
+                fused,
+                /*factors=*/{NullOpt, Integer(max_threadblocks), Integer(max_thread_per_block)});
+            sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]});
+            sch->Bind(splits[1], "blockIdx.x");
+            sch->Bind(splits[2], "threadIdx.x");
+          } else {
+            Array<tir::LoopRV> splits = sch->Split(
+                fused, /*factors=*/{NullOpt, Integer(std::min(product, max_thread_per_block))});
+            sch->Bind(splits[0], "blockIdx.x");
+            sch->Bind(splits[1], "threadIdx.x");
+          }
+        }
+        mutator.builder_->AddFunction(sch->mod()->Lookup(gv->name_hint), gv->name_hint);
+      } else {
+        mutator.builder_->AddFunction(func, gv->name_hint);

Review Comment:
   not needed since the function is already in the module and it's not changed



##########
src/relax/transform/default_schedule.cc:
##########
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief The helper class to schedule functions and build a new module which calls the new TIR
+ * function.
+ */
+class ThreadBindMutator : public ExprMutator {
+ public:
+  static IRModule Transform(const IRModule& mod, int64_t max_thread_per_block) {
+    ThreadBindMutator mutator(mod);
+
+    for (const auto& kv : mod->functions) {
+      const GlobalVar& gv = kv.first;
+      const BaseFunc& func = kv.second;
+
+      if (func->IsInstance<tir::PrimFuncNode>()) {
+        IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({kv}));

Review Comment:
   it is possible to skip creating a new module if we use `Schedule::WorkOn` to directly schedule the original module (after copy on write), see https://discuss.tvm.apache.org/t/manual-scheduling-of-call-tir-functions-in-relax/14446/7?u=vinx13



##########
src/relax/transform/default_schedule.cc:
##########
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief The helper class to schedule functions and build a new module which calls the new TIR
+ * function.
+ */
+class ThreadBindMutator : public ExprMutator {

Review Comment:
   since we are not visiting the function bodies probably it's not necessary to use a `ExprMutator`, we can directly update the module using `IRModule::Add`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen commented on PR #14182:
URL: https://github.com/apache/tvm/pull/14182#issuecomment-1454307858

   cc @junrushao 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "tvm-bot (via GitHub)" <gi...@apache.org>.
tvm-bot commented on PR #14182:
URL: https://github.com/apache/tvm/pull/14182#issuecomment-1452739461

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * cc @quic-sanirudh <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1124080592


##########
src/meta_schedule/utils.h:
##########
@@ -554,6 +554,68 @@ inline double Sum(const Array<FloatImm>& arr) {
   return sum;
 }
 
+/*! \brief Collecting all the blocks */
+class BlockCollector : public tir::StmtVisitor {
+ public:
+  static Array<tir::BlockRV> Collect(const tir::Schedule& sch,
+                                     const runtime::PackedFunc f_block_filter = nullptr) {  //
+    return BlockCollector(sch, f_block_filter).Run();
+  }
+
+ private:
+  /*! \brief Entry point */
+  Array<tir::BlockRV> Run() {
+    std::vector<tir::BlockRV> results;
+    for (const auto& kv : sch_->mod()->functions) {
+      const GlobalVar& gv = kv.first;         // `gv->name_hint` is the name of the function
+      const BaseFunc& base_func = kv.second;  // this can be PrimFunc or relay::Function

Review Comment:
   can use
   
   ```
   for (const auto& [gv, base_func]: ...
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen commented on PR #14182:
URL: https://github.com/apache/tvm/pull/14182#issuecomment-1454103249

   cc @junrushao 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] MasterJH5574 commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "MasterJH5574 (via GitHub)" <gi...@apache.org>.
MasterJH5574 commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1125349558


##########
include/tvm/tir/transform.h:
##########
@@ -709,6 +709,12 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage();
  */
 TVM_DLL Pass InstrumentProfileIntrinsics();
 
+/*!
+ * \brief Create default schedule for PrimFuncs to run on cuda device.
+ * \return The Pass.
+ */
+TVM_DLL Pass DefaultSchedule();

Review Comment:
   Though this pass is intended only for CUDA backend at this moment, I suppose it will actually work for some other GPU backends, as the tasks tuned by MetaSchedule on these backends are all following the same schedule rule.
   https://github.com/apache/tvm/blob/main/src/meta_schedule/utils.h#L512
   
   So I personally would prefer the name of “DefaultGPUSchedule” and list the GPU backends this pass supports in the doc string. In the pass we have already checked if the target contains attribute “maximum # of threads per block.” We can fall back with no change if the target doesn’t have this attribute.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yongwww commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "yongwww (via GitHub)" <gi...@apache.org>.
yongwww commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1124882665


##########
src/relax/transform/default_schedule.cc:
##########
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief The helper class to schedule functions and build a new module which calls the new TIR
+ * function.
+ */
+class ThreadBindMutator : public ExprMutator {
+ public:
+  static IRModule Transform(const IRModule& mod, int64_t max_thread_per_block) {
+    ThreadBindMutator mutator(mod);
+
+    for (const auto& kv : mod->functions) {
+      const GlobalVar& gv = kv.first;
+      const BaseFunc& func = kv.second;
+
+      if (func->IsInstance<tir::PrimFuncNode>()) {
+        IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({kv}));
+        tir::Schedule sch = tir::Schedule::Traced(mod, /*seed=*/-1, /*debug_mask=*/0,
+                                                  tir::ScheduleErrorRenderLevel::kDetail);
+        Array<tir::BlockRV> blocks = meta_schedule::BlockCollector::Collect(sch);
+        for (const tir::BlockRV& block : blocks) {
+          // fetch the loops
+          Array<tir::LoopRV> loops = sch->GetLoops(block);
+          bool scheduled = false;
+          for (const tir::LoopRV& loop : loops) {
+            if (sch->Get(loop)->thread_binding.defined()) {
+              scheduled = true;
+              break;
+            }
+          }
+          // skip if already scheduled
+          if (scheduled) {
+            continue;
+          }
+          Array<tir::IterVar> iters = sch->Get(block)->iter_vars;
+          ICHECK_EQ(loops.size(), iters.size());
+          Array<tir::LoopRV> data_parallel_loops;
+          // only fuse data parallel loops
+          for (size_t i = 0; i < loops.size(); ++i) {
+            if (iters[i]->iter_type == tir::IterVarType::kDataPar) {
+              data_parallel_loops.push_back(loops[i]);
+            }
+          }
+          if (data_parallel_loops.size() == 0) {
+            continue;
+          }
+          // fuse all data parallel loops
+          tir::LoopRV fused = sch->Fuse(data_parallel_loops, /*preserve_unit_iters=*/false);
+          int64_t product = std::numeric_limits<int64_t>::max();
+          if (sch->Get(fused)->extent->IsInstance<tir::IntImmNode>()) {
+            product = sch->Get(fused)->extent.as<tir::IntImmNode>()->value;
+          }
+          static const int64_t max_threadblocks = 256;
+          // schedule the fused loop
+          if (product > max_thread_per_block * max_threadblocks) {
+            Array<tir::LoopRV> splits = sch->Split(
+                fused,
+                /*factors=*/{NullOpt, Integer(max_threadblocks), Integer(max_thread_per_block)});
+            sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]});
+            sch->Bind(splits[1], "blockIdx.x");
+            sch->Bind(splits[2], "threadIdx.x");
+          } else {
+            Array<tir::LoopRV> splits = sch->Split(
+                fused, /*factors=*/{NullOpt, Integer(std::min(product, max_thread_per_block))});
+            sch->Bind(splits[0], "blockIdx.x");
+            sch->Bind(splits[1], "threadIdx.x");
+          }
+        }
+        mutator.builder_->AddFunction(sch->mod()->Lookup(gv->name_hint), gv->name_hint);
+      } else {
+        mutator.builder_->AddFunction(func, gv->name_hint);
+      }
+    }
+    return mutator.builder_->GetContextIRModule();
+  }
+
+ private:
+  explicit ThreadBindMutator(const IRModule& mod) : mod_(mod) {}
+
+ private:
+  /*! \brief The IRModule */
+  const IRModule& mod_;
+};
+
+IRModule DefaultSchedule(IRModule mod, int64_t max_thread_per_block) {
+  mod = ThreadBindMutator::Transform(mod, max_thread_per_block);
+  return mod;
+}
+
+namespace transform {
+
+Pass DefaultSchedule() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
+      [=](IRModule m, PassContext pc) {
+        tvm::Target target = tvm::Target::Current();
+        Integer max_thread_per_block = target->GetAttr<Integer>("max_num_threads").value_or(-1);
+        if (target->kind->name != "cuda") {
+          ICHECK_NE(max_thread_per_block, -1) << "max_num_threads is not set for target " << target;
+          return m;
+        }
+        return relax::DefaultSchedule(m, max_thread_per_block.IntValue());
+      };
+  return CreateModulePass(/*pass_function=*/pass_func,      //
+                          /*opt_level=*/0,                  //
+                          /*pass_name=*/"DefaultSchedule",  //
+                          /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.DefaultSchedule").set_body_typed(DefaultSchedule);

Review Comment:
   I feel it would be better to move it into `tir.transform.DefaultSchedule` (src/tir/transform), looks only PrimFunc is updated



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] jwfromm commented on a diff in pull request #14182: [Unity] Introduce Default Schedule Pass

Posted by "jwfromm (via GitHub)" <gi...@apache.org>.
jwfromm commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1125113839


##########
include/tvm/tir/transform.h:
##########
@@ -709,6 +709,12 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage();
  */
 TVM_DLL Pass InstrumentProfileIntrinsics();
 
+/*!
+ * \brief Create default schedule for PrimFuncs to run on cuda device.
+ * \return The Pass.
+ */
+TVM_DLL Pass DefaultSchedule();

Review Comment:
   I do think that would be a better name, `DefaultSchedule` makes it sound like it would be useful for cpu too.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] psrivas2 commented on a diff in pull request #14182: [Unity] Introduce Default GPU Schedule Pass

Posted by "psrivas2 (via GitHub)" <gi...@apache.org>.
psrivas2 commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1126760754


##########
src/tir/transforms/default_gpu_schedule.cc:
##########
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "../../meta_schedule/utils.h"
+
+namespace tvm {
+namespace tir {
+namespace transform {
+/*!
+ * \brief A helper class to do default thread binding for a block.

Review Comment:
   nit: s/class/function



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] jwfromm merged pull request #14182: [Unity] Introduce Default GPU Schedule Pass

Posted by "jwfromm (via GitHub)" <gi...@apache.org>.
jwfromm merged PR #14182:
URL: https://github.com/apache/tvm/pull/14182


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org