You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jc...@apache.org on 2021/06/10 03:14:18 UTC

[tvm] branch main updated: Add metadata information to the listing of PassContext configuration listing function (#8226)

This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a468f08  Add metadata information to the listing of PassContext configuration listing function (#8226)
a468f08 is described below

commit a468f08d77ae8bc0dfd492cf5adfcafd026090aa
Author: Leandro Nunes <le...@arm.com>
AuthorDate: Thu Jun 10 04:14:05 2021 +0100

    Add metadata information to the listing of PassContext configuration listing function (#8226)
    
    * Rename PassContext::ListConfigNames() to PassContext::ListConfigs() and its
       Python counterpart tvm.ir.transform.PassContext.list_config_names -> list_configs()
    
     * Adjust PassContext::ListConfigs() to include also metadata (currently only including the data type)
    
     * Adjust unit tests
---
 include/tvm/ir/transform.h                   |  6 +++---
 python/tvm/ir/transform.py                   | 12 +++++++++---
 src/ir/transform.cc                          | 16 +++++++++-------
 tests/cpp/relay_transform_sequential_test.cc |  9 +++++----
 tests/python/relay/test_pass_instrument.py   |  7 ++++---
 5 files changed, 30 insertions(+), 20 deletions(-)

diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index d5b50a7..cb556fc 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -184,10 +184,10 @@ class PassContext : public ObjectRef {
   TVM_DLL static PassContext Current();
 
   /*!
-   * \brief Get all supported configuration names, registered within the PassContext.
-   * \return List of all configuration names.
+   * \brief Get all supported configuration names and metadata, registered within the PassContext.
+   * \return Map indexed by the config name, pointing to the metadata map as key-value
    */
-  TVM_DLL static Array<String> ListConfigNames();
+  TVM_DLL static Map<String, Map<String, String>> ListConfigs();
 
   /*!
    * \brief Call instrument implementations' callbacks when entering PassContext.
diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py
index 7a0ea82..9296244 100644
--- a/python/tvm/ir/transform.py
+++ b/python/tvm/ir/transform.py
@@ -121,9 +121,15 @@ class PassContext(tvm.runtime.Object):
         return _ffi_transform_api.GetCurrentPassContext()
 
     @staticmethod
-    def list_config_names():
-        """List all registered `PassContext` configuration names"""
-        return list(_ffi_transform_api.ListConfigNames())
+    def list_configs():
+        """List all registered `PassContext` configuration names and metadata.
+
+        Returns
+        -------
+        configs : Dict[str, Dict[str, str]]
+
+        """
+        return _ffi_transform_api.ListConfigs()
 
 
 @tvm._ffi.register_object("transform.Pass")
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index a8541b1..8120ca7 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -145,12 +145,14 @@ class PassConfigManager {
     }
   }
 
-  Array<String> ListConfigNames() {
-    Array<String> config_keys;
+  Map<String, Map<String, String>> ListConfigs() {
+    Map<String, Map<String, String>> configs;
     for (const auto& kv : key2vtype_) {
-      config_keys.push_back(kv.first);
+      Map<String, String> metadata;
+      metadata.Set("type", kv.second.type_key);
+      configs.Set(kv.first, metadata);
     }
-    return config_keys;
+    return configs;
   }
 
   static PassConfigManager* Global() {
@@ -171,8 +173,8 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde
   PassConfigManager::Global()->Register(key, value_type_index);
 }
 
-Array<String> PassContext::ListConfigNames() {
-  return PassConfigManager::Global()->ListConfigNames();
+Map<String, Map<String, String>> PassContext::ListConfigs() {
+  return PassConfigManager::Global()->ListConfigs();
 }
 
 PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); }
@@ -619,7 +621,7 @@ Pass PrintIR(String header, bool show_meta_data) {
 
 TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR);
 
-TVM_REGISTER_GLOBAL("transform.ListConfigNames").set_body_typed(PassContext::ListConfigNames);
+TVM_REGISTER_GLOBAL("transform.ListConfigs").set_body_typed(PassContext::ListConfigs);
 
 }  // namespace transform
 }  // namespace tvm
diff --git a/tests/cpp/relay_transform_sequential_test.cc b/tests/cpp/relay_transform_sequential_test.cc
index 16e9438..6d38e10 100644
--- a/tests/cpp/relay_transform_sequential_test.cc
+++ b/tests/cpp/relay_transform_sequential_test.cc
@@ -121,11 +121,12 @@ TEST(Relay, Sequential) {
   ICHECK(tvm::StructuralEqual()(f, expected));
 }
 
-TEST(PassContextListConfigNames, Basic) {
-  Array<String> configs = relay::transform::PassContext::ListConfigNames();
+TEST(PassContextListConfigs, Basic) {
+  Map<String, Map<String, String>> configs = relay::transform::PassContext::ListConfigs();
   ICHECK_EQ(configs.empty(), false);
-  ICHECK_EQ(std::count(std::begin(configs), std::end(configs), "relay.backend.use_auto_scheduler"),
-            1);
+
+  auto config = configs["relay.backend.use_auto_scheduler"];
+  ICHECK_EQ(config["type"], "IntImm");
 }
 
 int main(int argc, char** argv) {
diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py
index c7405ae..610d4e4 100644
--- a/tests/python/relay/test_pass_instrument.py
+++ b/tests/python/relay/test_pass_instrument.py
@@ -183,10 +183,11 @@ def test_instrument_pass_counts():
 
 
 def test_list_pass_configs():
-    config_names = tvm.transform.PassContext.list_config_names()
+    configs = tvm.transform.PassContext.list_configs()
 
-    assert len(config_names) > 0
-    assert "relay.backend.use_auto_scheduler" in config_names
+    assert len(configs) > 0
+    assert "relay.backend.use_auto_scheduler" in configs.keys()
+    assert configs["relay.backend.use_auto_scheduler"]["type"] == "IntImm"
 
 
 def test_enter_pass_ctx_exception():