You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/12/02 18:26:01 UTC

[incubator-mxnet] branch master updated: Update TVM integration to v0.7 (#19613)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new af26e92  Update TVM integration to v0.7 (#19613)
af26e92 is described below

commit af26e92c5c7e3053c8e69f0cdf607e22856f3786
Author: Leonard Lausen <la...@amazon.com>
AuthorDate: Wed Dec 2 11:23:52 2020 -0700

    Update TVM integration to v0.7 (#19613)
    
    * Update TVM integration to v0.7
    
    * Disable TVM Bridge on Windows
    
    TVM Bridge does not introduce link dependencies, but causes link failure on Windows. May be a MSVC bug.
    
    [2020-12-02T01:54:37.813Z] LINK: command "C:\PROGRA~2\MICROS~1\2019\COMMUN~1\VC\Tools\MSVC\1428~1.293\bin\Hostx64\x64\link.exe /nologo @CMakeFiles\mxnet.rsp /out:libmxnet.dll /implib:libmxnet.lib /pdb:libmxnet.pdb /dll /version:0.0 /machine:x64 /INCREMENTAL:NO /OPT:REF /OPT:ICF /MANIFEST /MANIFESTFILE:libmxnet.dll.manifest" failed (exit code 1120) with the following output:
    [2020-12-02T01:54:37.813Z]    Creating library libmxnet.lib and object libmxnet.exp
    [2020-12-02T01:54:37.813Z] tvm_bridge.cc.obj : error LNK2019: unresolved external symbol "__declspec(dllimport) protected: void __cdecl tvm::runtime::Object::DecRef(void)" (__imp_?DecRef@Object@runtime@tvm@@IEAAXXZ) referenced in function "private: void __cdecl tvm::runtime::TVMRetValue::Clear(void)" (?Clear@TVMRetValue@runtime@tvm@@AEAAXXZ)
---
 3rdparty/tvm                        |  2 +-
 CMakeLists.txt                      |  8 +++-
 cmake/BuildTVM.cmake                |  6 +--
 contrib/tvmop/basic/ufunc.py        | 76 ++++++++++++++++++-------------------
 contrib/tvmop/compile.py            | 22 ++++++-----
 contrib/tvmop/core/fromnumeric.py   | 10 ++---
 contrib/tvmop/core/multiarray.py    | 18 ++++-----
 contrib/tvmop/core/umath.py         | 32 ++++++++--------
 contrib/tvmop/opdef.py              |  2 +-
 contrib/tvmop/utils.py              | 12 +++---
 python/mxnet/libinfo.py             |  1 +
 src/operator/tvmop/op_module.cc     |  2 +-
 tests/python/gpu/test_tvm_bridge.py | 10 ++---
 13 files changed, 105 insertions(+), 96 deletions(-)

diff --git a/3rdparty/tvm b/3rdparty/tvm
index 9bd2c7b..efdac94 160000
--- a/3rdparty/tvm
+++ b/3rdparty/tvm
@@ -1 +1 @@
-Subproject commit 9bd2c7b44208ed992061f8c2688e1137357f1db1
+Subproject commit efdac9439506d1de5eec91ecc795982c78e41909
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 425824b..7d34d5f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -510,6 +510,11 @@ endif()
 FILE(GLOB_RECURSE SOURCE "src/*.cc" "src/*.h" "include/*.h")
 FILE(GLOB_RECURSE CUDA "src/*.cu" "src/*.cuh")
 
+if(MSVC)
+  FILE(GLOB_RECURSE TVM_BRIDGE_SOURCE "src/*/tvm_bridge.cc")
+  list(REMOVE_ITEM SOURCE ${TVM_BRIDGE_SOURCE})
+endif()
+
 if(NOT USE_INTGEMM)
   FILE(GLOB_RECURSE INTGEMM_OPERATOR_SOURCE "src/operator/contrib/intgemm/*.cc" "src/operator/contrib/intgemm/*.h")
   list(REMOVE_ITEM SOURCE ${INTGEMM_OPERATOR_SOURCE})
@@ -865,12 +870,11 @@ function(BuildTVMOP)
   include(cmake/BuildTVM.cmake)
   add_subdirectory("3rdparty/tvm")
   set_target_properties(tvm PROPERTIES CXX_CLANG_TIDY "")  # don't lint 3rdparty dependency
-  set_target_properties(tvm_topi PROPERTIES CXX_CLANG_TIDY "")  # don't lint 3rdparty dependency
   set_target_properties(tvm_runtime PROPERTIES CXX_CLANG_TIDY "")  # don't lint 3rdparty dependency
 endfunction()
 
 if(USE_TVM_OP)
-  list(APPEND mxnet_LINKER_LIBS ${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm/libtvm_runtime.so)
+  list(APPEND mxnet_LINKER_LIBS tvm_runtime)
   BuildTVMOP()
   find_package(Python3 REQUIRED)
   set(TVM_OP_COMPILE_OPTIONS "-o${CMAKE_CURRENT_BINARY_DIR}" "--config" "${CMAKE_CURRENT_BINARY_DIR}/tvmop.conf" "-L" "${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm")
diff --git a/cmake/BuildTVM.cmake b/cmake/BuildTVM.cmake
index 5f57959..4637c2c 100644
--- a/cmake/BuildTVM.cmake
+++ b/cmake/BuildTVM.cmake
@@ -85,9 +85,9 @@ set(USE_LLVM ON)
 set(USE_BLAS none)
 
 # /path/to/mkl: mkl root path when use mkl blas library
-# set(USE_MKL_PATH /opt/intel/mkl) for UNIX
-# set(USE_MKL_PATH ../IntelSWTools/compilers_and_libraries_2018/windows/mkl) for WIN32
-set(USE_MKL_PATH none)
+# set(USE_MKL /opt/intel/mkl) for UNIX
+# set(USE_MKL ../IntelSWTools/compilers_and_libraries_2018/windows/mkl) for WIN32
+set(USE_MKL OFF)
 
 # Whether use contrib.random in runtime
 set(USE_RANDOM OFF)
diff --git a/contrib/tvmop/basic/ufunc.py b/contrib/tvmop/basic/ufunc.py
index 912263e..1c7b82e 100644
--- a/contrib/tvmop/basic/ufunc.py
+++ b/contrib/tvmop/basic/ufunc.py
@@ -21,11 +21,11 @@ from .. import defop, AllTypes, RealTypes
 from .. import assign_by_req, reduce_axes
 
 def compute_add(dtype, ndim):
-    A = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='A', dtype=dtype)
-    B = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='B', dtype=dtype)
-    C = tvm.compute([tvm.size_var() for _ in range(ndim)],
+    A = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='A', dtype=dtype)
+    B = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='B', dtype=dtype)
+    C = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
                     lambda *index: A[index] + B[index], name='C')
