You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/01 12:56:24 UTC

[GitHub] [tvm] manupa-arm commented on a change in pull request #10753: [AOT] Support LLVM backend with C++ runtime

manupa-arm commented on a change in pull request #10753:
URL: https://github.com/apache/tvm/pull/10753#discussion_r840519145



##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace
+
+class AOTCallGenerator {
+ public:
+  explicit AOTCallGenerator(std::string func_name)
+      : func_name_{func_name}, args_{tvm::tir::StringImm(func_name)} {}
+
+  tir::Var PushArg(PrimExpr arg) {
+    if (!arg->IsInstance<tir::VarNode>()) {
+      arg = MakeLetBind(arg);
+    }
+    args_.push_back(arg);
+    return Downcast<tir::Var>(arg);
+  }
+
+  void PushStackDLTensor(const TensorType& ttype, PrimExpr data) {
+    auto dltensor_var = MakeLetBind(StackAlloca("array", 1));
+    auto shape_var = MakeLetBind(StackAlloca("shape", ttype->shape.size()));
+
+    // Populate DLTensor.data
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrData, data})));
+
+    // Populate DLTensor.device
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrDeviceType, kDLCPU})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrDeviceId, 0})));
+
+    // Populate DLTensor.ndim
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrNDim, static_cast<int32_t>(ttype->shape.size())})));
+
+    // Populate DLTensor.dtype
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrTypeCode,
+                                      IntImm(DataType(kDLUInt, 8, 1), ttype->dtype.code())})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrTypeBits,
+                                      IntImm(DataType(kDLUInt, 8, 1), ttype->dtype.bits())})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrTypeLanes,
+                                      IntImm(DataType(kDLUInt, 16, 1), ttype->dtype.lanes())})));
+
+    // Populate DLTensor.shape
+    for (size_t i = 0; i < ttype->shape.size(); ++i) {
+      prep_stmts_.push_back(tvm::tir::Store(
+          shape_var, IntImm(DataType(kDLInt, 64, 1), Downcast<IntImm>(ttype->shape[i])->value),
+          IntImm(DataType(kDLUInt, 64, 1), i), tir::const_true()));
+    }
+
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrShape, shape_var})));
+
+    // Populate DLTensor.strides. DNS -- TODO actually pull correct byte_offset
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrStrides, IntImm(DataType(kDLUInt, 64, 1), 0)})));
+
+    // Populate DLTensor.byte_offset. DNS -- TODO actually pull correct byte_offset
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrByteOffset, IntImm(DataType(kDLUInt, 64, 1), 0)})));
+
+    args_.push_back(dltensor_var);
+  }
+
+  void PushStackDLTensors(const Expr& expr, std::vector<tir::Var> sids) {
+    const TupleNode* t = expr.as<TupleNode>();
+    if (t != nullptr) {
+      CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't "
+                                                 "handle this type of Relay Expr in a CallNode.";
+      for (size_t i = 0; i < sids.size(); i++) {
+        PushStackDLTensor(Downcast<TensorType>(t->fields[i]->checked_type()), sids[i]);
+      }
+    } else {
+      PushStackDLTensor(Downcast<TensorType>(expr->checked_type()), sids[0]);
+    }
+  }
+
+  tir::Stmt GenerateUnpacked(std::string device_name, PrimExpr device_context) {
+    auto make_call = [this] {
+      return tir::Evaluate(
+          tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args_));
+    };
+    if (device_context.defined()) {
+      tir::Var context_var = PushArg(device_context);
+      return Generate(tir::SeqStmt({
+          MakeDeviceHookCall(device_name, "Open", context_var),
+          make_call(),
+          MakeDeviceHookCall(device_name, "Close", context_var),
+      }));
+    } else {
+      return Generate(make_call());
+    }
+  }
+
+  tir::Stmt GeneratePacked() {
+    return Generate(
+        tir::Evaluate(tvm::tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), args_)));
+  }
+
+  tir::Stmt GenerateCPacked() {
+    // call_cpacked calling convention does not use a context
+    PushArg(tir::make_zero(DataType::Handle()));
+    return Generate(
+        tir::Evaluate(tvm::tir::Call(DataType::Int(32), tir::builtin::tvm_call_cpacked(), args_)));
+  }
+
+ private:
+  tir::Stmt Generate(tir::Stmt call_stmts) {
+    tir::Stmt body = tir::SeqStmt::Flatten(prep_stmts_, call_stmts);
+
+    for (auto bind : let_binds_) {
+      body = tir::LetStmt(bind.first, bind.second, body);

Review comment:
       Why do we need a Let binding here? Cant we just arg bind w/o introducing a let node here ?
   
   https://github.com/apache/tvm/blob/95df0eb1461718d9d1453d2ba4beb9441c5cab3c/src/tir/transforms/arg_binder.h#L74-L75

##########
File path: src/target/llvm/llvm_module.cc
##########
@@ -527,6 +527,46 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
       return runtime::Module(n);
     });
 
+runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target,
+                                            tvm::relay::Runtime runtime) {
+  InitializeLLVM();
+  auto tm = GetLLVMTargetMachine(target);
+  bool system_lib = runtime->GetAttr<Bool>("system-lib").value_or(Bool(false));
+  auto ctx = std::make_shared<llvm::LLVMContext>();
+  std::unique_ptr<CodeGenCPU> cg{new CodeGenCPU()};
+
+  cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib,
+           false /* target_c_runtime */);
+
+  cg->DefineMetadata(metadata);
+  auto mod = cg->Finish();
+  mod->addModuleFlag(llvm::Module::Warning, "tvm_target",
+                     llvm::MDString::get(*ctx, LLVMTargetToString(target)));
+  mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION);
+
+  if (tm->getTargetTriple().isOSDarwin()) {
+    mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
+  }
+
+  std::string verify_errors_storage;
+  llvm::raw_string_ostream verify_errors(verify_errors_storage);
+  LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors))
+      << "LLVM module verification failed with the following errors: \n"
+      << verify_errors.str();
+
+  // std::string tmp;
+  // llvm::raw_string_ostream stream(tmp);
+  // mod->print(stream, nullptr);
+  // LOG(INFO) << "LLVM metadata IR: " << stream.str();

Review comment:
       remove

##########
File path: src/target/llvm/codegen_cpu.cc
##########
@@ -802,10 +803,14 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
 
 CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>& args,
                                                          const DataType& r_type,
