You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/18 21:31:46 UTC

[GitHub] [tvm] electriclilies opened a new pull request #7474: Quantization in TVM

electriclilies opened a new pull request #7474:
URL: https://github.com/apache/tvm/pull/7474


   This PR introduces a new framework for quantization to TVM. For details and explanation of the code structure, please see https://discuss.tvm.apache.org/t/rfc-quantization-quantization-in-tvm/9161. 
   
   Please let me know if you have any questions or comments.
   
   Also, I'd like to thank @mbrookhart and @jwfromm for their mentorship on this project!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r578915649



##########
File path: src/relay/qnn/op/quantize.cc
##########
@@ -46,19 +47,22 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   }
 
   const auto input_dtype = data->dtype;
-  ICHECK(input_dtype == DataType::Float(32))
-      << "Input type should be one of float32 but was " << input_dtype;
+  ICHECK(input_dtype == DataType::Float(32) || input_dtype == DataType::Int(32))

Review comment:
       I have no idea why the input to quantize could be int32.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r580740763



##########
File path: src/relay/qnn/op/quantize.cc
##########
@@ -46,19 +47,22 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   }
 
   const auto input_dtype = data->dtype;
-  ICHECK(input_dtype == DataType::Float(32))
-      << "Input type should be one of float32 but was " << input_dtype;
+  ICHECK(input_dtype == DataType::Float(32) || input_dtype == DataType::Int(32))

Review comment:
       Sure but it is the output that needs to be int32. Here, you are allowing inputs to be int32. Bias is always float. Am I missing something?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r580435933



##########
File path: include/tvm/relay/qnn/attrs.h
##########
@@ -78,13 +78,18 @@ struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
 /*! \brief Attribute for dequantize operator */
 struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
   int axis;
+  DataType out_dtype;
 
   TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
     TVM_ATTR_FIELD(axis)
         .describe(
             "The channel axis for channel wise dequantization. Default value is -1,"
             "which corresponds to the last axis.")
         .set_default(-1);
+    TVM_ATTR_FIELD(out_dtype)
+        .describe(
+            "The datatype we are dequantizing to (float32 or int32). Defaults to float32.")

Review comment:
       The reason I did this is that nn.conv2d sometimes has int32 as an output dtype. Since we are replacing nn.conv2d with a pattern whose final op is dequantize, dequantize also needs to be able to output int32 as a dtype. If I don't introduce an out_dtype, then sometimes the qnn graph won't pass the type checker because the consumer of dequantize expects an int32 when it can only output a float32. 
   
   Here's where I use out_dtype='int32': https://github.com/apache/tvm/blob/49801e85ff6244b1d0567965fec18544ce51dd70/python/tvm/relay/transform/quantize/_quantizer_patterns.py#L325-#L328
   

##########
File path: python/tvm/relay/transform/quantize/_quantizer_patterns.py
##########
@@ -0,0 +1,712 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Patterns to quantize and how to quantize them."""
+
+import tvm
+from tvm import relay
+
+from tvm.relay.transform.quantize import CalibrationCallback
+from tvm.relay.dataflow_pattern import (
+    is_op,
+    wildcard,
+    is_constant,
+    DFPatternCallback,
+    _DFPatternCallback,
+)
+from tvm.relay.dataflow_pattern import ffi as pattern_ffi
+from tvm.relay.frontend.common import infer_type
+from tvm.relay.op.nn.utils import get_pad_tuple2d
+
+
+class QuantizerPattern(DFPatternCallback):
+    """DFPatternCallback to rewrite patterns as quantized. Also contains extra information
+    used for quantization and calibration.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate the nn.conv2d pattern.
+    """
+
+    # Counts the number of times we've added a scale and zp for variable naming
+    # This needs to be a global variable and not initialized in __init__ because
+    # each scale and zero point must be unique, even if they are created by different
+    # instances.
+    scales_count = 0
+    zp_count = 0
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__()
+        self.calibration_callback = calibration_callback
+
+    def calibrate_pattern(self, calibration_info):
+        """Calculates the scale and zero points for quantizing parts of a generic pattern. By
+        default, we call the calibrate_pattern method of the CalibrationCallback object that is
+        passed into QuantizerPattern during initialization. However, if you want a pattern specific
+        quantization method or a per-channel quantization method, you should overwrite the
+        QuantizerPattern's calibrate_pattern method.
+
+        Parameters
+        ----------
+        calibration_info : CalibrationInfo
+            The class containing relevant information and utility functions to calibrate one
+            instance of a pattern.
+
+        Returns
+        -------
+        scale_zp_map : Dictionary
+            A map from the names of scales and zero point variables in this pattern to their
+            values.
+        """
+        return self.calibration_callback.calibrate_pattern(calibration_info)
+
+    def callback(self, pre, post, node_map):
+        raise NotImplementedError
+
+    def scale(self, name):
+        """Helper to create the scale variable for qnn.quantize when rewriting our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the scale variable.
+
+        is_weight : bool
+            Whether this scale is a weight scale or a data scale. If it is a weight scale, we
+            the returned variable has shape (channels,). Only used for per-channel quantization.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_scale_0'.
+        """
+
+        var = relay.var(
+            str(name) + "_scale_" + str(QuantizerPattern.scales_count), shape=(), dtype="float32"
+        )
+        QuantizerPattern.scales_count += 1
+        return var
+
+    def zero_point(self, name):
+        """Helper to create the zero point variable for qnn.quantize when rewriting our
+        our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the variable.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_zero_pt_0'.
+        """
+        var = relay.var(
+            str(name) + "_zero_pt_" + str(QuantizerPattern.zp_count), shape=(), dtype="int32"
+        )
+        QuantizerPattern.zp_count += 1
+        return var
+
+    def create_scale_zps(self, left_name, right_name):
+        """Helper to create scales and zero points for binops.
+
+        Parameters
+        ----------
+        left_name : str
+            Identifier of the left hand side scale and zero point.
+
+        right_name : str
+            Identifier of the right hand side scale and zero point.
+        """
+        data_scale = self.scale(left_name)
+        data_zp = self.zero_point(left_name)
+        weight_scale = self.scale(right_name)
+        weight_zp = self.zero_point(right_name)
+        self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp]
+
+
+class Conv2DPattern(QuantizerPattern):
+    """Pattern to rewrite nn.conv2d ops as qnn.conv2d ops.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.input = wildcard()
+        self.conv_weight = wildcard()
+        self.inputs = [self.input, self.conv_weight]
+        self.conv2d = is_op("nn.conv2d")(self.input, self.conv_weight)
+        self.pattern = self.conv2d
+        self.attrs = None
+        self.weight_channel_axis = None
+        self.data_channel_axis = None
+        self.channels = None
+
+    def get_kernel_size(self, kernel_shape, kernel_layout):
+        """Gets the size of the kernel.
+
+        Parameters
+        ----------
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        kernel_layout : str
+            Layout of the kernel
+
+        Returns
+        -------
+            kernel_size : NDArray
+                Size of the kernel
+        """
+        if kernel_layout == "OIHW":
+            kernel_size = tuple(kernel_shape[2:4])
+        elif kernel_layout == "HWIO":
+            kernel_size = tuple(kernel_shape[0:2])
+        else:
+            raise ValueError(
+                "Quantizting kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                kernel_layout,
+            )
+        return kernel_size
+
+    def get_attrs(self, attrs, kernel_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.conv2d
+
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        new_attr_dict = {}
+        self.kernel_layout = attrs["kernel_layout"]
+        data_layout = attrs["data_layout"]
+
+        if self.kernel_layout == "OIHW":
+            self.weight_channel_axis = 0
+        elif self.kernel_layout == "HWIO":
+            self.weight_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                self.kernel_layout,
+            )
+
+        if data_layout == "NCHW":
+            self.data_channel_axis = 1
+        elif data_layout == "NHWC":
+            self.data_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing data layout %s for conv2d is not yet supported."
+                + "Please use NCHW or NHWC",
+                data_layout,
+            )
+
+        for attr in attrs.keys():
+            attr_value = attrs[attr]
+            if isinstance(attr_value, tvm.ir.container.Array):
+                attr_value = tuple(attr_value)
+            if attr == "kernel_size":
+                kernel_size = attrs[attr]
+                if kernel_size is None:
+                    kernel_size = self.get_kernel_size(self.kernel_layout, kernel_shape)
+                else:
+                    kernel_size = tuple([k.value for k in attrs[attr]])
+                new_attr_dict[attr] = kernel_size
+            elif attr == "channels":
+                self.channels = attrs[attr]
+                if self.channels is None:
+                    self.channels = kernel_shape[self.weight_channel_axis]
+                if isinstance(self.channels, tvm.tir.expr.IntImm):
+                    self.channels = self.channels.value
+                new_attr_dict[attr] = self.channels
+            elif attr == "padding":
+                # We don't need to put padding in attr dict because we explicitly construct padding
+                self.padding = attrs[attr]
+            else:
+                new_attr_dict[attr] = attr_value
+
+        new_attr_dict["out_dtype"] = "int32"
+        self.attrs = new_attr_dict
+
+    def quantize_args(self):
+        """Helper to quantize the arguments to the qnn.conv2d."""
+        quantized_data = relay.qnn.op.quantize(
+            self.args[0], self.scale_zps[0], self.scale_zps[1], axis=self.data_channel_axis

Review comment:
       Thanks for pointing this out. I'll add a dtype parameter to QuantizerPattern so we can pass in uint8 as well as int8.

##########
File path: python/tvm/relay/transform/quantize/_requantizer.py
##########
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Removes extraneous qnn.quantize and qnn.dequantize from calibrated modules, and replaces them
+with qnn.requanize ops."""
+import math
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import DFPatternCallback, wildcard, is_op, dominates, rewrite
+
+
+class Requantizer:
+    """Removes extraneous qnn.quantize and qnn.dequantize and replaces
+    them with qnn.requantize."""
+
+    class RequantizerCallback(DFPatternCallback):
+        """First pass that inserts requantize ops, specifically taking
+        qnn.dequantize -> qnn.quantize to qnn.requantize
+        and
+        qnn.dequantize -> int8_op* -> qnn.quantize to requantize -> int8_op*
+        """
+
+        def __init__(self):
+            super().__init__()
+
+            self.data = wildcard()
+            self.dequantize_scale = wildcard()
+            self.dequantize_zp = wildcard()
+
+            self.quantize_scale = wildcard()
+            self.quantize_zp = wildcard()
+
+            # Ops that are permitted inbetween quantize and dequantize if we are
+            # rewriting to requantize
+            self.is_int_8_op = (
+                is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool3d")(wildcard())
+                | is_op("nn.relu")(wildcard())
+                | is_op("transpose")(wildcard())
+                | is_op("reshape")(wildcard())
+                | is_op("nn.pad")(wildcard())
+                | is_op("squeeze")(wildcard())
+                | is_op("nn.global_avg_pool2d")
+                | is_op("nn.batch_flatten")
+                | is_op("copy")
+                | is_op("mean")
+                | is_op("sqrt")
+            )
+
+            # All ops in is_int_8_op must also be in self.op_map
+            self.op_map = {
+                relay.op.get("nn.max_pool2d"): relay.op.nn.max_pool2d,
+                relay.op.get("nn.max_pool3d"): relay.op.nn.max_pool3d,
+                relay.op.get("transpose"): relay.op.transpose,
+                relay.op.get("reshape"): relay.op.reshape,
+                relay.op.get("nn.pad"): relay.op.nn.pad,
+                relay.op.get("squeeze"): relay.op.squeeze,
+                relay.op.get("nn.global_avg_pool2d"): relay.op.nn.global_avg_pool2d,
+                relay.op.get("nn.batch_flatten"): relay.op.nn.batch_flatten,
+                relay.op.get("copy"): relay.op.copy,
+                relay.op.get("mean"): relay.op.mean,
+                relay.op.get("sqrt"): relay.op.sqrt,
+            }
+
+            # Main pattern -- quantize(is_int_8_op*(dequantize(data))) --
+            # (with 1 or more is_int_8_ops)
+            self.dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+
+            self.dominator = dominates(self.dequantize, self.is_int_8_op, self.is_int_8_op)
+            self.quantize = is_op("qnn.quantize")(
+                self.dominator, self.quantize_scale, self.quantize_zp
+            )
+
+            # Pattern with the null path : quantize(dequantize(data)) -- (no is_int_8_op inbetween)
+            # We have to do the null path outside the dominator pattern because of pattern matcher
+            # limitations
+            self.no_path_dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+            self.no_path_quantize = is_op("qnn.quantize")(
+                self.no_path_dequantize, self.quantize_scale, self.quantize_zp
+            )
+
+            self.pattern = self.quantize | self.no_path_quantize
+
+        def callback(self, pre, post, node_map):
+            # Extract data from the pattern
+            data = node_map[self.data][0]
+            dequantize_scale = node_map[self.dequantize_scale][0]
+            deq_zp = node_map[self.dequantize_zp][0]
+
+            quantize_scale = node_map[self.quantize_scale][0]
+            quantize_zp = node_map[self.quantize_zp][0]
+
+            # Case where there are no ops in between the dequantize and quantize
+            if self.no_path_quantize in node_map:
+                axis = node_map[self.no_path_dequantize][0].attrs.axis
+                res = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+            # Ops inbetween quantize and dequantize are dominated
+            elif self.quantize in node_map:
+
+                axis = node_map[self.dequantize][0].attrs.axis
+                transformed_data = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+                for i in range(len(node_map[self.is_int_8_op]) - 1, -1, -1):
+                    call = node_map[self.is_int_8_op][i]
+                    # Transform relu into max(zeropoint)
+                    if call.op == relay.op.get("nn.relu"):
+                        if (
+                            quantize_zp.data.asnumpy()
+                            == relay.const(0, dtype="int32").data.asnumpy()
+                        ):
+                            transformed_data = relay.op.nn.relu(transformed_data)
+                        else:
+                            transformed_data = relay.op.maximum(
+                                transformed_data, relay.cast(quantize_zp, "int8")
+                            )
+                    elif call.op in self.op_map.keys():
+                        transformed_data = self.op_map[call.op](transformed_data, **call.attrs)
+                    else:
+                        raise ValueError(
+                            "Uh oh, %s is not copied properly in the requantizer. " % str(call.op)
+                        )
+                res = transformed_data
+            return res
+
+    class RequantizeChainCallback(DFPatternCallback):
+        """Folds chains of requantizes into one requantize.
+        requantize(scale_a, zp_a, scale_b, zp_b) -> requantize(scale_b, zp_b, scale_c, zp_c) becomes

Review comment:
       I have definitely seen it which is why I put it in, but I'll have to find the example again. Let me get back to you about this.

##########
File path: src/relay/qnn/op/quantize.cc
##########
@@ -46,19 +47,22 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   }
 
   const auto input_dtype = data->dtype;
-  ICHECK(input_dtype == DataType::Float(32))
-      << "Input type should be one of float32 but was " << input_dtype;
+  ICHECK(input_dtype == DataType::Float(32) || input_dtype == DataType::Int(32))

Review comment:
       When quantizing nn.conv2d -> nn.bias_add, we quantize the bias weight to int32. This is what other quantization frameworks do, including pytorch (see https://stackoverflow.com/questions/63132181/how-does-bias-work-in-pytorch-quantized-convolution). 
   
   See this snippet from https://github.com/apache/tvm/blob/49801e85ff6244b1d0567965fec18544ce51dd70/python/tvm/relay/transform/quantize/_quantizer_patterns.py#L352-#L355
   
   ```
   quantized_bias = relay.qnn.op.quantize(
               self.args[2], self.scale_zps[0], self.scale_zps[1], axis=0, out_dtype="int32"
           )
           self.quantized_args.append(quantized_bias)
   ```
   

##########
File path: src/relay/qnn/op/dequantize.cc
##########
@@ -105,6 +115,10 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
 
   auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point);
   auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale);
