You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/06 02:53:02 UTC

[incubator-tvm] 01/02: Debug segfault from loading Python

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 35d49bf279a07aaa92e255bb5d48ae485bda3cdf
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Sun Oct 25 17:26:47 2020 -0700

    Debug segfault from loading Python
---
 python/tvm/__init__.py                            |  2 ++
 python/tvm/relay/__init__.py                      |  3 +-
 python/tvm/relay/analysis/__init__.py             |  2 +-
 python/tvm/relay/analysis/analysis.py             |  6 ++--
 python/tvm/relay/analysis/annotated_regions.py    |  2 +-
 python/tvm/relay/analysis/call_graph.py           |  4 +--
 python/tvm/relay/analysis/sparse_dense.py         | 15 ++++----
 python/tvm/relay/backend/graph_runtime_factory.py |  2 +-
 python/tvm/relay/build_module.py                  |  5 ++-
 python/tvm/relay/op/op.py                         | 43 +++++++++++------------
 python/tvm/relay/transform/__init__.py            |  2 +-
 python/tvm/relay/transform/memory_alloc.py        |  7 ++--
 python/tvm/relay/transform/transform.py           |  5 +--
 python/tvm/topi/cuda/__init__.py                  |  2 --
 python/tvm/topi/cuda/sparse.py                    |  3 +-
 rust/tvm-rt/src/map.rs                            | 12 +++++++
 rust/tvm-rt/src/module.rs                         | 16 +++++++++
 rust/tvm-rt/src/to_function.rs                    |  1 +
 rust/tvm/Cargo.toml                               |  2 +-
 rust/tvm/src/python.rs                            | 21 ++++++++---
 src/runtime/module.cc                             |  2 +-
 21 files changed, 101 insertions(+), 56 deletions(-)

diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 569e8f0..60f81f4 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -67,6 +67,8 @@ from . import support
 # Contrib initializers
 from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
 
+def cleanup():
+    _ffi.base._LIB = None
 
 def tvm_wrap_excepthook(exception_hook):
     """Wrap given excepthook with TVM additional work."""
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index cd96ecc..7e6ed4f 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -60,8 +60,7 @@ from . import qnn
 from .scope_builder import ScopeBuilder
 
 # Load Memory Passes
-from .transform import memory_alloc
-from .transform import memory_plan
+from .transform import memory_alloc, memory_plan
 
 # Required to traverse large programs
 setrecursionlimit(10000)
diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py
index b4ea7f3..4ea4de7 100644
--- a/python/tvm/relay/analysis/__init__.py
+++ b/python/tvm/relay/analysis/__init__.py
@@ -26,7 +26,7 @@ from .annotated_regions import AnnotatedRegionSet
 from . import call_graph
 from .call_graph import CallGraph
 
-# Feature
+# # Feature
 from . import feature
 from . import sparse_dense
 
diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py
index 7e49461..48e9ce0 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -20,9 +20,9 @@
 This file contains the set of passes for Relay, which exposes an interface for
 configuring the passes and scripting them in Python.
 """
-from tvm.ir import IRModule
-from tvm.relay import transform, build_module
-from tvm.runtime.ndarray import cpu
+from ...ir import IRModule
+from ...relay import transform, build_module
+from ...runtime.ndarray import cpu
 
 from . import _ffi_api
 from .feature import Feature
diff --git a/python/tvm/relay/analysis/annotated_regions.py b/python/tvm/relay/analysis/annotated_regions.py
index 437b97b..a18ccb9 100644
--- a/python/tvm/relay/analysis/annotated_regions.py
+++ b/python/tvm/relay/analysis/annotated_regions.py
@@ -17,7 +17,7 @@
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
 """Regions used in Relay."""
 
-from tvm.runtime import Object
+from ...runtime import Object
 from . import _ffi_api
 
 
diff --git a/python/tvm/relay/analysis/call_graph.py b/python/tvm/relay/analysis/call_graph.py
index 966659a..fd9704d 100644
--- a/python/tvm/relay/analysis/call_graph.py
+++ b/python/tvm/relay/analysis/call_graph.py
@@ -17,8 +17,8 @@
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
 """Call graph used in Relay."""
 
-from tvm.ir import IRModule
-from tvm.runtime import Object
+from ...ir import IRModule
+from ...runtime import Object
 from ..expr import GlobalVar
 from . import _ffi_api
 
diff --git a/python/tvm/relay/analysis/sparse_dense.py b/python/tvm/relay/analysis/sparse_dense.py
index d521748..51fab34 100644
--- a/python/tvm/relay/analysis/sparse_dense.py
+++ b/python/tvm/relay/analysis/sparse_dense.py
@@ -22,8 +22,8 @@ to block sparse model
 """
 from collections import namedtuple
 import numpy as np
