You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/09/12 00:33:31 UTC
[tvm] 01/01: Revert "[CodeGenC] Handle GlobalVar callee as internal function call (#15103)"
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch revert-15103-codegen_c_support_subroutine_calls
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 2902c10cc700671cbb538c679b7a375d20596cee
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Sep 11 17:33:25 2023 -0700
Revert "[CodeGenC] Handle GlobalVar callee as internal function call (#15103)"
This reverts commit 9ff71f4a9fed3ec9f82b999fb8c25ef1bf6e243c.
---
.../arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py | 8 +-
.../topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py | 87 +++---------
.../arm_cpu/mprofile/dsp/micro_kernel/max_pool.py | 13 +-
.../arm_cpu/mprofile/dsp/micro_kernel/tensordot.py | 7 +-
.../backend/contrib/cmsisnn/tir_to_runtime.cc | 28 ++--
.../contrib/example_target_hooks/tir_to_runtime.cc | 26 +---
src/relay/backend/contrib/uma/tir_to_runtime.cc | 34 ++---
src/target/opt/build_cuda_on.cc | 18 +--
src/target/source/codegen_aocl.cc | 19 +--
src/target/source/codegen_c.cc | 153 +++++++--------------
src/target/source/codegen_c.h | 59 +-------
src/target/source/codegen_c_host.cc | 93 +++++++------
src/target/source/codegen_c_host.h | 3 +-
src/target/source/codegen_cuda.cc | 4 +-
src/target/source/codegen_cuda.h | 2 +-
src/target/source/codegen_metal.cc | 77 ++++++-----
src/target/source/codegen_metal.h | 3 +-
src/target/source/codegen_opencl.cc | 24 ++--
src/target/source/codegen_vhls.cc | 34 ++---
src/target/source/codegen_webgpu.cc | 79 ++++++-----
src/target/source/codegen_webgpu.h | 4 +-
src/target/source/source_module.cc | 6 +-
src/tir/op/op.cc | 26 ----
.../relay/aot/test_crt_forward_declarations.py | 4 +-
.../topi/python/test_topi_conv2d_tensordot_opts.py | 28 +---
.../python/unittest/test_target_codegen_c_host.py | 48 ++-----
.../test_tir_transform_inject_ptx_async_copy.py | 1 -
27 files changed, 297 insertions(+), 591 deletions(-)
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
index 3eb32d8fdb..e8e45152aa 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
@@ -55,7 +55,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False):
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
- "int32",
+ cc.dtype,
f"{func_prefix}_{width}_{uniq_id}",
aa.access_ptr("r"),
cc.access_ptr("w"),
@@ -68,7 +68,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False):
def _reduce_reset():
ib = tvm.tir.ir_builder.create()
ib.emit(
- tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
+ tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
)
return ib.get()
@@ -113,8 +113,8 @@ extern "C"
__attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
int16_t *arr,
int16_t *res16,
- int32_t arr_offset,
- int32_t reset) {{
+ long arr_offset,
+ int reset) {{
int n;
int32_t *p32;
int32_t res = reset ? 0 : *res16;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
index e26e818fbd..929dcc6557 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
@@ -156,14 +156,9 @@ __attribute__((always_inline)) static inline const int8_t *read_and_pad(const in
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
- int32_t K_arg,
+ int K,
int8_t *aa, int8_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int K = K_arg;
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
+ int A_stride, int B_stride, int C_stride) {{
int k_base = (K / 4) * 4;
switch ( K % 4 ) {{
case 1:
@@ -205,12 +200,7 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
-
+ int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
@@ -231,11 +221,7 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
+ int A_stride, int B_stride, int C_stride) {{
int16_t bb_pad[{bb_pad_size}];
int32_t retcode = 0;
@@ -279,14 +265,9 @@ out:
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
- int32_t K_arg,
+ int K,
int8_t *aa, int8_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int K = K_arg;
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
+ int A_stride, int B_stride, int C_stride) {{
int k_base = (K / 4) * 4;
switch ( K % 4 ) {{
case 1:
@@ -328,11 +309,7 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
+ int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
@@ -350,11 +327,7 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
int8_t *aa, int8_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
+ int A_stride, int B_stride, int C_stride) {{
int16_t bb_pad[{bb_pad_size}];
int32_t retcode = 0;
@@ -395,14 +368,9 @@ out:
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
- int32_t K_arg,
+ int K,
int16_t *aa, int16_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int K = K_arg;
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
+ int A_stride, int B_stride, int C_stride) {{
int k_base = (K / 2) * 2;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
@@ -419,11 +387,7 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
+ int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
@@ -444,11 +408,7 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
+ int A_stride, int B_stride, int C_stride) {{
int32_t retcode = 0;
if ( {M} < 2 && {N} < 2 ) {{
@@ -490,14 +450,9 @@ out:
extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
- int32_t K_arg,
+ int K,
int16_t *aa, int16_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int K = K_arg;
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
+ int A_stride, int B_stride, int C_stride) {{
int k_base = (K / 2) * 2;
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
@@ -514,11 +469,7 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
+ int A_stride, int B_stride, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
int32_t sum = 0;
@@ -536,11 +487,7 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
int16_t *aa, int16_t *bb, int32_t *cc,
- int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
- int A_stride = A_stride_arg;
- int B_stride = B_stride_arg;
- int C_stride = C_stride_arg;
-
+ int A_stride, int B_stride, int C_stride) {{
int32_t retcode = 0;
if ( {M} < 2 && {N} < 2 ) {{
@@ -573,7 +520,7 @@ out:
#ifdef __cplusplus
extern "C"
#endif
-__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{
+__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
for (int i = 0; i < {M}; i++) {{
for (int j = 0; j < {N}; j++) {{
cc[i*C_stride + j] = 0;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
index cfed417c9f..66d712a4a0 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
@@ -46,7 +46,7 @@ def intrin_max(shape, in_dtype, out_dtype):
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
- "int32",
+ cc.dtype,
f"{func_prefix}_{uniq_id}",
aa.access_ptr("r"),
cc.access_ptr("w"),
@@ -59,7 +59,7 @@ def intrin_max(shape, in_dtype, out_dtype):
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
- "int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
+ cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0]
)
)
return ib.get()
@@ -96,7 +96,7 @@ extern "C"
#endif
__attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
int8_t *res,
- int32_t N) {{
+ int N) {{
memset(res, (int8_t)-128, N * sizeof(*res));
return 0;
}}
@@ -107,9 +107,7 @@ extern "C"
__attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
int8_t *arg,
int8_t *res,
- int32_t N_arg) {{
- int N = N_arg;
-
+ int N) {{
for ( int i = 0; i < N; ++ i )
if ( arg[i] > res[i] )
res[i] = arg[i];
@@ -122,8 +120,7 @@ extern "C"
__attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
int8_t *arg,
int8_t *res,
- int32_t N_arg) {{
- int N = N_arg;
+ int N) {{
int32_t *parg32, *pres32;
int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3;
int32_t retcode = 0;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
index af3b23e01d..d2a8f1ef69 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
@@ -390,13 +390,8 @@ def tensordot_int16_impl(
#define {function_name.upper()}_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t {function_name}(
- int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
- int32_t *bias, int32_t *scale
+ int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
) {{
- int32_t *output = output_arg;
- int32_t *tensor = tensor_arg;
- int32_t *kernel = kernel_arg;
-
{_init_biased_accumulators(num_outputs)}
{insert_lines(load_tensor_lines)}
diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
index 6febfe3486..186fa30f20 100644
--- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
@@ -46,6 +46,13 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices);
}
+ /*!
+ * \brief Emit code that offloads a subgraph to the Cortex-M
+ *
+ * \return string of code that offloads a subgraph to the Cortex-M
+ */
+ void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }
+
private:
/*! * \brief Enable storing the last error */
bool debug_last_error;
@@ -568,11 +575,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool emit_fwd_func_decl = false;
bool debug_last_error = GetCompilerAttrs()->debug_last_error;
CodeGenCMSISNN codegen;
+ Array<String> function_names;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error);
-
- std::vector<std::pair<tvm::GlobalVar, tvm::PrimFunc>> funcs;
- for (auto [gvar, base_func] : mod->functions) {
- funcs.push_back({gvar, Downcast<PrimFunc>(base_func)});
+ std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
+ for (auto kv : mod->functions) {
+ funcs.push_back(kv);
}
std::sort(funcs.begin(), funcs.end(),
@@ -587,16 +594,13 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
return name_hint_a < name_hint_b;
});
- for (auto [gvar, prim_func] : funcs) {
- codegen.AddFunction(gvar, prim_func);
+ for (auto kv : funcs) {
+ auto prim_func = Downcast<PrimFunc>(kv.second);
+ auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ function_names.push_back(global_symbol.value());
+ codegen.AddFunction(prim_func);
}
std::string code = codegen.Finish();
-
- Array<String> function_names;
- for (auto [gvar, prim_func] : funcs) {
- function_names.push_back(codegen.GetFunctionName(gvar));
- }
-
return codegen::CSourceModuleCreate(code, "c", function_names);
}
diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
index 6f09e0a0c3..0db8d06c31 100644
--- a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
@@ -49,30 +49,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool emit_asserts = false;
bool emit_fwd_func_decl = false;
CodeGenExampleTargetHook codegen;
-
+ Array<String> function_names;
std::unordered_set<std::string> devices;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
-
- Map<GlobalVar, PrimFunc> functions;
- for (auto [gvar, base_func] : mod->functions) {
- auto prim_func = Downcast<PrimFunc>(base_func);
- functions.Set(gvar, prim_func);
- }
-
- for (auto [gvar, prim_func] : functions) {
- codegen.DeclareFunction(gvar, prim_func);
- }
- for (auto [gvar, prim_func] : functions) {
- codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
+ for (auto kv : mod->functions) {
+ auto prim_func = Downcast<PrimFunc>(kv.second);
+ auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ function_names.push_back(global_symbol.value());
+ codegen.AddFunction(prim_func);
}
-
std::string code = codegen.Finish();
-
- Array<String> function_names;
- for (auto [gvar, prim_func] : functions) {
- function_names.push_back(codegen.GetFunctionName(gvar));
- }
-
return codegen::CSourceModuleCreate(code, "c", function_names);
}
diff --git a/src/relay/backend/contrib/uma/tir_to_runtime.cc b/src/relay/backend/contrib/uma/tir_to_runtime.cc
index 487e247f5d..3b58fda54b 100644
--- a/src/relay/backend/contrib/uma/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/uma/tir_to_runtime.cc
@@ -49,6 +49,13 @@ class UMACodegen : public codegen::CodeGenCHost {
CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str_, devices);
}
+ /*!
+ * \brief Emit code that offloads a subgraph to the UMA target
+ *
+ * \return string of code that offloads a subgraph to the UMA target
+ */
+ void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); }
+
private:
String target_str_;
};
@@ -56,30 +63,17 @@ class UMACodegen : public codegen::CodeGenCHost {
runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool output_ssa = false;
bool emit_asserts = false;
- bool emit_fwd_func_decl = true;
+ bool emit_fwd_func_decl = false;
UMACodegen codegen(target->kind->name);
+ Array<String> function_names;
codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl);
-
- Map<GlobalVar, PrimFunc> functions;
- for (auto [gvar, base_func] : mod->functions) {
- auto prim_func = Downcast<PrimFunc>(base_func);
- functions.Set(gvar, prim_func);
- }
-
- for (auto [gvar, prim_func] : functions) {
- codegen.DeclareFunction(gvar, prim_func);
- }
- for (auto [gvar, prim_func] : functions) {
- codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
+ for (auto kv : mod->functions) {
+ auto prim_func = Downcast<PrimFunc>(kv.second);
+ auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ function_names.push_back(global_symbol.value());
+ codegen.AddFunction(prim_func);
}
-
std::string code = codegen.Finish();
-
- Array<String> function_names;
- for (auto [gvar, prim_func] : functions) {
- function_names.push_back(codegen.GetFunctionName(gvar));
- }
-
return codegen::CSourceModuleCreate(code, "c", function_names);
}
diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc
index e0f53e3509..1c0b5094ef 100644
--- a/src/target/opt/build_cuda_on.cc
+++ b/src/target/opt/build_cuda_on.cc
@@ -131,21 +131,13 @@ runtime::Module BuildCUDA(IRModule mod, Target target) {
CodeGenCUDA cg;
cg.Init(output_ssa);
- Map<GlobalVar, PrimFunc> functions;
- for (auto [gvar, base_func] : mod->functions) {
- ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
- auto prim_func = Downcast<PrimFunc>(base_func);
- auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+ for (auto kv : mod->functions) {
+ ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
+ auto f = Downcast<PrimFunc>(kv.second);
+ auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
- functions.Set(gvar, prim_func);
- }
-
- for (auto [gvar, prim_func] : functions) {
- cg.DeclareFunction(gvar, prim_func);
- }
- for (auto [gvar, prim_func] : functions) {
- cg.AddFunction(gvar, prim_func);
+ cg.AddFunction(f);
}
std::string code = cg.Finish();
diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc
index dc3ba08751..700d85b4cc 100644
--- a/src/target/source/codegen_aocl.cc
+++ b/src/target/source/codegen_aocl.cc
@@ -40,22 +40,13 @@ runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) {
CodeGenOpenCL cg;
cg.Init(output_ssa);
- Map<GlobalVar, PrimFunc> functions;
- for (auto [gvar, base_func] : mod->functions) {
- ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only take PrimFunc";
- auto prim_func = Downcast<PrimFunc>(base_func);
- auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+ for (auto kv : mod->functions) {
+ ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only take PrimFunc";
+ auto f = Downcast<PrimFunc>(kv.second);
+ auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
- functions.Set(gvar, prim_func);
- }
-
- for (auto [gvar, prim_func] : functions) {
- cg.DeclareFunction(gvar, prim_func);
- }
-
- for (auto [gvar, prim_func] : functions) {
- cg.AddFunction(gvar, prim_func);
+ cg.AddFunction(f);
}
std::string code = cg.Finish();
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 187bdc74fe..a7cc320562 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -42,7 +42,6 @@ void CodeGenC::InitFuncState(const PrimFunc& f) {
alloc_storage_scope_.clear();
handle_data_type_.clear();
CodeGenSourceBase::ClearFuncState();
- ReserveKeywordsAsUnique();
}
void CodeGenC::ReserveKeywordsAsUnique() {
@@ -76,92 +75,51 @@ void CodeGenC::ReserveKeywordsAsUnique() {
name_supply_->ReserveName("return");
}
-void CodeGenC::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
- std::ostream& os) {
- PrintFuncPrefix(os);
- PrintType(func->ret_type, os);
- PrintExtraAttrs(func, os);
- os << " " << function_name << "(";
- for (size_t i = 0; i < func->params.size(); ++i) {
- tir::Var v = func->params[i];
-
- if (i > 0) {
- os << ", ";
- }
-
- if (auto it = alloc_storage_scope_.find(v.get()); it != alloc_storage_scope_.end()) {
- PrintStorageScope(it->second, os);
- }
-
- PrintType(GetType(v), os);
-
- bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
- bool is_handle = v.dtype().is_handle();
- if (no_alias && is_handle) {
- PrintRestrict(v, os);
- }
-
- os << " " << AllocVarID(v.get());
- }
- os << ")";
+void CodeGenC::AddFunction(const PrimFunc& f) {
+ // clear previous generated state.
+ this->InitFuncState(f);
+ // reserve keywords
+ ReserveKeywordsAsUnique();
- // Register handle data type
- // TODO(tvm-team): consider simply keep type info in the
- // type annotation(via a normalizing rewriting).
- for (const auto& param : func->params) {
- if (auto* ptr = param->type_annotation.as<PointerTypeNode>()) {
- if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
- RegisterHandleType(param.get(), prim->dtype);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(global_symbol.defined())
+ << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
+ bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
+
+ this->PrintFuncPrefix(stream);
+ PrintType(f->ret_type, stream);
+ this->PrintExtraAttrs(f);
+ this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
+
+ for (size_t i = 0; i < f->params.size(); ++i) {
+ tir::Var v = f->params[i];
+ std::string vid = AllocVarID(v.get());
+ if (i != 0) stream << ", ";
+ if (v.dtype().is_handle()) {
+ auto it = alloc_storage_scope_.find(v.get());
+ if (it != alloc_storage_scope_.end()) {
+ PrintStorageScope(it->second, stream);
}
- }
- }
-}
-void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) {
- if (internal_functions_.count(gvar)) {
- return;
- }
+ PrintType(GetType(v), stream);
+ // Register handle data type
+ // TODO(tvm-team): consider simply keep type info in the
+ // type annotation(via a normalizing rewriting).
+ if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
+ if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
+ RegisterHandleType(v.get(), prim->dtype);
+ }
+ }
- auto function_name = [&]() -> String {
- if (auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
- auto name = global_symbol.value();
- ICHECK(!func_name_supply_->ContainsName(name))
- << "Function " << gvar << " must use global symbol " << name
- << ", but this name has already been used.";
- func_name_supply_->ReserveName(name);
- return name;
+ if (no_alias) {
+ PrintRestrict(v, stream);
+ }
} else {
- func_name_supply_->ReserveName(gvar->name_hint);
- return gvar->name_hint;
+ PrintType(GetType(v), stream);
}
- }();
-
- internal_functions_.insert({gvar, function_name});
-
- InitFuncState(func);
- PrintFunctionSignature(function_name, func, fwd_decl_stream);
- fwd_decl_stream << ";\n";
-}
-
-String CodeGenC::GetFunctionName(const GlobalVar& gvar) {
- auto it = internal_functions_.find(gvar);
- ICHECK(it != internal_functions_.end())
- << "Attempted to find name of " << gvar
- << ", but no function with this GlobalVar has been declared";
- return it->second;
-}
-
-void CodeGenC::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
- // If the function has already been forward-declared, this is a
- // no-op.
- DeclareFunction(gvar, f);
- auto function_name = GetFunctionName(gvar);
-
- // clear previous generated state.
- InitFuncState(f);
-
- PrintFunctionSignature(function_name, f, stream);
- stream << " {\n";
+ stream << ' ' << vid;
+ }
+ stream << ") {\n";
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
@@ -172,15 +130,9 @@ void CodeGenC::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
void CodeGenC::PrintFuncPrefix(std::ostream& os) {}
-void CodeGenC::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {}
+void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}
-std::string CodeGenC::Finish() {
- std::ostringstream code;
- code << decl_stream.str();
- code << fwd_decl_stream.str();
- code << stream.str();
- return code.str();
-}
+std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*)
if (print_ssa_form_) {
@@ -590,17 +542,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
-
- // If the call_extern refers to an function within the IRModule, then
- // the forward declaration is already provided from DeclareFunction.
- if (!func_name_supply_->ContainsName(func->value)) {
- Array<Type> arg_types;
- for (size_t i = 1; i < op->args.size(); i++) {
- arg_types.push_back(GetType(op->args[i]));
- }
- Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
- this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type);
+ Array<Type> arg_types;
+ for (size_t i = 1; i < op->args.size(); i++) {
+ arg_types.push_back(GetType(op->args[i]));
}
+ Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
+ this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type);
} else if (op_attr_global_symbol_.count(call_op)) {
// call extern if the op itself have a global symbol.
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
@@ -668,13 +615,9 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
} else {
LOG(FATAL) << "Unresolved call " << op->op;
}
- } else if (auto opt = op->op.as<GlobalVar>()) {
- auto gvar = opt.value();
- auto callee_name = GetFunctionName(gvar);
- PrintCallExtern(GetType(GetRef<PrimExpr>(op)), callee_name, op->args, false, os);
} else {
- LOG(FATAL) << "CodeGenC: Unknown operation " << op->op << " is neither a recognized built-in, "
- << "nor a GlobalVar reference to another function in the IRModule";
+ ICHECK(op->op.as<GlobalVarNode>());
+ LOG(FATAL) << "Do not yet support cross function call";
}
}
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index 2921a56ef3..93f9ea519c 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -65,33 +65,12 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
* \param output_ssa Whether output SSA.
*/
void Init(bool output_ssa);
-
/*!
- * \brief Add the function declaration to the generated module,
- * without defining it.
- *
- * \param gvar The GlobalVar representing the function.
- * \param func The function to be compiled.
+ * \brief Add the function to the generated module.
+ * \param f The function to be compiled.
* \param whether to append return 0 in the end.
*/
- virtual void DeclareFunction(const GlobalVar& gvar, const PrimFunc& func);
-
- /*!
- * \brief Add the function to the generated module, including its
- * declaration and definition.
- *
- * \param gvar The GlobalVar representing the function.
- * \param func The function to be compiled.
- */
- virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& func);
-
- /*!
- * \brief Get the name of a declared function
- * \param gvar The GlobalVar of the function
- * \returns The string name of the function
- */
- String GetFunctionName(const GlobalVar& gvar);
-
+ void AddFunction(const PrimFunc& f);
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
@@ -117,23 +96,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
PrintExpr(n, os);
return os.str();
}
-
// The following parts are overloadable print operations.
-
- /*! \brief Print the function signature before the argument list
- *
- * The default implementation delegates out to PrintFuncPrefix and
- * PrintExtraAttrs.
- *
- * \param function_name The name of the function
- *
- * \param func The function whose signature should be printed
- *
- * \param os The output stream
- */
- virtual void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
- std::ostream& os);
-
/*!
* \brief Print the function header before the argument list
* \param os The output stream
@@ -146,7 +109,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
*
* Example: __launch_bounds__(256) for CUDA functions
*/
- virtual void PrintExtraAttrs(const PrimFunc& f, std::ostream& os); // NOLINT(*)
+ virtual void PrintExtraAttrs(const PrimFunc& f);
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
@@ -321,24 +284,10 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
private:
/*! \brief set of volatile buf access */
std::unordered_set<const VarNode*> volatile_buf_;
-
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
-
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
-
- /* \brief Map of GlobalVar to their symbol.
- *
- * For externally-exposed functions, this is given by the
- * tvm::attr::kTarget attribute of the PrimFunc. For internal
- * functions, this is the name of the function's GlobalVar, possibly
- * altered to prevent duplicate names.
- */
- std::unordered_map<GlobalVar, String, ObjectPtrHash, ObjectPtrEqual> internal_functions_;
-
- /* \brief Name supply to generate unique function names */
- NameSupply func_name_supply_{""};
};
} // namespace codegen
diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc
index caef43e8af..3255e11c5d 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -75,24 +75,19 @@ void CodeGenCHost::InitGlobalContext() {
void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; }
-void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func,
- bool emit_fwd_func_decl) {
- auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- if (global_symbol) {
- function_names_.push_back(global_symbol.value());
- }
+void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(global_symbol.defined())
+ << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute";
+ function_names_.push_back(global_symbol.value());
emit_fwd_func_decl_ = emit_fwd_func_decl;
- CodeGenC::AddFunction(gvar, func);
- if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
- ICHECK(global_symbol.defined())
- << "CodeGenCHost: The entry func must have the global_symbol attribute, "
- << "but function " << gvar << " only has attributes " << func->attrs;
-
+ CodeGenC::AddFunction(f);
+ if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
function_names_.push_back(runtime::symbol::tvm_module_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
- PrintType(func->ret_type, stream);
+ PrintType(f->ret_type, stream);
stream << " " << tvm::runtime::symbol::tvm_module_main
<< "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
<< "int* out_ret_tcode, void* resource_handle) {\n";
@@ -133,6 +128,15 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*)
<< "TVM_DLL ";
}
+std::string CodeGenCHost::Finish() { // NOLINT(*)
+ std::string ret = decl_stream.str();
+ if (emit_fwd_func_decl_) {
+ ret += fwd_decl_stream.str();
+ }
+ ret += stream.str();
+ return ret;
+}
+
void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
@@ -433,38 +437,42 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
CodeGenCHost cg;
cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
cg.SetConstantsByteAlignment(target->GetAttr<Integer>("constants-byte-alignment").value_or(16));
-
- auto is_aot_executor_fn = [](const PrimFunc& func) -> bool {
- return func->GetAttr<Bool>("runner_function", Bool(false)).value();
- };
-
- std::vector<std::pair<GlobalVar, PrimFunc>> funcs;
- for (auto [gvar, base_func] : mod->functions) {
- ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
- auto prim_func = Downcast<PrimFunc>(base_func);
- funcs.push_back({gvar, prim_func});
+ PrimFunc aot_executor_fn;
+
+ std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
+ for (auto kv : mod->functions) {
+ // Make sure that the executor function is the last one to be code generated so that all the
+ // symbols are available to __tvm_main__
+ auto fun_name = std::string(kv.first->name_hint);
+ bool is_aot_executor_fn = kv.second->GetAttr<Bool>("runner_function", Bool(false)).value();
+
+ if (is_aot_executor_fn) {
+ aot_executor_fn = Downcast<PrimFunc>(kv.second);
+ continue;
+ }
+ funcs.push_back(kv);
}
// Sort functions
- auto sort_key = [&is_aot_executor_fn](const auto& kv) {
- return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint};
- };
- std::sort(funcs.begin(), funcs.end(), [&sort_key](const auto& kv_a, const auto& kv_b) {
- return sort_key(kv_a) < sort_key(kv_b);
- });
-
- // Declare all functions first. This ensures that all functions,
- // including the __tvm_main__ used in AOT, have access to forward
- // declarations of other functions in the IRModule.
- for (const auto& [gvar, prim_func] : funcs) {
- cg.DeclareFunction(gvar, prim_func);
+ std::sort(funcs.begin(), funcs.end(),
+ [](std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_a,
+ std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_b) {
+ std::string name_hint_a = kv_a.first->name_hint;
+ std::string name_hint_b = kv_b.first->name_hint;
+ return name_hint_a < name_hint_b;
+ });
+
+ // Add all functions except __tvm_main__
+ for (auto& kv : funcs) {
+ ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
+ auto f = Downcast<PrimFunc>(kv.second);
+ cg.AddFunction(f);
}
- // Codegen all functions. Passing emit_fwd_func_decl=true adds a
- // forward declaration for any `builtin::call_extern`, based on the
- // arguments provided to it.
- for (const auto& [gvar, prim_func] : funcs) {
- cg.AddFunction(gvar, prim_func, emit_fwd_func_decl);
+ // Add __tvm_main__
+ if (aot_executor_fn.defined()) {
+ emit_fwd_func_decl = true;
+ cg.AddFunction(aot_executor_fn, emit_fwd_func_decl);
}
// NOTE: it's possible that kRuntime attr is not attached when the mod was built with tvm.build().
@@ -476,10 +484,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
} else {
runtime = relay::Runtime::Create("cpp", {});
}
-
- bool has_aot_executor_fn = std::any_of(
- funcs.begin(), funcs.end(), [&](const auto& kv) { return is_aot_executor_fn(kv.second); });
- if (has_aot_executor_fn && runtime->name == relay::kTvmRuntimeCpp) {
+ if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) {
cg.InitGlobalContext();
}
diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h
index aeba685f74..694104afc0 100644
--- a/src/target/source/codegen_c_host.h
+++ b/src/target/source/codegen_c_host.h
@@ -44,7 +44,8 @@ class CodeGenCHost : public CodeGenC {
const std::unordered_set<std::string>& devices);
void InitGlobalContext();
- void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool emit_fwd_func_decl = false);
+ void AddFunction(const PrimFunc& f, bool emit_fwd_func_decl = false);
+ std::string Finish() final;
/*!
* \brief Add functions from the (unordered) range to the current module in a deterministic
* order. This helps with debugging.
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index 7639ce6065..a91f8b0164 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -75,7 +75,7 @@ class ThreadIdxExtractor : public tir::StmtVisitor {
PrimExpr threadIdx_z_ext = Integer(1);
};
-void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
+void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
ThreadIdxExtractor extractor;
extractor(f->body);
arith::Analyzer analyzer;
@@ -86,7 +86,7 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
// unable to extract the number of threads per block, hence directly return
return;
}
- os << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
+ stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
}
}
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index bc7b34b500..3ec0c3bc2d 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -47,7 +47,7 @@ class CodeGenCUDA final : public CodeGenC {
}
// override behavior
void PrintFuncPrefix(std::ostream& os) final;
- void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; // NOLINT(*)
+ void PrintExtraAttrs(const PrimFunc& f) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc
index 3db8d216b3..b8c30691e2 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -36,8 +36,6 @@ namespace codegen {
void CodeGenMetal::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
- // skip the first underscore, so SSA variable starts from _1
- name_supply_->FreshName("v_");
// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
@@ -54,33 +52,37 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) {
<< "};\n\n";
}
-void CodeGenMetal::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
- std::ostream& os) {
+void CodeGenMetal::AddFunction(const PrimFunc& f) {
+ // clear previous generated state.
+ this->InitFuncState(f);
+ // skip the first underscore, so SSA variable starts from _1
+ name_supply_->FreshName("v_");
+
// add to alloc buffer type.
- auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
// Function header.
- os << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
+ this->stream << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
// Buffer arguments
size_t num_buffer = 0;
size_t limit = target_->GetAttr<Integer>("max_function_args").value().IntValue();
- if (func->params.size() > limit) {
+ if (f->params.size() > limit) {
LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of "
"buffers in the kernel";
}
- for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) {
- Var v = func->params[i];
+ for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) {
+ Var v = f->params[i];
if (!v.dtype().is_handle()) break;
- os << " ";
+ stream << " ";
std::string vid = AllocVarID(v.get());
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
- PrintStorageScope(it->second, os);
+ PrintStorageScope(it->second, stream);
}
- PrintType(GetType(v), os);
+ PrintType(GetType(v), stream);
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
// type annotation(via a normalizing rewriting).
@@ -89,18 +91,19 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
RegisterHandleType(v.get(), prim->dtype);
}
}
- os << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
+ stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
}
// Setup normal arguments.
- size_t nargs = func->params.size() - num_buffer;
+ size_t nargs = f->params.size() - num_buffer;
std::string varg = name_supply_->FreshName("arg");
if (nargs != 0) {
std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) + "_args_t";
- os << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n";
+ stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer
+ << ") ]],\n";
// declare the struct
decl_stream << "struct " << arg_buf_type << " {\n";
- for (size_t i = num_buffer; i < func->params.size(); ++i) {
- Var v = func->params[i];
+ for (size_t i = num_buffer; i < f->params.size(); ++i) {
+ Var v = f->params[i];
ICHECK(!v.dtype().is_handle());
std::string vid = AllocVarID(v.get());
std::ostringstream vref;
@@ -128,7 +131,7 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
int work_dim = 0;
- auto launch_params = func->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value();
+ auto launch_params = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value();
for (const auto& tag : launch_params) {
if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) {
runtime::ThreadScope scope = runtime::ThreadScope::Create(tag);
@@ -147,7 +150,13 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
}
thread_work_dim_ = work_dim;
- stream << ")";
+ // the function scope.
+ stream << ") {\n";
+ int func_scope = this->BeginScope();
+ this->PrintStmt(f->body);
+ this->EndScope(func_scope);
+ this->PrintIndent();
+ this->stream << "}\n\n";
}
void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
@@ -333,33 +342,27 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile");
std::string fmt = fmetal_compile ? "metallib" : "metal";
- Map<GlobalVar, PrimFunc> functions;
- for (auto [gvar, base_func] : mod->functions) {
- ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
- auto calling_conv = base_func->GetAttr<Integer>(tvm::attr::kCallingConv);
- ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
- << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
-
- auto prim_func = Downcast<PrimFunc>(base_func);
- functions.Set(gvar, prim_func);
- }
+ for (auto kv : mod->functions) {
+ ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
+ auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(global_symbol.defined());
+ std::string func_name = global_symbol.value();
- for (auto [gvar, prim_func] : functions) {
- source_maker << "// Function: " << gvar->name_hint << "\n";
+ source_maker << "// Function: " << func_name << "\n";
CodeGenMetal cg(target);
cg.Init(output_ssa);
+ auto f = Downcast<PrimFunc>(kv.second);
+ auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+ ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
+ << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
- for (auto [other_gvar, other_prim_func] : functions) {
- cg.DeclareFunction(other_gvar, other_prim_func);
- }
- cg.AddFunction(gvar, prim_func);
-
+ cg.AddFunction(f);
std::string fsource = cg.Finish();
source_maker << fsource << "\n";
if (fmetal_compile) {
fsource = (*fmetal_compile)(fsource, target).operator std::string();
}
- smap[cg.GetFunctionName(gvar)] = fsource;
+ smap[func_name] = fsource;
}
return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str());
diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h
index 26c991e60d..36be10d163 100644
--- a/src/target/source/codegen_metal.h
+++ b/src/target/source/codegen_metal.h
@@ -38,8 +38,7 @@ class CodeGenMetal final : public CodeGenC {
explicit CodeGenMetal(Target target);
// override print thread tag.
void PrintArgUnionDecl();
- void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
- std::ostream& os) override;
+ void AddFunction(const PrimFunc& f); // NOLINT(*)
void InitFuncState(const PrimFunc& f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc
index da6a4de619..c15d2253d7 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -595,26 +595,18 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
- Map<GlobalVar, PrimFunc> functions;
- for (auto [gvar, base_func] : mod->functions) {
- ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
- auto prim_func = Downcast<PrimFunc>(base_func);
- auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
- ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
- << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
- functions.Set(gvar, prim_func);
- }
-
std::stringstream code;
const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc");
- for (auto [gvar, prim_func] : functions) {
- code << "// Function: " << gvar->name_hint << std::endl;
+ for (auto kv : mod->functions) {
+ ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
+ code << "// Function: " << kv.first->name_hint << std::endl;
CodeGenOpenCL cg;
cg.Init(output_ssa);
- for (auto [other_gvar, other_prim_func] : functions) {
- cg.DeclareFunction(other_gvar, other_prim_func);
- }
- cg.AddFunction(gvar, prim_func);
+ auto f = Downcast<PrimFunc>(kv.second);
+ auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+ ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
+ << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
+ cg.AddFunction(f);
std::string fsource = cg.Finish();
if (fpostproc) {
fsource = (*fpostproc)(fsource, target).operator std::string();
diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc
index aa7a32320c..83046de107 100644
--- a/src/target/source/codegen_vhls.cc
+++ b/src/target/source/codegen_vhls.cc
@@ -145,21 +145,13 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) {
// Generate source code for get_source().
cg.Init(output_ssa);
- Map<GlobalVar, PrimFunc> functions;
- for (auto [gvar, base_func] : mod->functions) {
- ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only take PrimFunc";
- auto prim_func = Downcast<PrimFunc>(base_func);
- auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+ for (auto kv : mod->functions) {
+ ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only take PrimFunc";
+ auto f = Downcast<PrimFunc>(kv.second);
+ auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
- functions.Set(gvar, prim_func);
- }
-
- for (auto [gvar, prim_func] : functions) {
- cg.DeclareFunction(gvar, prim_func);
- }
- for (auto [gvar, prim_func] : functions) {
- cg.AddFunction(gvar, prim_func);
+ cg.AddFunction(f);
}
std::string whole_code = cg.Finish();
@@ -167,21 +159,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) {
// Generate source code for compilation.
Array<Array<runtime::String>> kernel_info;
- for (auto [gvar, prim_func] : functions) {
+ for (auto kv : mod->functions) {
+ ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
+ auto f = Downcast<PrimFunc>(kv.second);
CodeGenVivadoHLS cg;
cg.Init(output_ssa);
-
- for (auto [other_gvar, other_prim_func] : functions) {
- cg.DeclareFunction(other_gvar, other_prim_func);
- }
- cg.AddFunction(gvar, prim_func);
+ cg.AddFunction(f);
std::string code = cg.Finish();
if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
code = (*f)(code, target).operator std::string();
}
- auto function_name = cg.GetFunctionName(gvar);
- kernel_info.push_back({function_name, code});
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(global_symbol.defined())
+ << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
+ kernel_info.push_back({global_symbol.value(), code});
}
std::string xclbin;
diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc
index 6a6712a4ce..4d1d834c7f 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -45,12 +45,6 @@ std::string CodeGenWebGPU::Finish() {
void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
- // skip the first underscore, so SSA variable starts from
- name_supply_->FreshName("v_");
- // Setup the thread group info.
- ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
- ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
-
// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
@@ -62,12 +56,28 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
-void CodeGenWebGPU::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
- std::ostream& os) {
+void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
+ // clear previous generated state.
+ this->InitFuncState(f);
+ // skip the first underscore, so SSA variable starts from
+ name_supply_->FreshName("v_");
+ // Setup the thread group info.
+ ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
+ ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
+
+ // add to alloc buffer type.
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ ICHECK(global_symbol.defined())
+ << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
+
+ decl_stream << "//----------------------------------------\n"
+ << "// function: " << global_symbol.value() << "\n"
+ << "//----------------------------------------\n";
+
std::vector<Var> pod_args;
int num_buffer = 0;
// setup buffer argumemts
- for (Var arg : func->params) {
+ for (Var arg : f->params) {
DataType t = arg.dtype();
if (t.is_handle()) {
auto* ptr = arg->type_annotation.as<PointerTypeNode>();
@@ -101,18 +111,16 @@ void CodeGenWebGPU::PrintFunctionSignature(const String& function_name, const Pr
}
// add to alloc buffer type.
// Function header.
- os << "fn main(\n"
- << " @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
- << " @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
- << ")";
-}
-
-void CodeGenWebGPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
- CodeGenC::AddFunction(gvar, func);
- decl_stream << "//----------------------------------------\n"
- << "// function: " << GetFunctionName(gvar) << "\n"
- << "//----------------------------------------\n";
-
+ this->stream << "fn main(\n"
+ << " @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
+ << " @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
+ << ") {\n";
+ // the function scope.
+ int func_scope = this->BeginScope();
+ this->PrintStmt(f->body);
+ this->EndScope(func_scope);
+ this->PrintIndent();
+ this->stream << "}\n\n";
// anotate workgroup
this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0] << ", "
<< workgroup_size_[1] << ", " << workgroup_size_[2] << ")\n";
@@ -516,31 +524,22 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) {
mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));
bool output_ssa = false;
- Map<GlobalVar, PrimFunc> functions;
- for (auto [gvar, base_func] : mod->functions) {
- ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only take PrimFunc";
- auto prim_func = Downcast<PrimFunc>(base_func);
- auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+ std::unordered_map<std::string, std::string> smap;
+ for (auto kv : mod->functions) {
+ CodeGenWebGPU cg(target);
+ ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only take PrimFunc";
+ auto f = Downcast<PrimFunc>(kv.second);
+ auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
- auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
- functions.Set(gvar, prim_func);
- }
-
- std::unordered_map<std::string, std::string> smap;
- for (auto [gvar, prim_func] : functions) {
- CodeGenWebGPU cg(target);
+ std::string f_name = global_symbol.value();
cg.Init(output_ssa);
-
- for (auto [other_gvar, other_prim_func] : functions) {
- cg.DeclareFunction(other_gvar, other_prim_func);
- }
- cg.AddFunction(gvar, prim_func);
-
+ cg.AddFunction(f);
std::string code = cg.Finish();
- smap[cg.GetFunctionName(gvar)] = code;
+ smap[f_name] = code;
}
auto n = make_object<WebGPUSourceModuleNode>(smap, ExtractFuncInfo(mod));
return runtime::Module(n);
diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h
index 6ae942a3ad..57f226ba8a 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -48,9 +48,7 @@ class CodeGenWebGPU final : public CodeGenC {
explicit CodeGenWebGPU(Target target);
// overrides
std::string Finish() final;
- void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
- std::ostream& os) final;
- void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final;
+ void AddFunction(const PrimFunc& f); // NOLINT(*)
void InitFuncState(const PrimFunc& f) final;
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc
index 90640a6db6..a6f4b5bb3e 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -613,14 +613,12 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
}
for (const tir::Var& pool_var : metadata_->pools) {
- call_args_ss << "((uint8_t*)";
String pool_name = metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name;
if (IsInternalWorkspaceBuffer(pool_var)) {
- call_args_ss << "&" << pool_name;
+ call_args_ss << "&" << pool_name << ",";
} else {
- call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name);
+ call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << ",";
}
- call_args_ss << "),";
}
for (const String& device : metadata_->devices) {
call_args_ss << "devices->" << device << ",";
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index fd14f48921..39214c4546 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -70,32 +70,6 @@ Type GetType(const PrimExpr& expr) {
return ptr->type_annotation;
}
}
-
- if (auto* access = expr.as<tir::CallNode>()) {
- if (access->op.same_as(builtin::tvm_access_ptr())) {
- ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have empty arguments";
- auto type_annotation = Downcast<Call>(access->args[0]);
- static auto builtin_op = Op::Get("tir.type_annotation");
- ICHECK(type_annotation->op.same_as(builtin_op))
- << "Expected the first argument of builtin tvm_access_ptr() "
- << "to be a type annotation, but found " << type_annotation->op;
- return PointerType(PrimType(type_annotation->dtype));
- }
- }
-
- if (auto* address_of = expr.as<tir::CallNode>()) {
- if (address_of->op.same_as(builtin::address_of())) {
- ICHECK_EQ(address_of->args.size(), 1)
- << "Builtin address_of() expects a single argument, but received arguments "
- << address_of->args;
- auto* address = address_of->args[0].as<BufferLoadNode>();
- ICHECK(address)
- << "Builtin address_of() expects the argument to be a BufferLoad, but received argument "
- << address_of->args[0];
-
- return PointerType(PrimType(address->dtype));
- }
- }
// Default: return the type indicated by the dtype.
runtime::DataType dtype = expr.dtype();
return GetTypeFromRuntimeDataType(dtype);
diff --git a/tests/python/relay/aot/test_crt_forward_declarations.py b/tests/python/relay/aot/test_crt_forward_declarations.py
index 99e2f0c923..72305b0030 100644
--- a/tests/python/relay/aot/test_crt_forward_declarations.py
+++ b/tests/python/relay/aot/test_crt_forward_declarations.py
@@ -160,8 +160,8 @@ def test_internal_calls(interface_api, use_unpacked_api, test_runner):
lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0]
main_source = lib_mod.get_source()
- assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 2
- assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 6
+ assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 1
+ assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 3
@tvm.testing.requires_corstone300
diff --git a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
index f6145cd1c5..7bea7577b6 100644
--- a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
+++ b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
@@ -135,13 +135,8 @@ def test_write_3x3_depthwise_code():
#define TENSORDOT_OPT_X1_INT16_W48_3X3_000_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w48_3x3_000(
- int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
- int32_t *bias, int32_t *scale
+ int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
) {
- int32_t *output = output_arg;
- int32_t *tensor = tensor_arg;
- int32_t *kernel = kernel_arg;
-
int32_t sum_0 = *bias;
int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -193,13 +188,8 @@ def test_odd_width_3x3_depthwise_strides_code():
#define TENSORDOT_OPT_X2_INT16_W49_3X3_000_2_4_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t tensordot_opt_x2_int16_w49_3x3_000_2_4(
- int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
- int32_t *bias, int32_t *scale
+ int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
) {
- int32_t *output = output_arg;
- int32_t *tensor = tensor_arg;
- int32_t *kernel = kernel_arg;
-
int32_t sum_0 = *bias, sum_1 = *bias;
int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -261,13 +251,8 @@ def test_1x1x8_convolution_code():
#define TENSORDOT_OPT_X4_INT16_W384_1X8_000_8_1_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t tensordot_opt_x4_int16_w384_1x8_000_8_1(
- int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
- int32_t *bias, int32_t *scale
+ int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
) {
- int32_t *output = output_arg;
- int32_t *tensor = tensor_arg;
- int32_t *kernel = kernel_arg;
-
int32_t sum_0 = *bias, sum_1 = *bias, sum_2 = *bias, sum_3 = *bias;
int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -364,13 +349,8 @@ def test_3x3x3_offset_convolution_code():
#define TENSORDOT_OPT_X1_INT16_W288_3X9_111_EXISTS
#include <arm_acle.h>
__attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w288_3x9_111(
- int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
- int32_t *bias, int32_t *scale
+ int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
) {
- int32_t *output = output_arg;
- int32_t *tensor = tensor_arg;
- int32_t *kernel = kernel_arg;
-
int32_t sum_0 = *bias;
int32_t tensor__unknown__y00_x00 = tensor[0];
diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py
index 3aca0fc8c7..d02f8744f1 100644
--- a/tests/python/unittest/test_target_codegen_c_host.py
+++ b/tests/python/unittest/test_target_codegen_c_host.py
@@ -14,15 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
import tvm
import tvm.testing
-
from tvm import te
-from tvm.contrib import utils
-from tvm.script import tir as T, ir as I
-
import numpy as np
+from tvm.contrib import utils
def test_add():
@@ -232,39 +228,11 @@ def test_call_packed():
check_global_packed_func()
-def test_subroutine_call():
- @I.ir_module
- class mod:
- @T.prim_func
- def main(A: T.Buffer(1, dtype="float32")):
- mod.subroutine(A.data)
-
- @T.prim_func(private=True)
- def subroutine(A_data: T.handle("float32")):
- A = T.decl_buffer(1, dtype="float32", data=A_data)
- A[0] = 42.0
-
- built = tvm.build(mod, target="c")
-
- func_names = list(built["get_func_names"]())
- assert (
- "main" in func_names
- ), "Externally exposed functions should be listed in available functions."
- assert (
- "subroutine" not in func_names
- ), "Internal function should not be listed in available functions."
-
- source = built.get_source()
- assert (
- source.count("main(void*") == 2
- ), "Expected two occurrences, for forward-declaration and definition"
- assert (
- source.count("subroutine(float*") == 2
- ), "Expected two occurrences, for forward-declaration and definition"
- assert (
- source.count("subroutine(") == 3
- ), "Expected three occurrences, for forward-declaration, definition, and call from main."
-
-
if __name__ == "__main__":
- tvm.testing.main()
+ test_add()
+ test_add_pipeline()
+ test_reinterpret()
+ test_ceil()
+ test_floor()
+ test_round()
+ test_call_packed()
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index 61f0892a9c..588a92d87c 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -268,7 +268,6 @@ cast_smem_ptr_to_int(const void* const smem_ptr)
#define int64_t long long
#define uint64_t unsigned long long
#endif
-extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C);
extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) {
__shared__ float A_shared[64];
__shared__ float B_shared[64];