+
+  if (out_dtype != DataType::Float(32)) {
+    scaled_output = Cast(scaled_output, out_dtype);

Review comment:
       What would be correct?

##########
File path: python/tvm/relay/transform/quantize/_requantizer.py
##########
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Removes extraneous qnn.quantize and qnn.dequantize from calibrated modules, and replaces them
+with qnn.requanize ops."""
+import math
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import DFPatternCallback, wildcard, is_op, dominates, rewrite
+
+
+class Requantizer:
+    """Removes extraneous qnn.quantize and qnn.dequantize and replaces
+    them with qnn.requantize."""
+
+    class RequantizerCallback(DFPatternCallback):
+        """First pass that inserts requantize ops, specifically taking
+        qnn.dequantize -> qnn.quantize to qnn.requantize
+        and
+        qnn.dequantize -> int8_op* -> qnn.quantize to requantize -> int8_op*
+        """
+
+        def __init__(self):
+            super().__init__()
+
+            self.data = wildcard()
+            self.dequantize_scale = wildcard()
+            self.dequantize_zp = wildcard()
+
+            self.quantize_scale = wildcard()
+            self.quantize_zp = wildcard()
+
+            # Ops that are permitted inbetween quantize and dequantize if we are
+            # rewriting to requantize
+            self.is_int_8_op = (
+                is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool3d")(wildcard())
+                | is_op("nn.relu")(wildcard())
+                | is_op("transpose")(wildcard())
+                | is_op("reshape")(wildcard())
+                | is_op("nn.pad")(wildcard())
+                | is_op("squeeze")(wildcard())
+                | is_op("nn.global_avg_pool2d")
+                | is_op("nn.batch_flatten")
+                | is_op("copy")
+                | is_op("mean")
+                | is_op("sqrt")
+            )

Review comment:
       Can you provide more feedback about how it can break and what quantize and dequantizes it leaves? On all the graphs I've tried this on so far, I've been able to remove all extraneous quantize and dequantizes. 

##########
File path: python/tvm/relay/transform/quantize/_requantizer.py
##########
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Removes extraneous qnn.quantize and qnn.dequantize from calibrated modules, and replaces them
+with qnn.requanize ops."""
+import math
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import DFPatternCallback, wildcard, is_op, dominates, rewrite
+
+
+class Requantizer:
+    """Removes extraneous qnn.quantize and qnn.dequantize and replaces
+    them with qnn.requantize."""
+
+    class RequantizerCallback(DFPatternCallback):
+        """First pass that inserts requantize ops, specifically taking
+        qnn.dequantize -> qnn.quantize to qnn.requantize
+        and
+        qnn.dequantize -> int8_op* -> qnn.quantize to requantize -> int8_op*
+        """
+
+        def __init__(self):
+            super().__init__()
+
+            self.data = wildcard()
+            self.dequantize_scale = wildcard()
+            self.dequantize_zp = wildcard()
+
+            self.quantize_scale = wildcard()
+            self.quantize_zp = wildcard()
+
+            # Ops that are permitted inbetween quantize and dequantize if we are
+            # rewriting to requantize
+            self.is_int_8_op = (
+                is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool3d")(wildcard())
+                | is_op("nn.relu")(wildcard())
+                | is_op("transpose")(wildcard())
+                | is_op("reshape")(wildcard())
+                | is_op("nn.pad")(wildcard())
+                | is_op("squeeze")(wildcard())
+                | is_op("nn.global_avg_pool2d")
+                | is_op("nn.batch_flatten")
+                | is_op("copy")
+                | is_op("mean")
+                | is_op("sqrt")
+            )
+
+            # All ops in is_int_8_op must also be in self.op_map
+            self.op_map = {
+                relay.op.get("nn.max_pool2d"): relay.op.nn.max_pool2d,
+                relay.op.get("nn.max_pool3d"): relay.op.nn.max_pool3d,
+                relay.op.get("transpose"): relay.op.transpose,
+                relay.op.get("reshape"): relay.op.reshape,
+                relay.op.get("nn.pad"): relay.op.nn.pad,
+                relay.op.get("squeeze"): relay.op.squeeze,
+                relay.op.get("nn.global_avg_pool2d"): relay.op.nn.global_avg_pool2d,
+                relay.op.get("nn.batch_flatten"): relay.op.nn.batch_flatten,
+                relay.op.get("copy"): relay.op.copy,
+                relay.op.get("mean"): relay.op.mean,
+                relay.op.get("sqrt"): relay.op.sqrt,
+            }
+
+            # Main pattern -- quantize(is_int_8_op*(dequantize(data))) --
+            # (with 1 or more is_int_8_ops)
+            self.dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+
+            self.dominator = dominates(self.dequantize, self.is_int_8_op, self.is_int_8_op)
+            self.quantize = is_op("qnn.quantize")(
+                self.dominator, self.quantize_scale, self.quantize_zp
+            )
+
+            # Pattern with the null path : quantize(dequantize(data)) -- (no is_int_8_op inbetween)
+            # We have to do the null path outside the dominator pattern because of pattern matcher
+            # limitations
+            self.no_path_dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+            self.no_path_quantize = is_op("qnn.quantize")(
+                self.no_path_dequantize, self.quantize_scale, self.quantize_zp
+            )
+
+            self.pattern = self.quantize | self.no_path_quantize
+
+        def callback(self, pre, post, node_map):
+            # Extract data from the pattern
+            data = node_map[self.data][0]
+            dequantize_scale = node_map[self.dequantize_scale][0]
+            deq_zp = node_map[self.dequantize_zp][0]
+
+            quantize_scale = node_map[self.quantize_scale][0]
+            quantize_zp = node_map[self.quantize_zp][0]
+
+            # Case where there are no ops in between the dequantize and quantize
+            if self.no_path_quantize in node_map:
+                axis = node_map[self.no_path_dequantize][0].attrs.axis
+                res = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+            # Ops inbetween quantize and dequantize are dominated
+            elif self.quantize in node_map:
+
+                axis = node_map[self.dequantize][0].attrs.axis
+                transformed_data = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+                for i in range(len(node_map[self.is_int_8_op]) - 1, -1, -1):
+                    call = node_map[self.is_int_8_op][i]
+                    # Transform relu into max(zeropoint)
+                    if call.op == relay.op.get("nn.relu"):
+                        if (
+                            quantize_zp.data.asnumpy()
+                            == relay.const(0, dtype="int32").data.asnumpy()
+                        ):
+                            transformed_data = relay.op.nn.relu(transformed_data)
+                        else:
+                            transformed_data = relay.op.maximum(
+                                transformed_data, relay.cast(quantize_zp, "int8")
+                            )
+                    elif call.op in self.op_map.keys():
+                        transformed_data = self.op_map[call.op](transformed_data, **call.attrs)
+                    else:
+                        raise ValueError(
+                            "Uh oh, %s is not copied properly in the requantizer. " % str(call.op)

Review comment:
       The allow_overlapping_patterns option lets me match dequantizes which are consumed by two quantizes. 

##########
File path: python/tvm/relay/transform/quantize/_quantizer_patterns.py
##########
@@ -0,0 +1,712 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Patterns to quantize and how to quantize them."""
+
+import tvm
+from tvm import relay
+
+from tvm.relay.transform.quantize import CalibrationCallback
+from tvm.relay.dataflow_pattern import (
+    is_op,
+    wildcard,
+    is_constant,
+    DFPatternCallback,
+    _DFPatternCallback,
+)
+from tvm.relay.dataflow_pattern import ffi as pattern_ffi
+from tvm.relay.frontend.common import infer_type
+from tvm.relay.op.nn.utils import get_pad_tuple2d
+
+
+class QuantizerPattern(DFPatternCallback):
+    """DFPatternCallback to rewrite patterns as quantized. Also contains extra information
+    used for quantization and calibration.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate the nn.conv2d pattern.
+    """
+
+    # Counts the number of times we've added a scale and zp for variable naming
+    # This needs to be a global variable and not initialized in __init__ because
+    # each scale and zero point must be unique, even if they are created by different
+    # instances.
+    scales_count = 0
+    zp_count = 0
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__()
+        self.calibration_callback = calibration_callback
+
+    def calibrate_pattern(self, calibration_info):
+        """Calculates the scale and zero points for quantizing parts of a generic pattern. By
+        default, we call the calibrate_pattern method of the CalibrationCallback object that is
+        passed into QuantizerPattern during initialization. However, if you want a pattern specific
+        quantization method or a per-channel quantization method, you should overwrite the
+        QuantizerPattern's calibrate_pattern method.
+
+        Parameters
+        ----------
+        calibration_info : CalibrationInfo
+            The class containing relevant information and utility functions to calibrate one
+            instance of a pattern.
+
+        Returns
+        -------
+        scale_zp_map : Dictionary
+            A map from the names of scales and zero point variables in this pattern to their
+            values.
+        """
+        return self.calibration_callback.calibrate_pattern(calibration_info)
+
+    def callback(self, pre, post, node_map):
+        raise NotImplementedError
+
+    def scale(self, name):
+        """Helper to create the scale variable for qnn.quantize when rewriting our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the scale variable.
+
+        is_weight : bool
+            Whether this scale is a weight scale or a data scale. If it is a weight scale, we
+            the returned variable has shape (channels,). Only used for per-channel quantization.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_scale_0'.
+        """
+
+        var = relay.var(
+            str(name) + "_scale_" + str(QuantizerPattern.scales_count), shape=(), dtype="float32"
+        )
+        QuantizerPattern.scales_count += 1
+        return var
+
+    def zero_point(self, name):
+        """Helper to create the zero point variable for qnn.quantize when rewriting our
+        our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the variable.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_zero_pt_0'.
+        """
+        var = relay.var(
+            str(name) + "_zero_pt_" + str(QuantizerPattern.zp_count), shape=(), dtype="int32"
+        )
+        QuantizerPattern.zp_count += 1
+        return var
+
+    def create_scale_zps(self, left_name, right_name):
+        """Helper to create scales and zero points for binops.
+
+        Parameters
+        ----------
+        left_name : str
+            Identifier of the left hand side scale and zero point.
+
+        right_name : str
+            Identifier of the right hand side scale and zero point.
+        """
+        data_scale = self.scale(left_name)
+        data_zp = self.zero_point(left_name)
+        weight_scale = self.scale(right_name)
+        weight_zp = self.zero_point(right_name)
+        self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp]
+
+
+class Conv2DPattern(QuantizerPattern):
+    """Pattern to rewrite nn.conv2d ops as qnn.conv2d ops.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.input = wildcard()
+        self.conv_weight = wildcard()
+        self.inputs = [self.input, self.conv_weight]
+        self.conv2d = is_op("nn.conv2d")(self.input, self.conv_weight)
+        self.pattern = self.conv2d
+        self.attrs = None
+        self.weight_channel_axis = None
+        self.data_channel_axis = None
+        self.channels = None
+
+    def get_kernel_size(self, kernel_shape, kernel_layout):
+        """Gets the size of the kernel.
+
+        Parameters
+        ----------
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        kernel_layout : str
+            Layout of the kernel
+
+        Returns
+        -------
+            kernel_size : NDArray
+                Size of the kernel
+        """
+        if kernel_layout == "OIHW":
+            kernel_size = tuple(kernel_shape[2:4])
+        elif kernel_layout == "HWIO":
+            kernel_size = tuple(kernel_shape[0:2])
+        else:
+            raise ValueError(
+                "Quantizting kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                kernel_layout,
+            )
+        return kernel_size
+
+    def get_attrs(self, attrs, kernel_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.conv2d
+
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        new_attr_dict = {}
+        self.kernel_layout = attrs["kernel_layout"]
+        data_layout = attrs["data_layout"]
+
+        if self.kernel_layout == "OIHW":
+            self.weight_channel_axis = 0
+        elif self.kernel_layout == "HWIO":
+            self.weight_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                self.kernel_layout,
+            )
+
+        if data_layout == "NCHW":
+            self.data_channel_axis = 1
+        elif data_layout == "NHWC":
+            self.data_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing data layout %s for conv2d is not yet supported."
+                + "Please use NCHW or NHWC",
+                data_layout,
+            )
+
+        for attr in attrs.keys():
+            attr_value = attrs[attr]
+            if isinstance(attr_value, tvm.ir.container.Array):
+                attr_value = tuple(attr_value)
+            if attr == "kernel_size":
+                kernel_size = attrs[attr]
+                if kernel_size is None:
+                    kernel_size = self.get_kernel_size(self.kernel_layout, kernel_shape)
+                else:
+                    kernel_size = tuple([k.value for k in attrs[attr]])
+                new_attr_dict[attr] = kernel_size
+            elif attr == "channels":
+                self.channels = attrs[attr]
+                if self.channels is None:
+                    self.channels = kernel_shape[self.weight_channel_axis]
+                if isinstance(self.channels, tvm.tir.expr.IntImm):
+                    self.channels = self.channels.value
+                new_attr_dict[attr] = self.channels
+            elif attr == "padding":
+                # We don't need to put padding in attr dict because we explicitly construct padding
+                self.padding = attrs[attr]
+            else:
+                new_attr_dict[attr] = attr_value
+
+        new_attr_dict["out_dtype"] = "int32"
+        self.attrs = new_attr_dict
+
+    def quantize_args(self):
+        """Helper to quantize the arguments to the qnn.conv2d."""
+        quantized_data = relay.qnn.op.quantize(
+            self.args[0], self.scale_zps[0], self.scale_zps[1], axis=self.data_channel_axis
+        )

Review comment:
       This is an interesting suggestion. It's important to introduce a new quantize node for every input so that we can correctly do requantize later. Instead of caching the whole quantize node, I could cache the scale and zero point from the previous time the node was cached. However, I'm not sure how we would deal with this during calibration, since we are calibrating on a per-pattern basis. Either the second value set for that scale and zero point would be the actual value, or we could try to remove the scale and zero point from the second pattern. But removing a scale and zero point from a pattern would introduce some inconsistency and complexity to the callbacks that I think is not good. 
   
   CSE does work for this when we have integer constants as scales and zero points, however, it might not work for floats. There are only two cases that this really becomes an issue though: 1. a weight gets quantized multiple times 2. we end up with multiple requantize ops that requantize the same node to slightly different values. I think that dealing with this in post processing (i.e., requantizer) is the best decision. Then we can combine scales and zero points through some rule, perhaps by averaging them. 

##########
File path: python/tvm/relay/transform/quantize/_quantizer.py
##########
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Quantizes functions by inserting qnn.quantize and qnn.dequantize ops."""
+from typing import List
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import _DFPatternCallback
+from tvm.relay.transform.quantize import QuantizerPattern
+from tvm.relay.frontend.common import infer_type
+
+from . import _ffi as ffi
+
+
+class Quantizer:

Review comment:
       There is definitely an argument to combine the Quantizer and the QuantizationCalibrator into one class, since they are pretty closely linked. 
   
   @jwfromm, @mbrookhart and I went back and forth about this. Initially, we made them separate because they were both large classes and did logically different things, and combining them would have resulted in a lot of code in one monolithic class. However, now that I have the QuantizerPattern class and the CalibrationInfo class, which contain a lot of the code, it could make sense to combine them into one class. 
   
   One argument to keep them separate is that the Quantizer does a lot of stuff using the C backend, and keeping that separate from the Calibrater is kind of nice for readability.
   
   @jwfromm @mbrookhart What do you think?

##########
File path: python/tvm/relay/transform/quantize/_requantizer.py
##########
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Removes extraneous qnn.quantize and qnn.dequantize from calibrated modules, and replaces them
+with qnn.requanize ops."""
+import math
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import DFPatternCallback, wildcard, is_op, dominates, rewrite
+
+
+class Requantizer:
+    """Removes extraneous qnn.quantize and qnn.dequantize and replaces
+    them with qnn.requantize."""
+
+    class RequantizerCallback(DFPatternCallback):
+        """First pass that inserts requantize ops, specifically taking
+        qnn.dequantize -> qnn.quantize to qnn.requantize
+        and
+        qnn.dequantize -> int8_op* -> qnn.quantize to requantize -> int8_op*
+        """
+
+        def __init__(self):
+            super().__init__()
+
+            self.data = wildcard()
+            self.dequantize_scale = wildcard()
+            self.dequantize_zp = wildcard()
+
+            self.quantize_scale = wildcard()
+            self.quantize_zp = wildcard()
+
+            # Ops that are permitted inbetween quantize and dequantize if we are
+            # rewriting to requantize
+            self.is_int_8_op = (
+                is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool3d")(wildcard())
+                | is_op("nn.relu")(wildcard())
+                | is_op("transpose")(wildcard())
+                | is_op("reshape")(wildcard())
+                | is_op("nn.pad")(wildcard())
+                | is_op("squeeze")(wildcard())
+                | is_op("nn.global_avg_pool2d")
+                | is_op("nn.batch_flatten")
+                | is_op("copy")
+                | is_op("mean")
+                | is_op("sqrt")
+            )
+
+            # All ops in is_int_8_op must also be in self.op_map
+            self.op_map = {
+                relay.op.get("nn.max_pool2d"): relay.op.nn.max_pool2d,
+                relay.op.get("nn.max_pool3d"): relay.op.nn.max_pool3d,
+                relay.op.get("transpose"): relay.op.transpose,
+                relay.op.get("reshape"): relay.op.reshape,
+                relay.op.get("nn.pad"): relay.op.nn.pad,
+                relay.op.get("squeeze"): relay.op.squeeze,
+                relay.op.get("nn.global_avg_pool2d"): relay.op.nn.global_avg_pool2d,
+                relay.op.get("nn.batch_flatten"): relay.op.nn.batch_flatten,
+                relay.op.get("copy"): relay.op.copy,
+                relay.op.get("mean"): relay.op.mean,
+                relay.op.get("sqrt"): relay.op.sqrt,
+            }
+
+            # Main pattern -- quantize(is_int_8_op*(dequantize(data))) --
+            # (with 1 or more is_int_8_ops)
+            self.dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+
+            self.dominator = dominates(self.dequantize, self.is_int_8_op, self.is_int_8_op)
+            self.quantize = is_op("qnn.quantize")(
+                self.dominator, self.quantize_scale, self.quantize_zp
+            )
+
+            # Pattern with the null path : quantize(dequantize(data)) -- (no is_int_8_op inbetween)
+            # We have to do the null path outside the dominator pattern because of pattern matcher
+            # limitations
+            self.no_path_dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+            self.no_path_quantize = is_op("qnn.quantize")(
+                self.no_path_dequantize, self.quantize_scale, self.quantize_zp
+            )
+
+            self.pattern = self.quantize | self.no_path_quantize
+
+        def callback(self, pre, post, node_map):
+            # Extract data from the pattern
+            data = node_map[self.data][0]
+            dequantize_scale = node_map[self.dequantize_scale][0]
+            deq_zp = node_map[self.dequantize_zp][0]
+
+            quantize_scale = node_map[self.quantize_scale][0]
+            quantize_zp = node_map[self.quantize_zp][0]
+
+            # Case where there are no ops in between the dequantize and quantize
+            if self.no_path_quantize in node_map:
+                axis = node_map[self.no_path_dequantize][0].attrs.axis
+                res = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+            # Ops inbetween quantize and dequantize are dominated
+            elif self.quantize in node_map:
+
+                axis = node_map[self.dequantize][0].attrs.axis
+                transformed_data = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+                for i in range(len(node_map[self.is_int_8_op]) - 1, -1, -1):
+                    call = node_map[self.is_int_8_op][i]
+                    # Transform relu into max(zeropoint)
+                    if call.op == relay.op.get("nn.relu"):
+                        if (
+                            quantize_zp.data.asnumpy()
+                            == relay.const(0, dtype="int32").data.asnumpy()
+                        ):
+                            transformed_data = relay.op.nn.relu(transformed_data)
+                        else:
+                            transformed_data = relay.op.maximum(
+                                transformed_data, relay.cast(quantize_zp, "int8")
+                            )
+                    elif call.op in self.op_map.keys():
+                        transformed_data = self.op_map[call.op](transformed_data, **call.attrs)
+                    else:
+                        raise ValueError(
+                            "Uh oh, %s is not copied properly in the requantizer. " % str(call.op)
+                        )
+                res = transformed_data
+            return res
+
+    class RequantizeChainCallback(DFPatternCallback):
+        """Folds chains of requantizes into one requantize.
+        requantize(scale_a, zp_a, scale_b, zp_b) -> requantize(scale_b, zp_b, scale_c, zp_c) becomes
+        requantize(scale_a, zp_a, scale_c, zp_c)
+        """
+
+        # Takes a chain of requantizes and turns them into one requantize
+        def __init__(self):
+            super().__init__()
+            self.data = wildcard()
+            self.rq_parent_scale_in = wildcard()
+            self.rq_parent_zp_in = wildcard()
+            self.rq_parent_scale_out = wildcard()
+            self.rq_parent_zp_out = wildcard()
+
+            self.rq_child_scale_in = wildcard()
+            self.rq_child_zp_in = wildcard()
+            self.rq_child_scale_out = wildcard()
+            self.rq_child_zp_out = wildcard()
+
+            self.rq_parent = is_op("qnn.requantize")(
+                self.data,
+                self.rq_parent_scale_in,
+                self.rq_parent_zp_in,
+                self.rq_parent_scale_out,
+                self.rq_parent_zp_out,
+            )
+            self.rq_child = is_op("qnn.requantize")(
+                wildcard(),
+                self.rq_child_scale_in,
+                self.rq_child_zp_in,
+                self.rq_child_scale_out,
+                self.rq_child_zp_out,
+            )
+
+            self.pattern = dominates(self.rq_parent, self.rq_child, self.rq_child)
+
+        def callback(self, pre, post, node_map):
+            data = node_map[self.data][0]
+            rq_parent = node_map[self.rq_parent][0]
+
+            rq_parent_scale_in = node_map[self.rq_parent_scale_in][0]
+            rq_parent_zp_in = node_map[self.rq_parent_zp_in][0]
+
+            rq_parent_scale_out = node_map[self.rq_parent_scale_out][0]
+            rq_parent_zp_out = node_map[self.rq_parent_zp_out][0]
+
+            child_in_scales = node_map[self.rq_child_scale_in]
+            child_in_zps = node_map[self.rq_child_zp_in]
+            child_out_scales = node_map[self.rq_child_scale_out]
+            child_out_zps = node_map[self.rq_child_zp_out]
+
+            len_children = len(node_map[self.rq_child_scale_out])
+
+            # Check to make sure output and input scales and zps match before we apply this
+            # transformation
+            out_scale = rq_parent_scale_out
+            out_zp = rq_parent_zp_out
+
+            for i in range(0, len_children):
+
+                in_scale = child_in_scales[i]
+                in_zp = child_in_zps[i]
+
+                assert math.isclose(
+                    out_scale.data.asnumpy(), in_scale.data.asnumpy(), rel_tol=1e-05, abs_tol=1e-05
+                ) and math.isclose(
+                    out_zp.data.asnumpy(), in_zp.data.asnumpy(), rel_tol=1e-05, abs_tol=1e-05
+                ), (
+                    "Out scales/zps should match in scales/zps. Indicates an internal issue "
+                    "in the quantizer somewhere."
+                )
+
+                out_scale = child_out_scales[i]
+                out_zp = child_out_zps[i]
+
+            parent_axis = rq_parent.attrs["axis"]
+
+            return relay.qnn.op.requantize(
+                data, rq_parent_scale_in, rq_parent_zp_in, out_scale, out_zp, axis=parent_axis
+            )
+
+    class ConsolidateRequantizeandQuantize(DFPatternCallback):
+        """Gets rid of unnecessary requantizes directly following a quantize. Takes
+        quantize(scale_a, zp_a) -> requantize(scale_a, zp_a, scale_b, zp_b) to

Review comment:
       A good example is actually the result of the add pattern. 
   
   ```
   quantize -> dequantize -> quantize -> 
   quantize -> dequantize -> quantize -> add -> dequantize
   ```
   becomes 
   ```
   quantize -> requantize -> 
   quantize -> requantize -> add -> dequantize
   ```
   Then we need to combine the quantize and requantize into one quantize:
   ```
   quantize -> 
   quantize -> add -> dequantize
   ```
   

##########
File path: python/tvm/relay/transform/quantize/_quantizer_patterns.py
##########
@@ -0,0 +1,712 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Patterns to quantize and how to quantize them."""
+
+import tvm
+from tvm import relay
+
+from tvm.relay.transform.quantize import CalibrationCallback
+from tvm.relay.dataflow_pattern import (
+    is_op,
+    wildcard,
+    is_constant,
+    DFPatternCallback,
+    _DFPatternCallback,
+)
+from tvm.relay.dataflow_pattern import ffi as pattern_ffi
+from tvm.relay.frontend.common import infer_type
+from tvm.relay.op.nn.utils import get_pad_tuple2d
+
+
+class QuantizerPattern(DFPatternCallback):
+    """DFPatternCallback to rewrite patterns as quantized. Also contains extra information
+    used for quantization and calibration.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate the nn.conv2d pattern.
+    """
+
+    # Counts the number of times we've added a scale and zp for variable naming
+    # This needs to be a global variable and not initialized in __init__ because
+    # each scale and zero point must be unique, even if they are created by different
+    # instances.
+    scales_count = 0
+    zp_count = 0
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__()
+        self.calibration_callback = calibration_callback
+
+    def calibrate_pattern(self, calibration_info):
+        """Calculates the scale and zero points for quantizing parts of a generic pattern. By
+        default, we call the calibrate_pattern method of the CalibrationCallback object that is
+        passed into QuantizerPattern during initialization. However, if you want a pattern specific
+        quantization method or a per-channel quantization method, you should overwrite the
+        QuantizerPattern's calibrate_pattern method.
+
+        Parameters
+        ----------
+        calibration_info : CalibrationInfo
+            The class containing relevant information and utility functions to calibrate one
+            instance of a pattern.
+
+        Returns
+        -------
+        scale_zp_map : Dictionary
+            A map from the names of scales and zero point variables in this pattern to their
+            values.
+        """
+        return self.calibration_callback.calibrate_pattern(calibration_info)
+
+    def callback(self, pre, post, node_map):
+        raise NotImplementedError
+
+    def scale(self, name):
+        """Helper to create the scale variable for qnn.quantize when rewriting our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the scale variable.
+
+        is_weight : bool
+            Whether this scale is a weight scale or a data scale. If it is a weight scale, we
+            the returned variable has shape (channels,). Only used for per-channel quantization.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_scale_0'.
+        """
+
+        var = relay.var(
+            str(name) + "_scale_" + str(QuantizerPattern.scales_count), shape=(), dtype="float32"
+        )
+        QuantizerPattern.scales_count += 1
+        return var
+
+    def zero_point(self, name):
+        """Helper to create the zero point variable for qnn.quantize when rewriting our
+        our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the variable.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_zero_pt_0'.
+        """
+        var = relay.var(
+            str(name) + "_zero_pt_" + str(QuantizerPattern.zp_count), shape=(), dtype="int32"
+        )
+        QuantizerPattern.zp_count += 1
+        return var
+
+    def create_scale_zps(self, left_name, right_name):
+        """Helper to create scales and zero points for binops.
+
+        Parameters
+        ----------
+        left_name : str
+            Identifier of the left hand side scale and zero point.
+
+        right_name : str
+            Identifier of the right hand side scale and zero point.
+        """
+        data_scale = self.scale(left_name)
+        data_zp = self.zero_point(left_name)
+        weight_scale = self.scale(right_name)
+        weight_zp = self.zero_point(right_name)
+        self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp]
+
+
+class Conv2DPattern(QuantizerPattern):
+    """Pattern to rewrite nn.conv2d ops as qnn.conv2d ops.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.input = wildcard()
+        self.conv_weight = wildcard()
+        self.inputs = [self.input, self.conv_weight]
+        self.conv2d = is_op("nn.conv2d")(self.input, self.conv_weight)
+        self.pattern = self.conv2d
+        self.attrs = None
+        self.weight_channel_axis = None
+        self.data_channel_axis = None
+        self.channels = None
+
+    def get_kernel_size(self, kernel_shape, kernel_layout):
+        """Gets the size of the kernel.
+
+        Parameters
+        ----------
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        kernel_layout : str
+            Layout of the kernel
+
+        Returns
+        -------
+            kernel_size : NDArray
+                Size of the kernel
+        """
+        if kernel_layout == "OIHW":
+            kernel_size = tuple(kernel_shape[2:4])
+        elif kernel_layout == "HWIO":
+            kernel_size = tuple(kernel_shape[0:2])
+        else:
+            raise ValueError(
+                "Quantizting kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                kernel_layout,
+            )
+        return kernel_size
+
+    def get_attrs(self, attrs, kernel_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.conv2d
+
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        new_attr_dict = {}
+        self.kernel_layout = attrs["kernel_layout"]
+        data_layout = attrs["data_layout"]
+
+        if self.kernel_layout == "OIHW":
+            self.weight_channel_axis = 0
+        elif self.kernel_layout == "HWIO":
+            self.weight_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                self.kernel_layout,
+            )
+
+        if data_layout == "NCHW":
+            self.data_channel_axis = 1
+        elif data_layout == "NHWC":
+            self.data_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing data layout %s for conv2d is not yet supported."
+                + "Please use NCHW or NHWC",
+                data_layout,
+            )
+
+        for attr in attrs.keys():
+            attr_value = attrs[attr]
+            if isinstance(attr_value, tvm.ir.container.Array):
+                attr_value = tuple(attr_value)
+            if attr == "kernel_size":
+                kernel_size = attrs[attr]
+                if kernel_size is None:
+                    kernel_size = self.get_kernel_size(self.kernel_layout, kernel_shape)
+                else:
+                    kernel_size = tuple([k.value for k in attrs[attr]])
+                new_attr_dict[attr] = kernel_size
+            elif attr == "channels":
+                self.channels = attrs[attr]
+                if self.channels is None:
+                    self.channels = kernel_shape[self.weight_channel_axis]
+                if isinstance(self.channels, tvm.tir.expr.IntImm):
+                    self.channels = self.channels.value
+                new_attr_dict[attr] = self.channels
+            elif attr == "padding":
+                # We don't need to put padding in attr dict because we explicitly construct padding
+                self.padding = attrs[attr]
+            else:
+                new_attr_dict[attr] = attr_value
+
+        new_attr_dict["out_dtype"] = "int32"
+        self.attrs = new_attr_dict
+
+    def quantize_args(self):
+        """Helper to quantize the arguments to the qnn.conv2d."""
+        quantized_data = relay.qnn.op.quantize(
+            self.args[0], self.scale_zps[0], self.scale_zps[1], axis=self.data_channel_axis
+        )
+        quantized_weight = relay.qnn.op.quantize(
+            self.args[1], self.scale_zps[2], self.scale_zps[3], axis=self.weight_channel_axis
+        )
+        self.quantized_args = [quantized_data, quantized_weight]
+
+    def create_conv(self, args, node_map):
+        """Creates the qnn.conv2d.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.conv2d.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_conv2d : relay.Expr
+            Quantized version of the pattern.
+        """
+        return relay.qnn.op.conv2d(*args, **self.attrs)
+
+    def callback(self, pre, post, node_map):
+        self.args = [node_map[i][0] for i in self.inputs]
+        conv2d = node_map[self.conv2d][0]
+
+        self.out_dtype = conv2d.checked_type.dtype
+
+        self.get_attrs(conv2d.attrs, infer_type(self.args[1]).checked_type.shape)
+
+        self.create_scale_zps("conv2d_data", "conv2d_weight")
+        self.quantize_args()
+
+        conv_scale = self.scale_zps[0] * self.scale_zps[2]  # data_scale * weight_scale
+
+        # Conv zp is zero since QNN deals with input zps for us
+        conv_zp = relay.const(0, dtype="int32")
+        # args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale]
+        args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]]
+
+        if self.padding is not None:
+
+            top, left, bottom, right = [p.value for p in get_pad_tuple2d(self.padding)]
+            if self.kernel_layout == "OIHW":
+                pad_width = ((0, 0), (0, 0), (top, bottom), (left, right))
+            elif self.kernel_layout == "HWIO":
+                pad_width = (
+                    (top, bottom),
+                    (left, right),
+                    (0, 0),
+                    (0, 0),
+                )
+            pad_val = 0
+            args[0] = relay.op.nn.pad(args[0], pad_width, pad_val)
+
+        # Construct quantized qnn.conv2d and dequantize
+        qnn_call = self.create_conv(args, node_map)
+        dequantized_call = relay.qnn.op.dequantize(
+            qnn_call, conv_scale, conv_zp, out_dtype=self.out_dtype, axis=self.data_channel_axis
+        )
+
+        return dequantized_call
+
+
+class Conv2DBiasAddPattern(Conv2DPattern):
+    """Pattern to rewrite nn.conv2d -> nn.bias_add pattern as qnn.conv2d -> nn.bias_add.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.bias_weight = is_constant()
+        self.inputs.append(self.bias_weight)
+        self.add = is_op("add")(self.conv2d, self.bias_weight)
+        self.bias_add = is_op("nn.bias_add")(self.conv2d, self.bias_weight)
+        self.pattern = self.bias_add | self.add
+
+    def quantize_args(self):
+        """Quantizes the arguments to the nn.conv2d -> nn.bias_add pattern."""
+        super().quantize_args()
+        quantized_bias = relay.qnn.op.quantize(
+            self.args[2], self.scale_zps[0], self.scale_zps[1], axis=0, out_dtype="int32"
+        )
+        self.quantized_args.append(quantized_bias)
+
+    def create_conv(self, args, node_map):
+        """Creates the qnn.dense -> nn.bias_add.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.conv2d and bias_add.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_conv2d : relay.Expr
+            Quantized version of the pattern.
+        """
+        qnn_call = relay.qnn.op.conv2d(*args, **self.attrs)
+        if node_map.get(self.add) is not None:
+            bias_add = relay.op.add(qnn_call, self.quantized_args[2])
+        else:  # self.bias_add in node_map
+            bias_add = relay.op.nn.bias_add(
+                qnn_call, self.quantized_args[2], axis=self.data_channel_axis
+            )
+        return bias_add
+
+
+class DensePattern(QuantizerPattern):
+    """Pattern to rewrite nn.dense pattern as qnn.dense.
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.data = wildcard()
+        self.weight = wildcard()
+        self.inputs = [self.data, self.weight]
+
+        self.dense = is_op("nn.dense")(self.data, self.weight)
+
+        self.pattern = self.dense
+        self.attrs = None
+        self.units = None
+
+    def get_attrs(self, attrs, weight_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.dense
+
+        weight_shape : NDArray
+            Shape of the dense weights
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        self.attrs = {}
+        units = attrs["units"]
+        if units is None:
+            units = weight_shape[0]
+        self.units = units.value
+        self.attrs["units"] = self.units
+
+    def quantize_args(self):
+        """Quantizes the arguments to the nn.dense pattern."""
+        # Quantize data and construct args for qnn.dense
+        quantized_data = relay.qnn.op.quantize(self.args[0], self.scale_zps[0], self.scale_zps[1])
+        quantized_weight = relay.qnn.op.quantize(
+            self.args[1], self.scale_zps[2], self.scale_zps[3], axis=0
+        )  # Axis = 0 for per channel quantization
+        self.quantized_args = [quantized_data, quantized_weight]
+
+    def create_dense(self, args, node_map):
+        """Creates the qnn.dense.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.dense.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_dense : relay.Expr
+            Quantized version of the pattern.
+        """
+        qnn_call = relay.qnn.op.dense(*args, **self.attrs)
+        return qnn_call
+
+    def callback(self, pre, post, node_map):
+        self.args = [node_map[i][0] for i in self.inputs]
+        weight = node_map[self.weight][0]
+
+        dense = node_map[self.dense][0]
+        out_dtype = dense.checked_type.dtype
+        self.get_attrs(dense.attrs, infer_type(weight).checked_type.shape)
+        self.create_scale_zps("dense_data", "dense_weight")
+        self.quantize_args()
+
+        # args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale]
+        args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]]
+        qnn_call = self.create_dense(args, node_map)
+
+        deq_call = relay.qnn.op.dequantize(
+            qnn_call,
+            self.scale_zps[0] * self.scale_zps[2],
+            relay.const(0, dtype="int32"),
+            out_dtype=out_dtype,
+            axis=1,
+        )
+
+        return deq_call
+
+
+class DenseBiasAddPattern(DensePattern):
+    """Pattern to rewrite nn.dense -> add and nn.dense -> nn.bias_add pattern as
+    qnn.dense -> nn.bias_add.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.bias_weight = is_constant()
+        self.inputs.append(self.bias_weight)
+        self.bias_add = is_op("nn.bias_add")(self.dense, self.bias_weight)
+        self.add = is_op("add")(self.dense, self.bias_weight)
+        self.pattern = self.bias_add | self.add
+
+    def quantize_args(self):
+        super().quantize_args()
+        quantized_bias = relay.qnn.op.quantize(
+            self.args[2], self.scale_zps[0], self.scale_zps[1], axis=0, out_dtype="int32"
+        )
+        self.quantized_args.append(quantized_bias)
+
+    def create_dense(self, args, node_map):
+        qnn_call = relay.qnn.op.dense(*args, **self.attrs)
+        if node_map.get(self.add) is not None:
+            bias_add = relay.op.add(qnn_call, self.quantized_args[2])
+        else:  # self.bias_add in node_map
+            bias_add = relay.op.nn.bias_add(
+                qnn_call, self.quantized_args[2], axis=1  # Axis is always 1 for dense
+            )
+        return bias_add
+
+
+class AddPattern(QuantizerPattern):
+    """Pattern to rewrite add as quantized.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.lhs = wildcard()
+        self.rhs = wildcard()
+        self.add = is_op("add")(self.lhs, self.rhs)
+        self.pattern = self.add
+
+    def callback(self, pre, post, node_map):
+        lhs = node_map[self.lhs][0]
+        rhs = node_map[self.rhs][0]
+
+        add = node_map[self.add][0]
+
+        out_dtype = infer_type(add).checked_type.dtype
+
+        # Create quantization parameters for arguments to this addition
+        self.create_scale_zps("add_lhs", "add_rhs")
+
+        # Quantize, dequantize, and requantize inputs to have scale lhs_scale + rhs_scale
+        # (Scale represents the lowest possible value representable in the quantized type,
+        # so the smallest representable output is lhs_scale + rhs_scale)
+
+        # We do this to avoid the requantize op in qnn's add, which causes issues with compilation
+        # Requantize will be inserted in a future pass
+        lhs_scale, lhs_zp, rhs_scale, rhs_zp = self.scale_zps
+        quantized_lhs = relay.qnn.op.quantize(lhs, lhs_scale, lhs_zp)
+        quantized_rhs = relay.qnn.op.quantize(rhs, rhs_scale, rhs_zp)
+
+        dequantized_lhs = relay.qnn.op.dequantize(
+            quantized_lhs, lhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )
+        dequantized_rhs = relay.qnn.op.dequantize(
+            quantized_rhs, rhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )

Review comment:
       I will add a comment here explaining what is going on. The main reason I do it this way is because the qnn.add op implicitly quantizes the args, and uses requantize inside it and generally does some weird stuff. 
   
   We want to quantize lhs and rhs to lhs_scale and rhs_scale, respectively, and then requantize to lhs_scale + rhs_scale before adding and dequantizing. I have simply broken the requantize into dequantize -> quantize. 

##########
File path: python/tvm/relay/transform/quantize/_calibrator.py
##########
@@ -0,0 +1,382 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""API for calibrating a quantized function."""
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm.contrib import graph_runtime
+import tvm.relay.build_module as build_module
+
+
+class QuantizationCalibrator:
+    """The QuantizationCalibrator picks scales and zero points for all qnn ops in the quantized
+    module.
+
+    Parameters
+    ----------
+    quantizer : Quantizer
+        Quantizer created with the mod we are calibrating.
+
+    target : String, optional
+        The target to run the quantized function on during calibration.
+
+    ctx : String, optional
+        The ctx used for running the quantized function on during calibration.
+
+    dataset_manager : DatasetManager, optional
+        The dataset manager containing data used to run the graph during
+        data-aware calibration.
+    """
+
+    def __init__(self, quantizer, target="llvm", ctx=tvm.cpu(), dataset_manager=None,
+                 show_scale_zps=False):
+        self.quantizer = quantizer
+
+        self.calibration_info = CalibrationInfo(
+            quantizer.tuple_subgraph_func,
+            quantizer.q_tuple_subgraph_func,
+            quantizer.partition_infos,
+            dataset_manager,
+            target,
+            ctx,
+        )
+
+        self.show_scale_zps = show_scale_zps
+
+    def calibrate(self):
+        """Picks the scales and zero points for all qnn ops in the quantized graph, using the
+        calibrate_pattern function from the quantizer.
+
+        Returns
+        -------
+        calibrated_func : relay.Function
+            The quantized function with the values for scales and zero points substituted into the
+            function.
+        """
+        # Create a map of DFPatternCallback to QuantizerPattern
+        pattern_map = {pattern.pattern: pattern for pattern in self.quantizer.patterns}
+
+        for partition_info in self.calibration_info.partition_infos:
+            # Set the partition info so we can access it from the callback
+            self.calibration_info.set_current_partition_info(partition_info)
+            quantizer_pattern = pattern_map[partition_info.pattern]
+
+            # Get the values for scales and ZPs in this layer, store
+            scale_zps = quantizer_pattern.calibrate_pattern(self.calibration_info)
+            if self.show_scale_zps:
+                self.report_scale_zps(scale_zps)
+            self.calibration_info.update_scale_zp_map(scale_zps)
+
+        calibrated_func = build_module.bind_params_by_name(
+            self.quantizer.q_tuple_subgraph_func, self.calibration_info.scale_zp_value_map
+        )
+
+        # If num_orig_outputs is -1, original output wasn't a tuple
+        params = calibrated_func.params
+        if self.quantizer.num_orig_outputs == -1:
+            calibrated_func = relay.Function(params, calibrated_func.body.fields[0])
+        else:
+            new_body = relay.Tuple(calibrated_func.body.fields[0 : self.quantizer.num_orig_outputs])
+            calibrated_func = relay.Function(params, new_body)
+
+        return calibrated_func
+
+    def report_scale_zps(self, scale_zp_map):
+        """Prints the scales and zero points out.
+
+        Parameters
+        ----------
+        scale_zp_map : dict of str to value
+            The map from names of scale and zero point variables to their assigned values.
+        """
+        for key, value in scale_zp_map.items():
+            print("Set ", key, " variable to ", value)
+
+
+class CalibrationInfo:
+    """Helper class that contains information necessary for picking scales and zero points into
+    calibrate_pattern. The state of CalibrationInfo is updated by QuantizationCalibrator.
+
+    Parameters
+    ----------
+    tuple_subgraph_func : relay.Function
+        A function whose output is a tuple that contains values we will need to access during
+        calibration.
+
+    q_tuple_subgraph_func : relay.Function
+        A quantized version of the tuple_subgraph_func. Note that to run this function, you
+        must pass in values for scales and zero points.
+
+    partition_infos : List[PatternCalibrationInfo]
+        A list of objects that correspond to every pattern matched during quantization. Each
+        contains scale and zero point variables, and indices into the the tuple functions.
+
+    dataset_manager : DatasetManager
+        The dataset manager containing data used to run the graph during data-aware calibration.
+
+    target : String
+        The target to run the quantized function on during calibration.
+
+    ctx : String
+        The ctx used for running the quantized function on during calibration.
+    """
+
+    def __init__(
+        self,
+        tuple_subgraph_func,
+        q_tuple_subgraph_func,
+        partition_infos,
+        dataset_manager,
+        target,
+        ctx,
+    ):
+        self.tuple_subgraph_func = tuple_subgraph_func
+        self.q_tuple_subgraph_func = q_tuple_subgraph_func
+        self.dataset_manager = dataset_manager
+        self.partition_infos = partition_infos
+        self.target = target
+        self.ctx = ctx
+
+        self.partition_info = None
+        self.input_scale_zps = None
+
+        tuple_subgraph_mod = tvm.ir.IRModule.from_expr(self.tuple_subgraph_func)
+        q_tuple_subgraph_mod = tvm.ir.IRModule.from_expr(self.q_tuple_subgraph_func)
+
+        self.tuple_subgraph_graphmodule = None
+        self.q_tuple_subgraph_graphmodule = None
+        self.init_subgraph_graphmodules(tuple_subgraph_mod, q_tuple_subgraph_mod)
+
+        self.scale_zp_value_map = {}
+        self.initialize_scale_zp_map()
+
+    def init_subgraph_graphmodules(self, tuple_subgraph_mod, q_tuple_subgraph_mod):
+        """Builds the tuple subgraphs so they can be run during calibration.
+
+        Parameters
+        ----------
+        tuple_subgraph_mod : tvm.ir.IRModule
+            Module wrapping tuple_subgraph_func.
+
+        q_tuple_subgraph_mod : tvm.ir.IRModule
+            Module wrapping q_tuple_subgraph_func.
+        """
+        # AlterOpLayout is disabled because it inserts some pads and other ops
+        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            tuple_subgraph_lib = relay.build(tuple_subgraph_mod, target=self.target)
+            q_tuple_subgraph_lib = relay.build(q_tuple_subgraph_mod, target=self.target)
+
+        ts_graph_mod = graph_runtime.GraphModule(tuple_subgraph_lib["default"](self.ctx))
+        q_ts_graph_mod = graph_runtime.GraphModule(q_tuple_subgraph_lib["default"](self.ctx))
+        self.tuple_subgraph_graphmodule = ts_graph_mod
+        self.q_tuple_subgraph_graphmodule = q_ts_graph_mod
+
+    def initialize_scale_zp_map(self):
+        """Initializes scales to 1 and zero points to zero. These values will only be used
+        to calculate values in the tuple subgraph that are not returned to the user."""

