You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/11 07:18:14 UTC
[tvm] branch main updated: [MetaSchedule] JSONDatabase Utilities (#11680)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0df69611b2 [MetaSchedule] JSONDatabase Utilities (#11680)
0df69611b2 is described below
commit 0df69611b2fb46724a0023dd8d389c9a1ecedcb8
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sat Jun 11 00:18:10 2022 -0700
[MetaSchedule] JSONDatabase Utilities (#11680)
This PR adds some utility to JSONDatabase to accelerate its loading/saving time.
---
python/tvm/meta_schedule/utils.py | 28 +-
src/meta_schedule/arg_info.cc | 2 +-
src/meta_schedule/database/database.cc | 2 +-
src/meta_schedule/database/database_utils.cc | 377 +++++++++++++++++++++
src/meta_schedule/database/json_database.cc | 80 ++++-
src/meta_schedule/utils.h | 103 +++---
.../python/unittest/test_meta_schedule_database.py | 68 ++--
7 files changed, 526 insertions(+), 134 deletions(-)
diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py
index 919a29e6cf..26bf206709 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -16,12 +16,11 @@
# under the License.
"""Utilities for meta schedule"""
import ctypes
-import json
import logging
import os
import shutil
from contextlib import contextmanager
-from typing import Any, List, Dict, Callable, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
import psutil # type: ignore
from tvm._ffi import get_global_func, register_func
@@ -296,31 +295,6 @@ def _json_de_tvm(obj: Any) -> Any:
raise TypeError("Not supported type: " + str(type(obj)))
-@register_func("meta_schedule.json_obj2str")
-def json_obj2str(json_obj: Any) -> str:
- json_obj = _json_de_tvm(json_obj)
- return json.dumps(json_obj)
-
-
-@register_func("meta_schedule.batch_json_str2obj")
-def batch_json_str2obj(json_strs: List[str]) -> List[Any]:
- """Covert a list of JSON strings to a list of json objects.
- Parameters
- ----------
- json_strs : List[str]
- The list of JSON strings
- Returns
- -------
- result : List[Any]
- The list of json objects
- """
- return [
- json.loads(json_str)
- for json_str in map(str.strip, json_strs)
- if json_str and (not json_str.startswith("#")) and (not json_str.startswith("//"))
- ]
-
-
def shash2hex(mod: IRModule) -> str:
"""Get the structural hash of a module.
diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc
index 104662b6aa..9b225e8bea 100644
--- a/src/meta_schedule/arg_info.cc
+++ b/src/meta_schedule/arg_info.cc
@@ -88,7 +88,7 @@ TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) {
dtype = runtime::String2DLDataType(dtype_str);
}
// Load json[2] => shape
- shape = Downcast<Array<Integer>>(json_array->at(2));
+ shape = AsIntArray(json_array->at(2));
} catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error
LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
<< "\nThe error is: " << e.what();
diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc
index 86d999e4fd..9905ff73c7 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -115,7 +115,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w
CHECK(json_array && json_array->size() == 4);
// Load json[1] => run_secs
if (json_array->at(1).defined()) {
- run_secs = Downcast<Array<FloatImm>>(json_array->at(1));
+ run_secs = AsFloatArray(json_array->at(1));
}
// Load json[2] => target
if (json_array->at(2).defined()) {
diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc
new file mode 100644
index 0000000000..278c5267ea
--- /dev/null
+++ b/src/meta_schedule/database/database_utils.cc
@@ -0,0 +1,377 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <iomanip>
+#include <sstream>
+#include <vector>
+
+#include "../../support/str_escape.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+void JSONDumps(ObjectRef json_obj, std::ostringstream& os) {
+ if (!json_obj.defined()) {
+ os << "null";
+ } else if (const auto* int_imm = json_obj.as<IntImmNode>()) {
+ if (int_imm->dtype == DataType::Bool()) {
+ if (int_imm->value) {
+ os << "true";
+ } else {
+ os << "false";
+ }
+ } else {
+ os << int_imm->value;
+ }
+ } else if (const auto* float_imm = json_obj.as<FloatImmNode>()) {
+ os << std::setprecision(20) << float_imm->value;
+ } else if (const auto* str = json_obj.as<runtime::StringObj>()) {
+ os << '"' << support::StrEscape(str->data, str->size) << '"';
+ } else if (const auto* array = json_obj.as<runtime::ArrayNode>()) {
+ os << "[";
+ int n = array->size();
+ for (int i = 0; i < n; ++i) {
+ if (i != 0) {
+ os << ",";
+ }
+ JSONDumps(array->at(i), os);
+ }
+ os << "]";
+ } else if (const auto* dict = json_obj.as<runtime::MapNode>()) {
+ int n = dict->size();
+ std::vector<std::pair<String, ObjectRef>> key_values;
+ key_values.reserve(n);
+ for (const auto& kv : *dict) {
+ if (const auto* k = kv.first.as<StringObj>()) {
+ key_values.emplace_back(GetRef<String>(k), kv.second);
+ } else {
+ LOG(FATAL) << "TypeError: Only string keys are supported in JSON dumps, but got: "
+ << kv.first->GetTypeKey();
+ }
+ }
+ std::sort(key_values.begin(), key_values.end());
+ os << "{";
+ for (int i = 0; i < n; ++i) {
+ const auto& kv = key_values[i];
+ if (i != 0) {
+ os << ",";
+ }
+ os << '"' << support::StrEscape(kv.first->data, kv.first->size) << '"';
+ os << ":";
+ JSONDumps(kv.second, os);
+ }
+ os << "}";
+ } else {
+ LOG(FATAL) << "TypeError: Unsupported type in JSON object: " << json_obj->GetTypeKey();
+ }
+}
+
+std::string JSONDumps(ObjectRef json_obj) {
+ std::ostringstream os;
+ JSONDumps(json_obj, os);
+ return os.str();
+}
+
+class JSONTokenizer {
+ public:
+ enum class TokenType : int32_t {
+ kEOF = 0, // end of file
+ kNull = 1, // null
+ kTrue = 2, // true
+ kFalse = 3, // false
+ kLeftSquare = 4, // [
+ kRightSquare = 5, // ]
+ kLeftCurly = 6, // {
+ kRightCurly = 7, // }
+ kComma = 8, // ,
+ kColon = 9, // :
+ kInteger = 10, // integers
+ kFloat = 11, // floating point numbers
+ kString = 12, // string
+ };
+
+ struct Token {
+ TokenType type;
+ ObjectRef value{nullptr};
+ };
+
+ explicit JSONTokenizer(const char* st, const char* ed) : cur_(st), end_(ed) {}
+
+ Token Next() {
+ for (; cur_ != end_ && std::isspace(*cur_); ++cur_) {
+ }
+ if (cur_ == end_) return Token{TokenType::kEOF};
+ if (NextLeftSquare()) return Token{TokenType::kLeftSquare};
+ if (NextRightSquare()) return Token{TokenType::kRightSquare};
+ if (NextLeftCurly()) return Token{TokenType::kLeftCurly};
+ if (NextRightCurly()) return Token{TokenType::kRightCurly};
+ if (NextComma()) return Token{TokenType::kComma};
+ if (NextColon()) return Token{TokenType::kColon};
+ if (NextNull()) return Token{TokenType::kNull};
+ if (NextTrue()) return Token{TokenType::kTrue};
+ if (NextFalse()) return Token{TokenType::kFalse};
+ Token token;
+ if (NextString(&token)) return token;
+ if (NextNumber(&token)) return token;
+ LOG(FATAL) << "ValueError: Cannot tokenize: " << std::string(cur_, end_);
+ throw;
+ }
+
+ private:
+ bool NextLeftSquare() { return NextLiteral('['); }
+ bool NextRightSquare() { return NextLiteral(']'); }
+ bool NextLeftCurly() { return NextLiteral('{'); }
+ bool NextRightCurly() { return NextLiteral('}'); }
+ bool NextComma() { return NextLiteral(','); }
+ bool NextColon() { return NextLiteral(':'); }
+ bool NextNull() { return NextLiteral("null", 4); }
+ bool NextTrue() { return NextLiteral("true", 4); }
+ bool NextFalse() { return NextLiteral("false", 5); }
+
+ bool NextNumber(Token* token) {
+ using runtime::DataType;
+ bool is_float = false;
+ const char* st = cur_;
+ for (; cur_ != end_; ++cur_) {
+ if (std::isdigit(*cur_) || *cur_ == '+' || *cur_ == '-') {
+ continue;
+ } else if (*cur_ == '.' || *cur_ == 'e' || *cur_ == 'E') {
+ is_float = true;
+ } else {
+ break;
+ }
+ }
+ if (st == cur_) {
+ return false;
+ }
+ // TODO(@junrushao1994): error checking
+ if (is_float) {
+ *token = Token{TokenType::kFloat,
+ FloatImm(DataType::Float(64), //
+ std::stod(std::string(st, cur_)))};
+ } else {
+ *token = Token{TokenType::kInteger, //
+ Integer(std::stoi(std::string(st, cur_)))};
+ }
+ return true;
+ }
+
+ bool NextString(Token* token) {
+ if (cur_ == end_ || *cur_ != '"') return false;
+ ++cur_;
+ std::string str;
+ for (; cur_ != end_ && *cur_ != '\"'; ++cur_) {
+ if (*cur_ != '\\') {
+ str.push_back(*cur_);
+ continue;
+ }
+ ++cur_;
+ if (cur_ == end_) {
+ LOG(FATAL) << "ValueError: Unexpected end of string: \\";
+ throw;
+ }
+ switch (*cur_) {
+ case '\"':
+ str.push_back('\"');
+ break;
+ case '\\':
+ str.push_back('\\');
+ break;
+ case '/':
+ str.push_back('/');
+ break;
+ case 'b':
+ str.push_back('\b');
+ break;
+ case 'f':
+ str.push_back('\f');
+ break;
+ case 'n':
+ str.push_back('\n');
+ break;
+ case 'r':
+ str.push_back('\r');
+ break;
+ case 't':
+ str.push_back('\t');
+ break;
+ default:
+ LOG(FATAL) << "ValueError: Unsupported escape sequence: \\" << *cur_;
+ }
+ }
+ if (cur_ == end_) {
+ LOG(FATAL) << "ValueError: Unexpected end of string";
+ }
+ ++cur_;
+ *token = Token{TokenType::kString, String(str)};
+ return true;
+ }
+
+ bool NextLiteral(char c) {
+ if (cur_ != end_ && *cur_ == c) {
+ ++cur_;
+ return true;
+ }
+ return false;
+ }
+
+ bool NextLiteral(const char* str, int len) {
+ if (cur_ + len <= end_ && std::strncmp(cur_, str, len) == 0) {
+ cur_ += len;
+ return true;
+ }
+ return false;
+ }
+ /*! \brief The current pointer */
+ const char* cur_;
+ /*! \brief End of the string */
+ const char* end_;
+
+ friend class JSONParser;
+};
+
+class JSONParser {
+ public:
+ using TokenType = JSONTokenizer::TokenType;
+ using Token = JSONTokenizer::Token;
+
+ explicit JSONParser(const char* st, const char* ed) : tokenizer_(st, ed) {}
+
+ ObjectRef Get() {
+ Token token = tokenizer_.Next();
+ if (token.type == TokenType::kEOF) {
+ return ObjectRef(nullptr);
+ }
+ return ParseObject(std::move(token));
+ }
+
+ private:
+ ObjectRef ParseObject(Token token) {
+ switch (token.type) {
+ case TokenType::kNull:
+ return ObjectRef(nullptr);
+ case TokenType::kTrue:
+ return Bool(true);
+ case TokenType::kFalse:
+ return Bool(false);
+ case TokenType::kLeftSquare:
+ return ParseArray();
+ case TokenType::kLeftCurly:
+ return ParseDict();
+ case TokenType::kString:
+ case TokenType::kInteger:
+ case TokenType::kFloat:
+ return token.value;
+ case TokenType::kRightSquare:
+ LOG(FATAL) << "ValueError: Unexpected token: ]";
+ case TokenType::kRightCurly:
+ LOG(FATAL) << "ValueError: Unexpected token: }";
+ case TokenType::kComma:
+ LOG(FATAL) << "ValueError: Unexpected token: ,";
+ case TokenType::kColon:
+ LOG(FATAL) << "ValueError: Unexpected token: :";
+ case TokenType::kEOF:
+ LOG(FATAL) << "ValueError: Unexpected EOF";
+ default:
+ throw;
+ }
+ }
+
+ Array<ObjectRef> ParseArray() {
+ bool is_first = true;
+ Array<ObjectRef> results;
+ for (;;) {
+ Token token;
+ if (is_first) {
+ is_first = false;
+ token = Token{TokenType::kComma};
+ } else {
+ token = tokenizer_.Next();
+ }
+ // Three cases overall:
+ // - Case 1. 1 token: "]"
+ // - Case 2. 2 tokens: ",", "]"
+ // - Case 3. 2 tokens: ",", "obj"
+ if (token.type == TokenType::kRightSquare) { // Case 1
+ break;
+ } else if (token.type == TokenType::kComma) {
+ token = tokenizer_.Next();
+ if (token.type == TokenType::kRightSquare) { // Case 2
+ break;
+ }
+ // Case 3
+ results.push_back(ParseObject(std::move(token)));
+ continue;
+ } else {
+ LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_;
+ }
+ }
+ return results;
+ }
+
+ Map<String, ObjectRef> ParseDict() {
+ bool is_first = true;
+ Map<String, ObjectRef> results;
+ for (;;) {
+ Token token;
+ if (is_first) {
+ is_first = false;
+ token = Token{TokenType::kComma};
+ } else {
+ token = tokenizer_.Next();
+ }
+ // Three cases overall:
+ // - Case 1. 1 token: "}"
+ // - Case 2. 2 tokens: ",", "}"
+ // - Case 3. 2 tokens: ",", "key", ":", "value"
+ if (token.type == TokenType::kRightCurly) { // Case 1
+ break;
+ } else if (token.type == TokenType::kComma) {
+ token = tokenizer_.Next();
+ if (token.type == TokenType::kRightCurly) { // Case 2
+ break;
+ }
+ // Case 3
+ ObjectRef key = ParseObject(std::move(token));
+ ICHECK(key->IsInstance<StringObj>())
+ << "ValueError: key must be a string, but gets: " << key;
+ token = tokenizer_.Next();
+ CHECK(token.type == TokenType::kColon)
+ << "ValueError: Unexpected token before: " << tokenizer_.cur_;
+ ObjectRef value = ParseObject(tokenizer_.Next());
+ results.Set(Downcast<String>(key), value);
+ continue;
+ } else {
+ LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_;
+ }
+ }
+ return results;
+ }
+
+ JSONTokenizer tokenizer_;
+};
+
+ObjectRef JSONLoads(std::string str) {
+ const char* st = str.c_str();
+ const char* ed = st + str.length();
+ return JSONParser(st, ed).Get();
+}
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc
index 155d223217..4f5bd9b136 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <set>
+#include <thread>
#include <unordered_map>
#include "../utils.h"
@@ -46,6 +47,45 @@ struct SortTuningRecordByMeanRunSecs {
}
};
+/*!
+ * \brief Read lines from a json file.
+ * \param path The path to the json file.
+ * \param num_lines The number of threads used to concurrently parse the lines.
+ * \param allow_missing Whether to create new file when the given path is not found.
+ * \return An array containing lines read from the json file.
+ */
+std::vector<ObjectRef> JSONFileReadLines(const String& path, int num_threads, bool allow_missing) {
+ std::ifstream is(path);
+ if (is.good()) {
+ std::vector<String> json_strs;
+ for (std::string str; std::getline(is, str);) {
+ json_strs.push_back(str);
+ }
+ int n = json_strs.size();
+ std::vector<ObjectRef> json_objs;
+ json_objs.resize(n);
+ support::parallel_for_dynamic(0, n, num_threads, [&](int thread_id, int task_id) {
+ json_objs[task_id] = JSONLoads(json_strs[task_id]);
+ });
+ return json_objs;
+ }
+ CHECK(allow_missing) << "ValueError: File doesn't exist: " << path;
+ std::ofstream os(path);
+ CHECK(os.good()) << "ValueError: Cannot create new file: " << path;
+ return {};
+}
+
+/*!
+ * \brief Append a line to a json file.
+ * \param path The path to the json file.
+ * \param line The line to append.
+ */
+void JSONFileAppendLine(const String& path, const std::string& line) {
+ std::ofstream os(path, std::ofstream::app);
+ CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path;
+ os << line << std::endl;
+}
+
/*! \brief The default database implementation, which mimics two database tables with two files. */
class JSONDatabaseNode : public DatabaseNode {
public:
@@ -83,7 +123,7 @@ class JSONDatabaseNode : public DatabaseNode {
// If `mod` is new in `workloads2idx_`, append it to the workload file
if (inserted) {
it->second = static_cast<int>(this->workloads2idx_.size()) - 1;
- JSONFileAppendLine(this->path_workload, JSONObj2Str(workload->AsJSON()));
+ JSONFileAppendLine(this->path_workload, JSONDumps(workload->AsJSON()));
}
return it->first;
}
@@ -91,7 +131,7 @@ class JSONDatabaseNode : public DatabaseNode {
void CommitTuningRecord(const TuningRecord& record) {
this->tuning_records_.insert(record);
JSONFileAppendLine(this->path_tuning_record,
- JSONObj2Str(Array<ObjectRef>{
+ JSONDumps(Array<ObjectRef>{
/*workload_index=*/Integer(this->workloads2idx_.at(record->workload)),
/*tuning_record=*/record->AsJSON() //
}));
@@ -121,11 +161,12 @@ class JSONDatabaseNode : public DatabaseNode {
Database Database::JSONDatabase(String path_workload, String path_tuning_record,
bool allow_missing) {
+ int num_threads = std::thread::hardware_concurrency();
ObjectPtr<JSONDatabaseNode> n = make_object<JSONDatabaseNode>();
// Load `n->workloads2idx_` from `path_workload`
std::vector<Workload> workloads;
{
- Array<ObjectRef> json_objs = JSONStr2Obj(JSONFileReadLines(path_workload, allow_missing));
+ std::vector<ObjectRef> json_objs = JSONFileReadLines(path_workload, num_threads, allow_missing);
int n_objs = json_objs.size();
n->workloads2idx_.reserve(n_objs);
workloads.reserve(n_objs);
@@ -137,20 +178,25 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record,
}
// Load `n->tuning_records_` from `path_tuning_record`
{
- Array<ObjectRef> json_objs = JSONStr2Obj(JSONFileReadLines(path_tuning_record, allow_missing));
- for (const ObjectRef& json_obj : json_objs) {
- int workload_index = -1;
- ObjectRef tuning_record{nullptr};
- try {
- const ArrayNode* arr = json_obj.as<ArrayNode>();
- ICHECK_EQ(arr->size(), 2);
- workload_index = Downcast<Integer>(arr->at(0));
- tuning_record = arr->at(1);
- } catch (std::runtime_error& e) {
- LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
- << "\nThe error is: " << e.what();
- }
- n->tuning_records_.insert(TuningRecord::FromJSON(tuning_record, workloads[workload_index]));
+ std::vector<ObjectRef> json_objs =
+ JSONFileReadLines(path_tuning_record, num_threads, allow_missing);
+ std::vector<TuningRecord> records;
+ records.resize(json_objs.size(), TuningRecord{nullptr});
+ support::parallel_for_dynamic(
+ 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) {
+ const ObjectRef& json_obj = json_objs[task_id];
+ try {
+ const ArrayNode* arr = json_obj.as<ArrayNode>();
+ ICHECK_EQ(arr->size(), 2);
+ records[task_id] = TuningRecord::FromJSON(arr->at(1), //
+ workloads[Downcast<Integer>(arr->at(0))]);
+ } catch (std::runtime_error& e) {
+ LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
+ << "\nThe error is: " << e.what();
+ }
+ });
+ for (const TuningRecord& record : records) {
+ n->tuning_records_.insert(record);
}
}
n->path_workload = path_workload;
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index be7745f23d..40c301c617 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -107,38 +107,6 @@ class PyLogMessage {
/*! \brief The type of the random state */
using TRandState = support::LinearCongruentialEngine::TRandState;
-/*!
- * \brief Read lines from a json file.
- * \param path The path to the json file.
- * \param allow_missing Whether to create new file when the given path is not found.
- * \return An array containing lines read from the json file.
- */
-inline Array<String> JSONFileReadLines(const String& path, bool allow_missing) {
- std::ifstream is(path);
- if (is.good()) {
- Array<String> results;
- for (std::string str; std::getline(is, str);) {
- results.push_back(str);
- }
- return results;
- }
- CHECK(allow_missing) << "ValueError: File doesn't exist: " << path;
- std::ofstream os(path);
- CHECK(os.good()) << "ValueError: Cannot create new file: " << path;
- return {};
-}
-
-/*!
- * \brief Append a line to a json file.
- * \param path The path to the json file.
- * \param line The line to append.
- */
-inline void JSONFileAppendLine(const String& path, const std::string& line) {
- std::ofstream os(path, std::ofstream::app);
- CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path;
- os << line << std::endl;
-}
-
/*!
* \brief Get the base64 encoded result of a string.
* \param str The string to encode.
@@ -168,31 +136,18 @@ inline std::string Base64Decode(std::string str) {
}
/*!
- * \brief Parse lines of json string into a json object.
- * \param lines The lines of json string.
- * \return Array of json objects parsed.
- * \note The function calls the python-side json parser in runtime registry.
+ * \brief Parses a json string into a json object.
+ * \param json_str The json string.
+ * \return The json object
*/
-inline Array<ObjectRef> JSONStr2Obj(const Array<String>& lines) {
- static const runtime::PackedFunc* f_to_obj =
- runtime::Registry::Get("meta_schedule.batch_json_str2obj");
- ICHECK(f_to_obj) << "IndexError: Cannot find the packed function "
- "`meta_schedule.batch_json_str2obj` in the global registry";
- return (*f_to_obj)(lines);
-}
+ObjectRef JSONLoads(std::string json_str);
/*!
- * \brief Serialize a json object into a json string.
- * \param json_obj The json object to serialize.
- * \return A string containing the serialized json object.
- * \note The function calls the python-side json obj serializer in runtime registry.
+ * \brief Dumps a json object into a json string.
+ * \param json_obj The json object.
+ * \return The json string
*/
-inline String JSONObj2Str(const ObjectRef& json_obj) {
- static const runtime::PackedFunc* f_to_str = runtime::Registry::Get("meta_schedule.json_obj2str");
- ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
- "`meta_schedule.json_obj2str` in the global registry";
- return (*f_to_str)(json_obj);
-}
+std::string JSONDumps(ObjectRef json_obj);
/*!
* \brief Converts a structural hash code to string
@@ -447,6 +402,48 @@ inline double GetRunMsMedian(const RunnerResult& runner_result) {
}
}
+/*!
+ * \brief Convert the given object to an array of floating point numbers
+ * \param obj The object to be converted
+ * \return The array of floating point numbers
+ */
+inline Array<FloatImm> AsFloatArray(const ObjectRef& obj) {
+ const ArrayNode* arr = obj.as<ArrayNode>();
+ ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey();
+ Array<FloatImm> results;
+ results.reserve(arr->size());
+ for (const ObjectRef& elem : *arr) {
+ if (const auto* int_imm = elem.as<IntImmNode>()) {
+ results.push_back(FloatImm(DataType::Float(32), int_imm->value));
+ } else if (const auto* float_imm = elem.as<FloatImmNode>()) {
+ results.push_back(FloatImm(DataType::Float(32), float_imm->value));
+ } else {
+ LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << elem->GetTypeKey();
+ }
+ }
+ return results;
+}
+
+/*!
+ * \brief Convert the given object to an array of integers
+ * \param obj The object to be converted
+ * \return The array of integers
+ */
+inline Array<Integer> AsIntArray(const ObjectRef& obj) {
+ const ArrayNode* arr = obj.as<ArrayNode>();
+ ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey();
+ Array<Integer> results;
+ results.reserve(arr->size());
+ for (const ObjectRef& elem : *arr) {
+ if (const auto* int_imm = elem.as<IntImmNode>()) {
+ results.push_back(Integer(int_imm->value));
+ } else {
+ LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey();
+ }
+ }
+ return results;
+}
+
} // namespace meta_schedule
} // namespace tvm
diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py
index 1edfbe6c7a..ff0f350d89 100644
--- a/tests/python/unittest/test_meta_schedule_database.py
+++ b/tests/python/unittest/test_meta_schedule_database.py
@@ -17,20 +17,18 @@
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
"""Test Meta Schedule Database"""
import os.path as osp
-import sys
import tempfile
from typing import Callable
-import pytest
import tvm
import tvm.testing
+from tvm import meta_schedule as ms
from tvm import tir
from tvm.ir.module import IRModule
-from tvm.meta_schedule.arg_info import ArgInfo
-from tvm.meta_schedule.database import JSONDatabase, TuningRecord
from tvm.script import tir as T
from tvm.tir import Schedule
+
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
# fmt: off
@tvm.script.ir_module
@@ -92,13 +90,13 @@ def _create_schedule(mod: IRModule, sch_fn: Callable[[Schedule], None]) -> Sched
return sch
-def _create_tmp_database(tmpdir: str) -> JSONDatabase:
+def _create_tmp_database(tmpdir: str) -> ms.database.JSONDatabase:
path_workload = osp.join(tmpdir, "workloads.json")
path_tuning_record = osp.join(tmpdir, "tuning_records.json")
- return JSONDatabase(path_workload, path_tuning_record)
+ return ms.database.JSONDatabase(path_workload, path_tuning_record)
-def _equal_record(a: TuningRecord, b: TuningRecord):
+def _equal_record(a: ms.database.TuningRecord, b: ms.database.TuningRecord):
assert str(a.trace) == str(b.trace)
assert str(a.run_secs) == str(b.run_secs)
# AWAIT(@zxybazh): change to export after fixing "(bool)0"
@@ -113,15 +111,15 @@ def test_meta_schedule_tuning_record_round_trip():
with tempfile.TemporaryDirectory() as tmpdir:
database = _create_tmp_database(tmpdir)
workload = database.commit_workload(mod)
- record = TuningRecord(
+ record = ms.database.TuningRecord(
_create_schedule(mod, _schedule_matmul).trace,
workload,
[1.5, 2.5, 1.8],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
)
database.commit_tuning_record(record)
- new_record = TuningRecord.from_json(record.as_json(), workload)
+ new_record = ms.database.TuningRecord.from_json(record.as_json(), workload)
_equal_record(record, new_record)
@@ -138,12 +136,12 @@ def test_meta_schedule_database_has_workload():
with tempfile.TemporaryDirectory() as tmpdir:
database = _create_tmp_database(tmpdir)
workload = database.commit_workload(mod)
- record = TuningRecord(
+ record = ms.database.TuningRecord(
_create_schedule(mod, _schedule_matmul).trace,
workload,
[1.5, 2.5, 1.8],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
)
database.commit_tuning_record(record)
assert len(database) == 1
@@ -156,12 +154,12 @@ def test_meta_schedule_database_add_entry():
with tempfile.TemporaryDirectory() as tmpdir:
database = _create_tmp_database(tmpdir)
workload = database.commit_workload(mod)
- record = TuningRecord(
+ record = ms.database.TuningRecord(
_create_schedule(mod, _schedule_matmul).trace,
workload,
[1.5, 2.5, 1.8],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
)
database.commit_tuning_record(record)
assert len(database) == 1
@@ -176,12 +174,12 @@ def test_meta_schedule_database_missing():
database = _create_tmp_database(tmpdir)
workload = database.commit_workload(mod)
workload_2 = database.commit_workload(mod_2)
- record = TuningRecord(
+ record = ms.database.TuningRecord(
_create_schedule(mod, _schedule_matmul).trace,
workload,
[1.5, 2.5, 1.8],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
)
database.commit_tuning_record(record)
ret = database.get_top_k(workload_2, 3)
@@ -195,47 +193,47 @@ def test_meta_schedule_database_sorting():
token = database.commit_workload(mod)
trace = _create_schedule(mod, _schedule_matmul).trace
records = [
- TuningRecord(
+ ms.database.TuningRecord(
trace,
token,
[7.0, 8.0, 9.0],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
),
- TuningRecord(
+ ms.database.TuningRecord(
trace,
token,
[1.0, 2.0, 3.0],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
),
- TuningRecord(
+ ms.database.TuningRecord(
trace,
token,
[4.0, 5.0, 6.0],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
),
- TuningRecord(
+ ms.database.TuningRecord(
trace,
token,
[1.1, 1.2, 600.0],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
),
- TuningRecord(
+ ms.database.TuningRecord(
trace,
token,
[1.0, 100.0, 6.0],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
),
- TuningRecord(
+ ms.database.TuningRecord(
trace,
token,
[4.0, 9.0, 8.0],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
),
]
for record in records:
@@ -257,31 +255,31 @@ def test_meta_schedule_database_reload():
token = database.commit_workload(mod)
trace = _create_schedule(mod, _schedule_matmul).trace
records = [
- TuningRecord(
+ ms.database.TuningRecord(
trace,
token,
[7.0, 8.0, 9.0],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
),
- TuningRecord(
+ ms.database.TuningRecord(
trace,
token,
[1.0, 2.0, 3.0],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
),
- TuningRecord(
+ ms.database.TuningRecord(
trace,
token,
[4.0, 5.0, 6.0],
tvm.target.Target("llvm"),
- ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object
+ ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
),
]
for record in records:
database.commit_tuning_record(record)
- new_database = JSONDatabase( # pylint: disable=unused-variable
+ new_database = ms.database.JSONDatabase(
path_workload=database.path_workload,
path_tuning_record=database.path_tuning_record,
)