You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/12 00:33:31 UTC

[tvm] 01/01: Revert "[CodeGenC] Handle GlobalVar callee as internal function call (#15103)"

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch revert-15103-codegen_c_support_subroutine_calls
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 2902c10cc700671cbb538c679b7a375d20596cee
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Sep 11 17:33:25 2023 -0700

    Revert "[CodeGenC] Handle GlobalVar callee as internal function call (#15103)"
    
    This reverts commit 9ff71f4a9fed3ec9f82b999fb8c25ef1bf6e243c.
---
 .../arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py  |   8 +-
 .../topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py |  87 +++---------
 .../arm_cpu/mprofile/dsp/micro_kernel/max_pool.py  |  13 +-
 .../arm_cpu/mprofile/dsp/micro_kernel/tensordot.py |   7 +-
 .../backend/contrib/cmsisnn/tir_to_runtime.cc      |  28 ++--
 .../contrib/example_target_hooks/tir_to_runtime.cc |  26 +---
 src/relay/backend/contrib/uma/tir_to_runtime.cc    |  34 ++---
 src/target/opt/build_cuda_on.cc                    |  18 +--
 src/target/source/codegen_aocl.cc                  |  19 +--
 src/target/source/codegen_c.cc                     | 153 +++++++--------------
 src/target/source/codegen_c.h                      |  59 +-------
 src/target/source/codegen_c_host.cc                |  93 +++++++------
 src/target/source/codegen_c_host.h                 |   3 +-
 src/target/source/codegen_cuda.cc                  |   4 +-
 src/target/source/codegen_cuda.h                   |   2 +-
 src/target/source/codegen_metal.cc                 |  77 ++++++-----
 src/target/source/codegen_metal.h                  |   3 +-
 src/target/source/codegen_opencl.cc                |  24 ++--
 src/target/source/codegen_vhls.cc                  |  34 ++---
 src/target/source/codegen_webgpu.cc                |  79 ++++++-----
 src/target/source/codegen_webgpu.h                 |   4 +-
 src/target/source/source_module.cc                 |   6 +-
 src/tir/op/op.cc                                   |  26 ----
 .../relay/aot/test_crt_forward_declarations.py     |   4 +-
 .../topi/python/test_topi_conv2d_tensordot_opts.py |  28 +---
 .../python/unittest/test_target_codegen_c_host.py  |  48 ++-----
 .../test_tir_transform_inject_ptx_async_copy.py    |   1 -
 27 files changed, 297 insertions(+), 591 deletions(-)

diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
index 3eb32d8fdb..e8e45152aa 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
@@ -55,7 +55,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False):
             ib = tvm.tir.ir_builder.create()
             ib.emit(
                 tvm.tir.call_extern(
-                    "int32",
+                    cc.dtype,
                     f"{func_prefix}_{width}_{uniq_id}",
                     aa.access_ptr("r"),
                     cc.access_ptr("w"),
@@ -68,7 +68,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False):
         def _reduce_reset():
             ib = tvm.tir.ir_builder.create()
             ib.emit(
-                tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
+                tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
             )
             return ib.get()
 
@@ -113,8 +113,8 @@ extern "C"
 __attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
     int16_t *arr,
     int16_t *res16,
-    int32_t arr_offset,
-    int32_t reset) {{
+    long arr_offset,
+    int reset) {{
   int n;
   int32_t *p32;
   int32_t res = reset ? 0 : *res16;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
index e26e818fbd..929dcc6557 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
@@ -156,14 +156,9 @@ __attribute__((always_inline)) static inline const int8_t *read_and_pad(const in
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
-    int32_t K_arg,
+    int K,
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int K = K_arg;
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
+    int A_stride, int B_stride, int C_stride) {{
   int k_base = (K / 4) * 4;
   switch ( K % 4 ) {{
   case 1:
@@ -205,12 +200,7 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
-
+    int A_stride, int B_stride, int C_stride) {{
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -231,11 +221,7 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
+    int A_stride, int B_stride, int C_stride) {{
   int16_t bb_pad[{bb_pad_size}];
   int32_t retcode = 0;
 
@@ -279,14 +265,9 @@ out:
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
-    int32_t K_arg,
+    int K,
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int K = K_arg;
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
+    int A_stride, int B_stride, int C_stride) {{
   int k_base = (K / 4) * 4;
   switch ( K % 4 ) {{
   case 1:
@@ -328,11 +309,7 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
+    int A_stride, int B_stride, int C_stride) {{
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -350,11 +327,7 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
+    int A_stride, int B_stride, int C_stride) {{
   int16_t bb_pad[{bb_pad_size}];
   int32_t retcode = 0;
 
@@ -395,14 +368,9 @@ out:
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
-    int32_t K_arg,
+    int K,
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int K = K_arg;
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
+    int A_stride, int B_stride, int C_stride) {{
   int k_base = (K / 2) * 2;
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
@@ -419,11 +387,7 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
+    int A_stride, int B_stride, int C_stride) {{
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -444,11 +408,7 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
+    int A_stride, int B_stride, int C_stride) {{
   int32_t retcode = 0;
 
   if ( {M} < 2 && {N} < 2 ) {{
@@ -490,14 +450,9 @@ out:
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
-    int32_t K_arg,
+    int K,
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int K = K_arg;
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
+    int A_stride, int B_stride, int C_stride) {{
   int k_base = (K / 2) * 2;
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
@@ -514,11 +469,7 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
+    int A_stride, int B_stride, int C_stride) {{
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -536,11 +487,7 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
-  int A_stride = A_stride_arg;
-  int B_stride = B_stride_arg;
-  int C_stride = C_stride_arg;
-
+    int A_stride, int B_stride, int C_stride) {{
   int32_t retcode = 0;
 
   if ( {M} < 2 && {N} < 2 ) {{
@@ -573,7 +520,7 @@ out:
 #ifdef __cplusplus
 extern "C"
 #endif
-__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{
+__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       cc[i*C_stride + j] = 0;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
index cfed417c9f..66d712a4a0 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
@@ -46,7 +46,7 @@ def intrin_max(shape, in_dtype, out_dtype):
             ib = tvm.tir.ir_builder.create()
             ib.emit(
                 tvm.tir.call_extern(
-                    "int32",
+                    cc.dtype,
                     f"{func_prefix}_{uniq_id}",
                     aa.access_ptr("r"),
                     cc.access_ptr("w"),
@@ -59,7 +59,7 @@ def intrin_max(shape, in_dtype, out_dtype):
             ib = tvm.tir.ir_builder.create()
             ib.emit(
                 tvm.tir.call_extern(
-                    "int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
+                    cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
                 )
             )
             return ib.get()
@@ -96,7 +96,7 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
     int8_t *res,
-    int32_t N) {{
+    int N) {{
   memset(res, (int8_t)-128, N * sizeof(*res));
   return 0;
 }}
@@ -107,9 +107,7 @@ extern "C"
 __attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
     int8_t *arg,
     int8_t *res,
-    int32_t N_arg) {{
-  int N = N_arg;
-
+    int N) {{
   for ( int i = 0; i < N; ++ i )
     if ( arg[i] > res[i] )
       res[i] = arg[i];
@@ -122,8 +120,7 @@ extern "C"
 __attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
     int8_t *arg,
     int8_t *res,
-    int32_t N_arg) {{
-  int N = N_arg;
+    int N) {{
   int32_t *parg32, *pres32;
   int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3;
   int32_t retcode = 0;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
index af3b23e01d..d2a8f1ef69 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
@@ -390,13 +390,8 @@ def tensordot_int16_impl(
         #define {function_name.upper()}_EXISTS
         #include <arm_acle.h>
         __attribute__((always_inline)) static inline int32_t {function_name}(
-            int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
-            int32_t *bias, int32_t *scale
+            int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
         ) {{
-          int32_t *output = output_arg;
-          int32_t *tensor = tensor_arg;
-          int32_t *kernel = kernel_arg;
-
           {_init_biased_accumulators(num_outputs)}
 
           {insert_lines(load_tensor_lines)}
diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
index 6febfe3486..186fa30f20 100644
--- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
@@ -46,6 +46,13 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
     CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices);
   }
 
+  /*!
+   * \brief Emit code that offloads a subgraph to the Cortex-M
+   *
+   * \return string of code that offloads a subgraph to the Cortex-M
+   */
+  void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }
+
  private:
   /*!  * \brief Enable storing the last error */
   bool debug_last_error;
@@ -568,11 +575,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
   bool emit_fwd_func_decl = false;
   bool debug_last_error = GetCompilerAttrs()->debug_last_error;
   CodeGenCMSISNN codegen;
+  Array<String> function_names;
   codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error);
-
-  std::vector<std::pair<tvm::GlobalVar, tvm::PrimFunc>> funcs;
-  for (auto [gvar, base_func] : mod->functions) {
-    funcs.push_back({gvar, Downcast<PrimFunc>(base_func)});
+  std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
+  for (auto kv : mod->functions) {
+    funcs.push_back(kv);
   }
 
   std::sort(funcs.begin(), funcs.end(),
@@ -587,16 +594,13 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
               return name_hint_a < name_hint_b;
             });
 
-  for (auto [gvar, prim_func] : funcs) {
-    codegen.AddFunction(gvar, prim_func);
+  for (auto kv : funcs) {
+    auto prim_func = Downcast<PrimFunc>(kv.second);
+    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    function_names.push_back(global_symbol.value());
+    codegen.AddFunction(prim_func);
   }
   std::string code = codegen.Finish();
-
-  Array<String> function_names;
-  for (auto [gvar, prim_func] : funcs) {
-    function_names.push_back(codegen.GetFunctionName(gvar));
-  }
-
   return codegen::CSourceModuleCreate(code, "c", function_names);
 }
 
diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
index 6f09e0a0c3..0db8d06c31 100644
--- a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
@@ -49,30 +49,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
   bool emit_asserts = false;
   bool emit_fwd_func_decl = false;
   CodeGenExampleTargetHook codegen;
-
+  Array<String> function_names;
   std::unordered_set<std::string> devices;
   codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
-
-  Map<GlobalVar, PrimFunc> functions;
-  for (auto [gvar, base_func] : mod->functions) {
-    auto prim_func = Downcast<PrimFunc>(base_func);
-    functions.Set(gvar, prim_func);
-  }
-
-  for (auto [gvar, prim_func] : functions) {
-    codegen.DeclareFunction(gvar, prim_func);
-  }
-  for (auto [gvar, prim_func] : functions) {
-    codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
+  for (auto kv : mod->functions) {
+    auto prim_func = Downcast<PrimFunc>(kv.second);
+    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    function_names.push_back(global_symbol.value());
+    codegen.AddFunction(prim_func);
   }
-
   std::string code = codegen.Finish();
-
-  Array<String> function_names;
-  for (auto [gvar, prim_func] : functions) {
-    function_names.push_back(codegen.GetFunctionName(gvar));
-  }
-
   return codegen::CSourceModuleCreate(code, "c", function_names);
 }
 
diff --git a/src/relay/backend/contrib/uma/tir_to_runtime.cc b/src/relay/backend/contrib/uma/tir_to_runtime.cc
index 487e247f5d..3b58fda54b 100644
--- a/src/relay/backend/contrib/uma/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/uma/tir_to_runtime.cc
@@ -49,6 +49,13 @@ class UMACodegen : public codegen::CodeGenCHost {
     CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str_, devices);
   }
 
+  /*!
+   * \brief Emit code that offloads a subgraph to the UMA target
+   *
+   * \return string of code that offloads a subgraph to the UMA target
+   */
+  void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }
+
  private:
   String target_str_;
 };
@@ -56,30 +63,17 @@ class UMACodegen : public codegen::CodeGenCHost {
 runtime::Module TIRToRuntime(IRModule mod, Target target) {
   bool output_ssa = false;
   bool emit_asserts = false;
-  bool emit_fwd_func_decl = true;
+  bool emit_fwd_func_decl = false;
   UMACodegen codegen(target->kind->name);
+  Array<String> function_names;
   codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl);
-
-  Map<GlobalVar, PrimFunc> functions;
-  for (auto [gvar, base_func] : mod->functions) {
-    auto prim_func = Downcast<PrimFunc>(base_func);
-    functions.Set(gvar, prim_func);
-  }
-
-  for (auto [gvar, prim_func] : functions) {
-    codegen.DeclareFunction(gvar, prim_func);
-  }
-  for (auto [gvar, prim_func] : functions) {
-    codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
+  for (auto kv : mod->functions) {
+    auto prim_func = Downcast<PrimFunc>(kv.second);
+    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    function_names.push_back(global_symbol.value());
+    codegen.AddFunction(prim_func);
   }
-
   std::string code = codegen.Finish();
-
-  Array<String> function_names;
-  for (auto [gvar, prim_func] : functions) {
-    function_names.push_back(codegen.GetFunctionName(gvar));
-  }
-
   return codegen::CSourceModuleCreate(code, "c", function_names);
 }
 
diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc
index e0f53e3509..1c0b5094ef 100644
--- a/src/target/opt/build_cuda_on.cc
+++ b/src/target/opt/build_cuda_on.cc
@@ -131,21 +131,13 @@ runtime::Module BuildCUDA(IRModule mod, Target target) {
   CodeGenCUDA cg;
   cg.Init(output_ssa);
 
-  Map<GlobalVar, PrimFunc> functions;
-  for (auto [gvar, base_func] : mod->functions) {
-    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
-    auto prim_func = Downcast<PrimFunc>(base_func);
-    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+  for (auto kv : mod->functions) {
+    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
+    auto f = Downcast<PrimFunc>(kv.second);
+    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-    functions.Set(gvar, prim_func);
-  }
-
-  for (auto [gvar, prim_func] : functions) {
-    cg.DeclareFunction(gvar, prim_func);
-  }
-  for (auto [gvar, prim_func] : functions) {
-    cg.AddFunction(gvar, prim_func);
+    cg.AddFunction(f);
   }
 
   std::string code = cg.Finish();
diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc
index dc3ba08751..700d85b4cc 100644
--- a/src/target/source/codegen_aocl.cc
+++ b/src/target/source/codegen_aocl.cc
@@ -40,22 +40,13 @@ runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) {
   CodeGenOpenCL cg;
   cg.Init(output_ssa);
 
-  Map<GlobalVar, PrimFunc> functions;
-  for (auto [gvar, base_func] : mod->functions) {
-    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only take PrimFunc";
-    auto prim_func = Downcast<PrimFunc>(base_func);
-    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+  for (auto kv : mod->functions) {
+    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only take PrimFunc";
+    auto f = Downcast<PrimFunc>(kv.second);
+    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-    functions.Set(gvar, prim_func);
-  }
-
-  for (auto [gvar, prim_func] : functions) {
-    cg.DeclareFunction(gvar, prim_func);
-  }
-
-  for (auto [gvar, prim_func] : functions) {
-    cg.AddFunction(gvar, prim_func);
+    cg.AddFunction(f);
   }
 
   std::string code = cg.Finish();
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 187bdc74fe..a7cc320562 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -42,7 +42,6 @@ void CodeGenC::InitFuncState(const PrimFunc& f) {
   alloc_storage_scope_.clear();
   handle_data_type_.clear();
   CodeGenSourceBase::ClearFuncState();
-  ReserveKeywordsAsUnique();
 }
 
 void CodeGenC::ReserveKeywordsAsUnique() {
@@ -76,92 +75,51 @@ void CodeGenC::ReserveKeywordsAsUnique() {
   name_supply_->ReserveName("return");
 }
 
-void CodeGenC::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
-                                      std::ostream& os) {
-  PrintFuncPrefix(os);
-  PrintType(func->ret_type, os);
-  PrintExtraAttrs(func, os);
-  os << " " << function_name << "(";
-  for (size_t i = 0; i < func->params.size(); ++i) {
-    tir::Var v = func->params[i];
-
-    if (i > 0) {
-      os << ", ";
-    }
-
-    if (auto it = alloc_storage_scope_.find(v.get()); it != alloc_storage_scope_.end()) {
-      PrintStorageScope(it->second, os);
-    }
-
-    PrintType(GetType(v), os);
-
-    bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
-    bool is_handle = v.dtype().is_handle();
-    if (no_alias && is_handle) {
-      PrintRestrict(v, os);
-    }
-
-    os << " " << AllocVarID(v.get());
-  }
-  os << ")";
+void CodeGenC::AddFunction(const PrimFunc& f) {
+  // clear previous generated state.
+  this->InitFuncState(f);
+  // reserve keywords
+  ReserveKeywordsAsUnique();
 
-  // Register handle data type
-  // TODO(tvm-team): consider simply keep type info in the
-  // type annotation(via a normalizing rewriting).
-  for (const auto& param : func->params) {
-    if (auto* ptr = param->type_annotation.as<PointerTypeNode>()) {
-      if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
-        RegisterHandleType(param.get(), prim->dtype);
+  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+  ICHECK(global_symbol.defined())
+      << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
+  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
+
+  this->PrintFuncPrefix(stream);
+  PrintType(f->ret_type, stream);
+  this->PrintExtraAttrs(f);
+  this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
+
+  for (size_t i = 0; i < f->params.size(); ++i) {
+    tir::Var v = f->params[i];
+    std::string vid = AllocVarID(v.get());
+    if (i != 0) stream << ", ";
+    if (v.dtype().is_handle()) {
+      auto it = alloc_storage_scope_.find(v.get());
+      if (it != alloc_storage_scope_.end()) {
+        PrintStorageScope(it->second, stream);
       }
-    }
-  }
-}
 
-void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) {
-  if (internal_functions_.count(gvar)) {
-    return;
-  }
+      PrintType(GetType(v), stream);
+      // Register handle data type
+      // TODO(tvm-team): consider simply keep type info in the
+      // type annotation(via a normalizing rewriting).
+      if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
+        if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
+          RegisterHandleType(v.get(), prim->dtype);
+        }
+      }
 
-  auto function_name = [&]() -> String {
-    if (auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
-      auto name = global_symbol.value();
-      ICHECK(!func_name_supply_->ContainsName(name))
-          << "Function " << gvar << " must use global symbol " << name
-          << ", but this name has already been used.";
-      func_name_supply_->ReserveName(name);
-      return name;
+      if (no_alias) {
+        PrintRestrict(v, stream);
+      }
     } else {
-      func_name_supply_->ReserveName(gvar->name_hint);
-      return gvar->name_hint;
+      PrintType(GetType(v), stream);
     }
-  }();
-
-  internal_functions_.insert({gvar, function_name});
-
-  InitFuncState(func);
-  PrintFunctionSignature(function_name, func, fwd_decl_stream);
-  fwd_decl_stream << ";\n";
-}
-
-String CodeGenC::GetFunctionName(const GlobalVar& gvar) {
-  auto it = internal_functions_.find(gvar);
-  ICHECK(it != internal_functions_.end())
-      << "Attempted to find name of " << gvar
-      << ", but no function with this GlobalVar has been declared";
-  return it->second;
-}
-
-void CodeGenC::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
-  // If the function has already been forward-declared, this is a
-  // no-op.
-  DeclareFunction(gvar, f);
-  auto function_name = GetFunctionName(gvar);
-
-  // clear previous generated state.
-  InitFuncState(f);
-
-  PrintFunctionSignature(function_name, f, stream);
-  stream << " {\n";
+    stream << ' ' << vid;
+  }
+  stream << ") {\n";
   this->PreFunctionBody(f);
   int func_scope = this->BeginScope();
   this->PrintStmt(f->body);
@@ -172,15 +130,9 @@ void CodeGenC::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
 
 void CodeGenC::PrintFuncPrefix(std::ostream& os) {}
 
-void CodeGenC::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {}
+void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}
 
-std::string CodeGenC::Finish() {
-  std::ostringstream code;
-  code << decl_stream.str();
-  code << fwd_decl_stream.str();
-  code << stream.str();
-  return code.str();
-}
+std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
 
 void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) {  // NOLINT(*)
   if (print_ssa_form_) {
@@ -590,17 +542,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
       ICHECK_GE(op->args.size(), 1U);
       auto func = Downcast<StringImm>(op->args[0]);
       this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
-
-      // If the call_extern refers to an function within the IRModule, then
-      // the forward declaration is already provided from DeclareFunction.
-      if (!func_name_supply_->ContainsName(func->value)) {
-        Array<Type> arg_types;
-        for (size_t i = 1; i < op->args.size(); i++) {
-          arg_types.push_back(GetType(op->args[i]));
-        }
-        Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
-        this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type);
+      Array<Type> arg_types;
+      for (size_t i = 1; i < op->args.size(); i++) {
+        arg_types.push_back(GetType(op->args[i]));
       }
+      Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
+      this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type);
     } else if (op_attr_global_symbol_.count(call_op)) {
       // call extern if the op itself have a global symbol.
       this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
@@ -668,13 +615,9 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
     } else {
       LOG(FATAL) << "Unresolved call " << op->op;
     }
-  } else if (auto opt = op->op.as<GlobalVar>()) {
-    auto gvar = opt.value();
-    auto callee_name = GetFunctionName(gvar);
-    PrintCallExtern(GetType(GetRef<PrimExpr>(op)), callee_name, op->args, false, os);
   } else {
-    LOG(FATAL) << "CodeGenC: Unknown operation " << op->op << " is neither a recognized built-in, "
-               << "nor a GlobalVar reference to another function in the IRModule";
+    ICHECK(op->op.as<GlobalVarNode>());
+    LOG(FATAL) << "Do not yet support cross function call";
   }
 }
 
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index 2921a56ef3..93f9ea519c 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -65,33 +65,12 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
    * \param output_ssa Whether output SSA.
    */
   void Init(bool output_ssa);
-
   /*!
-   * \brief Add the function declaration to the generated module,
-   * without defining it.
-   *
-   * \param gvar The GlobalVar representing the function.
-   * \param func The function to be compiled.
+   * \brief Add the function to the generated module.
+   * \param f The function to be compiled.
    * \param whether to append return 0 in the end.
    */
-  virtual void DeclareFunction(const GlobalVar& gvar, const PrimFunc& func);
-
-  /*!
-   * \brief Add the function to the generated module, including its
-   * declaration and definition.
-   *
-   * \param gvar The GlobalVar representing the function.
-   * \param func The function to be compiled.
-   */
-  virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& func);
-
-  /*!
-   * \brief Get the name of a declared function
-   * \param gvar The GlobalVar of the function
-   * \returns The string name of the function
-   */
-  String GetFunctionName(const GlobalVar& gvar);
-
+  void AddFunction(const PrimFunc& f);
   /*!
    * \brief Finalize the compilation and return the code.
    * \return The code.
@@ -117,23 +96,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
     PrintExpr(n, os);
     return os.str();
   }
-
   // The following parts are overloadable print operations.
-
-  /*! \brief Print the function signature before the argument list
-   *
-   * The default implementation delegates out to PrintFuncPrefix and
-   * PrintExtraAttrs.
-   *
-   * \param function_name The name of the function
-   *
-   * \param func The function whose signature should be printed
-   *
-   * \param os The output stream
-   */
-  virtual void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
-                                      std::ostream& os);
-
   /*!
    * \brief Print the function header before the argument list
    * \param os The output stream
@@ -146,7 +109,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
    *
    *  Example: __launch_bounds__(256) for CUDA functions
    */
-  virtual void PrintExtraAttrs(const PrimFunc& f, std::ostream& os);  // NOLINT(*)
+  virtual void PrintExtraAttrs(const PrimFunc& f);
   /*!
    * \brief Insert statement before function body.
    * \param f The function to be compiled.
@@ -321,24 +284,10 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
  private:
   /*! \brief set of volatile buf access */
   std::unordered_set<const VarNode*> volatile_buf_;
-
   // deep comparison of PrimExpr
   ExprDeepEqual deep_equal_;
-
   // binding of let variables. Enables duplicate var defs that map to same value
   std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
-
-  /* \brief Map of GlobalVar to their symbol.
-   *
-   * For externally-exposed functions, this is given by the
-   * tvm::attr::kTarget attribute of the PrimFunc.  For internal
-   * functions, this is the name of the function's GlobalVar, possibly
-   * altered to prevent duplicate names.
-   */
-  std::unordered_map<GlobalVar, String, ObjectPtrHash, ObjectPtrEqual> internal_functions_;
-
-  /* \brief Name supply to generate unique function names */
-  NameSupply func_name_supply_{""};
 };
 
 }  // namespace codegen
diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc
index caef43e8af..3255e11c5d 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -75,24 +75,19 @@ void CodeGenCHost::InitGlobalContext() {
 
 void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; }
 
-void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func,
-                               bool emit_fwd_func_decl) {
-  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-  if (global_symbol) {
-    function_names_.push_back(global_symbol.value());
-  }
+void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
+  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+  ICHECK(global_symbol.defined())
+      << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute";
+  function_names_.push_back(global_symbol.value());
 
   emit_fwd_func_decl_ = emit_fwd_func_decl;
-  CodeGenC::AddFunction(gvar, func);
-  if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-    ICHECK(global_symbol.defined())
-        << "CodeGenCHost: The entry func must have the global_symbol attribute, "
-        << "but function " << gvar << " only has attributes " << func->attrs;
-
+  CodeGenC::AddFunction(f);
+  if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
     function_names_.push_back(runtime::symbol::tvm_module_main);
     stream << "// CodegenC: NOTE: Auto-generated entry function\n";
     PrintFuncPrefix(stream);
-    PrintType(func->ret_type, stream);
+    PrintType(f->ret_type, stream);
     stream << " " << tvm::runtime::symbol::tvm_module_main
            << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
            << "int* out_ret_tcode, void* resource_handle) {\n";
@@ -133,6 +128,15 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) {  // NOLINT(*)
      << "TVM_DLL ";
 }
 
+std::string CodeGenCHost::Finish() {  // NOLINT(*)
+  std::string ret = decl_stream.str();
+  if (emit_fwd_func_decl_) {
+    ret += fwd_decl_stream.str();
+  }
+  ret += stream.str();
+  return ret;
+}
+
 void CodeGenCHost::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
   int lanes = t.lanes();
   if (t.is_handle()) {
@@ -433,38 +437,42 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
   CodeGenCHost cg;
   cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
   cg.SetConstantsByteAlignment(target->GetAttr<Integer>("constants-byte-alignment").value_or(16));
-
-  auto is_aot_executor_fn = [](const PrimFunc& func) -> bool {
-    return func->GetAttr<Bool>("runner_function", Bool(false)).value();
-  };
-
-  std::vector<std::pair<GlobalVar, PrimFunc>> funcs;
-  for (auto [gvar, base_func] : mod->functions) {
-    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
-    auto prim_func = Downcast<PrimFunc>(base_func);
-    funcs.push_back({gvar, prim_func});
+  PrimFunc aot_executor_fn;
+
+  std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
+  for (auto kv : mod->functions) {
+    // Make sure that the executor function is the last one to be code generated so that all the
+    // symbols are available to __tvm_main__
+    auto fun_name = std::string(kv.first->name_hint);
+    bool is_aot_executor_fn = kv.second->GetAttr<Bool>("runner_function", Bool(false)).value();
+
+    if (is_aot_executor_fn) {
+      aot_executor_fn = Downcast<PrimFunc>(kv.second);
+      continue;
+    }
+    funcs.push_back(kv);
   }
 
   // Sort functions
-  auto sort_key = [&is_aot_executor_fn](const auto& kv) {
-    return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint};
-  };
-  std::sort(funcs.begin(), funcs.end(), [&sort_key](const auto& kv_a, const auto& kv_b) {
-    return sort_key(kv_a) < sort_key(kv_b);
-  });
-
-  // Declare all functions first.  This ensures that all functions,
-  // including the __tvm_main__ used in AOT, have access to forward
-  // declarations of other functions in the IRModule.
-  for (const auto& [gvar, prim_func] : funcs) {
-    cg.DeclareFunction(gvar, prim_func);
+  std::sort(funcs.begin(), funcs.end(),
+            [](std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_a,
+               std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_b) {
+              std::string name_hint_a = kv_a.first->name_hint;
+              std::string name_hint_b = kv_b.first->name_hint;
+              return name_hint_a < name_hint_b;
+            });
+
+  // Add all functions except __tvm_main__
+  for (auto& kv : funcs) {
+    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
+    auto f = Downcast<PrimFunc>(kv.second);
+    cg.AddFunction(f);
   }
 
-  // Codegen all functions.  Passing emit_fwd_func_decl=true adds a
-  // forward declaration for any `builtin::call_extern`, based on the
-  // arguments provided to it.
-  for (const auto& [gvar, prim_func] : funcs) {
-    cg.AddFunction(gvar, prim_func, emit_fwd_func_decl);
+  // Add __tvm_main__
+  if (aot_executor_fn.defined()) {
+    emit_fwd_func_decl = true;
+    cg.AddFunction(aot_executor_fn, emit_fwd_func_decl);
   }
 
   // NOTE: it's possible that kRuntime attr is not attached when the mod was built with tvm.build().
@@ -476,10 +484,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
   } else {
     runtime = relay::Runtime::Create("cpp", {});
   }
-
-  bool has_aot_executor_fn = std::any_of(
-      funcs.begin(), funcs.end(), [&](const auto& kv) { return is_aot_executor_fn(kv.second); });
-  if (has_aot_executor_fn && runtime->name == relay::kTvmRuntimeCpp) {
+  if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) {
     cg.InitGlobalContext();
   }
 
diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h
index aeba685f74..694104afc0 100644
--- a/src/target/source/codegen_c_host.h
+++ b/src/target/source/codegen_c_host.h
@@ -44,7 +44,8 @@ class CodeGenCHost : public CodeGenC {
             const std::unordered_set<std::string>& devices);
 
   void InitGlobalContext();
-  void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool emit_fwd_func_decl = false);
+  void AddFunction(const PrimFunc& f, bool emit_fwd_func_decl = false);
+  std::string Finish() final;
   /*!
    * \brief Add functions from the (unordered) range to the current module in a deterministic
    * order. This helps with debugging.
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index 7639ce6065..a91f8b0164 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -75,7 +75,7 @@ class ThreadIdxExtractor : public tir::StmtVisitor {
   PrimExpr threadIdx_z_ext = Integer(1);
 };
 
-void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
+void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
   ThreadIdxExtractor extractor;
   extractor(f->body);
   arith::Analyzer analyzer;
@@ -86,7 +86,7 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
       // unable to extract the number of threads per block, hence directly return
       return;
     }
-    os << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
+    stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
   }
 }
 
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index bc7b34b500..3ec0c3bc2d 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -47,7 +47,7 @@ class CodeGenCUDA final : public CodeGenC {
   }
   // override behavior
   void PrintFuncPrefix(std::ostream& os) final;
-  void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final;  // NOLINT(*)
+  void PrintExtraAttrs(const PrimFunc& f) final;
   void VisitStmt_(const ForNode* op) final;
   void PrintStorageSync(const CallNode* op) final;
   void PrintStorageScope(const std::string& scope, std::ostream& os) final;  // NOLINT(*)
diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc
index 3db8d216b3..b8c30691e2 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -36,8 +36,6 @@ namespace codegen {
 
 void CodeGenMetal::InitFuncState(const PrimFunc& f) {
   CodeGenC::InitFuncState(f);
-  // skip the first underscore, so SSA variable starts from _1
-  name_supply_->FreshName("v_");
   // analyze the data;
   for (Var arg : f->params) {
     if (arg.dtype().is_handle()) {
@@ -54,33 +52,37 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) {
               << "};\n\n";
 }
 
-void CodeGenMetal::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
-                                          std::ostream& os) {
+void CodeGenMetal::AddFunction(const PrimFunc& f) {
+  // clear previous generated state.
+  this->InitFuncState(f);
+  // skip the first underscore, so SSA variable starts from _1
+  name_supply_->FreshName("v_");
+
   // add to alloc buffer type.
-  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
   ICHECK(global_symbol.defined())
       << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
 
   // Function header.
-  os << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
+  this->stream << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
 
   // Buffer arguments
   size_t num_buffer = 0;
   size_t limit = target_->GetAttr<Integer>("max_function_args").value().IntValue();
-  if (func->params.size() > limit) {
+  if (f->params.size() > limit) {
     LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of "
                     "buffers in the kernel";
   }
-  for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) {
-    Var v = func->params[i];
+  for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) {
+    Var v = f->params[i];
     if (!v.dtype().is_handle()) break;
-    os << "  ";
+    stream << "  ";
     std::string vid = AllocVarID(v.get());
     auto it = alloc_storage_scope_.find(v.get());
     if (it != alloc_storage_scope_.end()) {
-      PrintStorageScope(it->second, os);
+      PrintStorageScope(it->second, stream);
     }
-    PrintType(GetType(v), os);
+    PrintType(GetType(v), stream);
     // Register handle data type
     // TODO(tvm-team): consider simply keep type info in the
     // type annotation(via a normalizing rewriting).
@@ -89,18 +91,19 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
         RegisterHandleType(v.get(), prim->dtype);
       }
     }
-    os << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
+    stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
   }
   // Setup normal arguments.
-  size_t nargs = func->params.size() - num_buffer;
+  size_t nargs = f->params.size() - num_buffer;
   std::string varg = name_supply_->FreshName("arg");
   if (nargs != 0) {
     std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) + "_args_t";
-    os << "  constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n";
+    stream << "  constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer
+           << ") ]],\n";
     // declare the struct
     decl_stream << "struct " << arg_buf_type << " {\n";
-    for (size_t i = num_buffer; i < func->params.size(); ++i) {
-      Var v = func->params[i];
+    for (size_t i = num_buffer; i < f->params.size(); ++i) {
+      Var v = f->params[i];
       ICHECK(!v.dtype().is_handle());
       std::string vid = AllocVarID(v.get());
       std::ostringstream vref;
@@ -128,7 +131,7 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
   ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
   ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
   int work_dim = 0;
-  auto launch_params = func->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value();
+  auto launch_params = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value();
   for (const auto& tag : launch_params) {
     if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) {
       runtime::ThreadScope scope = runtime::ThreadScope::Create(tag);
@@ -147,7 +150,13 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
   }
   thread_work_dim_ = work_dim;
 
-  stream << ")";
+  // the function scope.
+  stream << ") {\n";
+  int func_scope = this->BeginScope();
+  this->PrintStmt(f->body);
+  this->EndScope(func_scope);
+  this->PrintIndent();
+  this->stream << "}\n\n";
 }
 
 void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
@@ -333,33 +342,27 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
   const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile");
   std::string fmt = fmetal_compile ? "metallib" : "metal";
 
-  Map<GlobalVar, PrimFunc> functions;
-  for (auto [gvar, base_func] : mod->functions) {
-    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
-    auto calling_conv = base_func->GetAttr<Integer>(tvm::attr::kCallingConv);
-    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
-        << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-
-    auto prim_func = Downcast<PrimFunc>(base_func);
-    functions.Set(gvar, prim_func);
-  }
+  for (auto kv : mod->functions) {
+    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
+    auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    ICHECK(global_symbol.defined());
+    std::string func_name = global_symbol.value();
 
-  for (auto [gvar, prim_func] : functions) {
-    source_maker << "// Function: " << gvar->name_hint << "\n";
+    source_maker << "// Function: " << func_name << "\n";
     CodeGenMetal cg(target);
     cg.Init(output_ssa);
+    auto f = Downcast<PrimFunc>(kv.second);
+    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
+        << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
 
-    for (auto [other_gvar, other_prim_func] : functions) {
-      cg.DeclareFunction(other_gvar, other_prim_func);
-    }
-    cg.AddFunction(gvar, prim_func);
-
+    cg.AddFunction(f);
     std::string fsource = cg.Finish();
     source_maker << fsource << "\n";
     if (fmetal_compile) {
       fsource = (*fmetal_compile)(fsource, target).operator std::string();
     }
-    smap[cg.GetFunctionName(gvar)] = fsource;
+    smap[func_name] = fsource;
   }
 
   return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str());
diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h
index 26c991e60d..36be10d163 100644
--- a/src/target/source/codegen_metal.h
+++ b/src/target/source/codegen_metal.h
@@ -38,8 +38,7 @@ class CodeGenMetal final : public CodeGenC {
   explicit CodeGenMetal(Target target);
   // override print thread tag.
   void PrintArgUnionDecl();
-  void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
-                              std::ostream& os) override;
+  void AddFunction(const PrimFunc& f);  // NOLINT(*)
   void InitFuncState(const PrimFunc& f) final;
   void PrintStorageScope(const std::string& scope, std::ostream& os) final;  // NOLINT(*)
   void PrintStorageSync(const CallNode* op) final;                           // NOLINT(*)
diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc
index da6a4de619..c15d2253d7 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -595,26 +595,18 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) {
   using tvm::runtime::Registry;
   bool output_ssa = false;
 
-  Map<GlobalVar, PrimFunc> functions;
-  for (auto [gvar, base_func] : mod->functions) {
-    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
-    auto prim_func = Downcast<PrimFunc>(base_func);
-    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
-    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
-        << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-    functions.Set(gvar, prim_func);
-  }
-
   std::stringstream code;
   const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc");
-  for (auto [gvar, prim_func] : functions) {
-    code << "// Function: " << gvar->name_hint << std::endl;
+  for (auto kv : mod->functions) {
+    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
+    code << "// Function: " << kv.first->name_hint << std::endl;
     CodeGenOpenCL cg;
     cg.Init(output_ssa);
-    for (auto [other_gvar, other_prim_func] : functions) {
-      cg.DeclareFunction(other_gvar, other_prim_func);
-    }
-    cg.AddFunction(gvar, prim_func);
+    auto f = Downcast<PrimFunc>(kv.second);
+    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
+        << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
+    cg.AddFunction(f);
     std::string fsource = cg.Finish();
     if (fpostproc) {
       fsource = (*fpostproc)(fsource, target).operator std::string();
diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc
index aa7a32320c..83046de107 100644
--- a/src/target/source/codegen_vhls.cc
+++ b/src/target/source/codegen_vhls.cc
@@ -145,21 +145,13 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) {
   // Generate source code for get_source().
   cg.Init(output_ssa);
 
-  Map<GlobalVar, PrimFunc> functions;
-  for (auto [gvar, base_func] : mod->functions) {
-    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only take PrimFunc";
-    auto prim_func = Downcast<PrimFunc>(base_func);
-    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+  for (auto kv : mod->functions) {
+    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only take PrimFunc";
+    auto f = Downcast<PrimFunc>(kv.second);
+    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-    functions.Set(gvar, prim_func);
-  }
-
-  for (auto [gvar, prim_func] : functions) {
-    cg.DeclareFunction(gvar, prim_func);
-  }
-  for (auto [gvar, prim_func] : functions) {
-    cg.AddFunction(gvar, prim_func);
+    cg.AddFunction(f);
   }
 
   std::string whole_code = cg.Finish();
@@ -167,21 +159,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) {
   // Generate source code for compilation.
   Array<Array<runtime::String>> kernel_info;
 
-  for (auto [gvar, prim_func] : functions) {
+  for (auto kv : mod->functions) {
+    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
+    auto f = Downcast<PrimFunc>(kv.second);
     CodeGenVivadoHLS cg;
     cg.Init(output_ssa);
-
-    for (auto [other_gvar, other_prim_func] : functions) {
-      cg.DeclareFunction(other_gvar, other_prim_func);
-    }
-    cg.AddFunction(gvar, prim_func);
+    cg.AddFunction(f);
     std::string code = cg.Finish();
     if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
       code = (*f)(code, target).operator std::string();
     }
 
-    auto function_name = cg.GetFunctionName(gvar);
-    kernel_info.push_back({function_name, code});
+    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    ICHECK(global_symbol.defined())
+        << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
+    kernel_info.push_back({global_symbol.value(), code});
   }
 
   std::string xclbin;
diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc
index 6a6712a4ce..4d1d834c7f 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -45,12 +45,6 @@ std::string CodeGenWebGPU::Finish() {
 
 void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
   CodeGenC::InitFuncState(f);
-  // skip the first underscore, so SSA variable starts from
-  name_supply_->FreshName("v_");
-  // Setup the thread group info.
-  ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
-  ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
-
   // analyze the data;
   for (Var arg : f->params) {
     if (arg.dtype().is_handle()) {
@@ -62,12 +56,28 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
 
 CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
 
-void CodeGenWebGPU::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
-                                           std::ostream& os) {
+void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
+  // clear previous generated state.
+  this->InitFuncState(f);
+  // skip the first underscore, so SSA variable starts from
+  name_supply_->FreshName("v_");
+  // Setup the thread group info.
+  ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
+  ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
+
+  // add to alloc buffer type.
+  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+  ICHECK(global_symbol.defined())
+      << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
+
+  decl_stream << "//----------------------------------------\n"
+              << "// function: " << global_symbol.value() << "\n"
+              << "//----------------------------------------\n";
+
   std::vector<Var> pod_args;
   int num_buffer = 0;
   // setup buffer argumemts
-  for (Var arg : func->params) {
+  for (Var arg : f->params) {
     DataType t = arg.dtype();
     if (t.is_handle()) {
       auto* ptr = arg->type_annotation.as<PointerTypeNode>();
@@ -101,18 +111,16 @@ void CodeGenWebGPU::PrintFunctionSignature(const String& function_name, const Pr
   }
   // add to alloc buffer type.
   // Function header.
-  os << "fn main(\n"
-     << "  @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
-     << "  @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
-     << ")";
-}
-
-void CodeGenWebGPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
-  CodeGenC::AddFunction(gvar, func);
-  decl_stream << "//----------------------------------------\n"
-              << "// function: " << GetFunctionName(gvar) << "\n"
-              << "//----------------------------------------\n";
-
+  this->stream << "fn main(\n"
+               << "  @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
+               << "  @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
+               << ") {\n";
+  // the function scope.
+  int func_scope = this->BeginScope();
+  this->PrintStmt(f->body);
+  this->EndScope(func_scope);
+  this->PrintIndent();
+  this->stream << "}\n\n";
   // anotate workgroup
   this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0] << ", "
                         << workgroup_size_[1] << ", " << workgroup_size_[2] << ")\n";
@@ -516,31 +524,22 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) {
   mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));
   bool output_ssa = false;
 
-  Map<GlobalVar, PrimFunc> functions;
-  for (auto [gvar, base_func] : mod->functions) {
-    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only take PrimFunc";
-    auto prim_func = Downcast<PrimFunc>(base_func);
-    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+  std::unordered_map<std::string, std::string> smap;
+  for (auto kv : mod->functions) {
+    CodeGenWebGPU cg(target);
+    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only take PrimFunc";
+    auto f = Downcast<PrimFunc>(kv.second);
+    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
     ICHECK(global_symbol.defined())
         << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
-    functions.Set(gvar, prim_func);
-  }
-
-  std::unordered_map<std::string, std::string> smap;
-  for (auto [gvar, prim_func] : functions) {
-    CodeGenWebGPU cg(target);
+    std::string f_name = global_symbol.value();
     cg.Init(output_ssa);
-
-    for (auto [other_gvar, other_prim_func] : functions) {
-      cg.DeclareFunction(other_gvar, other_prim_func);
-    }
-    cg.AddFunction(gvar, prim_func);
-
+    cg.AddFunction(f);
     std::string code = cg.Finish();
-    smap[cg.GetFunctionName(gvar)] = code;
+    smap[f_name] = code;
   }
   auto n = make_object<WebGPUSourceModuleNode>(smap, ExtractFuncInfo(mod));
   return runtime::Module(n);
diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h
index 6ae942a3ad..57f226ba8a 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -48,9 +48,7 @@ class CodeGenWebGPU final : public CodeGenC {
   explicit CodeGenWebGPU(Target target);
   // overrides
   std::string Finish() final;
-  void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
-                              std::ostream& os) final;
-  void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final;
+  void AddFunction(const PrimFunc& f);  // NOLINT(*)
   void InitFuncState(const PrimFunc& f) final;
   void PrintStorageSync(const CallNode* op) final;     // NOLINT(*)
   void PrintType(DataType t, std::ostream& os) final;  // NOLINT(*)
diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc
index 90640a6db6..a6f4b5bb3e 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -613,14 +613,12 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
       }
 
       for (const tir::Var& pool_var : metadata_->pools) {
-        call_args_ss << "((uint8_t*)";
         String pool_name = metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name;
         if (IsInternalWorkspaceBuffer(pool_var)) {
-          call_args_ss << "&" << pool_name;
+          call_args_ss << "&" << pool_name << ",";
         } else {
-          call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name);
+          call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << ",";
         }
-        call_args_ss << "),";
       }
       for (const String& device : metadata_->devices) {
         call_args_ss << "devices->" << device << ",";
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index fd14f48921..39214c4546 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -70,32 +70,6 @@ Type GetType(const PrimExpr& expr) {
       return ptr->type_annotation;
     }
   }
-
-  if (auto* access = expr.as<tir::CallNode>()) {
-    if (access->op.same_as(builtin::tvm_access_ptr())) {
-      ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have empty arguments";
-      auto type_annotation = Downcast<Call>(access->args[0]);
-      static auto builtin_op = Op::Get("tir.type_annotation");
-      ICHECK(type_annotation->op.same_as(builtin_op))
-          << "Expected the first argument of builtin tvm_access_ptr() "
-          << "to be a type annotation, but found " << type_annotation->op;
-      return PointerType(PrimType(type_annotation->dtype));
-    }
-  }
-
-  if (auto* address_of = expr.as<tir::CallNode>()) {
-    if (address_of->op.same_as(builtin::address_of())) {
-      ICHECK_EQ(address_of->args.size(), 1)
-          << "Builtin address_of() expects a single argument, but received arguments "
-          << address_of->args;
-      auto* address = address_of->args[0].as<BufferLoadNode>();
-      ICHECK(address)
-          << "Builtin address_of() expects the argument to be a BufferLoad, but received argument "
-          << address_of->args[0];
-
-      return PointerType(PrimType(address->dtype));
-    }
-  }
   // Default: return the type indicated by the dtype.
   runtime::DataType dtype = expr.dtype();
   return GetTypeFromRuntimeDataType(dtype);
diff --git a/tests/python/relay/aot/test_crt_forward_declarations.py b/tests/python/relay/aot/test_crt_forward_declarations.py
index 99e2f0c923..72305b0030 100644
--- a/tests/python/relay/aot/test_crt_forward_declarations.py
+++ b/tests/python/relay/aot/test_crt_forward_declarations.py
@@ -160,8 +160,8 @@ def test_internal_calls(interface_api, use_unpacked_api, test_runner):
 
     lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0]
     main_source = lib_mod.get_source()
-    assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 2
-    assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 6
+    assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 1
+    assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 3
 
 
 @tvm.testing.requires_corstone300
diff --git a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
index f6145cd1c5..7bea7577b6 100644
--- a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
+++ b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
@@ -135,13 +135,8 @@ def test_write_3x3_depthwise_code():
     #define TENSORDOT_OPT_X1_INT16_W48_3X3_000_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w48_3x3_000(
-        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
-        int32_t *bias, int32_t *scale
+        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
     ) {
-      int32_t *output = output_arg;
-      int32_t *tensor = tensor_arg;
-      int32_t *kernel = kernel_arg;
-
       int32_t sum_0 = *bias;
 
       int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -193,13 +188,8 @@ def test_odd_width_3x3_depthwise_strides_code():
     #define TENSORDOT_OPT_X2_INT16_W49_3X3_000_2_4_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t tensordot_opt_x2_int16_w49_3x3_000_2_4(
-        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
-        int32_t *bias, int32_t *scale
+        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
     ) {
-      int32_t *output = output_arg;
-      int32_t *tensor = tensor_arg;
-      int32_t *kernel = kernel_arg;
-
       int32_t sum_0 = *bias, sum_1 = *bias;
 
       int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -261,13 +251,8 @@ def test_1x1x8_convolution_code():
     #define TENSORDOT_OPT_X4_INT16_W384_1X8_000_8_1_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t tensordot_opt_x4_int16_w384_1x8_000_8_1(
-        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
-        int32_t *bias, int32_t *scale
+        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
     ) {
-      int32_t *output = output_arg;
-      int32_t *tensor = tensor_arg;
-      int32_t *kernel = kernel_arg;
-
       int32_t sum_0 = *bias, sum_1 = *bias, sum_2 = *bias, sum_3 = *bias;
 
       int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -364,13 +349,8 @@ def test_3x3x3_offset_convolution_code():
     #define TENSORDOT_OPT_X1_INT16_W288_3X9_111_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w288_3x9_111(
-        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
-        int32_t *bias, int32_t *scale
+        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
     ) {
-      int32_t *output = output_arg;
-      int32_t *tensor = tensor_arg;
-      int32_t *kernel = kernel_arg;
-
       int32_t sum_0 = *bias;
 
       int32_t tensor__unknown__y00_x00 = tensor[0];
diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py
index 3aca0fc8c7..d02f8744f1 100644
--- a/tests/python/unittest/test_target_codegen_c_host.py
+++ b/tests/python/unittest/test_target_codegen_c_host.py
@@ -14,15 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 import tvm
 import tvm.testing
-
 from tvm import te
-from tvm.contrib import utils
-from tvm.script import tir as T, ir as I
-
 import numpy as np
+from tvm.contrib import utils
 
 
 def test_add():
@@ -232,39 +228,11 @@ def test_call_packed():
     check_global_packed_func()
 
 
-def test_subroutine_call():
-    @I.ir_module
-    class mod:
-        @T.prim_func
-        def main(A: T.Buffer(1, dtype="float32")):
-            mod.subroutine(A.data)
-
-        @T.prim_func(private=True)
-        def subroutine(A_data: T.handle("float32")):
-            A = T.decl_buffer(1, dtype="float32", data=A_data)
-            A[0] = 42.0
-
-    built = tvm.build(mod, target="c")
-
-    func_names = list(built["get_func_names"]())
-    assert (
-        "main" in func_names
-    ), "Externally exposed functions should be listed in available functions."
-    assert (
-        "subroutine" not in func_names
-    ), "Internal function should not be listed in available functions."
-
-    source = built.get_source()
-    assert (
-        source.count("main(void*") == 2
-    ), "Expected two occurrences, for forward-declaration and definition"
-    assert (
-        source.count("subroutine(float*") == 2
-    ), "Expected two occurrences, for forward-declaration and definition"
-    assert (
-        source.count("subroutine(") == 3
-    ), "Expected three occurrences, for forward-declaration, definition, and call from main."
-
-
 if __name__ == "__main__":
-    tvm.testing.main()
+    test_add()
+    test_add_pipeline()
+    test_reinterpret()
+    test_ceil()
+    test_floor()
+    test_round()
+    test_call_packed()
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index 61f0892a9c..588a92d87c 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -268,7 +268,6 @@ cast_smem_ptr_to_int(const void* const smem_ptr)
   #define int64_t long long
   #define uint64_t unsigned long long
 #endif
-extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C);
 extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) {
   __shared__ float A_shared[64];
   __shared__ float B_shared[64];