You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/02/13 04:15:14 UTC

[tvm] branch main updated: [ONNX] Make the ONNX Importer More Static (#7429)

This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e211a7  [ONNX] Make the ONNX Importer More Static (#7429)
4e211a7 is described below

commit 4e211a735221a9b9d188422025e2d464e37b3c96
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Feb 12 21:14:56 2021 -0700

    [ONNX] Make the ONNX Importer More Static (#7429)
    
    * Construct static Ops if inputs are Constant
    
    * Expose FoldConstant as a function in addition to the pass
    
    * refactor onnx importer to do more static imports by constant folding
    
    fix pylint
    
    * fix test regressions
    
    * fix style, two bugs
    
    * pipe freeze_params through sub_graphs when importing loops and control flow
---
 python/tvm/relay/frontend/common.py       |   6 +
 python/tvm/relay/frontend/onnx.py         | 198 +++++++++++++++++-------------
 python/tvm/relay/op/image/image.py        |   4 +-
 python/tvm/relay/op/nn/nn.py              |  16 ++-
 python/tvm/relay/op/tensor.py             |   6 +-
 python/tvm/relay/op/transform.py          |  18 ++-
 python/tvm/relay/transform/transform.py   |  17 +++
 src/relay/transforms/fold_constant.cc     |   2 +
 tests/python/relay/test_op_grad_level3.py |   2 +-
 9 files changed, 180 insertions(+), 89 deletions(-)

diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py
index 6323c63..2db420a 100644
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -491,6 +491,12 @@ def infer_type(node, mod=None):
     return ret
 
 
+def fold_constant(node, mod=None):
+    if mod is None:
+        mod = IRModule.from_expr(node)
+    return _transform.FoldConstantExpr(node, mod)
+
+
 def infer_channels(inputs, transpose=False):
     """A hack for getting 'channels' or 'units' since caffe2 does not provide
     these attributes. We check the shape of weights provided to get the number.
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index c9140d7..fb3d1c9 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -34,7 +34,7 @@ from .. import loops as _loops
 from .. import ty as _ty
 
 from .common import AttrCvt, Renamer
-from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value
+from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value, fold_constant
 from .common import infer_type, get_name
 
 
@@ -364,7 +364,7 @@ def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", d
         ),
         dtype="int64",
     )
-    shape = _op.strided_slice(_op.shape_of(data, dtype="int64"), [2], [ndim])
+    shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim])
     # get input shape
 
     # set up integer constants
@@ -545,9 +545,9 @@ class MatMul(OnnxOpConverter):
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
         # Need to check input shape as batch matmul must be supported.
-        a_shape = _op.shape_of(inputs[0])
+        a_shape = shape_of(inputs[0])
         a_rank = infer_shape(a_shape)[0]
-        b_shape = _op.shape_of(inputs[1])
+        b_shape = shape_of(inputs[1])
         b_rank = infer_shape(b_shape)[0]
         # When performing a batch matmul, we need to properly handle N-dim shapes.
         if a_rank > 2 or b_rank > 2:
@@ -555,9 +555,13 @@ class MatMul(OnnxOpConverter):
             def flatten_to_3d(x, x_shape):
                 ndims = infer_shape(x_shape)[0]
                 newshape = _op.concatenate(
-                    [_expr.const([-1]), _op.strided_slice(x_shape, [ndims - 2], [ndims])], 0
+                    [
+                        _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
+                        _op.strided_slice(x_shape, [ndims - 2], [ndims]),
+                    ],
+                    0,
                 )
-                out = _op.reshape(x, newshape)
+                out = _op.reshape(x, fold_constant(newshape))
                 return out
 
             # Convert a and b into 3 dimensional tensors.
@@ -598,7 +602,7 @@ class MatMul(OnnxOpConverter):
                 ],
                 0,
             )
-            return _op.reshape(output, final_shape)
+            return _op.reshape(output, fold_constant(final_shape))
         # Otherwise a simple dense op will get the job done.
         input_1_t = _op.transpose(inputs[1], axes=(1, 0))
         return _op.nn.dense(inputs[0], input_1_t)
@@ -646,7 +650,7 @@ class MaxUnpool(OnnxOpConverter):
         multiplier = _op.concatenate(
             [_expr.const([1, 1], dtype="int64"), _expr.const(list(strides), dtype="int64")], axis=0
         )
-        total_output_shape = multiplier * _op.shape_of(data, dtype="int64")
+        total_output_shape = multiplier * shape_of(data, dtype="int64")
         # Add extra dimensions from kernel size and stride mismatch
         total_output_shape += _op.concatenate(
             [_expr.const([0, 0], "int64"), _expr.const(list(kernel_shape), "int64")], axis=0
@@ -792,11 +796,11 @@ class Pad(OnnxOpConverter):
     def _impl_v11(cls, inputs, attr, params):
         pads = inputs[1]
         if len(inputs) == 3:
-            value = _op.take(inputs[2], _op.const(0))
+            value = fold_constant(_op.take(inputs[2], _op.const(0)))
         else:
             value = 0
 
-        pad_width_expr = _op.transpose(_op.reshape(pads, (2, -1)))
+        pad_width_expr = fold_constant(_op.transpose(_op.reshape(pads, (2, -1))))
         pad_mode = attr.get("mode", b"constant").decode("utf-8")
 
         if not pad_mode in ["constant", "edge", "reflect"]:
@@ -823,7 +827,7 @@ class Prelu(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
-        input_shape = _op.shape_of(inputs[0])
+        input_shape = shape_of(inputs[0])
         alpha = _op.broadcast_to_like(inputs[1], inputs[0])
         alpha = _op.reshape(alpha, [-1])
         output = _op.nn.prelu(_op.reshape(inputs[0], [-1]), alpha, axis=0)
@@ -875,7 +879,6 @@ class DepthToSpace(OnnxOpConverter):
 
     @classmethod
     def _impl_v11(cls, inputs, attr, params):
-
         block_size = int(attr["blocksize"])
         mode = attr.get("mode", b"DCR").decode("utf-8")
         return _op.nn.depth_to_space(inputs[0], block_size, mode=mode)
@@ -1015,8 +1018,9 @@ class Upsample(OnnxOpConverter):
                 scales = params[inputs[1].name_hint].asnumpy()
             else:
                 scales = inputs[1]
-
-        if not isinstance(scales, _expr.Call):
+        if isinstance(scales, _expr.Constant):
+            scales = list(scales.data.asnumpy())
+        if not isinstance(scales, _expr.Expr):
             assert scales[0] == 1.0 and scales[1] == 1.0
 
         mode = attr.get("mode")
@@ -1067,12 +1071,20 @@ class Upsample(OnnxOpConverter):
         return out
 
 
+def shape_of(x, dtype="int64"):
+    ttype = infer_type(x).checked_type
+    if not _ty.is_dynamic(ttype):
+        shape = list(ttype.shape)
+        return _expr.const(shape, dtype)
+    return _op.shape_of(x, dtype)
+
+
 class Shape(OnnxOpConverter):
     """Operator converter for Shape."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        return _op.shape_of(inputs[0], "int64")
+        return shape_of(inputs[0], "int64")
 
 
 class CumSum(OnnxOpConverter):
@@ -1204,7 +1216,7 @@ class Slice(OnnxOpConverter):
 
         # Update the starts and ends according to axes if required.
         if axes is not None:
-            data_shape = _op.shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype)
+            data_shape = shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype)
             starts = _op.scatter(
                 _op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype),
                 axes,
@@ -1223,7 +1235,9 @@ class Slice(OnnxOpConverter):
         if steps is None:
             steps = _op.const([1] * data_rank, dtype=infer_type(starts).checked_type.dtype)
 
-        return _op.strided_slice(inputs[0], starts, ends, steps)
+        return _op.strided_slice(
+            inputs[0], fold_constant(starts), fold_constant(ends), fold_constant(steps)
+        )
 
 
 class Gather(OnnxOpConverter):
@@ -1531,6 +1545,19 @@ class ConstantOfShape(OnnxOpConverter):
         return output
 
 
+class Constant(OnnxOpConverter):
+    """Operator converter for ConstantOfShape."""
+
+    @classmethod
+    def _impl_v9(cls, inputs, attr, params):
+        if "value" not in attr:
+            raise "No Value in Constant"
+        np_value = get_numpy(attr.pop("value"))
+        dtype = np_value.dtype.name
+        value = _expr.const(np_value, dtype)
+        return value
+
+
 class Sign(OnnxOpConverter):
     """Operator converter for Sign."""
 
@@ -1591,12 +1618,14 @@ class Where(OnnxOpConverter):
         # to that shape.
         max_rank = max(ranks)
         max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank]
