You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/22 15:08:01 UTC

[GitHub] [incubator-tvm] tqchen commented on a change in pull request #6917: Add Relay option to link parameters into runtime Modules

tqchen commented on a change in pull request #6917:
URL: https://github.com/apache/incubator-tvm/pull/6917#discussion_r528343180



##########
File path: include/tvm/tir/function.h
##########
@@ -150,6 +151,31 @@ class PrimFunc : public BaseFunc {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
 };
 
+class LinkedParamNode : public Object {
+ public:
+  /*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
+  int64_t id;
+
+  /*! \brief Parameter data which should get linked into the final module. */
+  ::tvm::runtime::NDArray param;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("id", &id);
+    v->Visit("param", &param);
+  }
+
+  static constexpr const char* _type_key = "tir.LinkedParam";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
+};
+
+class LinkedParam : public ObjectRef {
+ public:
+  LinkedParam(int64_t id, ::tvm::runtime::NDArray param);

Review comment:
       Add TVM_DLL so LinkParam can be constructed elsewhere

##########
File path: src/runtime/graph/graph_runtime.cc
##########
@@ -244,6 +249,39 @@ void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
   this->SetupOpExecs();
 }
 
+void GraphRuntime::PreAllocatedDLTensorDeleter(DLManagedTensor* tensor) {

Review comment:
       This function is not used anywhere. Likely we need a NDArray::Container's deleter instead

##########
File path: src/target/llvm/codegen_params.cc
##########
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file codegen_params.cc
+ */
+#ifdef TVM_LLVM_VERSION
+
+#include "codegen_params.h"
+
+#include <memory>
+#include <vector>
+
+namespace tvm {
+namespace codegen {
+
+namespace {
+class DLManagedTensorDeleter {
+ public:
+  void operator()(DLManagedTensor* ptr) { ptr->deleter(ptr); }
+};
+}  // namespace
+
+llvm::ConstantArray* NDArrayToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::NDArray arr) {
+  llvm::Type* element_type = nullptr;
+

Review comment:
       I wonder whether or not should we instead look into LLVM's target machine to lookup the endian parameter, and do byte swap ourselves.
   
   This would allow us to generalizes to data types that goes beyond the types that can recognized by LLVM, as long as we know the target machine layout
   
   See https://github.com/apache/incubator-tvm/blob/main/src/runtime/rpc/rpc_endpoint.cc#L429

##########
File path: src/relay/backend/build_module.cc
##########
@@ -443,16 +456,35 @@ class RelayBuildModule : public runtime::ModuleNode {
 
     auto lowered_funcs = graph_codegen_->GetIRModule();
 
-    // When there is no lowered_funcs due to reasons such as optimization.
-    if (lowered_funcs.size() == 0) {
-      Target target_host = GetTargetHost();
+    Target target_host = GetTargetHost();
+    // If no target_host has been set, we choose a default one, which is
+    // llvm if "codegen.LLVMModuleCreate" is accessible.
+    const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate");
+    if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm");
+
+    if (target_host->GetAttr<Bool>("link-params").value_or(Bool(false))) {

Review comment:
       Add comment to the code block. Generate a placeholder function that attaches linked params as its arguments

##########
File path: include/tvm/tir/function.h
##########
@@ -150,6 +151,31 @@ class PrimFunc : public BaseFunc {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
 };
 
+class LinkedParamNode : public Object {

Review comment:
       Document this class name as well
   

##########
File path: src/runtime/graph/graph_runtime.cc
##########
@@ -64,14 +64,19 @@ void GraphRuntime::Run() {
  * processor.
  * \param ctxs The context of the host and devices where graph nodes will be
  * executed on.
+ * \param lookup_linked_param_func Linked parameter lookup function.
  */
 void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module module,
-                        const std::vector<TVMContext>& ctxs) {
+                        const std::vector<TVMContext>& ctxs, PackedFunc lookup_linked_param_func) {

Review comment:
       Consider use TypedPackedFunc to add type signatures

##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -582,7 +596,13 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
         String key = args[0];
         ICHECK_GT(this->output_.params.count(key), 0);
-        *rv = this->output_.params[key];
+        *rv = this->output_.params[key].second;
+      });
+    } else if (name == "get_param_id") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        String key = args[0];
+        ICHECK_GT(this->output_.params.count(key), 0);

Review comment:
       use `auto it = this->output_.params.find`, and return `it->first` to avoid double lookup of the hash table(count then [])

##########
File path: src/runtime/rpc/rpc_module.cc
##########
@@ -36,6 +37,44 @@
 namespace tvm {
 namespace runtime {
 
+// deleter of RPC remote array
+static void RemoteNDArrayDeleter(Object* obj) {
+  auto* ptr = static_cast<NDArray::Container*>(obj);
+  RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
+  space->sess->FreeHandle(ptr->manager_ctx, kTVMNDArrayHandle);
+  delete space;
+  delete ptr;
+}
+
+/*!
+ * \brief Build a local NDArray with remote backing storage.
+ * \param sess the RPCSession which owns the given handle.
+ * \param handle A pointer valid on the remote end which should form the `data` field of the
+ *     underlying DLTensor.
+ * \param template_tensor An empty DLTensor whose shape and dtype fields are used to fill the newly
+ *     created array. Needed because it's difficult to pass a shape vector as a PackedFunc arg.
+ * \param ctx Remote context used with this tensor. Must have non-zero RPCSessMask.
+ * \param deleter A function invoked when the local NDArray object is no longer used. If `handle`
+ *      needs to be explicitly deleted after the NDArray is freed, this function should do that.
+ * \param deleter_ctx An opaque pointer passed to deleter to identify the tensor being deleted.
+ */
+NDArray NDArrayFromRemoteOpaqueHandle(std::shared_ptr<RPCSession> sess, void* handle,
+                                      DLTensor* template_tensor, TVMContext ctx,
+                                      ADTObj::FDeleter deleter, void* deleter_ctx) {

Review comment:
       It is a bit strange why deleter is passed in.  There is less flexibility in terms of ctx and interface. Consider directly inline the call into here.
   
   In the case of no deletion of remote (static memory), the deleter still need to delete the NDArray::Container.
   
   A better way might be to pass a flag, to indicate whether it is remote or static

##########
File path: include/tvm/tir/function.h
##########
@@ -150,6 +151,31 @@ class PrimFunc : public BaseFunc {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
 };
 
+class LinkedParamNode : public Object {
+ public:
+  /*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
+  int64_t id;
+
+  /*! \brief Parameter data which should get linked into the final module. */
+  ::tvm::runtime::NDArray param;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("id", &id);
+    v->Visit("param", &param);
+  }
+
+  static constexpr const char* _type_key = "tir.LinkedParam";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
+};
+
+class LinkedParam : public ObjectRef {

Review comment:
       document the class as managed reference to LinkParamNode

##########
File path: src/runtime/graph/graph_runtime.cc
##########
@@ -244,6 +249,39 @@ void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
   this->SetupOpExecs();
 }
 
+void GraphRuntime::PreAllocatedDLTensorDeleter(DLManagedTensor* tensor) {
+  // ctx is the DLTensor which needs to get deleted. The data member points to global const memory.
+  delete reinterpret_cast<DLTensor*>(tensor);
+}
+
+void GraphRuntime::DefaultLookupLinkedParam(TVMArgs args, TVMRetValue* rv) {
+  Module mod = args[0];
+  int64_t storage_id = args[1];
+  DLTensor* template_tensor = args[2];
+  TVMContext ctx = args[3];
+  // Get pre-linked parameter lookup function, if it was generated. When pf == nullptr, no linked
+  // params are present.
+  tvm::runtime::PackedFunc pf =

Review comment:
       shall we cache the look up in the graph runtim?

##########
File path: src/runtime/graph/graph_runtime.cc
##########
@@ -64,14 +64,19 @@ void GraphRuntime::Run() {
  * processor.
  * \param ctxs The context of the host and devices where graph nodes will be
  * executed on.
+ * \param lookup_linked_param_func Linked parameter lookup function.
  */
 void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module module,
-                        const std::vector<TVMContext>& ctxs) {
+                        const std::vector<TVMContext>& ctxs, PackedFunc lookup_linked_param_func) {
   std::istringstream is(graph_json);
   dmlc::JSONReader reader(&is);
   this->Load(&reader);
   module_ = module;
   ctxs_ = ctxs;
+  lookup_linked_param_ = lookup_linked_param_func;
+  if (lookup_linked_param_ == nullptr) {
+    lookup_linked_param_ = PackedFunc(&GraphRuntime::DefaultLookupLinkedParam);

Review comment:
       & is not necessary

##########
File path: src/runtime/graph/graph_runtime.cc
##########
@@ -244,6 +249,39 @@ void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
   this->SetupOpExecs();
 }
 
+void GraphRuntime::PreAllocatedDLTensorDeleter(DLManagedTensor* tensor) {
+  // ctx is the DLTensor which needs to get deleted. The data member points to global const memory.
+  delete reinterpret_cast<DLTensor*>(tensor);
+}
+
+void GraphRuntime::DefaultLookupLinkedParam(TVMArgs args, TVMRetValue* rv) {
+  Module mod = args[0];
+  int64_t storage_id = args[1];
+  DLTensor* template_tensor = args[2];
+  TVMContext ctx = args[3];
+  // Get pre-linked parameter lookup function, if it was generated. When pf == nullptr, no linked
+  // params are present.
+  tvm::runtime::PackedFunc pf =
+      mod.GetFunction(::tvm::runtime::symbol::tvm_lookup_linked_param, true);
+  if (pf == nullptr) {
+    *rv = nullptr;
+    return;
+  }
+
+  TVMRetValue opaque_handle = pf(storage_id);
+  if (opaque_handle.type_code() == kTVMNullptr) {
+    *rv = nullptr;
+    return;
+  }
+
+  std::vector<int64_t> shape_vec{template_tensor->shape,
+                                 template_tensor->shape + template_tensor->ndim};
+
+  std::unique_ptr<NDArray::Container> container{new NDArray::Container(

Review comment:
       Need to set the deleter. refer to other NDArray allocation examples

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -184,6 +188,104 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
   }
 }
 
+void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
+  // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc,
+  // but they are at a different layer in the compiler...
+  std::vector<llvm::Type*> param_types;
+  // args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // num_args
+  param_types.push_back(t_int_);
+  // ret_args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // ret_tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // resource_handle
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+
+  // TODO(tvm-team):
+  // Update the function type to respect the ret_type field of f.
+  // Once we allow more flexibility in the PrimFunc.
+  llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false);
+
+  llvm::Function* function =
+      llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
+                             ::tvm::runtime::symbol::tvm_lookup_linked_param, module_.get());
+  function->setCallingConv(llvm::CallingConv::C);
+  function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+
+  llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function);
+  builder_->SetInsertPoint(entry);
+  std::vector<llvm::Value*> zero_index_list{llvm::ConstantInt::get(t_int32_, 0)};
+  std::vector<llvm::Value*> zero_array_index_list{llvm::ConstantInt::get(t_int32_, 0),
+                                                  llvm::ConstantInt::get(t_int32_, 0)};
+  auto args_array = builder_->CreateBitCast(
+#if TVM_LLVM_VERSION >= 50
+      &function->arg_begin()[0],
+#else
+      &(*(function->arg_begin())),
+#endif
+      llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1));
+  llvm::Value* sid = builder_->CreateBitCast(
+      builder_->CreateLoad(t_void_->getPointerTo(GetGlobalAddressSpace()),
+                           builder_->CreateInBoundsGEP(args_array, zero_index_list)),
+      t_int64_);
+
+  llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function);
+  llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1);
+
+  builder_->SetInsertPoint(default_block);
+  {
+    auto ret_types_array = builder_->CreateBitCast(
+#if TVM_LLVM_VERSION >= 50
+        &function->arg_begin()[4],
+#else
+        &(*(std::next(function->arg_begin(), 4))),
+#endif
+        llvm::ArrayType::get(t_int_, 1)->getPointerTo());
+
+    builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr),
+                          builder_->CreateGEP(ret_types_array, zero_array_index_list));
+    builder_->CreateRet(ConstInt32(kTvmErrorNoError));
+  }
+
+  llvm::raw_os_ostream os{std::cout};

Review comment:
       Add comment: add data to the global section

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -184,6 +188,104 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
   }
 }
 
