You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ar...@apache.org on 2021/09/03 16:18:27 UTC

[tvm] branch main updated: Sanitize names of input tensors in interface header (#8720)

This is an automated email from the ASF dual-hosted git repository.

areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 06a0d63  Sanitize names of input tensors in interface header (#8720)
06a0d63 is described below

commit 06a0d63c43e251cfc9fa9e81bfad5aa45219652e
Author: Grant Watson <gr...@arm.com>
AuthorDate: Fri Sep 3 17:18:10 2021 +0100

    Sanitize names of input tensors in interface header (#8720)
    
    * Sanitize names of input tensors in interface header
    
    Change-Id: I7f02a993887bf84316262cd2586a734a9079c338
    
    * Update tensor name sanitizer tests to parameterize them.
    
    Change-Id: I157d8d8d607de2904285e403893f146e97b510d5
    
    * Only test unpacked, C interface API, AOT case
    
    Change-Id: I9082ae32079a1a3924c06c7f26c757aafa46dec2
---
 python/tvm/micro/interface_api.py        | 13 +++++++-
 src/target/source/source_module.cc       |  6 +++-
 tests/python/relay/aot/aot_test_utils.py | 17 +++++++---
 tests/python/relay/aot/test_crt_aot.py   | 56 ++++++++++++++++++++++++++++++++
 4 files changed, 86 insertions(+), 6 deletions(-)

diff --git a/python/tvm/micro/interface_api.py b/python/tvm/micro/interface_api.py
index 8086b1e..d9961e9 100644
--- a/python/tvm/micro/interface_api.py
+++ b/python/tvm/micro/interface_api.py
@@ -17,7 +17,13 @@
 
 """Defines functions for generating a C interface header"""
 
+# TODO: Currently the Interface API header is generated in Python but the source it references
+# is generated in C++. These should be consolidated to generate both header and source in C++
+# and avoid re-implementing logic, such as name sanitising, in the two different languages.
+# See https://github.com/apache/tvm/issues/8792 .
+
 import os
+import re
 
 from tvm.relay.backend.utils import mangle_module_name
 
@@ -58,8 +64,13 @@ def generate_c_interface_header(module_name, inputs, outputs, output_path):
 
         _emit_brief(header_file, module_name, "Input tensor pointers")
         header_file.write(f"struct {mangled_name}_inputs {{\n")
+        sanitized_names = []
         for input_name in inputs:
-            header_file.write(f"  void* {input_name};\n")
+            sanitized_input_name = re.sub(r"\W", "_", input_name)
+            if sanitized_input_name in sanitized_names:
+                raise ValueError(f"Sanitized input tensor name clash: {sanitized_input_name}")
+            sanitized_names.append(sanitized_input_name)
+            header_file.write(f"  void* {sanitized_input_name};\n")
         header_file.write("};\n\n")
 
         _emit_brief(header_file, module_name, "Output tensor pointers")
diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc
index 7728773..9b93b07 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -234,6 +234,8 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
     code_ << "}\n";
   }
 
+  static int isNotAlnum(char c) { return !std::isalnum(c); }
+
   void GenerateCInterfaceEntrypoint(const std::string& entrypoint_name, const std::string& run_func,
                                     const std::string& mod_name) {
     code_ << "#include <" << mod_name << ".h>\n";
@@ -252,7 +254,9 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
           << ") {";
     code_ << "return " << run_func << "(";
     for (const auto& input : metadata_->inputs) {
-      code_ << "inputs->" << input << ",";
+      std::string sanitised_input = input;
+      std::replace_if(sanitised_input.begin(), sanitised_input.end(), isNotAlnum, '_');
+      code_ << "inputs->" << sanitised_input << ",";
     }
     if (metadata_->num_outputs == 1) {
       code_ << "outputs->output";
diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py
index e5ac85b..baa2397 100644
--- a/tests/python/relay/aot/aot_test_utils.py
+++ b/tests/python/relay/aot/aot_test_utils.py
@@ -22,6 +22,7 @@ import logging
 import os
 import pathlib
 import platform
+import re
 import shutil
 import subprocess
 import tarfile
@@ -250,7 +251,10 @@ int main(){\n
 
 def emit_main_data(main_file, input_map, output_list, mod_name):
     for key in input_map:
-        main_file.write(f'#include "{mangle_name(mod_name,"input_data")}_{key}.h"\n')
+        sanitized_tensor_name = re.sub(r"\W", "_", key)
+        main_file.write(
+            f'#include "{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}.h"\n'
+        )
 
     for i in range(0, len(output_list)):
         main_file.write(f'#include "{mangle_name(mod_name,"expected_output_data")}{i}.h"\n')
@@ -262,7 +266,10 @@ def emit_main_data_structs(main_file, input_map, output_list, mod_name):
         f"struct {mangle_name(mod_name, 'inputs')} {mangle_name(mod_name, 'inputs')} = {{"
     )
     for key in input_map:
-        main_file.write(f"\t.{key} = {mangle_name(mod_name, 'input_data')}_{key},\n")
+        sanitized_tensor_name = re.sub(r"\W", "_", key)
+        main_file.write(
+            f"\t.{sanitized_tensor_name} = {mangle_name(mod_name, 'input_data')}_{sanitized_tensor_name},\n"
+        )
     main_file.write("};\n")
 
     main_file.write(
@@ -283,7 +290,8 @@ def emit_main_data_setup(main_file, input_map, output_list, mod_name):
 
     main_file.write(f'void* {mangle_name(mod_name,"inputs")}[{num_inputs}] = {{ ')
     for key in input_map:
-        main_file.write(f'{mangle_name(mod_name,"input_data")}_{key}, ')
+        sanitized_tensor_name = re.sub(r"\W", "_", key)
+        main_file.write(f'{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}, ')
     main_file.write("};\n")
 
     main_file.write(f'void* {mangle_name(mod_name,"outputs")}[{num_outputs}]  = {{ ')
@@ -521,8 +529,9 @@ def compile_and_run(
         workspace_bytes += extract_main_workspace_size_bytes(base_path)
 
         for key in model.inputs:
+            sanitized_tensor_name = re.sub(r"\W", "_", key)
             create_header_file(
-                f'{mangle_name(model.name, "input_data")}_{key}',
+                f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}',
                 model.inputs[key],
                 include_path,
             )
diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py
index 36cffef..64000a9 100644
--- a/tests/python/relay/aot/test_crt_aot.py
+++ b/tests/python/relay/aot/test_crt_aot.py
@@ -503,5 +503,61 @@ def test_transpose(interface_api, use_unpacked_api, test_runner):
     )
 
 
+def test_name_sanitiser():
+    """Test that input tensors with special characters in the name don't break compilation"""
+
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_DEFAULT_RUNNER
+
+    func = relay.var("input-x::2", "float32")
+    ident = relay.Function([func], func)
+    one = np.array(1.0, "float32")
+    inputs = {"input-x::2": one}
+    output_list = generate_ref_data(ident, inputs)
+
+    compile_and_run(
+        AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list),
+        test_runner,
+        interface_api,
+        use_unpacked_api,
+        enable_op_fusion=False,
+    )
+
+
+def test_name_sanitiser_name_clash():
+    """Test that 2 input tensors with names that clash once sanitized, generates an error"""
+
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_DEFAULT_RUNNER
+
+    dtype = "float32"
+    x = relay.var("input::-1", shape=(10, 5), dtype=dtype)
+    # Next 2 input tensor names will clash once sanitized.
+    y = relay.var("input::-2", shape=(10, 5), dtype=dtype)
+    t = relay.var("input:--2", shape=(), dtype=dtype)
+    a = relay.add(x, y)
+    b = relay.transpose(a)
+    z = relay.add(b, t)
+    # Check result.
+    func = relay.Function([x, y, t], z)
+    x_data = np.random.rand(10, 5).astype(dtype)
+    y_data = np.random.rand(10, 5).astype(dtype)
+    t_data = np.random.uniform(size=()).astype(dtype)
+
+    inputs = {"input::-1": x_data, "input::-2": y_data, "input:--2": t_data}
+    output_list = generate_ref_data(func, inputs)
+
+    with pytest.raises(ValueError, match="Sanitized input tensor name clash"):
+        compile_and_run(
+            AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list),
+            test_runner,
+            interface_api,
+            use_unpacked_api,
+            enable_op_fusion=False,
+        )
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))