You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/13 06:21:54 UTC

[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #4307: [QNN][Legalize] Specialize for Platforms w/o fast Int8 support

FrozenGene commented on a change in pull request #4307: [QNN][Legalize] Specialize for Platforms w/o fast Int8 support
URL: https://github.com/apache/incubator-tvm/pull/4307#discussion_r345584407
 
 

 ##########
 File path: python/tvm/relay/qnn/op/legalizations.py
 ##########
 @@ -137,4 +165,124 @@ def _is_int8_hw_support(target):
     new_attrs = {k : attrs[k] for k in attrs.keys()}
     new_attrs['input_zero_point'] = input_zp
     new_attrs['kernel_zero_point'] = kernel_zp
-    return relay.qnn.op.conv2d(data, kernel, **new_attrs)
+    return relay_op(data, kernel, **new_attrs)
+
+# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
+def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
+    """ Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
+    many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
+    conv2d/dense such that both the dtypes are same.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+
+    def _shift(data, out_dtype):
+        """Shifts (add/subtracts) the qnn tensor with +/-128)"""
+        if out_dtype == 'uint8':
+            shift = 128
+        elif out_dtype == 'int8':
+            shift = -128
+        else:
+            raise ValueError("Unsupport out dtype.")
+        data_modified = relay.cast(data, 'int32')
+        data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
+        data_modified = relay.cast(data_modified, out_dtype)
+        return data_modified
+
+    # Collect the dtypes.
+    data_dtype = types[0].dtype
+    kernel_dtype = types[1].dtype
+
+    # Collect the input exprs.
+    data, kernel = inputs
+
+    if data_dtype == kernel_dtype:
+        return None
+
+    assert 'int8' in data_dtype and 'int8' in kernel_dtype, \
+            "Qnn Conv2D only accepts uint8 or int8 inputs"
+
+    # Shift input if necessary.
+    input_zp = attrs['input_zero_point']
+    data = _shift(data, kernel_dtype)
+    if data_dtype == 'int8':
+        input_zp = input_zp + 128
+    elif data_dtype == 'uint8':
+        input_zp = input_zp - 128
+    else:
+        raise RuntimeError("Qnn Conv2D only accepts uint8 or int8 inputs")
+
+    new_attrs = {k : attrs[k] for k in attrs.keys()}
+    new_attrs['input_zero_point'] = input_zp
+    return relay_op(data, kernel, **new_attrs)
+
+def is_fast_int8_hw_present():
 
 Review comment:
   Maybe we could break into isolated function for Intel and ARM, which will make code cleaner, for example, we have PowerPC support in the future, I would like to have one isolated function ppc_int8_hw_support. However, current way is acceptable too.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services