-        broadcast_shape = _op.shape_of(inputs[max_rank_idxs[0]])
+        broadcast_shape = shape_of(inputs[max_rank_idxs[0]])
         # If two or more inputs have the same rank, compute the broadcast
         # shape by taking the maximum value of each dimensions.
         if len(max_rank_idxs) > 1:
             for idx in max_rank_idxs:
-                broadcast_shape = _op.maximum(broadcast_shape, _op.shape_of(inputs[idx]))
+                broadcast_shape = _op.maximum(broadcast_shape, shape_of(inputs[idx]))
+
+        broadcast_shape = fold_constant(broadcast_shape)
 
         condition = _op.broadcast_to(inputs[0], broadcast_shape)
         x = _op.broadcast_to(inputs[1], broadcast_shape)
@@ -1618,7 +1647,7 @@ class Expand(OnnxOpConverter):
     @classmethod
     def _impl_v8(cls, inputs, attr, params):
         dtype = infer_type(inputs[1]).checked_type.dtype
-        in_shape = _op.shape_of(inputs[0], dtype=dtype)
+        in_shape = shape_of(inputs[0], dtype=dtype)
         shape = inputs[1]
 
         # Currently 'op.broadcast_to' expect the rank of the given 'shape'
@@ -1667,7 +1696,7 @@ class Expand(OnnxOpConverter):
             new_shape = _op.maximum(in_shape, shape)
             return new_shape
 
