You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/10/28 18:51:46 UTC

[incubator-tvm] branch main updated: [API] Added remove_global_func to the Python API (#6787)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ad92efd  [API] Added remove_global_func to the Python API (#6787)
ad92efd is described below

commit ad92efdca0beab5939ef7cc4c82af950269ec0d1
Author: mbaret <55...@users.noreply.github.com>
AuthorDate: Wed Oct 28 18:51:28 2020 +0000

    [API] Added remove_global_func to the Python API (#6787)
    
    This is useful for unregistering functions after a test.
    
    Change-Id: Ic39499aa8f36bfe5470bc1f058ad3b96cf52b49c
---
 include/tvm/runtime/c_runtime_api.h |  6 ++++++
 python/tvm/_ffi/registry.py         | 11 +++++++++++
 src/runtime/registry.cc             |  6 ++++++
 3 files changed, 23 insertions(+)

diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h
index e25394a..aac49c1 100644
--- a/include/tvm/runtime/c_runtime_api.h
+++ b/include/tvm/runtime/c_runtime_api.h
@@ -366,6 +366,12 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
  */
 TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array);
 
+/*!
+ * \brief Remove a global function.
+ * \param name The name of the function.
+ */
+TVM_DLL int TVMFuncRemoveGlobal(const char* name);
+
 // Array related apis for quick proptyping
 /*!
  * \brief Allocate a nd-array's memory,
diff --git a/python/tvm/_ffi/registry.py b/python/tvm/_ffi/registry.py
index 6637cd1..677ca5d 100644
--- a/python/tvm/_ffi/registry.py
+++ b/python/tvm/_ffi/registry.py
@@ -262,6 +262,17 @@ def extract_ext_funcs(finit):
     return fdict
 
 
+def remove_global_func(name):
+    """Remove a global function by name
+
+    Parameters
+    ----------
+    name : str
+        The name of the global function
+    """
+    check_call(_LIB.TVMFuncRemoveGlobal(c_str(name)))
+
+
 def _get_api(f):
     flocal = f
     flocal.is_global = True
diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc
index 6e74dc3..a652350 100644
--- a/src/runtime/registry.cc
+++ b/src/runtime/registry.cc
@@ -146,3 +146,9 @@ int TVMFuncListGlobalNames(int* out_size, const char*** out_array) {
   *out_size = static_cast<int>(ret->ret_vec_str.size());
   API_END();
 }
+
+int TVMFuncRemoveGlobal(const char* name) {
+  API_BEGIN();
+  tvm::runtime::Registry::Remove(name);
+  API_END();
+}