You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/07 03:35:58 UTC

[GitHub] [tvm] areusch commented on a diff in pull request #10753: [AOT] Support LLVM backend with C++ runtime

areusch commented on code in PR #10753:
URL: https://github.com/apache/tvm/pull/10753#discussion_r843339941


##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace

Review Comment:
   to make this a "static" function aka it cannot be linked against from elsewhere



##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {

Review Comment:
   changed to use `tir::make_const`, which i think handles it



##########
src/target/llvm/codegen_cpu.cc:
##########
@@ -822,14 +827,46 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
   TypedPointer ret_tcode =
       CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32));
 
+  llvm::FunctionType* callee_ftype = nullptr;
+  llvm::Value* callee_value = nullptr;
+  std::vector<llvm::Value*> call_args;
+
+  if (use_string_lookup) {
+    callee_ftype = ftype_tvm_func_call_;
+    callee_value = RuntimeTVMFuncCall();
+    call_args.push_back(handle);
+  } else {
+    callee_ftype = ftype_tvm_backend_packed_c_func_;
+    callee_value = module_->getFunction(func_name);
+    if (callee_value == nullptr) {
+      callee_value =
+          llvm::Function::Create(ftype_tvm_backend_packed_c_func_, llvm::Function::ExternalLinkage,
+                                 func_name, module_.get());
+    }
+  }
+
+  if (use_string_lookup) {

Review Comment:
   done



##########
src/target/metadata_utils.cc:
##########
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/target/metadata_utils.cc
+ * \brief Defines utility functions and classes for emitting metadata.
+ */
+#include "metadata_utils.h"

Review Comment:
   done (let me know if more are needed)



##########
src/target/metadata_module.cc:
##########
@@ -144,6 +144,12 @@ static runtime::Module CreateCppMetadataModule(
         auto metadata_module = CreateCSourceCppMetadataModule(runtime_metadata);
         metadata_module->Import(target_module);
         target_module = metadata_module;
+#ifdef TVM_LLVM_VERSION

Review Comment:
   indicates USE_LLVM is ON. added a comment.



##########
src/target/llvm/llvm_module.cc:
##########
@@ -527,6 +527,46 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
       return runtime::Module(n);
     });
 
+runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target,
+                                            tvm::relay::Runtime runtime) {
+  InitializeLLVM();
+  auto tm = GetLLVMTargetMachine(target);
+  bool system_lib = runtime->GetAttr<Bool>("system-lib").value_or(Bool(false));
+  auto ctx = std::make_shared<llvm::LLVMContext>();
+  std::unique_ptr<CodeGenCPU> cg{new CodeGenCPU()};
+
+  cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib,
+           false /* target_c_runtime */);
+
+  cg->DefineMetadata(metadata);
+  auto mod = cg->Finish();
+  mod->addModuleFlag(llvm::Module::Warning, "tvm_target",
+                     llvm::MDString::get(*ctx, LLVMTargetToString(target)));
+  mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION);
+
+  if (tm->getTargetTriple().isOSDarwin()) {
+    mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
+  }
+
+  std::string verify_errors_storage;
+  llvm::raw_string_ostream verify_errors(verify_errors_storage);
+  LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors))
+      << "LLVM module verification failed with the following errors: \n"
+      << verify_errors.str();
+
+  // std::string tmp;
+  // llvm::raw_string_ostream stream(tmp);
+  // mod->print(stream, nullptr);
+  // LOG(INFO) << "LLVM metadata IR: " << stream.str();

Review Comment:
   done



##########
src/target/llvm/codegen_cpu.cc:
##########
@@ -914,6 +952,321 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
   return GetContextPtr(gv_tvm_parallel_barrier_);
 }
 