Review comment:
       These initial values aren't exposed -- Right now, you have to pass scale and zero point values into get_quantized_layer_inputs. 1 is just a placeholder for scales the parts of the graph that we haven't called a calibrator callback for yet. 
   However, in case someone forgets to set a scale or zero point, maybe I should make the default value better. I actually think that a lower default scale might be better since most ML activation values are very small (maybe something like 0.05?) In my experience that scale at least produces non-zero outputs that are the same or similar magnitude to the non-quantized graph. 
   Also as a quick note, @jwfromm (as well as you) suggested changing the API a little bit so that you don't have to pass in scale and zero point to get_quantized_layer_inputs, so I'm going to do that. It makes this problem less pressing and also lets people directly use the value of quantized data to calculate scales. 

##########
File path: python/tvm/relay/transform/quantize/_quantizer_patterns.py
##########
@@ -0,0 +1,712 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Patterns to quantize and how to quantize them."""
+
+import tvm
+from tvm import relay
+
+from tvm.relay.transform.quantize import CalibrationCallback
+from tvm.relay.dataflow_pattern import (
+    is_op,
+    wildcard,
+    is_constant,
+    DFPatternCallback,
+    _DFPatternCallback,
+)
+from tvm.relay.dataflow_pattern import ffi as pattern_ffi
+from tvm.relay.frontend.common import infer_type
+from tvm.relay.op.nn.utils import get_pad_tuple2d
+
+
+class QuantizerPattern(DFPatternCallback):
+    """DFPatternCallback to rewrite patterns as quantized. Also contains extra information
+    used for quantization and calibration.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate the nn.conv2d pattern.
+    """
+
+    # Counts the number of times we've added a scale and zp for variable naming
+    # This needs to be a global variable and not initialized in __init__ because
+    # each scale and zero point must be unique, even if they are created by different
+    # instances.
+    scales_count = 0
+    zp_count = 0
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__()
+        self.calibration_callback = calibration_callback
+
+    def calibrate_pattern(self, calibration_info):
+        """Calculates the scale and zero points for quantizing parts of a generic pattern. By
+        default, we call the calibrate_pattern method of the CalibrationCallback object that is
+        passed into QuantizerPattern during initialization. However, if you want a pattern specific
+        quantization method or a per-channel quantization method, you should overwrite the
+        QuantizerPattern's calibrate_pattern method.
+
+        Parameters
+        ----------
+        calibration_info : CalibrationInfo
+            The class containing relevant information and utility functions to calibrate one
+            instance of a pattern.
+
+        Returns
+        -------
+        scale_zp_map : Dictionary
+            A map from the names of scales and zero point variables in this pattern to their
+            values.
+        """
+        return self.calibration_callback.calibrate_pattern(calibration_info)
+
+    def callback(self, pre, post, node_map):
+        raise NotImplementedError
+
+    def scale(self, name):
+        """Helper to create the scale variable for qnn.quantize when rewriting our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the scale variable.
+
+        is_weight : bool
+            Whether this scale is a weight scale or a data scale. If it is a weight scale, we
+            the returned variable has shape (channels,). Only used for per-channel quantization.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_scale_0'.
+        """
+
+        var = relay.var(
+            str(name) + "_scale_" + str(QuantizerPattern.scales_count), shape=(), dtype="float32"
+        )
+        QuantizerPattern.scales_count += 1
+        return var
+
+    def zero_point(self, name):
+        """Helper to create the zero point variable for qnn.quantize when rewriting our
+        our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the variable.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_zero_pt_0'.
+        """
+        var = relay.var(
+            str(name) + "_zero_pt_" + str(QuantizerPattern.zp_count), shape=(), dtype="int32"
+        )
+        QuantizerPattern.zp_count += 1
+        return var
+
+    def create_scale_zps(self, left_name, right_name):
+        """Helper to create scales and zero points for binops.
+
+        Parameters
+        ----------
+        left_name : str
+            Identifier of the left hand side scale and zero point.
+
+        right_name : str
+            Identifier of the right hand side scale and zero point.
+        """
+        data_scale = self.scale(left_name)
+        data_zp = self.zero_point(left_name)
+        weight_scale = self.scale(right_name)
+        weight_zp = self.zero_point(right_name)
+        self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp]
+
+
+class Conv2DPattern(QuantizerPattern):
+    """Pattern to rewrite nn.conv2d ops as qnn.conv2d ops.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.input = wildcard()
+        self.conv_weight = wildcard()
+        self.inputs = [self.input, self.conv_weight]
+        self.conv2d = is_op("nn.conv2d")(self.input, self.conv_weight)
+        self.pattern = self.conv2d
+        self.attrs = None
+        self.weight_channel_axis = None
+        self.data_channel_axis = None
+        self.channels = None
+
+    def get_kernel_size(self, kernel_shape, kernel_layout):
+        """Gets the size of the kernel.
+
+        Parameters
+        ----------
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        kernel_layout : str
+            Layout of the kernel
+
+        Returns
+        -------
+            kernel_size : NDArray
+                Size of the kernel
+        """
+        if kernel_layout == "OIHW":
+            kernel_size = tuple(kernel_shape[2:4])
+        elif kernel_layout == "HWIO":
+            kernel_size = tuple(kernel_shape[0:2])
+        else:
+            raise ValueError(
+                "Quantizting kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                kernel_layout,
+            )
+        return kernel_size
+
+    def get_attrs(self, attrs, kernel_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.conv2d
+
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        new_attr_dict = {}
+        self.kernel_layout = attrs["kernel_layout"]
+        data_layout = attrs["data_layout"]
+
+        if self.kernel_layout == "OIHW":
+            self.weight_channel_axis = 0
+        elif self.kernel_layout == "HWIO":
+            self.weight_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                self.kernel_layout,
+            )
+
+        if data_layout == "NCHW":
+            self.data_channel_axis = 1
+        elif data_layout == "NHWC":
+            self.data_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing data layout %s for conv2d is not yet supported."
+                + "Please use NCHW or NHWC",
+                data_layout,
+            )
+
+        for attr in attrs.keys():
+            attr_value = attrs[attr]
+            if isinstance(attr_value, tvm.ir.container.Array):
+                attr_value = tuple(attr_value)
+            if attr == "kernel_size":
+                kernel_size = attrs[attr]
+                if kernel_size is None:
+                    kernel_size = self.get_kernel_size(self.kernel_layout, kernel_shape)
+                else:
+                    kernel_size = tuple([k.value for k in attrs[attr]])
+                new_attr_dict[attr] = kernel_size
+            elif attr == "channels":
+                self.channels = attrs[attr]
+                if self.channels is None:
+                    self.channels = kernel_shape[self.weight_channel_axis]
+                if isinstance(self.channels, tvm.tir.expr.IntImm):
+                    self.channels = self.channels.value
+                new_attr_dict[attr] = self.channels
+            elif attr == "padding":
+                # We don't need to put padding in attr dict because we explicitly construct padding
+                self.padding = attrs[attr]
+            else:
+                new_attr_dict[attr] = attr_value
+
+        new_attr_dict["out_dtype"] = "int32"
+        self.attrs = new_attr_dict
+
+    def quantize_args(self):
+        """Helper to quantize the arguments to the qnn.conv2d."""
+        quantized_data = relay.qnn.op.quantize(
+            self.args[0], self.scale_zps[0], self.scale_zps[1], axis=self.data_channel_axis
+        )
+        quantized_weight = relay.qnn.op.quantize(
+            self.args[1], self.scale_zps[2], self.scale_zps[3], axis=self.weight_channel_axis
+        )
+        self.quantized_args = [quantized_data, quantized_weight]
+
+    def create_conv(self, args, node_map):
+        """Creates the qnn.conv2d.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.conv2d.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_conv2d : relay.Expr
+            Quantized version of the pattern.
+        """
+        return relay.qnn.op.conv2d(*args, **self.attrs)
+
+    def callback(self, pre, post, node_map):
+        self.args = [node_map[i][0] for i in self.inputs]
+        conv2d = node_map[self.conv2d][0]
+
+        self.out_dtype = conv2d.checked_type.dtype
+
+        self.get_attrs(conv2d.attrs, infer_type(self.args[1]).checked_type.shape)
+
+        self.create_scale_zps("conv2d_data", "conv2d_weight")
+        self.quantize_args()
+
+        conv_scale = self.scale_zps[0] * self.scale_zps[2]  # data_scale * weight_scale
+
+        # Conv zp is zero since QNN deals with input zps for us
+        conv_zp = relay.const(0, dtype="int32")
+        # args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale]
+        args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]]
+
+        if self.padding is not None:
+
+            top, left, bottom, right = [p.value for p in get_pad_tuple2d(self.padding)]
+            if self.kernel_layout == "OIHW":
+                pad_width = ((0, 0), (0, 0), (top, bottom), (left, right))
+            elif self.kernel_layout == "HWIO":
+                pad_width = (
+                    (top, bottom),
+                    (left, right),
+                    (0, 0),
+                    (0, 0),
+                )
+            pad_val = 0
+            args[0] = relay.op.nn.pad(args[0], pad_width, pad_val)
+
+        # Construct quantized qnn.conv2d and dequantize
+        qnn_call = self.create_conv(args, node_map)
+        dequantized_call = relay.qnn.op.dequantize(
+            qnn_call, conv_scale, conv_zp, out_dtype=self.out_dtype, axis=self.data_channel_axis
+        )
+
+        return dequantized_call
+
+
+class Conv2DBiasAddPattern(Conv2DPattern):
+    """Pattern to rewrite nn.conv2d -> nn.bias_add pattern as qnn.conv2d -> nn.bias_add.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.bias_weight = is_constant()
+        self.inputs.append(self.bias_weight)
+        self.add = is_op("add")(self.conv2d, self.bias_weight)
+        self.bias_add = is_op("nn.bias_add")(self.conv2d, self.bias_weight)
+        self.pattern = self.bias_add | self.add
+
+    def quantize_args(self):
+        """Quantizes the arguments to the nn.conv2d -> nn.bias_add pattern."""
+        super().quantize_args()
+        quantized_bias = relay.qnn.op.quantize(
+            self.args[2], self.scale_zps[0], self.scale_zps[1], axis=0, out_dtype="int32"
+        )
+        self.quantized_args.append(quantized_bias)
+
+    def create_conv(self, args, node_map):
+        """Creates the qnn.dense -> nn.bias_add.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.conv2d and bias_add.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_conv2d : relay.Expr
+            Quantized version of the pattern.
+        """
+        qnn_call = relay.qnn.op.conv2d(*args, **self.attrs)
+        if node_map.get(self.add) is not None:
+            bias_add = relay.op.add(qnn_call, self.quantized_args[2])
+        else:  # self.bias_add in node_map
+            bias_add = relay.op.nn.bias_add(
+                qnn_call, self.quantized_args[2], axis=self.data_channel_axis
+            )
+        return bias_add
+
+
+class DensePattern(QuantizerPattern):
+    """Pattern to rewrite nn.dense pattern as qnn.dense.
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.data = wildcard()
+        self.weight = wildcard()
+        self.inputs = [self.data, self.weight]
+
+        self.dense = is_op("nn.dense")(self.data, self.weight)
+
+        self.pattern = self.dense
+        self.attrs = None
+        self.units = None
+
+    def get_attrs(self, attrs, weight_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.dense
+
+        weight_shape : NDArray
+            Shape of the dense weights
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        self.attrs = {}
+        units = attrs["units"]
+        if units is None:
+            units = weight_shape[0]
+        self.units = units.value
+        self.attrs["units"] = self.units
+
+    def quantize_args(self):
+        """Quantizes the arguments to the nn.dense pattern."""
+        # Quantize data and construct args for qnn.dense
+        quantized_data = relay.qnn.op.quantize(self.args[0], self.scale_zps[0], self.scale_zps[1])
+        quantized_weight = relay.qnn.op.quantize(
+            self.args[1], self.scale_zps[2], self.scale_zps[3], axis=0
+        )  # Axis = 0 for per channel quantization
+        self.quantized_args = [quantized_data, quantized_weight]
+
+    def create_dense(self, args, node_map):
+        """Creates the qnn.dense.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.dense.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_dense : relay.Expr
+            Quantized version of the pattern.
+        """
+        qnn_call = relay.qnn.op.dense(*args, **self.attrs)
+        return qnn_call
+
+    def callback(self, pre, post, node_map):
+        self.args = [node_map[i][0] for i in self.inputs]
+        weight = node_map[self.weight][0]
+
+        dense = node_map[self.dense][0]
+        out_dtype = dense.checked_type.dtype
+        self.get_attrs(dense.attrs, infer_type(weight).checked_type.shape)
+        self.create_scale_zps("dense_data", "dense_weight")
+        self.quantize_args()
+
+        # args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale]
+        args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]]
+        qnn_call = self.create_dense(args, node_map)
+
+        deq_call = relay.qnn.op.dequantize(
+            qnn_call,
+            self.scale_zps[0] * self.scale_zps[2],
+            relay.const(0, dtype="int32"),
+            out_dtype=out_dtype,
+            axis=1,
+        )
+
+        return deq_call
+
+
+class DenseBiasAddPattern(DensePattern):
+    """Pattern to rewrite nn.dense -> add and nn.dense -> nn.bias_add pattern as
+    qnn.dense -> nn.bias_add.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.bias_weight = is_constant()
+        self.inputs.append(self.bias_weight)
+        self.bias_add = is_op("nn.bias_add")(self.dense, self.bias_weight)
+        self.add = is_op("add")(self.dense, self.bias_weight)
+        self.pattern = self.bias_add | self.add
+
+    def quantize_args(self):
+        super().quantize_args()
+        quantized_bias = relay.qnn.op.quantize(
+            self.args[2], self.scale_zps[0], self.scale_zps[1], axis=0, out_dtype="int32"
+        )
+        self.quantized_args.append(quantized_bias)
+
+    def create_dense(self, args, node_map):
+        qnn_call = relay.qnn.op.dense(*args, **self.attrs)
+        if node_map.get(self.add) is not None:
+            bias_add = relay.op.add(qnn_call, self.quantized_args[2])
+        else:  # self.bias_add in node_map
+            bias_add = relay.op.nn.bias_add(
+                qnn_call, self.quantized_args[2], axis=1  # Axis is always 1 for dense
+            )
+        return bias_add
+
+
+class AddPattern(QuantizerPattern):
+    """Pattern to rewrite add as quantized.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.lhs = wildcard()
+        self.rhs = wildcard()
+        self.add = is_op("add")(self.lhs, self.rhs)
+        self.pattern = self.add
+
+    def callback(self, pre, post, node_map):
+        lhs = node_map[self.lhs][0]
+        rhs = node_map[self.rhs][0]
+
+        add = node_map[self.add][0]
+
+        out_dtype = infer_type(add).checked_type.dtype
+
+        # Create quantization parameters for arguments to this addition
+        self.create_scale_zps("add_lhs", "add_rhs")
+
+        # Quantize, dequantize, and requantize inputs to have scale lhs_scale + rhs_scale
+        # (Scale represents the lowest possible value representable in the quantized type,
+        # so the smallest representable output is lhs_scale + rhs_scale)
+
+        # We do this to avoid the requantize op in qnn's add, which causes issues with compilation
+        # Requantize will be inserted in a future pass
+        lhs_scale, lhs_zp, rhs_scale, rhs_zp = self.scale_zps
+        quantized_lhs = relay.qnn.op.quantize(lhs, lhs_scale, lhs_zp)
+        quantized_rhs = relay.qnn.op.quantize(rhs, rhs_scale, rhs_zp)
+
+        dequantized_lhs = relay.qnn.op.dequantize(
+            quantized_lhs, lhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )
+        dequantized_rhs = relay.qnn.op.dequantize(
+            quantized_rhs, rhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )
+
+        add_scale = relay.op.add(lhs_scale, rhs_scale)
+
+        requantized_lhs = relay.qnn.op.quantize(
+            dequantized_lhs, add_scale, relay.const(0, dtype="int32")
+        )
+        requantized_rhs = relay.qnn.op.quantize(
+            dequantized_rhs, add_scale, relay.const(0, dtype="int32")
+        )
+
+        add = relay.op.add(requantized_lhs, requantized_rhs)
+        dequantized_call = relay.qnn.op.dequantize(
+            add, add_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )
+
+        return dequantized_call
+
+
+class MultiplyPattern(QuantizerPattern):
+    """Pattern to rewrite multiply as quantized.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.lhs = wildcard()
+        self.rhs = wildcard()
+
+        self.multiply = is_op("multiply")(self.lhs, self.rhs)
+        self.pattern = self.multiply
+
+    def callback(self, pre, post, node_map):
+        lhs = node_map[self.lhs][0]
+        rhs = node_map[self.rhs][0]
+
+        multiply = node_map[self.multiply][0]
+
+        out_dtype = infer_type(multiply).checked_type.dtype
+
+        # Create quantization parameters for arguments to this multiplication.
+        self.create_scale_zps("mul_lhs", "mul_rhs")
+        lhs_scale, lhs_zp, rhs_scale, rhs_zp = self.scale_zps
+
+        # Quantize inputs and construct args for multiply
+        quantized_lhs = tvm.relay.cast(relay.qnn.op.quantize(lhs, lhs_scale, lhs_zp), "int32")
+        quantized_rhs = tvm.relay.cast(relay.qnn.op.quantize(rhs, rhs_scale, rhs_zp), "int32")
+
+        # Use normal relay multiply instead of qnn multiply to avoid requantize in qnn.mul
+        # Subtract zero points to center on zero so that we can multiply lhs, rhs directly
+        zeroed_quantized_lhs = relay.op.subtract(quantized_lhs, lhs_zp)
+        zeroed_quantized_rhs = relay.op.subtract(quantized_rhs, rhs_zp)
+
+        multiply = relay.op.multiply(zeroed_quantized_lhs, zeroed_quantized_rhs)
+        dequantized_call = relay.qnn.op.dequantize(
+            multiply, lhs_scale * rhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )
+
+        return dequantized_call
+
+
+class PerChannelPattern:
+    """A parent class for patterns that will be per-channel quantized. PerChannelPattern should
+    only be inherited by a class that also inherits QuantizerPattern or a subclass of it.
+    """
+
+    def extract_attrs(self, pre, post, node_map):
+        """A callback to get the quantized attributes of this pattern. Usually, we just call
+        self.get_attrs on the attributes of the original, unquantized node to construct the
+        quantized attributes. Since this callback is used by the pattern rewriter, we must return
+        a relay.Expr from it.
+
+        Parameters
+        ----------
+        pre : relay.Expr
+            Expression before transformation
+
+        post : relay.Expr
+            Expression after transformation
+
+        node_map : Map of pattern to relay.Expr
+            Contains expressions matching parts of the pattern.
+
+        Returns
+        -------
+        post : relay.Expr
+            Expression to rewrite the input expression as. We don't actually want to rewrite
+            anything in this pass, so you should just return post.
+        """
+        raise NotImplementedError()
+
+    def get_scale_size(self):
+        """Returns the size of the per-channel scale variable
+
+        Returns
+        -------
+        scale_size : tuple
+            The size of the scale variable
+        """
+        raise NotImplementedError
+
+    def weight_scale(self, name):
+        """Helper to create a variable for a per-channel scale.
+        Parameters
+        ----------
+        name : str
+            Name of the variable
+        """
+        var = relay.var(
+            str(name) + "_scale_" + str(QuantizerPattern.scales_count),
+            shape=self.get_scale_size(),
+            dtype="float32",
+        )
+        QuantizerPattern.scales_count += 1
+        return var
+
+    def create_scale_zps(self, left_name, right_name):
+        """Helper to create scales and zero points for binops, with the per channel weight scale
+        quantized.
+
+        Parameters
+        ----------
+        left_name : str
+            Identifier of the left hand side scale and zero point.
+
+        right_name : str
+            Identifier of the right hand side scale and zero point.
+        """
+        # Create quantization parameters for arguments with per channel on the right
+        data_scale = self.scale(left_name)
+        data_zp = self.zero_point(left_name)
+
+        weight_scale = self.weight_scale(right_name)
+        weight_zp = self.zero_point(right_name)
+        self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp]
+
+    def attr_callback(self, expr):
+        """A function to get the attributes of the quantized version of the current
+        pattern. Meant to be called from inside calibrate_pattern.
+
+        Parameters
+        ----------
+        expr : relay.Expr
+            Expression that we want the attributes from. This will be the unquantized
+            version of the expression.
+        """
+        pattern_ffi.rewrite(
+            [_DFPatternCallback(self.pattern, self.extract_attrs, self.require_type)],

Review comment:
       Yes, they are. This class is only meant to be used for multiple inheritance with a QuantizerPattern, which is why it looks a little weird. When I test the AverageMaxPerChannelConv2dPattern and AverageMaxPerChannelDensePattern, all these methods are tested. 
   I will make PerChannelPattern and AverageMaxPerChannelPattern abstact to make it clear that they should not be instantiated. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] anijain2305 commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r580763624



##########
File path: include/tvm/relay/qnn/attrs.h
##########
@@ -78,13 +78,18 @@ struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
 /*! \brief Attribute for dequantize operator */
 struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
   int axis;
+  DataType out_dtype;
 
   TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
     TVM_ATTR_FIELD(axis)
         .describe(
             "The channel axis for channel wise dequantization. Default value is -1,"
             "which corresponds to the last axis.")
         .set_default(-1);
+    TVM_ATTR_FIELD(out_dtype)
+        .describe(
+            "The datatype we are dequantizing to (float32 or int32). Defaults to float32.")

Review comment:
       Introducing a new quantization related operator like - `simulated_quantize` might be better. This op could take any input dtype and any out dtype, and you can handle all the cases internally. You can use this op for calibration, and once you have figured out good scale and zero points, you can replace this with its QNN counterpart depending on the in and out dtypes.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#issuecomment-781925088


   @electriclilies Can you add an end to end runnable example? Like importing a pytorch or onnx graph and quantize it. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r580740763



##########
File path: src/relay/qnn/op/quantize.cc
##########
@@ -46,19 +47,22 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   }
 
   const auto input_dtype = data->dtype;
-  ICHECK(input_dtype == DataType::Float(32))
-      << "Input type should be one of float32 but was " << input_dtype;
+  ICHECK(input_dtype == DataType::Float(32) || input_dtype == DataType::Int(32))

Review comment:
       Sure but it is the output that needs to be int32. Here, you are allowing inputs to be int32. Am I missing something?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r578914505



##########
File path: python/tvm/relay/transform/quantize/_requantizer.py
##########
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Removes extraneous qnn.quantize and qnn.dequantize from calibrated modules, and replaces them
+with qnn.requanize ops."""
+import math
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import DFPatternCallback, wildcard, is_op, dominates, rewrite
+
+
+class Requantizer:
+    """Removes extraneous qnn.quantize and qnn.dequantize and replaces
+    them with qnn.requantize."""
+
+    class RequantizerCallback(DFPatternCallback):
+        """First pass that inserts requantize ops, specifically taking
+        qnn.dequantize -> qnn.quantize to qnn.requantize
+        and
+        qnn.dequantize -> int8_op* -> qnn.quantize to requantize -> int8_op*
+        """
+
+        def __init__(self):
+            super().__init__()
+
+            self.data = wildcard()
+            self.dequantize_scale = wildcard()
+            self.dequantize_zp = wildcard()
+
+            self.quantize_scale = wildcard()
+            self.quantize_zp = wildcard()
+
+            # Ops that are permitted inbetween quantize and dequantize if we are
+            # rewriting to requantize
+            self.is_int_8_op = (
+                is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool3d")(wildcard())
+                | is_op("nn.relu")(wildcard())
+                | is_op("transpose")(wildcard())
+                | is_op("reshape")(wildcard())
+                | is_op("nn.pad")(wildcard())
+                | is_op("squeeze")(wildcard())
+                | is_op("nn.global_avg_pool2d")
+                | is_op("nn.batch_flatten")
+                | is_op("copy")
+                | is_op("mean")
+                | is_op("sqrt")
+            )
+
+            # All ops in is_int_8_op must also be in self.op_map
+            self.op_map = {
+                relay.op.get("nn.max_pool2d"): relay.op.nn.max_pool2d,
+                relay.op.get("nn.max_pool3d"): relay.op.nn.max_pool3d,
+                relay.op.get("transpose"): relay.op.transpose,
+                relay.op.get("reshape"): relay.op.reshape,
+                relay.op.get("nn.pad"): relay.op.nn.pad,
+                relay.op.get("squeeze"): relay.op.squeeze,
+                relay.op.get("nn.global_avg_pool2d"): relay.op.nn.global_avg_pool2d,
+                relay.op.get("nn.batch_flatten"): relay.op.nn.batch_flatten,
+                relay.op.get("copy"): relay.op.copy,
+                relay.op.get("mean"): relay.op.mean,
+                relay.op.get("sqrt"): relay.op.sqrt,
+            }
+
+            # Main pattern -- quantize(is_int_8_op*(dequantize(data))) --
+            # (with 1 or more is_int_8_ops)
+            self.dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+
+            self.dominator = dominates(self.dequantize, self.is_int_8_op, self.is_int_8_op)
+            self.quantize = is_op("qnn.quantize")(
+                self.dominator, self.quantize_scale, self.quantize_zp
+            )
+
+            # Pattern with the null path : quantize(dequantize(data)) -- (no is_int_8_op inbetween)
+            # We have to do the null path outside the dominator pattern because of pattern matcher
+            # limitations
+            self.no_path_dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+            self.no_path_quantize = is_op("qnn.quantize")(
+                self.no_path_dequantize, self.quantize_scale, self.quantize_zp
+            )
+
+            self.pattern = self.quantize | self.no_path_quantize
+
+        def callback(self, pre, post, node_map):
+            # Extract data from the pattern
+            data = node_map[self.data][0]
+            dequantize_scale = node_map[self.dequantize_scale][0]
+            deq_zp = node_map[self.dequantize_zp][0]
+
+            quantize_scale = node_map[self.quantize_scale][0]
+            quantize_zp = node_map[self.quantize_zp][0]
+
+            # Case where there are no ops in between the dequantize and quantize
+            if self.no_path_quantize in node_map:
+                axis = node_map[self.no_path_dequantize][0].attrs.axis
+                res = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+            # Ops inbetween quantize and dequantize are dominated
+            elif self.quantize in node_map:
+
+                axis = node_map[self.dequantize][0].attrs.axis
+                transformed_data = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+                for i in range(len(node_map[self.is_int_8_op]) - 1, -1, -1):
+                    call = node_map[self.is_int_8_op][i]
+                    # Transform relu into max(zeropoint)
+                    if call.op == relay.op.get("nn.relu"):
+                        if (
+                            quantize_zp.data.asnumpy()
+                            == relay.const(0, dtype="int32").data.asnumpy()
+                        ):
+                            transformed_data = relay.op.nn.relu(transformed_data)
+                        else:
+                            transformed_data = relay.op.maximum(
+                                transformed_data, relay.cast(quantize_zp, "int8")
+                            )
+                    elif call.op in self.op_map.keys():
+                        transformed_data = self.op_map[call.op](transformed_data, **call.attrs)
+                    else:
+                        raise ValueError(
+                            "Uh oh, %s is not copied properly in the requantizer. " % str(call.op)

Review comment:
       Even though you are using dominator pattern, rewritten this way it only supports a linear path graph. It breaks for diamond structures, e.g if the output of dequantize is consumed by multiple nodes




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r580740763



##########
File path: src/relay/qnn/op/quantize.cc
##########
@@ -46,19 +47,22 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   }
 
   const auto input_dtype = data->dtype;
