You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jc...@apache.org on 2021/06/10 03:14:18 UTC
[tvm] branch main updated: Add metadata information to the listing
of PassContext configuration listing function (#8226)
This is an automated email from the ASF dual-hosted git repository.
jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a468f08 Add metadata information to the listing of PassContext configuration listing function (#8226)
a468f08 is described below
commit a468f08d77ae8bc0dfd492cf5adfcafd026090aa
Author: Leandro Nunes <le...@arm.com>
AuthorDate: Thu Jun 10 04:14:05 2021 +0100
Add metadata information to the listing of PassContext configuration listing function (#8226)
* Rename PassContext::ListConfigNames() to PassContext::ListConfigs() and its
Python counterpart tvm.ir.transform.PassContext.list_config_names -> list_configs()
* Adjust PassContext::ListConfigs() to include also metadata (currently only including the data type)
* Adjust unit tests
---
include/tvm/ir/transform.h | 6 +++---
python/tvm/ir/transform.py | 12 +++++++++---
src/ir/transform.cc | 16 +++++++++-------
tests/cpp/relay_transform_sequential_test.cc | 9 +++++----
tests/python/relay/test_pass_instrument.py | 7 ++++---
5 files changed, 30 insertions(+), 20 deletions(-)
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index d5b50a7..cb556fc 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -184,10 +184,10 @@ class PassContext : public ObjectRef {
TVM_DLL static PassContext Current();
/*!
- * \brief Get all supported configuration names, registered within the PassContext.
- * \return List of all configuration names.
+ * \brief Get all supported configuration names and metadata, registered within the PassContext.
+ * \return Map indexed by the config name, pointing to the metadata map as key-value
*/
- TVM_DLL static Array<String> ListConfigNames();
+ TVM_DLL static Map<String, Map<String, String>> ListConfigs();
/*!
* \brief Call instrument implementations' callbacks when entering PassContext.
diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py
index 7a0ea82..9296244 100644
--- a/python/tvm/ir/transform.py
+++ b/python/tvm/ir/transform.py
@@ -121,9 +121,15 @@ class PassContext(tvm.runtime.Object):
return _ffi_transform_api.GetCurrentPassContext()
@staticmethod
- def list_config_names():
- """List all registered `PassContext` configuration names"""
- return list(_ffi_transform_api.ListConfigNames())
+ def list_configs():
+ """List all registered `PassContext` configuration names and metadata.
+
+ Returns
+ -------
+ configs : Dict[str, Dict[str, str]]
+
+ """
+ return _ffi_transform_api.ListConfigs()
@tvm._ffi.register_object("transform.Pass")
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index a8541b1..8120ca7 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -145,12 +145,14 @@ class PassConfigManager {
}
}
- Array<String> ListConfigNames() {
- Array<String> config_keys;
+ Map<String, Map<String, String>> ListConfigs() {
+ Map<String, Map<String, String>> configs;
for (const auto& kv : key2vtype_) {
- config_keys.push_back(kv.first);
+ Map<String, String> metadata;
+ metadata.Set("type", kv.second.type_key);
+ configs.Set(kv.first, metadata);
}
- return config_keys;
+ return configs;
}
static PassConfigManager* Global() {
@@ -171,8 +173,8 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde
PassConfigManager::Global()->Register(key, value_type_index);
}
-Array<String> PassContext::ListConfigNames() {
- return PassConfigManager::Global()->ListConfigNames();
+Map<String, Map<String, String>> PassContext::ListConfigs() {
+ return PassConfigManager::Global()->ListConfigs();
}
PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); }
@@ -619,7 +621,7 @@ Pass PrintIR(String header, bool show_meta_data) {
TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR);
-TVM_REGISTER_GLOBAL("transform.ListConfigNames").set_body_typed(PassContext::ListConfigNames);
+TVM_REGISTER_GLOBAL("transform.ListConfigs").set_body_typed(PassContext::ListConfigs);
} // namespace transform
} // namespace tvm
diff --git a/tests/cpp/relay_transform_sequential_test.cc b/tests/cpp/relay_transform_sequential_test.cc
index 16e9438..6d38e10 100644
--- a/tests/cpp/relay_transform_sequential_test.cc
+++ b/tests/cpp/relay_transform_sequential_test.cc
@@ -121,11 +121,12 @@ TEST(Relay, Sequential) {
ICHECK(tvm::StructuralEqual()(f, expected));
}
-TEST(PassContextListConfigNames, Basic) {
- Array<String> configs = relay::transform::PassContext::ListConfigNames();
+TEST(PassContextListConfigs, Basic) {
+ Map<String, Map<String, String>> configs = relay::transform::PassContext::ListConfigs();
ICHECK_EQ(configs.empty(), false);
- ICHECK_EQ(std::count(std::begin(configs), std::end(configs), "relay.backend.use_auto_scheduler"),
- 1);
+
+ auto config = configs["relay.backend.use_auto_scheduler"];
+ ICHECK_EQ(config["type"], "IntImm");
}
int main(int argc, char** argv) {
diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py
index c7405ae..610d4e4 100644
--- a/tests/python/relay/test_pass_instrument.py
+++ b/tests/python/relay/test_pass_instrument.py
@@ -183,10 +183,11 @@ def test_instrument_pass_counts():
def test_list_pass_configs():
- config_names = tvm.transform.PassContext.list_config_names()
+ configs = tvm.transform.PassContext.list_configs()
- assert len(config_names) > 0
- assert "relay.backend.use_auto_scheduler" in config_names
+ assert len(configs) > 0
+ assert "relay.backend.use_auto_scheduler" in configs.keys()
+ assert configs["relay.backend.use_auto_scheduler"]["type"] == "IntImm"
def test_enter_pass_ctx_exception():