You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2021/04/09 19:17:29 UTC

[tvm] 08/09: fix for onnxruntime

This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch checkpoint
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit bea9d9506a3881fa2b5ff63c7fa81e2096e110d4
Author: Tristan Konolige <tr...@gmail.com>
AuthorDate: Wed Mar 17 17:23:18 2021 -0700

    fix for onnxruntime
---
 cmake/modules/Vulkan.cmake            | 12 ++++++++----
 cmake/utils/FindVulkan.cmake          |  6 +++---
 python/tvm/relay/frontend/jit/onnx.py | 19 +++++++++++++------
 src/runtime/graph/graph_runtime.cc    | 15 ++++++++++++++-
 src/target/spirv/build_vulkan.cc      |  2 +-
 src/target/spirv/intrin_rule_spirv.cc |  2 +-
 src/target/spirv/ir_builder.h         |  2 +-
 7 files changed, 41 insertions(+), 17 deletions(-)

diff --git a/cmake/modules/Vulkan.cmake b/cmake/modules/Vulkan.cmake
index 4df8986..abed0ab 100644
--- a/cmake/modules/Vulkan.cmake
+++ b/cmake/modules/Vulkan.cmake
@@ -16,7 +16,11 @@
 # under the License.
 
 # Be compatible with older version of CMake