-  ICHECK(input_dtype == DataType::Float(32))
-      << "Input type should be one of float32 but was " << input_dtype;
+  ICHECK(input_dtype == DataType::Float(32) || input_dtype == DataType::Int(32))

Review comment:
       Sure but it is the output that needs to be int32. Here, you are allowing inputs to be int32. Bias is aways float. Am I missing something?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r578915463



##########
File path: src/relay/qnn/op/dequantize.cc
##########
@@ -105,6 +115,10 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
 
   auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point);
   auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale);
+
+  if (out_dtype != DataType::Float(32)) {
+    scaled_output = Cast(scaled_output, out_dtype);

Review comment:
       Casting the output of float multiply to int is wrong.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r579008341



##########
File path: python/tvm/relay/transform/quantize/_calibrator.py
##########
@@ -0,0 +1,382 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""API for calibrating a quantized function."""
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm.contrib import graph_runtime
+import tvm.relay.build_module as build_module
+
+
+class QuantizationCalibrator:
+    """The QuantizationCalibrator picks scales and zero points for all qnn ops in the quantized
+    module.
+
+    Parameters
+    ----------
+    quantizer : Quantizer
+        Quantizer created with the mod we are calibrating.
+
+    target : String, optional
+        The target to run the quantized function on during calibration.
+
+    ctx : String, optional
+        The ctx used for running the quantized function on during calibration.
+
+    dataset_manager : DatasetManager, optional
+        The dataset manager containing data used to run the graph during
+        data-aware calibration.
+    """
+
+    def __init__(self, quantizer, target="llvm", ctx=tvm.cpu(), dataset_manager=None,
+                 show_scale_zps=False):
+        self.quantizer = quantizer
+
+        self.calibration_info = CalibrationInfo(
+            quantizer.tuple_subgraph_func,
+            quantizer.q_tuple_subgraph_func,
+            quantizer.partition_infos,
+            dataset_manager,
+            target,
+            ctx,
+        )
+
+        self.show_scale_zps = show_scale_zps
+
+    def calibrate(self):
+        """Picks the scales and zero points for all qnn ops in the quantized graph, using the
+        calibrate_pattern function from the quantizer.
+
+        Returns
+        -------
+        calibrated_func : relay.Function
+            The quantized function with the values for scales and zero points substituted into the
+            function.
+        """
+        # Create a map of DFPatternCallback to QuantizerPattern
+        pattern_map = {pattern.pattern: pattern for pattern in self.quantizer.patterns}
+
+        for partition_info in self.calibration_info.partition_infos:
+            # Set the partition info so we can access it from the callback
+            self.calibration_info.set_current_partition_info(partition_info)
+            quantizer_pattern = pattern_map[partition_info.pattern]
+
+            # Get the values for scales and ZPs in this layer, store
+            scale_zps = quantizer_pattern.calibrate_pattern(self.calibration_info)
+            if self.show_scale_zps:
+                self.report_scale_zps(scale_zps)
+            self.calibration_info.update_scale_zp_map(scale_zps)
+
+        calibrated_func = build_module.bind_params_by_name(
+            self.quantizer.q_tuple_subgraph_func, self.calibration_info.scale_zp_value_map
+        )
+
+        # If num_orig_outputs is -1, original output wasn't a tuple
+        params = calibrated_func.params
+        if self.quantizer.num_orig_outputs == -1:
+            calibrated_func = relay.Function(params, calibrated_func.body.fields[0])
+        else:
+            new_body = relay.Tuple(calibrated_func.body.fields[0 : self.quantizer.num_orig_outputs])
+            calibrated_func = relay.Function(params, new_body)
+
+        return calibrated_func
+
+    def report_scale_zps(self, scale_zp_map):
+        """Prints the scales and zero points out.
+
+        Parameters
+        ----------
+        scale_zp_map : dict of str to value
+            The map from names of scale and zero point variables to their assigned values.
+        """
+        for key, value in scale_zp_map.items():
+            print("Set ", key, " variable to ", value)
+
+
+class CalibrationInfo:
+    """Helper class that contains information necessary for picking scales and zero points into
+    calibrate_pattern. The state of CalibrationInfo is updated by QuantizationCalibrator.
+
+    Parameters
+    ----------
+    tuple_subgraph_func : relay.Function
+        A function whose output is a tuple that contains values we will need to access during
+        calibration.
+
+    q_tuple_subgraph_func : relay.Function
+        A quantized version of the tuple_subgraph_func. Note that to run this function, you
+        must pass in values for scales and zero points.
+
+    partition_infos : List[PatternCalibrationInfo]
+        A list of objects that correspond to every pattern matched during quantization. Each
+        contains scale and zero point variables, and indices into the the tuple functions.
+
+    dataset_manager : DatasetManager
+        The dataset manager containing data used to run the graph during data-aware calibration.
+
+    target : String
+        The target to run the quantized function on during calibration.
+
+    ctx : String
+        The ctx used for running the quantized function on during calibration.
+    """
+
+    def __init__(
+        self,
+        tuple_subgraph_func,
+        q_tuple_subgraph_func,
+        partition_infos,
+        dataset_manager,
+        target,
+        ctx,
+    ):
+        self.tuple_subgraph_func = tuple_subgraph_func
+        self.q_tuple_subgraph_func = q_tuple_subgraph_func
+        self.dataset_manager = dataset_manager
+        self.partition_infos = partition_infos
+        self.target = target
+        self.ctx = ctx
+
+        self.partition_info = None
+        self.input_scale_zps = None
+
+        tuple_subgraph_mod = tvm.ir.IRModule.from_expr(self.tuple_subgraph_func)
+        q_tuple_subgraph_mod = tvm.ir.IRModule.from_expr(self.q_tuple_subgraph_func)
+
+        self.tuple_subgraph_graphmodule = None
+        self.q_tuple_subgraph_graphmodule = None
+        self.init_subgraph_graphmodules(tuple_subgraph_mod, q_tuple_subgraph_mod)
+
+        self.scale_zp_value_map = {}
+        self.initialize_scale_zp_map()
+
+    def init_subgraph_graphmodules(self, tuple_subgraph_mod, q_tuple_subgraph_mod):
+        """Builds the tuple subgraphs so they can be run during calibration.
+
+        Parameters
+        ----------
+        tuple_subgraph_mod : tvm.ir.IRModule
+            Module wrapping tuple_subgraph_func.
+
+        q_tuple_subgraph_mod : tvm.ir.IRModule
+            Module wrapping q_tuple_subgraph_func.
+        """
+        # AlterOpLayout is disabled because it inserts some pads and other ops
+        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            tuple_subgraph_lib = relay.build(tuple_subgraph_mod, target=self.target)
+            q_tuple_subgraph_lib = relay.build(q_tuple_subgraph_mod, target=self.target)
+
+        ts_graph_mod = graph_runtime.GraphModule(tuple_subgraph_lib["default"](self.ctx))
+        q_ts_graph_mod = graph_runtime.GraphModule(q_tuple_subgraph_lib["default"](self.ctx))
+        self.tuple_subgraph_graphmodule = ts_graph_mod
+        self.q_tuple_subgraph_graphmodule = q_ts_graph_mod
+
+    def initialize_scale_zp_map(self):
+        """Initializes scales to 1 and zero points to zero. These values will only be used
+        to calculate values in the tuple subgraph that are not returned to the user."""

Review comment:
       We should be careful with how we initialize the params here. A scale of 1 doesn't make sense, since it would essentially clamp the entire floating point range to [-128, 127]. So the outputs from the first run will likely be garbage.
   
   Does the choice of initialization affect the accuracy of the final model? If so, we should use more sensible values by default.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r579077073



##########
File path: python/tvm/relay/transform/quantize/_quantizer.py
##########
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Quantizes functions by inserting qnn.quantize and qnn.dequantize ops."""
+from typing import List
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import _DFPatternCallback
+from tvm.relay.transform.quantize import QuantizerPattern
+from tvm.relay.frontend.common import infer_type
+
+from . import _ffi as ffi
+
+
+class Quantizer:

Review comment:
       I think this is redundant, since all you do is to do some stuff in the constructor and immediately pass this object to `QuantizationCalibrator`. It is better to directly do the same initialization in the `QuantizationCalibrator` constructor.
   And probably I'd rename `QuantizationCalibrator` to `Quantizer`.

##########
File path: python/tvm/relay/transform/quantize/_quantizer.py
##########
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Quantizes functions by inserting qnn.quantize and qnn.dequantize ops."""
+from typing import List
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import _DFPatternCallback
+from tvm.relay.transform.quantize import QuantizerPattern
+from tvm.relay.frontend.common import infer_type
+
+from . import _ffi as ffi
+
+
+class Quantizer:

Review comment:
       I think this class is redundant, since all you do is to do some stuff in the constructor and immediately pass this object to `QuantizationCalibrator`. It is better to directly do the same initialization in the `QuantizationCalibrator` constructor.
   And probably I'd rename `QuantizationCalibrator` to `Quantizer`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r578912891



##########
File path: python/tvm/relay/transform/quantize/_quantizer_patterns.py
##########
@@ -0,0 +1,712 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Patterns to quantize and how to quantize them."""
+
+import tvm
+from tvm import relay
+
+from tvm.relay.transform.quantize import CalibrationCallback
+from tvm.relay.dataflow_pattern import (
+    is_op,
+    wildcard,
+    is_constant,
+    DFPatternCallback,
+    _DFPatternCallback,
+)
+from tvm.relay.dataflow_pattern import ffi as pattern_ffi
+from tvm.relay.frontend.common import infer_type
+from tvm.relay.op.nn.utils import get_pad_tuple2d
+
+
+class QuantizerPattern(DFPatternCallback):
+    """DFPatternCallback to rewrite patterns as quantized. Also contains extra information
+    used for quantization and calibration.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate the nn.conv2d pattern.
+    """
+
+    # Counts the number of times we've added a scale and zp for variable naming
+    # This needs to be a global variable and not initialized in __init__ because
+    # each scale and zero point must be unique, even if they are created by different
+    # instances.
+    scales_count = 0
+    zp_count = 0
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__()
+        self.calibration_callback = calibration_callback
+
+    def calibrate_pattern(self, calibration_info):
+        """Calculates the scale and zero points for quantizing parts of a generic pattern. By
+        default, we call the calibrate_pattern method of the CalibrationCallback object that is
+        passed into QuantizerPattern during initialization. However, if you want a pattern specific
+        quantization method or a per-channel quantization method, you should overwrite the
+        QuantizerPattern's calibrate_pattern method.
+
+        Parameters
+        ----------
+        calibration_info : CalibrationInfo
+            The class containing relevant information and utility functions to calibrate one
+            instance of a pattern.
+
+        Returns
+        -------
+        scale_zp_map : Dictionary
+            A map from the names of scales and zero point variables in this pattern to their
+            values.
+        """
+        return self.calibration_callback.calibrate_pattern(calibration_info)
+
+    def callback(self, pre, post, node_map):
+        raise NotImplementedError
+
+    def scale(self, name):
+        """Helper to create the scale variable for qnn.quantize when rewriting our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the scale variable.
+
+        is_weight : bool
+            Whether this scale is a weight scale or a data scale. If it is a weight scale, we
+            the returned variable has shape (channels,). Only used for per-channel quantization.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_scale_0'.
+        """
+
+        var = relay.var(
+            str(name) + "_scale_" + str(QuantizerPattern.scales_count), shape=(), dtype="float32"
+        )
+        QuantizerPattern.scales_count += 1
+        return var
+
+    def zero_point(self, name):
+        """Helper to create the zero point variable for qnn.quantize when rewriting our
+        our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the variable.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_zero_pt_0'.
+        """
+        var = relay.var(
+            str(name) + "_zero_pt_" + str(QuantizerPattern.zp_count), shape=(), dtype="int32"
+        )
+        QuantizerPattern.zp_count += 1
+        return var
+
+    def create_scale_zps(self, left_name, right_name):
+        """Helper to create scales and zero points for binops.
+
+        Parameters
+        ----------
+        left_name : str
+            Identifier of the left hand side scale and zero point.
+
+        right_name : str
+            Identifier of the right hand side scale and zero point.
+        """
+        data_scale = self.scale(left_name)
+        data_zp = self.zero_point(left_name)
+        weight_scale = self.scale(right_name)
+        weight_zp = self.zero_point(right_name)
+        self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp]
+
+
+class Conv2DPattern(QuantizerPattern):
+    """Pattern to rewrite nn.conv2d ops as qnn.conv2d ops.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.input = wildcard()
+        self.conv_weight = wildcard()
+        self.inputs = [self.input, self.conv_weight]
+        self.conv2d = is_op("nn.conv2d")(self.input, self.conv_weight)
+        self.pattern = self.conv2d
+        self.attrs = None
+        self.weight_channel_axis = None
+        self.data_channel_axis = None
+        self.channels = None
+
+    def get_kernel_size(self, kernel_shape, kernel_layout):
+        """Gets the size of the kernel.
+
+        Parameters
+        ----------
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        kernel_layout : str
+            Layout of the kernel
+
+        Returns
+        -------
+            kernel_size : NDArray
+                Size of the kernel
+        """
+        if kernel_layout == "OIHW":
+            kernel_size = tuple(kernel_shape[2:4])
+        elif kernel_layout == "HWIO":
+            kernel_size = tuple(kernel_shape[0:2])
+        else:
+            raise ValueError(
+                "Quantizting kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                kernel_layout,
+            )
+        return kernel_size
+
+    def get_attrs(self, attrs, kernel_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.conv2d
+
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        new_attr_dict = {}
+        self.kernel_layout = attrs["kernel_layout"]
+        data_layout = attrs["data_layout"]
+
+        if self.kernel_layout == "OIHW":
+            self.weight_channel_axis = 0
+        elif self.kernel_layout == "HWIO":
+            self.weight_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                self.kernel_layout,
+            )
+
+        if data_layout == "NCHW":
+            self.data_channel_axis = 1
+        elif data_layout == "NHWC":
+            self.data_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing data layout %s for conv2d is not yet supported."
+                + "Please use NCHW or NHWC",
+                data_layout,
+            )
+
+        for attr in attrs.keys():
+            attr_value = attrs[attr]
+            if isinstance(attr_value, tvm.ir.container.Array):
+                attr_value = tuple(attr_value)
+            if attr == "kernel_size":
+                kernel_size = attrs[attr]
+                if kernel_size is None:
+                    kernel_size = self.get_kernel_size(self.kernel_layout, kernel_shape)
+                else:
+                    kernel_size = tuple([k.value for k in attrs[attr]])
+                new_attr_dict[attr] = kernel_size
+            elif attr == "channels":
+                self.channels = attrs[attr]
+                if self.channels is None:
+                    self.channels = kernel_shape[self.weight_channel_axis]
+                if isinstance(self.channels, tvm.tir.expr.IntImm):
+                    self.channels = self.channels.value
+                new_attr_dict[attr] = self.channels
+            elif attr == "padding":
+                # We don't need to put padding in attr dict because we explicitly construct padding
+                self.padding = attrs[attr]
+            else:
+                new_attr_dict[attr] = attr_value
+
+        new_attr_dict["out_dtype"] = "int32"
+        self.attrs = new_attr_dict
+
+    def quantize_args(self):
+        """Helper to quantize the arguments to the qnn.conv2d."""
+        quantized_data = relay.qnn.op.quantize(
+            self.args[0], self.scale_zps[0], self.scale_zps[1], axis=self.data_channel_axis
+        )
+        quantized_weight = relay.qnn.op.quantize(
+            self.args[1], self.scale_zps[2], self.scale_zps[3], axis=self.weight_channel_axis
+        )
+        self.quantized_args = [quantized_data, quantized_weight]
+
+    def create_conv(self, args, node_map):
+        """Creates the qnn.conv2d.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.conv2d.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_conv2d : relay.Expr
+            Quantized version of the pattern.
+        """
+        return relay.qnn.op.conv2d(*args, **self.attrs)
+
+    def callback(self, pre, post, node_map):
+        self.args = [node_map[i][0] for i in self.inputs]
+        conv2d = node_map[self.conv2d][0]
+
+        self.out_dtype = conv2d.checked_type.dtype
+
+        self.get_attrs(conv2d.attrs, infer_type(self.args[1]).checked_type.shape)
+
+        self.create_scale_zps("conv2d_data", "conv2d_weight")
+        self.quantize_args()
+
+        conv_scale = self.scale_zps[0] * self.scale_zps[2]  # data_scale * weight_scale
+
+        # Conv zp is zero since QNN deals with input zps for us
+        conv_zp = relay.const(0, dtype="int32")
+        # args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale]
+        args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]]
+
+        if self.padding is not None:
+
+            top, left, bottom, right = [p.value for p in get_pad_tuple2d(self.padding)]
+            if self.kernel_layout == "OIHW":
+                pad_width = ((0, 0), (0, 0), (top, bottom), (left, right))
+            elif self.kernel_layout == "HWIO":
+                pad_width = (
+                    (top, bottom),
+                    (left, right),
+                    (0, 0),
+                    (0, 0),
+                )
+            pad_val = 0
+            args[0] = relay.op.nn.pad(args[0], pad_width, pad_val)
+
+        # Construct quantized qnn.conv2d and dequantize
+        qnn_call = self.create_conv(args, node_map)
+        dequantized_call = relay.qnn.op.dequantize(
+            qnn_call, conv_scale, conv_zp, out_dtype=self.out_dtype, axis=self.data_channel_axis
+        )
+
+        return dequantized_call
+
+
+class Conv2DBiasAddPattern(Conv2DPattern):
+    """Pattern to rewrite nn.conv2d -> nn.bias_add pattern as qnn.conv2d -> nn.bias_add.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.bias_weight = is_constant()
+        self.inputs.append(self.bias_weight)
+        self.add = is_op("add")(self.conv2d, self.bias_weight)
+        self.bias_add = is_op("nn.bias_add")(self.conv2d, self.bias_weight)
+        self.pattern = self.bias_add | self.add
+
+    def quantize_args(self):
+        """Quantizes the arguments to the nn.conv2d -> nn.bias_add pattern."""
+        super().quantize_args()
+        quantized_bias = relay.qnn.op.quantize(
+            self.args[2], self.scale_zps[0], self.scale_zps[1], axis=0, out_dtype="int32"
+        )
+        self.quantized_args.append(quantized_bias)
+
+    def create_conv(self, args, node_map):
+        """Creates the qnn.dense -> nn.bias_add.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.conv2d and bias_add.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_conv2d : relay.Expr
+            Quantized version of the pattern.
+        """
+        qnn_call = relay.qnn.op.conv2d(*args, **self.attrs)
+        if node_map.get(self.add) is not None:
+            bias_add = relay.op.add(qnn_call, self.quantized_args[2])
+        else:  # self.bias_add in node_map
+            bias_add = relay.op.nn.bias_add(
+                qnn_call, self.quantized_args[2], axis=self.data_channel_axis
+            )
+        return bias_add
+
+
+class DensePattern(QuantizerPattern):
+    """Pattern to rewrite nn.dense pattern as qnn.dense.
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.data = wildcard()
+        self.weight = wildcard()
+        self.inputs = [self.data, self.weight]
+
+        self.dense = is_op("nn.dense")(self.data, self.weight)
+
+        self.pattern = self.dense
+        self.attrs = None
+        self.units = None
+
+    def get_attrs(self, attrs, weight_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.dense
+
+        weight_shape : NDArray
+            Shape of the dense weights
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        self.attrs = {}
+        units = attrs["units"]
+        if units is None:
+            units = weight_shape[0]
+        self.units = units.value
+        self.attrs["units"] = self.units
+
+    def quantize_args(self):
+        """Quantizes the arguments to the nn.dense pattern."""
+        # Quantize data and construct args for qnn.dense
+        quantized_data = relay.qnn.op.quantize(self.args[0], self.scale_zps[0], self.scale_zps[1])
+        quantized_weight = relay.qnn.op.quantize(
+            self.args[1], self.scale_zps[2], self.scale_zps[3], axis=0
+        )  # Axis = 0 for per channel quantization
+        self.quantized_args = [quantized_data, quantized_weight]
+
+    def create_dense(self, args, node_map):
+        """Creates the qnn.dense.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.dense.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_dense : relay.Expr
+            Quantized version of the pattern.
+        """
+        qnn_call = relay.qnn.op.dense(*args, **self.attrs)
+        return qnn_call
+
+    def callback(self, pre, post, node_map):
+        self.args = [node_map[i][0] for i in self.inputs]
+        weight = node_map[self.weight][0]
+
+        dense = node_map[self.dense][0]
+        out_dtype = dense.checked_type.dtype
+        self.get_attrs(dense.attrs, infer_type(weight).checked_type.shape)
+        self.create_scale_zps("dense_data", "dense_weight")
+        self.quantize_args()
+
+        # args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale]
+        args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]]
+        qnn_call = self.create_dense(args, node_map)
+
+        deq_call = relay.qnn.op.dequantize(
+            qnn_call,
+            self.scale_zps[0] * self.scale_zps[2],
+            relay.const(0, dtype="int32"),
+            out_dtype=out_dtype,
+            axis=1,
+        )
+
+        return deq_call
+
+
+class DenseBiasAddPattern(DensePattern):
+    """Pattern to rewrite nn.dense -> add and nn.dense -> nn.bias_add pattern as
+    qnn.dense -> nn.bias_add.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.bias_weight = is_constant()
+        self.inputs.append(self.bias_weight)
+        self.bias_add = is_op("nn.bias_add")(self.dense, self.bias_weight)
+        self.add = is_op("add")(self.dense, self.bias_weight)
+        self.pattern = self.bias_add | self.add
+
+    def quantize_args(self):
+        super().quantize_args()
+        quantized_bias = relay.qnn.op.quantize(
+            self.args[2], self.scale_zps[0], self.scale_zps[1], axis=0, out_dtype="int32"
+        )
+        self.quantized_args.append(quantized_bias)
+
+    def create_dense(self, args, node_map):
+        qnn_call = relay.qnn.op.dense(*args, **self.attrs)
+        if node_map.get(self.add) is not None:
+            bias_add = relay.op.add(qnn_call, self.quantized_args[2])
+        else:  # self.bias_add in node_map
+            bias_add = relay.op.nn.bias_add(
+                qnn_call, self.quantized_args[2], axis=1  # Axis is always 1 for dense
+            )
+        return bias_add
+
+
+class AddPattern(QuantizerPattern):
+    """Pattern to rewrite add as quantized.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.lhs = wildcard()
+        self.rhs = wildcard()
+        self.add = is_op("add")(self.lhs, self.rhs)
+        self.pattern = self.add
+
+    def callback(self, pre, post, node_map):
+        lhs = node_map[self.lhs][0]
+        rhs = node_map[self.rhs][0]
+
+        add = node_map[self.add][0]
+
+        out_dtype = infer_type(add).checked_type.dtype
+
+        # Create quantization parameters for arguments to this addition
+        self.create_scale_zps("add_lhs", "add_rhs")
+
+        # Quantize, dequantize, and requantize inputs to have scale lhs_scale + rhs_scale
+        # (Scale represents the lowest possible value representable in the quantized type,
+        # so the smallest representable output is lhs_scale + rhs_scale)
+
+        # We do this to avoid the requantize op in qnn's add, which causes issues with compilation
+        # Requantize will be inserted in a future pass
+        lhs_scale, lhs_zp, rhs_scale, rhs_zp = self.scale_zps
+        quantized_lhs = relay.qnn.op.quantize(lhs, lhs_scale, lhs_zp)
+        quantized_rhs = relay.qnn.op.quantize(rhs, rhs_scale, rhs_zp)
+
+        dequantized_lhs = relay.qnn.op.dequantize(
+            quantized_lhs, lhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )
+        dequantized_rhs = relay.qnn.op.dequantize(
+            quantized_rhs, rhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )

Review comment:
       I have no idea what's going on here




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r578911395



##########
File path: python/tvm/relay/transform/quantize/_quantizer_patterns.py
##########
@@ -0,0 +1,712 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Patterns to quantize and how to quantize them."""
+
+import tvm
+from tvm import relay
+
+from tvm.relay.transform.quantize import CalibrationCallback
+from tvm.relay.dataflow_pattern import (
+    is_op,
+    wildcard,
+    is_constant,
+    DFPatternCallback,
+    _DFPatternCallback,
+)
+from tvm.relay.dataflow_pattern import ffi as pattern_ffi
+from tvm.relay.frontend.common import infer_type
+from tvm.relay.op.nn.utils import get_pad_tuple2d
+
+
+class QuantizerPattern(DFPatternCallback):
+    """DFPatternCallback to rewrite patterns as quantized. Also contains extra information
+    used for quantization and calibration.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate the nn.conv2d pattern.
+    """
+
+    # Counts the number of times we've added a scale and zp for variable naming
+    # This needs to be a global variable and not initialized in __init__ because
+    # each scale and zero point must be unique, even if they are created by different
+    # instances.
+    scales_count = 0
+    zp_count = 0
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__()
+        self.calibration_callback = calibration_callback
+
+    def calibrate_pattern(self, calibration_info):
+        """Calculates the scale and zero points for quantizing parts of a generic pattern. By
+        default, we call the calibrate_pattern method of the CalibrationCallback object that is
+        passed into QuantizerPattern during initialization. However, if you want a pattern specific
+        quantization method or a per-channel quantization method, you should overwrite the
+        QuantizerPattern's calibrate_pattern method.
+
+        Parameters
+        ----------
+        calibration_info : CalibrationInfo
+            The class containing relevant information and utility functions to calibrate one
+            instance of a pattern.
+
+        Returns
+        -------
+        scale_zp_map : Dictionary
+            A map from the names of scales and zero point variables in this pattern to their
+            values.
+        """
+        return self.calibration_callback.calibrate_pattern(calibration_info)
+
+    def callback(self, pre, post, node_map):
+        raise NotImplementedError
+
+    def scale(self, name):
+        """Helper to create the scale variable for qnn.quantize when rewriting our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the scale variable.
+
+        is_weight : bool
+            Whether this scale is a weight scale or a data scale. If it is a weight scale, we
+            the returned variable has shape (channels,). Only used for per-channel quantization.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_scale_0'.
+        """
+
+        var = relay.var(
+            str(name) + "_scale_" + str(QuantizerPattern.scales_count), shape=(), dtype="float32"
+        )
+        QuantizerPattern.scales_count += 1
+        return var
+
+    def zero_point(self, name):
+        """Helper to create the zero point variable for qnn.quantize when rewriting our
+        our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the variable.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_zero_pt_0'.
+        """
+        var = relay.var(
+            str(name) + "_zero_pt_" + str(QuantizerPattern.zp_count), shape=(), dtype="int32"
+        )
+        QuantizerPattern.zp_count += 1
+        return var
+
+    def create_scale_zps(self, left_name, right_name):
+        """Helper to create scales and zero points for binops.
+
+        Parameters
+        ----------
+        left_name : str
+            Identifier of the left hand side scale and zero point.
+
+        right_name : str
+            Identifier of the right hand side scale and zero point.
+        """
+        data_scale = self.scale(left_name)
+        data_zp = self.zero_point(left_name)
+        weight_scale = self.scale(right_name)
+        weight_zp = self.zero_point(right_name)
+        self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp]
+
+
+class Conv2DPattern(QuantizerPattern):
+    """Pattern to rewrite nn.conv2d ops as qnn.conv2d ops.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.input = wildcard()
+        self.conv_weight = wildcard()
+        self.inputs = [self.input, self.conv_weight]
+        self.conv2d = is_op("nn.conv2d")(self.input, self.conv_weight)
+        self.pattern = self.conv2d
+        self.attrs = None
+        self.weight_channel_axis = None
+        self.data_channel_axis = None
+        self.channels = None
+
+    def get_kernel_size(self, kernel_shape, kernel_layout):
+        """Gets the size of the kernel.
+
+        Parameters
+        ----------
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        kernel_layout : str
+            Layout of the kernel
+
+        Returns
+        -------
+            kernel_size : NDArray
+                Size of the kernel
+        """
+        if kernel_layout == "OIHW":
+            kernel_size = tuple(kernel_shape[2:4])
+        elif kernel_layout == "HWIO":
+            kernel_size = tuple(kernel_shape[0:2])
+        else:
+            raise ValueError(
+                "Quantizting kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                kernel_layout,
+            )
+        return kernel_size
+
+    def get_attrs(self, attrs, kernel_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.conv2d
+
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        new_attr_dict = {}
+        self.kernel_layout = attrs["kernel_layout"]
+        data_layout = attrs["data_layout"]
+
+        if self.kernel_layout == "OIHW":
+            self.weight_channel_axis = 0
+        elif self.kernel_layout == "HWIO":
+            self.weight_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                self.kernel_layout,
+            )
+
+        if data_layout == "NCHW":
+            self.data_channel_axis = 1
+        elif data_layout == "NHWC":
+            self.data_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing data layout %s for conv2d is not yet supported."
+                + "Please use NCHW or NHWC",
+                data_layout,
+            )
+
+        for attr in attrs.keys():
+            attr_value = attrs[attr]
+            if isinstance(attr_value, tvm.ir.container.Array):
+                attr_value = tuple(attr_value)
+            if attr == "kernel_size":
+                kernel_size = attrs[attr]
+                if kernel_size is None:
+                    kernel_size = self.get_kernel_size(self.kernel_layout, kernel_shape)
+                else:
+                    kernel_size = tuple([k.value for k in attrs[attr]])
+                new_attr_dict[attr] = kernel_size
+            elif attr == "channels":
+                self.channels = attrs[attr]
+                if self.channels is None:
+                    self.channels = kernel_shape[self.weight_channel_axis]
+                if isinstance(self.channels, tvm.tir.expr.IntImm):
+                    self.channels = self.channels.value
+                new_attr_dict[attr] = self.channels
+            elif attr == "padding":
+                # We don't need to put padding in attr dict because we explicitly construct padding
+                self.padding = attrs[attr]
+            else:
+                new_attr_dict[attr] = attr_value
+
+        new_attr_dict["out_dtype"] = "int32"
+        self.attrs = new_attr_dict
+
+    def quantize_args(self):
+        """Helper to quantize the arguments to the qnn.conv2d."""
+        quantized_data = relay.qnn.op.quantize(
+            self.args[0], self.scale_zps[0], self.scale_zps[1], axis=self.data_channel_axis
+        )
+        quantized_weight = relay.qnn.op.quantize(
+            self.args[1], self.scale_zps[2], self.scale_zps[3], axis=self.weight_channel_axis
+        )
+        self.quantized_args = [quantized_data, quantized_weight]
+
+    def create_conv(self, args, node_map):
+        """Creates the qnn.conv2d.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.conv2d.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_conv2d : relay.Expr
+            Quantized version of the pattern.
+        """
+        return relay.qnn.op.conv2d(*args, **self.attrs)
+
+    def callback(self, pre, post, node_map):
+        self.args = [node_map[i][0] for i in self.inputs]
+        conv2d = node_map[self.conv2d][0]
+
+        self.out_dtype = conv2d.checked_type.dtype
+
+        self.get_attrs(conv2d.attrs, infer_type(self.args[1]).checked_type.shape)
+
+        self.create_scale_zps("conv2d_data", "conv2d_weight")
+        self.quantize_args()
+
+        conv_scale = self.scale_zps[0] * self.scale_zps[2]  # data_scale * weight_scale
+
+        # Conv zp is zero since QNN deals with input zps for us
+        conv_zp = relay.const(0, dtype="int32")
+        # args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale]
+        args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]]
+
+        if self.padding is not None:
+
+            top, left, bottom, right = [p.value for p in get_pad_tuple2d(self.padding)]
+            if self.kernel_layout == "OIHW":
+                pad_width = ((0, 0), (0, 0), (top, bottom), (left, right))
+            elif self.kernel_layout == "HWIO":
+                pad_width = (
+                    (top, bottom),
+                    (left, right),
+                    (0, 0),
+                    (0, 0),
+                )
+            pad_val = 0

Review comment:
       again, this is assuming that zero point is always 0, so you are only supporting symmetric Q




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies closed pull request #7474: [WIP] [Quantization] Quantization in TVM

Posted by GitBox <gi...@apache.org>.
electriclilies closed pull request #7474:
URL: https://github.com/apache/tvm/pull/7474


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r579072951



##########
File path: python/tvm/relay/transform/quantize/_quantizer_patterns.py
##########
@@ -0,0 +1,712 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Patterns to quantize and how to quantize them."""
+
+import tvm
+from tvm import relay
+
+from tvm.relay.transform.quantize import CalibrationCallback
+from tvm.relay.dataflow_pattern import (
+    is_op,
+    wildcard,
+    is_constant,
+    DFPatternCallback,
+    _DFPatternCallback,
+)
+from tvm.relay.dataflow_pattern import ffi as pattern_ffi
+from tvm.relay.frontend.common import infer_type
+from tvm.relay.op.nn.utils import get_pad_tuple2d
+
+
+class QuantizerPattern(DFPatternCallback):
+    """DFPatternCallback to rewrite patterns as quantized. Also contains extra information
+    used for quantization and calibration.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate the nn.conv2d pattern.
+    """
+
+    # Counts the number of times we've added a scale and zp for variable naming
+    # This needs to be a global variable and not initialized in __init__ because
+    # each scale and zero point must be unique, even if they are created by different
+    # instances.
+    scales_count = 0
+    zp_count = 0
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__()
+        self.calibration_callback = calibration_callback
+
+    def calibrate_pattern(self, calibration_info):
+        """Calculates the scale and zero points for quantizing parts of a generic pattern. By
+        default, we call the calibrate_pattern method of the CalibrationCallback object that is
+        passed into QuantizerPattern during initialization. However, if you want a pattern specific
+        quantization method or a per-channel quantization method, you should overwrite the
+        QuantizerPattern's calibrate_pattern method.
+
+        Parameters
+        ----------
+        calibration_info : CalibrationInfo
+            The class containing relevant information and utility functions to calibrate one
+            instance of a pattern.
+
+        Returns
+        -------
+        scale_zp_map : Dictionary
+            A map from the names of scales and zero point variables in this pattern to their
+            values.
+        """
+        return self.calibration_callback.calibrate_pattern(calibration_info)
+
+    def callback(self, pre, post, node_map):
+        raise NotImplementedError
+
+    def scale(self, name):
+        """Helper to create the scale variable for qnn.quantize when rewriting our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the scale variable.
+
+        is_weight : bool
+            Whether this scale is a weight scale or a data scale. If it is a weight scale, we
+            the returned variable has shape (channels,). Only used for per-channel quantization.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_scale_0'.
+        """
+
+        var = relay.var(
+            str(name) + "_scale_" + str(QuantizerPattern.scales_count), shape=(), dtype="float32"
+        )
+        QuantizerPattern.scales_count += 1
+        return var
+
+    def zero_point(self, name):
+        """Helper to create the zero point variable for qnn.quantize when rewriting our
+        our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the variable.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_zero_pt_0'.
+        """
+        var = relay.var(
+            str(name) + "_zero_pt_" + str(QuantizerPattern.zp_count), shape=(), dtype="int32"
+        )
+        QuantizerPattern.zp_count += 1
+        return var
+
+    def create_scale_zps(self, left_name, right_name):
+        """Helper to create scales and zero points for binops.
+
+        Parameters
+        ----------
+        left_name : str
+            Identifier of the left hand side scale and zero point.
+
+        right_name : str
+            Identifier of the right hand side scale and zero point.
+        """
+        data_scale = self.scale(left_name)
+        data_zp = self.zero_point(left_name)
+        weight_scale = self.scale(right_name)
+        weight_zp = self.zero_point(right_name)
+        self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp]
+
+
+class Conv2DPattern(QuantizerPattern):
+    """Pattern to rewrite nn.conv2d ops as qnn.conv2d ops.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.input = wildcard()
+        self.conv_weight = wildcard()
+        self.inputs = [self.input, self.conv_weight]
+        self.conv2d = is_op("nn.conv2d")(self.input, self.conv_weight)
+        self.pattern = self.conv2d
+        self.attrs = None
+        self.weight_channel_axis = None
+        self.data_channel_axis = None
+        self.channels = None
+
+    def get_kernel_size(self, kernel_shape, kernel_layout):
+        """Gets the size of the kernel.
+
+        Parameters
+        ----------
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        kernel_layout : str
+            Layout of the kernel
+
+        Returns
+        -------
+            kernel_size : NDArray
+                Size of the kernel
+        """
+        if kernel_layout == "OIHW":
+            kernel_size = tuple(kernel_shape[2:4])
+        elif kernel_layout == "HWIO":
+            kernel_size = tuple(kernel_shape[0:2])
+        else:
+            raise ValueError(
+                "Quantizting kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                kernel_layout,
+            )
+        return kernel_size
+
+    def get_attrs(self, attrs, kernel_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.conv2d
+
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        new_attr_dict = {}
+        self.kernel_layout = attrs["kernel_layout"]
+        data_layout = attrs["data_layout"]
+
+        if self.kernel_layout == "OIHW":
+            self.weight_channel_axis = 0
+        elif self.kernel_layout == "HWIO":
+            self.weight_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                self.kernel_layout,
+            )
+
+        if data_layout == "NCHW":
+            self.data_channel_axis = 1
+        elif data_layout == "NHWC":
+            self.data_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing data layout %s for conv2d is not yet supported."
+                + "Please use NCHW or NHWC",
+                data_layout,
+            )
+
+        for attr in attrs.keys():
+            attr_value = attrs[attr]
+            if isinstance(attr_value, tvm.ir.container.Array):
+                attr_value = tuple(attr_value)
+            if attr == "kernel_size":
+                kernel_size = attrs[attr]
+                if kernel_size is None:
+                    kernel_size = self.get_kernel_size(self.kernel_layout, kernel_shape)
+                else:
+                    kernel_size = tuple([k.value for k in attrs[attr]])
+                new_attr_dict[attr] = kernel_size
+            elif attr == "channels":
+                self.channels = attrs[attr]
+                if self.channels is None:
+                    self.channels = kernel_shape[self.weight_channel_axis]
+                if isinstance(self.channels, tvm.tir.expr.IntImm):
+                    self.channels = self.channels.value
+                new_attr_dict[attr] = self.channels
+            elif attr == "padding":
+                # We don't need to put padding in attr dict because we explicitly construct padding
+                self.padding = attrs[attr]
+            else:
+                new_attr_dict[attr] = attr_value
+
+        new_attr_dict["out_dtype"] = "int32"
+        self.attrs = new_attr_dict
+
+    def quantize_args(self):
+        """Helper to quantize the arguments to the qnn.conv2d."""
+        quantized_data = relay.qnn.op.quantize(
+            self.args[0], self.scale_zps[0], self.scale_zps[1], axis=self.data_channel_axis
+        )

Review comment:
       If you quantize this way, I think an argument that is consumed by multiple nodes will get quantized multiple times. I don't think CSE would help since you are using separate scale and zp vars. You should cache the quantized values and avoid quantizing the same node twice.  




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r578913474



##########
File path: python/tvm/relay/transform/quantize/_requantizer.py
##########
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Removes extraneous qnn.quantize and qnn.dequantize from calibrated modules, and replaces them
+with qnn.requanize ops."""
+import math
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import DFPatternCallback, wildcard, is_op, dominates, rewrite
+
+
+class Requantizer:
+    """Removes extraneous qnn.quantize and qnn.dequantize and replaces
+    them with qnn.requantize."""
+
+    class RequantizerCallback(DFPatternCallback):
+        """First pass that inserts requantize ops, specifically taking
+        qnn.dequantize -> qnn.quantize to qnn.requantize
+        and
+        qnn.dequantize -> int8_op* -> qnn.quantize to requantize -> int8_op*
+        """
+
+        def __init__(self):
+            super().__init__()
+
+            self.data = wildcard()
+            self.dequantize_scale = wildcard()
+            self.dequantize_zp = wildcard()
+
+            self.quantize_scale = wildcard()
+            self.quantize_zp = wildcard()
+
+            # Ops that are permitted inbetween quantize and dequantize if we are
+            # rewriting to requantize
+            self.is_int_8_op = (
+                is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool3d")(wildcard())
+                | is_op("nn.relu")(wildcard())
+                | is_op("transpose")(wildcard())
+                | is_op("reshape")(wildcard())
+                | is_op("nn.pad")(wildcard())
+                | is_op("squeeze")(wildcard())
+                | is_op("nn.global_avg_pool2d")
+                | is_op("nn.batch_flatten")
+                | is_op("copy")
+                | is_op("mean")
+                | is_op("sqrt")
+            )

