You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/11 07:18:14 UTC

[tvm] branch main updated: [MetaSchedule] JSONDatabase Utilities (#11680)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0df69611b2 [MetaSchedule] JSONDatabase Utilities (#11680)
0df69611b2 is described below

commit 0df69611b2fb46724a0023dd8d389c9a1ecedcb8
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sat Jun 11 00:18:10 2022 -0700

    [MetaSchedule] JSONDatabase Utilities (#11680)
    
    This PR adds some utility to JSONDatabase to accelerate its loading/saving time.
---
 python/tvm/meta_schedule/utils.py                  |  28 +-
 src/meta_schedule/arg_info.cc                      |   2 +-
 src/meta_schedule/database/database.cc             |   2 +-
 src/meta_schedule/database/database_utils.cc       | 377 +++++++++++++++++++++
 src/meta_schedule/database/json_database.cc        |  80 ++++-
 src/meta_schedule/utils.h                          | 103 +++---
 .../python/unittest/test_meta_schedule_database.py |  68 ++--
 7 files changed, 526 insertions(+), 134 deletions(-)

diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py
index 919a29e6cf..26bf206709 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -16,12 +16,11 @@
 # under the License.
 """Utilities for meta schedule"""
 import ctypes
-import json
 import logging
 import os
 import shutil
 from contextlib import contextmanager
-from typing import Any, List, Dict, Callable, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
 
 import psutil  # type: ignore
 from tvm._ffi import get_global_func, register_func
@@ -296,31 +295,6 @@ def _json_de_tvm(obj: Any) -> Any:
     raise TypeError("Not supported type: " + str(type(obj)))
 
 
-@register_func("meta_schedule.json_obj2str")
-def json_obj2str(json_obj: Any) -> str:
-    json_obj = _json_de_tvm(json_obj)
-    return json.dumps(json_obj)
-
-
-@register_func("meta_schedule.batch_json_str2obj")
-def batch_json_str2obj(json_strs: List[str]) -> List[Any]:
-    """Covert a list of JSON strings to a list of json objects.
-    Parameters
-    ----------
-    json_strs : List[str]
-        The list of JSON strings
-    Returns
-    -------
-    result : List[Any]
-        The list of json objects
-    """
-    return [
-        json.loads(json_str)
-        for json_str in map(str.strip, json_strs)
-        if json_str and (not json_str.startswith("#")) and (not json_str.startswith("//"))
-    ]
-
-
 def shash2hex(mod: IRModule) -> str:
     """Get the structural hash of a module.
 
diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc
index 104662b6aa..9b225e8bea 100644
--- a/src/meta_schedule/arg_info.cc
+++ b/src/meta_schedule/arg_info.cc
@@ -88,7 +88,7 @@ TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) {
       dtype = runtime::String2DLDataType(dtype_str);
     }
     // Load json[2] => shape
-    shape = Downcast<Array<Integer>>(json_array->at(2));
+    shape = AsIntArray(json_array->at(2));
   } catch (const std::runtime_error& e) {  // includes tvm::Error and dmlc::Error
     LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
                << "\nThe error is: " << e.what();
diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc
index 86d999e4fd..9905ff73c7 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -115,7 +115,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w
     CHECK(json_array && json_array->size() == 4);
     // Load json[1] => run_secs
     if (json_array->at(1).defined()) {
-      run_secs = Downcast<Array<FloatImm>>(json_array->at(1));
+      run_secs = AsFloatArray(json_array->at(1));
     }
     // Load json[2] => target
     if (json_array->at(2).defined()) {
diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc
new file mode 100644
index 0000000000..278c5267ea
--- /dev/null
+++ b/src/meta_schedule/database/database_utils.cc
@@ -0,0 +1,377 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <iomanip>
+#include <sstream>
+#include <vector>
+
+#include "../../support/str_escape.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+void JSONDumps(ObjectRef json_obj, std::ostringstream& os) {
+  if (!json_obj.defined()) {
+    os << "null";
+  } else if (const auto* int_imm = json_obj.as<IntImmNode>()) {
+    if (int_imm->dtype == DataType::Bool()) {
+      if (int_imm->value) {
+        os << "true";
+      } else {
+        os << "false";
+      }
+    } else {
+      os << int_imm->value;
+    }
+  } else if (const auto* float_imm = json_obj.as<FloatImmNode>()) {
+    os << std::setprecision(20) << float_imm->value;
+  } else if (const auto* str = json_obj.as<runtime::StringObj>()) {
+    os << '"' << support::StrEscape(str->data, str->size) << '"';
+  } else if (const auto* array = json_obj.as<runtime::ArrayNode>()) {
+    os << "[";
+    int n = array->size();
+    for (int i = 0; i < n; ++i) {
+      if (i != 0) {
+        os << ",";
+      }
+      JSONDumps(array->at(i), os);
+    }
+    os << "]";
+  } else if (const auto* dict = json_obj.as<runtime::MapNode>()) {
+    int n = dict->size();
+    std::vector<std::pair<String, ObjectRef>> key_values;
+    key_values.reserve(n);
+    for (const auto& kv : *dict) {
+      if (const auto* k = kv.first.as<StringObj>()) {
+        key_values.emplace_back(GetRef<String>(k), kv.second);
+      } else {
+        LOG(FATAL) << "TypeError: Only string keys are supported in JSON dumps, but got: "
+                   << kv.first->GetTypeKey();
+      }
+    }
+    std::sort(key_values.begin(), key_values.end());
+    os << "{";
+    for (int i = 0; i < n; ++i) {
+      const auto& kv = key_values[i];
+      if (i != 0) {
+        os << ",";
+      }
+      os << '"' << support::StrEscape(kv.first->data, kv.first->size) << '"';
+      os << ":";
+      JSONDumps(kv.second, os);
+    }
+    os << "}";
+  } else {
+    LOG(FATAL) << "TypeError: Unsupported type in JSON object: " << json_obj->GetTypeKey();
+  }
+}
+
+std::string JSONDumps(ObjectRef json_obj) {
+  std::ostringstream os;
+  JSONDumps(json_obj, os);
+  return os.str();
+}
+
+class JSONTokenizer {
+ public:
+  enum class TokenType : int32_t {
+    kEOF = 0,          // end of file
+    kNull = 1,         // null
+    kTrue = 2,         // true
+    kFalse = 3,        // false
+    kLeftSquare = 4,   // [
+    kRightSquare = 5,  // ]
+    kLeftCurly = 6,    // {
+    kRightCurly = 7,   // }
+    kComma = 8,        // ,
+    kColon = 9,        // :
+    kInteger = 10,     // integers
+    kFloat = 11,       // floating point numbers
+    kString = 12,      // string
+  };
+
+  struct Token {
+    TokenType type;
+    ObjectRef value{nullptr};
+  };
+
+  explicit JSONTokenizer(const char* st, const char* ed) : cur_(st), end_(ed) {}
+
+  Token Next() {
+    for (; cur_ != end_ && std::isspace(*cur_); ++cur_) {
+    }
+    if (cur_ == end_) return Token{TokenType::kEOF};
+    if (NextLeftSquare()) return Token{TokenType::kLeftSquare};
+    if (NextRightSquare()) return Token{TokenType::kRightSquare};
+    if (NextLeftCurly()) return Token{TokenType::kLeftCurly};
+    if (NextRightCurly()) return Token{TokenType::kRightCurly};
+    if (NextComma()) return Token{TokenType::kComma};
+    if (NextColon()) return Token{TokenType::kColon};
+    if (NextNull()) return Token{TokenType::kNull};
+    if (NextTrue()) return Token{TokenType::kTrue};
+    if (NextFalse()) return Token{TokenType::kFalse};
+    Token token;
+    if (NextString(&token)) return token;
+    if (NextNumber(&token)) return token;
+    LOG(FATAL) << "ValueError: Cannot tokenize: " << std::string(cur_, end_);
+    throw;
+  }
+
+ private:
+  bool NextLeftSquare() { return NextLiteral('['); }
+  bool NextRightSquare() { return NextLiteral(']'); }
+  bool NextLeftCurly() { return NextLiteral('{'); }
+  bool NextRightCurly() { return NextLiteral('}'); }
+  bool NextComma() { return NextLiteral(','); }
+  bool NextColon() { return NextLiteral(':'); }
+  bool NextNull() { return NextLiteral("null", 4); }
+  bool NextTrue() { return NextLiteral("true", 4); }
+  bool NextFalse() { return NextLiteral("false", 5); }
+
+  bool NextNumber(Token* token) {
+    using runtime::DataType;
+    bool is_float = false;
+    const char* st = cur_;
+    for (; cur_ != end_; ++cur_) {
+      if (std::isdigit(*cur_) || *cur_ == '+' || *cur_ == '-') {
+        continue;
+      } else if (*cur_ == '.' || *cur_ == 'e' || *cur_ == 'E') {
+        is_float = true;
+      } else {
+        break;
+      }
+    }
+    if (st == cur_) {
+      return false;
+    }
+    // TODO(@junrushao1994): error checking
+    if (is_float) {
+      *token = Token{TokenType::kFloat,
+                     FloatImm(DataType::Float(64),  //
+                              std::stod(std::string(st, cur_)))};
+    } else {
+      *token = Token{TokenType::kInteger,  //
+                     Integer(std::stoi(std::string(st, cur_)))};
+    }
+    return true;
+  }
+
+  bool NextString(Token* token) {
+    if (cur_ == end_ || *cur_ != '"') return false;
+    ++cur_;
+    std::string str;
+    for (; cur_ != end_ && *cur_ != '\"'; ++cur_) {
+      if (*cur_ != '\\') {
+        str.push_back(*cur_);
+        continue;
+      }
+      ++cur_;
+      if (cur_ == end_) {
+        LOG(FATAL) << "ValueError: Unexpected end of string: \\";
+        throw;
+      }
+      switch (*cur_) {
+        case '\"':
+          str.push_back('\"');
+          break;
+        case '\\':
+          str.push_back('\\');
+          break;
+        case '/':
+          str.push_back('/');
+          break;
+        case 'b':
+          str.push_back('\b');
+          break;
+        case 'f':
+          str.push_back('\f');
+          break;
+        case 'n':
+          str.push_back('\n');
+          break;
+        case 'r':
+          str.push_back('\r');
+          break;
+        case 't':
+          str.push_back('\t');
+          break;
+        default:
+          LOG(FATAL) << "ValueError: Unsupported escape sequence: \\" << *cur_;
+      }
+    }
+    if (cur_ == end_) {
+      LOG(FATAL) << "ValueError: Unexpected end of string";
+    }
+    ++cur_;
+    *token = Token{TokenType::kString, String(str)};
+    return true;
+  }
+
+  bool NextLiteral(char c) {
+    if (cur_ != end_ && *cur_ == c) {
+      ++cur_;
+      return true;
+    }
+    return false;
+  }
+
+  bool NextLiteral(const char* str, int len) {
+    if (cur_ + len <= end_ && std::strncmp(cur_, str, len) == 0) {
+      cur_ += len;
+      return true;
+    }
+    return false;
+  }
+  /*! \brief The current pointer */
+  const char* cur_;
+  /*! \brief End of the string */
+  const char* end_;
+
+  friend class JSONParser;
+};
+
+class JSONParser {
+ public:
+  using TokenType = JSONTokenizer::TokenType;
+  using Token = JSONTokenizer::Token;
+
+  explicit JSONParser(const char* st, const char* ed) : tokenizer_(st, ed) {}
+
+  ObjectRef Get() {
+    Token token = tokenizer_.Next();
+    if (token.type == TokenType::kEOF) {
+      return ObjectRef(nullptr);
+    }
+    return ParseObject(std::move(token));
+  }
+
+ private:
+  ObjectRef ParseObject(Token token) {
+    switch (token.type) {
+      case TokenType::kNull:
+        return ObjectRef(nullptr);
+      case TokenType::kTrue:
+        return Bool(true);
+      case TokenType::kFalse:
+        return Bool(false);
+      case TokenType::kLeftSquare:
+        return ParseArray();
+      case TokenType::kLeftCurly:
+        return ParseDict();
+      case TokenType::kString:
+      case TokenType::kInteger:
+      case TokenType::kFloat:
+        return token.value;
+      case TokenType::kRightSquare:
+        LOG(FATAL) << "ValueError: Unexpected token: ]";
+      case TokenType::kRightCurly:
+        LOG(FATAL) << "ValueError: Unexpected token: }";
+      case TokenType::kComma:
+        LOG(FATAL) << "ValueError: Unexpected token: ,";
+      case TokenType::kColon:
+        LOG(FATAL) << "ValueError: Unexpected token: :";
+      case TokenType::kEOF:
+        LOG(FATAL) << "ValueError: Unexpected EOF";
+      default:
+        throw;
+    }
+  }
+
+  Array<ObjectRef> ParseArray() {
+    bool is_first = true;
+    Array<ObjectRef> results;
+    for (;;) {
+      Token token;
+      if (is_first) {
+        is_first = false;
+        token = Token{TokenType::kComma};
+      } else {
+        token = tokenizer_.Next();
+      }
+      // Three cases overall:
+      // - Case 1. 1 token: "]"
+      // - Case 2. 2 tokens: ",", "]"
+      // - Case 3. 2 tokens: ",", "obj"
+      if (token.type == TokenType::kRightSquare) {  // Case 1
+        break;
+      } else if (token.type == TokenType::kComma) {
+        token = tokenizer_.Next();
+        if (token.type == TokenType::kRightSquare) {  // Case 2
+          break;
+        }
+        // Case 3
+        results.push_back(ParseObject(std::move(token)));
+        continue;
+      } else {
+        LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_;
+      }
+    }
+    return results;
+  }
+
+  Map<String, ObjectRef> ParseDict() {
+    bool is_first = true;
+    Map<String, ObjectRef> results;
+    for (;;) {
+      Token token;
+      if (is_first) {
+        is_first = false;
+        token = Token{TokenType::kComma};
+      } else {
+        token = tokenizer_.Next();
+      }
+      // Three cases overall:
+      // - Case 1. 1 token: "}"
+      // - Case 2. 2 tokens: ",", "}"
+      // - Case 3. 2 tokens: ",", "key", ":", "value"
+      if (token.type == TokenType::kRightCurly) {  // Case 1
+        break;
+      } else if (token.type == TokenType::kComma) {
+        token = tokenizer_.Next();
+        if (token.type == TokenType::kRightCurly) {  // Case 2
+          break;
+        }
+        // Case 3
+        ObjectRef key = ParseObject(std::move(token));
+        ICHECK(key->IsInstance<StringObj>())
+            << "ValueError: key must be a string, but gets: " << key;
+        token = tokenizer_.Next();
+        CHECK(token.type == TokenType::kColon)
+            << "ValueError: Unexpected token before: " << tokenizer_.cur_;
+        ObjectRef value = ParseObject(tokenizer_.Next());
+        results.Set(Downcast<String>(key), value);
+        continue;
+      } else {
+        LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_;
+      }
+    }
+    return results;
+  }
+
+  JSONTokenizer tokenizer_;
+};
+
+ObjectRef JSONLoads(std::string str) {
+  const char* st = str.c_str();
+  const char* ed = st + str.length();
+  return JSONParser(st, ed).Get();
+}
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc
index 155d223217..4f5bd9b136 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -17,6 +17,7 @@
  * under the License.
  */
 #include <set>
+#include <thread>
 #include <unordered_map>
 
 #include "../utils.h"
@@ -46,6 +47,45 @@ struct SortTuningRecordByMeanRunSecs {
   }
 };
 
+/*!
+ * \brief Read lines from a json file.
+ * \param path The path to the json file.
+ * \param num_lines The number of threads used to concurrently parse the lines.
+ * \param allow_missing Whether to create new file when the given path is not found.
+ * \return An array containing lines read from the json file.
+ */
+std::vector<ObjectRef> JSONFileReadLines(const String& path, int num_threads, bool allow_missing) {
+  std::ifstream is(path);
+  if (is.good()) {
+    std::vector<String> json_strs;
+    for (std::string str; std::getline(is, str);) {
+      json_strs.push_back(str);
+    }
+    int n = json_strs.size();
+    std::vector<ObjectRef> json_objs;
+    json_objs.resize(n);
+    support::parallel_for_dynamic(0, n, num_threads, [&](int thread_id, int task_id) {
+      json_objs[task_id] = JSONLoads(json_strs[task_id]);
+    });
+    return json_objs;
+  }
+  CHECK(allow_missing) << "ValueError: File doesn't exist: " << path;
+  std::ofstream os(path);
+  CHECK(os.good()) << "ValueError: Cannot create new file: " << path;
+  return {};
+}
+
+/*!
+ * \brief Append a line to a json file.
+ * \param path The path to the json file.
+ * \param line The line to append.
+ */
+void JSONFileAppendLine(const String& path, const std::string& line) {
+  std::ofstream os(path, std::ofstream::app);
+  CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path;
+  os << line << std::endl;
+}
+
 /*! \brief The default database implementation, which mimics two database tables with two files. */
 class JSONDatabaseNode : public DatabaseNode {
  public:
@@ -83,7 +123,7 @@ class JSONDatabaseNode : public DatabaseNode {
     // If `mod` is new in `workloads2idx_`, append it to the workload file
     if (inserted) {
       it->second = static_cast<int>(this->workloads2idx_.size()) - 1;
-      JSONFileAppendLine(this->path_workload, JSONObj2Str(workload->AsJSON()));
+      JSONFileAppendLine(this->path_workload, JSONDumps(workload->AsJSON()));
     }
     return it->first;
   }
@@ -91,7 +131,7 @@ class JSONDatabaseNode : public DatabaseNode {
   void CommitTuningRecord(const TuningRecord& record) {
     this->tuning_records_.insert(record);
     JSONFileAppendLine(this->path_tuning_record,
-                       JSONObj2Str(Array<ObjectRef>{
+                       JSONDumps(Array<ObjectRef>{
                            /*workload_index=*/Integer(this->workloads2idx_.at(record->workload)),
                            /*tuning_record=*/record->AsJSON()  //
                        }));
@@ -121,11 +161,12 @@ class JSONDatabaseNode : public DatabaseNode {
 
 Database Database::JSONDatabase(String path_workload, String path_tuning_record,
                                 bool allow_missing) {
+  int num_threads = std::thread::hardware_concurrency();
   ObjectPtr<JSONDatabaseNode> n = make_object<JSONDatabaseNode>();
   // Load `n->workloads2idx_` from `path_workload`
   std::vector<Workload> workloads;
   {
-    Array<ObjectRef> json_objs = JSONStr2Obj(JSONFileReadLines(path_workload, allow_missing));
+    std::vector<ObjectRef> json_objs = JSONFileReadLines(path_workload, num_threads, allow_missing);
     int n_objs = json_objs.size();
     n->workloads2idx_.reserve(n_objs);
     workloads.reserve(n_objs);
@@ -137,20 +178,25 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record,
   }
   // Load `n->tuning_records_` from `path_tuning_record`
   {
-    Array<ObjectRef> json_objs = JSONStr2Obj(JSONFileReadLines(path_tuning_record, allow_missing));
-    for (const ObjectRef& json_obj : json_objs) {
-      int workload_index = -1;
-      ObjectRef tuning_record{nullptr};
-      try {
-        const ArrayNode* arr = json_obj.as<ArrayNode>();
-        ICHECK_EQ(arr->size(), 2);
-        workload_index = Downcast<Integer>(arr->at(0));
-        tuning_record = arr->at(1);
-      } catch (std::runtime_error& e) {
-        LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
-                   << "\nThe error is: " << e.what();
-      }
-      n->tuning_records_.insert(TuningRecord::FromJSON(tuning_record, workloads[workload_index]));
+    std::vector<ObjectRef> json_objs =
+        JSONFileReadLines(path_tuning_record, num_threads, allow_missing);
+    std::vector<TuningRecord> records;
+    records.resize(json_objs.size(), TuningRecord{nullptr});
+    support::parallel_for_dynamic(
+        0, json_objs.size(), num_threads, [&](int thread_id, int task_id) {
+          const ObjectRef& json_obj = json_objs[task_id];
+          try {
+            const ArrayNode* arr = json_obj.as<ArrayNode>();
+            ICHECK_EQ(arr->size(), 2);
+            records[task_id] = TuningRecord::FromJSON(arr->at(1),  //
+                                                      workloads[Downcast<Integer>(arr->at(0))]);
+          } catch (std::runtime_error& e) {
+            LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj
+                       << "\nThe error is: " << e.what();
+          }
+        });
+    for (const TuningRecord& record : records) {
+      n->tuning_records_.insert(record);
     }
   }
   n->path_workload = path_workload;
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index be7745f23d..40c301c617 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -107,38 +107,6 @@ class PyLogMessage {
 /*! \brief The type of the random state */
 using TRandState = support::LinearCongruentialEngine::TRandState;
 
-/*!
- * \brief Read lines from a json file.
- * \param path The path to the json file.
- * \param allow_missing Whether to create new file when the given path is not found.
- * \return An array containing lines read from the json file.
- */
-inline Array<String> JSONFileReadLines(const String& path, bool allow_missing) {
-  std::ifstream is(path);
-  if (is.good()) {
-    Array<String> results;
-    for (std::string str; std::getline(is, str);) {
-      results.push_back(str);
-    }
-    return results;
-  }
-  CHECK(allow_missing) << "ValueError: File doesn't exist: " << path;
-  std::ofstream os(path);
-  CHECK(os.good()) << "ValueError: Cannot create new file: " << path;
-  return {};
-}
-
-/*!
- * \brief Append a line to a json file.
- * \param path The path to the json file.
- * \param line The line to append.
- */
-inline void JSONFileAppendLine(const String& path, const std::string& line) {
-  std::ofstream os(path, std::ofstream::app);
-  CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path;
-  os << line << std::endl;
-}
-
 /*!
  * \brief Get the base64 encoded result of a string.
  * \param str The string to encode.
@@ -168,31 +136,18 @@ inline std::string Base64Decode(std::string str) {
 }
 
 /*!
- * \brief Parse lines of json string into a json object.
- * \param lines The lines of json string.
- * \return Array of json objects parsed.
- * \note The function calls the python-side json parser in runtime registry.
+ * \brief Parses a json string into a json object.
+ * \param json_str The json string.
+ * \return The json object
  */
-inline Array<ObjectRef> JSONStr2Obj(const Array<String>& lines) {
-  static const runtime::PackedFunc* f_to_obj =
-      runtime::Registry::Get("meta_schedule.batch_json_str2obj");
-  ICHECK(f_to_obj) << "IndexError: Cannot find the packed function "
-                      "`meta_schedule.batch_json_str2obj` in the global registry";
-  return (*f_to_obj)(lines);
-}
+ObjectRef JSONLoads(std::string json_str);
 
 /*!
- * \brief Serialize a json object into a json string.
- * \param json_obj The json object to serialize.
- * \return A string containing the serialized json object.
- * \note The function calls the python-side json obj serializer in runtime registry.
+ * \brief Dumps a json object into a json string.
+ * \param json_obj The json object.
+ * \return The json string
  */
-inline String JSONObj2Str(const ObjectRef& json_obj) {
-  static const runtime::PackedFunc* f_to_str = runtime::Registry::Get("meta_schedule.json_obj2str");
-  ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
-                      "`meta_schedule.json_obj2str` in the global registry";
-  return (*f_to_str)(json_obj);
-}
+std::string JSONDumps(ObjectRef json_obj);
 
 /*!
  * \brief Converts a structural hash code to string
@@ -447,6 +402,48 @@ inline double GetRunMsMedian(const RunnerResult& runner_result) {
   }
 }
 
+/*!
+ * \brief Convert the given object to an array of floating point numbers
+ * \param obj The object to be converted
+ * \return The array of floating point numbers
+ */
+inline Array<FloatImm> AsFloatArray(const ObjectRef& obj) {
+  const ArrayNode* arr = obj.as<ArrayNode>();
+  ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey();
+  Array<FloatImm> results;
+  results.reserve(arr->size());
+  for (const ObjectRef& elem : *arr) {
+    if (const auto* int_imm = elem.as<IntImmNode>()) {
+      results.push_back(FloatImm(DataType::Float(32), int_imm->value));
+    } else if (const auto* float_imm = elem.as<FloatImmNode>()) {
+      results.push_back(FloatImm(DataType::Float(32), float_imm->value));
+    } else {
+      LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << elem->GetTypeKey();
+    }
+  }
+  return results;
+}
+
+/*!
+ * \brief Convert the given object to an array of integers
+ * \param obj The object to be converted
+ * \return The array of integers
+ */
+inline Array<Integer> AsIntArray(const ObjectRef& obj) {
+  const ArrayNode* arr = obj.as<ArrayNode>();
+  ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey();
+  Array<Integer> results;
+  results.reserve(arr->size());
+  for (const ObjectRef& elem : *arr) {
+    if (const auto* int_imm = elem.as<IntImmNode>()) {
+      results.push_back(Integer(int_imm->value));
+    } else {
+      LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey();
+    }
+  }
+  return results;
+}
+
 }  // namespace meta_schedule
 }  // namespace tvm
 
diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py
index 1edfbe6c7a..ff0f350d89 100644
--- a/tests/python/unittest/test_meta_schedule_database.py
+++ b/tests/python/unittest/test_meta_schedule_database.py
@@ -17,20 +17,18 @@
 # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
 """Test Meta Schedule Database"""
 import os.path as osp
-import sys
 import tempfile
 from typing import Callable
 
-import pytest
 import tvm
 import tvm.testing
+from tvm import meta_schedule as ms
 from tvm import tir
 from tvm.ir.module import IRModule
-from tvm.meta_schedule.arg_info import ArgInfo
-from tvm.meta_schedule.database import JSONDatabase, TuningRecord
 from tvm.script import tir as T
 from tvm.tir import Schedule
 
+
 # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
 # fmt: off
 @tvm.script.ir_module
@@ -92,13 +90,13 @@ def _create_schedule(mod: IRModule, sch_fn: Callable[[Schedule], None]) -> Sched
     return sch
 
 
-def _create_tmp_database(tmpdir: str) -> JSONDatabase:
+def _create_tmp_database(tmpdir: str) -> ms.database.JSONDatabase:
     path_workload = osp.join(tmpdir, "workloads.json")
     path_tuning_record = osp.join(tmpdir, "tuning_records.json")
-    return JSONDatabase(path_workload, path_tuning_record)
+    return ms.database.JSONDatabase(path_workload, path_tuning_record)
 
 
-def _equal_record(a: TuningRecord, b: TuningRecord):
+def _equal_record(a: ms.database.TuningRecord, b: ms.database.TuningRecord):
     assert str(a.trace) == str(b.trace)
     assert str(a.run_secs) == str(b.run_secs)
     # AWAIT(@zxybazh): change to export after fixing "(bool)0"
@@ -113,15 +111,15 @@ def test_meta_schedule_tuning_record_round_trip():
     with tempfile.TemporaryDirectory() as tmpdir:
         database = _create_tmp_database(tmpdir)
         workload = database.commit_workload(mod)
-        record = TuningRecord(
+        record = ms.database.TuningRecord(
             _create_schedule(mod, _schedule_matmul).trace,
             workload,
             [1.5, 2.5, 1.8],
             tvm.target.Target("llvm"),
-            ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+            ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
         )
         database.commit_tuning_record(record)
-        new_record = TuningRecord.from_json(record.as_json(), workload)
+        new_record = ms.database.TuningRecord.from_json(record.as_json(), workload)
         _equal_record(record, new_record)
 
 
@@ -138,12 +136,12 @@ def test_meta_schedule_database_has_workload():
     with tempfile.TemporaryDirectory() as tmpdir:
         database = _create_tmp_database(tmpdir)
         workload = database.commit_workload(mod)
-        record = TuningRecord(
+        record = ms.database.TuningRecord(
             _create_schedule(mod, _schedule_matmul).trace,
             workload,
             [1.5, 2.5, 1.8],
             tvm.target.Target("llvm"),
-            ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+            ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
         )
         database.commit_tuning_record(record)
         assert len(database) == 1
@@ -156,12 +154,12 @@ def test_meta_schedule_database_add_entry():
     with tempfile.TemporaryDirectory() as tmpdir:
         database = _create_tmp_database(tmpdir)
         workload = database.commit_workload(mod)
-        record = TuningRecord(
+        record = ms.database.TuningRecord(
             _create_schedule(mod, _schedule_matmul).trace,
             workload,
             [1.5, 2.5, 1.8],
             tvm.target.Target("llvm"),
-            ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+            ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
         )
         database.commit_tuning_record(record)
         assert len(database) == 1
@@ -176,12 +174,12 @@ def test_meta_schedule_database_missing():
         database = _create_tmp_database(tmpdir)
         workload = database.commit_workload(mod)
         workload_2 = database.commit_workload(mod_2)
-        record = TuningRecord(
+        record = ms.database.TuningRecord(
             _create_schedule(mod, _schedule_matmul).trace,
             workload,
             [1.5, 2.5, 1.8],
             tvm.target.Target("llvm"),
-            ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+            ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
         )
         database.commit_tuning_record(record)
         ret = database.get_top_k(workload_2, 3)
@@ -195,47 +193,47 @@ def test_meta_schedule_database_sorting():
         token = database.commit_workload(mod)
         trace = _create_schedule(mod, _schedule_matmul).trace
         records = [
-            TuningRecord(
+            ms.database.TuningRecord(
                 trace,
                 token,
                 [7.0, 8.0, 9.0],
                 tvm.target.Target("llvm"),
-                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+                ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
             ),
-            TuningRecord(
+            ms.database.TuningRecord(
                 trace,
                 token,
                 [1.0, 2.0, 3.0],
                 tvm.target.Target("llvm"),
-                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+                ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
             ),
-            TuningRecord(
+            ms.database.TuningRecord(
                 trace,
                 token,
                 [4.0, 5.0, 6.0],
                 tvm.target.Target("llvm"),
-                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+                ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
             ),
-            TuningRecord(
+            ms.database.TuningRecord(
                 trace,
                 token,
                 [1.1, 1.2, 600.0],
                 tvm.target.Target("llvm"),
-                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+                ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
             ),
-            TuningRecord(
+            ms.database.TuningRecord(
                 trace,
                 token,
                 [1.0, 100.0, 6.0],
                 tvm.target.Target("llvm"),
-                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+                ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
             ),
-            TuningRecord(
+            ms.database.TuningRecord(
                 trace,
                 token,
                 [4.0, 9.0, 8.0],
                 tvm.target.Target("llvm"),
-                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+                ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
             ),
         ]
         for record in records:
@@ -257,31 +255,31 @@ def test_meta_schedule_database_reload():
         token = database.commit_workload(mod)
         trace = _create_schedule(mod, _schedule_matmul).trace
         records = [
-            TuningRecord(
+            ms.database.TuningRecord(
                 trace,
                 token,
                 [7.0, 8.0, 9.0],
                 tvm.target.Target("llvm"),
-                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+                ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
             ),
-            TuningRecord(
+            ms.database.TuningRecord(
                 trace,
                 token,
                 [1.0, 2.0, 3.0],
                 tvm.target.Target("llvm"),
-                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+                ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
             ),
-            TuningRecord(
+            ms.database.TuningRecord(
                 trace,
                 token,
                 [4.0, 5.0, 6.0],
                 tvm.target.Target("llvm"),
-                ArgInfo.from_prim_func(func=mod["main"]),  # pylint: disable=unsubscriptable-object
+                ms.arg_info.ArgInfo.from_prim_func(func=mod["main"]),
             ),
         ]
         for record in records:
             database.commit_tuning_record(record)
-        new_database = JSONDatabase(  # pylint: disable=unused-variable
+        new_database = ms.database.JSONDatabase(
             path_workload=database.path_workload,
             path_tuning_record=database.path_tuning_record,
         )