You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/12/23 14:37:20 UTC
[tvm] branch main updated: [AutoScheduler] Support string
processing to records (#7144)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e51bcdd [AutoScheduler] Support string processing to records (#7144)
e51bcdd is described below
commit e51bcdd1ed99fc813454c68911f8032c852f48b4
Author: Cody Yu <co...@gmail.com>
AuthorDate: Wed Dec 23 06:37:10 2020 -0800
[AutoScheduler] Support string processing to records (#7144)
* [AutoScheduler] Support string processing to records
* doc
* remove log
---
include/tvm/auto_scheduler/measure_record.h | 6 +++-
python/tvm/auto_scheduler/measure_record.py | 37 ++++++++++++++++++++++
src/auto_scheduler/measure_record.cc | 23 +++++++++++---
.../python/unittest/test_auto_scheduler_measure.py | 14 ++++++--
4 files changed, 72 insertions(+), 8 deletions(-)
diff --git a/include/tvm/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h
index fa8fe2b..4d7952f 100755
--- a/include/tvm/auto_scheduler/measure_record.h
+++ b/include/tvm/auto_scheduler/measure_record.h
@@ -34,6 +34,8 @@
namespace tvm {
namespace auto_scheduler {
+const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*)
+
/*! \brief Callback for logging the input and results of measurements to file */
class RecordToFileNode : public MeasureCallbackNode {
public:
@@ -116,9 +118,11 @@ class RecordReader : public ObjectRef {
* \param os A pointer to a output stream.
* \param inputs The MeasureInputs to be written.
* \param results The MeasureResults to be written.
+ * \param log_version The log version for the given record.
*/
void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
- const Array<MeasureResult>& results);
+ const Array<MeasureResult>& results,
+ const std::string log_version = AUTO_SCHEDULER_LOG_VERSION);
/*!
* \brief Read one measure record from a string.
diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py
index d6fea5c..35e5e9b 100644
--- a/python/tvm/auto_scheduler/measure_record.py
+++ b/python/tvm/auto_scheduler/measure_record.py
@@ -98,6 +98,43 @@ class RecordReader(Object):
yield ret[0], ret[1] # (input, result)
+def load_record_from_string(record):
+ """
+ Load the measure record from string.
+
+ Parameters
+ ----------
+ record: str
+ A record string, including the serialized MeausreInput and MeasureResult.
+
+ Returns
+ -------
+ ret: Tuple[MeasureInput, MeasureResult]
+ A tuple of MeasureInput, MeasureResult.
+ """
+ return _ffi_api.ReadMeasureRecord(record)
+
+
+def dump_record_to_string(inp, res):
+ """
+ Dump the measure record to a string.
+
+ Parameters
+ ----------
+ inp: MeasureInput
+ The measure input.
+
+ res: MeasureResult
+ The measure result.
+
+ Returns
+ -------
+ ret: str
+ The dumped string.
+ """
+ return _ffi_api.WriteMeasureRecords(inp, res)
+
+
def load_records(filename):
"""
Load measurement records from a file.
diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc
index aad0abe..faf3fca 100644
--- a/src/auto_scheduler/measure_record.cc
+++ b/src/auto_scheduler/measure_record.cc
@@ -279,8 +279,6 @@ namespace auto_scheduler {
TVM_REGISTER_OBJECT_TYPE(RecordToFileNode);
TVM_REGISTER_OBJECT_TYPE(RecordReaderNode);
-const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*)
-
RecordToFile::RecordToFile(String filename) {
auto node = make_object<RecordToFileNode>();
node->filename = std::move(filename);
@@ -288,13 +286,13 @@ RecordToFile::RecordToFile(String filename) {
}
void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
- const Array<MeasureResult>& results) {
+ const Array<MeasureResult>& results, const std::string log_version) {
dmlc::JSONWriter writer(os);
for (size_t i = 0; i < inputs.size(); ++i) {
writer.BeginObject(false);
writer.WriteObjectKeyValue("i", *inputs[i].operator->());
writer.WriteObjectKeyValue("r", *results[i].operator->());
- writer.WriteObjectKeyValue("v", AUTO_SCHEDULER_LOG_VERSION);
+ writer.WriteObjectKeyValue("v", log_version);
writer.EndObject();
*os << "\n";
}
@@ -398,6 +396,23 @@ TVM_REGISTER_GLOBAL("auto_scheduler.RecordReaderReadNext").set_body_typed([](Rec
}
});
+TVM_REGISTER_GLOBAL("auto_scheduler.ReadMeasureRecord").set_body_typed([](const std::string& str) {
+ auto inp = make_object<MeasureInputNode>();
+ auto res = make_object<MeasureResultNode>();
+ std::string log_version;
+ ReadMeasureRecord(str, inp.get(), res.get(), &log_version);
+ return Array<ObjectRef>{ObjectRef(inp), ObjectRef(res)};
+});
+
+TVM_REGISTER_GLOBAL("auto_scheduler.WriteMeasureRecords")
+ .set_body_typed([](MeasureInput inp, MeasureResult res) {
+ auto inps = Array<MeasureInput>({inp});
+ auto ress = Array<MeasureResult>({res});
+ std::ostringstream ss;
+ WriteMeasureRecords(&ss, inps, ress);
+ return String(ss.str());
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.SaveRecords")
.set_body_typed([](String filename, Array<MeasureInput> in, Array<MeasureResult> res) {
std::ofstream ofs(filename, std::ofstream::app);
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index 10bb0b4..e9f1fa4 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -34,11 +34,19 @@ def record_common(dag, s):
inp = auto_scheduler.measure.MeasureInput(task, s)
res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
+ # Test in-memory record processing.
+ record_str = auto_scheduler.measure_record.dump_record_to_string(inp, res)
+ r_inp, r_res = auto_scheduler.measure_record.load_record_from_string(record_str)
+ # Only check the workload_key for simplification.
+ assert inp.task.workload_key == r_inp.task.workload_key
+ assert str(res) == str(r_res)
+
+ # Test file-based record processing.
with tempfile.NamedTemporaryFile() as fp:
auto_scheduler.save_records(fp.name, [inp], [res])
log_reader = auto_scheduler.RecordReader(fp.name)
- inputs, results = log_reader.read_lines()
+ inputs, _ = log_reader.read_lines()
assert len(inputs) == 1
s1 = dag.infer_bound_from_state(s)
@@ -180,7 +188,7 @@ def test_recover_measure_input():
auto_scheduler.save_records(fp.name, [inp], [res])
log_reader = auto_scheduler.RecordReader(fp.name)
- inputs, results = log_reader.read_lines()
+ inputs, _ = log_reader.read_lines()
assert len(inputs) == 1
raw_inp = inputs[0]
@@ -266,7 +274,7 @@ def test_measure_target_host():
auto_scheduler.save_records(fp.name, [inp], [res])
log_reader = auto_scheduler.RecordReader(fp.name)
- inputs, results = log_reader.read_lines()
+ inputs, _ = log_reader.read_lines()
assert len(inputs) == 1
raw_inp = inputs[0]