You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/11/16 08:09:35 UTC

[tvm] branch main updated: [2/3][AOT][DeviceAPI] Add Hooks for Activate/Deactivate/Open/Close (#9500)

This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 948641c  [2/3][AOT][DeviceAPI] Add Hooks for Activate/Deactivate/Open/Close (#9500)
948641c is described below

commit 948641c7f4abef6376094d5eddcc52e79521d425
Author: Christopher Sidebottom <ch...@arm.com>
AuthorDate: Tue Nov 16 08:09:00 2021 +0000

    [2/3][AOT][DeviceAPI] Add Hooks for Activate/Deactivate/Open/Close (#9500)
    
    * [AOT][DeviceAPI] Add Hooks for Activate/Deactivate/Open/Close
    
    This adds the relevant hooks into their starting places in the code
    generation. As per the [C Device API
    RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0031-devices-api.md)
    
    * Standardise on `lowered_ir_mods` and correct device_hook variable name
---
 apps/microtvm/ethosu/include/tvm_ethosu_runtime.h  |   9 +-
 apps/microtvm/ethosu/src/tvm_ethosu_runtime.c      |   8 +-
 python/tvm/relay/backend/executor_factory.py       |  24 ++++-
 python/tvm/relay/build_module.py                   |  15 ++-
 src/relay/backend/aot_executor_codegen.cc          |  73 +++++++++++++--
 .../contrib/ethosu/bare_metal/tvm_ethosu_runtime.c |   8 +-
 .../contrib/ethosu/bare_metal/tvm_ethosu_runtime.h |   9 +-
 tests/python/relay/aot/test_crt_aot.py             | 104 +++++++++++++++++++++
 8 files changed, 234 insertions(+), 16 deletions(-)

diff --git a/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h b/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h
index 06188ba..8352fa5 100644
--- a/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h
+++ b/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h
@@ -24,7 +24,14 @@
 #include <stddef.h>
 #include <stdint.h>
 
-int32_t TVMEthosULaunch(struct ethosu_driver* resource_handle, void* cms_data, size_t cms_data_size,
+typedef void tvm_device_ethos_u_t;
+
+int32_t TVMEthosULaunch(tvm_device_ethos_u_t* resource_handle, void* cms_data, size_t cms_data_size,
                         uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors);
 
+int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context);
+
 #endif  // TVM_RUNTIME_CONTRIB_ETHOSU_ETHOSU_RUNTIME_H_
diff --git a/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c b/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
index 6b7399b..8e50602 100644
--- a/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
+++ b/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
@@ -21,8 +21,9 @@
 
 #include <ethosu_driver.h>
 
-int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
+int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cms_data_size,
                         uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors) {
+  struct ethosu_driver* driver = (struct ethosu_driver*)context;
   int32_t result =
       ethosu_invoke(driver, cms_data, cms_data_size, base_addrs, base_addrs_size, num_tensors);
 
@@ -32,3 +33,8 @@ int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms
   }
   return 0;
 }
+
+int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py
index db33c1b..7b96dd8 100644
--- a/python/tvm/relay/backend/executor_factory.py
+++ b/python/tvm/relay/backend/executor_factory.py
@@ -75,6 +75,8 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
     ----------
     ir_mod : :py:class:`~tvm.IRModule`
         The IR module to build.
+    lowered_ir_mods : dict[Target, IRModule]
+        The IR modules lowered per Target.
     target : tvm.Target
         The Target used to build this module.
     libmod : tvm.Module
@@ -89,8 +91,19 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
         List of devices used in the module
     """
 
-    def __init__(self, ir_mod, target, libmod, libmod_name, params, function_metadata, devices):
+    def __init__(
+        self,
+        ir_mod,
+        lowered_ir_mods,
+        target,
+        libmod,
+        libmod_name,
+        params,
+        function_metadata,
+        devices,
+    ):
         self.ir_mod = ir_mod
+        self.lowered_ir_mods = lowered_ir_mods
         self.target = target
         self.lib = libmod
         self.libmod_name = libmod_name
@@ -136,7 +149,14 @@ class GraphExecutorFactoryModule(ExecutorFactoryModule):
     """
 
     def __init__(
-        self, ir_mod, target, graph_json_str, libmod, libmod_name, params, function_metadata
+        self,
+        ir_mod,
+        target,
+        graph_json_str,
+        libmod,
+        libmod_name,
+        params,
+        function_metadata,
     ):
         assert isinstance(graph_json_str, string_types)
         fcreate = get_global_func("tvm.graph_executor_factory.create")
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 38c9a40..b66d5fb 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -102,6 +102,7 @@ class BuildModule(object):
         self._get_params_func = self.mod["get_params"]
         self._get_function_metadata = self.mod["get_function_metadata"]
         self._get_devices = self.mod["get_devices"]
+        self._get_irmodule = self.mod["get_irmodule"]
 
     def build(
         self, mod, target=None, target_host=None, params=None, executor="graph", mod_name=None
@@ -249,6 +250,10 @@ class BuildModule(object):
             ret[key] = value.data
         return ret
 
+    def get_irmodule(self):
+        """Returns the Target IRModule's post-lowering"""
+        return self._get_irmodule()
+
 
 @register_func("tvm.relay.module_export_library")
 def _module_export(module, file_name):  # fcompile, addons, kwargs?
@@ -376,10 +381,18 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
         )
         func_metadata = bld_mod.get_function_metadata()
         devices = bld_mod.get_devices()
+        lowered_ir_mods = bld_mod.get_irmodule()
 
         if executor == "aot":
             executor_factory = _executor_factory.AOTExecutorFactoryModule(
-                ir_mod, target, runtime_mod, mod_name, params, func_metadata, devices
+                ir_mod,
+                lowered_ir_mods,
+                target,
+                runtime_mod,
+                mod_name,
+                params,
+                func_metadata,
+                devices,
             )
         elif executor == "graph":
             executor_factory = _executor_factory.GraphExecutorFactoryModule(
diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc
index c240ec8..fde1de0 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -349,11 +349,19 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     GlobalVar global_var = call_lowered_props.lowered_func;
     bool has_c_device_api_context = device_contexts_.count(global_var) != 0;
     if (has_c_device_api_context) {
+      tir::Var context = device_contexts_.Get(global_var).value();
       args.push_back(device_contexts_[global_var]);
-    }
 
-    tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
-    create_func_call_stmts.push_back(func_call);
+      tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
+      create_func_call_stmts.push_back(tir::SeqStmt({
+          GenerateDeviceHook(context, "Open"),
+          func_call,
+          GenerateDeviceHook(context, "Close"),
+      }));
+    } else {
+      tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
+      create_func_call_stmts.push_back(func_call);
+    }
 
     tir::Stmt body = tir::SeqStmt(create_func_call_stmts);
     stmts_.push_back(body);
@@ -417,7 +425,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
           device_context_var = (*pair).second;
         } else {
           main_signature_.push_back(device_context_var);
-          devices_.push_back(context_name);
+          devices_.Set(context_name, device_context_var);
           target_contexts.Set(target_kind.value(), device_context_var);
         }
 
