You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/11/16 08:09:35 UTC
[tvm] branch main updated: [2/3][AOT][DeviceAPI] Add Hooks for Activate/Deactivate/Open/Close (#9500)
This is an automated email from the ASF dual-hosted git repository.
manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 948641c [2/3][AOT][DeviceAPI] Add Hooks for Activate/Deactivate/Open/Close (#9500)
948641c is described below
commit 948641c7f4abef6376094d5eddcc52e79521d425
Author: Christopher Sidebottom <ch...@arm.com>
AuthorDate: Tue Nov 16 08:09:00 2021 +0000
[2/3][AOT][DeviceAPI] Add Hooks for Activate/Deactivate/Open/Close (#9500)
* [AOT][DeviceAPI] Add Hooks for Activate/Deactivate/Open/Close
This adds the relevant hooks into their starting places in the code
generation. As per the [C Device API
RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0031-devices-api.md)
* Standardise on `lowered_ir_mods` and correct device_hook variable name
---
apps/microtvm/ethosu/include/tvm_ethosu_runtime.h | 9 +-
apps/microtvm/ethosu/src/tvm_ethosu_runtime.c | 8 +-
python/tvm/relay/backend/executor_factory.py | 24 ++++-
python/tvm/relay/build_module.py | 15 ++-
src/relay/backend/aot_executor_codegen.cc | 73 +++++++++++++--
.../contrib/ethosu/bare_metal/tvm_ethosu_runtime.c | 8 +-
.../contrib/ethosu/bare_metal/tvm_ethosu_runtime.h | 9 +-
tests/python/relay/aot/test_crt_aot.py | 104 +++++++++++++++++++++
8 files changed, 234 insertions(+), 16 deletions(-)
diff --git a/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h b/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h
index 06188ba..8352fa5 100644
--- a/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h
+++ b/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h
@@ -24,7 +24,14 @@
#include <stddef.h>
#include <stdint.h>
-int32_t TVMEthosULaunch(struct ethosu_driver* resource_handle, void* cms_data, size_t cms_data_size,
+typedef void tvm_device_ethos_u_t;
+
+int32_t TVMEthosULaunch(tvm_device_ethos_u_t* resource_handle, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors);
+int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context);
+
#endif // TVM_RUNTIME_CONTRIB_ETHOSU_ETHOSU_RUNTIME_H_
diff --git a/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c b/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
index 6b7399b..8e50602 100644
--- a/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
+++ b/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
@@ -21,8 +21,9 @@
#include <ethosu_driver.h>
-int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
+int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors) {
+ struct ethosu_driver* driver = (struct ethosu_driver*)context;
int32_t result =
ethosu_invoke(driver, cms_data, cms_data_size, base_addrs, base_addrs_size, num_tensors);
@@ -32,3 +33,8 @@ int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms
}
return 0;
}
+
+int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py
index db33c1b..7b96dd8 100644
--- a/python/tvm/relay/backend/executor_factory.py
+++ b/python/tvm/relay/backend/executor_factory.py
@@ -75,6 +75,8 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
----------
ir_mod : :py:class:`~tvm.IRModule`
The IR module to build.
+ lowered_ir_mods : dict[Target, IRModule]
+ The IR modules lowered per Target.
target : tvm.Target
The Target used to build this module.
libmod : tvm.Module
@@ -89,8 +91,19 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
List of devices used in the module
"""
- def __init__(self, ir_mod, target, libmod, libmod_name, params, function_metadata, devices):
+ def __init__(
+ self,
+ ir_mod,
+ lowered_ir_mods,
+ target,
+ libmod,
+ libmod_name,
+ params,
+ function_metadata,
+ devices,
+ ):
self.ir_mod = ir_mod
+ self.lowered_ir_mods = lowered_ir_mods
self.target = target
self.lib = libmod
self.libmod_name = libmod_name
@@ -136,7 +149,14 @@ class GraphExecutorFactoryModule(ExecutorFactoryModule):
"""
def __init__(
- self, ir_mod, target, graph_json_str, libmod, libmod_name, params, function_metadata
+ self,
+ ir_mod,
+ target,
+ graph_json_str,
+ libmod,
+ libmod_name,
+ params,
+ function_metadata,
):
assert isinstance(graph_json_str, string_types)
fcreate = get_global_func("tvm.graph_executor_factory.create")
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 38c9a40..b66d5fb 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -102,6 +102,7 @@ class BuildModule(object):
self._get_params_func = self.mod["get_params"]
self._get_function_metadata = self.mod["get_function_metadata"]
self._get_devices = self.mod["get_devices"]
+ self._get_irmodule = self.mod["get_irmodule"]
def build(
self, mod, target=None, target_host=None, params=None, executor="graph", mod_name=None
@@ -249,6 +250,10 @@ class BuildModule(object):
ret[key] = value.data
return ret
+ def get_irmodule(self):
+ """Returns the Target IRModule's post-lowering"""
+ return self._get_irmodule()
+
@register_func("tvm.relay.module_export_library")
def _module_export(module, file_name): # fcompile, addons, kwargs?
@@ -376,10 +381,18 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
)
func_metadata = bld_mod.get_function_metadata()
devices = bld_mod.get_devices()
+ lowered_ir_mods = bld_mod.get_irmodule()
if executor == "aot":
executor_factory = _executor_factory.AOTExecutorFactoryModule(
- ir_mod, target, runtime_mod, mod_name, params, func_metadata, devices
+ ir_mod,
+ lowered_ir_mods,
+ target,
+ runtime_mod,
+ mod_name,
+ params,
+ func_metadata,
+ devices,
)
elif executor == "graph":
executor_factory = _executor_factory.GraphExecutorFactoryModule(
diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc
index c240ec8..fde1de0 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -349,11 +349,19 @@ class AOTExecutorCodegen : public MixedModeVisitor {
GlobalVar global_var = call_lowered_props.lowered_func;
bool has_c_device_api_context = device_contexts_.count(global_var) != 0;
if (has_c_device_api_context) {
+ tir::Var context = device_contexts_.Get(global_var).value();
args.push_back(device_contexts_[global_var]);
- }
- tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
- create_func_call_stmts.push_back(func_call);
+ tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
+ create_func_call_stmts.push_back(tir::SeqStmt({
+ GenerateDeviceHook(context, "Open"),
+ func_call,
+ GenerateDeviceHook(context, "Close"),
+ }));
+ } else {
+ tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
+ create_func_call_stmts.push_back(func_call);
+ }
tir::Stmt body = tir::SeqStmt(create_func_call_stmts);
stmts_.push_back(body);
@@ -417,7 +425,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
device_context_var = (*pair).second;
} else {
main_signature_.push_back(device_context_var);
- devices_.push_back(context_name);
+ devices_.Set(context_name, device_context_var);
target_contexts.Set(target_kind.value(), device_context_var);
}
@@ -426,6 +434,44 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
}
+ /**
+ * \brief Generates a call to a given hook for all Devices found for C Device API
+ * \param Name of hook to generate statements for
+ * \return Statement with function calls for each device
+ */
+ tir::Stmt GenerateAllDeviceHook(const String& hook) {
+ std::vector<tir::Stmt> device_hooks;
+ for (const auto& it : devices_) {
+ const String& device_name = it.first;
+ const tir::Var& context = it.second;
+ Array<String> sections = {"Device", device_name, hook};
+ String device_hook_name = ToCFunctionStyle(PrefixName(sections));
+
+ tir::Evaluate device_hook(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
+ {tvm::tir::StringImm(device_hook_name), context}));
+ device_hooks.push_back(device_hook);
+ }
+ return tir::SeqStmt(device_hooks);
+ }
+
+ /**
+ * \brief Generates a call to a given hook for a single Device function
+ * \param Var Device context to call hook on
+ * \param Name of hook to generate statements for
+ * \return Statement with function call to Device API
+ */
+ tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) {
+ const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) {
+ return it.second->name_hint == context->name_hint;
+ });
+ const String& device_name = (*it).first;
+ Array<String> sections = {"Device", device_name, hook};
+ String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+ return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
+ {tvm::tir::StringImm(device_hook), context}));
+ }
+
/*!
* Utility function to string together different arguments
*/
@@ -587,8 +633,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));
+ tir::Stmt device_activations = GenerateAllDeviceHook("Activate");
+ tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate");
+ tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations});
+
// Make the PrimFunc
- return tir::PrimFunc(main_signature_, body, VoidType(), Map<tir::Var, tir::Buffer>(),
+ return tir::PrimFunc(main_signature_, final_body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
}
@@ -597,8 +647,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
runtime::Module* mod_;
/*! \brief list of input expressions (i.e., variable passed by the user) */
std::vector<Var> input_vars_;
- /*! \brief list of device contexts used */
- std::vector<String> devices_;
+ /*! \brief map of device contexts variables */
+ Map<String, tir::Var> devices_;
/*! \brief map of GlobalVars to C Device API contexts */
Map<GlobalVar, tir::Var> device_contexts_;
/*! \brief input and output variables belonging to the main function signature */
@@ -779,7 +829,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(),
[](Var input_var) -> String { return input_var->name_hint(); });
- ret.metadata = runtime::Metadata(input_var_names, devices_, return_sid_.size(),
+ ret.metadata = runtime::Metadata(input_var_names, ListDevices(), return_sid_.size(),
runtime::kTvmExecutorAot, mod_name);
return ret;
}
@@ -788,7 +838,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
* \brief Get list of devices found
* \return List of devices
*/
- Array<String> ListDevices() { return devices_; }
+ Array<String> ListDevices() {
+ std::vector<String> device_names(devices_.size());
+ std::transform(devices_.begin(), devices_.end(), device_names.begin(),
+ [](const auto& it) -> String { return it.first; });
+ return device_names;
+ }
}; // namespace backend
class AOTExecutorCodegenModule : public runtime::ModuleNode {
diff --git a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
index 6b7399b..8e50602 100644
--- a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
+++ b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
@@ -21,8 +21,9 @@
#include <ethosu_driver.h>
-int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
+int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors) {
+ struct ethosu_driver* driver = (struct ethosu_driver*)context;
int32_t result =
ethosu_invoke(driver, cms_data, cms_data_size, base_addrs, base_addrs_size, num_tensors);
@@ -32,3 +33,8 @@ int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms
}
return 0;
}
+
+int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
+int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
diff --git a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h
index d62afc4..31d1755 100644
--- a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h
+++ b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h
@@ -24,7 +24,14 @@
#include <stddef.h>
#include <stdint.h>
-int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size,
+typedef void tvm_device_ethos_u_t;
+
+int32_t TVMEthosULaunch(tvm_device_ethos_u_t* resource_handle, void* cms_data, size_t cms_data_size,
uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors);
+int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context);
+int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context);
+
#endif // TVM_RUNTIME_CONTRIB_ETHOSU_BARE_METAL_TVM_ETHOSU_RUNTIME_H_
diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py
index 94ecaba..e2bbb24 100644
--- a/tests/python/relay/aot/test_crt_aot.py
+++ b/tests/python/relay/aot/test_crt_aot.py
@@ -17,6 +17,7 @@
from collections import OrderedDict
import sys
+import re
import numpy as np
import pytest
@@ -693,5 +694,108 @@ def test_aot_codegen_backend_alloc_workspace_calls():
assert source.count("TVMBackendAllocWorkspace") == 3
+def test_device_api_hooks():
+ """Check for Device API hooks"""
+
+ # Ideally we should have a sample Target registered here
+ # but we're going to re-use this for now
+ pytest.importorskip("ethosu.vela")
+ import tensorflow as tf
+ import tflite.Model
+
+ from tests.python.contrib.test_ethosu import infra
+ from tvm.relay.op.contrib.ethosu import partition_for_ethosu
+
+ def create_tflite_graph():
+ tf.config.run_functions_eagerly(True)
+
+ class Model(tf.Module):
+ @tf.function
+ def tf_function(self, x):
+ return tf.nn.max_pool(x, [1, 2], [1, 2], "SAME")
+
+ def representative_dataset():
+ for _ in range(100):
+ data = np.random.rand(*tuple([1, 3, 4, 3]))
+ yield [data.astype(np.float32)]
+
+ model = Model()
+ concrete_func = model.tf_function.get_concrete_function(
+ tf.TensorSpec([1, 3, 4, 3], dtype=tf.float32)
+ )
+
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+ tflite_model = converter.convert()
+ return tflite_model
+
+ tflite_graph = create_tflite_graph()
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+ relay_module, params = relay.frontend.from_tflite(
+ tflite_model,
+ shape_dict={"x": [1, 3, 4, 3]},
+ dtype_dict={"x": "int8"},
+ )
+ mod = partition_for_ethosu(relay_module, params)
+
+ # Generate reference data
+ input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+ compiled_models = infra.build_source(
+ mod,
+ input_data,
+ output_data,
+ )
+ main_ir_module = list(compiled_models[0].executor_factory.lowered_ir_mods.values())[0]
+ main_func = main_ir_module["run_model"]
+
+ # Activate Device
+ assert (
+ str(main_func.body[0][0].value)
+ == "@tir.call_extern("
+ + '"TVMDeviceEthosUActivate",'
+ + " device_context_ethos_u: handle,"
+ + " dtype=int32)"
+ )
+ # Open Device
+ assert (
+ str(main_func.body[1].body.body[0][0][0].value)
+ == "@tir.call_extern("
+ + '"TVMDeviceEthosUOpen",'
+ + " device_context_ethos_u: handle,"
+ + " dtype=int32)"
+ )
+ # Device Call
+ assert (
+ str(main_func.body[1].body.body[0][0][1].value)
+ == "@tir.call_extern("
+ + '"tvmgen_default_ethos_u_main_0",'
+ + " input: handle, output: handle,"
+ + " device_context_ethos_u: handle,"
+ + " dtype=int32)"
+ )
+ # Close Device
+ assert (
+ str(main_func.body[1].body.body[0][0][2].value)
+ == "@tir.call_extern("
+ + '"TVMDeviceEthosUClose",'
+ + " device_context_ethos_u: handle,"
+ + " dtype=int32)"
+ )
+ # Deactivate Device
+ assert (
+ str(main_func.body[2][0].value)
+ == "@tir.call_extern("
+ + '"TVMDeviceEthosUDeactivate",'
+ + " device_context_ethos_u: handle,"
+ + " dtype=int32)"
+ )
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))