You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/01 14:56:07 UTC

[GitHub] [tvm] mbaret commented on a change in pull request #9623: Refactor Ethos-U codegen tests

mbaret commented on a change in pull request #9623:
URL: https://github.com/apache/tvm/pull/9623#discussion_r760262949



##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -166,8 +169,110 @@ def create_graph_activation(input_tensor_name, input_tensor_shape, input_tensor_
         infra.verify_source(compiled_models, accel_type)
 
 
+def _compare_ethosu_with_reference(
+    mod, input_data, output_data, accel_type, output_tolerance=0, print_cmm=False
+):
+    compiled_models = infra.build_source(
+        mod,
+        input_data,
+        output_data,
+        accel_type,
+        output_tolerance=output_tolerance,
+    )
+
+    # Assumes only two runtime.Modules are created -- i.e. single offload module
+    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
+    assert len(imported_modules) == 2
+    ethosu_module = imported_modules[0]
+
+    # Verify generated C source
+    if print_cmm:
+        get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
+        cmms = get_cs(ethosu_module)
+        cmms = bytes.fromhex(cmms)
+        infra.print_payload(cmms)
+
+    infra.verify_source(compiled_models, accel_type)
+
+
+def _compare_tvm_with_tflite(tf_func, shapes, accel_type, ranges=None, print_cmm=False):
+    tensor_specs = [tf.TensorSpec(shape, dtype=tf.float32) for shape in shapes]
+    if not ranges:
+        ranges = [(0, 1) for _ in shapes]
+    concrete_func = tf_func.get_concrete_function(*tensor_specs)
+
+    # Convert the model
+    def representative_dataset():
+        for _ in range(100):
+            inputs = []
+            for i, shape in enumerate(shapes):
+                data = np.random.uniform(
+                    low=ranges[i][0], high=ranges[i][1], size=tuple(shape)
+                ).astype("float32")
+                inputs.append(data)
+
+            yield inputs
+
+    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+    converter.optimizations = [tf.lite.Optimize.DEFAULT]
+    converter.representative_dataset = representative_dataset
+    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+    converter.inference_input_type = tf.int8
+    converter.inference_output_type = tf.int8
+    tflite_graph = converter.convert()
+
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    relay_module, params = relay.frontend.from_tflite(tflite_model)
+    mod = partition_for_ethosu(relay_module, params)
+
+    # Generate reference data
+    input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+    _compare_ethosu_with_reference(mod, input_data, output_data, accel_type, print_cmm=print_cmm)
+
+
+class EthosUAnnotator(ExprMutator):
+    """Annotate entire graph for Ethos-U offload"""
+
+    def __init__(self):
+        super(EthosUAnnotator, self).__init__()
+        self.compiler = "ethos-u"
+        self.last_call = True
+
+    def visit_call(self, call):
+        curr_last = self.last_call
+        self.last_call = False
+
+        params = []
+        for arg in call.args:
+            param = super().visit(arg)
+            if isinstance(param, relay.expr.Var):
+                param = compiler_begin(param, self.compiler)
+            params.append(param)
+
+        new_call = relay.Call(call.op, params, call.attrs)
+        if curr_last:
+            new_call = compiler_end(new_call, self.compiler)
+        return new_call
+
+    def visit_constant(self, constant):
+        new_constant = compiler_begin(constant, self.compiler)
+        return new_constant
+
+
+def _create_ethosu_partition(mod):
+    mod["main"] = EthosUAnnotator().visit(mod["main"])
+    mod = relay.transform.MergeCompilerRegions()(mod)
+    mod = relay.transform.InferType()(mod)
+    mod = relay.transform.PartitionGraph()(mod)
+    mod = relay.transform.InferType()(mod)
+    mod = preprocess.preprocess_ext_io()(mod)
+    return mod
+
+
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
-@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
+@pytest.mark.parametrize("ifm_shape", [(1, 10, 10, 8), (1, 23, 32, 7)])

Review comment:
       Ah, was testing something else and forgot to change back... Good catch :)

##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -183,83 +288,28 @@ def test_tflite_depthwise_conv2d(
     dilation,
     activation,
 ):
-    dtype = "int8"
-
-    def create_tflite_graph():
-        class Model(tf.Module):
-            @tf.function
-            def depthwise_conv2d(self, x):
-                weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
-                weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
-                # The input strides to the TensorFlow API needs to be of shape 1x4
-                tf_strides = [1, strides[0], strides[1], 1]
-                op = tf.nn.depthwise_conv2d(
-                    x, weight, strides=tf_strides, padding=padding, dilations=dilation
-                )
-                if activation:
-                    op = tf.nn.relu(op)
-                return op
-
-        model = Model()
-        concrete_func = model.depthwise_conv2d.get_concrete_function(
-            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+    @tf.function
+    def depthwise_conv2d(x):
+        weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
+        weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        # The input strides to the TensorFlow API needs to be of shape 1x4
+        tf_strides = [1, strides[0], strides[1], 1]
+        op = tf.nn.depthwise_conv2d(
+            x, weight, strides=tf_strides, padding=padding, dilations=dilation
         )
+        if activation:
+            op = tf.nn.relu(op)
+        return op
 
-        # Convert the model
-        def representative_dataset():
-            for _ in range(100):
-                data = np.random.rand(*tuple(ifm_shape))
-                yield [data.astype(np.float32)]
-
-        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
-        converter.optimizations = [tf.lite.Optimize.DEFAULT]
-        converter.representative_dataset = representative_dataset
-        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
-        converter.inference_input_type = tf.int8
-        converter.inference_output_type = tf.int8
-        tflite_model = converter.convert()
-        return tflite_model
-
-    tflite_graph = create_tflite_graph()
-    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
-
-    relay_module, params = relay.frontend.from_tflite(
-        tflite_model,
-        shape_dict={"input": ifm_shape},
-        dtype_dict={"input": dtype},
-    )
-    mod = partition_for_ethosu(relay_module, params)
-
-    # Generate reference data
-    input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
-
-    compiled_models = infra.build_source(
-        mod,
-        input_data,
-        output_data,
-        accel_type,
-    )
-
-    # Assumes only two runtime.Modules are created -- i.e. single offload module
-    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-    assert len(imported_modules) == 2
-    ethosu_module = imported_modules[0]
-
-    # Verify generated C source
-    get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
-    cmms = get_cs(ethosu_module)
-    cmms = bytes.fromhex(cmms)
-
-    infra.print_payload(cmms)
-    infra.verify_source(compiled_models, accel_type)
+    _compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type)
 
 
 @pytest.mark.parametrize(
     "accel_type",
     ACCEL_TYPES,
 )
 @pytest.mark.parametrize("pooling_type", ["MAX", "AVG"])
-@pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]])
+@pytest.mark.parametrize("ifm_shape", [[1, 10, 10, 24], [1, 4, 5, 2]])

Review comment:
       Ack




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org