Review comment:
       This is too ad hoc, it can easily break and leaves more dequantize/quantize than necessary. And the patterns are not correct.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r578915087



##########
File path: python/tvm/relay/transform/quantize/_requantizer.py
##########
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Removes extraneous qnn.quantize and qnn.dequantize from calibrated modules, and replaces them
+with qnn.requanize ops."""
+import math
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import DFPatternCallback, wildcard, is_op, dominates, rewrite
+
+
+class Requantizer:
+    """Removes extraneous qnn.quantize and qnn.dequantize and replaces
+    them with qnn.requantize."""
+
+    class RequantizerCallback(DFPatternCallback):
+        """First pass that inserts requantize ops, specifically taking
+        qnn.dequantize -> qnn.quantize to qnn.requantize
+        and
+        qnn.dequantize -> int8_op* -> qnn.quantize to requantize -> int8_op*
+        """
+
+        def __init__(self):
+            super().__init__()
+
+            self.data = wildcard()
+            self.dequantize_scale = wildcard()
+            self.dequantize_zp = wildcard()
+
+            self.quantize_scale = wildcard()
+            self.quantize_zp = wildcard()
+
+            # Ops that are permitted inbetween quantize and dequantize if we are
+            # rewriting to requantize
+            self.is_int_8_op = (
+                is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool3d")(wildcard())
+                | is_op("nn.relu")(wildcard())
+                | is_op("transpose")(wildcard())
+                | is_op("reshape")(wildcard())
+                | is_op("nn.pad")(wildcard())
+                | is_op("squeeze")(wildcard())
+                | is_op("nn.global_avg_pool2d")
+                | is_op("nn.batch_flatten")
+                | is_op("copy")
+                | is_op("mean")
+                | is_op("sqrt")
+            )
+
+            # All ops in is_int_8_op must also be in self.op_map
+            self.op_map = {
+                relay.op.get("nn.max_pool2d"): relay.op.nn.max_pool2d,
+                relay.op.get("nn.max_pool3d"): relay.op.nn.max_pool3d,
+                relay.op.get("transpose"): relay.op.transpose,
+                relay.op.get("reshape"): relay.op.reshape,
+                relay.op.get("nn.pad"): relay.op.nn.pad,
+                relay.op.get("squeeze"): relay.op.squeeze,
+                relay.op.get("nn.global_avg_pool2d"): relay.op.nn.global_avg_pool2d,
+                relay.op.get("nn.batch_flatten"): relay.op.nn.batch_flatten,
+                relay.op.get("copy"): relay.op.copy,
+                relay.op.get("mean"): relay.op.mean,
+                relay.op.get("sqrt"): relay.op.sqrt,
+            }
+
+            # Main pattern -- quantize(is_int_8_op*(dequantize(data))) --
+            # (with 1 or more is_int_8_ops)
+            self.dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+
+            self.dominator = dominates(self.dequantize, self.is_int_8_op, self.is_int_8_op)
+            self.quantize = is_op("qnn.quantize")(
+                self.dominator, self.quantize_scale, self.quantize_zp
+            )
+
+            # Pattern with the null path : quantize(dequantize(data)) -- (no is_int_8_op inbetween)
+            # We have to do the null path outside the dominator pattern because of pattern matcher
+            # limitations
+            self.no_path_dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+            self.no_path_quantize = is_op("qnn.quantize")(
+                self.no_path_dequantize, self.quantize_scale, self.quantize_zp
+            )
+
+            self.pattern = self.quantize | self.no_path_quantize
+
+        def callback(self, pre, post, node_map):
+            # Extract data from the pattern
+            data = node_map[self.data][0]
+            dequantize_scale = node_map[self.dequantize_scale][0]
+            deq_zp = node_map[self.dequantize_zp][0]
+
+            quantize_scale = node_map[self.quantize_scale][0]
+            quantize_zp = node_map[self.quantize_zp][0]
+
+            # Case where there are no ops in between the dequantize and quantize
+            if self.no_path_quantize in node_map:
+                axis = node_map[self.no_path_dequantize][0].attrs.axis
+                res = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+            # Ops inbetween quantize and dequantize are dominated
+            elif self.quantize in node_map:
+
+                axis = node_map[self.dequantize][0].attrs.axis
+                transformed_data = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+                for i in range(len(node_map[self.is_int_8_op]) - 1, -1, -1):
+                    call = node_map[self.is_int_8_op][i]
+                    # Transform relu into max(zeropoint)
+                    if call.op == relay.op.get("nn.relu"):
+                        if (
+                            quantize_zp.data.asnumpy()
+                            == relay.const(0, dtype="int32").data.asnumpy()
+                        ):
+                            transformed_data = relay.op.nn.relu(transformed_data)
+                        else:
+                            transformed_data = relay.op.maximum(
+                                transformed_data, relay.cast(quantize_zp, "int8")
+                            )
+                    elif call.op in self.op_map.keys():
+                        transformed_data = self.op_map[call.op](transformed_data, **call.attrs)
+                    else:
+                        raise ValueError(
+                            "Uh oh, %s is not copied properly in the requantizer. " % str(call.op)
+                        )
+                res = transformed_data
+            return res
+
+    class RequantizeChainCallback(DFPatternCallback):
+        """Folds chains of requantizes into one requantize.
+        requantize(scale_a, zp_a, scale_b, zp_b) -> requantize(scale_b, zp_b, scale_c, zp_c) becomes
+        requantize(scale_a, zp_a, scale_c, zp_c)
+        """
+
+        # Takes a chain of requantizes and turns them into one requantize
+        def __init__(self):
+            super().__init__()
+            self.data = wildcard()
+            self.rq_parent_scale_in = wildcard()
+            self.rq_parent_zp_in = wildcard()
+            self.rq_parent_scale_out = wildcard()
+            self.rq_parent_zp_out = wildcard()
+
+            self.rq_child_scale_in = wildcard()
+            self.rq_child_zp_in = wildcard()
+            self.rq_child_scale_out = wildcard()
+            self.rq_child_zp_out = wildcard()
+
+            self.rq_parent = is_op("qnn.requantize")(
+                self.data,
+                self.rq_parent_scale_in,
+                self.rq_parent_zp_in,
+                self.rq_parent_scale_out,
+                self.rq_parent_zp_out,
+            )
+            self.rq_child = is_op("qnn.requantize")(
+                wildcard(),
+                self.rq_child_scale_in,
+                self.rq_child_zp_in,
+                self.rq_child_scale_out,
+                self.rq_child_zp_out,
+            )
+
+            self.pattern = dominates(self.rq_parent, self.rq_child, self.rq_child)
+
+        def callback(self, pre, post, node_map):
+            data = node_map[self.data][0]
+            rq_parent = node_map[self.rq_parent][0]
+
+            rq_parent_scale_in = node_map[self.rq_parent_scale_in][0]
+            rq_parent_zp_in = node_map[self.rq_parent_zp_in][0]
+
+            rq_parent_scale_out = node_map[self.rq_parent_scale_out][0]
+            rq_parent_zp_out = node_map[self.rq_parent_zp_out][0]
+
+            child_in_scales = node_map[self.rq_child_scale_in]
+            child_in_zps = node_map[self.rq_child_zp_in]
+            child_out_scales = node_map[self.rq_child_scale_out]
+            child_out_zps = node_map[self.rq_child_zp_out]
+
+            len_children = len(node_map[self.rq_child_scale_out])
+
+            # Check to make sure output and input scales and zps match before we apply this
+            # transformation
+            out_scale = rq_parent_scale_out
+            out_zp = rq_parent_zp_out
+
+            for i in range(0, len_children):
+
+                in_scale = child_in_scales[i]
+                in_zp = child_in_zps[i]
+
+                assert math.isclose(
+                    out_scale.data.asnumpy(), in_scale.data.asnumpy(), rel_tol=1e-05, abs_tol=1e-05
+                ) and math.isclose(
+                    out_zp.data.asnumpy(), in_zp.data.asnumpy(), rel_tol=1e-05, abs_tol=1e-05
+                ), (
+                    "Out scales/zps should match in scales/zps. Indicates an internal issue "
+                    "in the quantizer somewhere."
+                )
+
+                out_scale = child_out_scales[i]
+                out_zp = child_out_zps[i]
+
+            parent_axis = rq_parent.attrs["axis"]
+
+            return relay.qnn.op.requantize(
+                data, rq_parent_scale_in, rq_parent_zp_in, out_scale, out_zp, axis=parent_axis
+            )
+
+    class ConsolidateRequantizeandQuantize(DFPatternCallback):
+        """Gets rid of unnecessary requantizes directly following a quantize. Takes
+        quantize(scale_a, zp_a) -> requantize(scale_a, zp_a, scale_b, zp_b) to