-        shape = expand_shape(in_shape, shape)
+        shape = fold_constant(expand_shape(in_shape, shape))
         return _op.broadcast_to(inputs[0], shape=shape)
 
 
@@ -1942,10 +1971,9 @@ class Resize(OnnxOpConverter):
             )
 
         scale = inputs[1]
-        size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
-
+        size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
         layout = "NCHW"  # ONNX assumes NCHW layout
-        out_size = _op.strided_slice(size, [2], [4])
+        out_size = fold_constant(_op.strided_slice(size, [2], [4]))
         return _op.image.resize(inputs[0], out_size, layout, method, "asymmetric")
 
     @classmethod
@@ -1969,7 +1997,7 @@ class Resize(OnnxOpConverter):
             size = inputs[3]
         else:
             assert len(scale_shape) != 0, "One of scale or size should be passed."
-            size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
+            size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
 
         coord_trans = attr.get("coordinate_transformation_mode")
         if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]:
@@ -1983,7 +2011,7 @@ class Resize(OnnxOpConverter):
                 "Unsupported coordinate_transformation_mode: {}".format(coord_trans)
             )
         layout = "NCHW"  # ONNX assumes NCHW layout
-        out_size = _op.strided_slice(size, [2], [4])
+        out_size = fold_constant(_op.strided_slice(size, [2], [4]))
         return _op.image.resize(inputs[0], out_size, layout, method, coord_trans)
 
 