-    s = tvm.create_schedule(C.op)
+    s = tvm.te.create_schedule(C.op)
     return s, A, B, C
 
 
@@ -44,12 +44,12 @@ def vadd(dtype, ndim):
        dtype=["float32", "float64"], ndim=[5])
 def vadd_gpu(dtype, ndim):
     s, A, B, C = compute_add(dtype, ndim)
-    s = tvm.create_schedule(C.op)
+    s = tvm.te.create_schedule(C.op)
     axes = [axis for axis in C.op.axis]
     fused = s[C].fuse(*axes)
     bx, tx = s[C].split(fused, factor=64)
-    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
-    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[C].bind(bx, tvm.te.thread_axis("blockIdx.x"))
+    s[C].bind(tx, tvm.te.thread_axis("threadIdx.x"))
     return s, [A, B, C]
 
 
@@ -62,12 +62,12 @@ def compute_backward_vadd(dtype, ndim, reduce1st, req):
     # They compressed bit string is stored in `axes`. And `reduce1st` represents the first bit
     # of the compressed bit string. Credit to @junrushao1994 and @yzhliu.
     axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim]
-    X = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='X', dtype=dtype)
-    reducer = tvm.comm_reducer(lambda x, y: x + y,
-        lambda t: tvm.const(0, dtype=t), name="sum")
+    X = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='X', dtype=dtype)
+    reducer = tvm.te.comm_reducer(lambda x, y: x + y,
+        lambda t: tvm.tir.const(0, dtype=t), name="sum")
     ret = reduce_axes(X, axes, reducer)
     in_grad_a, in_grad = assign_by_req(ret, req)
