You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2021/04/09 19:17:22 UTC
[tvm] 01/09: Initial commit for AMD proposal of ONNXRT<>TVM
This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch checkpoint
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit df61188d938af20063c07ace40314a80b7f2dc32
Author: mei-ye <me...@yahoo.com>
AuthorDate: Thu Aug 20 22:56:52 2020 -0700
Initial commit for AMD proposal of ONNXRT<>TVM
---
include/tvm/driver/jit_interface.h | 10 +++++++
src/driver/driver_api.cc | 58 ++++++++++++++++++++++++++++++++++++++
src/relay/backend/build_module.cc | 43 ++++++++++++++++++++++++++++
3 files changed, 111 insertions(+)
diff --git a/include/tvm/driver/jit_interface.h b/include/tvm/driver/jit_interface.h
new file mode 100644
index 0000000..966d5a8
--- /dev/null
+++ b/include/tvm/driver/jit_interface.h
@@ -0,0 +1,10 @@
+#define EXPORT_DLL __attribute__((visibility("default")))
+
+#ifdef __cplusplus
+extern "C" {
+ EXPORT_DLL tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level);
+ EXPORT_DLL void TVMRun(tvm::runtime::Module& mod, const std::string& name, tvm::runtime::TVMArgs& args, tvm::runtime::TVMRetValue* ret);
+
+
+} // TVM_EXTERN_C
+#endif
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index bbbb7e3..758f019 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -23,13 +23,25 @@
*/
#include <dmlc/thread_local.h>
#include <tvm/driver/driver_api.h>
+#include <tvm/driver/jit_interface.h>
+#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/op_strategy.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/codegen.h>
#include <tvm/te/operation.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/transform.h>
+#include <topi/generic/injective.h>
+#include <tvm/target/generic_func.h>
#include <algorithm>
#include <mutex>
@@ -324,3 +336,49 @@ runtime::Module build(const IRModule& funcs, const Target& target, const Target&
}
} // namespace tvm
+
+
+tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level)
+{
+ auto tensor_type = tvm::relay::TensorType({1, 6}, tvm::runtime::DataType::Float(32));
+ auto X1 = tvm::relay::Var("X1", tensor_type);
+ auto mul_op = tvm::relay::Op::Get("multiply");
+ auto mul1 = tvm::relay::Call(mul_op, {X1, X1}, tvm::Attrs(), {});
+ auto mul2 = tvm::relay::Call(mul_op, {X1, mul1}, tvm::Attrs(), {});
+ auto mul3 = tvm::relay::Call(mul_op, {X1, mul2}, tvm::Attrs(), {});
+ auto Y4 = tvm::relay::Call(mul_op, {X1, mul3}, tvm::Attrs(), {});
+ auto func = tvm::relay::Function(tvm::relay::FreeVars(Y4), Y4, tvm::relay::Type(), {});
+
+ auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr");
+ if (!reg)
+ LOG(FATAL) << "no _Register";
+
+ auto fs = tvm::runtime::Registry::Get("jit.strategy");
+ if (!fs)
+ LOG(FATAL) << "No jit strategy registered.";
+
+ auto fgeneric = tvm::GenericFunc::Get("jit.strategy_generic").set_default(*fs);
+ (*reg)("multiply", "FTVMStrategy", fgeneric, 10);
+ (*reg)("multiply", "TShapeDataDependant", false, 10);
+
+ auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
+ tvm::runtime::Module build_mod = (*pfb)();
+ auto build_f = build_mod.GetFunction("build", false);
+ auto mod_f = build_mod.GetFunction("get_module", false);
+ auto relay_mod = tvm::IRModule::FromExpr(func);
+ tvm::Map<tvm::Integer, tvm::Target> targets;
+ // tvm::Target tgt = tvm::Target::Create(target);
+ tvm::Target tgt = tvm::Target::Create("llvm");
+ targets.Set(0, tgt);
+ // tvm::Target host = (target == target_host) ? tgt : tvm::Target::Create(target_host);
+ build_f(relay_mod, targets, tgt);
+ tvm::runtime::Module mod = mod_f();
+ return mod;
+}
+
+void TVMRun(tvm::runtime::Module& mod, const std::string& name, tvm::runtime::TVMArgs& args, tvm::runtime::TVMRetValue* ret)
+{
+ mod.GetFunction(name).CallPacked(args, ret);
+ // process return value, refe to TVMFuncCall in c_runtime_api.cc
+
+}
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index 0884692..3c047e1 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -28,6 +28,10 @@
#include <tvm/relay/qnn/transform.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/device_api.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/op_strategy.h>
+#include <topi/broadcast.h>
+#include <topi/generic/injective.h>
#include <memory>
@@ -553,6 +557,45 @@ runtime::Module RelayBuildCreate() {
return runtime::Module(exec);
}
+#if 1
+TVM_REGISTER_GLOBAL("jit.strategy")
+ .set_body_typed([](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type,
+ const Target& target) {
+ FTVMCompute fcompute = [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+ const Type& out_type) -> Array<te::Tensor> {
+ CHECK_EQ(inputs.size(), 2U);
+ return {topi::multiply(inputs[0], inputs[1])};
+ };
+ FTVMSchedule fschedule = [](const Attrs& attrs, const Array<te::Tensor>& outs,
+ const Target& target) {
+ With<Target> target_scope(target);
+ return topi::generic::schedule_injective(target, outs);
+ };
+
+ auto n = make_object<OpStrategyNode>();
+ auto strategy = relay::OpStrategy(std::move(n));
+ strategy.AddImplementation(fcompute, fschedule, "jit.strategy", 10);
+ return strategy;
+});
+
+
+TVM_REGISTER_GLOBAL("relay.backend.lower_call")
+ .set_body_typed([](const relay::Call& call, const Array<te::Tensor>& inputs,
+ const Target& target) {
+ static auto fstrategy = Op::GetAttrMap<relay::FTVMStrategy>("FTVMStrategy");
+ Op op = Downcast<Op>(call->op);
+ auto out_type = call->checked_type();
+ OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target);
+ auto impl = strategy->specializations[0]->implementations[0];
+ auto outs = impl.Compute(call->attrs, inputs, out_type);
+ auto f = runtime::Registry::Get("relay.backend._make_LoweredOutput");
+ if (!f) {
+ LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered";
+ }
+ return (*f)(outs, impl);
+});
+#endif
+
TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RelayBuildCreate();
});