You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/12/24 14:36:43 UTC

[tvm] branch main updated: [AutoScheduler] Python based measure callbacks (#7143)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 68e7838  [AutoScheduler] Python based measure callbacks (#7143)
68e7838 is described below

commit 68e7838b04ac3db6f1a3553c32872cf51d2955a1
Author: Cody Yu <co...@gmail.com>
AuthorDate: Thu Dec 24 06:36:20 2020 -0800

    [AutoScheduler] Python based measure callbacks (#7143)
    
    * add
    
    * make it work
    
    * format
    
    * add poilcy
    
    * comment
    
    * move test
    
    * format
    
    * fix ci
    
    * Delete useless old code
    
    Co-authored-by: Lianmin Zheng <li...@gmail.com>
---
 include/tvm/auto_scheduler/measure.h               | 29 ++++++++++++++++++++++
 python/tvm/auto_scheduler/measure.py               | 25 +++++++++++++++++++
 src/auto_scheduler/measure.cc                      | 27 ++++++++++++++++++++
 .../unittest/test_auto_scheduler_search_policy.py  | 12 ++++++++-
 4 files changed, 92 insertions(+), 1 deletion(-)

diff --git a/include/tvm/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h
index e8c01e8..841b6b9 100755
--- a/include/tvm/auto_scheduler/measure.h
+++ b/include/tvm/auto_scheduler/measure.h
@@ -232,6 +232,35 @@ class MeasureCallback : public ObjectRef {
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
 };
 
+/*! \brief A wrapper for measure callback defined by python code
+ *  This class will call functions defined in the python */
+class PythonBasedMeasureCallbackNode : public MeasureCallbackNode {
+ public:
+  /*! \brief Pointer to the callback funcion in python */
+  PackedFunc callback_func;
+
+  void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
+                const Array<MeasureResult>& results) final;
+  static constexpr const char* _type_key = "auto_scheduler.PythonBasedMeasureCallback";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedMeasureCallbackNode, MeasureCallbackNode);
+};
+
+/*!
+ * \brief Managed reference to PythonBasedMeasureCallbackNode.
+ * \sa PythonBasedMeasureCallbackNode
+ */
+class PythonBasedMeasureCallback : public MeasureCallback {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param callback_func The pointer to the callback function defined in python
+   */
+  explicit PythonBasedMeasureCallback(PackedFunc callback_func);
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedMeasureCallback, MeasureCallback,
+                                        PythonBasedMeasureCallbackNode);
+};
+
 // The base class of ProgramBuilders and ProgramRunners.
 
 /*! \brief ProgramBuilder that builds the programs */
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 7e4f149..38a420d 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -70,6 +70,31 @@ class MeasureCallback(Object):
     """ The base class of measurement callback functions. """
 
 