+void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
+  // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc,
+  // but they are at a different layer in the compiler...
+  std::vector<llvm::Type*> param_types;
+  // args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // num_args
+  param_types.push_back(t_int_);
+  // ret_args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // ret_tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // resource_handle
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+
+  // TODO(tvm-team):
+  // Update the function type to respect the ret_type field of f.

Review comment:
       This TODO is not necessary. Since the return value is passed by ret_tcode and ret_args. For PackedFunc the return code is always int

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -184,6 +188,104 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
   }
 }
 
+void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
+  // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc,
+  // but they are at a different layer in the compiler...
+  std::vector<llvm::Type*> param_types;
+  // args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // num_args
+  param_types.push_back(t_int_);
+  // ret_args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // ret_tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // resource_handle
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+
+  // TODO(tvm-team):
+  // Update the function type to respect the ret_type field of f.
+  // Once we allow more flexibility in the PrimFunc.
+  llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false);
+
+  llvm::Function* function =
+      llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
+                             ::tvm::runtime::symbol::tvm_lookup_linked_param, module_.get());
+  function->setCallingConv(llvm::CallingConv::C);
+  function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+
+  llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function);
+  builder_->SetInsertPoint(entry);
+  std::vector<llvm::Value*> zero_index_list{llvm::ConstantInt::get(t_int32_, 0)};
+  std::vector<llvm::Value*> zero_array_index_list{llvm::ConstantInt::get(t_int32_, 0),
+                                                  llvm::ConstantInt::get(t_int32_, 0)};
+  auto args_array = builder_->CreateBitCast(
+#if TVM_LLVM_VERSION >= 50
+      &function->arg_begin()[0],
+#else
+      &(*(function->arg_begin())),
+#endif
+      llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1));
+  llvm::Value* sid = builder_->CreateBitCast(
+      builder_->CreateLoad(t_void_->getPointerTo(GetGlobalAddressSpace()),
+                           builder_->CreateInBoundsGEP(args_array, zero_index_list)),
+      t_int64_);
+
+  llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function);
+  llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1);
+
+  builder_->SetInsertPoint(default_block);
+  {
+    auto ret_types_array = builder_->CreateBitCast(
+#if TVM_LLVM_VERSION >= 50
+        &function->arg_begin()[4],
+#else
+        &(*(std::next(function->arg_begin(), 4))),
+#endif
+        llvm::ArrayType::get(t_int_, 1)->getPointerTo());
+
+    builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr),
+                          builder_->CreateGEP(ret_types_array, zero_array_index_list));