-    s = tvm.create_schedule(in_grad.op)
+    s = tvm.te.create_schedule(in_grad.op)
     return s, X, in_grad_a, in_grad, [ret, in_grad]
 
 
@@ -90,8 +90,8 @@ def backward_vadd_gpu(dtype, ndim, reduce1st, req):
     s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
     num_thread = 64
     for t in c_list:
-        block_x = tvm.thread_axis("blockIdx.x")
-        thread_x = tvm.thread_axis("threadIdx.x")
+        block_x = tvm.te.thread_axis("blockIdx.x")
+        thread_x = tvm.te.thread_axis("threadIdx.x")
         axes = [axis for axis in t.op.axis]
         fused = s[t].fuse(*axes)
         bx, tx = s[t].split(fused, factor=num_thread)
@@ -101,15 +101,15 @@ def backward_vadd_gpu(dtype, ndim, reduce1st, req):
 
 
 def compute_degandrad(dtype, ndim, n):
-    A = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='A', dtype=dtype)
+    A = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='A', dtype=dtype)
     import math
     if n == 0:
-        B = tvm.compute([tvm.size_var() for _ in range(ndim)],
-                        lambda *index: A[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype), name='B')
+        B = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
+                        lambda *index: A[index] * tvm.tir.const(math.pi, dtype) / tvm.tir.const(180, dtype), name='B')
     else:
-        B = tvm.compute([tvm.size_var() for _ in range(ndim)],
-                        lambda *index: A[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype), name='B')
-    s = tvm.create_schedule(B.op)
+        B = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
+                        lambda *index: A[index] / tvm.tir.const(math.pi, dtype) * tvm.tir.const(180, dtype), name='B')
+    s = tvm.te.create_schedule(B.op)
     return s, A, B
 
 
@@ -137,12 +137,12 @@ def rad2deg(dtype, ndim):
        dtype=["float32", "float64"], ndim=list(range(0, 6)))
 def deg2rad_gpu(dtype, ndim):
     s, A, B = compute_degandrad(dtype, ndim, 0)
-    s = tvm.create_schedule(B.op)
+    s = tvm.te.create_schedule(B.op)
     axes = [axis for axis in B.op.axis]
     fused = s[B].fuse(*axes)
     bx, tx = s[B].split(fused, factor=64)
-    s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
-    s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[B].bind(bx, tvm.te.thread_axis("blockIdx.x"))
+    s[B].bind(tx, tvm.te.thread_axis("threadIdx.x"))
     return s, [A, B]
 
 
@@ -150,30 +150,30 @@ def deg2rad_gpu(dtype, ndim):
        dtype=["float32", "float64"], ndim=list(range(0, 6)))
 def rad2deg_gpu(dtype, ndim):
     s, A, B = compute_degandrad(dtype, ndim, 1)
-    s = tvm.create_schedule(B.op)
+    s = tvm.te.create_schedule(B.op)
     axes = [axis for axis in B.op.axis]
     fused = s[B].fuse(*axes)
     bx, tx = s[B].split(fused, factor=64)
-    s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
-    s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
+    s[B].bind(bx, tvm.te.thread_axis("blockIdx.x"))
+    s[B].bind(tx, tvm.te.thread_axis("threadIdx.x"))
     return s, [A, B]
 
 
 def compute_backward_degandrad(dtype, ndim, req, n):
-    ishape = [tvm.size_var() for _ in range(ndim)]
-    in_grad_tmp = tvm.placeholder(ishape, name='in_grad_tmp', dtype=dtype)
-    in_grad = tvm.placeholder(ishape, name='in_grad', dtype=dtype)
-    out_grad = tvm.placeholder(ishape, name='out_grad', dtype=dtype)
+    ishape = [tvm.te.size_var() for _ in range(ndim)]
+    in_grad_tmp = tvm.te.placeholder(ishape, name='in_grad_tmp', dtype=dtype)
+    in_grad = tvm.te.placeholder(ishape, name='in_grad', dtype=dtype)
+    out_grad = tvm.te.placeholder(ishape, name='out_grad', dtype=dtype)
     import math
     if n == 0:
-        ret = tvm.compute(ishape, lambda *index: out_grad[index] * tvm.const(math.pi, dtype) / tvm.const(180, dtype))
+        ret = tvm.te.compute(ishape, lambda *index: out_grad[index] * tvm.tir.const(math.pi, dtype) / tvm.tir.const(180, dtype))
     else:
