You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2021/12/23 01:37:32 UTC
[tvm] branch main updated: [M3c][MetaScheduler] Add More Measure Callbacks. (#9780)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b35fc83 [M3c][MetaScheduler] Add More Measure Callbacks. (#9780)
b35fc83 is described below
commit b35fc83670b47f18f997f24e1f22e263cd48e9fc
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Wed Dec 22 17:36:57 2021 -0800
[M3c][MetaScheduler] Add More Measure Callbacks. (#9780)
* Add measure callbacks.
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
* Fix comments.
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
---
include/tvm/meta_schedule/measure_callback.h | 146 +++++++++
include/tvm/meta_schedule/task_scheduler.h | 6 +
include/tvm/meta_schedule/tune_context.h | 1 +
.../tvm/meta_schedule/measure_callback/__init__.py | 24 ++
.../measure_callback/add_to_database.py | 30 ++
.../measure_callback/echo_statistics.py | 30 ++
.../measure_callback/measure_callback.py | 104 +++++++
.../measure_callback/remove_build_artifact.py | 30 ++
.../measure_callback/update_cost_model.py | 30 ++
.../measure_callback/add_to_database.cc | 65 ++++
.../measure_callback/echo_statistics.cc | 336 +++++++++++++++++++++
.../measure_callback/measure_callback.cc | 50 +++
.../measure_callback/remove_build_artifact.cc | 52 ++++
.../measure_callback/update_cost_model.cc | 53 ++++
src/meta_schedule/utils.h | 19 ++
.../test_meta_schedule_measure_callback.py | 132 ++++++++
16 files changed, 1108 insertions(+)
diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h
new file mode 100644
index 0000000..e9abb12
--- /dev/null
+++ b/include/tvm/meta_schedule/measure_callback.h
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
+#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
+
+#include <tvm/meta_schedule/builder.h>
+#include <tvm/meta_schedule/runner.h>
+#include <tvm/meta_schedule/search_strategy.h>
+#include <tvm/meta_schedule/tune_context.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+class TaskScheduler;
+
+/*! \brief Rules to apply after measure results is available. */
+class MeasureCallbackNode : public runtime::Object {
+ public:
+ /*! \brief Virtual destructor. */
+ virtual ~MeasureCallbackNode() = default;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {}
+
+ /*!
+ * \brief Apply a measure callback rule with given arguments.
+ * \param task_scheduler The task scheduler.
+ * \param task_id The id of the task (tune context) to apply measure callbacks.
+ * \param measure_candidates The measure candidates.
+ * \param builder_results The builder results by building the measure candidates.
+ * \param runner_results The runner results by running the built measure candidates.
+ */
+ virtual void Apply(const TaskScheduler& task_scheduler, //
+ int task_id, //
+ const Array<MeasureCandidate>& measure_candidates, //
+ const Array<BuilderResult>& builder_results, //
+ const Array<RunnerResult>& runner_results) = 0;
+
+ static constexpr const char* _type_key = "meta_schedule.MeasureCallback";
+ TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
+};
+
+/*! \brief The measure callback with customized methods on the python-side. */
+class PyMeasureCallbackNode : public MeasureCallbackNode {
+ public:
+ /*!
+ * \brief Apply a measure callback to the given schedule.
+ * \param task_scheduler The task scheduler.
+ * \param tasks The list of tune context to process.
+ * \param measure_candidates The measure candidates.
+ * \param builds The builder results by building the measure candidates.
+ * \param results The runner results by running the built measure candidates.
+ * \return Whether the measure callback was successfully applied.
+ */
+ using FApply =
+ runtime::TypedPackedFunc<void(const TaskScheduler& task_scheduler, //
+ int task_id, //
+ const Array<MeasureCandidate>& measure_candidates, //
+ const Array<BuilderResult>& builds, //
+ const Array<RunnerResult>& results)>;
+ /*!
+ * \brief Get the measure callback function as string with name.
+ * \return The string of the measure callback function.
+ */
+ using FAsString = runtime::TypedPackedFunc<String()>;
+
+ /*! \brief The packed function to the `Apply` function. */
+ FApply f_apply;
+ /*! \brief The packed function to the `AsString` function. */
+ FAsString f_as_string;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `f_apply` is not visited
+ // `f_as_string` is not visited
+ }
+
+ void Apply(const TaskScheduler& task_scheduler, //
+ int task_id, //
+ const Array<MeasureCandidate>& measure_candidates, //
+ const Array<BuilderResult>& builds, //
+ const Array<RunnerResult>& results) final {
+ ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
+ return this->f_apply(task_scheduler, task_id, measure_candidates, builds, results);
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode);
+};
+
+/*!
+ * \brief Managed reference to MeasureCallbackNode
+ * \sa MeasureCallbackNode
+ */
+class MeasureCallback : public runtime::ObjectRef {
+ public:
+ /*!
+ * \brief Create a measure callback that adds the measurement results into the database
+ * \return The measure callback created.
+ */
+ TVM_DLL static MeasureCallback AddToDatabase();
+ /*!
+ * \brief Create a measure callback that removes the build artifacts from the disk
+ * \return The measure callback created.
+ */
+ TVM_DLL static MeasureCallback RemoveBuildArtifact();
+ /*!
+ * \brief Create a measure callback that echos the statistics of the tuning process to the console
+ * \return The measure callback created.
+ */
+ TVM_DLL static MeasureCallback EchoStatistics();
+ /*!
+ * \brief Create a measure callback that updates the cost model with measurement result.
+ * \return The measure callback created.
+ */
+ TVM_DLL static MeasureCallback UpdateCostModel();
+ /*!
+ * \brief Create a measure callback with customized methods on the python-side.
+ * \param f_apply The packed function of `Apply`.
+ * \param f_as_string The packed function of `AsString`.
+ * \return The measure callback created.
+ */
+ TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply,
+ PyMeasureCallbackNode::FAsString f_as_string);
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
+};
+
+} // namespace meta_schedule
+} // namespace tvm
+
+#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h
index 5841e85..f28c33d 100644
--- a/include/tvm/meta_schedule/task_scheduler.h
+++ b/include/tvm/meta_schedule/task_scheduler.h
@@ -73,6 +73,10 @@ class TaskSchedulerNode : public runtime::Object {
Runner runner{nullptr};
/*! \brief The database of the scheduler. */
Database database{nullptr};
+ /*! \brief The cost model of the scheduler. */
+ Optional<CostModel> cost_model;
+ /*! \brief The list of measure callbacks of the scheduler. */
+ Array<MeasureCallback> measure_callbacks;
/*! \brief The default desctructor. */
virtual ~TaskSchedulerNode() = default;
@@ -82,6 +86,8 @@ class TaskSchedulerNode : public runtime::Object {
v->Visit("builder", &builder);
v->Visit("runner", &runner);
v->Visit("database", &database);
+ v->Visit("cost_model", &cost_model);
+ v->Visit("measure_callbacks", &measure_callbacks);
}
/*! \brief Auto-tuning. */
diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h
index 559f2da..6eacd4d 100644
--- a/include/tvm/meta_schedule/tune_context.h
+++ b/include/tvm/meta_schedule/tune_context.h
@@ -20,6 +20,7 @@
#define TVM_META_SCHEDULE_TUNE_CONTEXT_H_
#include <tvm/ir/module.h>
+#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/meta_schedule/space_generator.h>
#include <tvm/support/random_engine.h>
#include <tvm/target/target.h>
diff --git a/python/tvm/meta_schedule/measure_callback/__init__.py b/python/tvm/meta_schedule/measure_callback/__init__.py
new file mode 100644
index 0000000..f697e77
--- /dev/null
+++ b/python/tvm/meta_schedule/measure_callback/__init__.py
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+The tvm.meta_schedule.measure_callback package.
+"""
+from .measure_callback import MeasureCallback, PyMeasureCallback
+from .add_to_database import AddToDatabase
+from .echo_statistics import EchoStatistics
+from .remove_build_artifact import RemoveBuildArtifact
+from .update_cost_model import UpdateCostModel
diff --git a/python/tvm/meta_schedule/measure_callback/add_to_database.py b/python/tvm/meta_schedule/measure_callback/add_to_database.py
new file mode 100644
index 0000000..ab61e87
--- /dev/null
+++ b/python/tvm/meta_schedule/measure_callback/add_to_database.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A callback that adds the measurement results into the database"""
+from tvm._ffi import register_object
+
+from .. import _ffi_api
+from .measure_callback import MeasureCallback
+
+
+@register_object("meta_schedule.AddToDatabase")
+class AddToDatabase(MeasureCallback):
+ def __init__(self) -> None:
+ """A callback that adds the measurement results into the database"""
+ self.__init_handle_by_constructor__(
+ _ffi_api.MeasureCallbackAddToDatabase, # type: ignore # pylint: disable=no-member
+ )
diff --git a/python/tvm/meta_schedule/measure_callback/echo_statistics.py b/python/tvm/meta_schedule/measure_callback/echo_statistics.py
new file mode 100644
index 0000000..867409f
--- /dev/null
+++ b/python/tvm/meta_schedule/measure_callback/echo_statistics.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A callback that echos the statistics of the tuning process to the console"""
+from tvm._ffi import register_object
+
+from .. import _ffi_api
+from .measure_callback import MeasureCallback
+
+
+@register_object("meta_schedule.EchoStatistics")
+class EchoStatistics(MeasureCallback):
+ def __init__(self) -> None:
+ """A callback that echos the statistics of the tuning process to the console"""
+ self.__init_handle_by_constructor__(
+ _ffi_api.MeasureCallbackEchoStatistics, # type: ignore # pylint: disable=no-member
+ )
diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py
new file mode 100644
index 0000000..2b3a369
--- /dev/null
+++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Meta Schedule MeasureCallback."""
+
+from typing import List, TYPE_CHECKING
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+
+from .. import _ffi_api
+from ..builder import BuilderResult
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..utils import _get_hex_address, check_override
+
+if TYPE_CHECKING:
+ from ..task_scheduler import TaskScheduler
+
+
+@register_object("meta_schedule.MeasureCallback")
+class MeasureCallback(Object):
+ """Rules to apply after measure results is available."""
+
+ def apply(
+ self,
+ task_scheduler: "TaskScheduler",
+ task_id: int,
+ measure_candidates: List[MeasureCandidate],
+ builder_results: List[BuilderResult],
+ runner_results: List[RunnerResult],
+ ) -> None:
+ """Apply a measure callback to the given schedule.
+
+ Parameters
+ ----------
+ task_scheduler: TaskScheduler
+ The task scheduler.
+ task_id: int
+ The task id.
+ measure_candidates: List[MeasureCandidate]
+ The measure candidates.
+ builder_results: List[BuilderResult]
+ The builder results by building the measure candidates.
+ runner_results: List[RunnerResult]
+ The runner results by running the built measure candidates.
+ """
+ return _ffi_api.MeasureCallbackApply( # type: ignore # pylint: disable=no-member
+ self,
+ task_scheduler,
+ task_id,
+ measure_candidates,
+ builder_results,
+ runner_results,
+ )
+
+
+@register_object("meta_schedule.PyMeasureCallback")
+class PyMeasureCallback(MeasureCallback):
+ """An abstract MeasureCallback with customized methods on the python-side."""
+
+ def __init__(self):
+ """Constructor."""
+
+ @check_override(self.__class__, MeasureCallback)
+ def f_apply(
+ task_scheduler: "TaskScheduler",
+ task_id: int,
+ measure_candidates: List[MeasureCandidate],
+ builder_results: List[BuilderResult],
+ runner_results: List[RunnerResult],
+ ) -> None:
+ return self.apply(
+ task_scheduler,
+ task_id,
+ measure_candidates,
+ builder_results,
+ runner_results,
+ )
+
+ def f_as_string() -> str:
+ return str(self)
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.MeasureCallbackPyMeasureCallback, # type: ignore # pylint: disable=no-member
+ f_apply,
+ f_as_string,
+ )
+
+ def __str__(self) -> str:
+ return f"PyMeasureCallback({_get_hex_address(self.handle)})"
diff --git a/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py
new file mode 100644
index 0000000..4b2e1ab
--- /dev/null
+++ b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A callback that removes the build artifacts from the disk"""
+from tvm._ffi import register_object
+
+from .. import _ffi_api
+from .measure_callback import MeasureCallback
+
+
+@register_object("meta_schedule.RemoveBuildArtifact")
+class RemoveBuildArtifact(MeasureCallback):
+ def __init__(self) -> None:
+ """A callback that removes the build artifacts from the disk"""
+ self.__init_handle_by_constructor__(
+ _ffi_api.MeasureCallbackRemoveBuildArtifact, # type: ignore # pylint: disable=no-member
+ )
diff --git a/python/tvm/meta_schedule/measure_callback/update_cost_model.py b/python/tvm/meta_schedule/measure_callback/update_cost_model.py
new file mode 100644
index 0000000..c6ee1d2
--- /dev/null
+++ b/python/tvm/meta_schedule/measure_callback/update_cost_model.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A measure callback that updates the cost model"""
+from tvm._ffi import register_object
+
+from .. import _ffi_api
+from .measure_callback import MeasureCallback
+
+
+@register_object("meta_schedule.UpdateCostModel")
+class UpdateCostModel(MeasureCallback):
+ def __init__(self) -> None:
+ """A measure callback that updates the cost model"""
+ self.__init_handle_by_constructor__(
+ _ffi_api.MeasureCallbackUpdateCostModel, # type: ignore # pylint: disable=no-member
+ )
diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc
new file mode 100644
index 0000000..b294053
--- /dev/null
+++ b/src/meta_schedule/measure_callback/add_to_database.cc
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+class AddToDatabaseNode : public MeasureCallbackNode {
+ public:
+ void Apply(const TaskScheduler& task_scheduler, int task_id,
+ const Array<MeasureCandidate>& measure_candidates,
+ const Array<BuilderResult>& builder_results,
+ const Array<RunnerResult>& runner_results) final {
+ TuneContext task = task_scheduler->tasks[task_id];
+ Database database = task_scheduler->database;
+ Workload workload = database->CommitWorkload(task->mod.value());
+ Target target = task->target.value();
+ ICHECK_EQ(runner_results.size(), measure_candidates.size());
+ int n = runner_results.size();
+ for (int i = 0; i < n; ++i) {
+ RunnerResult result = runner_results[i];
+ MeasureCandidate candidate = measure_candidates[i];
+ if (result->error_msg.defined()) {
+ continue;
+ }
+ database->CommitTuningRecord(TuningRecord(
+ /*trace=*/candidate->sch->trace().value(),
+ /*run_secs=*/result->run_secs.value(),
+ /*workload=*/workload,
+ /*target=*/target,
+ /*args_info=*/candidate->args_info));
+ }
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.AddToDatabase";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AddToDatabaseNode, MeasureCallbackNode);
+};
+
+MeasureCallback MeasureCallback::AddToDatabase() {
+ ObjectPtr<AddToDatabaseNode> n = make_object<AddToDatabaseNode>();
+ return MeasureCallback(n);
+}
+
+TVM_REGISTER_NODE_TYPE(AddToDatabaseNode);
+TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackAddToDatabase")
+ .set_body_typed(MeasureCallback::AddToDatabase);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/measure_callback/echo_statistics.cc b/src/meta_schedule/measure_callback/echo_statistics.cc
new file mode 100644
index 0000000..1209e6c
--- /dev/null
+++ b/src/meta_schedule/measure_callback/echo_statistics.cc
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <sstream>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+double CountFlop(const IRModule& mod) {
+ struct TResult {
+ using TTable = std::unordered_map<int32_t, double>;
+
+ TResult() = default;
+
+ explicit TResult(const tvm::DataType& dtype) { Add(dtype); }
+
+ void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; }
+
+ TResult operator+=(const TResult& rhs) {
+ for (const auto& kv : rhs.data_) {
+ data_[kv.first] += kv.second;
+ }
+ return *this;
+ }
+
+ TResult operator*=(int64_t rhs) {
+ for (auto& kv : data_) {
+ kv.second *= rhs;
+ }
+ return *this;
+ }
+
+ TResult MaxWith(const TResult& rhs) {
+ for (const auto& kv : rhs.data_) {
+ double& v = data_[kv.first];
+ if (v < kv.second) {
+ v = kv.second;
+ }
+ }
+ return *this;
+ }
+
+ struct DType {
+ uint8_t code : 8;
+ uint8_t bits : 8;
+ uint16_t lanes : 16;
+ };
+ static_assert(sizeof(DType) == 4, "Incorrect size of DType");
+
+ static String Int2Str(int32_t dtype) {
+ union {
+ DType dst;
+ int32_t src;
+ } converter;
+ converter.src = dtype;
+ static std::string type_code_tab[] = {"int", "uint", "float", "handle", "bfloat"};
+ std::ostringstream os;
+ os << type_code_tab[converter.dst.code];
+ os << static_cast<int>(converter.dst.bits);
+ if (converter.dst.lanes != 1) {
+ os << "x" << static_cast<int>(converter.dst.lanes);
+ }
+ return os.str();
+ }
+
+ static int32_t DataType2Int(const tvm::DataType& dtype) {
+ union {
+ DType src;
+ int32_t dst;
+ } converter;
+ converter.src.code = dtype.code();
+ converter.src.bits = dtype.bits();
+ converter.src.lanes = dtype.lanes();
+ return converter.dst;
+ }
+
+ TTable data_;
+ };
+
+ class FlopCounter : public ExprFunctor<TResult(const PrimExpr& n)>,
+ public StmtFunctor<TResult(const Stmt& n)> {
+ public:
+ ~FlopCounter() {}
+
+ TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); }
+ TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); }
+
+ TResult VisitStmt_(const IfThenElseNode* branch) override {
+ TResult cond = VisitExpr(branch->condition);
+ cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case));
+ return cond;
+ }
+
+ TResult VisitStmt_(const BufferStoreNode* store) override {
+ TResult result = VisitExpr(store->value);
+ for (const PrimExpr& e : store->indices) {
+ result += VisitExpr(e);
+ }
+ return result;
+ }
+
+ TResult VisitStmt_(const SeqStmtNode* seq) override {
+ TResult result;
+ for (const Stmt& stmt : seq->seq) {
+ result += VisitStmt(stmt);
+ }
+ return result;
+ }
+
+ TResult VisitStmt_(const BlockRealizeNode* block) override {
+ return VisitStmt(block->block->body);
+ }
+
+ TResult VisitStmt_(const BlockNode* block) override {
+ TResult result;
+ if (block->init.defined()) {
+ result += VisitStmt(block->init.value());
+ }
+ result += VisitStmt(block->body);
+ return result;
+ }
+
+ TResult VisitStmt_(const ForNode* loop) override {
+ TResult result = VisitStmt(loop->body);
+ const auto* int_imm = loop->extent.as<IntImmNode>();
+ ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: "
+ << loop->extent->GetTypeKey();
+ result *= int_imm->value;
+ return result;
+ }
+
+#define TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(Node) \
+ TResult VisitExpr_(const Node* op) final { \
+ TResult result(op->dtype); \
+ result += VisitExpr(op->a); \
+ result += VisitExpr(op->b); \
+ return result; \
+ }
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AddNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(SubNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MulNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(DivNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(ModNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorDivNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorModNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MinNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MaxNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(EQNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(NENode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LTNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LENode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GTNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GENode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AndNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(OrNode);
+#undef TVM_META_SCHEDULE_FLOP_COUNTER_BINARY
+ TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); }
+ TResult VisitExpr_(const VarNode* op) override { return TResult(); }
+ TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); }
+ TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); }
+ TResult VisitExpr_(const IntImmNode* op) override { return TResult(); }
+ TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); }
+ TResult VisitExpr_(const NotNode* op) override {
+ TResult result(op->dtype);
+ result += VisitExpr(op->a);
+ return result;
+ }
+ TResult VisitExpr_(const SelectNode* op) override {
+ TResult cond = VisitExpr(op->condition);
+ cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value));
+ return cond;
+ }
+ TResult VisitExpr_(const CallNode* op) override {
+ TResult ret;
+ for (const auto& x : op->args) {
+ ret += VisitExpr(x);
+ }
+ return ret;
+ }
+ };
+ FlopCounter counter;
+ TResult result;
+ for (const auto& kv : mod->functions) {
+ const BaseFunc& base_func = kv.second;
+ if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
+ result += counter.VisitStmt(prim_func->body);
+ }
+ }
+ double cnt = 0.0;
+ int i32 = TResult::DataType2Int(tvm::DataType::Int(32));
+ int i64 = TResult::DataType2Int(tvm::DataType::Int(64));
+ int u1 = TResult::DataType2Int(tvm::DataType::UInt(1));
+ for (const auto& kv : result.data_) {
+ if (kv.first != i32 && kv.first != i64 && kv.first != u1) {
+ cnt += kv.second;
+ }
+ }
+ return cnt;
+}
+
+} // namespace tir
+} // namespace tvm
+
+namespace tvm {
+namespace meta_schedule {
+
+constexpr const double kMaxTime = 1e10;
+
+std::string GetTaskName(const TuneContext& task, int task_id) {
+ std::ostringstream os;
+ os << '#' << task_id << ": " << task->task_name;
+ return os.str();
+}
+
+double GetRunMs(const Array<FloatImm>& run_secs) {
+ double total = 0.0;
+ for (const FloatImm& i : run_secs) {
+ total += i->value;
+ }
+ return total * 1e3 / run_secs.size();
+}
+
+struct TaskInfo {
+ std::string name;
+ double flop = 0.0;
+ int trials = 0;
+ int best_round = -1;
+ double best_ms = kMaxTime;
+ double best_gflops = 0.0;
+ int error_count = 0;
+
+ explicit TaskInfo(const String& name) : name(name) {}
+
+ void Update(double run_ms) {
+ ++trials;
+ if (run_ms < best_ms) {
+ best_ms = run_ms;
+ best_round = trials;
+ best_gflops = flop / run_ms / 1e6;
+ }
+ LOG(INFO) << "[" << name << "] Trial #" << trials //
+ << std::fixed << std::setprecision(4) //
+ << ": GFLOPs: " << (flop / run_ms / 1e6) //
+ << ". Time: " << run_ms << " ms" //
+ << ". Best GFLOPs: " << best_gflops;
+ }
+
+ void UpdateError(std::string err, const MeasureCandidate& candidate) {
+ static const auto* f_proc = runtime::Registry::Get("meta_schedule._process_error_message");
+ ICHECK(f_proc != nullptr);
+ err = (*f_proc)(err).operator std::string();
+ ++error_count;
+ ++trials;
+ LOG(INFO) << "[" << name << "] Trial #" << trials //
+ << std::fixed << std::setprecision(4) //
+ << ": Error in building: " << err << "\n"
+ << tir::AsTVMScript(candidate->sch->mod()) << "\n"
+ << Concat(candidate->sch->trace().value()->AsPython(false), "\n");
+ }
+};
+
+class EchoStatisticsNode : public MeasureCallbackNode {
+ public:
+ void Apply(const TaskScheduler& task_scheduler, int task_id,
+ const Array<MeasureCandidate>& measure_candidates,
+ const Array<BuilderResult>& builder_results,
+ const Array<RunnerResult>& runner_results) final {
+ if (this->task_info.empty()) {
+ SetupTaskInfo(task_scheduler->tasks);
+ }
+ ICHECK_EQ(measure_candidates.size(), builder_results.size());
+ ICHECK_EQ(measure_candidates.size(), runner_results.size());
+ int n = measure_candidates.size();
+ TuneContext task = task_scheduler->tasks[task_id];
+ TaskInfo& info = this->task_info[task_id];
+ std::string task_name = GetTaskName(task, task_id);
+ for (int i = 0; i < n; ++i) {
+ MeasureCandidate candidate = measure_candidates[i];
+ BuilderResult builder_result = builder_results[i];
+ RunnerResult runner_result = runner_results[i];
+ if (Optional<String> err = builder_result->error_msg) {
+ info.UpdateError(err.value(), candidate);
+ } else if (Optional<String> err = runner_result->error_msg) {
+ info.UpdateError(err.value(), candidate);
+ } else {
+ ICHECK(runner_result->run_secs.defined());
+ info.Update(GetRunMs(runner_result->run_secs.value()));
+ }
+ }
+ }
+
+ void SetupTaskInfo(const Array<TuneContext>& tasks) {
+ task_info.reserve(tasks.size());
+ int task_id = 0;
+ for (const TuneContext& task : tasks) {
+ task_info.push_back(TaskInfo(GetTaskName(task, task_id)));
+ TaskInfo& info = task_info.back();
+ info.flop = tir::CountFlop(task->mod.value());
+ ++task_id;
+ }
+ }
+
+ std::vector<TaskInfo> task_info;
+
+ static constexpr const char* _type_key = "meta_schedule.EchoStatistics";
+ TVM_DECLARE_FINAL_OBJECT_INFO(EchoStatisticsNode, MeasureCallbackNode);
+};
+
+MeasureCallback MeasureCallback::EchoStatistics() {
+ ObjectPtr<EchoStatisticsNode> n = make_object<EchoStatisticsNode>();
+ return MeasureCallback(n);
+}
+
+TVM_REGISTER_NODE_TYPE(EchoStatisticsNode);
+TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackEchoStatistics")
+ .set_body_typed(MeasureCallback::EchoStatistics);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc
new file mode 100644
index 0000000..733d118
--- /dev/null
+++ b/src/meta_schedule/measure_callback/measure_callback.cc
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+MeasureCallback MeasureCallback::PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, //
+ PyMeasureCallbackNode::FAsString f_as_string) {
+ ObjectPtr<PyMeasureCallbackNode> n = make_object<PyMeasureCallbackNode>();
+ n->f_apply = std::move(f_apply);
+ n->f_as_string = std::move(f_as_string);
+ return MeasureCallback(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<PyMeasureCallbackNode>([](const ObjectRef& n, ReprPrinter* p) {
+ const auto* self = n.as<PyMeasureCallbackNode>();
+ ICHECK(self);
+ PyMeasureCallbackNode::FAsString f_as_string = (*self).f_as_string;
+ ICHECK(f_as_string != nullptr) << "PyMeasureCallback's AsString method not implemented!";
+ p->stream << f_as_string();
+ });
+
+TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
+TVM_REGISTER_NODE_TYPE(PyMeasureCallbackNode);
+
+TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply")
+ .set_body_method<MeasureCallback>(&MeasureCallbackNode::Apply);
+TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback")
+ .set_body_typed(MeasureCallback::PyMeasureCallback);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/measure_callback/remove_build_artifact.cc b/src/meta_schedule/measure_callback/remove_build_artifact.cc
new file mode 100644
index 0000000..649636d
--- /dev/null
+++ b/src/meta_schedule/measure_callback/remove_build_artifact.cc
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+class RemoveBuildArtifactNode : public MeasureCallbackNode {
+ public:
+ void Apply(const TaskScheduler& task_scheduler, int task_id,
+ const Array<MeasureCandidate>& measure_candidates,
+ const Array<BuilderResult>& builder_results,
+ const Array<RunnerResult>& runner_results) final {
+ static const PackedFunc* f_rm = runtime::Registry::Get("meta_schedule.remove_build_dir");
+ for (const BuilderResult& build_result : builder_results) {
+ if (Optional<String> path = build_result->artifact_path) {
+ (*f_rm)(path.value());
+ }
+ }
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.RemoveBuildArtifact";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RemoveBuildArtifactNode, MeasureCallbackNode);
+};
+
+MeasureCallback MeasureCallback::RemoveBuildArtifact() {
+ ObjectPtr<RemoveBuildArtifactNode> n = make_object<RemoveBuildArtifactNode>();
+ return MeasureCallback(n);
+}
+
+TVM_REGISTER_NODE_TYPE(RemoveBuildArtifactNode);
+TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackRemoveBuildArtifact")
+ .set_body_typed(MeasureCallback::RemoveBuildArtifact);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc
new file mode 100644
index 0000000..58c86ab
--- /dev/null
+++ b/src/meta_schedule/measure_callback/update_cost_model.cc
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+class UpdateCostModelNode : public MeasureCallbackNode {
+ public:
+ void Apply(const TaskScheduler& task_scheduler, int task_id,
+ const Array<MeasureCandidate>& measure_candidates,
+ const Array<BuilderResult>& builder_results,
+ const Array<RunnerResult>& runner_results) final {
+ TuneContext task = task_scheduler->tasks[task_id];
+ ICHECK(task_scheduler->cost_model.defined()) //
+ << "Cost model must be defined for the task scheduler!";
+ ICHECK(task->measure_candidates.defined()) //
+ << "Task's measure candidates must be present!";
+ CostModel cost_model = task_scheduler->cost_model.value();
+ cost_model->Update(task, task->measure_candidates.value(), runner_results);
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.UpdateCostModel";
+ TVM_DECLARE_FINAL_OBJECT_INFO(UpdateCostModelNode, MeasureCallbackNode);
+};
+
+MeasureCallback MeasureCallback::UpdateCostModel() {
+ ObjectPtr<UpdateCostModelNode> n = make_object<UpdateCostModelNode>();
+ return MeasureCallback(n);
+}
+
+TVM_REGISTER_NODE_TYPE(UpdateCostModelNode);
+TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackUpdateCostModel")
+ .set_body_typed(MeasureCallback::UpdateCostModel);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index f4f9575..0a9ce4a 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -25,6 +25,7 @@
#include <tvm/meta_schedule/cost_model.h>
#include <tvm/meta_schedule/database.h>
#include <tvm/meta_schedule/feature_extractor.h>
+#include <tvm/meta_schedule/measure_callback.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/meta_schedule/search_strategy.h>
@@ -214,6 +215,24 @@ inline std::vector<support::LinearCongruentialEngine::TRandState> ForkSeed(
return results;
}
+/*!
+ * \brief Concatenate strings
+ * \param strs The strings to concatenate
+ * \param delim The delimiter
+ * \return The concatenated string
+ */
+inline std::string Concat(const Array<String>& strs, const std::string& delim) {
+ if (strs.empty()) {
+ return "";
+ }
+ std::ostringstream os;
+ os << strs[0];
+ for (int i = 1, n = strs.size(); i < n; ++i) {
+ os << delim << strs[i];
+ }
+ return os.str();
+}
+
} // namespace meta_schedule
} // namespace tvm
diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py
new file mode 100644
index 0000000..b36d6ca
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_measure_callback.py
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+import re
+from typing import List
+
+import pytest
+import tvm
+from tvm.ir.base import assert_structural_equal
+from tvm.meta_schedule.builder import BuilderResult
+from tvm.meta_schedule.measure_callback import PyMeasureCallback
+from tvm.meta_schedule.runner import RunnerResult
+from tvm.meta_schedule.search_strategy import MeasureCandidate
+from tvm.meta_schedule.task_scheduler.task_scheduler import TaskScheduler
+from tvm.meta_schedule.utils import _get_hex_address
+from tvm.script import tir as T
+from tvm.tir.schedule import Schedule
+
+# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,
+# fmt: off
+
+@tvm.script.ir_module
+class Matmul:
+ @T.prim_func
+ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
+ T.func_attr({"global_symbol": "main"})
+ A = T.match_buffer(a, (1024, 1024), "float32")
+ B = T.match_buffer(b, (1024, 1024), "float32")
+ C = T.match_buffer(c, (1024, 1024), "float32")
+ for i, j, k in T.grid(1024, 1024, 1024):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+# fmt: on
+# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
+
+
+def test_meta_schedule_measure_callback():
+ class FancyMeasureCallback(PyMeasureCallback):
+ def apply(
+ self,
+ task_scheduler: TaskScheduler,
+ task_id: int,
+ measure_candidates: List[MeasureCandidate],
+ builds: List[BuilderResult],
+ results: List[RunnerResult],
+ ) -> None:
+ assert len(measure_candidates) == 1
+ assert_structural_equal(measure_candidates[0].sch.mod, Matmul)
+ assert (
+ len(builds) == 1
+ and builds[0].error_msg is None
+ and builds[0].artifact_path == "test_build"
+ )
+ assert (
+ len(results) == 1 and results[0].error_msg is None and len(results[0].run_secs) == 2
+ )
+
+ measure_callback = FancyMeasureCallback()
+ measure_callback.apply(
+ TaskScheduler(),
+ 0,
+ [MeasureCandidate(Schedule(Matmul), None)],
+ [BuilderResult("test_build", None)],
+ [RunnerResult([1.0, 2.1], None)],
+ )
+
+
+def test_meta_schedule_measure_callback_fail():
+ class FailingMeasureCallback(PyMeasureCallback):
+ def apply(
+ self,
+ task_scheduler: TaskScheduler,
+ task_id: int,
+ measure_candidates: List[MeasureCandidate],
+ builds: List[BuilderResult],
+ results: List[RunnerResult],
+ ) -> None:
+ raise ValueError("test")
+
+ measure_callback = FailingMeasureCallback()
+ with pytest.raises(ValueError, match="test"):
+ measure_callback.apply(
+ TaskScheduler(),
+ 0,
+ [MeasureCandidate(Schedule(Matmul), None)],
+ [BuilderResult("test_build", None)],
+ [RunnerResult([1.0, 2.1], None)],
+ )
+
+
+def test_meta_schedule_measure_callback_as_string():
+ class NotSoFancyMeasureCallback(PyMeasureCallback):
+ def apply(
+ self,
+ task_scheduler: "TaskScheduler",
+ task_id: int,
+ measure_candidates: List[MeasureCandidate],
+ builds: List[BuilderResult],
+ results: List[RunnerResult],
+ ) -> None:
+ pass
+
+ def __str__(self) -> str:
+ return f"NotSoFancyMeasureCallback({_get_hex_address(self.handle)})"
+
+ measure_callback = NotSoFancyMeasureCallback()
+ pattern = re.compile(r"NotSoFancyMeasureCallback\(0x[a-f|0-9]*\)")
+ assert pattern.match(str(measure_callback))
+
+
+if __name__ == "__main__":
+ test_meta_schedule_measure_callback()
+ test_meta_schedule_measure_callback_fail()
+ test_meta_schedule_measure_callback_as_string()