You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/15 11:25:28 UTC

[GitHub] [tvm] manupa-arm commented on a change in pull request #9508: [microNPU] Update Conv2D Tests to Use TF API to Gen Test Cases

manupa-arm commented on a change in pull request #9508:
URL: https://github.com/apache/tvm/pull/9508#discussion_r749226683



##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -30,9 +30,8 @@
 from tvm.relay.op.contrib import ethosu
 from tvm.relay.build_module import bind_params_by_name
 
-from . import relay_ir_builder
 from . import infra
-
+pytest.importorskip("ethosu.vela")

Review comment:
       Same goes to this one as well

##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -48,122 +46,147 @@ def get_shape_expr(in_expr, out_expr):
     return shape
 
 
-@pytest.mark.parametrize(
-    "accel_type",
-    ACCEL_TYPES,
-)
-def test_ethosu_conv2d(accel_type):
-    def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtype):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[3], 32)
-        c1_params.kernel.sc = relay.const(np.random.rand(32) * 2, "float32")
-        c1_params.strides = (1, 1)
-        c1_params.pad = "VALID"
-        c1_params.update_output_qnn_params(
-            input_tensor_dtype, input_tensor_dtype, input_tensor_dtype
-        )
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
+@pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 55, 55, 3)])
+@pytest.mark.parametrize("kernel_shape", [(3, 2, 3, 3), (1, 3, 3, 3)])
+@pytest.mark.parametrize("padding", ["SAME", "VALID"])
+@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
+def test_ethosu_conv2d(ifm_shape, kernel_shape, padding, accel_type):
+    dtype = "int8"
 
-        f = relay.Function([input0], c1)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c1_params]
-
-    def create_graph_double(input_tensor_name, input_tensor_shape, input_tensor_dtype):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8)
-        c1_params.strides = (2, 2)
-        c1_params.pad = "VALID"
-        c1_params.update_output_qnn_params(
-            input_tensor_dtype, input_tensor_dtype, input_tensor_dtype
-        )
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
-
-        c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c2_params.ifm.shape = c1_params.ofm.shape
-        c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16)
-        c2_params.strides = (1, 1)
-        c2_params.pad = "SAME"
-        c2_params.update_output_qnn_params()
-        c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1)
-        c2_params.ofm.shape = get_shape_expr(input0, c2)
-
-        f = relay.Function([input0], c2)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c2_params, c1_params]
-
-    def create_graph_activation(input_tensor_name, input_tensor_shape, input_tensor_dtype):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8)
-        c1_params.strides = (2, 2)
-        c1_params.pad = "VALID"
-        c1_params.activation = "CLIP"
-        c1_params.clip_min = 90
-        c1_params.clip_max = 110
-        c1_params.update_output_qnn_params(
-            input_tensor_dtype, input_tensor_dtype, input_tensor_dtype
+    def create_tflite_graph_single():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x):
+                # Use tf.nn API to create the model
+                op = tf.nn.conv2d(
+                    x,
+                    filters=tf.constant(np.random.uniform(size=kernel_shape), dtype=tf.float32),
+                    strides=(1, 1),
+                    padding=padding,
+                    data_format="NHWC",
+                    dilations=1,
+                )
+                return op
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
         )
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
-
-        c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c2_params.ifm.shape = c1_params.ofm.shape
-        c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16)
-        c2_params.strides = (1, 1)
-        c2_params.pad = "SAME"
-        c2_params.update_output_qnn_params()
-        c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1)
-        c2_params.ofm.shape = get_shape_expr(input0, c2)
-
-        f = relay.Function([input0], c2)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c2_params, c1_params]
-
-    test_cases = [
-        (create_graph_single, ["input", (1, 300, 300, 3), "int8"]),
-        (create_graph_double, ["input", (1, 128, 256, 4), "int8"]),
-        (create_graph_activation, ["input", (1, 64, 100, 4), "int8"]),
-    ]
-    np.random.seed(42)
-    for test_case in test_cases:
-        relay_module, conv_params = test_case[0](*test_case[1])
-        input_tensor, input_shape, input_dtype = test_case[1]
-        mod = partition_for_ethosu(relay_module)
-
-        # Generate reference data
-        in_min, in_max = util.get_range_for_dtype_str(input_dtype)
-        input_data = {
-            input_tensor: np.random.randint(
-                in_min, high=in_max, size=input_shape, dtype=input_dtype
-            )
-        }
-        output_data = generate_ref_data(relay_module, input_data)
-
-        compiled_models = infra.build_source(
-            mod, input_data, output_data, accel_type, output_tolerance=1
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def create_tflite_graph_double():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function_double(self, x):
+                # Use tf.nn API to create the model with two convolutions
+                op = tf.nn.conv2d(
+                    x,
+                    filters=tf.constant(np.random.uniform(size=kernel_shape), dtype=tf.float32),
+                    strides=(1, 1),
+                    padding=padding,
+                    data_format="NHWC",
+                    dilations=1,
+                )
+                # Second convolution
+                op2 = tf.nn.conv2d(
+                    op,
+                    filters=tf.constant(np.random.uniform(size=kernel_shape), dtype=tf.float32),
+                    strides=(1, 1),
+                    padding=padding,
+                    data_format="NHWC",
+                    dilations=2,
+                )
+                return op2
+
+        model = Model()
+        concrete_func = model.tf_function_double.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
         )
 