@@ -2152,7 +2180,9 @@ class Loop(OnnxOpConverter):
 
         # Get the current graph proto and create a clone for the subgraph
         graph_scope = GraphProto.current
-        subgraph_scope = GraphProto(graph_scope._shape, graph_scope._dtype)
+        subgraph_scope = GraphProto(
+            graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params
+        )
         # Load nodes from outer graph into inner graph.
         subgraph_scope._nodes = graph_scope._nodes.copy()
 
@@ -2246,7 +2276,7 @@ class Loop(OnnxOpConverter):
                 expand_scan = _op.expand_dims(new_scan, axis=0)
                 # For non scalar outputs we need to broadcast the initial value.
                 if rank > 0:
-                    new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype)
+                    new_scan_shape = shape_of(new_scan, dtype=iter_dtype)
                     scan_broadcast = _op.concatenate(
                         [_op.reshape(loop_count, [1]), new_scan_shape], axis=0
                     )
@@ -2264,7 +2294,7 @@ class Loop(OnnxOpConverter):
             return [loop_count, max_count, new_cond] + new_loop_vars + combined_scan_outputs
 
         # Create the loop function.
-        loop = _loops.while_loop(cond_fn, loop_vars + scan_output_vars, body_fn)
+        loop = fold_constant(_loops.while_loop(cond_fn, loop_vars + scan_output_vars, body_fn))
 
         # Now need to run initial values through the graph.
         init_count = _expr.const(0, dtype=iter_dtype)
@@ -2287,6 +2317,7 @@ class Loop(OnnxOpConverter):
         # Update outer graph with constants found in the subgraph.
         free_vars = analysis.free_vars(loop)
         graph_scope._params.update(subgraph_scope._params)
+        graph_scope._nodes.update(subgraph_scope._nodes)
         for var in free_vars:
             graph_scope._nodes.update({var.name_hint: var})
         return outputs
@@ -2307,9 +2338,9 @@ class If(OnnxOpConverter):
 
         # Create graph converters for both branches.
         graph_scope = GraphProto.current
-        then_graph = GraphProto(graph_scope._shape, graph_scope._dtype)
+        then_graph = GraphProto(graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params)
         then_graph._nodes = graph_scope._nodes.copy()
-        else_graph = GraphProto(graph_scope._shape, graph_scope._dtype)
+        else_graph = GraphProto(graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params)
         else_graph._nodes = graph_scope._nodes.copy()
 
         # Convert each branch to a relay expression.
@@ -2320,10 +2351,12 @@ class If(OnnxOpConverter):
 
         # Add constants from both branches to parent graph.
         graph_scope._params.update(then_graph._params)
+        graph_scope._nodes.update(then_graph._nodes)
         then_free_vars = analysis.free_vars(then_expr)
         for var in then_free_vars:
             graph_scope._nodes.update({var.name_hint: var})
         graph_scope._params.update(else_graph._params)
+        graph_scope._nodes.update(else_graph._nodes)
         else_free_vars = analysis.free_vars(else_expr)
         for var in else_free_vars:
             graph_scope._nodes.update({var.name_hint: var})
@@ -2468,9 +2501,9 @@ class NonMaxSuppression(OnnxOpConverter):
             # partially prepare ONNX output format by labeling batch_num, class_id
             nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1)
             batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1)
-            batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], dtype="int64"))
+            batch_num = _op.broadcast_to(batch_num, shape_of(nms_ret[0], dtype="int64"))
             batch_num = _op.expand_dims(batch_num, -1, 1)
-            class_num = _op.broadcast_to(i, _op.shape_of(nms_padded_out, dtype="int64"))
+            class_num = _op.broadcast_to(i, shape_of(nms_padded_out, dtype="int64"))
             new_onnx_out = _op.concatenate(
                 [batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1
             )
@@ -2570,7 +2603,7 @@ class NonMaxSuppression(OnnxOpConverter):
         )
 
         # Call the first loop, perform NMS