+/*! \brief Defines LLVM Types for each Metadata member type. */
+struct MetadataLlvmTypes {
+  llvm::Type* t_float64;
+  llvm::Type* t_uint8;
+  llvm::Type* t_int64;
+  llvm::Type* t_bool;
+  llvm::Type* t_cstring;
+  llvm::Type* t_void_p;
+  llvm::StructType* t_data_type;
+
+  /*! \brief Maps a MetadataBase subclass' type_key to its corresponding LLVM StructType. */
+  ::std::unordered_map<std::string, llvm::StructType*> structs_by_type_key;
+};
+
+class MetadataTypeDefiner : public AttrVisitor {
+ public:
+  MetadataTypeDefiner(llvm::LLVMContext* ctx, struct MetadataLlvmTypes* llvm_types)
+      : ctx_{ctx}, llvm_types_{llvm_types} {}
+
+  void Visit(const char* key, double* value) final {
+    elements_.emplace_back(llvm_types_->t_float64);
+  }
+  void Visit(const char* key, int64_t* value) final {
+    elements_.emplace_back(llvm_types_->t_int64);
+  }
+  void Visit(const char* key, uint64_t* value) final {
+    elements_.emplace_back(llvm_types_->t_int64);
+  }
+  void Visit(const char* key, int* value) final { elements_.emplace_back(llvm_types_->t_int64); }
+  void Visit(const char* key, bool* value) final { elements_.emplace_back(llvm_types_->t_bool); }
+  void Visit(const char* key, std::string* value) final {
+    elements_.emplace_back(llvm_types_->t_cstring);
+  }
+  void Visit(const char* key, void** value) final { elements_.emplace_back(llvm_types_->t_void_p); }
+  void Visit(const char* key, DataType* value) final {
+    elements_.emplace_back(llvm_types_->t_data_type);
+  }
+  void Visit(const char* key, runtime::NDArray* value) final {
+    CHECK(false) << "Do not support serializing NDArray";
+  }
+
+ private:
+  void VisitMetadataBase(runtime::metadata::MetadataBase metadata) {
+    elements_.emplace_back(llvm::PointerType::getUnqual(
+        llvm::StructType::create(*ctx_, metadata->get_c_struct_name())));
+    if (visited_.find(metadata->get_c_struct_name()) != visited_.end()) {
+      return;
+    }
+
+    if (to_visit_.find(metadata->get_c_struct_name()) != to_visit_.end()) {
+      return;
+    }
+    to_visit_[metadata->get_c_struct_name()] = metadata;
+  }
+
+ public:
+  using MetadataKind = runtime::metadata::MetadataKind;
+
+  void VisitArray(const runtime::metadata::MetadataArrayNode* arr) {
+    switch (arr->kind) {
+      case MetadataKind::kUint64:  // LLVM encodes signed and unsigned with same types.
+      case MetadataKind::kInt64:
+        elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_int64));
+        break;
+      case MetadataKind::kBool:
+        elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_bool));
+        break;
+      case MetadataKind::kString:
+        elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_cstring));
+        break;
+      case MetadataKind::kHandle:
+        CHECK(false) << "Do not support handle";
+        break;
+      case MetadataKind::kMetadata:
+        elements_.emplace_back(
+            llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[arr->type_key]));
+        break;
+      default:
+        CHECK(false) << "Unsupported metadata kind " << arr->kind;
+        break;
+    }
+  }
+
+  void Visit(const char* key, ObjectRef* value) final {
+    const runtime::metadata::MetadataArrayNode* arr =
+        value->as<runtime::metadata::MetadataArrayNode>();
+    if (arr != nullptr) {
+      VisitArray(arr);
+    } else {
+      elements_.emplace_back(
+          llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[(*value)->GetTypeKey()]));
+    }
+  }
+
+  void DefineType(runtime::metadata::MetadataBase metadata) {
+    ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
+    LOG(INFO) << "Created type for " << metadata->GetTypeKey() << ":";
+    for (auto e : elements_) {
+      std::string value;
+      llvm::raw_string_ostream os(value);
+      e->print(os, true);
+      //      LOG(INFO) << " - " << e << ", tyid=" << e->getTypeID() << " == " << value;

Review Comment:
   done



##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace
+
+class AOTCallGenerator {

Review Comment:
   good point--this is cruft from the old approach, before I started relying on LegalizePackedCalls/LowerTVMBuiltin, so can remove it now :)



##########
src/target/llvm/codegen_cpu.cc:
##########
@@ -802,10 +803,14 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
 
 CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>& args,
                                                          const DataType& r_type,
-                                                         const int64_t begin, const int64_t end) {
+                                                         const int64_t begin, const int64_t end,
+                                                         bool use_string_lookup) {
   PackedCall pc;
   std::string func_name = args[0].as<StringImmNode>()->value;
-  llvm::Value* handle = GetPackedFuncHandle(func_name);
+  llvm::Value* handle = nullptr;
+  if (use_string_lookup) {

Review Comment:
   done!



##########
src/target/llvm/codegen_llvm.cc:
##########
@@ -1399,9 +1411,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
+  //  LOG(INFO) << "Visit Call:" << GetRef<Call>(op);
   if (auto* ptr_op = op->op.as<OpNode>()) {
     auto call_op = GetRef<Op>(ptr_op);
-    if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
+    if (op->op.same_as(builtin_lookup_param_)) {
+      //      return llvm::ConstantInt::get(t_void_p_, 0);

Review Comment:
   done



##########
src/target/llvm/codegen_llvm.cc:
##########
@@ -1399,9 +1411,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
+  //  LOG(INFO) << "Visit Call:" << GetRef<Call>(op);

Review Comment:
   done



##########
src/target/llvm/codegen_llvm.h:
##########
@@ -389,6 +406,16 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
                                              unsigned int shared_address_space, int alignment,
                                              llvm::GlobalValue::LinkageTypes linkage);
 
+  llvm::Argument* GetArg(const llvm::Function* function, int i) const {

Review Comment:
   done



##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace
+
+class AOTCallGenerator {
+ public:
+  explicit AOTCallGenerator(std::string func_name)
+      : func_name_{func_name}, args_{tvm::tir::StringImm(func_name)} {}
+
+  tir::Var PushArg(PrimExpr arg) {
+    if (!arg->IsInstance<tir::VarNode>()) {
+      arg = MakeLetBind(arg);
+    }
+    args_.push_back(arg);
+    return Downcast<tir::Var>(arg);
+  }
+
+  void PushStackDLTensor(const TensorType& ttype, PrimExpr data) {
+    auto dltensor_var = MakeLetBind(StackAlloca("array", 1));
+    auto shape_var = MakeLetBind(StackAlloca("shape", ttype->shape.size()));
+
+    // Populate DLTensor.data
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrData, data})));
+
+    // Populate DLTensor.device
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrDeviceType, kDLCPU})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrDeviceId, 0})));
+
+    // Populate DLTensor.ndim
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrNDim, static_cast<int32_t>(ttype->shape.size())})));
+
+    // Populate DLTensor.dtype
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrTypeCode,
+                                      IntImm(DataType(kDLUInt, 8, 1), ttype->dtype.code())})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrTypeBits,
+                                      IntImm(DataType(kDLUInt, 8, 1), ttype->dtype.bits())})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrTypeLanes,
+                                      IntImm(DataType(kDLUInt, 16, 1), ttype->dtype.lanes())})));
+
+    // Populate DLTensor.shape
+    for (size_t i = 0; i < ttype->shape.size(); ++i) {
+      prep_stmts_.push_back(tvm::tir::Store(
+          shape_var, IntImm(DataType(kDLInt, 64, 1), Downcast<IntImm>(ttype->shape[i])->value),
+          IntImm(DataType(kDLUInt, 64, 1), i), tir::const_true()));
+    }
+
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrShape, shape_var})));
+
+    // Populate DLTensor.strides. DNS -- TODO actually pull correct byte_offset
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrStrides, IntImm(DataType(kDLUInt, 64, 1), 0)})));
+
+    // Populate DLTensor.byte_offset. DNS -- TODO actually pull correct byte_offset
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrByteOffset, IntImm(DataType(kDLUInt, 64, 1), 0)})));
+
+    args_.push_back(dltensor_var);
+  }
+
+  void PushStackDLTensors(const Expr& expr, std::vector<tir::Var> sids) {
+    const TupleNode* t = expr.as<TupleNode>();
+    if (t != nullptr) {
+      CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't "
+                                                 "handle this type of Relay Expr in a CallNode.";
+      for (size_t i = 0; i < sids.size(); i++) {
+        PushStackDLTensor(Downcast<TensorType>(t->fields[i]->checked_type()), sids[i]);
+      }
+    } else {
+      PushStackDLTensor(Downcast<TensorType>(expr->checked_type()), sids[0]);
+    }
+  }
+
+  tir::Stmt GenerateUnpacked(std::string device_name, PrimExpr device_context) {
+    auto make_call = [this] {
+      return tir::Evaluate(
+          tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args_));
+    };
+    if (device_context.defined()) {
+      tir::Var context_var = PushArg(device_context);
+      return Generate(tir::SeqStmt({
+          MakeDeviceHookCall(device_name, "Open", context_var),
+          make_call(),
+          MakeDeviceHookCall(device_name, "Close", context_var),
+      }));
+    } else {
+      return Generate(make_call());
+    }
+  }
+
+  tir::Stmt GeneratePacked() {
+    return Generate(
+        tir::Evaluate(tvm::tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), args_)));
+  }
+
+  tir::Stmt GenerateCPacked() {
+    // call_cpacked calling convention does not use a context
+    PushArg(tir::make_zero(DataType::Handle()));
+    return Generate(
+        tir::Evaluate(tvm::tir::Call(DataType::Int(32), tir::builtin::tvm_call_cpacked(), args_)));
+  }
+
+ private:
+  tir::Stmt Generate(tir::Stmt call_stmts) {
+    tir::Stmt body = tir::SeqStmt::Flatten(prep_stmts_, call_stmts);
+
+    for (auto bind : let_binds_) {
+      body = tir::LetStmt(bind.first, bind.second, body);

Review Comment:
   yeah sorry, this part is cruft and should get deleted.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org