-import scipy.sparse as sp
-import tvm
+
+from ... import nd, runtime
 from . import _ffi_api
 
 
@@ -73,6 +73,7 @@ def process_params(expr, params, block_size, sparsity_threshold):
     ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]]
         return names of qualified dense weight and the shape in BSR format
     """
+    import scipy.sparse as sp
     memo = SparseAnalysisResult(weight_name=[], weight_shape=[])
     weight_names = _search_dense_op_weight(expr)
     for name in weight_names:
@@ -89,11 +90,11 @@ def process_params(expr, params, block_size, sparsity_threshold):
                 + list(sparse_weight.indices.shape)
                 + list(sparse_weight.indptr.shape)
             )
-            params[name + ".data"] = tvm.nd.array(sparse_weight.data)
-            params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
-            params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)
+            params[name + ".data"] = nd.array(sparse_weight.data)
+            params[name + ".indices"] = nd.array(sparse_weight.indices)
+            params[name + ".indptr"] = nd.array(sparse_weight.indptr)
     ret = SparseAnalysisResult(
-        weight_name=tvm.runtime.convert(memo.weight_name),
-        weight_shape=tvm.runtime.convert(memo.weight_shape),
+        weight_name=runtime.convert(memo.weight_name),
+        weight_shape=runtime.convert(memo.weight_shape),
     )
     return ret
diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py
index 4c6ac47..3427a62 100644
--- a/python/tvm/relay/backend/graph_runtime_factory.py
+++ b/python/tvm/relay/backend/graph_runtime_factory.py
@@ -21,7 +21,7 @@ from tvm._ffi.registry import get_global_func
 from tvm.runtime import ndarray
 
 
-class GraphRuntimeFactoryModule(object):
+class GraphRuntimeFactoryModule:
     """Graph runtime factory module.
     This is a module of graph runtime factory
 
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 35bd8e6..7e32dea 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -24,7 +24,7 @@ import numpy as np
 from tvm.ir import IRModule
 
 from tvm.tir import expr as tvm_expr
-from .. import nd as _nd, autotvm
+from .. import nd as _nd, autotvm, register_func
 from ..target import Target
 from ..contrib import graph_runtime as _graph_rt
 from . import _build_module
@@ -186,6 +186,9 @@ class BuildModule(object):
             ret[key] = value.data
         return ret
 
+@register_func("tvm.relay.build")
+def _rust_build_module(mod, target=None, target_host=None, params=None, mod_name="default"):
+    return build(mod, target, target_host, params, mod_name).module
 
 def build(mod, target=None, target_host=None, params=None, mod_name="default"):
     """Helper function that builds a Relay function to run on TVM graph
diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py
index fa420c4..b8c1d69 100644
--- a/python/tvm/relay/op/op.py
+++ b/python/tvm/relay/op/op.py
@@ -16,12 +16,11 @@
 # under the License.
 # pylint: disable=unused-argument,invalid-name
 """The base node types for the Relay language."""
-import tvm._ffi
-import tvm.ir
+from ... import _ffi, ir
 from tvm.auto_scheduler.relay_integration import auto_schedule_topi, auto_schedule_impl_suffix
-from tvm.driver import lower, build
-from tvm.target import get_native_generic_func, GenericFunc
-from tvm.runtime import Object
+from ...driver import lower, build
+from ...target import get_native_generic_func, GenericFunc
+from ...runtime import Object
 from . import _make
 
 
@@ -38,7 +37,7 @@ def get(op_name):
     op : Op
         The op of the corresponding name
     """
-    return tvm.ir.Op.get(op_name)
+    return ir.Op.get(op_name)
 
 
 class OpPattern(object):
@@ -65,7 +64,7 @@ class OpPattern(object):
     OPAQUE = 8
 
 
-@tvm._ffi.register_object("relay.OpImplementation")
+@_ffi.register_object("relay.OpImplementation")
 class OpImplementation(Object):
     """Operator implementation"""
 