-        B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3)
+        B, C, S = _op.split(shape_of(scores, dtype="int64"), 3)
         init_count = _op.const(np.array([0]), dtype="int64")
         init_onnx_out = _op.const([1], dtype="int64")
         init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0))
@@ -2617,6 +2650,7 @@ def _get_convert_map(opset):
         "ThresholdedRelu": ThresholdedRelu.get_converter(opset),
         "ScaledTanh": ScaledTanh.get_converter(opset),
         "ParametricSoftplus": ParametricSoftPlus.get_converter(opset),
+        "Constant": Constant.get_converter(opset),
         "ConstantOfShape": ConstantOfShape.get_converter(opset),
         # 'GivenTensorFill'
         "FC": AttrCvt("dense", ignores=["axis", "axis_w"]),
@@ -2776,11 +2810,19 @@ class GraphProto:
 
     dtype : str or dict of str to str
         The input types to the graph
+
+    freeze_params: bool
+        If this parameter is true, the importer will take any provided
+        onnx input values (weights, shapes, etc) and embed them into the relay model
+        as Constants instead of variables. This allows more aggressive optimizations
+        at compile time and helps in making models static if certain inputs represent
+        attributes relay would traditionally consider compile-time constants.
+
     """
 
     current = None
 
-    def __init__(self, shape, dtype):
+    def __init__(self, shape, dtype, freeze_params=False):
         self._nodes = {}
         self._params = {}
         self._inputs = {}
@@ -2790,6 +2832,7 @@ class GraphProto:
         self._shape = shape if shape else {}
         self._dtype = dtype
         self.opset = None
+        self._freeze_params = freeze_params
 
     def __enter__(self):
         self._old_manager = GraphProto.current
@@ -2808,7 +2851,7 @@ class GraphProto:
         fn = _function.Function(analysis.free_vars(body), body)
         return fn, {}
 
-    def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
+    def from_onnx(self, graph, opset, get_output_expr=False):
         """Construct Relay expression from ONNX graph.
 
         Onnx graph is a python protobuf object.
@@ -2825,13 +2868,6 @@ class GraphProto:
 
         opset : opset version
 
-        freeze_params: bool
-            If this parameter is true, the importer will take any provided
-            onnx input values (weights, shapes, etc) and embed them into the relay model
-            as Constants instead of variables. This allows more aggressive optimizations
-            at compile time and helps in making models static if certain inputs represent
-            attributes relay would traditionally consider compile-time constants.
-
         get_output_expr: bool
             If set to true, this conversion will return each output expression rather
             than a packaged module. This can be useful when converting subgraphs to
@@ -2850,12 +2886,16 @@ class GraphProto:
         for init_tensor in graph.initializer:
             if not init_tensor.name.strip():
                 raise ValueError("Tensor's name is required.")
-            self._params[init_tensor.name] = self._parse_array(init_tensor)
-            self._nodes[init_tensor.name] = new_var(
-                init_tensor.name,
-                shape=self._params[init_tensor.name].shape,
-                dtype=self._params[init_tensor.name].dtype,
-            )
+            array = self._parse_array(init_tensor)
+            if self._freeze_params:
+                self._nodes[init_tensor.name] = _expr.const(array)
+            else:
+                self._params[init_tensor.name] = array
+                self._nodes[init_tensor.name] = new_var(
+                    init_tensor.name,
+                    shape=self._params[init_tensor.name].shape,
+                    dtype=self._params[init_tensor.name].dtype,
+                )
         for i in graph.input:
             # from onnx v0.2, GraphProto.input has type ValueInfoProto,
             #  and the name is 'i.name'
@@ -2867,6 +2907,8 @@ class GraphProto:
                 self._nodes[i_name] = new_var(
                     i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype
                 )
+            elif i_name in self._nodes:
+                continue
             else:
                 self._num_input += 1
                 if i_name in self._shape:
