You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/24 16:30:39 UTC

[GitHub] [tvm] giuseros commented on a change in pull request #8023: [AOT] Initial implementation of --typed-operators

giuseros commented on a change in pull request #8023:
URL: https://github.com/apache/tvm/pull/8023#discussion_r638106241



##########
File path: src/target/source/source_module.cc
##########
@@ -192,17 +192,59 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
           << "}\n";
   }
 
+  void GenerateUntypedEntrypoint() {
+    code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix << "(";
+    int total_args = (metadata_->num_inputs + metadata_->num_outputs);
+    for (int i = 0; i < total_args; ++i) {
+      code_ << "arg" << i;
+      if (i + 1 != total_args) {
+        code_ << ",";
+      }
+    }
+    code_ << ");\n";
+    code_ << "static int32_t _tvm_entrypoint";
+    code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
+             "out_type_code, void* resource_handle) {\n";
+    code_ << "return " << ::tvm::runtime::symbol::tvm_run_func_prefix << "(";
+    for (int i = 0; i < metadata_->num_inputs; ++i) {
+      code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,";
+    }
+    for (int i = 0; i < metadata_->num_outputs; ++i) {
+      int j = metadata_->num_inputs + i;
+      code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data";
+      if (i + 1 != metadata_->num_outputs) {
+        code_ << ",";
+      }
+    }
+    code_ << ");\n";
+    code_ << "}\n";
+  }
+
+  void GenerateTypedEntrypoint() {
+    code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix;
+    code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
+             "out_type_code, void* resource_handle);\n";
+    code_ << "static int32_t _tvm_entrypoint";
+    code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
+             "out_type_code, void* resource_handle) {\n";
+    code_ << "return " << ::tvm::runtime::symbol::tvm_run_func_prefix;
+    code_ << "(args, type_code, num_args, out_value, out_type_code, resource_handle);\n";
+    code_ << "}\n";
+  }
+
   void GenerateAOTDescriptor() {
     code_ << "#include \"tvm/runtime/crt/internal/aot_executor/aot_executor.h\"\n";
     code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n";
     code_ << "#ifdef __cplusplus\n";
     code_ << "extern \"C\"\n";
     code_ << "#endif\n";
-    code_ << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_run_func_prefix;
-    code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
-             "out_type_code, void* resource_handle);\n";
+    if (target_->GetAttr<Bool>("typed-operators").value_or(Bool(true))) {
+      GenerateTypedEntrypoint();
+    } else {
+      GenerateUntypedEntrypoint();
+    }
     code_ << "const tvm_model_t network = {\n"
-          << "    .run_func = &" << ::tvm::runtime::symbol::tvm_run_func_prefix << ",\n"
+          << "    .run_func = &_tvm_entrypoint,\n"

Review comment:
       Maybe define a constant string and store it as a class member (e.g., `tvm_entrypoint_name`) , so we have only one place where this is defined. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org