-        # Assumes only two runtime.Modules are created -- i.e. single offload module
-        imported_modules = compiled_models[0].executor_factory.lib.imported_modules
-        assert len(imported_modules) == 2
-        ethosu_module = imported_modules[0]
-
-        # Verify generated C source
-        get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
-        cmms = get_cs(ethosu_module)
-        cmms = bytes.fromhex(cmms)
-        infra.print_payload(cmms)
-        infra.verify_source(compiled_models, accel_type)
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    tflite_graph_single = create_tflite_graph_single()
+    tflite_model_single = tflite.Model.Model.GetRootAsModel(tflite_graph_single, 0)
+
+    tflite_graph_double = create_tflite_graph_double()
+    tflite_model_double = tflite.Model.Model.GetRootAsModel(tflite_graph_double, 0)

Review comment:
       I think tflite_model_double and tflite_model_single beyond this point has the same code, if I am not mistaken. Could we refactor that bit re-use the same code ?

##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -221,135 +220,135 @@ def get_shape_expr(in_expr, out_expr):
     return shape
 
 
+def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation):
+    if padding.lower() == "valid":
+        h = math.ceil((ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0]) / strides[0])
+        w = math.ceil((ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1]) / strides[1])
+    if padding.lower() == "same":
+        h = math.ceil(ifm_shape[1] / strides[0])
+        w = math.ceil(ifm_shape[2] / strides[1])
+    ofm_shape = [ifm_shape[0], h, w, kernel_shape[3]]
+    return ofm_shape
+
+
 INVERSE_LAYOUT_TRANSFORM_OHWI_MAP = {
     "HWIO": [1, 2, 3, 0],
     "HWOI": [1, 2, 0, 3],
     "OWHI": [0, 1, 2, 3],
 }
 
 
