You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/01/29 04:24:24 UTC
[tvm] branch main updated: [MetaSchedule][M4a] Mutator: Mutate-Tile-Size (#10092)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4d0dac3 [MetaSchedule][M4a] Mutator: Mutate-Tile-Size (#10092)
4d0dac3 is described below
commit 4d0dac3e552f157d641b8cf01119250203134676
Author: Ruihang Lai <la...@qq.com>
AuthorDate: Sat Jan 29 12:24:03 2022 +0800
[MetaSchedule][M4a] Mutator: Mutate-Tile-Size (#10092)
* [MetaSchedule][M4a] Mutator: Mutate-Tile-Size
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
* Python 3.8 has no `math.prod`
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
---
include/tvm/meta_schedule/mutator.h | 2 +-
python/tvm/meta_schedule/mutator/__init__.py | 1 +
.../mutator/{__init__.py => mutate_tile_size.py} | 24 +-
src/meta_schedule/mutator/mutate_tile_size.cc | 273 +++++++++++++++++++++
.../test_meta_schedule_mutator_mutate_tile_size.py | 93 +++++++
5 files changed, 383 insertions(+), 10 deletions(-)
diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h
index e3fa847..002fa51 100644
--- a/include/tvm/meta_schedule/mutator.h
+++ b/include/tvm/meta_schedule/mutator.h
@@ -111,7 +111,7 @@ class PyMutatorNode : public MutatorNode {
*/
class Mutator : public runtime::ObjectRef {
public:
- /*! \brief Create a Mutator that mutates the tile size. */
+ /*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */
TVM_DLL static Mutator MutateTileSize();
/*!
* \brief Create a Mutator that mutates the parallel extent
diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py
index af3485b..e534ba1 100644
--- a/python/tvm/meta_schedule/mutator/__init__.py
+++ b/python/tvm/meta_schedule/mutator/__init__.py
@@ -21,5 +21,6 @@ design space.
"""
from .mutator import Mutator, PyMutator
from .mutate_compute_location import MutateComputeLocation
+from .mutate_tile_size import MutateTileSize
from .mutate_parallel import MutateParallel
from .mutate_unroll import MutateUnroll
diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/mutate_tile_size.py
similarity index 60%
copy from python/tvm/meta_schedule/mutator/__init__.py
copy to python/tvm/meta_schedule/mutator/mutate_tile_size.py
index af3485b..ff432a6 100644
--- a/python/tvm/meta_schedule/mutator/__init__.py
+++ b/python/tvm/meta_schedule/mutator/mutate_tile_size.py
@@ -14,12 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-The tvm.meta_schedule.mutator package.
-Meta Schedule mutator that mutates the trace to explore the
-design space.
-"""
-from .mutator import Mutator, PyMutator
-from .mutate_compute_location import MutateComputeLocation
-from .mutate_parallel import MutateParallel
-from .mutate_unroll import MutateUnroll
+"""Mutator that mutates the decision of instruction Sample-Perfect-Tile"""
+from tvm._ffi.registry import register_object
+
+from .. import _ffi_api
+from .mutator import Mutator
+
+
+@register_object("meta_schedule.MutateTileSize")
+class MutateTileSize(Mutator):
+ """Mutator that mutates the decision of instruction Sample-Perfect-Tile"""
+
+ def __init__(self) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.MutatorMutateTileSize, # type: ignore # pylint: disable=no-member
+ )
diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc
new file mode 100644
index 0000000..6e03488
--- /dev/null
+++ b/src/meta_schedule/mutator/mutate_tile_size.cc
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <mutex>
+#include <unordered_map>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+using tir::Instruction;
+using tir::InstructionKind;
+using tir::Trace;
+
+/*!
+ * \brief Downcast the decision of Sample-Perfect-Tile to an array of integers
+ * \param decision The decision of Sample-Perfect-Tile
+ * \return The result of downcast
+ */
+std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) {
+ const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode);
+ return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr));
+}
+
+/*!
+ * \brief Calculate the product of elements in an array
+ * \param array The array
+ * \return The product of elements in the array
+ */
+int64_t Product(const std::vector<int64_t>& array) {
+ int64_t result = 1;
+ for (int64_t x : array) {
+ result *= x;
+ }
+ return result;
+}
+
+/*! \brief A mutator that mutates the decision of instruction Sample-Perfect-Tile */
+class MutateTileSizeNode : public MutatorNode {
+ public:
+ void VisitAttrs(tvm::AttrVisitor* v) {}
+ static constexpr const char* _type_key = "meta_schedule.MutateTileSize";
+ TVM_DECLARE_FINAL_OBJECT_INFO(MutateTileSizeNode, MutatorNode);
+
+ public:
+ // Inherit from `MutatorNode`
+ void InitializeWithTuneContext(const TuneContext& context) final {}
+ // Inherit from `MutatorNode`
+ Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+};
+
+/*!
+ * \brief Find the Sample-Perfect-Tile instructions and their decisions in the trace
+ * \param trace The trace
+ * \param inst The instructions found
+ * \param decision The decisions of the instructions found
+ */
+void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst,
+ std::vector<std::vector<int64_t>>* decision) {
+ static const InstructionKind& inst_sample_perfect_tile =
+ InstructionKind::Get("SamplePerfectTile");
+ std::vector<Instruction>& instructions = *inst;
+ std::vector<std::vector<int64_t>>& decisions = *decision;
+ instructions.reserve(trace->decisions.size());
+ decisions.reserve(trace->decisions.size());
+ for (const auto& kv : trace->decisions) {
+ const Instruction& inst = kv.first;
+ const ObjectRef& decision = kv.second;
+ if (inst->kind.same_as(inst_sample_perfect_tile)) {
+ std::vector<int64_t> tiles = DowncastTilingDecision(decision);
+ if (tiles.size() >= 2 && Product(tiles) >= 2) {
+ instructions.push_back(inst);
+ decisions.push_back(tiles);
+ }
+ }
+ }
+}
+
+/*!
+ * \brief Find all Sample-Categorical instructions (and their decisions) whose outputs are used for
+ * cooperative fetch annotation
+ * \param trace The trace
+ * \param inst The instructions found
+ * \param decision The decisions of the instructions found
+ */
+void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
+ std::vector<int64_t>* decision) {
+ static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical");
+ static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
+ std::vector<Instruction>& instructions = *inst;
+ std::vector<int64_t>& decisions = *decision;
+ std::unordered_set<const Object*> annotated;
+ instructions.reserve(trace->decisions.size());
+ decisions.reserve(trace->decisions.size());
+ annotated.reserve(trace->decisions.size());
+ // Find annotation with `meta_schedule_cooperative_fetch`
+ for (const Instruction& inst : trace->insts) {
+ if (inst->kind.same_as(inst_annotate)) {
+ ICHECK_EQ(inst->attrs.size(), 1);
+ ICHECK_EQ(inst->inputs.size(), 2);
+ if (Downcast<String>(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) {
+ const auto* ann_val = inst->inputs[1].as<tir::ExprRVNode>();
+ ICHECK(ann_val);
+ annotated.insert(ann_val);
+ }
+ }
+ }
+ // Find sampling instruction that generates the annotation
+ for (const auto& kv : trace->decisions) {
+ const Instruction& inst = kv.first;
+ const ObjectRef& decision = kv.second;
+ if (inst->kind.same_as(inst_sample_categorical)) {
+ ICHECK_EQ(inst->outputs.size(), 1);
+ if (annotated.count(inst->outputs[0].get())) {
+ const auto* d = TVM_TYPE_AS(d, decision, IntImmNode);
+ instructions.push_back(inst);
+ decisions.push_back(d->value);
+ }
+ }
+ }
+}
+
+struct FactorMemo {
+ /*!
+ * \brief Find all factors of the input integer
+ * \param n The integer to be factorized
+ * \return The factors of the input integer
+ */
+ static std::vector<int> Factorize(int n) {
+ if (const std::vector<int>* result = Global()->Query(n)) {
+ return *result;
+ }
+ std::vector<int> result;
+ for (int64_t i = 1; i * i < n; ++i) {
+ if (n % i == 0) {
+ result.push_back(i);
+ if (i * i != n) {
+ result.push_back(n / i);
+ }
+ }
+ }
+ std::sort(result.begin(), result.end());
+ Global()->Add(n, result);
+ return result;
+ }
+
+ private:
+ const std::vector<int>* Query(int n) {
+ std::unique_lock<std::mutex> lock(mutex);
+ auto it = memo.find(n);
+ if (it != memo.end()) {
+ return &it->second;
+ }
+ return nullptr;
+ }
+
+ void Add(int n, std::vector<int> result) {
+ std::unique_lock<std::mutex> lock(mutex);
+ memo.emplace(n, std::move(result));
+ }
+
+ static FactorMemo* Global() {
+ static FactorMemo singleton;
+ return &singleton;
+ }
+
+ std::unordered_map<int, std::vector<int>> memo;
+ std::mutex mutex;
+};
+
+Optional<Trace> MutateSampleTileSize(const Trace& trace, Instruction inst,
+ std::vector<int64_t> tiles, TRandState* rand_state) {
+ int n_splits = tiles.size();
+ // Step 1. Choose two loops, `x` and `y`
+ int x, y;
+ // select source
+ while (true) {
+ x = tir::SampleInt(rand_state, 0, n_splits);
+ if (tiles[x] <= 1) {
+ continue;
+ }
+ y = tir::SampleInt(rand_state, 0, n_splits - 1);
+ if (y >= x) {
+ ++y;
+ }
+ std::vector<int> factors = FactorMemo::Factorize(tiles[x]);
+ // Step 2. Choose the divide factor
+ int64_t divide_factor;
+ if (y != n_splits - 1) {
+ divide_factor = factors[tir::SampleInt(rand_state, 1, factors.size())];
+ } else {
+ int64_t limit = Downcast<Integer>(inst->attrs[1])->value;
+ int max_factor_index = static_cast<int>(factors.size()) - 1;
+ for (; max_factor_index >= 1; max_factor_index--) {
+ if (factors[max_factor_index] * tiles[y] <= limit) {
+ break;
+ }
+ }
+ if (max_factor_index == 0) {
+ if (n_splits <= 2) {
+ return NullOpt;
+ }
+ // Failed on this dst_idx, try next one.
+ continue;
+ }
+ divide_factor = factors[tir::SampleInt(rand_state, 1, max_factor_index + 1)];
+ }
+ tiles[x] /= divide_factor;
+ tiles[y] *= divide_factor;
+ return trace->WithDecision(inst, support::AsArray<int64_t, ObjectRef>(tiles),
+ /*remove_postproc=*/true);
+ }
+}
+
+Optional<Trace> MutateSampleVectorize(const Trace& trace, Instruction inst,
+ int64_t original_decision, TRandState* rand_state) {
+ ICHECK_EQ(inst->attrs.size(), 2);
+ std::vector<double> probs =
+ support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(inst->attrs[1]));
+ probs.erase(probs.begin() + original_decision);
+ int result = tir::MakeMultinomialSampler(rand_state, probs)();
+ if (result >= original_decision) {
+ result += 1;
+ }
+ return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true);
+}
+
+Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) {
+ std::vector<Instruction> sample_perfect_tile_insts;
+ std::vector<Instruction> sample_vectorize_insts;
+ std::vector<std::vector<int64_t>> sample_perfect_tile_tiles;
+ std::vector<int64_t> sample_vectorize_decisions;
+ FindSamplePerfectTile(trace, &sample_perfect_tile_insts, &sample_perfect_tile_tiles);
+ FindSampleVectorize(trace, &sample_vectorize_insts, &sample_vectorize_decisions);
+ int size_a = sample_perfect_tile_insts.size();
+ int size_b = sample_vectorize_insts.size();
+ if (size_a == 0 && size_b == 0) {
+ return NullOpt;
+ }
+ int n = tir::SampleInt(rand_state, 0, size_a + size_b);
+ if (n < size_a) {
+ return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n],
+ rand_state);
+ } else {
+ n -= size_a;
+ return MutateSampleVectorize(trace, sample_vectorize_insts[n], sample_vectorize_decisions[n],
+ rand_state);
+ }
+}
+
+Mutator Mutator::MutateTileSize() { return Mutator(make_object<MutateTileSizeNode>()); }
+
+TVM_REGISTER_NODE_TYPE(MutateTileSizeNode);
+TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize").set_body_typed(Mutator::MutateTileSize);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py
new file mode 100644
index 0000000..9e75497
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+import operator
+from functools import reduce
+from typing import List
+
+from tvm.meta_schedule import TuneContext
+from tvm.meta_schedule.mutator import MutateTileSize, Mutator
+from tvm.script import tir as T
+from tvm.target import Target
+from tvm.tir import Schedule
+
+# pylint: disable=invalid-name, no-member
+
+
+@T.prim_func
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, [512, 512])
+ B = T.match_buffer(b, [512, 512])
+ C = T.match_buffer(c, [512, 512])
+ for i, j, k in T.grid(512, 512, 512): # type: ignore
+ with T.block("C"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore
+ with T.init():
+ C[vi, vj] = 0.0 # type: ignore
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+# pylint: enable=invalid-name, no-member
+
+
+def _sch(decisions: List[List[int]]) -> Schedule:
+ sch = Schedule(matmul, debug_mask="all")
+ # pylint: disable=invalid-name
+ (d0,) = decisions
+ b0 = sch.get_block(name="C", func_name="main")
+ sch.get_consumers(block=b0)
+ b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")
+ l2, l3, l4 = sch.get_loops(block=b0)
+ v5, v6, v7, v8 = sch.sample_perfect_tile(
+ loop=l2,
+ n=4,
+ max_innermost_factor=64,
+ decision=d0,
+ )
+ l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])
+ l17, l18, l19, l20 = sch.split(loop=l3, factors=[8, 4, 8, 2])
+ l23, l24 = sch.split(loop=l4, factors=[512, 1])
+ sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)
+ sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)
+ # pylint: enable=invalid-name
+ return sch
+
+
+def _make_mutator(target: Target) -> Mutator:
+ mutator = MutateTileSize()
+ mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target))
+ return mutator
+
+
+def test_mutate_tile_size_matmul():
+ mutator = _make_mutator(
+ target=Target("llvm --num-cores=16"),
+ )
+ results = {}
+ sch = _sch(decisions=[[4, 32, 4, 1]])
+ for _ in range(100):
+ trace = mutator.apply(sch.trace)
+ assert trace.insts[4].kind.name == "SamplePerfectTile"
+ decision = trace.decisions[trace.insts[4]]
+ decision = [int(x) for x in decision]
+ results[str(decision)] = decision
+ assert reduce(operator.mul, decision, 1) == 512
+ assert len(results) > 15
+
+
+if __name__ == "__main__":
+ test_mutate_tile_size_matmul()