@@ -2909,37 +2951,28 @@ class GraphProto:
             for i in node.input:
                 if i != "":
                     inputs[i] = self._nodes[self._renames.get(i, i)]
-            if op_name == "Constant":
-                t_proto = self._parse_attr(node.attribute)["value"]
-                self._num_param += 1
-                # We should convert scalar integers to int32, to normalize.
-                array = self._parse_array(t_proto)
-                self._params[node.output[0]] = array
-                self._nodes[node.output[0]] = new_var(
-                    node.output[0], shape=list(t_proto.dims), dtype=array.dtype
-                )
+            i_name = self._parse_value_proto(node)
+            node_output = self._fix_outputs(op_name, node.output)
+            attr["tvm_custom"] = {}
+            attr["tvm_custom"]["name"] = i_name
+            attr["tvm_custom"]["num_outputs"] = len(node_output)
+
+            op = self._convert_operator(op_name, inputs, attr, opset)
+            if not isinstance(op, _expr.TupleWrapper):
+                outputs_num = 1
             else:
-                i_name = self._parse_value_proto(node)
-                node_output = self._fix_outputs(op_name, node.output)
-                attr["tvm_custom"] = {}
-                attr["tvm_custom"]["name"] = i_name
-                attr["tvm_custom"]["num_outputs"] = len(node_output)
-
-                op = self._convert_operator(op_name, inputs, attr, opset)
-                if not isinstance(op, _expr.TupleWrapper):
-                    outputs_num = 1
-                else:
-                    outputs_num = len(op)
-                assert (
-                    len(node_output) == outputs_num
-                ), "Number of output mismatch {} vs {} in {}.".format(
-                    len(node_output), outputs_num, op_name
-                )
-                if outputs_num == 1:
-                    self._nodes[node_output[0]] = op
-                else:
-                    for k, i in zip(list(node_output), range(len(node_output))):
-                        self._nodes[k] = op[i]
+                outputs_num = len(op)
+            assert (
+                len(node_output) == outputs_num
+            ), "Number of output mismatch {} vs {} in {}.".format(
+                len(node_output), outputs_num, op_name
+            )
+            if outputs_num == 1:
+                self._nodes[node_output[0]] = fold_constant(op)
+            else:
+                op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))
+                for k, i in zip(list(node_output), range(len(node_output))):
+                    self._nodes[k] = op[i]
 
         # now return the outputs
         outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
@@ -2957,9 +2990,6 @@ class GraphProto:
                 self._inputs[i_name] = self._nodes[i_name]
         # Create a function from our output expression and all input variables.
         func = _function.Function([v for k, v in self._inputs.items()], outputs)
-        if freeze_params:
-            func, params = self.freeze(func, self._params)
-            return IRModule.from_expr(func), params
         return IRModule.from_expr(func), self._params
 
     def _parse_value_proto(self, value_proto):
@@ -3100,7 +3130,7 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
                 warnings.warn(str(e))
     except ImportError:
         pass
-    g = GraphProto(shape, dtype)
+    g = GraphProto(shape, dtype, freeze_params)
     graph = model.graph
     if opset is None:
         try:
@@ -3109,5 +3139,5 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
             opset = 1
     # Use the graph proto as a scope so that ops can access other nodes if needed.
     with g:
-        mod, params = g.from_onnx(graph, opset, freeze_params)
+        mod, params = g.from_onnx(graph, opset)
     return mod, params
diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py
index a3f3a3e..153439b 100644
--- a/python/tvm/relay/op/image/image.py
+++ b/python/tvm/relay/op/image/image.py
@@ -17,7 +17,7 @@
 """Image operations."""
 from . import _make
 from ..dyn.image import _make as _dyn_make