-                                                         const int64_t begin, const int64_t end) {
+                                                         const int64_t begin, const int64_t end,
+                                                         bool use_string_lookup) {
   PackedCall pc;
   std::string func_name = args[0].as<StringImmNode>()->value;
-  llvm::Value* handle = GetPackedFuncHandle(func_name);
+  llvm::Value* handle = nullptr;
+  if (use_string_lookup) {

Review comment:
       I think we dont need to introduce handle just yet.
   
   Shall we just merged to a single if/else down ? --  so its clear what happens when string-based function lookup is not used.

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -1399,9 +1411,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
+  //  LOG(INFO) << "Visit Call:" << GetRef<Call>(op);

Review comment:
       remove

##########
File path: src/target/metadata_module.cc
##########
@@ -144,6 +144,12 @@ static runtime::Module CreateCppMetadataModule(
         auto metadata_module = CreateCSourceCppMetadataModule(runtime_metadata);
         metadata_module->Import(target_module);
         target_module = metadata_module;
+#ifdef TVM_LLVM_VERSION

Review comment:
       Why this Ifdef ?

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace

Review comment:
       Why this empty namespace ?

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {

Review comment:
       Maybe not needed for this PR, but should we move this as tir::make_const_int32 ?

##########
File path: src/target/llvm/codegen_llvm.h
##########
@@ -389,6 +406,16 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
                                              unsigned int shared_address_space, int alignment,
                                              llvm::GlobalValue::LinkageTypes linkage);
 
+  llvm::Argument* GetArg(const llvm::Function* function, int i) const {

Review comment:
       Should we add docs for this ?

##########
File path: src/target/llvm/codegen_cpu.cc
##########
@@ -822,14 +827,46 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
   TypedPointer ret_tcode =
       CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32));
 
+  llvm::FunctionType* callee_ftype = nullptr;
+  llvm::Value* callee_value = nullptr;
+  std::vector<llvm::Value*> call_args;
+
+  if (use_string_lookup) {
+    callee_ftype = ftype_tvm_func_call_;
+    callee_value = RuntimeTVMFuncCall();
+    call_args.push_back(handle);
+  } else {
+    callee_ftype = ftype_tvm_backend_packed_c_func_;
+    callee_value = module_->getFunction(func_name);
+    if (callee_value == nullptr) {
+      callee_value =
+          llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage,
+                                 func_name, module_.get());
+    }
+  }
+
+  if (use_string_lookup) {

Review comment:
       Here also lets merge this if/else blocks.

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace
+
+class AOTCallGenerator {

Review comment:
       docs : we need docs for this but I think this part is still WIP as I did not see who uses this

##########
File path: src/target/llvm/codegen_cpu.cc
##########
@@ -914,6 +952,321 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
   return GetContextPtr(gv_tvm_parallel_barrier_);
 }
 
+/*! \brief Defines LLVM Types for each Metadata member type. */
+struct MetadataLlvmTypes {
+  llvm::Type* t_float64;
+  llvm::Type* t_uint8;
+  llvm::Type* t_int64;
+  llvm::Type* t_bool;
+  llvm::Type* t_cstring;
+  llvm::Type* t_void_p;
+  llvm::StructType* t_data_type;
+
+  /*! \brief Maps a MetadataBase subclass' type_key to its corresponding LLVM StructType. */
+  ::std::unordered_map<std::string, llvm::StructType*> structs_by_type_key;
+};
+
+class MetadataTypeDefiner : public AttrVisitor {
+ public:
+  MetadataTypeDefiner(llvm::LLVMContext* ctx, struct MetadataLlvmTypes* llvm_types)
+      : ctx_{ctx}, llvm_types_{llvm_types} {}
+
+  void Visit(const char* key, double* value) final {
+    elements_.emplace_back(llvm_types_->t_float64);
+  }
+  void Visit(const char* key, int64_t* value) final {
+    elements_.emplace_back(llvm_types_->t_int64);
+  }
+  void Visit(const char* key, uint64_t* value) final {
+    elements_.emplace_back(llvm_types_->t_int64);
+  }
+  void Visit(const char* key, int* value) final { elements_.emplace_back(llvm_types_->t_int64); }
+  void Visit(const char* key, bool* value) final { elements_.emplace_back(llvm_types_->t_bool); }
+  void Visit(const char* key, std::string* value) final {
+    elements_.emplace_back(llvm_types_->t_cstring);
+  }
+  void Visit(const char* key, void** value) final { elements_.emplace_back(llvm_types_->t_void_p); }
+  void Visit(const char* key, DataType* value) final {
+    elements_.emplace_back(llvm_types_->t_data_type);
+  }
+  void Visit(const char* key, runtime::NDArray* value) final {
+    CHECK(false) << "Do not support serializing NDArray";
+  }
+
+ private:
+  void VisitMetadataBase(runtime::metadata::MetadataBase metadata) {
+    elements_.emplace_back(llvm::PointerType::getUnqual(
+        llvm::StructType::create(*ctx_, metadata->get_c_struct_name())));
+    if (visited_.find(metadata->get_c_struct_name()) != visited_.end()) {
+      return;
+    }
+
+    if (to_visit_.find(metadata->get_c_struct_name()) != to_visit_.end()) {
+      return;
+    }
+    to_visit_[metadata->get_c_struct_name()] = metadata;
+  }
+
+ public:
+  using MetadataKind = runtime::metadata::MetadataKind;
+
+  void VisitArray(const runtime::metadata::MetadataArrayNode* arr) {
+    switch (arr->kind) {
+      case MetadataKind::kUint64:  // LLVM encodes signed and unsigned with same types.
+      case MetadataKind::kInt64:
+        elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_int64));
+        break;
+      case MetadataKind::kBool:
+        elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_bool));
+        break;
+      case MetadataKind::kString:
+        elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_cstring));
+        break;
+      case MetadataKind::kHandle:
+        CHECK(false) << "Do not support handle";
+        break;
+      case MetadataKind::kMetadata:
+        elements_.emplace_back(
+            llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[arr->type_key]));
+        break;
+      default:
+        CHECK(false) << "Unsupported metadata kind " << arr->kind;
+        break;
+    }
+  }
+
+  void Visit(const char* key, ObjectRef* value) final {
+    const runtime::metadata::MetadataArrayNode* arr =
+        value->as<runtime::metadata::MetadataArrayNode>();
+    if (arr != nullptr) {
+      VisitArray(arr);
+    } else {
+      elements_.emplace_back(
+          llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[(*value)->GetTypeKey()]));
+    }
+  }
+
+  void DefineType(runtime::metadata::MetadataBase metadata) {
+    ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
+    LOG(INFO) << "Created type for " << metadata->GetTypeKey() << ":";
+    for (auto e : elements_) {
+      std::string value;
+      llvm::raw_string_ostream os(value);
+      e->print(os, true);
+      //      LOG(INFO) << " - " << e << ", tyid=" << e->getTypeID() << " == " << value;

Review comment:
       remove

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -1399,9 +1411,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
+  //  LOG(INFO) << "Visit Call:" << GetRef<Call>(op);
   if (auto* ptr_op = op->op.as<OpNode>()) {
     auto call_op = GetRef<Op>(ptr_op);
-    if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
+    if (op->op.same_as(builtin_lookup_param_)) {
+      //      return llvm::ConstantInt::get(t_void_p_, 0);

Review comment:
       remove

##########
File path: src/target/metadata_utils.cc
##########
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/target/metadata_utils.cc
+ * \brief Defines utility functions and classes for emitting metadata.
+ */
+#include "metadata_utils.h"

Review comment:
       Please add docs for all the functions introduced here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org