You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/12 05:48:25 UTC

[GitHub] [incubator-tvm] jackwish commented on a change in pull request #4307: [QNN][Legalize] Specialize for Platforms w/o fast Int8 support

jackwish commented on a change in pull request #4307: [QNN][Legalize] Specialize for Platforms w/o fast Int8 support
URL: https://github.com/apache/incubator-tvm/pull/4307#discussion_r345009110
 
 

 ##########
 File path: python/tvm/relay/qnn/op/legalizations.py
 ##########
 @@ -137,4 +165,124 @@ def _is_int8_hw_support(target):
     new_attrs = {k : attrs[k] for k in attrs.keys()}
     new_attrs['input_zero_point'] = input_zp
     new_attrs['kernel_zero_point'] = kernel_zp
-    return relay.qnn.op.conv2d(data, kernel, **new_attrs)
+    return relay_op(data, kernel, **new_attrs)
+
+# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
+def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
+    """ Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
+    many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
+    conv2d/dense such that both the dtypes are same.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+
+    def _shift(data, out_dtype):
+        """Shifts (add/subtracts) the qnn tensor with +/-128)"""
+        if out_dtype == 'uint8':
+            shift = 128
+        elif out_dtype == 'int8':
+            shift = -128
+        else:
+            raise ValueError("Unsupport out dtype.")
+        data_modified = relay.cast(data, 'int32')
+        data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
+        data_modified = relay.cast(data_modified, out_dtype)
+        return data_modified
+
+    # Collect the dtypes.
+    data_dtype = types[0].dtype
+    kernel_dtype = types[1].dtype
+
+    # Collect the input exprs.
+    data, kernel = inputs
+
+    if data_dtype == kernel_dtype:
+        return None
+
+    assert 'int8' in data_dtype and 'int8' in kernel_dtype, \
+            "Qnn Conv2D only accepts uint8 or int8 inputs"
+
+    # Shift input if necessary.
+    input_zp = attrs['input_zero_point']
+    data = _shift(data, kernel_dtype)
+    if data_dtype == 'int8':
+        input_zp = input_zp + 128
+    elif data_dtype == 'uint8':
+        input_zp = input_zp - 128
+    else:
+        raise RuntimeError("Qnn Conv2D only accepts uint8 or int8 inputs")
+
+    new_attrs = {k : attrs[k] for k in attrs.keys()}
+    new_attrs['input_zero_point'] = input_zp
+    return relay_op(data, kernel, **new_attrs)
+
+def is_fast_int8_hw_present():
+    """
+    Checks whether the hardware has support for fast Int8 arithmetic operations.
+        1) Intel - Skylake/CascadeLake
+        2) ARM - Dotprod
+    We can extend this function to add more device targets.
+    """
+
+    target = tvm.target.current_target(allow_none=False)
+
+    # Intel cpu
+    intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
+    is_present_intel = intel_supported_arches.intersection(set(target.options))
+
+    # ARM cpu
+    arm_supported_attr = '+v8.2a,+dotprod'
+    is_present_arm = False
+    for opt in target.options:
+        if arm_supported_attr in opt:
+            is_present_arm = True
+
+    return is_present_intel or is_present_arm
+
+########################
+# ARM CPU legalizations.
+########################
+
+@qnn_conv2d_legalize.register('arm_cpu')
+def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
+    # ARM prefers the dtypes to be same.
+    if is_fast_int8_hw_present():
+        return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
+
+@qnn_dense_legalize.register('arm_cpu')
+def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
+    # ARM prefers the dtypes to be same.
+    if is_fast_int8_hw_present():
+        return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
+
+##########################
+# Intel CPU legalizations.
+##########################
+
+@qnn_conv2d_legalize.register('cpu')
+def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
+    # The VNNI transformations prefer uint8 x int8 datatypes.
+    if is_fast_int8_hw_present():
 
 Review comment:
   As we are already Intel CPU here, I think the HW feature checking can try Intel CPU directly.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services