You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/12 07:36:06 UTC

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #4307: [QNN][Legalize] Specialize for Platforms w/o fast Int8 support

anijain2305 commented on a change in pull request #4307: [QNN][Legalize] Specialize for Platforms w/o fast Int8 support
URL: https://github.com/apache/incubator-tvm/pull/4307#discussion_r345051410
 
 

 ##########
 File path: python/tvm/relay/qnn/op/legalizations.py
 ##########
 @@ -137,4 +165,124 @@ def _is_int8_hw_support(target):
     new_attrs = {k : attrs[k] for k in attrs.keys()}
     new_attrs['input_zero_point'] = input_zp
     new_attrs['kernel_zero_point'] = kernel_zp
-    return relay.qnn.op.conv2d(data, kernel, **new_attrs)
+    return relay_op(data, kernel, **new_attrs)
+
+# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
+def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
+    """ Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
+    many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
+    conv2d/dense such that both the dtypes are same.
+
+    Parameters
+    ----------
+    attrs : tvm.attrs.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+
+    def _shift(data, out_dtype):
+        """Shifts (add/subtracts) the qnn tensor with +/-128)"""
+        if out_dtype == 'uint8':
+            shift = 128
+        elif out_dtype == 'int8':
+            shift = -128
+        else:
+            raise ValueError("Unsupport out dtype.")
+        data_modified = relay.cast(data, 'int32')
+        data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
+        data_modified = relay.cast(data_modified, out_dtype)
+        return data_modified
+
+    # Collect the dtypes.
+    data_dtype = types[0].dtype
+    kernel_dtype = types[1].dtype
+
+    # Collect the input exprs.
+    data, kernel = inputs
+
+    if data_dtype == kernel_dtype:
+        return None
+
+    assert 'int8' in data_dtype and 'int8' in kernel_dtype, \
+            "Qnn Conv2D only accepts uint8 or int8 inputs"
+
+    # Shift input if necessary.
+    input_zp = attrs['input_zero_point']
+    data = _shift(data, kernel_dtype)
+    if data_dtype == 'int8':
+        input_zp = input_zp + 128
+    elif data_dtype == 'uint8':
+        input_zp = input_zp - 128
+    else:
+        raise RuntimeError("Qnn Conv2D only accepts uint8 or int8 inputs")
+
+    new_attrs = {k : attrs[k] for k in attrs.keys()}
+    new_attrs['input_zero_point'] = input_zp
+    return relay_op(data, kernel, **new_attrs)
+
+def is_fast_int8_hw_present():
+    """
+    Checks whether the hardware has support for fast Int8 arithmetic operations.
+        1) Intel - Skylake/CascadeLake
+        2) ARM - Dotprod
+    We can extend this function to add more device targets.
+    """
+
+    target = tvm.target.current_target(allow_none=False)
+
+    # Intel cpu
+    intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
+    is_present_intel = intel_supported_arches.intersection(set(target.options))
+
+    # ARM cpu
+    arm_supported_attr = '+v8.2a,+dotprod'
+    is_present_arm = False
+    for opt in target.options:
+        if arm_supported_attr in opt:
+            is_present_arm = True
+
+    return is_present_intel or is_present_arm
+
+########################
+# ARM CPU legalizations.
+########################
+
+@qnn_conv2d_legalize.register('arm_cpu')
+def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
+    # ARM prefers the dtypes to be same.
+    if is_fast_int8_hw_present():
+        return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
+
+@qnn_dense_legalize.register('arm_cpu')
+def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
+    # ARM prefers the dtypes to be same.
+    if is_fast_int8_hw_present():
+        return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
+    return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
+
+##########################
+# Intel CPU legalizations.
+##########################
+
+@qnn_conv2d_legalize.register('cpu')
+def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
+    # The VNNI transformations prefer uint8 x int8 datatypes.
+    if is_fast_int8_hw_present():
 
 Review comment:
   This function is used twice - for conv2d and dense, even for Intel CPU. So, I decided to put that into a function. I think this might be ok. We can have one place where we can filter out the targets that have fast int8 HW.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services