You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2021/04/09 19:17:22 UTC

[tvm] 01/09: Initial commit for AMD proposal of ONNXRT<>TVM

This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch checkpoint
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit df61188d938af20063c07ace40314a80b7f2dc32
Author: mei-ye <me...@yahoo.com>
AuthorDate: Thu Aug 20 22:56:52 2020 -0700

    Initial commit for AMD proposal of ONNXRT<>TVM
---
 include/tvm/driver/jit_interface.h | 10 +++++++
 src/driver/driver_api.cc           | 58 ++++++++++++++++++++++++++++++++++++++
 src/relay/backend/build_module.cc  | 43 ++++++++++++++++++++++++++++
 3 files changed, 111 insertions(+)

diff --git a/include/tvm/driver/jit_interface.h b/include/tvm/driver/jit_interface.h
new file mode 100644
index 0000000..966d5a8
--- /dev/null
+++ b/include/tvm/driver/jit_interface.h
@@ -0,0 +1,10 @@
+#define EXPORT_DLL __attribute__((visibility("default")))
+
+#ifdef __cplusplus
+extern "C" {
+    EXPORT_DLL tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level);
+    EXPORT_DLL void TVMRun(tvm::runtime::Module& mod, const std::string& name, tvm::runtime::TVMArgs& args, tvm::runtime::TVMRetValue* ret);
+    
+    
+}  // TVM_EXTERN_C
+#endif
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index bbbb7e3..758f019 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -23,13 +23,25 @@
  */
 #include <dmlc/thread_local.h>
 #include <tvm/driver/driver_api.h>
+#include <tvm/driver/jit_interface.h>
+#include <tvm/ir/module.h>
 #include <tvm/ir/transform.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/op_strategy.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/container.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/target/codegen.h>
 #include <tvm/te/operation.h>
 #include <tvm/tir/analysis.h>
 #include <tvm/tir/transform.h>
+#include <topi/generic/injective.h>
+#include <tvm/target/generic_func.h>
 
 #include <algorithm>
 #include <mutex>
@@ -324,3 +336,49 @@ runtime::Module build(const IRModule& funcs, const Target& target, const Target&
 }
 
 }  // namespace tvm
+
+
+tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level)
+{
+    auto tensor_type = tvm::relay::TensorType({1, 6}, tvm::runtime::DataType::Float(32));
+    auto X1 = tvm::relay::Var("X1", tensor_type);
+    auto mul_op = tvm::relay::Op::Get("multiply");
+    auto mul1 = tvm::relay::Call(mul_op, {X1, X1}, tvm::Attrs(), {});
+    auto mul2 = tvm::relay::Call(mul_op, {X1, mul1}, tvm::Attrs(), {});
+    auto mul3 = tvm::relay::Call(mul_op, {X1, mul2}, tvm::Attrs(), {});
+    auto Y4 = tvm::relay::Call(mul_op, {X1, mul3}, tvm::Attrs(), {});
+    auto func = tvm::relay::Function(tvm::relay::FreeVars(Y4), Y4, tvm::relay::Type(), {});
+
+    auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr");
+    if (!reg)
+        LOG(FATAL) << "no _Register";
+
+    auto fs = tvm::runtime::Registry::Get("jit.strategy");
+    if (!fs)
+        LOG(FATAL) << "No jit strategy registered.";
+
+    auto fgeneric = tvm::GenericFunc::Get("jit.strategy_generic").set_default(*fs);
+    (*reg)("multiply", "FTVMStrategy", fgeneric, 10);
+    (*reg)("multiply", "TShapeDataDependant", false, 10);
+   
+    auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
+    tvm::runtime::Module build_mod = (*pfb)();
+    auto build_f = build_mod.GetFunction("build", false);
+    auto mod_f = build_mod.GetFunction("get_module", false);
+    auto relay_mod = tvm::IRModule::FromExpr(func);
+    tvm::Map<tvm::Integer, tvm::Target> targets;
+    // tvm::Target tgt = tvm::Target::Create(target);
+    tvm::Target tgt = tvm::Target::Create("llvm");
+    targets.Set(0, tgt);
+    // tvm::Target host = (target == target_host) ? tgt : tvm::Target::Create(target_host);
+    build_f(relay_mod, targets, tgt);
+    tvm::runtime::Module mod = mod_f();
+    return mod;
+}
+
+void TVMRun(tvm::runtime::Module& mod, const std::string& name, tvm::runtime::TVMArgs& args, tvm::runtime::TVMRetValue* ret)
+{
+    mod.GetFunction(name).CallPacked(args, ret);
+    // process return value, refe to TVMFuncCall in c_runtime_api.cc
+    
+}
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index 0884692..3c047e1 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -28,6 +28,10 @@
 #include <tvm/relay/qnn/transform.h>
 #include <tvm/relay/transform.h>
 #include <tvm/runtime/device_api.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/op_strategy.h>
+#include <topi/broadcast.h>
+#include <topi/generic/injective.h>
 
 #include <memory>
 
@@ -553,6 +557,45 @@ runtime::Module RelayBuildCreate() {
   return runtime::Module(exec);
 }
 
+#if 1
+TVM_REGISTER_GLOBAL("jit.strategy")
+    .set_body_typed([](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type,
+                       const Target& target) {
+      FTVMCompute fcompute = [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+                                const Type& out_type) -> Array<te::Tensor> {
+        CHECK_EQ(inputs.size(), 2U);
+        return {topi::multiply(inputs[0], inputs[1])};
+      };
+      FTVMSchedule fschedule = [](const Attrs& attrs, const Array<te::Tensor>& outs,
+                                  const Target& target) {
+        With<Target> target_scope(target);
+        return topi::generic::schedule_injective(target, outs);
+      };
+
+      auto n = make_object<OpStrategyNode>();
+      auto strategy = relay::OpStrategy(std::move(n));
+      strategy.AddImplementation(fcompute, fschedule, "jit.strategy", 10);
+      return strategy;
+});
+
+    
+TVM_REGISTER_GLOBAL("relay.backend.lower_call")
+    .set_body_typed([](const relay::Call& call, const Array<te::Tensor>& inputs,
+                       const Target& target) {
+      static auto fstrategy = Op::GetAttrMap<relay::FTVMStrategy>("FTVMStrategy");
+      Op op = Downcast<Op>(call->op);
+      auto out_type = call->checked_type();
+      OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target);
+      auto impl = strategy->specializations[0]->implementations[0];
+      auto outs = impl.Compute(call->attrs, inputs, out_type);
+      auto f = runtime::Registry::Get("relay.backend._make_LoweredOutput");
+      if (!f) {
+        LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered";
+      }
+      return (*f)(outs, impl);
+});
+#endif    
+    
 TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
   *rv = RelayBuildCreate();
 });