-        ret = tvm.compute(ishape, lambda *index: out_grad[index] / tvm.const(math.pi, dtype) * tvm.const(180, dtype))
+        ret = tvm.te.compute(ishape, lambda *index: out_grad[index] / tvm.tir.const(math.pi, dtype) * tvm.tir.const(180, dtype))
     if (req == "kAddTo"):
-        in_grad = tvm.compute(ishape, lambda *index: in_grad_tmp[index] + ret[index])
+        in_grad = tvm.te.compute(ishape, lambda *index: in_grad_tmp[index] + ret[index])
     else:
-        in_grad = tvm.compute(ishape, lambda *index: ret[index])
-    s = tvm.create_schedule(in_grad.op)
+        in_grad = tvm.te.compute(ishape, lambda *index: ret[index])
+    s = tvm.te.create_schedule(in_grad.op)
     return s, out_grad, in_grad_tmp, in_grad, [ret, in_grad]
 
 
@@ -208,8 +208,8 @@ def cuda_backward_deg2rad(dtype, ndim, req):
     s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 0)
     num_thread = 64
     for t in c_list:
-        block_x = tvm.thread_axis("blockIdx.x")
-        thread_x = tvm.thread_axis("threadIdx.x")
+        block_x = tvm.te.thread_axis("blockIdx.x")
+        thread_x = tvm.te.thread_axis("threadIdx.x")
         axes = [axis for axis in t.op.axis]
         fused = s[t].fuse(*axes)
         bx, tx = s[t].split(fused, factor=num_thread)
@@ -225,8 +225,8 @@ def cuda_backward_rad2deg(dtype, ndim, req):
     s, out_grad, in_grad_tmp, in_grad, c_list = compute_backward_degandrad(dtype, ndim, req, 1)
     num_thread = 64
     for t in c_list:
-        block_x = tvm.thread_axis("blockIdx.x")
-        thread_x = tvm.thread_axis("threadIdx.x")
+        block_x = tvm.te.thread_axis("blockIdx.x")
+        thread_x = tvm.te.thread_axis("threadIdx.x")
         axes = [axis for axis in t.op.axis]
         fused = s[t].fuse(*axes)
         bx, tx = s[t].split(fused, factor=num_thread)
diff --git a/contrib/tvmop/compile.py b/contrib/tvmop/compile.py
index f15e5a7..6f9fb7c5 100644
--- a/contrib/tvmop/compile.py
+++ b/contrib/tvmop/compile.py
@@ -125,23 +125,27 @@ if __name__ == "__main__":
                         help="Path which stores the config file")
     arguments = parser.parse_args()
 
-    func_list_llvm = []
-    func_list_cuda = []
+    mod_llvm = tvm.IRModule({})
+    mod_cuda = tvm.IRModule({})
+    has_cuda = False
 
     # TODO: attach instruction features to the library, e.g., avx-512, etc.
     for operator_def in __OP_DEF__:
         for sch, args, name in operator_def.invoke_all():
             name = operator_def.get_op_name(name, args)
-            if tvm.module.enabled(get_target(operator_def.target)):
-                func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda
+            if tvm.runtime.module.enabled(get_target(operator_def.target)):
                 func_lower = tvm.lower(sch, args,
                                        name=name,
                                        binds=operator_def.get_binds(args))
-                func_list.append(func_lower)
-
-    lowered_funcs = {get_target("cpu"): func_list_llvm}
-    if len(func_list_cuda) > 0:
-        lowered_funcs[get_target("cuda")] = func_list_cuda
+                if operator_def.target == "cpu":
+                    mod = mod_llvm.update(func_lower)
+                else:
+                    has_cuda = True
+                    mod_cuda.update(func_lower)
+
+    lowered_funcs = {get_target("cpu"): mod_llvm}
+    if has_cuda > 0:
+        lowered_funcs[get_target("cuda")] = mod_cuda
         cuda_arch = get_cuda_arch(arguments.cuda_arch)
         if cuda_arch is None:
             logging.info('No cuda arch specified. TVM will try to detect it from the build platform.')
