You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/12/23 14:37:20 UTC

[tvm] branch main updated: [AutoScheduler] Support string processing to records (#7144)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e51bcdd  [AutoScheduler] Support string processing to records (#7144)
e51bcdd is described below

commit e51bcdd1ed99fc813454c68911f8032c852f48b4
Author: Cody Yu <co...@gmail.com>
AuthorDate: Wed Dec 23 06:37:10 2020 -0800

    [AutoScheduler] Support string processing to records (#7144)
    
    * [AutoScheduler] Support string processing to records
    
    * doc
    
    * remove log
---
 include/tvm/auto_scheduler/measure_record.h        |  6 +++-
 python/tvm/auto_scheduler/measure_record.py        | 37 ++++++++++++++++++++++
 src/auto_scheduler/measure_record.cc               | 23 +++++++++++---
 .../python/unittest/test_auto_scheduler_measure.py | 14 ++++++--
 4 files changed, 72 insertions(+), 8 deletions(-)

diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h
index fa8fe2b..4d7952f 100755
--- a/include/tvm/auto_scheduler/measure_record.h
+++ b/include/tvm/auto_scheduler/measure_record.h
@@ -34,6 +34,8 @@
 namespace tvm {
 namespace auto_scheduler {
 
+const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4";  // NOLINT(*)
+
 /*! \brief Callback for logging the input and results of measurements to file */
 class RecordToFileNode : public MeasureCallbackNode {
  public:
@@ -116,9 +118,11 @@ class RecordReader : public ObjectRef {
  * \param os A pointer to a output stream.
  * \param inputs The MeasureInputs to be written.
  * \param results The MeasureResults to be written.
+ * \param log_version The log version for the given record.
  */
 void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
-                         const Array<MeasureResult>& results);
+                         const Array<MeasureResult>& results,
+                         const std::string log_version = AUTO_SCHEDULER_LOG_VERSION);
 
 /*!
  * \brief Read one measure record from a string.
diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py
index d6fea5c..35e5e9b 100644
--- a/python/tvm/auto_scheduler/measure_record.py
+++ b/python/tvm/auto_scheduler/measure_record.py
@@ -98,6 +98,43 @@ class RecordReader(Object):
             yield ret[0], ret[1]  # (input, result)
 
 
+def load_record_from_string(record):
+    """
+    Load the measure record from string.
+
+    Parameters
+    ----------
+    record: str
+        A record string, including the serialized MeausreInput and MeasureResult.
+
+    Returns
+    -------
+    ret: Tuple[MeasureInput, MeasureResult]
+        A tuple of MeasureInput, MeasureResult.
+    """
+    return _ffi_api.ReadMeasureRecord(record)
+
+
+def dump_record_to_string(inp, res):
+    """
+    Dump the measure record to a string.
+
+    Parameters
+    ----------
+    inp: MeasureInput
+        The measure input.
+
+    res: MeasureResult
+        The measure result.
+
+    Returns
+    -------
+    ret: str
+        The dumped string.
+    """
+    return _ffi_api.WriteMeasureRecords(inp, res)
+
+
 def load_records(filename):
     """
     Load measurement records from a file.
diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc
index aad0abe..faf3fca 100644
--- a/src/auto_scheduler/measure_record.cc
+++ b/src/auto_scheduler/measure_record.cc
@@ -279,8 +279,6 @@ namespace auto_scheduler {
 TVM_REGISTER_OBJECT_TYPE(RecordToFileNode);
 TVM_REGISTER_OBJECT_TYPE(RecordReaderNode);
 
-const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4";  // NOLINT(*)
-
 RecordToFile::RecordToFile(String filename) {
   auto node = make_object<RecordToFileNode>();
   node->filename = std::move(filename);
@@ -288,13 +286,13 @@ RecordToFile::RecordToFile(String filename) {
 }
 
 void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
-                         const Array<MeasureResult>& results) {
+                         const Array<MeasureResult>& results, const std::string log_version) {
   dmlc::JSONWriter writer(os);
   for (size_t i = 0; i < inputs.size(); ++i) {
     writer.BeginObject(false);
     writer.WriteObjectKeyValue("i", *inputs[i].operator->());
     writer.WriteObjectKeyValue("r", *results[i].operator->());
-    writer.WriteObjectKeyValue("v", AUTO_SCHEDULER_LOG_VERSION);
+    writer.WriteObjectKeyValue("v", log_version);
     writer.EndObject();
     *os << "\n";
   }
@@ -398,6 +396,23 @@ TVM_REGISTER_GLOBAL("auto_scheduler.RecordReaderReadNext").set_body_typed([](Rec
   }
 });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.ReadMeasureRecord").set_body_typed([](const std::string& str) {
+  auto inp = make_object<MeasureInputNode>();
+  auto res = make_object<MeasureResultNode>();
+  std::string log_version;
+  ReadMeasureRecord(str, inp.get(), res.get(), &log_version);
+  return Array<ObjectRef>{ObjectRef(inp), ObjectRef(res)};
+});
+
+TVM_REGISTER_GLOBAL("auto_scheduler.WriteMeasureRecords")
+    .set_body_typed([](MeasureInput inp, MeasureResult res) {
+      auto inps = Array<MeasureInput>({inp});
+      auto ress = Array<MeasureResult>({res});
+      std::ostringstream ss;
+      WriteMeasureRecords(&ss, inps, ress);
+      return String(ss.str());
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.SaveRecords")
     .set_body_typed([](String filename, Array<MeasureInput> in, Array<MeasureResult> res) {
       std::ofstream ofs(filename, std::ofstream::app);
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index 10bb0b4..e9f1fa4 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -34,11 +34,19 @@ def record_common(dag, s):
     inp = auto_scheduler.measure.MeasureInput(task, s)
     res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
 
+    # Test in-memory record processing.
+    record_str = auto_scheduler.measure_record.dump_record_to_string(inp, res)
+    r_inp, r_res = auto_scheduler.measure_record.load_record_from_string(record_str)
+    # Only check the workload_key for simplification.
+    assert inp.task.workload_key == r_inp.task.workload_key
+    assert str(res) == str(r_res)
+
+    # Test file-based record processing.
     with tempfile.NamedTemporaryFile() as fp:
         auto_scheduler.save_records(fp.name, [inp], [res])
 
         log_reader = auto_scheduler.RecordReader(fp.name)
-        inputs, results = log_reader.read_lines()
+        inputs, _ = log_reader.read_lines()
         assert len(inputs) == 1
 
         s1 = dag.infer_bound_from_state(s)
@@ -180,7 +188,7 @@ def test_recover_measure_input():
         auto_scheduler.save_records(fp.name, [inp], [res])
 
         log_reader = auto_scheduler.RecordReader(fp.name)
-        inputs, results = log_reader.read_lines()
+        inputs, _ = log_reader.read_lines()
         assert len(inputs) == 1
 
         raw_inp = inputs[0]
@@ -266,7 +274,7 @@ def test_measure_target_host():
         auto_scheduler.save_records(fp.name, [inp], [res])
 
         log_reader = auto_scheduler.RecordReader(fp.name)
-        inputs, results = log_reader.read_lines()
+        inputs, _ = log_reader.read_lines()
         assert len(inputs) == 1
 
         raw_inp = inputs[0]