You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/04 07:35:55 UTC
[tvm] branch main updated: [microNPU] Refactor codegen tests (#9623)
This is an automated email from the ASF dual-hosted git repository.
manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 643d991 [microNPU] Refactor codegen tests (#9623)
643d991 is described below
commit 643d9912189bf2e8cbc9279a01d82d52d6691df9
Author: Matthew Barrett <55...@users.noreply.github.com>
AuthorDate: Sat Dec 4 07:35:19 2021 +0000
[microNPU] Refactor codegen tests (#9623)
* [microNPU] Refactor codegen tests
Change-Id: I9c08520c9e03eb3fc32bd911b56c95981e851b4b
* Fix params
Change-Id: I8cea69ed3824c3a0417bb67abbabce460c17c4c6
* Remove prints
Change-Id: Iadf048e9590e724d73c2adac51bbe303de6f59a8
* Address review comments
Change-Id: I56d647d86e3d495abe38b13cca349a71ec81cf4d
---
tests/python/contrib/test_ethosu/test_codegen.py | 787 ++++++++---------------
1 file changed, 268 insertions(+), 519 deletions(-)
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index afd635d..42695db 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -24,7 +24,10 @@ import tflite.Model
import tvm
import tensorflow as tf
from tvm import relay
+from tvm.relay.expr_functor import ExprMutator
+from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.backend.contrib.ethosu import util
+from tvm.relay.backend.contrib.ethosu import preprocess
from tvm.relay.op.contrib.ethosu import partition_for_ethosu
from tests.python.relay.aot.aot_test_utils import generate_ref_data
@@ -166,89 +169,146 @@ def test_ethosu_conv2d(accel_type):
infra.verify_source(compiled_models, accel_type)
-@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
-@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
-@pytest.mark.parametrize(
- "kernel_shape, activation",
- [((3, 3), "relu"), ((1, 2), None)],
-)
-@pytest.mark.parametrize("padding", ["SAME", "VALID"])
-@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 2)), ((3, 2), (1, 1))])
-def test_tflite_depthwise_conv2d(
- accel_type,
- ifm_shape,
- kernel_shape,
- padding,
- strides,
- dilation,
- activation,
+def _compare_ethosu_with_reference(
+ mod, input_data, output_data, accel_type, output_tolerance=0, print_cmm=False
):
- dtype = "int8"
+ compiled_models = infra.build_source(
+ mod,
+ input_data,
+ output_data,
+ accel_type,
+ output_tolerance=output_tolerance,
+ )
- def create_tflite_graph():
- class Model(tf.Module):
- @tf.function
- def depthwise_conv2d(self, x):
- weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
- weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
- # The input strides to the TensorFlow API needs to be of shape 1x4
- tf_strides = [1, strides[0], strides[1], 1]
- op = tf.nn.depthwise_conv2d(
- x, weight, strides=tf_strides, padding=padding, dilations=dilation
- )
- if activation:
- op = tf.nn.relu(op)
- return op
+ # Assumes only two runtime.Modules are created -- i.e. single offload module
+ ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
- model = Model()
- concrete_func = model.depthwise_conv2d.get_concrete_function(
- tf.TensorSpec(ifm_shape, dtype=tf.float32)
- )
+ # Verify generated C source
+ if print_cmm:
+ get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+ compilation_artifacts = get_artifacts(ethosu_module)
+ cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
+ infra.print_payload(cmms)
- # Convert the model
- def representative_dataset():
- for _ in range(100):
- data = np.random.rand(*tuple(ifm_shape))
- yield [data.astype(np.float32)]
+ infra.verify_source(compiled_models, accel_type)
- converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- tflite_model = converter.convert()
- return tflite_model
- tflite_graph = create_tflite_graph()
+def _compare_tvm_with_tflite(
+ tf_func, shapes, accel_type, ranges=None, output_tolerance=0, print_cmm=False
+):
+ tensor_specs = [tf.TensorSpec(shape, dtype=tf.float32) for shape in shapes]
+ if not ranges:
+ ranges = [(0, 1) for _ in shapes]
+ concrete_func = tf_func.get_concrete_function(*tensor_specs)
+
+ # Convert the model
+ def representative_dataset():
+ for _ in range(100):
+ inputs = []
+ for i, shape in enumerate(shapes):
+ data = np.random.uniform(
+ low=ranges[i][0], high=ranges[i][1], size=tuple(shape)
+ ).astype("float32")
+ inputs.append(data)
+
+ yield inputs
+
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+ tflite_graph = converter.convert()
+
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
- relay_module, params = relay.frontend.from_tflite(
- tflite_model,
- shape_dict={"input": ifm_shape},
- dtype_dict={"input": dtype},
- )
+ relay_module, params = relay.frontend.from_tflite(tflite_model)
mod = partition_for_ethosu(relay_module, params)
# Generate reference data
input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
- compiled_models = infra.build_source(
+ _compare_ethosu_with_reference(
mod,
input_data,
output_data,
accel_type,
+ output_tolerance=output_tolerance,
+ print_cmm=print_cmm,
)
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_models, accel_type)
+class EthosUAnnotator(ExprMutator):
+ """Annotate entire graph for Ethos-U offload"""
+
+ def __init__(self):
+ super(EthosUAnnotator, self).__init__()
+ self.compiler = "ethos-u"
+ self.last_call = True
+
+ def visit_call(self, call):
+ curr_last = self.last_call
+ self.last_call = False
+
+ params = []
+ for arg in call.args:
+ param = super().visit(arg)
+ if isinstance(param, relay.expr.Var):
+ param = compiler_begin(param, self.compiler)
+ params.append(param)
+
+ new_call = relay.Call(call.op, params, call.attrs)
+ if curr_last:
+ new_call = compiler_end(new_call, self.compiler)
+ return new_call
+
+ def visit_constant(self, constant):
+ new_constant = compiler_begin(constant, self.compiler)
+ return new_constant
+
+
+def _create_ethosu_partition(mod):
+ mod["main"] = EthosUAnnotator().visit(mod["main"])
+ mod = relay.transform.MergeCompilerRegions()(mod)
+ mod = relay.transform.InferType()(mod)
+ mod = relay.transform.PartitionGraph()(mod)
+ mod = relay.transform.InferType()(mod)
+ mod = preprocess.preprocess_ext_io()(mod)
+ return mod
+
+
+@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
+@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
+@pytest.mark.parametrize(
+ "kernel_shape, activation_function",
+ [((3, 3), "RELU"), ((1, 2), "NONE")],
+)
+@pytest.mark.parametrize("padding", ["SAME", "VALID"])
+@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 2)), ((3, 2), (1, 1))])
+def test_tflite_depthwise_conv2d(
+ accel_type,
+ ifm_shape,
+ kernel_shape,
+ padding,
+ strides,
+ dilation,
+ activation_function,
+):
+ @tf.function
+ def depthwise_conv2d(x):
+ weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
+ weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+ # The input strides to the TensorFlow API needs to be of shape 1x4
+ tf_strides = [1, strides[0], strides[1], 1]
+ op = tf.nn.depthwise_conv2d(
+ x, weight, strides=tf_strides, padding=padding, dilations=dilation
+ )
+ if activation_function:
+ op = tf.nn.relu(op)
+ return op
+
+ _compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type)
@pytest.mark.parametrize(
@@ -270,69 +330,17 @@ def test_ethosu_pooling(
activation_function,
padding,
):
- dtype = "int8"
+ @tf.function
+ def pooling(x):
+ if pooling_type == "MAX":
+ op = tf.nn.max_pool(x, pool_shape, strides, padding)
+ elif pooling_type == "AVG":
+ op = tf.nn.avg_pool(x, pool_shape, strides, padding)
+ if activation_function == "RELU":
+ op = tf.nn.relu(op)
+ return op
- def create_tflite_graph():
- class Model(tf.Module):
- @tf.function
- def tf_function(self, x):
- if pooling_type == "MAX":
- op = tf.nn.max_pool(x, pool_shape, strides, padding)
- elif pooling_type == "AVG":
- op = tf.nn.avg_pool(x, pool_shape, strides, padding)
- if activation_function == "RELU":
- op = tf.nn.relu(op)
- return op
-
- model = Model()
- concrete_func = model.tf_function.get_concrete_function(
- tf.TensorSpec(ifm_shape, dtype=tf.float32)
- )
-
- # Convert the model
- def representative_dataset():
- for _ in range(100):
- data = np.random.rand(*tuple(ifm_shape))
- yield [data.astype(np.float32)]
-
- converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- tflite_model = converter.convert()
- return tflite_model
-
- tflite_graph = create_tflite_graph()
- tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
-
- relay_module, params = relay.frontend.from_tflite(
- tflite_model,
- shape_dict={"x": ifm_shape},
- dtype_dict={"x": dtype},
- )
- mod = partition_for_ethosu(relay_module, params)
-
- # Generate reference data
- input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
-
- compiled_models = infra.build_source(
- mod,
- input_data,
- output_data,
- accel_type,
- )
-
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_models, accel_type)
+ _compare_tvm_with_tflite(pooling, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -354,78 +362,30 @@ def test_ethosu_binary_elementwise(
ifm2_shape,
activation_function,
):
- dtype = "int8"
-
- def create_tflite_graph():
- class Model(tf.Module):
- @tf.function
- def tf_function(self, lhs, rhs):
- if operator_type == "ADD":
- op = tf.math.add(lhs, rhs)
- elif operator_type == "SUB":
- op = tf.math.subtract(lhs, rhs)
- elif operator_type == "MUL":
- op = tf.math.multiply(lhs, rhs)
- elif operator_type == "MIN":
- op = tf.math.minimum(lhs, rhs)
- elif operator_type == "MAX":
- op = tf.math.maximum(lhs, rhs)
- if activation_function == "RELU":
- op = tf.nn.relu(op)
- return op
-
- model = Model()
- concrete_func = model.tf_function.get_concrete_function(
- tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32)
- )
-
- # Convert the model
- def representative_dataset():
- for _ in range(100):
- data = np.random.rand(*tuple(ifm_shape))
- data2 = np.random.rand(*tuple(ifm2_shape)) * 2
- yield [data.astype(np.float32), data2.astype(np.float32)]
-
- converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- tflite_model = converter.convert()
- return tflite_model
-
- tflite_graph = create_tflite_graph()
- tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
-
- mod, params = relay.frontend.from_tflite(
- tflite_model,
- shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape},
- dtype_dict={"ifm": dtype, "ifm2": dtype},
- )
- mod = partition_for_ethosu(mod, params)
-
- # Generate reference data
- input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
-
- compiled_models = infra.build_source(
- mod,
- input_data,
- output_data,
- accel_type,
+ @tf.function
+ def binary_elementwise(lhs, rhs):
+ if operator_type == "ADD":
+ op = tf.math.add(lhs, rhs)
+ elif operator_type == "SUB":
+ op = tf.math.subtract(lhs, rhs)
+ elif operator_type == "MUL":
+ op = tf.math.multiply(lhs, rhs)
+ elif operator_type == "MIN":
+ op = tf.math.minimum(lhs, rhs)
+ elif operator_type == "MAX":
+ op = tf.math.maximum(lhs, rhs)
+ if activation_function == "RELU":
+ op = tf.nn.relu(op)
+ return op
+
+ _compare_tvm_with_tflite(
+ binary_elementwise,
+ shapes=[ifm_shape, ifm2_shape],
+ ranges=[(0, 1), (0, 2)],
+ accel_type=accel_type,
output_tolerance=1 if operator_type == "MAX" else 0,
)
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_models, accel_type)
-
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
@@ -441,66 +401,17 @@ def test_binary_add_with_non_4d_shapes(
ifm_shape,
ifm2_shape,
):
- dtype = "int8"
-
- def create_tflite_graph():
- class Model(tf.Module):
- @tf.function
- def tf_function(self, lhs, rhs):
- return tf.math.add(lhs, rhs)
-
- model = Model()
- concrete_func = model.tf_function.get_concrete_function(
- tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32)
- )
-
- # Convert the model
- def representative_dataset():
- for _ in range(100):
- data = np.random.rand(*tuple(ifm_shape))
- data2 = np.random.rand(*tuple(ifm2_shape)) * 2
- yield [data.astype(np.float32), data2.astype(np.float32)]
-
- converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- tflite_model = converter.convert()
- return tflite_model
-
- tflite_graph = create_tflite_graph()
- tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
-
- mod, params = relay.frontend.from_tflite(
- tflite_model,
- shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape},
- dtype_dict={"ifm": dtype, "ifm2": dtype},
- )
- mod = partition_for_ethosu(mod, params)
-
- # Generate reference data
- input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
-
- compiled_models = infra.build_source(
- mod,
- input_data,
- output_data,
- accel_type,
- output_tolerance=0,
+ @tf.function
+ def binary_elementwise(lhs, rhs):
+ return tf.math.add(lhs, rhs)
+
+ _compare_tvm_with_tflite(
+ binary_elementwise,
+ shapes=[ifm_shape, ifm2_shape],
+ ranges=[(0, 1), (0, 2)],
+ accel_type=accel_type,
)
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_models, accel_type)
-
@pytest.mark.parametrize(
"accel_type",
@@ -621,34 +532,19 @@ def test_binary_add_from_constant_scalar(accel_type):
relay.const(1.0, dtype="float32"),
relay.const(0, dtype="int32"),
)
- func = relay.Function(relay.analysis.free_vars(add), add)
- return tvm.IRModule.from_expr(func)
+ return tvm.IRModule.from_expr(relay.Function(relay.analysis.free_vars(add), add))
- mod = create_relay_graph()
- partitioned_mod = partition_for_ethosu(mod)
+ cpu_mod = create_relay_graph()
+ ethosu_mod = partition_for_ethosu(cpu_mod)
# Generate reference data
input_data = {"input": np.random.randint(low=0, high=255, size=ifm_shape, dtype=dtype)}
- output_data = generate_ref_data(mod, input_data)
+ output_data = generate_ref_data(cpu_mod, input_data)
- compiled_models = infra.build_source(
- partitioned_mod,
- input_data,
- output_data,
- accel_type,
- output_tolerance=0,
+ _compare_ethosu_with_reference(
+ ethosu_mod, input_data, output_data, accel_type, output_tolerance=0
)
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_models, accel_type)
-
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
@@ -670,13 +566,9 @@ def test_ethosu_left_shift_binary_elemwise(
ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
c1 = relay.left_shift(ifm, ifm2)
- f = relay.Function([ifm, ifm2], c1)
- mod = tvm.IRModule()
- mod["main"] = f
- return mod
+ return tvm.IRModule.from_expr(relay.Function([ifm, ifm2], c1))
- relay_mod = create_model()
- mod = partition_for_ethosu(relay_mod)
+ cpu_mod = create_model()
# Generate reference data
in_min, in_max = util.get_range_for_dtype_str(dtype)
@@ -684,25 +576,13 @@ def test_ethosu_left_shift_binary_elemwise(
"ifm": np.random.randint(in_min, high=in_max, size=ifm_shape, dtype=dtype),
"ifm2": np.random.randint(0, high=32, size=ifm2_shape, dtype=dtype),
}
- output_data = generate_ref_data(relay_mod, input_data)
+ output_data = generate_ref_data(cpu_mod, input_data)
+ ethosu_mod = partition_for_ethosu(cpu_mod)
- compiled_models = infra.build_source(
- mod,
- input_data,
- output_data,
- accel_type,
+ _compare_ethosu_with_reference(
+ ethosu_mod, input_data, output_data, accel_type, output_tolerance=0
)
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_models, accel_type)
-
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
@@ -719,50 +599,33 @@ def test_ethosu_right_shift_binary_elemwise(
dtype = "int32"
def create_model():
- ifm_count = int(np.prod(ifm_shape))
- ifm2_count = int(np.prod(ifm2_shape))
-
- # Create a "partitioned" Relay function
- ifms = relay.var("ifms", shape=[ifm_count + ifm2_count], dtype=dtype)
- split = relay.split(ifms, [ifm_count])
- ifm = relay.reshape(split[0], newshape=ifm_shape)
- ifm2 = relay.reshape(split[1], newshape=ifm2_shape)
+ ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+ ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
shr_op = infra.make_ethosu_binary_elementwise(
ifm, ifm2, ifm_shape[3], ifm2_shape[3], "SHR", ofm_dtype, reversed_operands
)
+ return tvm.IRModule.from_expr(relay.Function([ifm, ifm2], shr_op))
+
+ def generate_output_data(input_data):
+ lhs = input_data["ifm"]
+ rhs = input_data["ifm2"]
+ if reversed_operands:
+ lhs = np.broadcast_to(lhs, ifm2_shape)
+ lhs, rhs = rhs, lhs
+ else:
+ rhs = np.broadcast_to(rhs, ifm_shape)
- glb_ethosu = relay.GlobalVar("tvmgen_default_ethos_u_main_0")
- func = (
- relay.Function([ifms], shr_op)
- .with_attr("Inline", 1)
- .with_attr("Compiler", "ethos-u")
- .with_attr("global_symbol", "tvmgen_default_ethos_u_main_0")
- .with_attr("Primitive", 1)
- )
- mod = tvm.IRModule()
- mod[glb_ethosu] = func
- mod = relay.transform.InferType()(mod)
+ def rounding_right_shift(lhs, rhs):
+ r = 1 << (rhs - 1)
+ return (lhs + r) >> rhs
- # Main
- ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
- ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
- call = relay.Call(
- glb_ethosu,
- [
- relay.concatenate(
- data=(
- relay.reshape(ifm, newshape=ifm_count),
- relay.reshape(ifm2, newshape=ifm2_count),
- ),
- axis=0,
- )
- ],
- )
- mod["main"] = relay.Function([ifm, ifm2], call)
- mod = relay.transform.InferType()(mod)
- return mod
+ return [
+ np.array([rounding_right_shift(x[0], x[1]) for x in zip(lhs.flat, rhs.flat)]).astype(
+ ofm_dtype
+ )
+ ]
- mod = create_model()
+ cpu_mod = create_model()
# Generate reference data
in_min, in_max = util.get_range_for_dtype_str(dtype)
@@ -773,61 +636,39 @@ def test_ethosu_right_shift_binary_elemwise(
"ifm": lhs,
"ifm2": rhs,
}
+ output_data = generate_output_data(input_data)
+ ethosu_mod = _create_ethosu_partition(cpu_mod)
- if reversed_operands:
- lhs = np.broadcast_to(lhs, ifm2_shape)
- lhs, rhs = rhs, lhs
- else:
- rhs = np.broadcast_to(rhs, ifm_shape)
-
- def rounding_right_shift(lhs, rhs):
- r = 1 << (rhs - 1)
- return (lhs + r) >> rhs
-
- output_data = np.array(
- [rounding_right_shift(x[0], x[1]) for x in zip(lhs.flat, rhs.flat)]
- ).astype(ofm_dtype)
-
- compiled_models = infra.build_source(mod, input_data, [output_data], accel_type)
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_models, accel_type)
+ _compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("ifm_shape", [(3, 2), (1, 15, 11, 7), (3, 1, 12), (400,)])
@pytest.mark.parametrize("ifm_scale, ifm_zp, ofm_scale, ofm_zp", [(1, 0, 1, 0), (0.015, 3, 0.2, 5)])
def test_ethosu_identity_codegen(ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp, accel_type):
- # Create a "partitioned" Relay function
- ifm0 = relay.var("ifm0", shape=ifm_shape, dtype="int8")
- identity = infra.make_ethosu_identity(
- ifm0, ifm_scale=ifm_scale, ifm_zero_point=ifm_zp, ofm_scale=ofm_scale, ofm_zero_point=ofm_zp
- )
- mod = infra.make_partitioned_function(identity)
-
- in_data = np.random.randint(-120, high=120, size=ifm_shape, dtype="int8")
- requant_data = (ifm_scale * (in_data - ifm_zp)) / ofm_scale + ofm_zp
- out_data = np.round(np.clip(requant_data, -128, 127)).astype("int8")
+ def create_model():
+ ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
+ identity = infra.make_ethosu_identity(
+ ifm,
+ ifm_scale=ifm_scale,
+ ifm_zero_point=ifm_zp,
+ ofm_scale=ofm_scale,
+ ofm_zero_point=ofm_zp,
+ )
+ return tvm.IRModule.from_expr(relay.Function([ifm], identity))
- compiled_model = infra.build_source(
- mod, {"ifm": in_data}, [out_data], accel_type, output_tolerance=1
- )
+ def generate_output_data(input_data):
+ requant_data = (ifm_scale * (input_data["ifm"] - ifm_zp)) / ofm_scale + ofm_zp
+ return [np.round(np.clip(requant_data, -128, 127)).astype("int8")]
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0]
+ cpu_mod = create_model()
+ input_data = {"ifm": np.random.randint(-120, high=120, size=ifm_shape, dtype="int8")}
+ output_data = generate_output_data(input_data)
+ ethosu_mod = _create_ethosu_partition(cpu_mod)
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_model, accel_type)
+ _compare_ethosu_with_reference(
+ ethosu_mod, input_data, output_data, accel_type, output_tolerance=1
+ )
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -844,36 +685,17 @@ def test_ethosu_identity_codegen(ifm_shape, ifm_scale, ifm_zp, ofm_scale, ofm_zp
],
)
def test_relay_reshape_codegen(ifm_shape, new_shape, accel_type):
- # Create a "partitioned" Relay graph
- ifm0 = relay.var("ifm0", shape=ifm_shape, dtype="int8")
- reshape = relay.op.reshape(ifm0, newshape=new_shape)
- mod = infra.make_partitioned_function(reshape)
-
- data = np.random.randint(-128, high=127, size=ifm_shape, dtype="int8")
-
- # Generate a reference output using Relay reshape that doesn't get offloaded
- ref_mod = tvm.IRModule()
- ref_mod["main"] = relay.Function([ifm0], reshape)
- ref_mod = relay.transform.InferType()(ref_mod)
-
- out_data = generate_ref_data(ref_mod, {"ifm0": data})
-
- compiled_model = infra.build_source(
- mod,
- {"ifm": data},
- out_data,
- accel_type,
- )
+ def create_model():
+ ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
+ reshape = relay.op.reshape(ifm, newshape=new_shape)
+ return tvm.IRModule.from_expr(relay.Function([ifm], reshape))
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0]
+ cpu_mod = create_model()
+ input_data = {"ifm": np.random.randint(-128, high=127, size=ifm_shape, dtype="int8")}
+ output_data = generate_ref_data(cpu_mod, input_data)
+ ethosu_mod = _create_ethosu_partition(cpu_mod)
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_model, accel_type)
+ _compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -887,36 +709,17 @@ def test_relay_reshape_codegen(ifm_shape, new_shape, accel_type):
],
)
def test_relay_strided_slice_codegen(ifm_shape, begin, end, accel_type):
- # Create a "partitioned" Relay graph
- ifm0 = relay.var("ifm0", shape=ifm_shape, dtype="int8")
- strided_slice = relay.op.strided_slice(ifm0, begin, end)
- mod = infra.make_partitioned_function(strided_slice)
-
- input_data = np.random.randint(-128, high=127, size=ifm_shape, dtype="int8")
-
- # Generate a reference output using Relay strided slice that doesn't get offloaded
- ref_mod = tvm.IRModule()
- ref_mod["main"] = relay.Function([ifm0], strided_slice)
- ref_mod = relay.transform.InferType()(ref_mod)
-
- out_data = generate_ref_data(ref_mod, {"ifm0": input_data})
-
- compiled_model = infra.build_source(
- mod,
- {"ifm": input_data},
- out_data,
- accel_type,
- )
+ def create_model():
+ ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
+ strided_slice = relay.op.strided_slice(ifm, begin, end)
+ return tvm.IRModule.from_expr(relay.Function([ifm], strided_slice))
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0]
+ cpu_mod = create_model()
+ input_data = {"ifm": np.random.randint(-128, high=127, size=ifm_shape, dtype="int8")}
+ output_data = generate_ref_data(cpu_mod, input_data)
+ ethosu_mod = _create_ethosu_partition(cpu_mod)
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_model, accel_type)
+ _compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -930,66 +733,13 @@ def test_ethosu_unary_elementwise(
operator_type,
ifm_shape,
):
- dtype = "int8"
-
- def get_tflite_graph():
- class Model(tf.Module):
- @tf.function
- def abs_func(self, x):
- if operator_type == "ABS":
- op = tf.math.abs(x)
- return op
-
- model = Model()
-
- concrete_func = model.abs_func.get_concrete_function(
- tf.TensorSpec(ifm_shape, dtype=tf.float32)
- )
-
- # Convert the model
- def representative_dataset():
- for _ in range(100):
- data = np.random.rand(*tuple(ifm_shape))
- yield [data.astype(np.float32) * 2 - 1]
+ @tf.function
+ def abs_func(x):
+ if operator_type == "ABS":
+ op = tf.math.abs(x)
+ return op
- converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- tflite_model = converter.convert()
- return tflite_model
-
- tflite_graph = get_tflite_graph()
- tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
-
- relay_module, params = relay.frontend.from_tflite(
- tflite_model,
- shape_dict={"input": ifm_shape},
- dtype_dict={"input": dtype},
- )
- mod = partition_for_ethosu(relay_module, params)
-
- # Generate reference data
- input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
-
- compiled_models = infra.build_source(
- mod,
- input_data,
- output_data,
- accel_type,
- )
-
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_models, accel_type)
+ _compare_tvm_with_tflite(abs_func, [ifm_shape], accel_type)
def test_ethosu_section_name():
@@ -1046,33 +796,32 @@ def test_ethosu_section_name():
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
def test_ethosu_clz(accel_type):
ifm_shape = (1, 42, 5, 4)
- # Create a "partitioned" Relay function
- ifm0 = relay.var("ifm0", shape=ifm_shape, dtype="int32")
- clz = infra.make_ethosu_unary_elementwise(ifm0, 4, "CLZ")
- mod = infra.make_partitioned_function(clz)
-
- in_data = np.random.randint(-500000, high=500000, size=ifm_shape, dtype="int32")
- def clz_comp(n):
- n_bin = np.binary_repr(n)
- if n_bin[0] == "-":
- return 0
- else:
- return 32 - len(n_bin)
-
- out_data = np.array([clz_comp(i) for i in in_data.ravel()]).reshape(ifm_shape).astype("int32")
-
- compiled_model = infra.build_source(mod, {"ifm": in_data}, [out_data], accel_type)
-
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_model[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_model, accel_type)
+ def create_model():
+ ifm = relay.var("ifm", shape=ifm_shape, dtype="int32")
+ clz = infra.make_ethosu_unary_elementwise(ifm, 4, "CLZ")
+ return tvm.IRModule.from_expr(relay.Function([ifm], clz))
+
+ def generate_output_data(input_data):
+ def clz_comp(n):
+ n_bin = np.binary_repr(n)
+ if n_bin[0] == "-":
+ return 0
+ else:
+ return 32 - len(n_bin)
+
+ return [
+ np.array([clz_comp(i) for i in input_data["ifm"].ravel()])
+ .reshape(ifm_shape)
+ .astype("int32")
+ ]
+
+ cpu_mod = create_model()
+ input_data = {"ifm": np.random.randint(-500000, high=500000, size=ifm_shape, dtype="int32")}
+ output_data = generate_output_data(input_data)
+ ethosu_mod = _create_ethosu_partition(cpu_mod)
+
+ _compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)