diff --git a/contrib/tvmop/core/fromnumeric.py b/contrib/tvmop/core/fromnumeric.py
index 5b21cf8..f85f493 100644
--- a/contrib/tvmop/core/fromnumeric.py
+++ b/contrib/tvmop/core/fromnumeric.py
@@ -23,10 +23,10 @@ from ..utils import reduce_axes, assign_by_req
 
 def _compute_sum(itype, otype, ndim, reduce1st_dim, req):
     axes = ([reduce1st_dim, 1 - reduce1st_dim] * ndim)[:ndim]
-    a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='a', dtype=itype)
-    reduce_output = reduce_axes(a, axes, tvm.sum, otype)
+    a = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='a', dtype=itype)
+    reduce_output = reduce_axes(a, axes, tvm.tir.sum, otype)
     output_placeholder, final_output = assign_by_req(reduce_output, req)
-    s = tvm.create_schedule(final_output.op)
+    s = tvm.te.create_schedule(final_output.op)
     return s, a, output_placeholder, final_output, [reduce_output, final_output]
 
 
@@ -53,8 +53,8 @@ def _sum_gpu(itype, otype, ndim, reduce1st_dim, req):
         itype, otype, ndim, reduce1st_dim, req)
     num_threads = 64
     for t in tensor_list:
-        block_x = tvm.thread_axis("blockIdx.x")
-        thread_x = tvm.thread_axis("threadIdx.x")
+        block_x = tvm.te.thread_axis("blockIdx.x")
+        thread_x = tvm.te.thread_axis("threadIdx.x")
         axes = [axis for axis in t.op.axis]
         fused = s[t].fuse(*axes)
         bx, tx = s[t].split(fused, factor=num_threads)
diff --git a/contrib/tvmop/core/multiarray.py b/contrib/tvmop/core/multiarray.py
index baccba9..b246df7 100644
--- a/contrib/tvmop/core/multiarray.py
+++ b/contrib/tvmop/core/multiarray.py
@@ -25,9 +25,9 @@ def compute_dot(A, B):
     M = A.shape[0]
     K = A.shape[1]
     N = B.shape[1]
-    k = tvm.reduce_axis((0, K), 'k')
-    C = tvm.compute((M, N),
-                    lambda x, y: tvm.sum(A[x, k] * B[k, y], axis=k),
+    k = tvm.te.reduce_axis((0, K), 'k')
+    C = tvm.te.compute((M, N),
+                    lambda x, y: tvm.tir.sum(A[x, k] * B[k, y], axis=k),
                     name='C')
     return C
 
@@ -37,13 +37,13 @@ def dot(dtype, fallback):
     cfg = autotvm.get_config()
     cfg.define_knob("bn", [64] if fallback else [64, 32])
     cfg.define_knob("factor", [4] if fallback else [4])
-    M = tvm.size_var("M")
-    K = tvm.size_var("K")
-    N = tvm.size_var("N")
-    A = tvm.placeholder((M, K), name='A', dtype=dtype)
-    B = tvm.placeholder((K, N), name='B', dtype=dtype)
+    M = tvm.te.size_var("M")
+    K = tvm.te.size_var("K")
+    N = tvm.te.size_var("N")
+    A = tvm.te.placeholder((M, K), name='A', dtype=dtype)
+    B = tvm.te.placeholder((K, N), name='B', dtype=dtype)
     C = compute_dot(A, B)
-    s = tvm.create_schedule(C.op)
+    s = tvm.te.create_schedule(C.op)
     # Blocking by loop tiling
     xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], cfg["bn"].val, cfg["bn"].val)
     k, = s[C].op.reduce_axis
diff --git a/contrib/tvmop/core/umath.py b/contrib/tvmop/core/umath.py
index 94f2f4a..d314bde 100644
--- a/contrib/tvmop/core/umath.py
+++ b/contrib/tvmop/core/umath.py
@@ -25,18 +25,18 @@ _bin_logic_op_map = {
     'less': lambda a, b, *idx: a[idx] < b[idx],
     'greater_equal': lambda a, b, *idx: a[idx] >= b[idx],
     'less_equal': lambda a, b, *idx: a[idx] <= b[idx],
-    'logical_and': lambda a, b, *idx: tvm.all(a[idx] != 0, b[idx] != 0),
-    'logical_or': lambda a, b, *idx: tvm.any(a[idx] != 0, b[idx] != 0),
-    'logical_xor': lambda a, b, *idx: tvm.all(tvm.any(a[idx] != 0, b[idx] != 0), tvm.any(a[idx] == 0, b[idx] == 0)),
+    'logical_and': lambda a, b, *idx: tvm.tir.all(a[idx] != 0, b[idx] != 0),
+    'logical_or': lambda a, b, *idx: tvm.tir.any(a[idx] != 0, b[idx] != 0),
+    'logical_xor': lambda a, b, *idx: tvm.tir.all(tvm.tir.any(a[idx] != 0, b[idx] != 0), tvm.tir.any(a[idx] == 0, b[idx] == 0)),
 }
 
 
 def _compute_binary_logic(op, dtype, ndim):