@@ -112,12 +111,12 @@ class OpImplementation(Object):
         return _OpImplementationSchedule(self, attrs, outs, target)
 
 
-@tvm._ffi.register_object("relay.OpSpecialization")
+@_ffi.register_object("relay.OpSpecialization")
 class OpSpecialization(Object):
     """Operator specialization"""
 
 
-@tvm._ffi.register_object("relay.OpStrategy")
+@_ffi.register_object("relay.OpStrategy")
 class OpStrategy(Object):
     """Operator strategy"""
 
@@ -208,7 +207,7 @@ def register_compute(op_name, compute=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FTVMCompute", compute, level)
+    return ir.register_op_attr(op_name, "FTVMCompute", compute, level)
 
 
 def register_strategy(op_name, fstrategy=None, level=10):
@@ -229,7 +228,7 @@ def register_strategy(op_name, fstrategy=None, level=10):
     if not isinstance(fstrategy, GenericFunc):
         assert hasattr(fstrategy, "generic_func_node")
         fstrategy = fstrategy.generic_func_node
-    return tvm.ir.register_op_attr(op_name, "FTVMStrategy", fstrategy, level)
+    return ir.register_op_attr(op_name, "FTVMStrategy", fstrategy, level)
 
 
 def register_schedule(op_name, schedule, level=10):
@@ -310,7 +309,7 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FTVMAlterOpLayout", alter_layout, level)
+    return ir.register_op_attr(op_name, "FTVMAlterOpLayout", alter_layout, level)
 
 
 def register_convert_op_layout(op_name, convert_layout=None, level=10):
@@ -327,7 +326,7 @@ def register_convert_op_layout(op_name, convert_layout=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FTVMConvertOpLayout", convert_layout, level)
+    return ir.register_op_attr(op_name, "FTVMConvertOpLayout", convert_layout, level)
 
 
 def register_legalize(op_name, legal_op=None, level=10):
@@ -344,7 +343,7 @@ def register_legalize(op_name, legal_op=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FTVMLegalize", legal_op, level)
+    return ir.register_op_attr(op_name, "FTVMLegalize", legal_op, level)
 
 
 def register_pattern(op_name, pattern, level=10):
@@ -361,7 +360,7 @@ def register_pattern(op_name, pattern, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "TOpPattern", pattern, level)
+    return ir.register_op_attr(op_name, "TOpPattern", pattern, level)
 
 
 def register_gradient(op_name, fgradient=None, level=10):
@@ -378,7 +377,7 @@ def register_gradient(op_name, fgradient=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FPrimalGradient", fgradient, level)
+    return ir.register_op_attr(op_name, "FPrimalGradient", fgradient, level)
 
 
 def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
@@ -400,7 +399,7 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
         The priority level
     """
     get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
-    return tvm.ir.register_op_attr(op_name, "FShapeFunc", shape_func, level)
+    return ir.register_op_attr(op_name, "FShapeFunc", shape_func, level)
 
 
 def register_external_compiler(op_name, fexternal=None, level=10):
@@ -419,15 +418,15 @@ def register_external_compiler(op_name, fexternal=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level)
+    return ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level)
 
 
-@tvm._ffi.register_func("relay.op.compiler._lower")
+_ffi.register_func("relay.op.compiler._lower")
 def _lower(name, schedule, inputs, outputs):
     return lower(schedule, list(inputs) + list(outputs), name=name)
 
 
-@tvm._ffi.register_func("relay.op.compiler._build")
+_ffi.register_func("relay.op.compiler._build")
 def _build(lowered_funcs):
     return build(lowered_funcs, target="llvm")
 
@@ -444,7 +443,7 @@ def debug(expr, debug_func=None):
 
     if debug_func:
         name = "debugger_func{}".format(__DEBUG_COUNTER__)
-        tvm._ffi.register_func(name, debug_func)
+        _ffi.register_func(name, debug_func)
         __DEBUG_COUNTER__ += 1
     else:
         name = ""
@@ -452,4 +451,4 @@ def debug(expr, debug_func=None):
     return _make.debug(expr, name)
 
 
-tvm._ffi._init_api("relay.op", __name__)
+_ffi._init_api("relay.op", __name__)
diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py
index 1d0ea17..9684e42 100644
--- a/python/tvm/relay/transform/__init__.py
+++ b/python/tvm/relay/transform/__init__.py
@@ -19,4 +19,4 @@
 # transformation passes
 from .transform import *
 from .recast import recast
-from . import memory_alloc
+# from . import memory_alloc
diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py
index 66528c8..593a411 100644
--- a/python/tvm/relay/transform/memory_alloc.py
+++ b/python/tvm/relay/transform/memory_alloc.py
@@ -20,14 +20,13 @@ A pass for manifesting explicit memory allocations.
 """
 import numpy as np
 
-from tvm.ir.transform import PassContext, module_pass
-from tvm.relay.transform import InferType
-from tvm import nd, container
+from ... import DataType, register_func, nd, container, cpu
+from ...ir.transform import PassContext, module_pass
+from . import InferType
 from ..function import Function
 from ..expr_functor import ExprVisitor, ExprMutator
 from ..scope_builder import ScopeBuilder
 from .. import op
-from ... import DataType, register_func
 from .. import ty, expr
 from ..backend import compile_engine
 from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type
diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py
index 4907a0b..3b01182 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -23,11 +23,12 @@ import inspect
 import functools
 import warnings
 
+from ...ir import transform as tvm_transform
 import tvm.ir
 from tvm import te
 from tvm.runtime import ndarray as _nd
 
-from tvm import relay
+# from tvm import relay
 from . import _ffi_api
 
 
@@ -82,7 +83,7 @@ def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None
 
 
 @tvm._ffi.register_object("relay.FunctionPass")
-class FunctionPass(tvm.ir.transform.Pass):
+class FunctionPass():
     """A pass that works on each tvm.relay.Function in a module. A function
     pass class should be created through `function_pass`.
     """
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index 3ff544f..47badb5 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -17,8 +17,6 @@
 
 # pylint: disable=redefined-builtin, wildcard-import
 """CUDA specific declaration and schedules."""
-from __future__ import absolute_import as _abs
-
 from .conv1d import *
 from .conv1d_transpose_ncw import *
 from .conv2d import *
diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py
index ebac551..50f6ae8 100644
--- a/python/tvm/topi/cuda/sparse.py
+++ b/python/tvm/topi/cuda/sparse.py
@@ -17,7 +17,6 @@
 
 """Sparse operators"""
 import numpy as np
-import scipy.sparse as sp
 
 import tvm
 from tvm import relay, te
@@ -326,6 +325,7 @@ def schedule_sparse_dense_padded(outs):
 
 def pad_sparse_matrix(matrix, blocksize):
     """Pad rows of sparse matrix matrix so that they are a multiple of blocksize."""
+    import scipy.sparse as sp
     assert isinstance(matrix, sp.bsr_matrix)
     new_entries = np.zeros(matrix.shape[0], dtype=matrix.indptr.dtype)
     bsr = matrix.blocksize[0]
@@ -362,6 +362,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
     sparse_dense implementation for one that operates on a padded matrix. We
     also padd the matrix.
     """
+    import scipy.sparse as sp
     if (
         isinstance(inputs[1], relay.Constant)
         and isinstance(inputs[2], relay.Constant)
diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs
index b8bfb4e..5df9040 100644
--- a/rust/tvm-rt/src/map.rs
+++ b/rust/tvm-rt/src/map.rs
@@ -107,6 +107,18 @@ where
         let oref: ObjectRef = map_get_item(self.object.clone(), key.upcast())?;
         oref.downcast()
     }
+
+    pub fn empty() -> Self {
+        Self::from_iter(vec![].into_iter())
+    }
+
+    //(@jroesch): I don't think this is a correct implementation.
+    pub fn null() -> Self {
+        Map {
+            object: ObjectRef::null(),
+            _data: PhantomData,
+        }
+    }
 }
 
 pub struct IntoIter<K, V> {
diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs
index c0822a5..18347da 100644
--- a/rust/tvm-rt/src/module.rs
+++ b/rust/tvm-rt/src/module.rs
@@ -30,6 +30,8 @@ use tvm_sys::ffi;
 
 use crate::errors::Error;
 use crate::{errors, function::Function};
+use crate::{String as TString};
+use crate::RetValue;
 
 const ENTRY_FUNC: &str = "__tvm_main__";
 
@@ -49,6 +51,9 @@ crate::external! {
 
     #[name("runtime.ModuleLoadFromFile")]
     fn load_from_file(file_name: CString, format: CString) -> Module;
+
+    #[name("runtime.ModuleSaveToFile")]
+    fn save_to_file(module: ffi::TVMModuleHandle, name: TString, fmt: TString);
 }
 
 impl Module {
@@ -110,6 +115,10 @@ impl Module {
         Ok(module)
     }
 
+    pub fn save_to_file(&self, name: String, fmt: String) -> Result<(), Error> {
+        save_to_file(self.handle(), name.into(), fmt.into())
+    }
+
     /// Checks if a target device is enabled for a module.
     pub fn enabled(&self, target: &str) -> bool {
         let target = CString::new(target).unwrap();
@@ -128,3 +137,10 @@ impl Drop for Module {
         check_call!(ffi::TVMModFree(self.handle));
     }
 }
+
+// impl std::convert::TryFrom<RetValue> for Module {
+//     type Error = Error;
+//     fn try_from(ret_value: RetValue) -> Result<Module, Self::Error> {
+//         Ok(Module::new(ret_value.try_into()?))
+//     }
+// }
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
index affd81b..c5ede7d 100644
--- a/rust/tvm-rt/src/to_function.rs
+++ b/rust/tvm-rt/src/to_function.rs
@@ -255,6 +255,7 @@ impl_typed_and_to_function!(2; A, B);
 impl_typed_and_to_function!(3; A, B, C);
 impl_typed_and_to_function!(4; A, B, C, D);
 impl_typed_and_to_function!(5; A, B, C, D, E);
+impl_typed_and_to_function!(6; A, B, C, D, E, G);
 
 #[cfg(test)]
 mod tests {
diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml
index 153a195..c1d8aa8 100644
--- a/rust/tvm/Cargo.toml
+++ b/rust/tvm/Cargo.toml
@@ -50,7 +50,7 @@ tvm-macros = { version = "*", path = "../tvm-macros/" }
 paste = "0.1"
 mashup = "0.1"
 once_cell = "^1.3.1"
-pyo3 = { version = "0.11.1", optional = true }
+pyo3 = { version = "^0.12", optional = true }
 codespan-reporting = "0.9.5"
 structopt = { version = "0.3" }
 
diff --git a/rust/tvm/src/python.rs b/rust/tvm/src/python.rs
index 89558af..50ce7b0 100644
--- a/rust/tvm/src/python.rs
+++ b/rust/tvm/src/python.rs
@@ -18,6 +18,7 @@
  */
 
 use pyo3::prelude::*;
+use once_cell::sync::OnceCell;
 
 /// Load the Python interpreter into the address space.
 ///
@@ -29,6 +30,8 @@ use pyo3::prelude::*;
 pub fn load() -> Result<String, ()> {
     let gil = Python::acquire_gil();
     let py = gil.python();
+    // let main_mod = initialize();
+    //let main_mod = main_mod.as_ref(py);
     load_python_tvm_(py).map_err(|e| {
         // We can't display Python exceptions via std::fmt::Display,
         // so print the error here manually.
@@ -36,12 +39,22 @@ pub fn load() -> Result<String, ()> {
     })
 }
 
-// const TVMC_CODE: &'static str = include_str!("tvmc.py");
+pub fn import(mod_to_import: &str) -> PyResult<()> {
+    let gil = Python::acquire_gil();
+    let py = gil.python();
+    import_python(py, mod_to_import)?;
+    Ok(())
+}
+
+fn import_python<'p, 'b: 'p>(py: Python<'p>, to_import: &'b str) -> PyResult<&'p PyModule> {
+    let imported_mod = py.import(to_import)?;
+    Ok(imported_mod)
+}
+
 
 fn load_python_tvm_(py: Python) -> PyResult<String> {
-    let sys = py.import("tvm")?;
-    let version: String = sys.get("__version__")?.extract()?;
-    // py.run(TVMC_CODE, None, None)?;
+    let imported_mod = import_python(py, "tvm")?;
+    let version: String = imported_mod.get("__version__")?.extract()?;
     Ok(version)
 }
 
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index ac2b60f..af5feab 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -175,7 +175,7 @@ TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) {
 TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);
 
 TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
-    .set_body_typed([](Module mod, std::string name, std::string fmt) {
+    .set_body_typed([](Module mod, tvm::String name, tvm::String fmt) {
       mod->SaveToFile(name, fmt);
     });