You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ar...@apache.org on 2021/09/03 16:18:27 UTC
[tvm] branch main updated: Sanitize names of input tensors in
interface header (#8720)
This is an automated email from the ASF dual-hosted git repository.
areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 06a0d63 Sanitize names of input tensors in interface header (#8720)
06a0d63 is described below
commit 06a0d63c43e251cfc9fa9e81bfad5aa45219652e
Author: Grant Watson <gr...@arm.com>
AuthorDate: Fri Sep 3 17:18:10 2021 +0100
Sanitize names of input tensors in interface header (#8720)
* Sanitize names of input tensors in interface header
Change-Id: I7f02a993887bf84316262cd2586a734a9079c338
* Update tensor name sanitizer tests to parameterize them.
Change-Id: I157d8d8d607de2904285e403893f146e97b510d5
* Only test unpacked, C interface API, AOT case
Change-Id: I9082ae32079a1a3924c06c7f26c757aafa46dec2
---
python/tvm/micro/interface_api.py | 13 +++++++-
src/target/source/source_module.cc | 6 +++-
tests/python/relay/aot/aot_test_utils.py | 17 +++++++---
tests/python/relay/aot/test_crt_aot.py | 56 ++++++++++++++++++++++++++++++++
4 files changed, 86 insertions(+), 6 deletions(-)
diff --git a/python/tvm/micro/interface_api.py b/python/tvm/micro/interface_api.py
index 8086b1e..d9961e9 100644
--- a/python/tvm/micro/interface_api.py
+++ b/python/tvm/micro/interface_api.py
@@ -17,7 +17,13 @@
"""Defines functions for generating a C interface header"""
+# TODO: Currently the Interface API header is generated in Python but the source it references
+# is generated in C++. These should be consolidated to generate both header and source in C++
+# and avoid re-implementing logic, such as name sanitising, in the two different languages.
+# See https://github.com/apache/tvm/issues/8792 .
+
import os
+import re
from tvm.relay.backend.utils import mangle_module_name
@@ -58,8 +64,13 @@ def generate_c_interface_header(module_name, inputs, outputs, output_path):
_emit_brief(header_file, module_name, "Input tensor pointers")
header_file.write(f"struct {mangled_name}_inputs {{\n")
+ sanitized_names = []
for input_name in inputs:
- header_file.write(f" void* {input_name};\n")
+ sanitized_input_name = re.sub(r"\W", "_", input_name)
+ if sanitized_input_name in sanitized_names:
+ raise ValueError(f"Sanitized input tensor name clash: {sanitized_input_name}")
+ sanitized_names.append(sanitized_input_name)
+ header_file.write(f" void* {sanitized_input_name};\n")
header_file.write("};\n\n")
_emit_brief(header_file, module_name, "Output tensor pointers")
diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc
index 7728773..9b93b07 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -234,6 +234,8 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
code_ << "}\n";
}
+ static int isNotAlnum(char c) { return !std::isalnum(c); }
+
void GenerateCInterfaceEntrypoint(const std::string& entrypoint_name, const std::string& run_func,
const std::string& mod_name) {
code_ << "#include <" << mod_name << ".h>\n";
@@ -252,7 +254,9 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
<< ") {";
code_ << "return " << run_func << "(";
for (const auto& input : metadata_->inputs) {
- code_ << "inputs->" << input << ",";
+ std::string sanitised_input = input;
+ std::replace_if(sanitised_input.begin(), sanitised_input.end(), isNotAlnum, '_');
+ code_ << "inputs->" << sanitised_input << ",";
}
if (metadata_->num_outputs == 1) {
code_ << "outputs->output";
diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py
index e5ac85b..baa2397 100644
--- a/tests/python/relay/aot/aot_test_utils.py
+++ b/tests/python/relay/aot/aot_test_utils.py
@@ -22,6 +22,7 @@ import logging
import os
import pathlib
import platform
+import re
import shutil
import subprocess
import tarfile
@@ -250,7 +251,10 @@ int main(){\n
def emit_main_data(main_file, input_map, output_list, mod_name):
for key in input_map:
- main_file.write(f'#include "{mangle_name(mod_name,"input_data")}_{key}.h"\n')
+ sanitized_tensor_name = re.sub(r"\W", "_", key)
+ main_file.write(
+ f'#include "{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}.h"\n'
+ )
for i in range(0, len(output_list)):
main_file.write(f'#include "{mangle_name(mod_name,"expected_output_data")}{i}.h"\n')
@@ -262,7 +266,10 @@ def emit_main_data_structs(main_file, input_map, output_list, mod_name):
f"struct {mangle_name(mod_name, 'inputs')} {mangle_name(mod_name, 'inputs')} = {{"
)
for key in input_map:
- main_file.write(f"\t.{key} = {mangle_name(mod_name, 'input_data')}_{key},\n")
+ sanitized_tensor_name = re.sub(r"\W", "_", key)
+ main_file.write(
+ f"\t.{sanitized_tensor_name} = {mangle_name(mod_name, 'input_data')}_{sanitized_tensor_name},\n"
+ )
main_file.write("};\n")
main_file.write(
@@ -283,7 +290,8 @@ def emit_main_data_setup(main_file, input_map, output_list, mod_name):
main_file.write(f'void* {mangle_name(mod_name,"inputs")}[{num_inputs}] = {{ ')
for key in input_map:
- main_file.write(f'{mangle_name(mod_name,"input_data")}_{key}, ')
+ sanitized_tensor_name = re.sub(r"\W", "_", key)
+ main_file.write(f'{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}, ')
main_file.write("};\n")
main_file.write(f'void* {mangle_name(mod_name,"outputs")}[{num_outputs}] = {{ ')
@@ -521,8 +529,9 @@ def compile_and_run(
workspace_bytes += extract_main_workspace_size_bytes(base_path)
for key in model.inputs:
+ sanitized_tensor_name = re.sub(r"\W", "_", key)
create_header_file(
- f'{mangle_name(model.name, "input_data")}_{key}',
+ f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}',
model.inputs[key],
include_path,
)
diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py
index 36cffef..64000a9 100644
--- a/tests/python/relay/aot/test_crt_aot.py
+++ b/tests/python/relay/aot/test_crt_aot.py
@@ -503,5 +503,61 @@ def test_transpose(interface_api, use_unpacked_api, test_runner):
)
+def test_name_sanitiser():
+ """Test that input tensors with special characters in the name don't break compilation"""
+
+ interface_api = "c"
+ use_unpacked_api = True
+ test_runner = AOT_DEFAULT_RUNNER
+
+ func = relay.var("input-x::2", "float32")
+ ident = relay.Function([func], func)
+ one = np.array(1.0, "float32")
+ inputs = {"input-x::2": one}
+ output_list = generate_ref_data(ident, inputs)
+
+ compile_and_run(
+ AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list),
+ test_runner,
+ interface_api,
+ use_unpacked_api,
+ enable_op_fusion=False,
+ )
+
+
+def test_name_sanitiser_name_clash():
+ """Test that 2 input tensors with names that clash once sanitized, generates an error"""
+
+ interface_api = "c"
+ use_unpacked_api = True
+ test_runner = AOT_DEFAULT_RUNNER
+
+ dtype = "float32"
+ x = relay.var("input::-1", shape=(10, 5), dtype=dtype)
+ # Next 2 input tensor names will clash once sanitized.
+ y = relay.var("input::-2", shape=(10, 5), dtype=dtype)
+ t = relay.var("input:--2", shape=(), dtype=dtype)
+ a = relay.add(x, y)
+ b = relay.transpose(a)
+ z = relay.add(b, t)
+ # Check result.
+ func = relay.Function([x, y, t], z)
+ x_data = np.random.rand(10, 5).astype(dtype)
+ y_data = np.random.rand(10, 5).astype(dtype)
+ t_data = np.random.uniform(size=()).astype(dtype)
+
+ inputs = {"input::-1": x_data, "input::-2": y_data, "input:--2": t_data}
+ output_list = generate_ref_data(func, inputs)
+
+ with pytest.raises(ValueError, match="Sanitized input tensor name clash"):
+ compile_and_run(
+ AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list),
+ test_runner,
+ interface_api,
+ use_unpacked_api,
+ enable_op_fusion=False,
+ )
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))