Review comment:
       I don't see how this pattern could arise in practice? Do you have an example?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r578909514



##########
File path: include/tvm/relay/qnn/attrs.h
##########
@@ -78,13 +78,18 @@ struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
 /*! \brief Attribute for dequantize operator */
 struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
   int axis;
+  DataType out_dtype;
 
   TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
     TVM_ATTR_FIELD(axis)
         .describe(
             "The channel axis for channel wise dequantization. Default value is -1,"
             "which corresponds to the last axis.")
         .set_default(-1);
+    TVM_ATTR_FIELD(out_dtype)
+        .describe(
+            "The datatype we are dequantizing to (float32 or int32). Defaults to float32.")

Review comment:
       The output of dequantize is always float, so this change doesn't make sense to me.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r579070402



##########
File path: python/tvm/relay/transform/quantize/_quantizer_patterns.py
##########
@@ -0,0 +1,712 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Patterns to quantize and how to quantize them."""
+
+import tvm
+from tvm import relay
+
+from tvm.relay.transform.quantize import CalibrationCallback
+from tvm.relay.dataflow_pattern import (
+    is_op,
+    wildcard,
+    is_constant,
+    DFPatternCallback,
+    _DFPatternCallback,
+)
+from tvm.relay.dataflow_pattern import ffi as pattern_ffi
+from tvm.relay.frontend.common import infer_type
+from tvm.relay.op.nn.utils import get_pad_tuple2d
+
+
+class QuantizerPattern(DFPatternCallback):
+    """DFPatternCallback to rewrite patterns as quantized. Also contains extra information
+    used for quantization and calibration.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate the nn.conv2d pattern.
+    """
+
+    # Counts the number of times we've added a scale and zp for variable naming
+    # This needs to be a global variable and not initialized in __init__ because
+    # each scale and zero point must be unique, even if they are created by different
+    # instances.
+    scales_count = 0
+    zp_count = 0
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__()
+        self.calibration_callback = calibration_callback
+
+    def calibrate_pattern(self, calibration_info):
+        """Calculates the scale and zero points for quantizing parts of a generic pattern. By
+        default, we call the calibrate_pattern method of the CalibrationCallback object that is
+        passed into QuantizerPattern during initialization. However, if you want a pattern specific
+        quantization method or a per-channel quantization method, you should overwrite the
+        QuantizerPattern's calibrate_pattern method.
+
+        Parameters
+        ----------
+        calibration_info : CalibrationInfo
+            The class containing relevant information and utility functions to calibrate one
+            instance of a pattern.
+
+        Returns
+        -------
+        scale_zp_map : Dictionary
+            A map from the names of scales and zero point variables in this pattern to their
+            values.
+        """
+        return self.calibration_callback.calibrate_pattern(calibration_info)
+
+    def callback(self, pre, post, node_map):
+        raise NotImplementedError
+
+    def scale(self, name):
+        """Helper to create the scale variable for qnn.quantize when rewriting our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the scale variable.
+
+        is_weight : bool
+            Whether this scale is a weight scale or a data scale. If it is a weight scale, we
+            the returned variable has shape (channels,). Only used for per-channel quantization.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_scale_0'.
+        """
+
+        var = relay.var(
+            str(name) + "_scale_" + str(QuantizerPattern.scales_count), shape=(), dtype="float32"
+        )
+        QuantizerPattern.scales_count += 1
+        return var
+
+    def zero_point(self, name):
+        """Helper to create the zero point variable for qnn.quantize when rewriting our
+        our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the variable.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_zero_pt_0'.
+        """
+        var = relay.var(
+            str(name) + "_zero_pt_" + str(QuantizerPattern.zp_count), shape=(), dtype="int32"
+        )
+        QuantizerPattern.zp_count += 1
+        return var
+
+    def create_scale_zps(self, left_name, right_name):
+        """Helper to create scales and zero points for binops.
+
+        Parameters
+        ----------
+        left_name : str
+            Identifier of the left hand side scale and zero point.
+
+        right_name : str
+            Identifier of the right hand side scale and zero point.
+        """
+        data_scale = self.scale(left_name)
+        data_zp = self.zero_point(left_name)
+        weight_scale = self.scale(right_name)
+        weight_zp = self.zero_point(right_name)
+        self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp]
+
+
+class Conv2DPattern(QuantizerPattern):
+    """Pattern to rewrite nn.conv2d ops as qnn.conv2d ops.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.input = wildcard()
+        self.conv_weight = wildcard()
+        self.inputs = [self.input, self.conv_weight]
+        self.conv2d = is_op("nn.conv2d")(self.input, self.conv_weight)
+        self.pattern = self.conv2d
+        self.attrs = None
+        self.weight_channel_axis = None
+        self.data_channel_axis = None
+        self.channels = None
+
+    def get_kernel_size(self, kernel_shape, kernel_layout):
+        """Gets the size of the kernel.
+
+        Parameters
+        ----------
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        kernel_layout : str
+            Layout of the kernel
+
+        Returns
+        -------
+            kernel_size : NDArray
+                Size of the kernel
+        """
+        if kernel_layout == "OIHW":
+            kernel_size = tuple(kernel_shape[2:4])
+        elif kernel_layout == "HWIO":
+            kernel_size = tuple(kernel_shape[0:2])
+        else:
+            raise ValueError(
+                "Quantizting kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                kernel_layout,
+            )
+        return kernel_size
+
+    def get_attrs(self, attrs, kernel_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.conv2d
+
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        new_attr_dict = {}
+        self.kernel_layout = attrs["kernel_layout"]
+        data_layout = attrs["data_layout"]
+
+        if self.kernel_layout == "OIHW":
+            self.weight_channel_axis = 0
+        elif self.kernel_layout == "HWIO":
+            self.weight_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                self.kernel_layout,
+            )
+
+        if data_layout == "NCHW":
+            self.data_channel_axis = 1
+        elif data_layout == "NHWC":
+            self.data_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing data layout %s for conv2d is not yet supported."
+                + "Please use NCHW or NHWC",
+                data_layout,
+            )
+
+        for attr in attrs.keys():
+            attr_value = attrs[attr]
+            if isinstance(attr_value, tvm.ir.container.Array):
+                attr_value = tuple(attr_value)
+            if attr == "kernel_size":
+                kernel_size = attrs[attr]
+                if kernel_size is None:
+                    kernel_size = self.get_kernel_size(self.kernel_layout, kernel_shape)
+                else:
+                    kernel_size = tuple([k.value for k in attrs[attr]])
+                new_attr_dict[attr] = kernel_size
+            elif attr == "channels":
+                self.channels = attrs[attr]
+                if self.channels is None:
+                    self.channels = kernel_shape[self.weight_channel_axis]
+                if isinstance(self.channels, tvm.tir.expr.IntImm):
+                    self.channels = self.channels.value
+                new_attr_dict[attr] = self.channels
+            elif attr == "padding":
+                # We don't need to put padding in attr dict because we explicitly construct padding
+                self.padding = attrs[attr]
+            else:
+                new_attr_dict[attr] = attr_value
+
+        new_attr_dict["out_dtype"] = "int32"
+        self.attrs = new_attr_dict
+
+    def quantize_args(self):
+        """Helper to quantize the arguments to the qnn.conv2d."""
+        quantized_data = relay.qnn.op.quantize(
+            self.args[0], self.scale_zps[0], self.scale_zps[1], axis=self.data_channel_axis
+        )
+        quantized_weight = relay.qnn.op.quantize(
+            self.args[1], self.scale_zps[2], self.scale_zps[3], axis=self.weight_channel_axis
+        )
+        self.quantized_args = [quantized_data, quantized_weight]
+
+    def create_conv(self, args, node_map):
+        """Creates the qnn.conv2d.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.conv2d.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_conv2d : relay.Expr
+            Quantized version of the pattern.
+        """
+        return relay.qnn.op.conv2d(*args, **self.attrs)
+
+    def callback(self, pre, post, node_map):
+        self.args = [node_map[i][0] for i in self.inputs]
+        conv2d = node_map[self.conv2d][0]
+
+        self.out_dtype = conv2d.checked_type.dtype
+
+        self.get_attrs(conv2d.attrs, infer_type(self.args[1]).checked_type.shape)
+
+        self.create_scale_zps("conv2d_data", "conv2d_weight")
+        self.quantize_args()
+
+        conv_scale = self.scale_zps[0] * self.scale_zps[2]  # data_scale * weight_scale
+
+        # Conv zp is zero since QNN deals with input zps for us
+        conv_zp = relay.const(0, dtype="int32")
+        # args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale]
+        args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]]
+
+        if self.padding is not None:
+
+            top, left, bottom, right = [p.value for p in get_pad_tuple2d(self.padding)]
+            if self.kernel_layout == "OIHW":
+                pad_width = ((0, 0), (0, 0), (top, bottom), (left, right))
+            elif self.kernel_layout == "HWIO":
+                pad_width = (
+                    (top, bottom),
+                    (left, right),
+                    (0, 0),
+                    (0, 0),
+                )
+            pad_val = 0
+            args[0] = relay.op.nn.pad(args[0], pad_width, pad_val)
+
+        # Construct quantized qnn.conv2d and dequantize
+        qnn_call = self.create_conv(args, node_map)
+        dequantized_call = relay.qnn.op.dequantize(
+            qnn_call, conv_scale, conv_zp, out_dtype=self.out_dtype, axis=self.data_channel_axis
+        )
+
+        return dequantized_call
+
+
+class Conv2DBiasAddPattern(Conv2DPattern):
+    """Pattern to rewrite nn.conv2d -> nn.bias_add pattern as qnn.conv2d -> nn.bias_add.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.bias_weight = is_constant()
+        self.inputs.append(self.bias_weight)
+        self.add = is_op("add")(self.conv2d, self.bias_weight)
+        self.bias_add = is_op("nn.bias_add")(self.conv2d, self.bias_weight)
+        self.pattern = self.bias_add | self.add
+
+    def quantize_args(self):
+        """Quantizes the arguments to the nn.conv2d -> nn.bias_add pattern."""
+        super().quantize_args()
+        quantized_bias = relay.qnn.op.quantize(
+            self.args[2], self.scale_zps[0], self.scale_zps[1], axis=0, out_dtype="int32"
+        )
+        self.quantized_args.append(quantized_bias)
+
+    def create_conv(self, args, node_map):
+        """Creates the qnn.dense -> nn.bias_add.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.conv2d and bias_add.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_conv2d : relay.Expr
+            Quantized version of the pattern.
+        """
+        qnn_call = relay.qnn.op.conv2d(*args, **self.attrs)
+        if node_map.get(self.add) is not None:
+            bias_add = relay.op.add(qnn_call, self.quantized_args[2])
+        else:  # self.bias_add in node_map
+            bias_add = relay.op.nn.bias_add(
+                qnn_call, self.quantized_args[2], axis=self.data_channel_axis
+            )
+        return bias_add
+
+
+class DensePattern(QuantizerPattern):
+    """Pattern to rewrite nn.dense pattern as qnn.dense.
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.data = wildcard()
+        self.weight = wildcard()
+        self.inputs = [self.data, self.weight]
+
+        self.dense = is_op("nn.dense")(self.data, self.weight)
+
+        self.pattern = self.dense
+        self.attrs = None
+        self.units = None
+
+    def get_attrs(self, attrs, weight_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.dense
+
+        weight_shape : NDArray
+            Shape of the dense weights
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        self.attrs = {}
+        units = attrs["units"]
+        if units is None:
+            units = weight_shape[0]
+        self.units = units.value
+        self.attrs["units"] = self.units
+
+    def quantize_args(self):
+        """Quantizes the arguments to the nn.dense pattern."""
+        # Quantize data and construct args for qnn.dense
+        quantized_data = relay.qnn.op.quantize(self.args[0], self.scale_zps[0], self.scale_zps[1])
+        quantized_weight = relay.qnn.op.quantize(
+            self.args[1], self.scale_zps[2], self.scale_zps[3], axis=0
+        )  # Axis = 0 for per channel quantization
+        self.quantized_args = [quantized_data, quantized_weight]
+
+    def create_dense(self, args, node_map):
+        """Creates the qnn.dense.
+
+        Parameters
+        ----------
+        args : List[relay.Expr]
+            Quantized arguments for the qnn.dense.
+
+        node_map : tvm.ir.container.Map
+            Node map from DFPatternCallback's callback
+
+        Returns
+        -------
+        q_dense : relay.Expr
+            Quantized version of the pattern.
+        """
+        qnn_call = relay.qnn.op.dense(*args, **self.attrs)
+        return qnn_call
+
+    def callback(self, pre, post, node_map):
+        self.args = [node_map[i][0] for i in self.inputs]
+        weight = node_map[self.weight][0]
+
+        dense = node_map[self.dense][0]
+        out_dtype = dense.checked_type.dtype
+        self.get_attrs(dense.attrs, infer_type(weight).checked_type.shape)
+        self.create_scale_zps("dense_data", "dense_weight")
+        self.quantize_args()
+
+        # args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale]
+        args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]]
+        qnn_call = self.create_dense(args, node_map)
+
+        deq_call = relay.qnn.op.dequantize(
+            qnn_call,
+            self.scale_zps[0] * self.scale_zps[2],
+            relay.const(0, dtype="int32"),
+            out_dtype=out_dtype,
+            axis=1,
+        )
+
+        return deq_call
+
+
+class DenseBiasAddPattern(DensePattern):
+    """Pattern to rewrite nn.dense -> add and nn.dense -> nn.bias_add pattern as
+    qnn.dense -> nn.bias_add.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.bias_weight = is_constant()
+        self.inputs.append(self.bias_weight)
+        self.bias_add = is_op("nn.bias_add")(self.dense, self.bias_weight)
+        self.add = is_op("add")(self.dense, self.bias_weight)
+        self.pattern = self.bias_add | self.add
+
+    def quantize_args(self):
+        super().quantize_args()
+        quantized_bias = relay.qnn.op.quantize(
+            self.args[2], self.scale_zps[0], self.scale_zps[1], axis=0, out_dtype="int32"
+        )
+        self.quantized_args.append(quantized_bias)
+
+    def create_dense(self, args, node_map):
+        qnn_call = relay.qnn.op.dense(*args, **self.attrs)
+        if node_map.get(self.add) is not None:
+            bias_add = relay.op.add(qnn_call, self.quantized_args[2])
+        else:  # self.bias_add in node_map
+            bias_add = relay.op.nn.bias_add(
+                qnn_call, self.quantized_args[2], axis=1  # Axis is always 1 for dense
+            )
+        return bias_add
+
+
+class AddPattern(QuantizerPattern):
+    """Pattern to rewrite add as quantized.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.lhs = wildcard()
+        self.rhs = wildcard()
+        self.add = is_op("add")(self.lhs, self.rhs)
+        self.pattern = self.add
+
+    def callback(self, pre, post, node_map):
+        lhs = node_map[self.lhs][0]
+        rhs = node_map[self.rhs][0]
+
+        add = node_map[self.add][0]
+
+        out_dtype = infer_type(add).checked_type.dtype
+
+        # Create quantization parameters for arguments to this addition
+        self.create_scale_zps("add_lhs", "add_rhs")
+
+        # Quantize, dequantize, and requantize inputs to have scale lhs_scale + rhs_scale
+        # (Scale represents the lowest possible value representable in the quantized type,
+        # so the smallest representable output is lhs_scale + rhs_scale)
+
+        # We do this to avoid the requantize op in qnn's add, which causes issues with compilation
+        # Requantize will be inserted in a future pass
+        lhs_scale, lhs_zp, rhs_scale, rhs_zp = self.scale_zps
+        quantized_lhs = relay.qnn.op.quantize(lhs, lhs_scale, lhs_zp)
+        quantized_rhs = relay.qnn.op.quantize(rhs, rhs_scale, rhs_zp)
+
+        dequantized_lhs = relay.qnn.op.dequantize(
+            quantized_lhs, lhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )
+        dequantized_rhs = relay.qnn.op.dequantize(
+            quantized_rhs, rhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )
+
+        add_scale = relay.op.add(lhs_scale, rhs_scale)
+
+        requantized_lhs = relay.qnn.op.quantize(
+            dequantized_lhs, add_scale, relay.const(0, dtype="int32")
+        )
+        requantized_rhs = relay.qnn.op.quantize(
+            dequantized_rhs, add_scale, relay.const(0, dtype="int32")
+        )
+
+        add = relay.op.add(requantized_lhs, requantized_rhs)
+        dequantized_call = relay.qnn.op.dequantize(
+            add, add_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )
+
+        return dequantized_call
+
+
+class MultiplyPattern(QuantizerPattern):
+    """Pattern to rewrite multiply as quantized.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.lhs = wildcard()
+        self.rhs = wildcard()
+
+        self.multiply = is_op("multiply")(self.lhs, self.rhs)
+        self.pattern = self.multiply
+
+    def callback(self, pre, post, node_map):
+        lhs = node_map[self.lhs][0]
+        rhs = node_map[self.rhs][0]
+
+        multiply = node_map[self.multiply][0]
+
+        out_dtype = infer_type(multiply).checked_type.dtype
+
+        # Create quantization parameters for arguments to this multiplication.
+        self.create_scale_zps("mul_lhs", "mul_rhs")
+        lhs_scale, lhs_zp, rhs_scale, rhs_zp = self.scale_zps
+
+        # Quantize inputs and construct args for multiply
+        quantized_lhs = tvm.relay.cast(relay.qnn.op.quantize(lhs, lhs_scale, lhs_zp), "int32")
+        quantized_rhs = tvm.relay.cast(relay.qnn.op.quantize(rhs, rhs_scale, rhs_zp), "int32")
+
+        # Use normal relay multiply instead of qnn multiply to avoid requantize in qnn.mul
+        # Subtract zero points to center on zero so that we can multiply lhs, rhs directly
+        zeroed_quantized_lhs = relay.op.subtract(quantized_lhs, lhs_zp)
+        zeroed_quantized_rhs = relay.op.subtract(quantized_rhs, rhs_zp)
+
+        multiply = relay.op.multiply(zeroed_quantized_lhs, zeroed_quantized_rhs)
+        dequantized_call = relay.qnn.op.dequantize(
+            multiply, lhs_scale * rhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype
+        )
+
+        return dequantized_call
+
+
+class PerChannelPattern:
+    """A parent class for patterns that will be per-channel quantized. PerChannelPattern should
+    only be inherited by a class that also inherits QuantizerPattern or a subclass of it.
+    """
+
+    def extract_attrs(self, pre, post, node_map):
+        """A callback to get the quantized attributes of this pattern. Usually, we just call
+        self.get_attrs on the attributes of the original, unquantized node to construct the
+        quantized attributes. Since this callback is used by the pattern rewriter, we must return
+        a relay.Expr from it.
+
+        Parameters
+        ----------
+        pre : relay.Expr
+            Expression before transformation
+
+        post : relay.Expr
+            Expression after transformation
+
+        node_map : Map of pattern to relay.Expr
+            Contains expressions matching parts of the pattern.
+
+        Returns
+        -------
+        post : relay.Expr
+            Expression to rewrite the input expression as. We don't actually want to rewrite
+            anything in this pass, so you should just return post.
+        """
+        raise NotImplementedError()
+
+    def get_scale_size(self):
+        """Returns the size of the per-channel scale variable
+
+        Returns
+        -------
+        scale_size : tuple
+            The size of the scale variable
+        """
+        raise NotImplementedError
+
+    def weight_scale(self, name):
+        """Helper to create a variable for a per-channel scale.
+        Parameters
+        ----------
+        name : str
+            Name of the variable
+        """
+        var = relay.var(
+            str(name) + "_scale_" + str(QuantizerPattern.scales_count),
+            shape=self.get_scale_size(),
+            dtype="float32",
+        )
+        QuantizerPattern.scales_count += 1
+        return var
+
+    def create_scale_zps(self, left_name, right_name):
+        """Helper to create scales and zero points for binops, with the per channel weight scale
+        quantized.
+
+        Parameters
+        ----------
+        left_name : str
+            Identifier of the left hand side scale and zero point.
+
+        right_name : str
+            Identifier of the right hand side scale and zero point.
+        """
+        # Create quantization parameters for arguments with per channel on the right
+        data_scale = self.scale(left_name)
+        data_zp = self.zero_point(left_name)
+
+        weight_scale = self.weight_scale(right_name)
+        weight_zp = self.zero_point(right_name)
+        self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp]
+
+    def attr_callback(self, expr):
+        """A function to get the attributes of the quantized version of the current
+        pattern. Meant to be called from inside calibrate_pattern.
+
+        Parameters
+        ----------
+        expr : relay.Expr
+            Expression that we want the attributes from. This will be the unquantized
+            version of the expression.
+        """
+        pattern_ffi.rewrite(
+            [_DFPatternCallback(self.pattern, self.extract_attrs, self.require_type)],

Review comment:
       are the methods in this class tested? I don't find definitions of `self.pattern`, `self.scale` etc in this and derived classes.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r578915157



##########
File path: python/tvm/relay/transform/quantize/_requantizer.py
##########
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Removes extraneous qnn.quantize and qnn.dequantize from calibrated modules, and replaces them
+with qnn.requanize ops."""
+import math
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import DFPatternCallback, wildcard, is_op, dominates, rewrite
+
+
+class Requantizer:
+    """Removes extraneous qnn.quantize and qnn.dequantize and replaces
+    them with qnn.requantize."""
+
+    class RequantizerCallback(DFPatternCallback):
+        """First pass that inserts requantize ops, specifically taking
+        qnn.dequantize -> qnn.quantize to qnn.requantize
+        and
+        qnn.dequantize -> int8_op* -> qnn.quantize to requantize -> int8_op*
+        """
+
+        def __init__(self):
+            super().__init__()
+
+            self.data = wildcard()
+            self.dequantize_scale = wildcard()
+            self.dequantize_zp = wildcard()
+
+            self.quantize_scale = wildcard()
+            self.quantize_zp = wildcard()
+
+            # Ops that are permitted inbetween quantize and dequantize if we are
+            # rewriting to requantize
+            self.is_int_8_op = (
+                is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool3d")(wildcard())
+                | is_op("nn.relu")(wildcard())
+                | is_op("transpose")(wildcard())
+                | is_op("reshape")(wildcard())
+                | is_op("nn.pad")(wildcard())
+                | is_op("squeeze")(wildcard())
+                | is_op("nn.global_avg_pool2d")
+                | is_op("nn.batch_flatten")
+                | is_op("copy")
+                | is_op("mean")
+                | is_op("sqrt")
+            )
+
+            # All ops in is_int_8_op must also be in self.op_map
+            self.op_map = {
+                relay.op.get("nn.max_pool2d"): relay.op.nn.max_pool2d,
+                relay.op.get("nn.max_pool3d"): relay.op.nn.max_pool3d,
+                relay.op.get("transpose"): relay.op.transpose,
+                relay.op.get("reshape"): relay.op.reshape,
+                relay.op.get("nn.pad"): relay.op.nn.pad,
+                relay.op.get("squeeze"): relay.op.squeeze,
+                relay.op.get("nn.global_avg_pool2d"): relay.op.nn.global_avg_pool2d,
+                relay.op.get("nn.batch_flatten"): relay.op.nn.batch_flatten,
+                relay.op.get("copy"): relay.op.copy,
+                relay.op.get("mean"): relay.op.mean,
+                relay.op.get("sqrt"): relay.op.sqrt,
+            }
+
+            # Main pattern -- quantize(is_int_8_op*(dequantize(data))) --
+            # (with 1 or more is_int_8_ops)
+            self.dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+
+            self.dominator = dominates(self.dequantize, self.is_int_8_op, self.is_int_8_op)
+            self.quantize = is_op("qnn.quantize")(
+                self.dominator, self.quantize_scale, self.quantize_zp
+            )
+
+            # Pattern with the null path : quantize(dequantize(data)) -- (no is_int_8_op inbetween)
+            # We have to do the null path outside the dominator pattern because of pattern matcher
+            # limitations
+            self.no_path_dequantize = is_op("qnn.dequantize")(
+                self.data, self.dequantize_scale, self.dequantize_zp
+            )
+            self.no_path_quantize = is_op("qnn.quantize")(
+                self.no_path_dequantize, self.quantize_scale, self.quantize_zp
+            )
+
+            self.pattern = self.quantize | self.no_path_quantize
+
+        def callback(self, pre, post, node_map):
+            # Extract data from the pattern
+            data = node_map[self.data][0]
+            dequantize_scale = node_map[self.dequantize_scale][0]
+            deq_zp = node_map[self.dequantize_zp][0]
+
+            quantize_scale = node_map[self.quantize_scale][0]
+            quantize_zp = node_map[self.quantize_zp][0]
+
+            # Case where there are no ops in between the dequantize and quantize
+            if self.no_path_quantize in node_map:
+                axis = node_map[self.no_path_dequantize][0].attrs.axis
+                res = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+            # Ops inbetween quantize and dequantize are dominated
+            elif self.quantize in node_map:
+
+                axis = node_map[self.dequantize][0].attrs.axis
+                transformed_data = relay.qnn.op.requantize(
+                    data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis
+                )
+                for i in range(len(node_map[self.is_int_8_op]) - 1, -1, -1):
+                    call = node_map[self.is_int_8_op][i]
+                    # Transform relu into max(zeropoint)
+                    if call.op == relay.op.get("nn.relu"):
+                        if (
+                            quantize_zp.data.asnumpy()
+                            == relay.const(0, dtype="int32").data.asnumpy()
+                        ):
+                            transformed_data = relay.op.nn.relu(transformed_data)
+                        else:
+                            transformed_data = relay.op.maximum(
+                                transformed_data, relay.cast(quantize_zp, "int8")
+                            )
+                    elif call.op in self.op_map.keys():
+                        transformed_data = self.op_map[call.op](transformed_data, **call.attrs)
+                    else:
+                        raise ValueError(
+                            "Uh oh, %s is not copied properly in the requantizer. " % str(call.op)
+                        )
+                res = transformed_data
+            return res
+
+    class RequantizeChainCallback(DFPatternCallback):
+        """Folds chains of requantizes into one requantize.
+        requantize(scale_a, zp_a, scale_b, zp_b) -> requantize(scale_b, zp_b, scale_c, zp_c) becomes

Review comment:
       I don't see how this pattern could arise in practice? Do you have an example?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r578911104



##########
File path: python/tvm/relay/transform/quantize/_quantizer_patterns.py
##########
@@ -0,0 +1,712 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Patterns to quantize and how to quantize them."""
+
+import tvm
+from tvm import relay
+
+from tvm.relay.transform.quantize import CalibrationCallback
+from tvm.relay.dataflow_pattern import (
+    is_op,
+    wildcard,
+    is_constant,
+    DFPatternCallback,
+    _DFPatternCallback,
+)
+from tvm.relay.dataflow_pattern import ffi as pattern_ffi
+from tvm.relay.frontend.common import infer_type
+from tvm.relay.op.nn.utils import get_pad_tuple2d
+
+
+class QuantizerPattern(DFPatternCallback):
+    """DFPatternCallback to rewrite patterns as quantized. Also contains extra information
+    used for quantization and calibration.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate the nn.conv2d pattern.
+    """
+
+    # Counts the number of times we've added a scale and zp for variable naming
+    # This needs to be a global variable and not initialized in __init__ because
+    # each scale and zero point must be unique, even if they are created by different
+    # instances.
+    scales_count = 0
+    zp_count = 0
+
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__()
+        self.calibration_callback = calibration_callback
+
+    def calibrate_pattern(self, calibration_info):
+        """Calculates the scale and zero points for quantizing parts of a generic pattern. By
+        default, we call the calibrate_pattern method of the CalibrationCallback object that is
+        passed into QuantizerPattern during initialization. However, if you want a pattern specific
+        quantization method or a per-channel quantization method, you should overwrite the
+        QuantizerPattern's calibrate_pattern method.
+
+        Parameters
+        ----------
+        calibration_info : CalibrationInfo
+            The class containing relevant information and utility functions to calibrate one
+            instance of a pattern.
+
+        Returns
+        -------
+        scale_zp_map : Dictionary
+            A map from the names of scales and zero point variables in this pattern to their
+            values.
+        """
+        return self.calibration_callback.calibrate_pattern(calibration_info)
+
+    def callback(self, pre, post, node_map):
+        raise NotImplementedError
+
+    def scale(self, name):
+        """Helper to create the scale variable for qnn.quantize when rewriting our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the scale variable.
+
+        is_weight : bool
+            Whether this scale is a weight scale or a data scale. If it is a weight scale, we
+            the returned variable has shape (channels,). Only used for per-channel quantization.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_scale_0'.
+        """
+
+        var = relay.var(
+            str(name) + "_scale_" + str(QuantizerPattern.scales_count), shape=(), dtype="float32"
+        )
+        QuantizerPattern.scales_count += 1
+        return var
+
+    def zero_point(self, name):
+        """Helper to create the zero point variable for qnn.quantize when rewriting our
+        our pattern.
+
+        Parameters
+        ----------
+        name : str
+            Identifier at the beginning of the variable.
+
+        Returns
+        -------
+        var : relay.Var
+            Relay variable for scale. If the input name is 'conv2d_data', then the name of the
+            relay variable might be 'conv2d_data_zero_pt_0'.
+        """
+        var = relay.var(
+            str(name) + "_zero_pt_" + str(QuantizerPattern.zp_count), shape=(), dtype="int32"
+        )
+        QuantizerPattern.zp_count += 1
+        return var
+
+    def create_scale_zps(self, left_name, right_name):
+        """Helper to create scales and zero points for binops.
+
+        Parameters
+        ----------
+        left_name : str
+            Identifier of the left hand side scale and zero point.
+
+        right_name : str
+            Identifier of the right hand side scale and zero point.
+        """
+        data_scale = self.scale(left_name)
+        data_zp = self.zero_point(left_name)
+        weight_scale = self.scale(right_name)
+        weight_zp = self.zero_point(right_name)
+        self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp]
+
+
+class Conv2DPattern(QuantizerPattern):
+    """Pattern to rewrite nn.conv2d ops as qnn.conv2d ops.
+
+    Parameters
+    ----------
+    calibration_callback : CalibrationCallback
+        The method we will use to calibrate this pattern.
+    """
+    def __init__(self, calibration_callback: CalibrationCallback = None):
+        super().__init__(calibration_callback)
+        self.input = wildcard()
+        self.conv_weight = wildcard()
+        self.inputs = [self.input, self.conv_weight]
+        self.conv2d = is_op("nn.conv2d")(self.input, self.conv_weight)
+        self.pattern = self.conv2d
+        self.attrs = None
+        self.weight_channel_axis = None
+        self.data_channel_axis = None
+        self.channels = None
+
+    def get_kernel_size(self, kernel_shape, kernel_layout):
+        """Gets the size of the kernel.
+
+        Parameters
+        ----------
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        kernel_layout : str
+            Layout of the kernel
+
+        Returns
+        -------
+            kernel_size : NDArray
+                Size of the kernel
+        """
+        if kernel_layout == "OIHW":
+            kernel_size = tuple(kernel_shape[2:4])
+        elif kernel_layout == "HWIO":
+            kernel_size = tuple(kernel_shape[0:2])
+        else:
+            raise ValueError(
+                "Quantizting kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                kernel_layout,
+            )
+        return kernel_size
+
+    def get_attrs(self, attrs, kernel_shape):
+        """Constructs the attributes for qnn.conv2d.
+
+        Parameters
+        ----------
+        attrs : dict
+            Attributes of the original nn.conv2d
+
+        kernel_shape : NDArray
+            Shape of the kernel
+
+        Returns
+        -------
+            quantized_attrs : dict
+                Attributes for the qnn.conv2d
+        """
+        new_attr_dict = {}
+        self.kernel_layout = attrs["kernel_layout"]
+        data_layout = attrs["data_layout"]
+
+        if self.kernel_layout == "OIHW":
+            self.weight_channel_axis = 0
+        elif self.kernel_layout == "HWIO":
+            self.weight_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing kernel layout %s for conv2d is not yet supported."
+                + "Please use OIHW or HWIO",
+                self.kernel_layout,
+            )
+
+        if data_layout == "NCHW":
+            self.data_channel_axis = 1
+        elif data_layout == "NHWC":
+            self.data_channel_axis = 3
+        else:
+            raise ValueError(
+                "Quantizing data layout %s for conv2d is not yet supported."
+                + "Please use NCHW or NHWC",
+                data_layout,
+            )
+
+        for attr in attrs.keys():
+            attr_value = attrs[attr]
+            if isinstance(attr_value, tvm.ir.container.Array):
+                attr_value = tuple(attr_value)
+            if attr == "kernel_size":
+                kernel_size = attrs[attr]
+                if kernel_size is None:
+                    kernel_size = self.get_kernel_size(self.kernel_layout, kernel_shape)
+                else:
+                    kernel_size = tuple([k.value for k in attrs[attr]])
+                new_attr_dict[attr] = kernel_size
+            elif attr == "channels":
+                self.channels = attrs[attr]
+                if self.channels is None:
+                    self.channels = kernel_shape[self.weight_channel_axis]
+                if isinstance(self.channels, tvm.tir.expr.IntImm):
+                    self.channels = self.channels.value
+                new_attr_dict[attr] = self.channels
+            elif attr == "padding":
+                # We don't need to put padding in attr dict because we explicitly construct padding
+                self.padding = attrs[attr]
+            else:
+                new_attr_dict[attr] = attr_value
+
+        new_attr_dict["out_dtype"] = "int32"
+        self.attrs = new_attr_dict
+
+    def quantize_args(self):
+        """Helper to quantize the arguments to the qnn.conv2d."""
+        quantized_data = relay.qnn.op.quantize(
+            self.args[0], self.scale_zps[0], self.scale_zps[1], axis=self.data_channel_axis

Review comment:
       You should be explicit about what data type you are quantizing to. As rewritten this way, you are implicitly assuming that we are always doing symmetric quantization. Asymmetric quantation should also be supported.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7474: Quantization in TVM

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7474:
URL: https://github.com/apache/tvm/pull/7474#discussion_r578913474



##########
File path: python/tvm/relay/transform/quantize/_requantizer.py
##########
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Removes extraneous qnn.quantize and qnn.dequantize from calibrated modules, and replaces them
+with qnn.requanize ops."""
+import math
+
+import tvm
+from tvm import relay
+from tvm.relay.dataflow_pattern import DFPatternCallback, wildcard, is_op, dominates, rewrite
+
+
+class Requantizer:
+    """Removes extraneous qnn.quantize and qnn.dequantize and replaces
+    them with qnn.requantize."""
+
+    class RequantizerCallback(DFPatternCallback):
+        """First pass that inserts requantize ops, specifically taking
+        qnn.dequantize -> qnn.quantize to qnn.requantize
+        and
+        qnn.dequantize -> int8_op* -> qnn.quantize to requantize -> int8_op*
+        """
+
+        def __init__(self):
+            super().__init__()
+
+            self.data = wildcard()
+            self.dequantize_scale = wildcard()
+            self.dequantize_zp = wildcard()
+
+            self.quantize_scale = wildcard()
+            self.quantize_zp = wildcard()
+
+            # Ops that are permitted inbetween quantize and dequantize if we are
+            # rewriting to requantize
+            self.is_int_8_op = (
+                is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool2d")(wildcard())
+                | is_op("nn.max_pool3d")(wildcard())
+                | is_op("nn.relu")(wildcard())
+                | is_op("transpose")(wildcard())
+                | is_op("reshape")(wildcard())
+                | is_op("nn.pad")(wildcard())
+                | is_op("squeeze")(wildcard())
+                | is_op("nn.global_avg_pool2d")
+                | is_op("nn.batch_flatten")
+                | is_op("copy")
+                | is_op("mean")
+                | is_op("sqrt")
+            )

Review comment:
       This is too ad hoc, it can easily break and leaves more dequantize/quantize than necessary




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org