+@tvm._ffi.register_object("auto_scheduler.PythonBasedMeasureCallback")
+class PythonBasedMeasureCallback(MeasureCallback):
+    """Base class for measure callbacks implemented in python"""
+
+    def __init__(self):
+        def callback_func(policy, inputs, results):
+            self.callback(policy, inputs, results)
+
+        self.__init_handle_by_constructor__(_ffi_api.PythonBasedMeasureCallback, callback_func)
+
+    def callback(self, policy, inputs, results):
+        """The callback function.
+
+        Parameters
+        ----------
+        policy: auto_scheduler.search_policy.SearchPolicy
+            The search policy.
+        inputs : List[auto_scheduler.measure.MeasureInput]
+            The measurement inputs
+        results : List[auto_scheduler.measure.MeasureResult]
+            The measurement results
+        """
+        raise NotImplementedError
+
+
 @tvm._ffi.register_object("auto_scheduler.MeasureInput")
 class MeasureInput(Object):
     """Store the input of a measurement.
diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc
index 5b7e886..c3212f2 100755
--- a/src/auto_scheduler/measure.cc
+++ b/src/auto_scheduler/measure.cc
@@ -27,6 +27,8 @@
 
 #include <algorithm>
 
+#include "search_policy/empty_policy.h"
+#include "search_policy/sketch_policy.h"
 #include "utils.h"
 
 namespace tvm {
@@ -36,6 +38,7 @@ TVM_REGISTER_NODE_TYPE(MeasureInputNode);
 TVM_REGISTER_NODE_TYPE(BuildResultNode);
 TVM_REGISTER_NODE_TYPE(MeasureResultNode);
 TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
+TVM_REGISTER_OBJECT_TYPE(PythonBasedMeasureCallbackNode);
 TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
 TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
 TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode);
@@ -183,6 +186,25 @@ Array<MeasureResult> RPCRunnerNode::Run(const Array<MeasureInput>& inputs,
   return Array<MeasureResult>();
 }
 
+/********** MeasureCallback **********/
+PythonBasedMeasureCallback::PythonBasedMeasureCallback(PackedFunc callback_func) {
+  auto node = make_object<PythonBasedMeasureCallbackNode>();
+  node->callback_func = std::move(callback_func);
+  data_ = std::move(node);
+}
+
+void PythonBasedMeasureCallbackNode::Callback(const SearchPolicy& policy,
+                                              const Array<MeasureInput>& inputs,
+                                              const Array<MeasureResult>& results) {
+  if (auto* sketch_policy = static_cast<SketchPolicyNode*>(policy.operator->())) {
+    callback_func(GetRef<SketchPolicy>(sketch_policy), inputs, results);
+  } else if (auto* empty_policy = static_cast<EmptyPolicyNode*>(policy.operator->())) {
+    callback_func(GetRef<EmptyPolicy>(empty_policy), inputs, results);
+  } else {
+    LOG(FATAL) << "Unrecognized search policy type. Expect SketchPolicy or EmptyPolicy";
+  }
+}
+
 /********** ProgramMeasurer **********/
 ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
                                  Optional<Array<MeasureCallback>> callbacks, int verbose,
@@ -360,6 +382,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.MeasureResult")
       return MeasureResult(costs, error_no, error_msg, all_cost, timestamp);
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.PythonBasedMeasureCallback")
+    .set_body_typed([](PackedFunc callback_func) {
+      return PythonBasedMeasureCallback(callback_func);
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.ProgramMeasurer")
     .set_body_typed([](ProgramBuilder builder, ProgramRunner runner,
                        Array<MeasureCallback> callbacks, int verbose, int max_continuous_error) {
diff --git a/tests/python/unittest/test_auto_scheduler_search_policy.py b/tests/python/unittest/test_auto_scheduler_search_policy.py
index 1bb7449..6d4fb68 100644
--- a/tests/python/unittest/test_auto_scheduler_search_policy.py
+++ b/tests/python/unittest/test_auto_scheduler_search_policy.py
@@ -30,6 +30,16 @@ from test_auto_scheduler_common import matmul_auto_scheduler_test, PropagatingTh
 import multiprocessing
 
 
+class CustomMeasureCallback(auto_scheduler.measure.PythonBasedMeasureCallback):
+    """A simple Python-based callback for testing."""
+
+    def callback(self, policy, inputs, results):
+        assert isinstance(policy, auto_scheduler.search_policy.SearchPolicy)
+        for inp, res in zip(inputs, results):
+            assert isinstance(inp, auto_scheduler.MeasureInput)
+            assert isinstance(res, auto_scheduler.MeasureResult)
+
+
 def search_common(
     workload=matmul_auto_scheduler_test,
     target="llvm",
@@ -68,7 +78,7 @@ def search_common(
             early_stopping=1,
             runner=runner,
             verbose=2,
-            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+            measure_callbacks=[auto_scheduler.RecordToFile(log_file), CustomMeasureCallback()],
         )
         task.tune(tuning_options=tuning_options, search_policy=search_policy)
         sch, args = task.apply_best(log_file)