You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ar...@apache.org on 2022/06/14 22:28:30 UTC
[tvm] branch main updated: [microTVM][zephyr] Add support for host-driven AoT execution on zephyr (#11650)
This is an automated email from the ASF dual-hosted git repository.
areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5b3cef30f9 [microTVM][zephyr] Add support for host-driven AoT execution on zephyr (#11650)
5b3cef30f9 is described below
commit 5b3cef30f963a236205088848d7dc660a1f6c7fc
Author: Alan MacDonald <al...@users.noreply.github.com>
AuthorDate: Tue Jun 14 15:28:25 2022 -0700
[microTVM][zephyr] Add support for host-driven AoT execution on zephyr (#11650)
* - add support for host-driven AoT execution on zephyr;
- add initial version of reference counting to prevent python code from inadvertently freeing tensors during garbage collection;
- add support for numerical indices to host-drive AoT get_input();
- add two initial tests for host-driven AoT execution on zephyr;
- rename existing zephyr AoT exec. test;
* address PR feedback
* increase stack size to accommodate qemu_riscv64 stack usage
---
.../template_project/crt_config/crt_config.h | 2 +-
.../zephyr/template_project/microtvm_api_server.py | 2 +-
python/tvm/micro/session.py | 10 +-
python/tvm/runtime/ndarray.py | 2 +-
src/runtime/crt/aot_executor/aot_executor.c | 12 +-
.../crt/aot_executor_module/aot_executor_module.c | 30 +++-
src/runtime/crt/common/crt_runtime_api.c | 49 +++----
src/runtime/crt/common/ndarray.c | 26 +++-
src/runtime/crt/graph_executor/graph_executor.c | 4 +-
.../graph_executor_module/graph_executor_module.c | 13 +-
src/runtime/crt/host/main.cc | 3 -
.../tvm/runtime/crt/internal/common/ndarray.h | 8 ++
src/runtime/crt/microtvm_rpc_server/rpc_server.cc | 6 +
src/runtime/graph_executor/graph_executor.h | 2 +-
tests/micro/zephyr/conftest.py | 4 +-
tests/micro/zephyr/test_zephyr_aot_exec.py | 157 +++++++++++++++++++++
...r_aot.py => test_zephyr_aot_exec_standalone.py} | 0
17 files changed, 276 insertions(+), 54 deletions(-)
diff --git a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h
index c3beaed522..3481d342a1 100644
--- a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h
+++ b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h
@@ -36,7 +36,7 @@
#define TVM_CRT_MAX_ARGS 10
/*! Size of the global function registry, in bytes. */
-#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256
+#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512
/*! Maximum number of registered modules. */
#define TVM_CRT_MAX_REGISTERED_MODULES 2
diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py
index bcf9f78f4b..dad4cdf9d6 100644
--- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py
+++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py
@@ -420,7 +420,7 @@ class Handler(server.ProjectAPIHandler):
API_SERVER_CRT_LIBS_TOKEN = "<API_SERVER_CRT_LIBS>"
CRT_LIBS_BY_PROJECT_TYPE = {
- "host_driven": "microtvm_rpc_server microtvm_rpc_common common",
+ "host_driven": "microtvm_rpc_server microtvm_rpc_common aot_executor_module aot_executor common",
"aot_demo": "memory microtvm_rpc_common common",
}
diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py
index 4c38476207..967eaee629 100644
--- a/python/tvm/micro/session.py
+++ b/python/tvm/micro/session.py
@@ -39,7 +39,7 @@ except ImportError:
@register_error
class SessionTerminatedError(Exception):
- """Raised when a transport read operationd discovers that the remote session is terminated."""
+ """Raised when a transport read operation discovers that the remote session is terminated."""
class Session:
@@ -86,12 +86,18 @@ class Session:
self._rpc = None
self._graph_executor = None
+ self._enable_rpc_logger = False
self._exit_called = False
def get_system_lib(self):
return self._rpc.get_function("runtime.SystemLib")()
+ def create_aot_executor(self):
+ return self._rpc.get_function("tvm.aot_executor.create")(
+ self.get_system_lib(), self.device, "default"
+ )
+
def _wrap_transport_read(self, n, timeout_microsec):
try:
return self.transport.read(
@@ -133,7 +139,7 @@ class Session:
int(timeouts.session_start_timeout_sec * 1e6),
int(timeouts.session_established_timeout_sec * 1e6),
self._cleanup,
- False,
+ self._enable_rpc_logger,
)
)
self.device = self._rpc.cpu(0)
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 3d4764d616..9d3a3aff21 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -127,7 +127,7 @@ class NDArray(NDArrayBase):
raise TypeError("type %s not supported" % str(type(value)))
def copyfrom(self, source_array):
- """Perform an synchronize copy from the array.
+ """Perform a synchronous copy from the array.
Parameters
----------
diff --git a/src/runtime/crt/aot_executor/aot_executor.c b/src/runtime/crt/aot_executor/aot_executor.c
index 1360c40b0f..1724fabec4 100644
--- a/src/runtime/crt/aot_executor/aot_executor.c
+++ b/src/runtime/crt/aot_executor/aot_executor.c
@@ -173,21 +173,29 @@ int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle,
for (i = 0; i < md->num_inputs; ++i) {
LOG_DEBUG("input allocate[%d]: %s\n", i, md->inputs[i].name);
+ TVMNDArray* array = &executor->args[arg_idx++];
+
status = TVMNDArray_Empty(md->inputs[i].num_shape, md->inputs[i].shape, md->inputs[i].dtype,
- executor->device, &executor->args[arg_idx++]);
+ executor->device, array);
if (status != 0) {
return status;
}
+
+ TVMNDArray_IncrementReference(array);
}
for (i = 0; i < md->num_outputs; ++i) {
LOG_DEBUG("output allocate[%d]: %s\n", i, md->outputs[i].name);
+ TVMNDArray* array = &executor->args[arg_idx++];
+
status = TVMNDArray_Empty(md->outputs[i].num_shape, md->outputs[i].shape, md->outputs[i].dtype,
- executor->device, &executor->args[arg_idx++]);
+ executor->device, array);
if (status != 0) {
return status;
}
+
+ TVMNDArray_IncrementReference(array);
}
for (i = 0; i < md->num_pools; ++i) {
diff --git a/src/runtime/crt/aot_executor_module/aot_executor_module.c b/src/runtime/crt/aot_executor_module/aot_executor_module.c
index e1dbd533a3..5dd11c3dbc 100644
--- a/src/runtime/crt/aot_executor_module/aot_executor_module.c
+++ b/src/runtime/crt/aot_executor_module/aot_executor_module.c
@@ -80,13 +80,27 @@ int32_t TVMAotExecutorModule_NotImplemented(TVMValue* args, int* tcodes, int nar
int32_t TVMAotExecutorModule_GetInput(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values,
int* ret_tcodes, void* resource_handle) {
- int index = TVMAotExecutor_GetInputIndex(aot_executor.executor, args[0].v_str);
+ int64_t index;
- if (index < 0) {
- return kTvmErrorExecutorModuleNoSuchInput;
+ if (tcodes[0] == kTVMArgInt) {
+ if (args[0].v_int64 > TVMAotExecutor_GetNumInputs(aot_executor.executor)) {
+ return kTvmErrorFunctionCallInvalidArg;
+ }
+
+ index = args[0].v_int64;
+ } else {
+ index = TVMAotExecutor_GetInputIndex(aot_executor.executor, args[0].v_str);
+
+ if (index < 0) {
+ return kTvmErrorExecutorModuleNoSuchInput;
+ }
}
- ret_values[0].v_handle = (void*)&aot_executor.executor->args[index].dl_tensor;
+ TVMNDArray* array = &aot_executor.executor->args[index];
+
+ TVMNDArray_IncrementReference(array);
+
+ ret_values[0].v_handle = (void*)(&array->dl_tensor);
ret_tcodes[0] = kTVMNDArrayHandle;
return 0;
@@ -103,9 +117,13 @@ int32_t TVMAotExecutorModule_GetOutput(TVMValue* args, int* tcodes, int nargs, T
}
// index past the input entries
- int64_t idx = args[0].v_int64 + TVMAotExecutor_GetNumInputs(aot_executor.executor);
+ int64_t index = args[0].v_int64 + TVMAotExecutor_GetNumInputs(aot_executor.executor);
+
+ TVMNDArray* array = &aot_executor.executor->args[index];
+
+ TVMNDArray_IncrementReference(array);
- ret_values[0].v_handle = (void*)&aot_executor.executor->args[idx].dl_tensor;
+ ret_values[0].v_handle = (void*)(&array->dl_tensor);
ret_tcodes[0] = kTVMNDArrayHandle;
return 0;
diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c
index 31ab3e9a69..a8a17041f5 100644
--- a/src/runtime/crt/common/crt_runtime_api.c
+++ b/src/runtime/crt/common/crt_runtime_api.c
@@ -76,9 +76,9 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_
}
int TVMArrayFree(TVMArrayHandle handle) {
- TVMNDArray arr;
- arr.dl_tensor = *handle;
- return TVMNDArray_Release(&arr);
+ TVMNDArray* arr = (TVMNDArray*)handle;
+
+ return TVMNDArray_Release(arr);
}
int TVMDeviceAllocDataSpace(DLDevice dev, size_t nbytes, size_t alignment, DLDataType type_hint,
@@ -149,7 +149,7 @@ static const TVMModule* registered_modules[TVM_CRT_MAX_REGISTERED_MODULES];
/*! \brief Passed as `module_index` to EncodeFunctionHandle. */
static const tvm_module_index_t kGlobalFuncModuleIndex = TVM_CRT_MAX_REGISTERED_MODULES;
-/*! \brief Special module handle for retur values from RPCTimeEvaluator. */
+/*! \brief Special module handle for return values from RPCTimeEvaluator. */
static const tvm_module_index_t kTimeEvaluatorModuleIndex = 0x7fff;
static int DecodeModuleHandle(TVMModuleHandle handle, tvm_module_index_t* out_module_index) {
@@ -202,8 +202,8 @@ int TVMModFree(TVMModuleHandle mod) {
return 0;
}
-int SystemLibraryCreate(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
- int* ret_type_codes) {
+static int SystemLibraryCreate(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
+ int* ret_type_codes) {
const TVMModule* system_lib;
if (system_lib_handle == kTVMModuleHandleUninitialized) {
@@ -400,8 +400,22 @@ int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMVal
return 0;
}
-int TVMContribRandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
- int* ret_type_code);
+// Fill the tensor in args[0] with random data using TVMPlatformGenerateRandom.
+static int RandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
+ int* ret_type_code) {
+ if (num_args != 1) {
+ return kTvmErrorFunctionCallNumArguments;
+ }
+
+ if (type_codes[0] != kTVMDLTensorHandle) {
+ return kTvmErrorFunctionCallWrongArgType;
+ }
+
+ DLTensor* tensor = (DLTensor*)args[0].v_handle;
+ TVMNDArray arr = {*tensor, 0};
+ return TVMNDArray_RandomFill(&arr);
+}
+
tvm_crt_error_t TVMInitializeRuntime() {
int idx = 0;
tvm_crt_error_t error = kTvmErrorNoError;
@@ -440,7 +454,7 @@ tvm_crt_error_t TVMInitializeRuntime() {
}
if (error == kTvmErrorNoError) {
- error = TVMFuncRegisterGlobal("tvm.contrib.random.random_fill", &TVMContribRandomFill, 0);
+ error = TVMFuncRegisterGlobal("tvm.contrib.random.random_fill", &RandomFill, 0);
}
if (error != kTvmErrorNoError) {
@@ -590,20 +604,3 @@ __attribute__((weak)) tvm_crt_error_t TVMPlatformBeforeMeasurement() { return kT
// Default implementation, overridden by the platform runtime.
__attribute__((weak)) tvm_crt_error_t TVMPlatformAfterMeasurement() { return kTvmErrorNoError; }
-
-// Fill the tensor in args[0] with random data using TVMPlatformGenerateRandom.
-// Named to correspond with the analogous function in the C++ runtime.
-int TVMContribRandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
- int* ret_type_code) {
- if (num_args != 1) {
- return kTvmErrorFunctionCallNumArguments;
- }
-
- if (type_codes[0] != kTVMDLTensorHandle) {
- return kTvmErrorFunctionCallWrongArgType;
- }
-
- DLTensor* tensor = (DLTensor*)args[0].v_handle;
- TVMNDArray arr = {*tensor};
- return TVMNDArray_RandomFill(&arr);
-}
diff --git a/src/runtime/crt/common/ndarray.c b/src/runtime/crt/common/ndarray.c
index 16bde3227f..b0e869766b 100644
--- a/src/runtime/crt/common/ndarray.c
+++ b/src/runtime/crt/common/ndarray.c
@@ -30,8 +30,8 @@
#include "crt_config.h"
-int TVMNDArray_Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev,
- TVMNDArray* array) {
+static int Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev,
+ TVMNDArray* array) {
memset(array, 0, sizeof(TVMNDArray));
array->dl_tensor.ndim = ndim;
tvm_crt_error_t err;
@@ -58,7 +58,7 @@ int64_t TVMNDArray_DataSizeBytes(TVMNDArray* array) {
int TVMNDArray_Empty(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev,
TVMNDArray* array) {
- int status = TVMNDArray_Create(ndim, shape, dtype, dev, array);
+ int status = Create(ndim, shape, dtype, dev, array);
if (status != 0) {
return status;
}
@@ -132,7 +132,7 @@ int TVMNDArray_Load(TVMNDArray* ret, const char** strm) {
int TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, int32_t ndim, DLDataType dtype,
TVMNDArray* array_view) {
- int status = TVMNDArray_Create(ndim, shape, dtype, arr->dl_tensor.device, array_view);
+ int status = Create(ndim, shape, dtype, arr->dl_tensor.device, array_view);
if (status != 0) {
return status;
}
@@ -149,21 +149,35 @@ int TVMNDArray_RandomFill(TVMNDArray* arr) {
return TVMPlatformGenerateRandom(arr->dl_tensor.data, (size_t)num_bytes);
}
+void TVMNDArray_IncrementReference(TVMNDArray* arr) { arr->reference_count++; }
+
+uint32_t TVMNDArray_DecrementReference(TVMNDArray* arr) {
+ if (arr->reference_count > 0) {
+ arr->reference_count--;
+ }
+
+ return arr->reference_count;
+}
+
int TVMNDArray_Release(TVMNDArray* arr) {
tvm_crt_error_t err;
DLDevice dev = {kDLCPU, 0};
+ if (TVMNDArray_DecrementReference(arr) > 0) {
+ return 0;
+ }
+
err = TVMPlatformMemoryFree(arr->dl_tensor.data, dev);
if (err != kTvmErrorNoError) {
return err;
}
+ arr->dl_tensor.data = NULL;
- arr->dl_tensor.data = 0;
err = TVMPlatformMemoryFree(arr->dl_tensor.shape, dev);
if (err != kTvmErrorNoError) {
return err;
}
+ arr->dl_tensor.shape = NULL;
- arr->dl_tensor.shape = 0;
return 0;
}
diff --git a/src/runtime/crt/graph_executor/graph_executor.c b/src/runtime/crt/graph_executor/graph_executor.c
index 3fea408d97..395a343ccb 100644
--- a/src/runtime/crt/graph_executor/graph_executor.c
+++ b/src/runtime/crt/graph_executor/graph_executor.c
@@ -1014,7 +1014,7 @@ int TVMGraphExecutor_SetupStorage(TVMGraphExecutor* executor) {
executor->storage_pool_count++;
}
- // Assign the pooled entries. A unified memory pool is used to simplifiy
+ // Assign the pooled entries. A unified memory pool is used to simplify
// memory assignment for each node entry. The allocated memory on each device
// is mapped to this pool.
executor->data_entry_count = executor->node_row_ptr[executor->node_row_ptr_count - 1];
@@ -1031,6 +1031,8 @@ int TVMGraphExecutor_SetupStorage(TVMGraphExecutor* executor) {
attrs->shape + idx * TVM_CRT_MAX_NDIM, attrs->ndim[idx],
vtype[idx], &executor->data_entry[idx]);
CHECK_EQ(status, 0, "fail to create for node with idx=%d, storage_id=%u\n", idx, storage_id);
+
+ TVMNDArray_IncrementReference(&executor->data_entry[idx]);
}
// Release memory
diff --git a/src/runtime/crt/graph_executor_module/graph_executor_module.c b/src/runtime/crt/graph_executor_module/graph_executor_module.c
index 0ae12f5a9e..559b6896a5 100644
--- a/src/runtime/crt/graph_executor_module/graph_executor_module.c
+++ b/src/runtime/crt/graph_executor_module/graph_executor_module.c
@@ -95,7 +95,12 @@ int32_t TVMGraphExecutorModule_GetInput(TVMValue* args, int* tcodes, int nargs,
uint32_t eid = TVMGraphExecutor_GetEntryId(graph_executor.executor,
graph_executor.executor->input_nodes[index], 0);
- ret_values[0].v_handle = (void*)&graph_executor.executor->data_entry[eid].dl_tensor;
+
+ TVMNDArray* array = &graph_executor.executor->data_entry[eid];
+
+ TVMNDArray_IncrementReference(array);
+
+ ret_values[0].v_handle = (void*)(&array->dl_tensor);
ret_tcodes[0] = kTVMNDArrayHandle;
return 0;
}
@@ -158,7 +163,11 @@ int32_t TVMGraphExecutorModule_GetOutput(TVMValue* args, int* tcodes, int nargs,
uint32_t index = graph_executor.executor->outputs[output_index].index;
uint32_t eid = TVMGraphExecutor_GetEntryId(graph_executor.executor, nid, index);
- ret_values[0].v_handle = (void*)&(graph_executor.executor->data_entry[eid].dl_tensor);
+ TVMNDArray* array = &graph_executor.executor->data_entry[eid];
+
+ TVMNDArray_IncrementReference(array);
+
+ ret_values[0].v_handle = (void*)(&array->dl_tensor);
ret_tcodes[0] = kTVMNDArrayHandle;
return 0;
}
diff --git a/src/runtime/crt/host/main.cc b/src/runtime/crt/host/main.cc
index bf4a98569e..d8fa95fe23 100644
--- a/src/runtime/crt/host/main.cc
+++ b/src/runtime/crt/host/main.cc
@@ -139,9 +139,6 @@ int main(int argc, char** argv) {
"failed to register GraphExecutor TVMModule");
#endif
- CHECK_EQ(TVMAotExecutorModule_Register(), kTvmErrorNoError,
- "failed to register AoT Executor TVMModule");
-
int error = TVMFuncRegisterGlobal("tvm.testing.reset_server",
(TVMFunctionHandle)&testonly_reset_server, 0);
if (error) {
diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h b/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h
index e5869ed2a3..0162c6eb4d 100644
--- a/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h
+++ b/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h
@@ -38,7 +38,11 @@ static const uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
static const uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
typedef struct TVMNDArray {
+ /*! \brief the actual tensor in DLPack format. NOTE: this must be first element in struct */
DLTensor dl_tensor;
+
+ /*! \brief count of references to TVMNDArray to avoid early freeing by host */
+ uint32_t reference_count;
} TVMNDArray;
int TVMNDArray_Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev,
@@ -56,6 +60,10 @@ int TVMNDArray_Load(TVMNDArray* ret, const char** strm);
int TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, int32_t ndim, DLDataType dtype,
TVMNDArray* array_view);
+void TVMNDArray_IncrementReference(TVMNDArray* arr);
+
+uint32_t TVMNDArray_DecrementReference(TVMNDArray* arr);
+
int TVMNDArray_Release(TVMNDArray* arr);
#endif // TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_COMMON_NDARRAY_H_
diff --git a/src/runtime/crt/microtvm_rpc_server/rpc_server.cc b/src/runtime/crt/microtvm_rpc_server/rpc_server.cc
index b7bae243ec..1e5f625998 100644
--- a/src/runtime/crt/microtvm_rpc_server/rpc_server.cc
+++ b/src/runtime/crt/microtvm_rpc_server/rpc_server.cc
@@ -33,6 +33,7 @@
#define DMLC_CMAKE_LITTLE_ENDIAN DMLC_IO_USE_LITTLE_ENDIAN
#define DMLC_LITTLE_ENDIAN 1
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/crt/aot_executor_module.h>
#include <tvm/runtime/crt/crt.h>
#include <tvm/runtime/crt/logging.h>
#include <tvm/runtime/crt/microtvm_rpc_server.h>
@@ -207,6 +208,11 @@ microtvm_rpc_server_t MicroTVMRpcServerInit(microtvm_rpc_channel_write_t write_f
TVMPlatformAbort(err);
}
+ err = TVMAotExecutorModule_Register();
+ if (err != kTvmErrorNoError) {
+ TVMPlatformAbort(err);
+ }
+
DLDevice dev = {kDLCPU, 0};
void* receive_buffer_memory;
err = TVMPlatformMemoryAllocate(TVM_CRT_MAX_PACKET_SIZE_BYTES, dev, &receive_buffer_memory);
diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h
index 25b01a253c..2564f5b0d9 100644
--- a/src/runtime/graph_executor/graph_executor.h
+++ b/src/runtime/graph_executor/graph_executor.h
@@ -61,7 +61,7 @@ struct TVMOpParam {
/*!
* \brief Tiny graph executor.
*
- * This runtime can be acccesibly in various language via
+ * This runtime can be accessible in various languages via
* TVM runtime PackedFunc API.
*/
class TVM_DLL GraphExecutor : public ModuleNode {
diff --git a/tests/micro/zephyr/conftest.py b/tests/micro/zephyr/conftest.py
index 997237d370..c4de48a0a4 100644
--- a/tests/micro/zephyr/conftest.py
+++ b/tests/micro/zephyr/conftest.py
@@ -30,7 +30,7 @@ def pytest_addoption(parser):
"--zephyr-board",
required=True,
choices=test_utils.ZEPHYR_BOARDS.keys(),
- help=("Zephyr board for test."),
+ help="Zephyr board for test.",
)
parser.addoption(
"--west-cmd", default="west", help="Path to `west` command for flashing device."
@@ -92,5 +92,5 @@ def skip_by_board(request, board):
def pytest_configure(config):
config.addinivalue_line(
"markers",
- "skip_by_board(board): skip test for the given board",
+ "skip_boards(board): skip test for the given board",
)
diff --git a/tests/micro/zephyr/test_zephyr_aot_exec.py b/tests/micro/zephyr/test_zephyr_aot_exec.py
new file mode 100644
index 0000000000..1add0063bc
--- /dev/null
+++ b/tests/micro/zephyr/test_zephyr_aot_exec.py
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+import os
+import pathlib
+import sys
+import logging
+
+import pytest
+import numpy as np
+
+import onnx
+from PIL import Image
+
+import tvm
+import tvm.testing
+import tvm.relay as relay
+from tvm.relay.backend import Executor, Runtime
+from tvm.relay.testing import byoc
+from tvm.contrib import utils
+from tvm.micro.testing.utils import check_tune_log
+from tvm._ffi import get_global_func, register_func
+
+import test_utils
+
+
+def _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config):
+ config_main_stack_size = None
+ if test_utils.qemu_boards(zephyr_board):
+ # fyi: qemu_riscv64 seems to be the greediest stack user
+ config_main_stack_size = 4096
+
+ project_options = {
+ "project_type": "host_driven",
+ "west_cmd": west_cmd,
+ "verbose": bool(build_config.get("debug")),
+ "zephyr_board": zephyr_board,
+ }
+ if config_main_stack_size is not None:
+ project_options["config_main_stack_size"] = config_main_stack_size
+
+ project = tvm.micro.generate_project(
+ str(test_utils.TEMPLATE_PROJECT_DIR),
+ mod,
+ temp_dir / "project",
+ project_options,
+ )
+ project.build()
+ project.flash()
+ return tvm.micro.Session(project.transport())
+
+
+@tvm.testing.requires_micro
+def test_relay(temp_dir, board, west_cmd, tvm_debug):
+ """Testing a simple relay graph"""
+
+ model = test_utils.ZEPHYR_BOARDS[board]
+ build_config = {"debug": tvm_debug}
+ shape = (10,)
+ dtype = "int8"
+
+ # Construct Relay program.
+ x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype))
+ xx = relay.multiply(x, x)
+ z = relay.add(xx, relay.const(np.ones(shape=shape, dtype=dtype)))
+ func = relay.Function([x], z)
+ ir_mod = tvm.IRModule.from_expr(func)
+
+ runtime = Runtime("crt", {"system-lib": True})
+ executor = Executor("aot")
+ target = tvm.target.target.micro(model)
+ with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
+ mod = tvm.relay.build(ir_mod, target=target, runtime=runtime, executor=executor)
+
+ with _make_session(temp_dir, board, west_cmd, mod, build_config) as session:
+
+ aot_executor = tvm.runtime.executor.aot_executor.AotModule(session.create_aot_executor())
+
+ x_in = np.random.randint(10, size=shape[0], dtype=dtype)
+ aot_executor.run(x=x_in)
+ result = aot_executor.get_output(0).numpy()
+ tvm.testing.assert_allclose(aot_executor.get_input(0).numpy(), x_in)
+ tvm.testing.assert_allclose(result, x_in * x_in + 1)
+
+
+@tvm.testing.requires_micro
+def test_aot_executor(temp_dir, board, west_cmd, tvm_debug):
+ """Test use of the AOT executor with microTVM."""
+
+ model = test_utils.ZEPHYR_BOARDS[board]
+ build_config = {"debug": tvm_debug}
+ shape = (10,)
+ dtype = "int8"
+
+ print("test_relay: construct relay program\n")
+
+ # Construct Relay program.
+ relay_mod = tvm.parser.fromtext(
+ """
+ #[version = "0.0.5"]
+ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) {
+ %0 = %a + %b;
+ %0
+ }"""
+ )
+
+ runtime = Runtime("crt", {"system-lib": True})
+ executor = Executor("aot")
+ target = tvm.target.target.micro(model)
+ with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
+ mod = tvm.relay.build(relay_mod, target=target, runtime=runtime, executor=executor)
+
+ def do_test():
+
+ aot_executor = tvm.runtime.executor.aot_executor.AotModule(session.create_aot_executor())
+
+ assert aot_executor.get_input_index("a") == 0
+ assert aot_executor.get_input_index("b") == 1
+
+ assert aot_executor.get_num_inputs() == 2
+ assert aot_executor.get_num_outputs() == 1
+
+ A_np = np.array([[2, 3]], dtype="uint8")
+ B_np = np.array([[4, 7]], dtype="uint8")
+
+ A_data = aot_executor.get_input("a").copyfrom(A_np)
+ B_data = aot_executor.get_input("b").copyfrom(B_np)
+
+ aot_executor.run()
+
+ out = aot_executor.get_output(0)
+ assert (out.numpy() == np.array([6, 10])).all()
+
+ B_np_new = np.array([[5, 8]])
+ aot_executor.set_input("b", B_np_new)
+ assert (B_data.numpy() == B_np_new).all()
+
+ with _make_session(temp_dir, board, west_cmd, mod, build_config) as session:
+ do_test()
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot_exec_standalone.py
similarity index 100%
rename from tests/micro/zephyr/test_zephyr_aot.py
rename to tests/micro/zephyr/test_zephyr_aot_exec_standalone.py