-    a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], dtype=dtype, name='a')
-    b = tvm.placeholder([tvm.size_var() for _ in range(ndim)], dtype=dtype, name='b')
-    c = tvm.compute([tvm.size_var() for _ in range(ndim)],
+    a = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], dtype=dtype, name='a')
+    b = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], dtype=dtype, name='b')
+    c = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
                     lambda *idx: _bin_logic_op_map[op](a, b, *idx), name='c')
-    s = tvm.create_schedule(c.op)
+    s = tvm.te.create_schedule(c.op)
     return s, a, b, c
 
 
@@ -70,8 +70,8 @@ def _binary_logic_gpu(compute_func, op, itype, ndim):
     axes = [axis for axis in c.op.axis]
     fused = s[c].fuse(*axes)
     bx, tx = s[c].split(fused, factor=64)
-    s[c].bind(bx, tvm.thread_axis('blockIdx.x'))
-    s[c].bind(tx, tvm.thread_axis('threadIdx.x'))
+    s[c].bind(bx, tvm.te.thread_axis('blockIdx.x'))
+    s[c].bind(tx, tvm.te.thread_axis('threadIdx.x'))
     return s, [a, b, c]
 
 
@@ -90,18 +90,18 @@ _bin_scalar_logic_op_map = {
     'less_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) < b,
     'greater_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) >= b,
     'less_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) <= b,
-    'logical_and_scalar': lambda a, b, *idx: tvm.all(a[idx].astype(b.dtype) != 0 , b != 0),
-    'logical_or_scalar': lambda a, b, *idx: tvm.any(a[idx].astype(b.dtype) != 0, b != 0),
-    'logical_xor_scalar': lambda a, b, *idx: tvm.all(tvm.any(a[idx].astype(b.dtype) != 0, b != 0), tvm.any(a[idx].astype(b.dtype) == 0, b == 0)),
+    'logical_and_scalar': lambda a, b, *idx: tvm.tir.all(a[idx].astype(b.dtype) != 0 , b != 0),
+    'logical_or_scalar': lambda a, b, *idx: tvm.tir.any(a[idx].astype(b.dtype) != 0, b != 0),
+    'logical_xor_scalar': lambda a, b, *idx: tvm.tir.all(tvm.tir.any(a[idx].astype(b.dtype) != 0, b != 0), tvm.tir.any(a[idx].astype(b.dtype) == 0, b == 0)),
 }
 
 
 def _compute_binary_scalar_logic(op, dtype, ndim):
-    a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='a', dtype=dtype)
-    b = tvm.var('b', dtype='float64')
-    c = tvm.compute([tvm.size_var() for _ in range(ndim)],
+    a = tvm.te.placeholder([tvm.te.size_var() for _ in range(ndim)], name='a', dtype=dtype)
+    b = tvm.te.var('b', dtype='float64')
+    c = tvm.te.compute([tvm.te.size_var() for _ in range(ndim)],
                     lambda *idx: _bin_scalar_logic_op_map[op](a, b, *idx), name='c')
-    s = tvm.create_schedule(c.op)
+    s = tvm.te.create_schedule(c.op)
     return s, a, b, c
 
 
diff --git a/contrib/tvmop/opdef.py b/contrib/tvmop/opdef.py
index 1e0f346..7a7c27f 100644
--- a/contrib/tvmop/opdef.py
+++ b/contrib/tvmop/opdef.py
@@ -116,7 +116,7 @@ class OpDef:
 
     def get_binds(self, args):
         if self.auto_broadcast:
-            return {arg: tvm.decl_buffer(arg.shape, arg.dtype, buffer_type="auto_broadcast")
+            return {arg: tvm.tir.decl_buffer(arg.shape, arg.dtype, buffer_type="auto_broadcast")
                     for arg in args}
         return None
 
diff --git a/contrib/tvmop/utils.py b/contrib/tvmop/utils.py
index 07eb748..9c31eb3 100644
--- a/contrib/tvmop/utils.py
+++ b/contrib/tvmop/utils.py
@@ -23,12 +23,12 @@ RealTypes = ["float32", "float64", "float16"]
 
 
 def assign_by_req(a, req, otype=None):
-    b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype)
+    b = tvm.te.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype)
     if req == "kAddTo":