-def test_ethosu_conv2d_legalize():
-    def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtype):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[3], 32)
-        c1_params.strides = (1, 1)
-        c1_params.pad = "VALID"
-        c1_params.activation = "CLIP"
-        c1_params.clip_min = 23
-        c1_params.clip_max = 180
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
-
-        f = relay.Function([input0], c1)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c1_params]
-
-    def create_graph_double(input_tensor_name, input_tensor_shape, input_tensor_dtype):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8)
-        c1_params.strides = (2, 2)
-        c1_params.pad = "VALID"
-        c1_params.activation = "CLIP"
-        c1_params.clip_min = 10
-        c1_params.clip_max = 240
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
-
-        c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c2_params.ifm.shape = c1_params.ofm.shape
-        c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16)
-        c2_params.strides = (1, 1)
-        c2_params.pad = "SAME"
-        c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1)
-        c2_params.ofm.shape = get_shape_expr(input0, c2)
-
-        f = relay.Function([input0], c2)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c2_params, c1_params]
+@pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 55, 55, 3)])
+@pytest.mark.parametrize("kernel_shape", [(3, 2, 3, 3), (1, 3, 3, 3)])
+@pytest.mark.parametrize("padding", ["SAME", "VALID"])
+@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
+@pytest.mark.parametrize("activation", [None, None])
+def test_tflite_conv_2d_legalize(ifm_shape, kernel_shape, padding, strides, dilation, activation):
+    dtype = "int8"
+
+    def create_tflite_graph_single():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, input_shape):
+                op = tf.nn.conv2d(
+                    input_shape,
+                    filters=tf.constant(np.random.uniform(size=kernel_shape), dtype=tf.float32),
+                    strides=strides,
+                    padding=padding,
+                    data_format="NHWC",
+                    dilations=dilation,
+                )
+                if activation:
+                    op = tf.nn.relu(op)
+                return op
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
 
-    def verify_tensor(tensor_type, expr):
-        assert list(tensor_type.shape) == list(expr.checked_type.shape)
-        assert str(tensor_type.dtype) == str(expr.checked_type.dtype)
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
 
-    def verify_linear(ext_func, conv2d_params):
+    def verify(ext_func):
         op = ext_func.body
-        for param in conv2d_params:
-            verify_tensor(param.ifm, op.args[0])
-            verify_tensor(param.ofm, op)
-
-            # This will be in OHWI layout
-            weights_ohwi = op.args[1].data.asnumpy()
-            weights_layout = str(param.kernel.layout)
-            weights = np.transpose(weights_ohwi, INVERSE_LAYOUT_TRANSFORM_OHWI_MAP[weights_layout])
-            assert weights.shape == param.kernel.shape
-            assert weights.dtype == param.kernel.dtype
-
-            assert list(op.args[2].checked_type.shape)[0] == weights_ohwi.shape[0]
-
-            assert float(op.attrs.ifm_scale) == float(param.ifm.sc.data.asnumpy())
-            assert int(op.attrs.ifm_zero_point) == int(param.ifm.zp.data.asnumpy())
-            assert int(op.attrs.weight_zero_point) == int(param.kernel.zp.data.asnumpy())
-            assert float(op.attrs.ofm_scale) == float(param.ofm.sc.data.asnumpy())
-            assert int(op.attrs.ofm_zero_point) == int(param.ofm.zp.data.asnumpy())
-            assert int(op.attrs.ofm_channels) == int(weights_ohwi.shape[0])
-            assert list(op.attrs.padding) == list(param.pad)
-            assert list(op.attrs.strides) == list(param.strides)
-            assert list(op.attrs.dilation) == list(param.dilation)
-            assert str(op.attrs.activation) == str(param.activation)
-            assert int(op.attrs.clip_min) == int(param.clip_min)
-            assert int(op.attrs.clip_max) == int(param.clip_max)
-            op = op.args[0]
-
-    test_cases = [
-        (create_graph_single, ["input", (1, 299, 299, 3), "uint8"]),
-        (create_graph_double, ["input", (1, 128, 256, 4), "uint8"]),
-    ]
-    for test_case in test_cases:
-        mod, conv_params = test_case[0](*test_case[1])
-        mod = ethosu.partition_for_ethosu(mod)
-        mod = legalize.LegalizeConv2D()(mod)
-        verify_linear(mod["tvmgen_default_ethosu_main_0"], conv_params)
-
-
-def test_ethosu_conv2d_legalize_errors():
-    def create_graph_single_unsupported_ifm_layout(
-        input_tensor_name, input_tensor_shape, input_tensor_dtype
-    ):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.ifm.layout = "NCHW"
-        c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[1], 32)
-        c1_params.strides = (1, 1)
-        c1_params.pad = "VALID"
-        c1_params.activation = "CLIP"
-        c1_params.clip_min = 23
-        c1_params.clip_max = 180
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
-
-        f = relay.Function([input0], c1)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c1_params]
+        ofm_channels = op.attrs.ofm_channels
 