-find_vulkan(${USE_VULKAN})
+find_package(Vulkan REQUIRED)
+find_package(PkgConfig REQUIRED)
+
+pkg_check_modules(SPIRV REQUIRED spirv)
+pkg_check_modules(SPIRV_TOOLS REQUIRED SPIRV-Tools)
 
 # Extra Vulkan runtime options, exposed for advanced users.
 tvm_option(USE_VULKAN_IMMEDIATE_MODE "Use Vulkan Immediate mode
@@ -29,7 +33,7 @@ tvm_option(USE_VULKAN_VALIDATION "Enable Vulkan API validation layers" OFF
 if(Vulkan_FOUND)
   # always set the includedir
   # avoid global retrigger of cmake
-  include_directories(SYSTEM ${Vulkan_INCLUDE_DIRS})
+  include_directories(SYSTEM ${Vulkan_INCLUDE_DIRS} ${SPIRV_INCLUDEDIR} ${SPIRV_TOOLS_INCLUDEDIR})
 endif(Vulkan_FOUND)
 
 if(USE_VULKAN)
@@ -41,8 +45,8 @@ if(USE_VULKAN)
   file(GLOB COMPILER_VULKAN_SRCS src/target/spirv/*.cc)
   list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS})
   list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
-  list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY})
-  list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY})
+  list(APPEND TVM_LINKER_LIBS ${Vulkan_LIBRARIES} ${SPIRV_LIBRARIES} ${SPIRV_TOOLS_LIBRARIES})
+  list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARIES})
 
   if(USE_VULKAN_IMMEDIATE_MODE)
     message(STATUS "Build with Vulkan immediate mode")
diff --git a/cmake/utils/FindVulkan.cmake b/cmake/utils/FindVulkan.cmake
index feb5eec..ac2e32a 100644
--- a/cmake/utils/FindVulkan.cmake
+++ b/cmake/utils/FindVulkan.cmake
@@ -65,9 +65,9 @@ macro(find_vulkan use_vulkan)
         HINTS ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${VULKAN_SDK}/lib)
 
     find_path(_libspirv libspirv.h HINTS ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan spirv-tools)
-    find_path(_spirv spirv.hpp HINTS ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers)
-    find_path(_glsl_std GLSL.std.450.h HINTS ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers)
-    list(APPEND Vulkan_INCLUDE_DIRS ${_libspirv} ${_spirv} ${_glsl_std})
+    # find_path(_spirv spirv.hpp HINTS ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers)
+    # find_path(_glsl_std GLSL.std.450.h HINTS ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers)
+    list(APPEND Vulkan_INCLUDE_DIRS ${_libspirv})
     message(STATUS "Vulkan_INCLUDE_DIRS=" ${Vulkan_INCLUDE_DIRS})
     message(STATUS "Vulkan_LIBRARY=" ${Vulkan_LIBRARY})
     message(STATUS "Vulkan_SPIRV_TOOLS_LIBRARY=" ${Vulkan_SPIRV_TOOLS_LIBRARY})
diff --git a/python/tvm/relay/frontend/jit/onnx.py b/python/tvm/relay/frontend/jit/onnx.py
index 0f2a79c..ae10915 100644
--- a/python/tvm/relay/frontend/jit/onnx.py
+++ b/python/tvm/relay/frontend/jit/onnx.py
@@ -21,6 +21,7 @@ import tvm.relay
 import tvm.autotvm as autotvm
 import timeit
 import numpy as np
+import collections
 
 @tvm.register_func("tvm_run_with_benchmark")
 def run_with_benchmark(mod):
@@ -40,18 +41,24 @@ def run_with_benchmark(mod):
 def onnx_compile(model_string, target, target_host, opt_level, input_shapes):
     model = onnx.load_model_from_string(bytes(model_string))
 
-    input_shapes = {name : shape for (name, shape) in zip([i.name for i in model.graph.input], input_shapes)}
+    input_mapping = [(name , shape) for (name, shape) in zip([i.name for i in model.graph.input], input_shapes)]
+    # Using an ordereddict maintains input ordering.
+    shape_dict = collections.OrderedDict(input_mapping)
 
-    irmod, params = tvm.relay.frontend.from_onnx(model, input_shapes, opset=11)
+    irmod, params = tvm.relay.frontend.from_onnx(model, shape_dict, opset=11)
+    print(irmod)
+    # import ipdb; ipdb.set_trace()
     with tvm.relay.build_config(opt_level=opt_level):
         tuning_logfile = os.getenv("AUTOTVM_TUNING_LOG")
         if tuning_logfile:
             with autotvm.apply_history_best(tuning_logfile):
-                graph, lib, params = tvm.relay.build(irmod, target_host=target_host, target=target, params=params)
+                # XXX: do not pass parameters to relay.build otherwise they will be inline into the module
+                lib = tvm.relay.build(irmod, target_host=target_host, target=target)
         else:
-            graph, lib, params = tvm.relay.build(irmod, target_host=target_host, target=target, params=params)
+            lib = tvm.relay.build(irmod, target_host=target_host, target=target)
 
+    print(lib.graph_json)
     ctx = tvm.context(target, 0)
-    m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
-    m.set_input(**params)
+    m = tvm.contrib.graph_runtime.GraphModule(lib["default"](ctx))
+    # m.set_input(**params)
     return m.module
diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc
index 6d586cf..5df9068 100644
--- a/src/runtime/graph/graph_runtime.cc
+++ b/src/runtime/graph/graph_runtime.cc
@@ -84,6 +84,13 @@ void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module modu
     const uint32_t nid = input_nodes_[i];
     std::string& name = nodes_[nid].name;
     input_map_[name] = i;
+    uint32_t eid = this->entry_id(input_nodes_[i], 0);
+    const DLTensor* old_t = data_entry_[eid].operator->();
+    std::stringstream s;
+    for(int ind = 0; ind < old_t->ndim; ind++) {
+      s << old_t->shape[ind] << " ";
+    }
+    LOG(INFO) << s.str();
   }
 }
 /*!
@@ -120,7 +127,13 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) {
 
   // check the consistency of input
   ICHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*data_ref));
-  ICHECK_EQ(reinterpret_cast<size_t>(data_ref->data) % kAllocAlignment, 0);
+  // ICHECK_EQ(reinterpret_cast<size_t>(data_ref->data) % kAllocAlignment, 0) << data_ref->data;
+  for(int i = 0; i < old_t->ndim; i++) {
+    LOG(INFO) << "OLD " << old_t->shape[i];
+  }
+  for(int i = 0; i < data_ref->ndim; i++) {
+    LOG(INFO) << "DATA_REF " << data_ref->shape[i];
+  }
   ICHECK_EQ(old_t->ndim, static_cast<size_t>(data_ref->ndim));
   ICHECK_EQ(old_t->ctx.device_type, data_ref->ctx.device_type);
   ICHECK_EQ(old_t->ctx.device_id, data_ref->ctx.device_id);
diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc
index a0f0b76..0b12c1c 100644
--- a/src/target/spirv/build_vulkan.cc
+++ b/src/target/spirv/build_vulkan.cc
@@ -23,7 +23,7 @@
  */
 // Use libspirv for parsing and validating code.
 #include <dmlc/memory_io.h>
-#include <libspirv.h>
+#include <spirv-tools/libspirv.h>
 #include <tvm/tir/transform.h>
 
 #include "../../runtime/vulkan/vulkan_module.h"
diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc
index 90b2eb2..601f74a 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -20,7 +20,7 @@
 /*!
  * \file intrin_rule_spirv.cc
  */
-#include <GLSL.std.450.h>
+#include <SPIRV/GLSL.std.450.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h
index 8a08048..20e2f1b 100644
--- a/src/target/spirv/ir_builder.h
+++ b/src/target/spirv/ir_builder.h
@@ -34,7 +34,7 @@
 #include <unordered_map>
 #include <utility>
 #include <vector>
-#include <spirv.hpp>
+#include <SPIRV/spirv.hpp>
 // clang-format on
 
 namespace tvm {