@@ -426,6 +434,44 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     }
   }
 
+  /**
+   * \brief Generates a call to a given hook for all Devices found for C Device API
+   * \param Name of hook to generate statements for
+   * \return Statement with function calls for each device
+   */
+  tir::Stmt GenerateAllDeviceHook(const String& hook) {
+    std::vector<tir::Stmt> device_hooks;
+    for (const auto& it : devices_) {
+      const String& device_name = it.first;
+      const tir::Var& context = it.second;
+      Array<String> sections = {"Device", device_name, hook};
+      String device_hook_name = ToCFunctionStyle(PrefixName(sections));
+
+      tir::Evaluate device_hook(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
+                                               {tvm::tir::StringImm(device_hook_name), context}));
+      device_hooks.push_back(device_hook);
+    }
+    return tir::SeqStmt(device_hooks);
+  }
+
+  /**
+   * \brief Generates a call to a given hook for a single Device function
+   * \param Var Device context to call hook on
+   * \param Name of hook to generate statements for
+   * \return Statement with function call to Device API
+   */
+  tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) {
+    const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) {
+      return it.second->name_hint == context->name_hint;
+    });
+    const String& device_name = (*it).first;
+    Array<String> sections = {"Device", device_name, hook};
+    String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+    return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
+                                   {tvm::tir::StringImm(device_hook), context}));
+  }
+
   /*!
    * Utility function to string together different arguments
    */
@@ -587,8 +633,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     dict_attrs.Set("global_symbol", run_func_name);
     dict_attrs.Set("runner_function", Bool(true));
 
+    tir::Stmt device_activations = GenerateAllDeviceHook("Activate");
+    tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate");
+    tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations});
+
     // Make the PrimFunc
-    return tir::PrimFunc(main_signature_, body, VoidType(), Map<tir::Var, tir::Buffer>(),
+    return tir::PrimFunc(main_signature_, final_body, VoidType(), Map<tir::Var, tir::Buffer>(),
                          DictAttrs(dict_attrs));
   }
 
@@ -597,8 +647,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
   runtime::Module* mod_;
   /*! \brief list of input expressions (i.e., variable passed by the user) */
   std::vector<Var> input_vars_;
-  /*! \brief list of device contexts used */
-  std::vector<String> devices_;
+  /*! \brief map of device contexts variables */
+  Map<String, tir::Var> devices_;
   /*! \brief map of GlobalVars to C Device API contexts */
   Map<GlobalVar, tir::Var> device_contexts_;
   /*! \brief input and output variables belonging to the main function signature */
@@ -779,7 +829,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(),
                    [](Var input_var) -> String { return input_var->name_hint(); });
 
