You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/06 15:00:19 UTC

[GitHub] [tvm] mbaret commented on a change in pull request #9209: Arm(R) Ethos(TM)-U NPU Depthwise2d operator support

mbaret commented on a change in pull request #9209:
URL: https://github.com/apache/tvm/pull/9209#discussion_r723369248



##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -327,12 +341,118 @@ def create_graph_single_unsupported_ifm_layout(
 
     for test_case in test_cases:
         mod, conv_params = test_case[0](*test_case[1])
-        mod = partition_for_ethosu(mod)
+        mod = ethosu.partition_for_ethosu(mod)
         with pytest.raises(
             tvm._ffi.base.TVMError, match="EthosUCodegenError: Unsupported Layout NCHW"
         ):
             mod = legalize.LegalizeEthosUConv2D()(mod)
 
 
+@pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)])
+@pytest.mark.parametrize("kernel_shape", [(7, 3), (22, 5)])
+@pytest.mark.parametrize("padding", ["SAME", "VALID"])
+@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
+@pytest.mark.parametrize("activation", ["RELU", None])
+def test_tflite_depthwise2d_legalize(
+    ifm_shape, kernel_shape, padding, strides, dilation, activation
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def depthwise2d(self, x):
+                weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
+                weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+                # The input strides to the TensorFlow API needs to be of shape 1x4
+                tf_strides = [1, strides[0], strides[1], 1]
+                op = tf.nn.depthwise_conv2d(
+                    x, weight, strides=tf_strides, padding=padding, dilations=dilation
+                )
+                if activation:
+                    op = tf.nn.relu(op)
+                return op
+
+        model = Model()
+        concrete_func = model.depthwise2d.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        op = ext_func.body
+        ofm_channels = op.attrs.ofm_channels
+
+        # check IFM
+        ifm = op.args[0].checked_type
+        assert list(ifm.shape) == list(ifm_shape)
+        assert str(ifm.dtype) == dtype
+        assert ifm.shape[3] == ofm_channels
+
+        # check OFM
+        ofm = op.checked_type
+        expected_ofm_shape = infra.compute_ofm_shape(
+            ifm_shape, padding, kernel_shape, strides, dilation
+        )
+        assert list(ofm.shape) == list(expected_ofm_shape)
+        assert str(ofm.dtype) == dtype
+        assert ofm.shape[3] == ofm_channels
+
+        # check weights
+        weights_ohwi = op.args[1].data.asnumpy()
+        assert str(weights_ohwi.dtype) == dtype
+        assert weights_ohwi.shape[0] == ofm_channels
+        assert weights_ohwi.shape[1] == kernel_shape[0]
+        assert weights_ohwi.shape[2] == kernel_shape[1]
+        assert weights_ohwi.shape[3] == 1  # only depth multiplier 1 is supported
+
+        # Check that scale_bias matches weight tensor
+        assert list(op.args[2].checked_type.shape)[0] == ofm_channels
+
+        expected_padding = infra.compute_padding_shape(
+            ifm_shape, expected_ofm_shape, padding, kernel_shape, strides, dilation
+        )
+        assert list(op.attrs.padding) == list(expected_padding)
+        assert op.attrs.ofm_channels == ofm_channels
+        assert list(op.attrs.strides) == list(strides)
+        assert list(op.attrs.dilation) == list(dilation)
+        if activation == "RELU":
+            assert str(op.attrs.activation) == "CLIP"
+
+    depthwise_pattern_table = [
+        (
+            "ethosu.depthwise2d",
+            ethosu.qnn_depthwise2d_pattern(),
+            lambda pat: ethosu.QnnDepthwise2DParams(pat).is_valid(),
+        )
+    ]
+
+    tflite_model = create_tflite_graph()
+    tflite_mod = infra.parse_tflite_model(tflite_model)
+
+    mod, params = infra.parse_relay_tflite_model(tflite_mod, "input", ifm_shape, dtype)

Review comment:
       I think these infra methods obfuscate the code without providing much of a saving. Shall we consider inlining them to make the test more self-contained?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org