-    test_cases = [
-        (create_graph_single_unsupported_ifm_layout, ["input", (1, 3, 299, 299), "uint8"]),
+        # check IFM
+        ifm = op.args[0].checked_type
+        assert list(ifm.shape) == list(ifm_shape)
+        assert str(ifm.dtype) == dtype
+        assert ifm.shape[3] == ofm_channels
+
+        # check OFM
+        ofm = op.checked_type
+        expected_ofm_shape = compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation)
+        assert list(ofm.shape) == list(expected_ofm_shape)
+        assert str(ofm.dtype) == dtype
+        assert ofm.shape[3] == ofm_channels
+
+        # check weights
+        weights_ohwi = op.args[1].data.asnumpy()
+        assert str(weights_ohwi.dtype) == dtype
+        assert weights_ohwi.shape[0] == ofm_channels
+        assert weights_ohwi.shape[1] == kernel_shape[0]
+        assert weights_ohwi.shape[2] == kernel_shape[1]
+        assert weights_ohwi.shape[3] == 3  # only depth multiplier 1 is supported
+
+        # Check that scale_bias matches weight tensor
+        assert list(op.args[2].checked_type.shape)[0] == ofm_channels
+
+        expected_padding = infra.compute_padding_shape(
+            ifm_shape,
+            expected_ofm_shape,
+            padding,
+            (kernel_shape[0], kernel_shape[1]),
+            strides,
+            dilation,
+        )
+        assert list(op.attrs.padding) == list(expected_padding)
+        assert op.attrs.ofm_channels == ofm_channels
+        assert list(op.attrs.strides) == list(strides)
+        assert list(op.attrs.dilation) == list(dilation)
+        if activation == "RELU":
+            assert str(op.attrs.activation) == "CLIP"
+
+    conv2d_pattern_table = [
+        (
+            ethosu.QnnConv2DParams.composite_name,
+            ethosu.qnn_conv2d_pattern(),
+            lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
+        )
     ]
 
-    for test_case in test_cases:
-        mod, conv_params = test_case[0](*test_case[1])
-        mod = ethosu.partition_for_ethosu(mod)
-        with pytest.raises(
-            tvm._ffi.base.TVMError, match="EthosUCodegenError: Unsupported Layout NCHW"

Review comment:
       I think we discussed and decided to remove this error as it would not be invoked from a user flow.  cc : @ekalda . 
   
   Please remove the following lines : https://github.com/apache/tvm/blob/13f54e03fd3b7066d9ce9b0b0c3e1ba297d627bb/python/tvm/relay/backend/contrib/ethosu/legalize.py#L141-L142 and the errors.py as they are no longer needed.

##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -17,19 +17,17 @@
 # pylint: disable=invalid-name, unused-argument
 import pytest
 
-pytest.importorskip("ethosu.vela")
 import numpy as np
 import tflite.Model
 
 import tvm
 import tensorflow as tf
 from tvm import relay
-from tvm.relay.backend.contrib.ethosu import util
 from tvm.relay.op.contrib.ethosu import partition_for_ethosu
-from tests.python.relay.aot.aot_test_utils import generate_ref_data
 
-from . import relay_ir_builder
 from . import infra
+pytest.importorskip("ethosu.vela")

Review comment:
       Lets keep this at the very top so all dependencies are satisfied -- as it originally were.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org