You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2023/08/08 21:27:32 UTC

[tvm] branch main updated: [CodeGenC] Handle GlobalVar callee as internal function call (#15103)

This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9ff71f4a9f [CodeGenC] Handle GlobalVar callee as internal function call (#15103)
9ff71f4a9f is described below

commit 9ff71f4a9fed3ec9f82b999fb8c25ef1bf6e243c
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Tue Aug 8 16:27:24 2023 -0500

    [CodeGenC] Handle GlobalVar callee as internal function call (#15103)
    
    Analogous to #14901, treat GlobalVar callees as internal function calls in CodeGenC. This specific PR doesn't provide new end-to-end functionality, as the target="c" backend isn't compiled. It does lead into allowing subroutines in any target whose codegen derives from CodeGenC, which will depend on the single-module lowering flow in #14985.
    
    * [CodeGenC] Added unit tests for desired behavior
    
    * [CodeGenC] Handle GlobalVar callee as internal function call
    
    * Update CodeGenC subclasses for updated interface
    
    - Call `DeclareFunction` for each `PrimFunc`, prior to any
      `AddFunction` calls
    
    - Provide both `GlobalVar` and `PrimFunc` to `AddFunction` calls.
    
    * Updated CRT test to expect forward declaration
    
    * Provide forward declarations for call_extern in cmsis
    
    * Avoid duplicate forward declaration
    
    C's automatic pointer cast (e.g. `void*` to `int*`) means that use of
    the arguments to infer the function signature may be incorrect.  If a
    `call_extern` refers to a function within the same module, only output
    a single forward declaration based on the PrimFunc's parameters, not
    based on the CallNode's arguments.
    
    * Updated expected ptx cuda
    
    * Cast the AOT pools to the arg type
    
    * Improved tvm::GetType for tvm_access_ptr and address_of
    
    These `Call` instances can return a
    `PointerType(PrimType(pointee_dtype))` rather than a
    `PrimType(DataType::Handle())`.
    
    * [ARM][Topi] Update micro kernels to use same argument type as caller
    
    Previously, the micro kernels for gemm, avg_pool, max_pool, and
    tensordot relied on C's implicit type conversions for the arguments,
    when the caller's argument types differ from the signature's parameter
    types.  This works, except when the codegen has auto-generated a
    forward declaration based on the caller's argument types, such as
    during AOT, which then causes a conflicting definition.
    
    Since the codegen cannot determine the functions names from the
    `"pragma_import_c"` in order to suppress these forward declarations,
    this conflict can be more easily resolved by updating the micro kernel
    signatures.  The three types of mismatches are below.
    
    - Use of `int` or `long` parameters, whose width may vary by compiler,
      instead of fixed-width types.
    
    - TIR expecting the data array's integer type to also be used as an
      error code's return type, rather than the micro kernels' `int32_t`
      error code.
    
    - Pointer conversion done during argument conversion.
    
    Type conversions are done at the start of each micro kernel, to avoid
    changing types that are used within the computational sections of each
    micro kernel.
    
    * Updated unit tests with private=True
    
    Required for internal functions after PR #15214
    
    * Docstring updates from review
---
 .../arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py  |   8 +-
 .../topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py |  87 +++++++++---
 .../arm_cpu/mprofile/dsp/micro_kernel/max_pool.py  |  13 +-
 .../arm_cpu/mprofile/dsp/micro_kernel/tensordot.py |   7 +-
 .../backend/contrib/cmsisnn/tir_to_runtime.cc      |  28 ++--
 .../contrib/example_target_hooks/tir_to_runtime.cc |  26 +++-
 src/relay/backend/contrib/uma/tir_to_runtime.cc    |  34 +++--
 src/target/opt/build_cuda_on.cc                    |  18 ++-
 src/target/source/codegen_aocl.cc                  |  19 ++-
 src/target/source/codegen_c.cc                     | 153 ++++++++++++++-------
 src/target/source/codegen_c.h                      |  59 +++++++-
 src/target/source/codegen_c_host.cc                |  93 ++++++-------
 src/target/source/codegen_c_host.h                 |   3 +-
 src/target/source/codegen_cuda.cc                  |   4 +-
 src/target/source/codegen_cuda.h                   |   2 +-
 src/target/source/codegen_metal.cc                 |  77 +++++------
 src/target/source/codegen_metal.h                  |   3 +-
 src/target/source/codegen_opencl.cc                |  24 ++--
 src/target/source/codegen_vhls.cc                  |  34 +++--
 src/target/source/codegen_webgpu.cc                |  79 +++++------
 src/target/source/codegen_webgpu.h                 |   4 +-
 src/target/source/source_module.cc                 |   6 +-
 src/tir/op/op.cc                                   |  26 ++++
 .../relay/aot/test_crt_forward_declarations.py     |   4 +-
 .../topi/python/test_topi_conv2d_tensordot_opts.py |  28 +++-
 .../python/unittest/test_target_codegen_c_host.py  |  48 +++++--
 .../test_tir_transform_inject_ptx_async_copy.py    |   1 +
 27 files changed, 591 insertions(+), 297 deletions(-)

diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
index e8e45152aa..3eb32d8fdb 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
@@ -55,7 +55,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False):
             ib = tvm.tir.ir_builder.create()
             ib.emit(
                 tvm.tir.call_extern(
-                    cc.dtype,
+                    "int32",
                     f"{func_prefix}_{width}_{uniq_id}",
                     aa.access_ptr("r"),
                     cc.access_ptr("w"),
@@ -68,7 +68,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False):
         def _reduce_reset():
             ib = tvm.tir.ir_builder.create()
             ib.emit(
-                tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
+                tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
             )
             return ib.get()
 
@@ -113,8 +113,8 @@ extern "C"
 __attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
     int16_t *arr,
     int16_t *res16,
-    long arr_offset,
-    int reset) {{
+    int32_t arr_offset,
+    int32_t reset) {{
   int n;
   int32_t *p32;
   int32_t res = reset ? 0 : *res16;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
index 929dcc6557..e26e818fbd 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
@@ -156,9 +156,14 @@ __attribute__((always_inline)) static inline const int8_t *read_and_pad(const in
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
-    int K,
+    int32_t K_arg,
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int K = K_arg;
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int k_base = (K / 4) * 4;
   switch ( K % 4 ) {{
   case 1:
@@ -200,7 +205,12 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
+
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -221,7 +231,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int16_t bb_pad[{bb_pad_size}];
   int32_t retcode = 0;
 
@@ -265,9 +279,14 @@ out:
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
-    int K,
+    int32_t K_arg,
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int K = K_arg;
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int k_base = (K / 4) * 4;
   switch ( K % 4 ) {{
   case 1:
@@ -309,7 +328,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -327,7 +350,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int16_t bb_pad[{bb_pad_size}];
   int32_t retcode = 0;
 
@@ -368,9 +395,14 @@ out:
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
-    int K,
+    int32_t K_arg,
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int K = K_arg;
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int k_base = (K / 2) * 2;
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
@@ -387,7 +419,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -408,7 +444,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int32_t retcode = 0;
 
   if ( {M} < 2 && {N} < 2 ) {{
@@ -450,9 +490,14 @@ out:
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
-    int K,
+    int32_t K_arg,
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int K = K_arg;
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int k_base = (K / 2) * 2;
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
@@ -469,7 +514,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -487,7 +536,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int32_t retcode = 0;
 
   if ( {M} < 2 && {N} < 2 ) {{
@@ -520,7 +573,7 @@ out:
 #ifdef __cplusplus
 extern "C"
 #endif
-__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
+__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       cc[i*C_stride + j] = 0;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
index 66d712a4a0..cfed417c9f 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
@@ -46,7 +46,7 @@ def intrin_max(shape, in_dtype, out_dtype):
             ib = tvm.tir.ir_builder.create()
             ib.emit(
                 tvm.tir.call_extern(
-                    cc.dtype,
+                    "int32",
                     f"{func_prefix}_{uniq_id}",
                     aa.access_ptr("r"),
                     cc.access_ptr("w"),
@@ -59,7 +59,7 @@ def intrin_max(shape, in_dtype, out_dtype):
             ib = tvm.tir.ir_builder.create()
             ib.emit(
                 tvm.tir.call_extern(
-                    cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
+                    "int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
                 )
             )
             return ib.get()
@@ -96,7 +96,7 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
     int8_t *res,
-    int N) {{
+    int32_t N) {{
   memset(res, (int8_t)-128, N * sizeof(*res));
   return 0;
 }}
@@ -107,7 +107,9 @@ extern "C"
 __attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
     int8_t *arg,
     int8_t *res,
-    int N) {{
+    int32_t N_arg) {{
+  int N = N_arg;
+
   for ( int i = 0; i < N; ++ i )
     if ( arg[i] > res[i] )
       res[i] = arg[i];
@@ -120,7 +122,8 @@ extern "C"
 __attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
     int8_t *arg,
     int8_t *res,
-    int N) {{
+    int32_t N_arg) {{
+  int N = N_arg;
   int32_t *parg32, *pres32;
   int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3;
   int32_t retcode = 0;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
index d2a8f1ef69..af3b23e01d 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
@@ -390,8 +390,13 @@ def tensordot_int16_impl(
         #define {function_name.upper()}_EXISTS
         #include <arm_acle.h>
         __attribute__((always_inline)) static inline int32_t {function_name}(
-            int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
+            int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+            int32_t *bias, int32_t *scale
         ) {{
+          int32_t *output = output_arg;
+          int32_t *tensor = tensor_arg;
+          int32_t *kernel = kernel_arg;
+
           {_init_biased_accumulators(num_outputs)}
 
           {insert_lines(load_tensor_lines)}
diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
index ba2aea54bb..ea2eabd767 100644
--- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
@@ -46,13 +46,6 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
     CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices);
   }
 
-  /*!
-   * \brief Emit code that offloads a subgraph to the Cortex-M
-   *
-   * \return string of code that offloads a subgraph to the Cortex-M
-   */
-  void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }
-
  private:
   /*!  * \brief Enable storing the last error */
   bool debug_last_error;
@@ -519,11 +512,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
   bool emit_fwd_func_decl = false;
   bool debug_last_error = GetCompilerAttrs()->debug_last_error;
   CodeGenCMSISNN codegen;
-  Array<String> function_names;
   codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error);
-  std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
-  for (auto kv : mod->functions) {
-    funcs.push_back(kv);
+
+  std::vector<std::pair<tvm::GlobalVar, tvm::PrimFunc>> funcs;
+  for (auto [gvar, base_func] : mod->functions) {
+    funcs.push_back({gvar, Downcast<PrimFunc>(base_func)});
   }
 
   std::sort(funcs.begin(), funcs.end(),
@@ -538,13 +531,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
               return name_hint_a < name_hint_b;
             });
 
-  for (auto kv : funcs) {
-    auto prim_func = Downcast<PrimFunc>(kv.second);
-    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    function_names.push_back(global_symbol.value());
-    codegen.AddFunction(prim_func);
+  for (auto [gvar, prim_func] : funcs) {
+    codegen.AddFunction(gvar, prim_func);
   }
   std::string code = codegen.Finish();
+
+  Array<String> function_names;
+  for (auto [gvar, prim_func] : funcs) {
+    function_names.push_back(codegen.GetFunctionName(gvar));
+  }
+
   return codegen::CSourceModuleCreate(code, "c", function_names);
 }
 
diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
index 0db8d06c31..6f09e0a0c3 100644
--- a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
@@ -49,16 +49,30 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
   bool emit_asserts = false;
   bool emit_fwd_func_decl = false;
   CodeGenExampleTargetHook codegen;
-  Array<String> function_names;
+
   std::unordered_set<std::string> devices;
   codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
-  for (auto kv : mod->functions) {
-    auto prim_func = Downcast<PrimFunc>(kv.second);
-    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    function_names.push_back(global_symbol.value());
-    codegen.AddFunction(prim_func);
+
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    functions.Set(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    codegen.DeclareFunction(gvar, prim_func);
+  }
+  for (auto [gvar, prim_func] : functions) {
+    codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
   }
+
   std::string code = codegen.Finish();
+
+  Array<String> function_names;
+  for (auto [gvar, prim_func] : functions) {
+    function_names.push_back(codegen.GetFunctionName(gvar));
+  }
+
   return codegen::CSourceModuleCreate(code, "c", function_names);
 }
 
diff --git a/src/relay/backend/contrib/uma/tir_to_runtime.cc b/src/relay/backend/contrib/uma/tir_to_runtime.cc
index 3b58fda54b..487e247f5d 100644
--- a/src/relay/backend/contrib/uma/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/uma/tir_to_runtime.cc
@@ -49,13 +49,6 @@ class UMACodegen : public codegen::CodeGenCHost {
     CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str_, devices);
   }
 
-  /*!
-   * \brief Emit code that offloads a subgraph to the UMA target
-   *
-   * \return string of code that offloads a subgraph to the UMA target
-   */
-  void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }
-
  private:
   String target_str_;
 };
@@ -63,17 +56,30 @@ class UMACodegen : public codegen::CodeGenCHost {
 runtime::Module TIRToRuntime(IRModule mod, Target target) {
   bool output_ssa = false;
   bool emit_asserts = false;
-  bool emit_fwd_func_decl = false;
+  bool emit_fwd_func_decl = true;
   UMACodegen codegen(target->kind->name);
-  Array<String> function_names;
   codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl);
-  for (auto kv : mod->functions) {
-    auto prim_func = Downcast<PrimFunc>(kv.second);
-    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    function_names.push_back(global_symbol.value());
-    codegen.AddFunction(prim_func);
+
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    functions.Set(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    codegen.DeclareFunction(gvar, prim_func);
+  }
+  for (auto [gvar, prim_func] : functions) {
+    codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
   }
+
   std::string code = codegen.Finish();
+
+  Array<String> function_names;
+  for (auto [gvar, prim_func] : functions) {
+    function_names.push_back(codegen.GetFunctionName(gvar));
+  }
+
   return codegen::CSourceModuleCreate(code, "c", function_names);
 }
 
diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc
index 1c0b5094ef..e0f53e3509 100644
--- a/src/target/opt/build_cuda_on.cc
+++ b/src/target/opt/build_cuda_on.cc
@@ -131,13 +131,21 @@ runtime::Module BuildCUDA(IRModule mod, Target target) {
   CodeGenCUDA cg;
   cg.Init(output_ssa);
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-    cg.AddFunction(f);
+    functions.Set(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    cg.DeclareFunction(gvar, prim_func);
+  }
+  for (auto [gvar, prim_func] : functions) {
+    cg.AddFunction(gvar, prim_func);
   }
 
   std::string code = cg.Finish();
diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc
index 700d85b4cc..dc3ba08751 100644
--- a/src/target/source/codegen_aocl.cc
+++ b/src/target/source/codegen_aocl.cc
@@ -40,13 +40,22 @@ runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) {
   CodeGenOpenCL cg;
   cg.Init(output_ssa);
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-    cg.AddFunction(f);
+    functions.Set(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    cg.DeclareFunction(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    cg.AddFunction(gvar, prim_func);
   }
 
   std::string code = cg.Finish();
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index a7cc320562..187bdc74fe 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -42,6 +42,7 @@ void CodeGenC::InitFuncState(const PrimFunc& f) {
   alloc_storage_scope_.clear();
   handle_data_type_.clear();
   CodeGenSourceBase::ClearFuncState();
+  ReserveKeywordsAsUnique();
 }
 
 void CodeGenC::ReserveKeywordsAsUnique() {
@@ -75,51 +76,92 @@ void CodeGenC::ReserveKeywordsAsUnique() {
   name_supply_->ReserveName("return");
 }
 
-void CodeGenC::AddFunction(const PrimFunc& f) {
-  // clear previous generated state.
-  this->InitFuncState(f);
-  // reserve keywords
-  ReserveKeywordsAsUnique();
+void CodeGenC::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
+                                      std::ostream& os) {
+  PrintFuncPrefix(os);
+  PrintType(func->ret_type, os);
+  PrintExtraAttrs(func, os);
+  os << " " << function_name << "(";
+  for (size_t i = 0; i < func->params.size(); ++i) {
+    tir::Var v = func->params[i];
 
-  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-  ICHECK(global_symbol.defined())
-      << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
-  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
-
-  this->PrintFuncPrefix(stream);
-  PrintType(f->ret_type, stream);
-  this->PrintExtraAttrs(f);
-  this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
-
-  for (size_t i = 0; i < f->params.size(); ++i) {
-    tir::Var v = f->params[i];
-    std::string vid = AllocVarID(v.get());
-    if (i != 0) stream << ", ";
-    if (v.dtype().is_handle()) {
-      auto it = alloc_storage_scope_.find(v.get());
-      if (it != alloc_storage_scope_.end()) {
-        PrintStorageScope(it->second, stream);
-      }
+    if (i > 0) {
+      os << ", ";
+    }
 
-      PrintType(GetType(v), stream);
-      // Register handle data type
-      // TODO(tvm-team): consider simply keep type info in the
-      // type annotation(via a normalizing rewriting).
-      if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
-        if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
-          RegisterHandleType(v.get(), prim->dtype);
-        }
-      }
+    if (auto it = alloc_storage_scope_.find(v.get()); it != alloc_storage_scope_.end()) {
+      PrintStorageScope(it->second, os);
+    }
 
-      if (no_alias) {
-        PrintRestrict(v, stream);
+    PrintType(GetType(v), os);
+
+    bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
+    bool is_handle = v.dtype().is_handle();
+    if (no_alias && is_handle) {
+      PrintRestrict(v, os);
+    }
+
+    os << " " << AllocVarID(v.get());
+  }
+  os << ")";
+
+  // Register handle data type
+  // TODO(tvm-team): consider simply keep type info in the
+  // type annotation(via a normalizing rewriting).
+  for (const auto& param : func->params) {
+    if (auto* ptr = param->type_annotation.as<PointerTypeNode>()) {
+      if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
+        RegisterHandleType(param.get(), prim->dtype);
       }
-    } else {
-      PrintType(GetType(v), stream);
     }
-    stream << ' ' << vid;
   }
-  stream << ") {\n";
+}
+
+void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) {
+  if (internal_functions_.count(gvar)) {
+    return;
+  }
+
+  auto function_name = [&]() -> String {
+    if (auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
+      auto name = global_symbol.value();
+      ICHECK(!func_name_supply_->ContainsName(name))
+          << "Function " << gvar << " must use global symbol " << name
+          << ", but this name has already been used.";
+      func_name_supply_->ReserveName(name);
+      return name;
+    } else {
+      func_name_supply_->ReserveName(gvar->name_hint);
+      return gvar->name_hint;
+    }
+  }();
+
+  internal_functions_.insert({gvar, function_name});
+
+  InitFuncState(func);
+  PrintFunctionSignature(function_name, func, fwd_decl_stream);
+  fwd_decl_stream << ";\n";
+}
+
+String CodeGenC::GetFunctionName(const GlobalVar& gvar) {
+  auto it = internal_functions_.find(gvar);
+  ICHECK(it != internal_functions_.end())
+      << "Attempted to find name of " << gvar
+      << ", but no function with this GlobalVar has been declared";
+  return it->second;
+}
+
+void CodeGenC::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
+  // If the function has already been forward-declared, this is a
+  // no-op.
+  DeclareFunction(gvar, f);
+  auto function_name = GetFunctionName(gvar);
+
+  // clear previous generated state.
+  InitFuncState(f);
+
+  PrintFunctionSignature(function_name, f, stream);
+  stream << " {\n";
   this->PreFunctionBody(f);
   int func_scope = this->BeginScope();
   this->PrintStmt(f->body);
@@ -130,9 +172,15 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
 
 void CodeGenC::PrintFuncPrefix(std::ostream& os) {}
 
-void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}
+void CodeGenC::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {}
 
-std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
+std::string CodeGenC::Finish() {
+  std::ostringstream code;
+  code << decl_stream.str();
+  code << fwd_decl_stream.str();
+  code << stream.str();
+  return code.str();
+}
 
 void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) {  // NOLINT(*)
   if (print_ssa_form_) {
@@ -542,12 +590,17 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
       ICHECK_GE(op->args.size(), 1U);
       auto func = Downcast<StringImm>(op->args[0]);
       this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
-      Array<Type> arg_types;
-      for (size_t i = 1; i < op->args.size(); i++) {
-        arg_types.push_back(GetType(op->args[i]));
+
+      // If the call_extern refers to an function within the IRModule, then
+      // the forward declaration is already provided from DeclareFunction.
+      if (!func_name_supply_->ContainsName(func->value)) {
+        Array<Type> arg_types;
+        for (size_t i = 1; i < op->args.size(); i++) {
+          arg_types.push_back(GetType(op->args[i]));
+        }
+        Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
+        this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type);
       }
-      Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
-      this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type);
     } else if (op_attr_global_symbol_.count(call_op)) {
       // call extern if the op itself have a global symbol.
       this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
@@ -615,9 +668,13 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
     } else {
       LOG(FATAL) << "Unresolved call " << op->op;
     }
+  } else if (auto opt = op->op.as<GlobalVar>()) {
+    auto gvar = opt.value();
+    auto callee_name = GetFunctionName(gvar);
+    PrintCallExtern(GetType(GetRef<PrimExpr>(op)), callee_name, op->args, false, os);
   } else {
-    ICHECK(op->op.as<GlobalVarNode>());
-    LOG(FATAL) << "Do not yet support cross function call";
+    LOG(FATAL) << "CodeGenC: Unknown operation " << op->op << " is neither a recognized built-in, "
+               << "nor a GlobalVar reference to another function in the IRModule";
   }
 }
 
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index 93f9ea519c..2921a56ef3 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -65,12 +65,33 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
    * \param output_ssa Whether output SSA.
    */
   void Init(bool output_ssa);
+
   /*!
-   * \brief Add the function to the generated module.
-   * \param f The function to be compiled.
+   * \brief Add the function declaration to the generated module,
+   * without defining it.
+   *
+   * \param gvar The GlobalVar representing the function.
+   * \param func The function to be compiled.
    * \param whether to append return 0 in the end.
    */
-  void AddFunction(const PrimFunc& f);
+  virtual void DeclareFunction(const GlobalVar& gvar, const PrimFunc& func);
+
+  /*!
+   * \brief Add the function to the generated module, including its
+   * declaration and definition.
+   *
+   * \param gvar The GlobalVar representing the function.
+   * \param func The function to be compiled.
+   */
+  virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& func);
+
+  /*!
+   * \brief Get the name of a declared function
+   * \param gvar The GlobalVar of the function
+   * \returns The string name of the function
+   */
+  String GetFunctionName(const GlobalVar& gvar);
+
   /*!
    * \brief Finalize the compilation and return the code.
    * \return The code.
@@ -96,7 +117,23 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
     PrintExpr(n, os);
     return os.str();
   }
+
   // The following parts are overloadable print operations.
+
+  /*! \brief Print the function signature before the argument list
+   *
+   * The default implementation delegates out to PrintFuncPrefix and
+   * PrintExtraAttrs.
+   *
+   * \param function_name The name of the function
+   *
+   * \param func The function whose signature should be printed
+   *
+   * \param os The output stream
+   */
+  virtual void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
+                                      std::ostream& os);
+
   /*!
    * \brief Print the function header before the argument list
    * \param os The output stream
@@ -109,7 +146,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
    *
    *  Example: __launch_bounds__(256) for CUDA functions
    */
-  virtual void PrintExtraAttrs(const PrimFunc& f);
+  virtual void PrintExtraAttrs(const PrimFunc& f, std::ostream& os);  // NOLINT(*)
   /*!
    * \brief Insert statement before function body.
    * \param f The function to be compiled.
@@ -284,10 +321,24 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
  private:
   /*! \brief set of volatile buf access */
   std::unordered_set<const VarNode*> volatile_buf_;
+
   // deep comparison of PrimExpr
   ExprDeepEqual deep_equal_;
+
   // binding of let variables. Enables duplicate var defs that map to same value
   std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
+
+  /* \brief Map of GlobalVar to their symbol.
+   *
+   * For externally-exposed functions, this is given by the
+   * tvm::attr::kTarget attribute of the PrimFunc.  For internal
+   * functions, this is the name of the function's GlobalVar, possibly
+   * altered to prevent duplicate names.
+   */
+  std::unordered_map<GlobalVar, String, ObjectPtrHash, ObjectPtrEqual> internal_functions_;
+
+  /* \brief Name supply to generate unique function names */
+  NameSupply func_name_supply_{""};
 };
 
 }  // namespace codegen
diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc
index 3255e11c5d..caef43e8af 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -75,19 +75,24 @@ void CodeGenCHost::InitGlobalContext() {
 
 void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; }
 
-void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
-  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-  ICHECK(global_symbol.defined())
-      << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute";
-  function_names_.push_back(global_symbol.value());
+void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func,
+                               bool emit_fwd_func_decl) {
+  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+  if (global_symbol) {
+    function_names_.push_back(global_symbol.value());
+  }
 
   emit_fwd_func_decl_ = emit_fwd_func_decl;
-  CodeGenC::AddFunction(f);
-  if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
+  CodeGenC::AddFunction(gvar, func);
+  if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
+    ICHECK(global_symbol.defined())
+        << "CodeGenCHost: The entry func must have the global_symbol attribute, "
+        << "but function " << gvar << " only has attributes " << func->attrs;
+
     function_names_.push_back(runtime::symbol::tvm_module_main);
     stream << "// CodegenC: NOTE: Auto-generated entry function\n";
     PrintFuncPrefix(stream);
-    PrintType(f->ret_type, stream);
+    PrintType(func->ret_type, stream);
     stream << " " << tvm::runtime::symbol::tvm_module_main
            << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
            << "int* out_ret_tcode, void* resource_handle) {\n";
@@ -128,15 +133,6 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) {  // NOLINT(*)
      << "TVM_DLL ";
 }
 
-std::string CodeGenCHost::Finish() {  // NOLINT(*)
-  std::string ret = decl_stream.str();
-  if (emit_fwd_func_decl_) {
-    ret += fwd_decl_stream.str();
-  }
-  ret += stream.str();
-  return ret;
-}
-
 void CodeGenCHost::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
   int lanes = t.lanes();
   if (t.is_handle()) {
@@ -437,42 +433,38 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
   CodeGenCHost cg;
   cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
   cg.SetConstantsByteAlignment(target->GetAttr<Integer>("constants-byte-alignment").value_or(16));
-  PrimFunc aot_executor_fn;
-
-  std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
-  for (auto kv : mod->functions) {
-    // Make sure that the executor function is the last one to be code generated so that all the
-    // symbols are available to __tvm_main__
-    auto fun_name = std::string(kv.first->name_hint);
-    bool is_aot_executor_fn = kv.second->GetAttr<Bool>("runner_function", Bool(false)).value();
-
-    if (is_aot_executor_fn) {
-      aot_executor_fn = Downcast<PrimFunc>(kv.second);
-      continue;
-    }
-    funcs.push_back(kv);
+
+  auto is_aot_executor_fn = [](const PrimFunc& func) -> bool {
+    return func->GetAttr<Bool>("runner_function", Bool(false)).value();
+  };
+
+  std::vector<std::pair<GlobalVar, PrimFunc>> funcs;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    funcs.push_back({gvar, prim_func});
   }
 
   // Sort functions
-  std::sort(funcs.begin(), funcs.end(),
-            [](std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_a,
-               std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_b) {
-              std::string name_hint_a = kv_a.first->name_hint;
-              std::string name_hint_b = kv_b.first->name_hint;
-              return name_hint_a < name_hint_b;
-            });
-
-  // Add all functions except __tvm_main__
-  for (auto& kv : funcs) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
-    cg.AddFunction(f);
+  auto sort_key = [&is_aot_executor_fn](const auto& kv) {
+    return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint};
+  };
+  std::sort(funcs.begin(), funcs.end(), [&sort_key](const auto& kv_a, const auto& kv_b) {
+    return sort_key(kv_a) < sort_key(kv_b);
+  });
+
+  // Declare all functions first.  This ensures that all functions,
+  // including the __tvm_main__ used in AOT, have access to forward
+  // declarations of other functions in the IRModule.
+  for (const auto& [gvar, prim_func] : funcs) {
+    cg.DeclareFunction(gvar, prim_func);
   }
 
-  // Add __tvm_main__
-  if (aot_executor_fn.defined()) {
-    emit_fwd_func_decl = true;
-    cg.AddFunction(aot_executor_fn, emit_fwd_func_decl);
+  // Codegen all functions.  Passing emit_fwd_func_decl=true adds a
+  // forward declaration for any `builtin::call_extern`, based on the
+  // arguments provided to it.
+  for (const auto& [gvar, prim_func] : funcs) {
+    cg.AddFunction(gvar, prim_func, emit_fwd_func_decl);
   }
 
   // NOTE: it's possible that kRuntime attr is not attached when the mod was built with tvm.build().
@@ -484,7 +476,10 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
   } else {
     runtime = relay::Runtime::Create("cpp", {});
   }
-  if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) {
+
+  bool has_aot_executor_fn = std::any_of(
+      funcs.begin(), funcs.end(), [&](const auto& kv) { return is_aot_executor_fn(kv.second); });
+  if (has_aot_executor_fn && runtime->name == relay::kTvmRuntimeCpp) {
     cg.InitGlobalContext();
   }
 
diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h
index 694104afc0..aeba685f74 100644
--- a/src/target/source/codegen_c_host.h
+++ b/src/target/source/codegen_c_host.h
@@ -44,8 +44,7 @@ class CodeGenCHost : public CodeGenC {
             const std::unordered_set<std::string>& devices);
 
   void InitGlobalContext();
-  void AddFunction(const PrimFunc& f, bool emit_fwd_func_decl = false);
-  std::string Finish() final;
+  void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool emit_fwd_func_decl = false);
   /*!
    * \brief Add functions from the (unordered) range to the current module in a deterministic
    * order. This helps with debugging.
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index 22103f7b0f..6c02348191 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -75,7 +75,7 @@ class ThreadIdxExtractor : public tir::StmtVisitor {
   PrimExpr threadIdx_z_ext = Integer(1);
 };
 
-void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
+void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
   ThreadIdxExtractor extractor;
   extractor(f->body);
   arith::Analyzer analyzer;
@@ -86,7 +86,7 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
       // unable to extract the number of threads per block, hence directly return
       return;
     }
-    stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
+    os << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
   }
 }
 
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index c6cf96d460..7de6ae05e8 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -47,7 +47,7 @@ class CodeGenCUDA final : public CodeGenC {
   }
   // override behavior
   void PrintFuncPrefix(std::ostream& os) final;
-  void PrintExtraAttrs(const PrimFunc& f) final;
+  void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final;  // NOLINT(*)
   void VisitStmt_(const ForNode* op) final;
   void PrintStorageSync(const CallNode* op) final;
   void PrintStorageScope(const std::string& scope, std::ostream& os) final;  // NOLINT(*)
diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc
index b8c30691e2..3db8d216b3 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -36,6 +36,8 @@ namespace codegen {
 
 void CodeGenMetal::InitFuncState(const PrimFunc& f) {
   CodeGenC::InitFuncState(f);
+  // skip the first underscore, so SSA variable starts from _1
+  name_supply_->FreshName("v_");
   // analyze the data;
   for (Var arg : f->params) {
     if (arg.dtype().is_handle()) {
@@ -52,37 +54,33 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) {
               << "};\n\n";
 }
 
-void CodeGenMetal::AddFunction(const PrimFunc& f) {
-  // clear previous generated state.
-  this->InitFuncState(f);
-  // skip the first underscore, so SSA variable starts from _1
-  name_supply_->FreshName("v_");
-
+void CodeGenMetal::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
+                                          std::ostream& os) {
   // add to alloc buffer type.
-  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
   ICHECK(global_symbol.defined())
       << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
 
   // Function header.
-  this->stream << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
+  os << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
 
   // Buffer arguments
   size_t num_buffer = 0;
   size_t limit = target_->GetAttr<Integer>("max_function_args").value().IntValue();
-  if (f->params.size() > limit) {
+  if (func->params.size() > limit) {
     LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of "
                     "buffers in the kernel";
   }
-  for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) {
-    Var v = f->params[i];
+  for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) {
+    Var v = func->params[i];
     if (!v.dtype().is_handle()) break;
-    stream << "  ";
+    os << "  ";
     std::string vid = AllocVarID(v.get());
     auto it = alloc_storage_scope_.find(v.get());
     if (it != alloc_storage_scope_.end()) {
-      PrintStorageScope(it->second, stream);
+      PrintStorageScope(it->second, os);
     }
-    PrintType(GetType(v), stream);
+    PrintType(GetType(v), os);
     // Register handle data type
     // TODO(tvm-team): consider simply keep type info in the
     // type annotation(via a normalizing rewriting).
@@ -91,19 +89,18 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
         RegisterHandleType(v.get(), prim->dtype);
       }
     }
-    stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
+    os << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
   }
   // Setup normal arguments.
-  size_t nargs = f->params.size() - num_buffer;
+  size_t nargs = func->params.size() - num_buffer;
   std::string varg = name_supply_->FreshName("arg");
   if (nargs != 0) {
     std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) + "_args_t";
-    stream << "  constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer
-           << ") ]],\n";
+    os << "  constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n";
     // declare the struct
     decl_stream << "struct " << arg_buf_type << " {\n";
-    for (size_t i = num_buffer; i < f->params.size(); ++i) {
-      Var v = f->params[i];
+    for (size_t i = num_buffer; i < func->params.size(); ++i) {
+      Var v = func->params[i];
       ICHECK(!v.dtype().is_handle());
       std::string vid = AllocVarID(v.get());
       std::ostringstream vref;
@@ -131,7 +128,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
   ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
   ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
   int work_dim = 0;
-  auto launch_params = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value();
+  auto launch_params = func->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value();
   for (const auto& tag : launch_params) {
     if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) {
       runtime::ThreadScope scope = runtime::ThreadScope::Create(tag);
@@ -150,13 +147,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
   }
   thread_work_dim_ = work_dim;
 
-  // the function scope.
-  stream << ") {\n";
-  int func_scope = this->BeginScope();
-  this->PrintStmt(f->body);
-  this->EndScope(func_scope);
-  this->PrintIndent();
-  this->stream << "}\n\n";
+  stream << ")";
 }
 
 void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
@@ -342,27 +333,33 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
   const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile");
   std::string fmt = fmetal_compile ? "metallib" : "metal";
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
-    auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    ICHECK(global_symbol.defined());
-    std::string func_name = global_symbol.value();
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
+    auto calling_conv = base_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
+        << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
+
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    functions.Set(gvar, prim_func);
+  }
 
-    source_maker << "// Function: " << func_name << "\n";
+  for (auto [gvar, prim_func] : functions) {
+    source_maker << "// Function: " << gvar->name_hint << "\n";
     CodeGenMetal cg(target);
     cg.Init(output_ssa);
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
-    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
-        << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
 
-    cg.AddFunction(f);
+    for (auto [other_gvar, other_prim_func] : functions) {
+      cg.DeclareFunction(other_gvar, other_prim_func);
+    }
+    cg.AddFunction(gvar, prim_func);
+
     std::string fsource = cg.Finish();
     source_maker << fsource << "\n";
     if (fmetal_compile) {
       fsource = (*fmetal_compile)(fsource, target).operator std::string();
     }
-    smap[func_name] = fsource;
+    smap[cg.GetFunctionName(gvar)] = fsource;
   }
 
   return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str());
diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h
index 36be10d163..26c991e60d 100644
--- a/src/target/source/codegen_metal.h
+++ b/src/target/source/codegen_metal.h
@@ -38,7 +38,8 @@ class CodeGenMetal final : public CodeGenC {
   explicit CodeGenMetal(Target target);
   // override print thread tag.
   void PrintArgUnionDecl();
-  void AddFunction(const PrimFunc& f);  // NOLINT(*)
+  void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
+                              std::ostream& os) override;
   void InitFuncState(const PrimFunc& f) final;
   void PrintStorageScope(const std::string& scope, std::ostream& os) final;  // NOLINT(*)
   void PrintStorageSync(const CallNode* op) final;                           // NOLINT(*)
diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc
index c15d2253d7..da6a4de619 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -595,18 +595,26 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) {
   using tvm::runtime::Registry;
   bool output_ssa = false;
 
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
+        << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
+    functions.Set(gvar, prim_func);
+  }
+
   std::stringstream code;
   const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc");
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
-    code << "// Function: " << kv.first->name_hint << std::endl;
+  for (auto [gvar, prim_func] : functions) {
+    code << "// Function: " << gvar->name_hint << std::endl;
     CodeGenOpenCL cg;
     cg.Init(output_ssa);
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
-    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
-        << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-    cg.AddFunction(f);
+    for (auto [other_gvar, other_prim_func] : functions) {
+      cg.DeclareFunction(other_gvar, other_prim_func);
+    }
+    cg.AddFunction(gvar, prim_func);
     std::string fsource = cg.Finish();
     if (fpostproc) {
       fsource = (*fpostproc)(fsource, target).operator std::string();
diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc
index 83046de107..aa7a32320c 100644
--- a/src/target/source/codegen_vhls.cc
+++ b/src/target/source/codegen_vhls.cc
@@ -145,13 +145,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) {
   // Generate source code for get_source().
   cg.Init(output_ssa);
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-    cg.AddFunction(f);
+    functions.Set(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    cg.DeclareFunction(gvar, prim_func);
+  }
+  for (auto [gvar, prim_func] : functions) {
+    cg.AddFunction(gvar, prim_func);
   }
 
   std::string whole_code = cg.Finish();
@@ -159,21 +167,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) {
   // Generate source code for compilation.
   Array<Array<runtime::String>> kernel_info;
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
+  for (auto [gvar, prim_func] : functions) {
     CodeGenVivadoHLS cg;
     cg.Init(output_ssa);
-    cg.AddFunction(f);
+
+    for (auto [other_gvar, other_prim_func] : functions) {
+      cg.DeclareFunction(other_gvar, other_prim_func);
+    }
+    cg.AddFunction(gvar, prim_func);
     std::string code = cg.Finish();
     if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
       code = (*f)(code, target).operator std::string();
     }
 
-    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    ICHECK(global_symbol.defined())
-        << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
-    kernel_info.push_back({global_symbol.value(), code});
+    auto function_name = cg.GetFunctionName(gvar);
+    kernel_info.push_back({function_name, code});
   }
 
   std::string xclbin;
diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc
index 4d1d834c7f..6a6712a4ce 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -45,6 +45,12 @@ std::string CodeGenWebGPU::Finish() {
 
 void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
   CodeGenC::InitFuncState(f);
+  // skip the first underscore, so SSA variable starts from
+  name_supply_->FreshName("v_");
+  // Setup the thread group info.
+  ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
+  ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
+
   // analyze the data;
   for (Var arg : f->params) {
     if (arg.dtype().is_handle()) {
@@ -56,28 +62,12 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
 
 CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
 
-void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
-  // clear previous generated state.
-  this->InitFuncState(f);
-  // skip the first underscore, so SSA variable starts from
-  name_supply_->FreshName("v_");
-  // Setup the thread group info.
-  ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
-  ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
-
-  // add to alloc buffer type.
-  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-  ICHECK(global_symbol.defined())
-      << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
-
-  decl_stream << "//----------------------------------------\n"
-              << "// function: " << global_symbol.value() << "\n"
-              << "//----------------------------------------\n";
-
+void CodeGenWebGPU::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
+                                           std::ostream& os) {
   std::vector<Var> pod_args;
   int num_buffer = 0;
   // setup buffer argumemts
-  for (Var arg : f->params) {
+  for (Var arg : func->params) {
     DataType t = arg.dtype();
     if (t.is_handle()) {
       auto* ptr = arg->type_annotation.as<PointerTypeNode>();
@@ -111,16 +101,18 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
   }
   // add to alloc buffer type.
   // Function header.
-  this->stream << "fn main(\n"
-               << "  @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
-               << "  @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
-               << ") {\n";
-  // the function scope.
-  int func_scope = this->BeginScope();
-  this->PrintStmt(f->body);
-  this->EndScope(func_scope);
-  this->PrintIndent();
-  this->stream << "}\n\n";
+  os << "fn main(\n"
+     << "  @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
+     << "  @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
+     << ")";
+}
+
+void CodeGenWebGPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
+  CodeGenC::AddFunction(gvar, func);
+  decl_stream << "//----------------------------------------\n"
+              << "// function: " << GetFunctionName(gvar) << "\n"
+              << "//----------------------------------------\n";
+
   // anotate workgroup
   this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0] << ", "
                         << workgroup_size_[1] << ", " << workgroup_size_[2] << ")\n";
@@ -524,22 +516,31 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) {
   mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));
   bool output_ssa = false;
 
-  std::unordered_map<std::string, std::string> smap;
-  for (auto kv : mod->functions) {
-    CodeGenWebGPU cg(target);
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
     ICHECK(global_symbol.defined())
         << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
-    std::string f_name = global_symbol.value();
+    functions.Set(gvar, prim_func);
+  }
+
+  std::unordered_map<std::string, std::string> smap;
+  for (auto [gvar, prim_func] : functions) {
+    CodeGenWebGPU cg(target);
     cg.Init(output_ssa);
-    cg.AddFunction(f);
+
+    for (auto [other_gvar, other_prim_func] : functions) {
+      cg.DeclareFunction(other_gvar, other_prim_func);
+    }
+    cg.AddFunction(gvar, prim_func);
+
     std::string code = cg.Finish();
-    smap[f_name] = code;
+    smap[cg.GetFunctionName(gvar)] = code;
   }
   auto n = make_object<WebGPUSourceModuleNode>(smap, ExtractFuncInfo(mod));
   return runtime::Module(n);
diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h
index 57f226ba8a..6ae942a3ad 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -48,7 +48,9 @@ class CodeGenWebGPU final : public CodeGenC {
   explicit CodeGenWebGPU(Target target);
   // overrides
   std::string Finish() final;
-  void AddFunction(const PrimFunc& f);  // NOLINT(*)
+  void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
+                              std::ostream& os) final;
+  void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final;
   void InitFuncState(const PrimFunc& f) final;
   void PrintStorageSync(const CallNode* op) final;     // NOLINT(*)
   void PrintType(DataType t, std::ostream& os) final;  // NOLINT(*)
diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc
index be5179e081..c75f3008ef 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -574,12 +574,14 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
       }
 
       for (const tir::Var& pool_var : metadata_->pools) {
+        call_args_ss << "((uint8_t*)";
         String pool_name = metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name;
         if (IsInternalWorkspaceBuffer(pool_var)) {
-          call_args_ss << "&" << pool_name << ",";
+          call_args_ss << "&" << pool_name;
         } else {
-          call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << ",";
+          call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name);
         }
+        call_args_ss << "),";
       }
       for (const String& device : metadata_->devices) {
         call_args_ss << "devices->" << device << ",";
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 39214c4546..fd14f48921 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -70,6 +70,32 @@ Type GetType(const PrimExpr& expr) {
       return ptr->type_annotation;
     }
   }
+
+  if (auto* access = expr.as<tir::CallNode>()) {
+    if (access->op.same_as(builtin::tvm_access_ptr())) {
+      ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have empty arguments";
+      auto type_annotation = Downcast<Call>(access->args[0]);
+      static auto builtin_op = Op::Get("tir.type_annotation");
+      ICHECK(type_annotation->op.same_as(builtin_op))
+          << "Expected the first argument of builtin tvm_access_ptr() "
+          << "to be a type annotation, but found " << type_annotation->op;
+      return PointerType(PrimType(type_annotation->dtype));
+    }
+  }
+
+  if (auto* address_of = expr.as<tir::CallNode>()) {
+    if (address_of->op.same_as(builtin::address_of())) {
+      ICHECK_EQ(address_of->args.size(), 1)
+          << "Builtin address_of() expects a single argument, but received arguments "
+          << address_of->args;
+      auto* address = address_of->args[0].as<BufferLoadNode>();
+      ICHECK(address)
+          << "Builtin address_of() expects the argument to be a BufferLoad, but received argument "
+          << address_of->args[0];
+
+      return PointerType(PrimType(address->dtype));
+    }
+  }
   // Default: return the type indicated by the dtype.
   runtime::DataType dtype = expr.dtype();
   return GetTypeFromRuntimeDataType(dtype);
diff --git a/tests/python/relay/aot/test_crt_forward_declarations.py b/tests/python/relay/aot/test_crt_forward_declarations.py
index 17af7a5d68..0c73f18e8a 100644
--- a/tests/python/relay/aot/test_crt_forward_declarations.py
+++ b/tests/python/relay/aot/test_crt_forward_declarations.py
@@ -160,8 +160,8 @@ def test_internal_calls(interface_api, use_unpacked_api, test_runner):
 
     lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0]
     main_source = lib_mod.get_source()
-    assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 1
-    assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 3
+    assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 2
+    assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 6
 
 
 @tvm.testing.requires_corstone300
diff --git a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
index 7bea7577b6..f6145cd1c5 100644
--- a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
+++ b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
@@ -135,8 +135,13 @@ def test_write_3x3_depthwise_code():
     #define TENSORDOT_OPT_X1_INT16_W48_3X3_000_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w48_3x3_000(
-        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
+        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+        int32_t *bias, int32_t *scale
     ) {
+      int32_t *output = output_arg;
+      int32_t *tensor = tensor_arg;
+      int32_t *kernel = kernel_arg;
+
       int32_t sum_0 = *bias;
 
       int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -188,8 +193,13 @@ def test_odd_width_3x3_depthwise_strides_code():
     #define TENSORDOT_OPT_X2_INT16_W49_3X3_000_2_4_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t tensordot_opt_x2_int16_w49_3x3_000_2_4(
-        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
+        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+        int32_t *bias, int32_t *scale
     ) {
+      int32_t *output = output_arg;
+      int32_t *tensor = tensor_arg;
+      int32_t *kernel = kernel_arg;
+
       int32_t sum_0 = *bias, sum_1 = *bias;
 
       int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -251,8 +261,13 @@ def test_1x1x8_convolution_code():
     #define TENSORDOT_OPT_X4_INT16_W384_1X8_000_8_1_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t tensordot_opt_x4_int16_w384_1x8_000_8_1(
-        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
+        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+        int32_t *bias, int32_t *scale
     ) {
+      int32_t *output = output_arg;
+      int32_t *tensor = tensor_arg;
+      int32_t *kernel = kernel_arg;
+
       int32_t sum_0 = *bias, sum_1 = *bias, sum_2 = *bias, sum_3 = *bias;
 
       int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -349,8 +364,13 @@ def test_3x3x3_offset_convolution_code():
     #define TENSORDOT_OPT_X1_INT16_W288_3X9_111_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w288_3x9_111(
-        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
+        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+        int32_t *bias, int32_t *scale
     ) {
+      int32_t *output = output_arg;
+      int32_t *tensor = tensor_arg;
+      int32_t *kernel = kernel_arg;
+
       int32_t sum_0 = *bias;
 
       int32_t tensor__unknown__y00_x00 = tensor[0];
diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py
index d02f8744f1..3aca0fc8c7 100644
--- a/tests/python/unittest/test_target_codegen_c_host.py
+++ b/tests/python/unittest/test_target_codegen_c_host.py
@@ -14,11 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 import tvm
 import tvm.testing
+
 from tvm import te
-import numpy as np
 from tvm.contrib import utils
+from tvm.script import tir as T, ir as I
+
+import numpy as np
 
 
 def test_add():
@@ -228,11 +232,39 @@ def test_call_packed():
     check_global_packed_func()
 
 
+def test_subroutine_call():
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def main(A: T.Buffer(1, dtype="float32")):
+            mod.subroutine(A.data)
+
+        @T.prim_func(private=True)
+        def subroutine(A_data: T.handle("float32")):
+            A = T.decl_buffer(1, dtype="float32", data=A_data)
+            A[0] = 42.0
+
+    built = tvm.build(mod, target="c")
+
+    func_names = list(built["get_func_names"]())
+    assert (
+        "main" in func_names
+    ), "Externally exposed functions should be listed in available functions."
+    assert (
+        "subroutine" not in func_names
+    ), "Internal function should not be listed in available functions."
+
+    source = built.get_source()
+    assert (
+        source.count("main(void*") == 2
+    ), "Expected two occurrences, for forward-declaration and definition"
+    assert (
+        source.count("subroutine(float*") == 2
+    ), "Expected two occurrences, for forward-declaration and definition"
+    assert (
+        source.count("subroutine(") == 3
+    ), "Expected three occurrences, for forward-declaration, definition, and call from main."
+
+
 if __name__ == "__main__":
-    test_add()
-    test_add_pipeline()
-    test_reinterpret()
-    test_ceil()
-    test_floor()
-    test_round()
-    test_call_packed()
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index 3543f798c3..b39fca72c8 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -204,6 +204,7 @@ expected_cuda_script = r"""
   #define int64_t long long
   #define uint64_t unsigned long long
 #endif
+extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C);
 extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) {
   __shared__ float A_shared[64];
   __shared__ float B_shared[64];