-from ...expr import Expr
+from ...expr import Expr, Constant
 
 
 def resize(
@@ -66,6 +66,8 @@ def resize(
     result: relay.Expr
         The resized result.
     """
+    if isinstance(size, Constant):
+        size = list(size.data.asnumpy().astype("int32"))
     if isinstance(size, Expr):
         return _dyn_make.resize(
             data, size, layout, method, coordinate_transformation_mode, out_dtype
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 0c233a6..5135ac7 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -21,7 +21,7 @@ from tvm.relay import expr
 from . import _make
 from ..dyn.nn import _make as _dyn_make
 from .utils import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d
-from ...expr import const, Expr
+from ...expr import const, Expr, Constant
 
 
 def conv1d(
@@ -1279,6 +1279,10 @@ def upsampling(
     result : tvm.relay.Expr
         The computed result.
     """
+    if isinstance(scale_h, Constant):
+        scale_h = scale_h.data.asnumpy().item()
+    if isinstance(scale_w, Constant):
+        scale_w = scale_w.data.asnumpy().item()
     if isinstance(scale_h, Expr) or isinstance(scale_w, Expr):
         if not isinstance(scale_h, Expr):
             scale_h = const(scale_h, "float64")
@@ -1338,6 +1342,12 @@ def upsampling3d(
     result : tvm.relay.Expr
         The computed result.
     """
+    if isinstance(scale_d, Constant):
+        scale_d = scale_d.data.asnumpy().item()
+    if isinstance(scale_h, Constant):
+        scale_h = scale_h.data.asnumpy().item()
+    if isinstance(scale_w, Constant):
+        scale_w = scale_w.data.asnumpy().item()
     if isinstance(scale_d, Expr) or isinstance(scale_h, Expr) or isinstance(scale_w, Expr):
         if not isinstance(scale_d, Expr):
             scale_d = const(scale_d, "float64")
@@ -1596,6 +1606,10 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"):
     result : tvm.relay.Expr
         The computed result.
     """
+    if isinstance(pad_value, Constant):
+        pad_value = pad_value.data.asnumpy().item()
+    if isinstance(pad_width, Constant):
+        pad_width = [list(i) for i in pad_width.data.asnumpy()]
     if isinstance(pad_width, Expr) or (isinstance(pad_value, Expr)):
         if not isinstance(pad_width, Expr):
             pad_width = const(list(pad_width))
diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py
index 75e2987..5b01104 100644
--- a/python/tvm/relay/op/tensor.py
+++ b/python/tvm/relay/op/tensor.py
@@ -22,7 +22,7 @@ from tvm.te.hybrid import script
 
 from . import _make
 from .dyn import _make as _dyn_make
-from ..expr import Tuple, Expr
+from ..expr import Tuple, Expr, Constant
 from . import op as reg
 
 
@@ -960,6 +960,8 @@ def zeros(shape, dtype):
     result : relay.Expr
         The resulting tensor.
     """
+    if isinstance(shape, Constant):
+        shape = list(shape.data.asnumpy())
     if isinstance(shape, Expr):
         return _dyn_make.zeros(shape, dtype)
     if isinstance(shape, int):
@@ -1001,6 +1003,8 @@ def ones(shape, dtype):
     result : relay.Expr
         The resulting tensor.
     """
+    if isinstance(shape, Constant):
+        shape = list(shape.data.asnumpy())
     if isinstance(shape, Expr):
         return _dyn_make.ones(shape, dtype)
     if isinstance(shape, int):
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index d42ef47..cda417c 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -21,7 +21,7 @@
 from . import _make
 from .dyn import _make as _dyn_make
 from .tensor import shape_of
-from ..expr import TupleWrapper, const, Expr, Tuple
+from ..expr import TupleWrapper, const, Constant, Expr, Tuple
 from ...tir import expr as _expr
 
 
@@ -216,6 +216,8 @@ def reshape(data, newshape):
     result : relay.Expr
         The reshaped result.
     """
+    if isinstance(newshape, Constant):
+        newshape = list(newshape.data.asnumpy())
     if isinstance(newshape, Expr):
         return _dyn_make.reshape(data, newshape)
     if isinstance(newshape, int):
@@ -431,6 +433,8 @@ def full(fill_value, shape=(), dtype=""):
     result : relay.Expr
         The resulting tensor.
     """
+    if isinstance(shape, Constant):
+        shape = list(shape.data.asnumpy())
     if isinstance(shape, Expr):
         return _dyn_make.full(fill_value, shape, dtype)
     if isinstance(shape, int):
@@ -614,6 +618,8 @@ def tile(data, reps):
     data is promoted to be d-dimensional by prepending new axes.
     If data.ndim >=  d, reps is promoted to a.ndim by pre-pending 1's to it.
     """
+    if isinstance(reps, Constant):
+        reps = list(reps.data.asnumpy())
     if isinstance(reps, Expr):
         return _dyn_make.tile(data, reps)
     return _make.tile(data, reps)
@@ -753,6 +759,8 @@ def broadcast_to(data, shape):
     result : relay.Expr
         The resulting tensor.
     """
+    if isinstance(shape, Constant):
+        shape = list(shape.data.asnumpy())
     if isinstance(shape, Expr):
         return _dyn_make.broadcast_to(data, shape)
     if isinstance(shape, int):
@@ -884,6 +892,12 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
         The computed result.
     """
     strides = strides or [1]
+    if isinstance(begin, Constant):
+        begin = list(begin.data.asnumpy())
+    if isinstance(end, Constant):
+        end = list(end.data.asnumpy())
+    if isinstance(strides, Constant):
+        strides = list(strides.data.asnumpy())
     if isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr):
         if isinstance(begin, (tuple, list)):
             begin = const(list(begin))
@@ -1170,6 +1184,8 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
              [0, 1, 0],
              [0, 0, 1]]
     """
+    if isinstance(depth, Constant):
+        depth = depth.data.asnumpy().item()
     if isinstance(depth, Expr):
         return _dyn_make.one_hot(indices, on_value, off_value, depth, axis, dtype)
     return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py
index c6df8c1..f02f835 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -240,6 +240,23 @@ def LazyGradientInit():
     return _ffi_api.LazyGradientInit()
 
 
+def FoldConstantExpr(expr, mod):
+    """Fold the constant expressions in a Relay program.
+    Parameters
+    ----------
+    expr: Expr
+        The expression to fold
+    mod: IRModule
+        The module the expr lives in (for global calls)
+
+    Returns
+    -------
+    new_expr: Expr
+        The expr after Constant Folding
+    """
+    return _ffi_api.FoldConstantExpr(expr, mod)
+
+
 def FoldConstant():
     """Fold the constant expressions in a Relay program.
 
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index 4454c9c..9416b0e 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -374,6 +374,8 @@ Expr FoldConstant(const Expr& expr, const IRModule& mod) {
   return ConstantFolder(mod).Mutate(expr);
 }
 
+TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr").set_body_typed(FoldConstant);
+
 namespace transform {
 
 Pass FoldConstant() {
diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py
index 904576a..d43744b 100644
--- a/tests/python/relay/test_op_grad_level3.py
+++ b/tests/python/relay/test_op_grad_level3.py
@@ -146,7 +146,7 @@ def test_zeros_ones_grad_const_ints():
 
 def test_zeros_ones_grad_const_expr():
     # when shape is static (i.e. not an input), there is no gradient at all
-    shape_const = relay.const(np.array([2, 3, 4]), dtype="int32")
+    shape_const = relay.const(np.array([2, 3, 4]), dtype="int32") * relay.const(1, dtype="int32")
     static_ty = relay.TensorType([2, 3, 4], dtype="float32")
     dyn_ty = relay.TensorType([relay.Any(), relay.Any(), relay.Any()], dtype="float32")
     expected_ty_static = relay.TupleType([static_ty, relay.TupleType([])])