You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/01/29 04:24:24 UTC

[tvm] branch main updated: [MetaSchedule][M4a] Mutator: Mutate-Tile-Size (#10092)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d0dac3  [MetaSchedule][M4a] Mutator: Mutate-Tile-Size (#10092)
4d0dac3 is described below

commit 4d0dac3e552f157d641b8cf01119250203134676
Author: Ruihang Lai <la...@qq.com>
AuthorDate: Sat Jan 29 12:24:03 2022 +0800

    [MetaSchedule][M4a] Mutator: Mutate-Tile-Size (#10092)
    
    * [MetaSchedule][M4a] Mutator: Mutate-Tile-Size
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    
    * Python 3.8 has no `math.prod`
    
    Co-authored-by: Junru Shao <ju...@gmail.com>
    Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
    Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
    Co-authored-by: Hongyi Jin <32...@qq.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
---
 include/tvm/meta_schedule/mutator.h                |   2 +-
 python/tvm/meta_schedule/mutator/__init__.py       |   1 +
 .../mutator/{__init__.py => mutate_tile_size.py}   |  24 +-
 src/meta_schedule/mutator/mutate_tile_size.cc      | 273 +++++++++++++++++++++
 .../test_meta_schedule_mutator_mutate_tile_size.py |  93 +++++++
 5 files changed, 383 insertions(+), 10 deletions(-)

diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h
index e3fa847..002fa51 100644
--- a/include/tvm/meta_schedule/mutator.h
+++ b/include/tvm/meta_schedule/mutator.h
@@ -111,7 +111,7 @@ class PyMutatorNode : public MutatorNode {
  */
 class Mutator : public runtime::ObjectRef {
  public:
-  /*! \brief Create a Mutator that mutates the tile size. */
+  /*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */
   TVM_DLL static Mutator MutateTileSize();
   /*!
    * \brief Create a Mutator that mutates the parallel extent
diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py
index af3485b..e534ba1 100644
--- a/python/tvm/meta_schedule/mutator/__init__.py
+++ b/python/tvm/meta_schedule/mutator/__init__.py
@@ -21,5 +21,6 @@ design space.
 """
 from .mutator import Mutator, PyMutator
 from .mutate_compute_location import MutateComputeLocation
+from .mutate_tile_size import MutateTileSize
 from .mutate_parallel import MutateParallel
 from .mutate_unroll import MutateUnroll
diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/mutate_tile_size.py
similarity index 60%
copy from python/tvm/meta_schedule/mutator/__init__.py
copy to python/tvm/meta_schedule/mutator/mutate_tile_size.py
index af3485b..ff432a6 100644
--- a/python/tvm/meta_schedule/mutator/__init__.py
+++ b/python/tvm/meta_schedule/mutator/mutate_tile_size.py
@@ -14,12 +14,18 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""
-The tvm.meta_schedule.mutator package.
-Meta Schedule mutator that mutates the trace to explore the
-design space.
-"""
-from .mutator import Mutator, PyMutator
-from .mutate_compute_location import MutateComputeLocation
-from .mutate_parallel import MutateParallel
-from .mutate_unroll import MutateUnroll
+"""Mutator that mutates the decision of instruction Sample-Perfect-Tile"""
+from tvm._ffi.registry import register_object
+
+from .. import _ffi_api
+from .mutator import Mutator
+
+
+@register_object("meta_schedule.MutateTileSize")
+class MutateTileSize(Mutator):
+    """Mutator that mutates the decision of instruction Sample-Perfect-Tile"""
+
+    def __init__(self) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.MutatorMutateTileSize,  # type: ignore # pylint: disable=no-member
+        )
diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc
new file mode 100644
index 0000000..6e03488
--- /dev/null
+++ b/src/meta_schedule/mutator/mutate_tile_size.cc
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <mutex>
+#include <unordered_map>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+using tir::Instruction;
+using tir::InstructionKind;
+using tir::Trace;
+
+/*!
+ * \brief Downcast the decision of Sample-Perfect-Tile to an array of integers
+ * \param decision The decision of Sample-Perfect-Tile
+ * \return The result of downcast
+ */
+std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) {
+  const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode);
+  return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr));
+}
+
+/*!
+ * \brief Calculate the product of elements in an array
+ * \param array The array
+ * \return The product of elements in the array
+ */
+int64_t Product(const std::vector<int64_t>& array) {
+  int64_t result = 1;
+  for (int64_t x : array) {
+    result *= x;
+  }
+  return result;
+}
+
+/*! \brief A mutator that mutates the decision of instruction Sample-Perfect-Tile */
+class MutateTileSizeNode : public MutatorNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+  static constexpr const char* _type_key = "meta_schedule.MutateTileSize";
+  TVM_DECLARE_FINAL_OBJECT_INFO(MutateTileSizeNode, MutatorNode);
+
+ public:
+  // Inherit from `MutatorNode`
+  void InitializeWithTuneContext(const TuneContext& context) final {}
+  // Inherit from `MutatorNode`
+  Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+};
+
+/*!
+ * \brief Find the Sample-Perfect-Tile instructions and their decisions in the trace
+ * \param trace The trace
+ * \param inst The instructions found
+ * \param decision The decisions of the instructions found
+ */
+void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst,
+                           std::vector<std::vector<int64_t>>* decision) {
+  static const InstructionKind& inst_sample_perfect_tile =
+      InstructionKind::Get("SamplePerfectTile");
+  std::vector<Instruction>& instructions = *inst;
+  std::vector<std::vector<int64_t>>& decisions = *decision;
+  instructions.reserve(trace->decisions.size());
+  decisions.reserve(trace->decisions.size());
+  for (const auto& kv : trace->decisions) {
+    const Instruction& inst = kv.first;
+    const ObjectRef& decision = kv.second;
+    if (inst->kind.same_as(inst_sample_perfect_tile)) {
+      std::vector<int64_t> tiles = DowncastTilingDecision(decision);
+      if (tiles.size() >= 2 && Product(tiles) >= 2) {
+        instructions.push_back(inst);
+        decisions.push_back(tiles);
+      }
+    }
+  }
+}
+
+/*!
+ * \brief Find all Sample-Categorical instructions (and their decisions) whose outputs are used for
+ * cooperative fetch annotation
+ * \param trace The trace
+ * \param inst The instructions found
+ * \param decision The decisions of the instructions found
+ */
+void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
+                         std::vector<int64_t>* decision) {
+  static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical");
+  static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
+  std::vector<Instruction>& instructions = *inst;
+  std::vector<int64_t>& decisions = *decision;
+  std::unordered_set<const Object*> annotated;
+  instructions.reserve(trace->decisions.size());
+  decisions.reserve(trace->decisions.size());
+  annotated.reserve(trace->decisions.size());
+  // Find annotation with `meta_schedule_cooperative_fetch`
+  for (const Instruction& inst : trace->insts) {
+    if (inst->kind.same_as(inst_annotate)) {
+      ICHECK_EQ(inst->attrs.size(), 1);
+      ICHECK_EQ(inst->inputs.size(), 2);
+      if (Downcast<String>(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) {
+        const auto* ann_val = inst->inputs[1].as<tir::ExprRVNode>();
+        ICHECK(ann_val);
+        annotated.insert(ann_val);
+      }
+    }
+  }
+  // Find sampling instruction that generates the annotation
+  for (const auto& kv : trace->decisions) {
+    const Instruction& inst = kv.first;
+    const ObjectRef& decision = kv.second;
+    if (inst->kind.same_as(inst_sample_categorical)) {
+      ICHECK_EQ(inst->outputs.size(), 1);
+      if (annotated.count(inst->outputs[0].get())) {
+        const auto* d = TVM_TYPE_AS(d, decision, IntImmNode);
+        instructions.push_back(inst);
+        decisions.push_back(d->value);
+      }
+    }
+  }
+}
+
+struct FactorMemo {
+  /*!
+   * \brief Find all factors of the input integer
+   * \param n The integer to be factorized
+   * \return The factors of the input integer
+   */
+  static std::vector<int> Factorize(int n) {
+    if (const std::vector<int>* result = Global()->Query(n)) {
+      return *result;
+    }
+    std::vector<int> result;
+    for (int64_t i = 1; i * i < n; ++i) {
+      if (n % i == 0) {
+        result.push_back(i);
+        if (i * i != n) {
+          result.push_back(n / i);
+        }
+      }
+    }
+    std::sort(result.begin(), result.end());
+    Global()->Add(n, result);
+    return result;
+  }
+
+ private:
+  const std::vector<int>* Query(int n) {
+    std::unique_lock<std::mutex> lock(mutex);
+    auto it = memo.find(n);
+    if (it != memo.end()) {
+      return &it->second;
+    }
+    return nullptr;
+  }
+
+  void Add(int n, std::vector<int> result) {
+    std::unique_lock<std::mutex> lock(mutex);
+    memo.emplace(n, std::move(result));
+  }
+
+  static FactorMemo* Global() {
+    static FactorMemo singleton;
+    return &singleton;
+  }
+
+  std::unordered_map<int, std::vector<int>> memo;
+  std::mutex mutex;
+};
+
+Optional<Trace> MutateSampleTileSize(const Trace& trace, Instruction inst,
+                                     std::vector<int64_t> tiles, TRandState* rand_state) {
+  int n_splits = tiles.size();
+  // Step 1. Choose two loops, `x` and `y`
+  int x, y;
+  // select source
+  while (true) {
+    x = tir::SampleInt(rand_state, 0, n_splits);
+    if (tiles[x] <= 1) {
+      continue;
+    }
+    y = tir::SampleInt(rand_state, 0, n_splits - 1);
+    if (y >= x) {
+      ++y;
+    }
+    std::vector<int> factors = FactorMemo::Factorize(tiles[x]);
+    // Step 2. Choose the divide factor
+    int64_t divide_factor;
+    if (y != n_splits - 1) {
+      divide_factor = factors[tir::SampleInt(rand_state, 1, factors.size())];
+    } else {
+      int64_t limit = Downcast<Integer>(inst->attrs[1])->value;
+      int max_factor_index = static_cast<int>(factors.size()) - 1;
+      for (; max_factor_index >= 1; max_factor_index--) {
+        if (factors[max_factor_index] * tiles[y] <= limit) {
+          break;
+        }
+      }
+      if (max_factor_index == 0) {
+        if (n_splits <= 2) {
+          return NullOpt;
+        }
+        // Failed on this dst_idx, try next one.
+        continue;
+      }
+      divide_factor = factors[tir::SampleInt(rand_state, 1, max_factor_index + 1)];
+    }
+    tiles[x] /= divide_factor;
+    tiles[y] *= divide_factor;
+    return trace->WithDecision(inst, support::AsArray<int64_t, ObjectRef>(tiles),
+                               /*remove_postproc=*/true);
+  }
+}
+
+Optional<Trace> MutateSampleVectorize(const Trace& trace, Instruction inst,
+                                      int64_t original_decision, TRandState* rand_state) {
+  ICHECK_EQ(inst->attrs.size(), 2);
+  std::vector<double> probs =
+      support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(inst->attrs[1]));
+  probs.erase(probs.begin() + original_decision);
+  int result = tir::MakeMultinomialSampler(rand_state, probs)();
+  if (result >= original_decision) {
+    result += 1;
+  }
+  return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true);
+}
+
+Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) {
+  std::vector<Instruction> sample_perfect_tile_insts;
+  std::vector<Instruction> sample_vectorize_insts;
+  std::vector<std::vector<int64_t>> sample_perfect_tile_tiles;
+  std::vector<int64_t> sample_vectorize_decisions;
+  FindSamplePerfectTile(trace, &sample_perfect_tile_insts, &sample_perfect_tile_tiles);
+  FindSampleVectorize(trace, &sample_vectorize_insts, &sample_vectorize_decisions);
+  int size_a = sample_perfect_tile_insts.size();
+  int size_b = sample_vectorize_insts.size();
+  if (size_a == 0 && size_b == 0) {
+    return NullOpt;
+  }
+  int n = tir::SampleInt(rand_state, 0, size_a + size_b);
+  if (n < size_a) {
+    return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n],
+                                rand_state);
+  } else {
+    n -= size_a;
+    return MutateSampleVectorize(trace, sample_vectorize_insts[n], sample_vectorize_decisions[n],
+                                 rand_state);
+  }
+}
+
+Mutator Mutator::MutateTileSize() { return Mutator(make_object<MutateTileSizeNode>()); }
+
+TVM_REGISTER_NODE_TYPE(MutateTileSizeNode);
+TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize").set_body_typed(Mutator::MutateTileSize);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py
new file mode 100644
index 0000000..9e75497
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+import operator
+from functools import reduce
+from typing import List
+
+from tvm.meta_schedule import TuneContext
+from tvm.meta_schedule.mutator import MutateTileSize, Mutator
+from tvm.script import tir as T
+from tvm.target import Target
+from tvm.tir import Schedule
+
+# pylint: disable=invalid-name, no-member
+
+
+@T.prim_func
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [512, 512])
+    B = T.match_buffer(b, [512, 512])
+    C = T.match_buffer(c, [512, 512])
+    for i, j, k in T.grid(512, 512, 512):  # type: ignore
+        with T.block("C"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])  # type: ignore
+            with T.init():
+                C[vi, vj] = 0.0  # type: ignore
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+# pylint: enable=invalid-name, no-member
+
+
+def _sch(decisions: List[List[int]]) -> Schedule:
+    sch = Schedule(matmul, debug_mask="all")
+    # pylint: disable=invalid-name
+    (d0,) = decisions
+    b0 = sch.get_block(name="C", func_name="main")
+    sch.get_consumers(block=b0)
+    b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")
+    l2, l3, l4 = sch.get_loops(block=b0)
+    v5, v6, v7, v8 = sch.sample_perfect_tile(
+        loop=l2,
+        n=4,
+        max_innermost_factor=64,
+        decision=d0,
+    )
+    l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])
+    l17, l18, l19, l20 = sch.split(loop=l3, factors=[8, 4, 8, 2])
+    l23, l24 = sch.split(loop=l4, factors=[512, 1])
+    sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)
+    sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)
+    # pylint: enable=invalid-name
+    return sch
+
+
+def _make_mutator(target: Target) -> Mutator:
+    mutator = MutateTileSize()
+    mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target))
+    return mutator
+
+
+def test_mutate_tile_size_matmul():
+    mutator = _make_mutator(
+        target=Target("llvm --num-cores=16"),
+    )
+    results = {}
+    sch = _sch(decisions=[[4, 32, 4, 1]])
+    for _ in range(100):
+        trace = mutator.apply(sch.trace)
+        assert trace.insts[4].kind.name == "SamplePerfectTile"
+        decision = trace.decisions[trace.insts[4]]
+        decision = [int(x) for x in decision]
+        results[str(decision)] = decision
+        assert reduce(operator.mul, decision, 1) == 512
+    assert len(results) > 15
+
+
+if __name__ == "__main__":
+    test_mutate_tile_size_matmul()