You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/02/13 04:15:14 UTC
[tvm] branch main updated: [ONNX] Make the ONNX Importer More
Static (#7429)
This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4e211a7 [ONNX] Make the ONNX Importer More Static (#7429)
4e211a7 is described below
commit 4e211a735221a9b9d188422025e2d464e37b3c96
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Feb 12 21:14:56 2021 -0700
[ONNX] Make the ONNX Importer More Static (#7429)
* Construct static Ops if inputs are Constant
* Expose FoldConstant as a function in addition to the pass
* refactor onnx importer to do more static imports by constant folding
fix pylint
* fix test regressions
* fix style, two bugs
* pipe freeze_params through sub_graphs when importing loops and control flow
---
python/tvm/relay/frontend/common.py | 6 +
python/tvm/relay/frontend/onnx.py | 198 +++++++++++++++++-------------
python/tvm/relay/op/image/image.py | 4 +-
python/tvm/relay/op/nn/nn.py | 16 ++-
python/tvm/relay/op/tensor.py | 6 +-
python/tvm/relay/op/transform.py | 18 ++-
python/tvm/relay/transform/transform.py | 17 +++
src/relay/transforms/fold_constant.cc | 2 +
tests/python/relay/test_op_grad_level3.py | 2 +-
9 files changed, 180 insertions(+), 89 deletions(-)
diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py
index 6323c63..2db420a 100644
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -491,6 +491,12 @@ def infer_type(node, mod=None):
return ret
+def fold_constant(node, mod=None):
+ if mod is None:
+ mod = IRModule.from_expr(node)
+ return _transform.FoldConstantExpr(node, mod)
+
+
def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number.
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index c9140d7..fb3d1c9 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -34,7 +34,7 @@ from .. import loops as _loops
from .. import ty as _ty
from .common import AttrCvt, Renamer
-from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value
+from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value, fold_constant
from .common import infer_type, get_name
@@ -364,7 +364,7 @@ def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", d
),
dtype="int64",
)
- shape = _op.strided_slice(_op.shape_of(data, dtype="int64"), [2], [ndim])
+ shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim])
# get input shape
# set up integer constants
@@ -545,9 +545,9 @@ class MatMul(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
# Need to check input shape as batch matmul must be supported.
- a_shape = _op.shape_of(inputs[0])
+ a_shape = shape_of(inputs[0])
a_rank = infer_shape(a_shape)[0]
- b_shape = _op.shape_of(inputs[1])
+ b_shape = shape_of(inputs[1])
b_rank = infer_shape(b_shape)[0]
# When performing a batch matmul, we need to properly handle N-dim shapes.
if a_rank > 2 or b_rank > 2:
@@ -555,9 +555,13 @@ class MatMul(OnnxOpConverter):
def flatten_to_3d(x, x_shape):
ndims = infer_shape(x_shape)[0]
newshape = _op.concatenate(
- [_expr.const([-1]), _op.strided_slice(x_shape, [ndims - 2], [ndims])], 0
+ [
+ _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
+ _op.strided_slice(x_shape, [ndims - 2], [ndims]),
+ ],
+ 0,
)
- out = _op.reshape(x, newshape)
+ out = _op.reshape(x, fold_constant(newshape))
return out
# Convert a and b into 3 dimensional tensors.
@@ -598,7 +602,7 @@ class MatMul(OnnxOpConverter):
],
0,
)
- return _op.reshape(output, final_shape)
+ return _op.reshape(output, fold_constant(final_shape))
# Otherwise a simple dense op will get the job done.
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t)
@@ -646,7 +650,7 @@ class MaxUnpool(OnnxOpConverter):
multiplier = _op.concatenate(
[_expr.const([1, 1], dtype="int64"), _expr.const(list(strides), dtype="int64")], axis=0
)
- total_output_shape = multiplier * _op.shape_of(data, dtype="int64")
+ total_output_shape = multiplier * shape_of(data, dtype="int64")
# Add extra dimensions from kernel size and stride mismatch
total_output_shape += _op.concatenate(
[_expr.const([0, 0], "int64"), _expr.const(list(kernel_shape), "int64")], axis=0
@@ -792,11 +796,11 @@ class Pad(OnnxOpConverter):
def _impl_v11(cls, inputs, attr, params):
pads = inputs[1]
if len(inputs) == 3:
- value = _op.take(inputs[2], _op.const(0))
+ value = fold_constant(_op.take(inputs[2], _op.const(0)))
else:
value = 0
- pad_width_expr = _op.transpose(_op.reshape(pads, (2, -1)))
+ pad_width_expr = fold_constant(_op.transpose(_op.reshape(pads, (2, -1))))
pad_mode = attr.get("mode", b"constant").decode("utf-8")
if not pad_mode in ["constant", "edge", "reflect"]:
@@ -823,7 +827,7 @@ class Prelu(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
- input_shape = _op.shape_of(inputs[0])
+ input_shape = shape_of(inputs[0])
alpha = _op.broadcast_to_like(inputs[1], inputs[0])
alpha = _op.reshape(alpha, [-1])
output = _op.nn.prelu(_op.reshape(inputs[0], [-1]), alpha, axis=0)
@@ -875,7 +879,6 @@ class DepthToSpace(OnnxOpConverter):
@classmethod
def _impl_v11(cls, inputs, attr, params):
-
block_size = int(attr["blocksize"])
mode = attr.get("mode", b"DCR").decode("utf-8")
return _op.nn.depth_to_space(inputs[0], block_size, mode=mode)
@@ -1015,8 +1018,9 @@ class Upsample(OnnxOpConverter):
scales = params[inputs[1].name_hint].asnumpy()
else:
scales = inputs[1]
-
- if not isinstance(scales, _expr.Call):
+ if isinstance(scales, _expr.Constant):
+ scales = list(scales.data.asnumpy())
+ if not isinstance(scales, _expr.Expr):
assert scales[0] == 1.0 and scales[1] == 1.0
mode = attr.get("mode")
@@ -1067,12 +1071,20 @@ class Upsample(OnnxOpConverter):
return out
+def shape_of(x, dtype="int64"):
+ ttype = infer_type(x).checked_type
+ if not _ty.is_dynamic(ttype):
+ shape = list(ttype.shape)
+ return _expr.const(shape, dtype)
+ return _op.shape_of(x, dtype)
+
+
class Shape(OnnxOpConverter):
"""Operator converter for Shape."""
@classmethod
def _impl_v1(cls, inputs, attr, params):
- return _op.shape_of(inputs[0], "int64")
+ return shape_of(inputs[0], "int64")
class CumSum(OnnxOpConverter):
@@ -1204,7 +1216,7 @@ class Slice(OnnxOpConverter):
# Update the starts and ends according to axes if required.
if axes is not None:
- data_shape = _op.shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype)
+ data_shape = shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype)
starts = _op.scatter(
_op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype),
axes,
@@ -1223,7 +1235,9 @@ class Slice(OnnxOpConverter):
if steps is None:
steps = _op.const([1] * data_rank, dtype=infer_type(starts).checked_type.dtype)
- return _op.strided_slice(inputs[0], starts, ends, steps)
+ return _op.strided_slice(
+ inputs[0], fold_constant(starts), fold_constant(ends), fold_constant(steps)
+ )
class Gather(OnnxOpConverter):
@@ -1531,6 +1545,19 @@ class ConstantOfShape(OnnxOpConverter):
return output
+class Constant(OnnxOpConverter):
+ """Operator converter for ConstantOfShape."""
+
+ @classmethod
+ def _impl_v9(cls, inputs, attr, params):
+ if "value" not in attr:
+ raise "No Value in Constant"
+ np_value = get_numpy(attr.pop("value"))
+ dtype = np_value.dtype.name
+ value = _expr.const(np_value, dtype)
+ return value
+
+
class Sign(OnnxOpConverter):
"""Operator converter for Sign."""
@@ -1591,12 +1618,14 @@ class Where(OnnxOpConverter):
# to that shape.
max_rank = max(ranks)
max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank]
- broadcast_shape = _op.shape_of(inputs[max_rank_idxs[0]])
+ broadcast_shape = shape_of(inputs[max_rank_idxs[0]])
# If two or more inputs have the same rank, compute the broadcast
# shape by taking the maximum value of each dimensions.
if len(max_rank_idxs) > 1:
for idx in max_rank_idxs:
- broadcast_shape = _op.maximum(broadcast_shape, _op.shape_of(inputs[idx]))
+ broadcast_shape = _op.maximum(broadcast_shape, shape_of(inputs[idx]))
+
+ broadcast_shape = fold_constant(broadcast_shape)
condition = _op.broadcast_to(inputs[0], broadcast_shape)
x = _op.broadcast_to(inputs[1], broadcast_shape)
@@ -1618,7 +1647,7 @@ class Expand(OnnxOpConverter):
@classmethod
def _impl_v8(cls, inputs, attr, params):
dtype = infer_type(inputs[1]).checked_type.dtype
- in_shape = _op.shape_of(inputs[0], dtype=dtype)
+ in_shape = shape_of(inputs[0], dtype=dtype)
shape = inputs[1]
# Currently 'op.broadcast_to' expect the rank of the given 'shape'
@@ -1667,7 +1696,7 @@ class Expand(OnnxOpConverter):
new_shape = _op.maximum(in_shape, shape)
return new_shape
- shape = expand_shape(in_shape, shape)
+ shape = fold_constant(expand_shape(in_shape, shape))
return _op.broadcast_to(inputs[0], shape=shape)
@@ -1942,10 +1971,9 @@ class Resize(OnnxOpConverter):
)
scale = inputs[1]
- size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
-
+ size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
layout = "NCHW" # ONNX assumes NCHW layout
- out_size = _op.strided_slice(size, [2], [4])
+ out_size = fold_constant(_op.strided_slice(size, [2], [4]))
return _op.image.resize(inputs[0], out_size, layout, method, "asymmetric")
@classmethod
@@ -1969,7 +1997,7 @@ class Resize(OnnxOpConverter):
size = inputs[3]
else:
assert len(scale_shape) != 0, "One of scale or size should be passed."
- size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
+ size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
coord_trans = attr.get("coordinate_transformation_mode")
if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]:
@@ -1983,7 +2011,7 @@ class Resize(OnnxOpConverter):
"Unsupported coordinate_transformation_mode: {}".format(coord_trans)
)
layout = "NCHW" # ONNX assumes NCHW layout
- out_size = _op.strided_slice(size, [2], [4])
+ out_size = fold_constant(_op.strided_slice(size, [2], [4]))
return _op.image.resize(inputs[0], out_size, layout, method, coord_trans)
@@ -2152,7 +2180,9 @@ class Loop(OnnxOpConverter):
# Get the current graph proto and create a clone for the subgraph
graph_scope = GraphProto.current
- subgraph_scope = GraphProto(graph_scope._shape, graph_scope._dtype)
+ subgraph_scope = GraphProto(
+ graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params
+ )
# Load nodes from outer graph into inner graph.
subgraph_scope._nodes = graph_scope._nodes.copy()
@@ -2246,7 +2276,7 @@ class Loop(OnnxOpConverter):
expand_scan = _op.expand_dims(new_scan, axis=0)
# For non scalar outputs we need to broadcast the initial value.
if rank > 0:
- new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype)
+ new_scan_shape = shape_of(new_scan, dtype=iter_dtype)
scan_broadcast = _op.concatenate(
[_op.reshape(loop_count, [1]), new_scan_shape], axis=0
)
@@ -2264,7 +2294,7 @@ class Loop(OnnxOpConverter):
return [loop_count, max_count, new_cond] + new_loop_vars + combined_scan_outputs
# Create the loop function.
- loop = _loops.while_loop(cond_fn, loop_vars + scan_output_vars, body_fn)
+ loop = fold_constant(_loops.while_loop(cond_fn, loop_vars + scan_output_vars, body_fn))
# Now need to run initial values through the graph.
init_count = _expr.const(0, dtype=iter_dtype)
@@ -2287,6 +2317,7 @@ class Loop(OnnxOpConverter):
# Update outer graph with constants found in the subgraph.
free_vars = analysis.free_vars(loop)
graph_scope._params.update(subgraph_scope._params)
+ graph_scope._nodes.update(subgraph_scope._nodes)
for var in free_vars:
graph_scope._nodes.update({var.name_hint: var})
return outputs
@@ -2307,9 +2338,9 @@ class If(OnnxOpConverter):
# Create graph converters for both branches.
graph_scope = GraphProto.current
- then_graph = GraphProto(graph_scope._shape, graph_scope._dtype)
+ then_graph = GraphProto(graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params)
then_graph._nodes = graph_scope._nodes.copy()
- else_graph = GraphProto(graph_scope._shape, graph_scope._dtype)
+ else_graph = GraphProto(graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params)
else_graph._nodes = graph_scope._nodes.copy()
# Convert each branch to a relay expression.
@@ -2320,10 +2351,12 @@ class If(OnnxOpConverter):
# Add constants from both branches to parent graph.
graph_scope._params.update(then_graph._params)
+ graph_scope._nodes.update(then_graph._nodes)
then_free_vars = analysis.free_vars(then_expr)
for var in then_free_vars:
graph_scope._nodes.update({var.name_hint: var})
graph_scope._params.update(else_graph._params)
+ graph_scope._nodes.update(else_graph._nodes)
else_free_vars = analysis.free_vars(else_expr)
for var in else_free_vars:
graph_scope._nodes.update({var.name_hint: var})
@@ -2468,9 +2501,9 @@ class NonMaxSuppression(OnnxOpConverter):
# partially prepare ONNX output format by labeling batch_num, class_id
nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1)
batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1)
- batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], dtype="int64"))
+ batch_num = _op.broadcast_to(batch_num, shape_of(nms_ret[0], dtype="int64"))
batch_num = _op.expand_dims(batch_num, -1, 1)
- class_num = _op.broadcast_to(i, _op.shape_of(nms_padded_out, dtype="int64"))
+ class_num = _op.broadcast_to(i, shape_of(nms_padded_out, dtype="int64"))
new_onnx_out = _op.concatenate(
[batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1
)
@@ -2570,7 +2603,7 @@ class NonMaxSuppression(OnnxOpConverter):
)
# Call the first loop, perform NMS
- B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3)
+ B, C, S = _op.split(shape_of(scores, dtype="int64"), 3)
init_count = _op.const(np.array([0]), dtype="int64")
init_onnx_out = _op.const([1], dtype="int64")
init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0))
@@ -2617,6 +2650,7 @@ def _get_convert_map(opset):
"ThresholdedRelu": ThresholdedRelu.get_converter(opset),
"ScaledTanh": ScaledTanh.get_converter(opset),
"ParametricSoftplus": ParametricSoftPlus.get_converter(opset),
+ "Constant": Constant.get_converter(opset),
"ConstantOfShape": ConstantOfShape.get_converter(opset),
# 'GivenTensorFill'
"FC": AttrCvt("dense", ignores=["axis", "axis_w"]),
@@ -2776,11 +2810,19 @@ class GraphProto:
dtype : str or dict of str to str
The input types to the graph
+
+ freeze_params: bool
+ If this parameter is true, the importer will take any provided
+ onnx input values (weights, shapes, etc) and embed them into the relay model
+ as Constants instead of variables. This allows more aggressive optimizations
+ at compile time and helps in making models static if certain inputs represent
+ attributes relay would traditionally consider compile-time constants.
+
"""
current = None
- def __init__(self, shape, dtype):
+ def __init__(self, shape, dtype, freeze_params=False):
self._nodes = {}
self._params = {}
self._inputs = {}
@@ -2790,6 +2832,7 @@ class GraphProto:
self._shape = shape if shape else {}
self._dtype = dtype
self.opset = None
+ self._freeze_params = freeze_params
def __enter__(self):
self._old_manager = GraphProto.current
@@ -2808,7 +2851,7 @@ class GraphProto:
fn = _function.Function(analysis.free_vars(body), body)
return fn, {}
- def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
+ def from_onnx(self, graph, opset, get_output_expr=False):
"""Construct Relay expression from ONNX graph.
Onnx graph is a python protobuf object.
@@ -2825,13 +2868,6 @@ class GraphProto:
opset : opset version
- freeze_params: bool
- If this parameter is true, the importer will take any provided
- onnx input values (weights, shapes, etc) and embed them into the relay model
- as Constants instead of variables. This allows more aggressive optimizations
- at compile time and helps in making models static if certain inputs represent
- attributes relay would traditionally consider compile-time constants.
-
get_output_expr: bool
If set to true, this conversion will return each output expression rather
than a packaged module. This can be useful when converting subgraphs to
@@ -2850,12 +2886,16 @@ class GraphProto:
for init_tensor in graph.initializer:
if not init_tensor.name.strip():
raise ValueError("Tensor's name is required.")
- self._params[init_tensor.name] = self._parse_array(init_tensor)
- self._nodes[init_tensor.name] = new_var(
- init_tensor.name,
- shape=self._params[init_tensor.name].shape,
- dtype=self._params[init_tensor.name].dtype,
- )
+ array = self._parse_array(init_tensor)
+ if self._freeze_params:
+ self._nodes[init_tensor.name] = _expr.const(array)
+ else:
+ self._params[init_tensor.name] = array
+ self._nodes[init_tensor.name] = new_var(
+ init_tensor.name,
+ shape=self._params[init_tensor.name].shape,
+ dtype=self._params[init_tensor.name].dtype,
+ )
for i in graph.input:
# from onnx v0.2, GraphProto.input has type ValueInfoProto,
# and the name is 'i.name'
@@ -2867,6 +2907,8 @@ class GraphProto:
self._nodes[i_name] = new_var(
i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype
)
+ elif i_name in self._nodes:
+ continue
else:
self._num_input += 1
if i_name in self._shape:
@@ -2909,37 +2951,28 @@ class GraphProto:
for i in node.input:
if i != "":
inputs[i] = self._nodes[self._renames.get(i, i)]
- if op_name == "Constant":
- t_proto = self._parse_attr(node.attribute)["value"]
- self._num_param += 1
- # We should convert scalar integers to int32, to normalize.
- array = self._parse_array(t_proto)
- self._params[node.output[0]] = array
- self._nodes[node.output[0]] = new_var(
- node.output[0], shape=list(t_proto.dims), dtype=array.dtype
- )
+ i_name = self._parse_value_proto(node)
+ node_output = self._fix_outputs(op_name, node.output)
+ attr["tvm_custom"] = {}
+ attr["tvm_custom"]["name"] = i_name
+ attr["tvm_custom"]["num_outputs"] = len(node_output)
+
+ op = self._convert_operator(op_name, inputs, attr, opset)
+ if not isinstance(op, _expr.TupleWrapper):
+ outputs_num = 1
else:
- i_name = self._parse_value_proto(node)
- node_output = self._fix_outputs(op_name, node.output)
- attr["tvm_custom"] = {}
- attr["tvm_custom"]["name"] = i_name
- attr["tvm_custom"]["num_outputs"] = len(node_output)
-
- op = self._convert_operator(op_name, inputs, attr, opset)
- if not isinstance(op, _expr.TupleWrapper):
- outputs_num = 1
- else:
- outputs_num = len(op)
- assert (
- len(node_output) == outputs_num
- ), "Number of output mismatch {} vs {} in {}.".format(
- len(node_output), outputs_num, op_name
- )
- if outputs_num == 1:
- self._nodes[node_output[0]] = op
- else:
- for k, i in zip(list(node_output), range(len(node_output))):
- self._nodes[k] = op[i]
+ outputs_num = len(op)
+ assert (
+ len(node_output) == outputs_num
+ ), "Number of output mismatch {} vs {} in {}.".format(
+ len(node_output), outputs_num, op_name
+ )
+ if outputs_num == 1:
+ self._nodes[node_output[0]] = fold_constant(op)
+ else:
+ op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))
+ for k, i in zip(list(node_output), range(len(node_output))):
+ self._nodes[k] = op[i]
# now return the outputs
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
@@ -2957,9 +2990,6 @@ class GraphProto:
self._inputs[i_name] = self._nodes[i_name]
# Create a function from our output expression and all input variables.
func = _function.Function([v for k, v in self._inputs.items()], outputs)
- if freeze_params:
- func, params = self.freeze(func, self._params)
- return IRModule.from_expr(func), params
return IRModule.from_expr(func), self._params
def _parse_value_proto(self, value_proto):
@@ -3100,7 +3130,7 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
warnings.warn(str(e))
except ImportError:
pass
- g = GraphProto(shape, dtype)
+ g = GraphProto(shape, dtype, freeze_params)
graph = model.graph
if opset is None:
try:
@@ -3109,5 +3139,5 @@ def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=Fals
opset = 1
# Use the graph proto as a scope so that ops can access other nodes if needed.
with g:
- mod, params = g.from_onnx(graph, opset, freeze_params)
+ mod, params = g.from_onnx(graph, opset)
return mod, params
diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py
index a3f3a3e..153439b 100644
--- a/python/tvm/relay/op/image/image.py
+++ b/python/tvm/relay/op/image/image.py
@@ -17,7 +17,7 @@
"""Image operations."""
from . import _make
from ..dyn.image import _make as _dyn_make
-from ...expr import Expr
+from ...expr import Expr, Constant
def resize(
@@ -66,6 +66,8 @@ def resize(
result: relay.Expr
The resized result.
"""
+ if isinstance(size, Constant):
+ size = list(size.data.asnumpy().astype("int32"))
if isinstance(size, Expr):
return _dyn_make.resize(
data, size, layout, method, coordinate_transformation_mode, out_dtype
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 0c233a6..5135ac7 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -21,7 +21,7 @@ from tvm.relay import expr
from . import _make
from ..dyn.nn import _make as _dyn_make
from .utils import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d
-from ...expr import const, Expr
+from ...expr import const, Expr, Constant
def conv1d(
@@ -1279,6 +1279,10 @@ def upsampling(
result : tvm.relay.Expr
The computed result.
"""
+ if isinstance(scale_h, Constant):
+ scale_h = scale_h.data.asnumpy().item()
+ if isinstance(scale_w, Constant):
+ scale_w = scale_w.data.asnumpy().item()
if isinstance(scale_h, Expr) or isinstance(scale_w, Expr):
if not isinstance(scale_h, Expr):
scale_h = const(scale_h, "float64")
@@ -1338,6 +1342,12 @@ def upsampling3d(
result : tvm.relay.Expr
The computed result.
"""
+ if isinstance(scale_d, Constant):
+ scale_d = scale_d.data.asnumpy().item()
+ if isinstance(scale_h, Constant):
+ scale_h = scale_h.data.asnumpy().item()
+ if isinstance(scale_w, Constant):
+ scale_w = scale_w.data.asnumpy().item()
if isinstance(scale_d, Expr) or isinstance(scale_h, Expr) or isinstance(scale_w, Expr):
if not isinstance(scale_d, Expr):
scale_d = const(scale_d, "float64")
@@ -1596,6 +1606,10 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"):
result : tvm.relay.Expr
The computed result.
"""
+ if isinstance(pad_value, Constant):
+ pad_value = pad_value.data.asnumpy().item()
+ if isinstance(pad_width, Constant):
+ pad_width = [list(i) for i in pad_width.data.asnumpy()]
if isinstance(pad_width, Expr) or (isinstance(pad_value, Expr)):
if not isinstance(pad_width, Expr):
pad_width = const(list(pad_width))
diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py
index 75e2987..5b01104 100644
--- a/python/tvm/relay/op/tensor.py
+++ b/python/tvm/relay/op/tensor.py
@@ -22,7 +22,7 @@ from tvm.te.hybrid import script
from . import _make
from .dyn import _make as _dyn_make
-from ..expr import Tuple, Expr
+from ..expr import Tuple, Expr, Constant
from . import op as reg
@@ -960,6 +960,8 @@ def zeros(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
+ if isinstance(shape, Constant):
+ shape = list(shape.data.asnumpy())
if isinstance(shape, Expr):
return _dyn_make.zeros(shape, dtype)
if isinstance(shape, int):
@@ -1001,6 +1003,8 @@ def ones(shape, dtype):
result : relay.Expr
The resulting tensor.
"""
+ if isinstance(shape, Constant):
+ shape = list(shape.data.asnumpy())
if isinstance(shape, Expr):
return _dyn_make.ones(shape, dtype)
if isinstance(shape, int):
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index d42ef47..cda417c 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -21,7 +21,7 @@
from . import _make
from .dyn import _make as _dyn_make
from .tensor import shape_of
-from ..expr import TupleWrapper, const, Expr, Tuple
+from ..expr import TupleWrapper, const, Constant, Expr, Tuple
from ...tir import expr as _expr
@@ -216,6 +216,8 @@ def reshape(data, newshape):
result : relay.Expr
The reshaped result.
"""
+ if isinstance(newshape, Constant):
+ newshape = list(newshape.data.asnumpy())
if isinstance(newshape, Expr):
return _dyn_make.reshape(data, newshape)
if isinstance(newshape, int):
@@ -431,6 +433,8 @@ def full(fill_value, shape=(), dtype=""):
result : relay.Expr
The resulting tensor.
"""
+ if isinstance(shape, Constant):
+ shape = list(shape.data.asnumpy())
if isinstance(shape, Expr):
return _dyn_make.full(fill_value, shape, dtype)
if isinstance(shape, int):
@@ -614,6 +618,8 @@ def tile(data, reps):
data is promoted to be d-dimensional by prepending new axes.
If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it.
"""
+ if isinstance(reps, Constant):
+ reps = list(reps.data.asnumpy())
if isinstance(reps, Expr):
return _dyn_make.tile(data, reps)
return _make.tile(data, reps)
@@ -753,6 +759,8 @@ def broadcast_to(data, shape):
result : relay.Expr
The resulting tensor.
"""
+ if isinstance(shape, Constant):
+ shape = list(shape.data.asnumpy())
if isinstance(shape, Expr):
return _dyn_make.broadcast_to(data, shape)
if isinstance(shape, int):
@@ -884,6 +892,12 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
The computed result.
"""
strides = strides or [1]
+ if isinstance(begin, Constant):
+ begin = list(begin.data.asnumpy())
+ if isinstance(end, Constant):
+ end = list(end.data.asnumpy())
+ if isinstance(strides, Constant):
+ strides = list(strides.data.asnumpy())
if isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr):
if isinstance(begin, (tuple, list)):
begin = const(list(begin))
@@ -1170,6 +1184,8 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
[0, 1, 0],
[0, 0, 1]]
"""
+ if isinstance(depth, Constant):
+ depth = depth.data.asnumpy().item()
if isinstance(depth, Expr):
return _dyn_make.one_hot(indices, on_value, off_value, depth, axis, dtype)
return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py
index c6df8c1..f02f835 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -240,6 +240,23 @@ def LazyGradientInit():
return _ffi_api.LazyGradientInit()
+def FoldConstantExpr(expr, mod):
+ """Fold the constant expressions in a Relay program.
+ Parameters
+ ----------
+ expr: Expr
+ The expression to fold
+ mod: IRModule
+ The module the expr lives in (for global calls)
+
+ Returns
+ -------
+ new_expr: Expr
+ The expr after Constant Folding
+ """
+ return _ffi_api.FoldConstantExpr(expr, mod)
+
+
def FoldConstant():
"""Fold the constant expressions in a Relay program.
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index 4454c9c..9416b0e 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -374,6 +374,8 @@ Expr FoldConstant(const Expr& expr, const IRModule& mod) {
return ConstantFolder(mod).Mutate(expr);
}
+TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr").set_body_typed(FoldConstant);
+
namespace transform {
Pass FoldConstant() {
diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py
index 904576a..d43744b 100644
--- a/tests/python/relay/test_op_grad_level3.py
+++ b/tests/python/relay/test_op_grad_level3.py
@@ -146,7 +146,7 @@ def test_zeros_ones_grad_const_ints():
def test_zeros_ones_grad_const_expr():
# when shape is static (i.e. not an input), there is no gradient at all
- shape_const = relay.const(np.array([2, 3, 4]), dtype="int32")
+ shape_const = relay.const(np.array([2, 3, 4]), dtype="int32") * relay.const(1, dtype="int32")
static_ty = relay.TensorType([2, 3, 4], dtype="float32")
dyn_ty = relay.TensorType([relay.Any(), relay.Any(), relay.Any()], dtype="float32")
expected_ty_static = relay.TupleType([static_ty, relay.TupleType([])])