Review comment:
       Consider use CreateInBoundGEP instead and pass in constant (instead of array) https://github.com/apache/incubator-tvm/blob/main/src/target/llvm/codegen_llvm.cc#L660
   
   

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -184,6 +188,104 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
   }
 }
 
+void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
+  // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc,

Review comment:
       It is possible to list more into make_packed_api later, if we have intrinsics like
   
   ```
   value = tir.lookup_linked_param_in_attr(id)
   
   ```
   
   The next steps should be supported once we have return value in PackedFunc. cc @ZihengJiang 

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -184,6 +188,104 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
   }
 }
 
+void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
+  // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc,
+  // but they are at a different layer in the compiler...
+  std::vector<llvm::Type*> param_types;
+  // args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // num_args
+  param_types.push_back(t_int_);
+  // ret_args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // ret_tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // resource_handle
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+
+  // TODO(tvm-team):
+  // Update the function type to respect the ret_type field of f.
+  // Once we allow more flexibility in the PrimFunc.
+  llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false);
+
+  llvm::Function* function =
+      llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
+                             ::tvm::runtime::symbol::tvm_lookup_linked_param, module_.get());
+  function->setCallingConv(llvm::CallingConv::C);
+  function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+
+  llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function);
+  builder_->SetInsertPoint(entry);
+  std::vector<llvm::Value*> zero_index_list{llvm::ConstantInt::get(t_int32_, 0)};
+  std::vector<llvm::Value*> zero_array_index_list{llvm::ConstantInt::get(t_int32_, 0),
+                                                  llvm::ConstantInt::get(t_int32_, 0)};
+  auto args_array = builder_->CreateBitCast(
+#if TVM_LLVM_VERSION >= 50
+      &function->arg_begin()[0],
+#else
+      &(*(function->arg_begin())),
+#endif
+      llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1));
+  llvm::Value* sid = builder_->CreateBitCast(
+      builder_->CreateLoad(t_void_->getPointerTo(GetGlobalAddressSpace()),
+                           builder_->CreateInBoundsGEP(args_array, zero_index_list)),
+      t_int64_);
+
+  llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function);
+  llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1);
+
+  builder_->SetInsertPoint(default_block);
+  {
+    auto ret_types_array = builder_->CreateBitCast(
+#if TVM_LLVM_VERSION >= 50
+        &function->arg_begin()[4],
+#else
+        &(*(std::next(function->arg_begin(), 4))),
+#endif
+        llvm::ArrayType::get(t_int_, 1)->getPointerTo());
+
+    builder_->CreateStore(llvm::ConstantInt::get(t_int_, kTVMNullptr),
+                          builder_->CreateGEP(ret_types_array, zero_array_index_list));
+    builder_->CreateRet(ConstInt32(kTvmErrorNoError));
+  }
+
+  llvm::raw_os_ostream os{std::cout};
+
+  for (auto kv : params) {
+    auto array = NDArrayToLLVMArray(ctx_, kv.second->param);
+    std::string symbol_name = std::string(::tvm::runtime::symbol::tvm_param_prefix) + kv.first;
+    llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable(
+        *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name);
+
+    llvm::BasicBlock* case_block = llvm::BasicBlock::Create(*ctx_, "case_" + symbol_name, function);
+    switch_inst->addCase(
+        llvm::cast<llvm::ConstantInt>(llvm::ConstantInt::get(t_int64_, kv.second->id)), case_block);
+    builder_->SetInsertPoint(case_block);
+    auto retval_array = builder_->CreateBitCast(
+#if TVM_LLVM_VERSION >= 50

Review comment:
       Consider move argument cast to the beginning

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -184,6 +188,104 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
   }
 }
 