-        c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) + b[idx]
+        c = tvm.te.compute(a.shape, lambda *idx: a[idx].astype(otype) + b[idx]
                                               if otype else a[idx] + b[idx])
     else:
-        c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) if otype else a[idx])
+        c = tvm.te.compute(a.shape, lambda *idx: a[idx].astype(otype) if otype else a[idx])
     return b, c
 
 
@@ -45,9 +45,9 @@ def reduce_axes(X, axes, reducer, atype=None):
     
     ishape = X.shape
     odim = (len(ishape) + 1 - axes[0]) // 2
-    oshape = [tvm.size_var() for _ in range(odim)]
-    ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1]
-    ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)].astype(atype)
+    oshape = [tvm.te.size_var() for _ in range(odim)]
+    ridx = [tvm.te.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1]
+    ret = tvm.te.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)].astype(atype)
                                                    if atype else X[get_index(idx, ridx)],
                                                    axis=ridx), name='ret')
     return ret
diff --git a/python/mxnet/libinfo.py b/python/mxnet/libinfo.py
index ea674ac..7f73276 100644
--- a/python/mxnet/libinfo.py
+++ b/python/mxnet/libinfo.py
@@ -32,6 +32,7 @@ def find_lib_path(prefix='libmxnet'):
     """
     lib_from_env = os.environ.get('MXNET_LIBRARY_PATH')
     if lib_from_env:
+        lib_from_env = lib_from_env.replace('libmxnet', prefix)
         if os.path.isfile(lib_from_env):
             if not os.path.isabs(lib_from_env):
                 logging.warning("MXNET_LIBRARY_PATH should be an absolute path, instead of: %s",
diff --git a/src/operator/tvmop/op_module.cc b/src/operator/tvmop/op_module.cc
index d833ae0..352e885 100644
--- a/src/operator/tvmop/op_module.cc
+++ b/src/operator/tvmop/op_module.cc
@@ -39,7 +39,7 @@ namespace tvm {
 namespace runtime {
 
 void TVMOpModule::Load(const std::string &filepath) {
-  static const PackedFunc *f_load = Registry::Get("module._LoadFromFile");
+  static const PackedFunc *f_load = Registry::Get("runtime.ModuleLoadFromFile");
   std::lock_guard<std::mutex> lock(mutex_);
   Module module = (*f_load)(filepath, "");
   module_ptr_ = std::make_shared<Module>();
diff --git a/tests/python/gpu/test_tvm_bridge.py b/tests/python/gpu/test_tvm_bridge.py
index 7a4339c..b442aaf 100644
--- a/tests/python/gpu/test_tvm_bridge.py
+++ b/tests/python/gpu/test_tvm_bridge.py
@@ -33,11 +33,11 @@ def test_tvm_bridge():
 
     def check(target, dtype):
         shape = (20,)
-        scale = tvm.var("scale", dtype="float32")
-        x = tvm.placeholder(shape, dtype=dtype)
-        y = tvm.placeholder(shape, dtype=dtype)
-        z = tvm.compute(shape, lambda i: x[i] + y[i])
-        zz = tvm.compute(shape, lambda *i: z(*i) * scale.astype(dtype))
+        scale = tvm.te.var("scale", dtype="float32")
+        x = tvm.te.placeholder(shape, dtype=dtype)
+        y = tvm.te.placeholder(shape, dtype=dtype)
+        z = tvm.te.compute(shape, lambda i: x[i] + y[i])
+        zz = tvm.te.compute(shape, lambda *i: z(*i) * scale.astype(dtype))
         ctx = mx.gpu(0) if target == "cuda" else mx.cpu(0)
         target = tvm.target.create(target)