-    ret.metadata = runtime::Metadata(input_var_names, devices_, return_sid_.size(),
+    ret.metadata = runtime::Metadata(input_var_names, ListDevices(), return_sid_.size(),
                                      runtime::kTvmExecutorAot, mod_name);
     return ret;
   }
@@ -788,7 +838,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
    * \brief Get list of devices found
    * \return List of devices
    */
-  Array<String> ListDevices() { return devices_; }
+  Array<String> ListDevices() {
+    std::vector<String> device_names(devices_.size());
+    std::transform(devices_.begin(), devices_.end(), device_names.begin(),
+                   [](const auto& it) -> String { return it.first; });
+    return device_names;
+  }
 };  // namespace backend
 
 class AOTExecutorCodegenModule : public runtime::ModuleNode {
diff --git a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
index 6b7399b..8e50602 100644
--- a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
+++ b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
@@ -21,8 +21,9 @@
 
 #include <ethosu_driver.h>
 
-int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
+int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cms_data_size,
                         uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors) {
+  struct ethosu_driver* driver = (struct ethosu_driver*)context;
   int32_t result =
       ethosu_invoke(driver, cms_data, cms_data_size, base_addrs, base_addrs_size, num_tensors);
 
@@ -32,3 +33,8 @@ int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms
   }
   return 0;
 }
+
+int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
diff --git a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h
index d62afc4..31d1755 100644
--- a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h
+++ b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h
@@ -24,7 +24,14 @@
 #include <stddef.h>
 #include <stdint.h>
 
-int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
+typedef void tvm_device_ethos_u_t;
+
+int32_t TVMEthosULaunch(tvm_device_ethos_u_t* resource_handle, void* cms_data, size_t cms_data_size,
                         uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors);
 
+int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context);
+
 #endif  // TVM_RUNTIME_CONTRIB_ETHOSU_BARE_METAL_TVM_ETHOSU_RUNTIME_H_
diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py
index 94ecaba..e2bbb24 100644
--- a/tests/python/relay/aot/test_crt_aot.py
+++ b/tests/python/relay/aot/test_crt_aot.py
@@ -17,6 +17,7 @@
 
 from collections import OrderedDict
 import sys
+import re
 
 import numpy as np
 import pytest
@@ -693,5 +694,108 @@ def test_aot_codegen_backend_alloc_workspace_calls():
     assert source.count("TVMBackendAllocWorkspace") == 3
 
 
+def test_device_api_hooks():
+    """Check for Device API hooks"""
+
+    # Ideally we should have a sample Target registered here
+    # but we're going to re-use this for now
+    pytest.importorskip("ethosu.vela")
+    import tensorflow as tf
+    import tflite.Model
+
+    from tests.python.contrib.test_ethosu import infra
+    from tvm.relay.op.contrib.ethosu import partition_for_ethosu
+
+    def create_tflite_graph():
+        tf.config.run_functions_eagerly(True)
+
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x):
+                return tf.nn.max_pool(x, [1, 2], [1, 2], "SAME")
+
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple([1, 3, 4, 3]))
+                yield [data.astype(np.float32)]
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec([1, 3, 4, 3], dtype=tf.float32)
+        )
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    relay_module, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"x": [1, 3, 4, 3]},
+        dtype_dict={"x": "int8"},
+    )
+    mod = partition_for_ethosu(relay_module, params)
+
+    # Generate reference data
+    input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+    compiled_models = infra.build_source(
+        mod,
+        input_data,
+        output_data,
+    )
+    main_ir_module = list(compiled_models[0].executor_factory.lowered_ir_mods.values())[0]
+    main_func = main_ir_module["run_model"]
+
+    # Activate Device
+    assert (
+        str(main_func.body[0][0].value)
+        == "@tir.call_extern("
+        + '"TVMDeviceEthosUActivate",'
+        + " device_context_ethos_u: handle,"
+        + " dtype=int32)"
+    )
+    # Open Device
+    assert (
+        str(main_func.body[1].body.body[0][0][0].value)
+        == "@tir.call_extern("
+        + '"TVMDeviceEthosUOpen",'
+        + " device_context_ethos_u: handle,"
+        + " dtype=int32)"
+    )
+    # Device Call
+    assert (
+        str(main_func.body[1].body.body[0][0][1].value)
+        == "@tir.call_extern("
+        + '"tvmgen_default_ethos_u_main_0",'
+        + " input: handle, output: handle,"
+        + " device_context_ethos_u: handle,"
+        + " dtype=int32)"
+    )
+    # Close Device
+    assert (
+        str(main_func.body[1].body.body[0][0][2].value)
+        == "@tir.call_extern("
+        + '"TVMDeviceEthosUClose",'
+        + " device_context_ethos_u: handle,"
+        + " dtype=int32)"
+    )
+    # Deactivate Device
+    assert (
+        str(main_func.body[2][0].value)
+        == "@tir.call_extern("
+        + '"TVMDeviceEthosUDeactivate",'
+        + " device_context_ethos_u: handle,"
+        + " dtype=int32)"
+    )
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))