You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/01/23 16:06:52 UTC

[tvm] branch main updated: [MetaSchedule] Schedule Rule: Parallelize-Vectorize-Unroll (#10033)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new de01c3e  [MetaSchedule] Schedule Rule: Parallelize-Vectorize-Unroll (#10033)
de01c3e is described below

commit de01c3e2a732613f920264c2ae73874ade3e16f6
Author: Ruihang Lai <la...@qq.com>
AuthorDate: Mon Jan 24 00:06:10 2022 +0800

    [MetaSchedule] Schedule Rule: Parallelize-Vectorize-Unroll (#10033)
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
---
 include/tvm/meta_schedule/schedule_rule.h          |  13 ++-
 include/tvm/tir/stmt.h                             |  12 ++
 python/tvm/meta_schedule/schedule_rule/__init__.py |   3 +-
 .../schedule_rule/parallel_vectorize_unroll.py     |  64 ++++++++++
 python/tvm/meta_schedule/testing/schedule_rule.py  |  28 +++++
 .../schedule_rule/parallel_vectorize_unroll.cc     | 129 +++++++++++++++++++++
 ...dule_schedule_rule_parallel_vectorize_unroll.py | 105 +++++++++++++++++
 7 files changed, 347 insertions(+), 7 deletions(-)

diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h
index b3a4f78..3911a52 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -171,19 +171,20 @@ class ScheduleRule : public runtime::ObjectRef {
   TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
   /*!
    * \brief A rule that randomly select a compute-at location for a free block
-   * \return The rule created
+   * \return The schedule rule created
    */
   TVM_DLL static ScheduleRule RandomComputeLocation();
   /*!
-   * \brief Mark parallelize, vectorize and unroll to each block correspondingly
+   * \brief Mark parallelize, vectorize and unroll to the root block. The mark will be applied to
+   * each block in a follow-up post processor
    * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
-   * uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
+   * upper limit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
    * parallelism.
    * \param max_vectorize_extent The maximum extent to be vectorized.
-   * It sets the uplimit of the CPU vectorization. Use -1 to disable vectorization.
-   * \param unroll_max_steps The maximum number of unroll steps to be done.
+   * It sets the upper limit of the hardware target vectorization. Use -1 to disable vectorization.
+   * \param unroll_max_steps The options of the maximum number of unroll steps to be done.
    * Use an empty array to disable unroll.
-   * \param unroll_explicit Whether to explicitly unroll the loop, or just add a unroll pragma.
+   * \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma.
    * \return The schedule rule created
    */
   TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core,            //
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 7bc3b69..0a05439 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1368,6 +1368,18 @@ constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_str
 constexpr const char* meta_schedule_random_compute_producer =
     "meta_schedule.random_compute_producer";
 
+/*! \brief Mark auto-parallel setting on the block. */
+constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";
+
+/*! \brief Mark auto-vectorize setting on the block. */
+constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";
+
+/*! \brief Mark auto-unroll setting on the block. */
+constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";
+
+/*! \brief Mark auto-unroll setting on the block. */
+constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
+
 /*!
  * \brief Check if attr_key is a pragma key extension
  * \param attr_key The attr key to be compared
diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py
index c54eecf..ce66323 100644
--- a/python/tvm/meta_schedule/schedule_rule/__init__.py
+++ b/python/tvm/meta_schedule/schedule_rule/__init__.py
@@ -19,5 +19,6 @@ blocks in a schedule. See also PostOrderApply.
 from .add_rfactor import AddRFactor
 from .auto_inline import AutoInline
 from .cross_thread_reduction import CrossThreadReduction
-from .schedule_rule import PyScheduleRule, ScheduleRule
+from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
 from .random_compute_location import RandomComputeLocation
+from .schedule_rule import PyScheduleRule, ScheduleRule
diff --git a/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py
new file mode 100644
index 0000000..a79ea91
--- /dev/null
+++ b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Rule that mark parallelize, vectorize and unroll to the root block. The mark will be applied to
+each block in a follow-up post processor"""
+from typing import List, Optional
+
+from tvm._ffi import register_object
+
+from .. import _ffi_api
+from .schedule_rule import ScheduleRule
+
+
+@register_object("meta_schedule.ParallelizeVectorizeUnroll")
+class ParallelizeVectorizeUnroll(ScheduleRule):
+    """Rule that mark parallelize, vectorize and unroll to the root block. The mark will be applied
+    to each block in a follow-up post processor
+
+    Parameters
+    ----------
+    max_jobs_per_core: int
+        The maximum number of jobs to be launched per CPU core. It sets the upper limit of CPU
+        parallelism, i.e. `num_cores * max_jobs_per_core`.
+        Use -1 to disable parallelism.
+    max_vectorize_extent: int
+        The maximum extent to be vectorized. It sets the upper limit of the hardware target
+        vectorization.
+        Use -1 to disable vectorization.
+    unroll_max_steps: Optional[List[int]]
+        The options of the maximum number of unroll steps to be done.
+        Use None to disable unroll
+    unroll_explicit: bool
+        Whether to explicitly unroll the loop, or just add an "unroll" pragma
+    """
+
+    def __init__(
+        self,
+        max_jobs_per_core: int = 16,
+        max_vectorize_extent: int = 16,
+        unroll_max_steps: Optional[List[int]] = None,
+        unroll_explicit: bool = True,
+    ) -> None:
+        if unroll_max_steps is None:
+            unroll_max_steps = []
+        self.__init_handle_by_constructor__(
+            _ffi_api.ScheduleRuleParallelizeVectorizeUnroll,  # type: ignore # pylint: disable=no-member
+            max_jobs_per_core,
+            max_vectorize_extent,
+            unroll_max_steps,
+            unroll_explicit,
+        )
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index b9606ee..464d249 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -19,6 +19,8 @@ from tvm.meta_schedule.schedule_rule import (
     AddRFactor,
     AutoInline,
     CrossThreadReduction,
+    ParallelizeVectorizeUnroll,
+    RandomComputeLocation,
     ScheduleRule,
 )
 from tvm.target import Target
@@ -61,3 +63,29 @@ def cross_thread_reduction(target: Target) -> ScheduleRule:
     if target.kind.name == "cuda":
         return CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
     raise NotImplementedError(f"{target.kind.name} is not supported")
+
+
+def random_compute_location(target: Target) -> ScheduleRule:
+    """Default schedule rules for with random-compute-location"""
+    if target.kind.name == "llvm":
+        return RandomComputeLocation()
+    raise NotImplementedError(f"{target.kind.name} is not supported")
+
+
+def parallel_vectorize_unroll(target: Target) -> ScheduleRule:
+    """Default schedule rules for with parallel-vectorize-unroll"""
+    if target.kind.name == "llvm":
+        return ParallelizeVectorizeUnroll(
+            max_jobs_per_core=16,
+            max_vectorize_extent=32,
+            unroll_max_steps=[0, 16, 64, 512],
+            unroll_explicit=True,
+        )
+    if target.kind.name == "cuda":
+        return ParallelizeVectorizeUnroll(
+            max_jobs_per_core=-1,
+            max_vectorize_extent=-1,
+            unroll_max_steps=[0, 16, 64, 512, 1024],
+            unroll_explicit=True,
+        )
+    raise NotImplementedError(f"{target.kind.name} is not supported")
diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
new file mode 100644
index 0000000..c0e57a6
--- /dev/null
+++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+bool IsRootBlock(const Schedule& sch, const BlockRV& block_rv) {
+  StmtSRef block_sref = sch->GetSRef(block_rv);
+  return block_sref->parent == nullptr;
+}
+
+}  // namespace tir
+}  // namespace tvm
+
+namespace tvm {
+namespace meta_schedule {
+
+class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode {
+ public:
+  // Inherited from ScheduleRuleNode
+  void InitializeWithTuneContext(const TuneContext& context) final {
+    ICHECK(context->target.defined());
+    if (this->max_jobs_per_core != -1) {
+      Target target = context->target.value();
+      this->max_parallel_extent_ = GetTargetNumCores(target) * max_jobs_per_core;
+    }
+  }
+
+  // Inherited from ScheduleRuleNode
+  Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& root_rv) {
+    // Currently only mark the root block with annotations.
+    if (!tir::IsRootBlock(sch, root_rv)) {
+      return {sch};
+    }
+
+    // Parallelization
+    if (max_jobs_per_core != -1) {
+      sch->Annotate(root_rv, tir::attr::meta_schedule_parallel,
+                    Integer(this->max_parallel_extent_));
+    }
+    // Vectorization
+    if (max_vectorize_extent != -1) {
+      sch->Annotate(root_rv, tir::attr::meta_schedule_vectorize, Integer(max_vectorize_extent));
+    }
+    // Unroll
+    if (!unroll_max_steps.empty()) {
+      int n = unroll_max_steps.size();
+      double prob = 1.0 / n;
+      Array<FloatImm> probs(n, FloatImm(DataType::Float(64), prob));
+      PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs);
+      if (unroll_explicit) {
+        sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step);
+      } else {
+        sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_implicit, max_step);
+      }
+    }
+    return {sch};
+  }
+
+ public:
+  /*!
+   * \brief The maximum number of jobs to be launched per CPU core. It sets the
+   * upper limit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
+   * parallelism.
+   */
+  int64_t max_jobs_per_core;
+  /*!
+   * \brief The maximum extent to be vectorized.
+   * It sets the upper limit of the hardware target vectorization. Use -1 to disable vectorization.
+   */
+  int max_vectorize_extent;
+  /*!
+   * \brief The options of the maximum number of unroll steps to be done.
+   * Use an empty array to disable unroll.
+   */
+  Array<Integer> unroll_max_steps;
+  /*! \brief Whether to explicitly unroll the loop, or just add an "unroll" pragma. */
+  bool unroll_explicit;
+  /*! \brief The number of maximum available jobs in CPU. */
+  int64_t max_parallel_extent_;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("max_jobs_per_core", &max_jobs_per_core);
+    v->Visit("max_vectorize_extent", &max_vectorize_extent);
+    v->Visit("unroll_max_steps", &unroll_max_steps);
+    v->Visit("unroll_explicit", &unroll_explicit);
+    // `max_parallel_extent_` is not visited
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.ParallelizeVectorizeUnroll";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ParallelizeVectorizeUnrollNode, ScheduleRuleNode);
+};
+
+ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core,
+                                                      int max_vectorize_extent,
+                                                      Array<Integer> unroll_max_steps,
+                                                      bool unroll_explicit) {
+  ObjectPtr<ParallelizeVectorizeUnrollNode> n = make_object<ParallelizeVectorizeUnrollNode>();
+  n->max_jobs_per_core = max_jobs_per_core;
+  n->max_vectorize_extent = max_vectorize_extent;
+  n->unroll_max_steps = unroll_max_steps;
+  n->unroll_explicit = unroll_explicit;
+  n->max_parallel_extent_ = -1;
+  return ScheduleRule(n);
+}
+
+TVM_REGISTER_NODE_TYPE(ParallelizeVectorizeUnrollNode);
+TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll")
+    .set_body_typed(ScheduleRule::ParallelizeVectorizeUnroll);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py
new file mode 100644
index 0000000..e57799f
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+import tvm
+from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
+from tvm.meta_schedule.testing.schedule_rule import parallel_vectorize_unroll
+from tvm.meta_schedule.testing.space_generation import check_trace
+from tvm.meta_schedule.tune_context import TuneContext
+from tvm.script import tir as T
+from tvm.target import Target
+
+# fmt: off
+# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
+
+@tvm.script.ir_module
+class Matmul:
+    @T.prim_func
+    def main(a: T.handle, b: T.handle, c: T.handle) -> None:
+        T.func_attr({"global_symbol": "main"})
+        A = T.match_buffer(a, (1024, 1024), "float32")
+        B = T.match_buffer(b, (1024, 1024), "float32")
+        C = T.match_buffer(c, (1024, 1024), "float32")
+        for i, j, k in T.grid(1024, 1024, 1024):
+            with T.block("matmul"):
+                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                with T.init():
+                    C[vi, vj] = 0.0
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+
+@tvm.script.ir_module
+class ParallelizeVectorizeUnroll:
+    @T.prim_func
+    def main(a: T.handle, b: T.handle, c: T.handle) -> None:
+        T.func_attr({"global_symbol": "main"})
+        A = T.match_buffer(a, (1024, 1024), "float32")
+        B = T.match_buffer(b, (1024, 1024), "float32")
+        C = T.match_buffer(c, (1024, 1024), "float32")
+        with T.block("root"):
+            T.reads([])
+            T.writes([])
+            T.block_attr({"meta_schedule.parallel": 128, "meta_schedule.vectorize": 16, "meta_schedule.unroll_explicit": 2})
+            for i, j, k in T.grid(1024, 1024, 1024):
+                with T.block("matmul"):
+                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    with T.init():
+                        C[vi, vj] = 0.0
+                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
+# fmt: on
+
+
+def _create_context(mod, target, rule):
+    ctx = TuneContext(
+        mod=mod,
+        target=target,
+        space_generator=PostOrderApply(),
+        sch_rules=[rule],
+        task_name="test",
+    )
+    ctx.space_generator.initialize_with_tune_context(ctx)
+    for sch_rule in ctx.sch_rules:
+        sch_rule.initialize_with_tune_context(ctx)
+    return ctx
+
+
+def test_parallel_vectorize_unroll():
+    expected = [
+        [
+            'b0 = sch.get_block(name="root", func_name="main")',
+            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.parallel", ann_val=512)',
+            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.vectorize", ann_val=32)',
+            "v1 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)',
+        ]
+    ]
+    mod = Matmul
+    target = Target("llvm --num-cores=32")
+    ctx = _create_context(
+        mod=mod,
+        target=target,
+        rule=parallel_vectorize_unroll(target=target),
+    )
+    spaces = ctx.space_generator.generate_design_space(mod=mod)
+    assert len(spaces) == 1
+    check_trace(spaces, expected)
+
+
+if __name__ == "__main__":
+    test_parallel_vectorize_unroll()