+void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
+  // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc,
+  // but they are at a different layer in the compiler...
+  std::vector<llvm::Type*> param_types;
+  // args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // num_args
+  param_types.push_back(t_int_);
+  // ret_args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // ret_tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // resource_handle
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+
+  // TODO(tvm-team):
+  // Update the function type to respect the ret_type field of f.
+  // Once we allow more flexibility in the PrimFunc.
+  llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false);
+
+  llvm::Function* function =
+      llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
+                             ::tvm::runtime::symbol::tvm_lookup_linked_param, module_.get());
+  function->setCallingConv(llvm::CallingConv::C);
+  function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+
+  llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function);
+  builder_->SetInsertPoint(entry);
+  std::vector<llvm::Value*> zero_index_list{llvm::ConstantInt::get(t_int32_, 0)};
+  std::vector<llvm::Value*> zero_array_index_list{llvm::ConstantInt::get(t_int32_, 0),
+                                                  llvm::ConstantInt::get(t_int32_, 0)};
+  auto args_array = builder_->CreateBitCast(
+#if TVM_LLVM_VERSION >= 50
+      &function->arg_begin()[0],
+#else
+      &(*(function->arg_begin())),
+#endif
+      llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1));
+  llvm::Value* sid = builder_->CreateBitCast(
+      builder_->CreateLoad(t_void_->getPointerTo(GetGlobalAddressSpace()),
+                           builder_->CreateInBoundsGEP(args_array, zero_index_list)),
+      t_int64_);
+
+  llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function);

