You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/12/24 14:36:43 UTC
[tvm] branch main updated: [AutoScheduler] Python based measure
callbacks (#7143)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 68e7838 [AutoScheduler] Python based measure callbacks (#7143)
68e7838 is described below
commit 68e7838b04ac3db6f1a3553c32872cf51d2955a1
Author: Cody Yu <co...@gmail.com>
AuthorDate: Thu Dec 24 06:36:20 2020 -0800
[AutoScheduler] Python based measure callbacks (#7143)
* add
* make it work
* format
* add poilcy
* comment
* move test
* format
* fix ci
* Delete useless old code
Co-authored-by: Lianmin Zheng <li...@gmail.com>
---
include/tvm/auto_scheduler/measure.h | 29 ++++++++++++++++++++++
python/tvm/auto_scheduler/measure.py | 25 +++++++++++++++++++
src/auto_scheduler/measure.cc | 27 ++++++++++++++++++++
.../unittest/test_auto_scheduler_search_policy.py | 12 ++++++++-
4 files changed, 92 insertions(+), 1 deletion(-)
diff --git a/include/tvm/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h
index e8c01e8..841b6b9 100755
--- a/include/tvm/auto_scheduler/measure.h
+++ b/include/tvm/auto_scheduler/measure.h
@@ -232,6 +232,35 @@ class MeasureCallback : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
};
+/*! \brief A wrapper for measure callback defined by python code
+ * This class will call functions defined in the python */
+class PythonBasedMeasureCallbackNode : public MeasureCallbackNode {
+ public:
+ /*! \brief Pointer to the callback funcion in python */
+ PackedFunc callback_func;
+
+ void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
+ const Array<MeasureResult>& results) final;
+ static constexpr const char* _type_key = "auto_scheduler.PythonBasedMeasureCallback";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedMeasureCallbackNode, MeasureCallbackNode);
+};
+
+/*!
+ * \brief Managed reference to PythonBasedMeasureCallbackNode.
+ * \sa PythonBasedMeasureCallbackNode
+ */
+class PythonBasedMeasureCallback : public MeasureCallback {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param callback_func The pointer to the callback function defined in python
+ */
+ explicit PythonBasedMeasureCallback(PackedFunc callback_func);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedMeasureCallback, MeasureCallback,
+ PythonBasedMeasureCallbackNode);
+};
+
// The base class of ProgramBuilders and ProgramRunners.
/*! \brief ProgramBuilder that builds the programs */
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 7e4f149..38a420d 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -70,6 +70,31 @@ class MeasureCallback(Object):
""" The base class of measurement callback functions. """
+@tvm._ffi.register_object("auto_scheduler.PythonBasedMeasureCallback")
+class PythonBasedMeasureCallback(MeasureCallback):
+ """Base class for measure callbacks implemented in python"""
+
+ def __init__(self):
+ def callback_func(policy, inputs, results):
+ self.callback(policy, inputs, results)
+
+ self.__init_handle_by_constructor__(_ffi_api.PythonBasedMeasureCallback, callback_func)
+
+ def callback(self, policy, inputs, results):
+ """The callback function.
+
+ Parameters
+ ----------
+ policy: auto_scheduler.search_policy.SearchPolicy
+ The search policy.
+ inputs : List[auto_scheduler.measure.MeasureInput]
+ The measurement inputs
+ results : List[auto_scheduler.measure.MeasureResult]
+ The measurement results
+ """
+ raise NotImplementedError
+
+
@tvm._ffi.register_object("auto_scheduler.MeasureInput")
class MeasureInput(Object):
"""Store the input of a measurement.
diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc
index 5b7e886..c3212f2 100755
--- a/src/auto_scheduler/measure.cc
+++ b/src/auto_scheduler/measure.cc
@@ -27,6 +27,8 @@
#include <algorithm>
+#include "search_policy/empty_policy.h"
+#include "search_policy/sketch_policy.h"
#include "utils.h"
namespace tvm {
@@ -36,6 +38,7 @@ TVM_REGISTER_NODE_TYPE(MeasureInputNode);
TVM_REGISTER_NODE_TYPE(BuildResultNode);
TVM_REGISTER_NODE_TYPE(MeasureResultNode);
TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
+TVM_REGISTER_OBJECT_TYPE(PythonBasedMeasureCallbackNode);
TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode);
@@ -183,6 +186,25 @@ Array<MeasureResult> RPCRunnerNode::Run(const Array<MeasureInput>& inputs,
return Array<MeasureResult>();
}
+/********** MeasureCallback **********/
+PythonBasedMeasureCallback::PythonBasedMeasureCallback(PackedFunc callback_func) {
+ auto node = make_object<PythonBasedMeasureCallbackNode>();
+ node->callback_func = std::move(callback_func);
+ data_ = std::move(node);
+}
+
+void PythonBasedMeasureCallbackNode::Callback(const SearchPolicy& policy,
+ const Array<MeasureInput>& inputs,
+ const Array<MeasureResult>& results) {
+ if (auto* sketch_policy = static_cast<SketchPolicyNode*>(policy.operator->())) {
+ callback_func(GetRef<SketchPolicy>(sketch_policy), inputs, results);
+ } else if (auto* empty_policy = static_cast<EmptyPolicyNode*>(policy.operator->())) {
+ callback_func(GetRef<EmptyPolicy>(empty_policy), inputs, results);
+ } else {
+ LOG(FATAL) << "Unrecognized search policy type. Expect SketchPolicy or EmptyPolicy";
+ }
+}
+
/********** ProgramMeasurer **********/
ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
Optional<Array<MeasureCallback>> callbacks, int verbose,
@@ -360,6 +382,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.MeasureResult")
return MeasureResult(costs, error_no, error_msg, all_cost, timestamp);
});
+TVM_REGISTER_GLOBAL("auto_scheduler.PythonBasedMeasureCallback")
+ .set_body_typed([](PackedFunc callback_func) {
+ return PythonBasedMeasureCallback(callback_func);
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.ProgramMeasurer")
.set_body_typed([](ProgramBuilder builder, ProgramRunner runner,
Array<MeasureCallback> callbacks, int verbose, int max_continuous_error) {
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 1bb7449..6d4fb68 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -30,6 +30,16 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test, PropagatingTh
import multiprocessing
+class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback):
+ """A simple Python-based callback for testing."""
+
+ def callback(self, policy, inputs, results):
+ assert isinstance(policy, auto_scheduler.search_policy.SearchPolicy)
+ for inp, res in zip(inputs, results):
+ assert isinstance(inp, auto_scheduler.MeasureInput)
+ assert isinstance(res, auto_scheduler.MeasureResult)
+
+
def search_common(
workload=matmul_auto_scheduler_test,
target="llvm",
@@ -68,7 +78,7 @@ def search_common(
early_stopping=1,
runner=runner,
verbose=2,
- measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+ measure_callbacks=[auto_scheduler.RecordToFile(log_file), CustomMeasureCallback()],
)
task.tune(tuning_options=tuning_options, search_policy=search_policy)
sch, args = task.apply_best(log_file)