You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/07/18 18:08:53 UTC

[tvm] branch main updated: [AOT] Avoid call_extern() with incorrect argument count (#15301)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d81e8809b8 [AOT] Avoid call_extern() with incorrect argument count (#15301)
d81e8809b8 is described below

commit d81e8809b81b0746d4a7230c10fde858f5a177ba
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Tue Jul 18 13:08:47 2023 -0500

    [AOT] Avoid call_extern() with incorrect argument count (#15301)
    
    Prior to this commit, if device initialization is required, the AOT
    main function produced a `call_extern()` that included the device
    context as input.  This commit updates the AOT main function to
    provide the device context only if the function being called accepts a
    device context as input.
    
    If an extra device context argument is included at the call site, the
    C codegen would produce a function signature that includes the device
    context for the caller's compilation unit, but a signature without the
    device context for the callee's compilation unit.  While this can
    compile and run in some cases, it is undefined behavior for the
    signature to vary between compilation units, and should be avoided.
    
    This was initially discovered while debugging
    https://github.com/apache/tvm/pull/14985, in which changes to the
    lowering flow resulted in the caller and callee being within the same
    compilation unit.
---
 src/relay/backend/aot_executor_codegen.cc | 38 ++++++++++++++++++++++++++++++-
 1 file changed, 37 insertions(+), 1 deletion(-)

diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc
index 945290f702..f698c654d6 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -454,7 +454,32 @@ class AOTExecutorCodegen : public MixedModeVisitor {
         // call_extern calling convention with optional context
         if (has_c_device_api_context) {
           device_context = device_contexts_.Get(global_var).value();
-          args.push_back(device_context);
+
+          // call_extern has no further legalization steps, and
+          // requires the number of arguments to match exactly.    For
+          // internal calls, conditionally append the device context.
+          bool requires_device_context = [&]() -> bool {
+            Optional<Integer> opt = num_arguments_.Get(global_var);
+            if (!opt.defined()) {
+              // For external calls, we must trust that the user has
+              // supplied a kernel that accepts a device_context
+              // argument.
+              return true;
+            }
+            int num_callee_params = opt.value()->value;
+            int num_args = call_lowered_props.arguments.size();
+            if (num_callee_params == num_args) {
+              return false;
+            } else if (num_callee_params == num_args + 1) {
+              return true;
+            } else {
+              LOG(FATAL) << "Callee " << global_var << " requires " << num_callee_params
+                         << ", but is called with " << num_args << " arguments.";
+            }
+          }();
+          if (requires_device_context) {
+            args.push_back(device_context);
+          }
         }
         func_call = tir::Evaluate(AddCheckReturn(
             tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args)));
@@ -1007,6 +1032,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
   Map<String, tir::Var> devices_;
   /*! \brief map of GlobalVars to C Device API contexts */
   Map<GlobalVar, tir::Var> device_contexts_;
+  /*! \brief map of GlobalVars to the number of arguments they require */
+  Map<GlobalVar, Integer> num_arguments_;
   /*! \brief input and output variables belonging to the main function signature */
   Array<tir::Var> main_signature_;
   /*! \brief input and output variables belonging to the main function signature */
@@ -1183,6 +1210,15 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     }
 
     CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts").value());
+    num_arguments_ = [&]() -> Map<GlobalVar, Integer> {
+      Map<GlobalVar, Integer> arg_count;
+      for (const auto& [gvar, func] : lowered_mod->functions) {
+        if (const auto* prim_func = func.as<tir::PrimFuncNode>()) {
+          arg_count.Set(gvar, prim_func->params.size());
+        }
+      }
+      return arg_count;
+    }();
     VisitExpr(lowered_main_func->body);
 
     // Create the runner function. Please note that the function is not legal yet