Review comment:
       default => end?

##########
File path: src/target/llvm/codegen_params.cc
##########
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file codegen_params.cc
+ */
+#ifdef TVM_LLVM_VERSION
+
+#include "codegen_params.h"
+
+#include <memory>
+#include <vector>
+
+namespace tvm {
+namespace codegen {
+
+namespace {
+class DLManagedTensorDeleter {
+ public:
+  void operator()(DLManagedTensor* ptr) { ptr->deleter(ptr); }
+};
+}  // namespace
+
+llvm::ConstantArray* NDArrayToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::NDArray arr) {
+  llvm::Type* element_type = nullptr;
+
+  auto arr_type = arr.DataType();
+  CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw "
+                                << arr_type.lanes();
+
+  auto shape = arr.Shape();
+  int num_elements = 1;

Review comment:
       Need to CHECK arr IsContiguous() and is on CPU

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -184,6 +188,104 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
   }
 }
 
+void CodeGenLLVM::LinkParameters(const Map<String, LinkedParam> params) {
+  // It would be nice to de-dupe these declarations frm src/tir/transforms/make_packed_api.cc,
+  // but they are at a different layer in the compiler...
+  std::vector<llvm::Type*> param_types;
+  // args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // num_args
+  param_types.push_back(t_int_);
+  // ret_args
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+  // ret_tcodes
+  param_types.push_back(t_int_->getPointerTo(GetGlobalAddressSpace()));
+  // resource_handle
+  param_types.push_back(t_void_->getPointerTo(GetGlobalAddressSpace()));
+
+  // TODO(tvm-team):
+  // Update the function type to respect the ret_type field of f.
+  // Once we allow more flexibility in the PrimFunc.
+  llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, param_types, false);
+
+  llvm::Function* function =
+      llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
+                             ::tvm::runtime::symbol::tvm_lookup_linked_param, module_.get());
+  function->setCallingConv(llvm::CallingConv::C);
+  function->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+
+  llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function);
+  builder_->SetInsertPoint(entry);
+  std::vector<llvm::Value*> zero_index_list{llvm::ConstantInt::get(t_int32_, 0)};
+  std::vector<llvm::Value*> zero_array_index_list{llvm::ConstantInt::get(t_int32_, 0),
+                                                  llvm::ConstantInt::get(t_int32_, 0)};
+  auto args_array = builder_->CreateBitCast(
+#if TVM_LLVM_VERSION >= 50
+      &function->arg_begin()[0],
+#else
+      &(*(function->arg_begin())),
+#endif
+      llvm::ArrayType::get(t_void_->getPointerTo(GetGlobalAddressSpace()), 1));
+  llvm::Value* sid = builder_->CreateBitCast(
+      builder_->CreateLoad(t_void_->getPointerTo(GetGlobalAddressSpace()),
+                           builder_->CreateInBoundsGEP(args_array, zero_index_list)),
+      t_int64_);
+
+  llvm::BasicBlock* default_block = llvm::BasicBlock::Create(*ctx_, "default_block", function);
+  llvm::SwitchInst* switch_inst = builder_->CreateSwitch(sid, default_block, params.size() + 1);
+
+  builder_->SetInsertPoint(default_block);
+  {
+    auto ret_types_array = builder_->CreateBitCast(

Review comment:
       Consider move argument cast to the beginning

##########
File path: src/target/llvm/codegen_params.cc
##########
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file codegen_params.cc
+ */
+#ifdef TVM_LLVM_VERSION
+
+#include "codegen_params.h"
+
+#include <memory>
+#include <vector>
+
+namespace tvm {
+namespace codegen {
+
+namespace {
+class DLManagedTensorDeleter {
+ public:
+  void operator()(DLManagedTensor* ptr) { ptr->deleter(ptr); }
+};
+}  // namespace
+
+llvm::ConstantArray* NDArrayToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::NDArray arr) {
+  llvm::Type* element_type = nullptr;
+
+  auto arr_type = arr.DataType();
+  CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw "
+                                << arr_type.lanes();
+
+  auto shape = arr.Shape();
+  int num_elements = 1;
+  for (auto shape_elem : shape) {
+    num_elements *= shape_elem;
+  }
+

Review comment:
       I don't think we need to convert things to DLPack. arr.operator-> already gives you a DLTensor*




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org