You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/01/12 04:28:14 UTC
[tvm] branch main updated: [Relay][Frontend] Span Filling ONNX (#13767)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 079876ed54 [Relay][Frontend] Span Filling ONNX (#13767)
079876ed54 is described below
commit 079876ed5412e7f8beb4c8e82248cca461781ee2
Author: Chun-I Tsai <qu...@quicinc.com>
AuthorDate: Thu Jan 12 12:28:07 2023 +0800
[Relay][Frontend] Span Filling ONNX (#13767)
- Set node name as the source name of span during the conversion of
ONNX model.
- Assign node name to a node based on op type when it is empty.
- To get the reference of renamed nodes. Add a function to export
the ONNX model after conversion.
- Add structural_equal comparisons with and without set_span to the
existing test cases.
- Add span test cases for frequent conversions.
- Add span test case for exporting model parameter.
Co-authored-by: Joey Tsai <ch...@qti.qualcomm.com>
---
python/tvm/relay/frontend/onnx.py | 133 +++++++++-
tests/python/frontend/onnx/test_forward.py | 376 ++++++++++++++++++++++++++++-
2 files changed, 487 insertions(+), 22 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 328b5d7bd8..3e4c9db2b0 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -57,6 +57,7 @@ from .common import (
shape_of,
try_resolve_var_to_const,
unbind,
+ set_span,
)
__all__ = ["from_onnx"]
@@ -556,6 +557,37 @@ def layer_norm(x, eps, gamma, beta):
return output
+def get_source_name(node, type_dict):
+ """A helper function to get source information of onnx nodes."""
+ if node.name:
+ return node.name
+ else:
+ op_idx = 0
+ if node.op_type in type_dict:
+ op_idx = type_dict[node.op_type] + 1
+ type_dict[node.op_type] = op_idx
+ # rewrite name property in case any revisiting occurs to current node
+ node.name = "{}_{}".format(node.op_type, str(op_idx))
+ return node.name
+
+
+def get_source_name_from_parameter(expr, name_sep="."):
+ """A helper function to get source information of graph node from parameter."""
+ if expr.span:
+ source_name = expr.span.source_name.name
+ # discard variable/parameter name to get span of op node
+ # e.g. conv2d.w -> conv2d
+ if isinstance(expr, _expr.Var):
+ postfix = f"{name_sep}{expr.name_hint}"
+ source_name = source_name[: -len(postfix)]
+ return source_name
+ return None
+
+
+def make_parameter_span(source_name_list, name_sep="."):
+ return name_sep.join(source_name_list)
+
+
class OnnxOpConverter(object):
"""A helper class for holding onnx op converters."""
@@ -2712,10 +2744,13 @@ class EyeLike(OnnxOpConverter):
else:
dtype = get_type(dtype)
- in_shape = _op.shape_of(inputs[0])
+ node_source_name = get_source_name_from_parameter(inputs[0])
+ # since there exists multi-comsumer for the same expression
+ # invoke set_span here to prevent expr-rewritten in span-filling stage
+ in_shape = set_span(_op.shape_of(inputs[0]), node_source_name)
zeros = _op.zeros(in_shape, dtype)
- dim = _op.take(in_shape, _op.const(0))
+ dim = set_span(_op.take(in_shape, _op.const(0)), node_source_name)
indices = _op.arange(_op.const(0), dim, dtype="int32")
ones = _op.full(_op.const(1), _op.reshape(dim, (1,)), dtype=dtype)
@@ -4128,7 +4163,10 @@ class Loop(OnnxOpConverter):
# Get the current graph proto and create a clone for the subgraph
graph_scope = GraphProto.current
subgraph_scope = GraphProto(
- graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params
+ graph_scope._shape,
+ graph_scope._dtype,
+ graph_scope._freeze_params,
+ graph_scope._op_type_dict,
)
# Load nodes from outer graph into inner graph.
subgraph_scope._nodes = graph_scope._nodes.copy()
@@ -4159,6 +4197,11 @@ class Loop(OnnxOpConverter):
]
loop_vars += [get_var(body.input[i + 2].name, v) for i, v in enumerate(loop_deps)]
loop_var_names = [v.name_hint for v in loop_vars]
+ # get span information of loop body
+ body_source_name = get_source_name(body, subgraph_scope._op_type_dict)
+ # set span to inputs of loop body
+ for i, v in enumerate(loop_vars):
+ loop_vars[i] = set_span(v, make_parameter_span([v.name_hint, body_source_name]))
num_scan_outputs = len(body.output) - (1 + num_deps)
@@ -4287,9 +4330,19 @@ class If(OnnxOpConverter):
# Create graph converters for both branches.
graph_scope = GraphProto.current
- then_graph = GraphProto(graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params)
+ then_graph = GraphProto(
+ graph_scope._shape,
+ graph_scope._dtype,
+ graph_scope._freeze_params,
+ graph_scope._op_type_dict,
+ )
then_graph._nodes = graph_scope._nodes.copy()
- else_graph = GraphProto(graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params)
+ else_graph = GraphProto(
+ graph_scope._shape,
+ graph_scope._dtype,
+ graph_scope._freeze_params,
+ graph_scope._op_type_dict,
+ )
else_graph._nodes = graph_scope._nodes.copy()
# Convert each branch to a relay expression.
@@ -4386,7 +4439,10 @@ class Scan(OnnxOpConverter):
# Get the current graph proto and create a clone for the subgraph
graph_scope = GraphProto.current
subgraph_scope = GraphProto(
- graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params
+ graph_scope._shape,
+ graph_scope._dtype,
+ graph_scope._freeze_params,
+ graph_scope._op_type_dict,
)
# Load nodes from outer graph into inner graph.
subgraph_scope._nodes = graph_scope._nodes.copy()
@@ -4440,6 +4496,12 @@ class Scan(OnnxOpConverter):
loop_vars += [
get_var(body.input[i].name, v) for i, v in enumerate(inputs) if i < num_state_inputs
]
+ # get span information of scan body
+ body_source_name = get_source_name(body, subgraph_scope._op_type_dict)
+ # set span to inputs of scan body
+ for i, v in enumerate(loop_vars):
+ loop_vars[i] = set_span(v, make_parameter_span([v.name_hint, body_source_name]))
+
loop_vars += scan_output_vars
body_input_var_names = ["iter"] + [body.input[i].name for i in range(len(body.input))]
@@ -6197,11 +6259,16 @@ class GraphProto:
at compile time and helps in making models static if certain inputs represent
attributes relay would traditionally consider compile-time constants.
+ op_type_dict: Dict[str, int]
+ Dictionary for span filling usage. If the name property of op was not set
+ op_type_dict will provide an alternative by combining literal op type with
+ its presenting order
+
"""
current = None
- def __init__(self, shape, dtype, freeze_params=False):
+ def __init__(self, shape, dtype, freeze_params=False, op_type_dict=None):
self._nodes = {}
self._params = {}
self._inputs = {}
@@ -6213,6 +6280,7 @@ class GraphProto:
self._dtype = dtype
self.opset = None
self._freeze_params = freeze_params
+ self._op_type_dict = op_type_dict
def __enter__(self):
self._old_manager = GraphProto.current
@@ -6365,6 +6433,9 @@ class GraphProto:
for node in graph.node:
op_name = node.op_type
attr = self._parse_attr(node.attribute)
+ # Fill in span of inputs
+ node_source_name = get_source_name(node, self._op_type_dict)
+ self._set_parameter_span(node, node_source_name)
# Create and populate input list.
inputs = onnx_input()
for i in node.input:
@@ -6389,6 +6460,8 @@ class GraphProto:
else:
op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))
+ op = set_span(op, node_source_name)
+
if outputs_num > 1:
# ONNX supports optional outputs for some nodes.
# This block searches for missing outputs in the ONNX graph
@@ -6427,6 +6500,19 @@ class GraphProto:
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
+ def _set_parameter_span(self, node, node_source_name):
+ for i in node.input:
+ if i != "":
+ name = self._renames.get(i, i)
+ expr = self._nodes.get(name)
+ # relay.Var -> inputs / params
+ # relay.Constant -> freezed params / built-in constants
+ if isinstance(expr, (relay.Var, relay.Constant)):
+ expr_with_span = set_span(expr, make_parameter_span([node_source_name, name]))
+ self._nodes[name] = expr_with_span
+ if name in self._inputs:
+ self._inputs[name] = expr_with_span
+
def _parse_value_proto(self, value_proto):
"""Parse ValueProto or raw str."""
try:
@@ -6506,8 +6592,28 @@ class GraphProto:
return outputs
+def export_model(location, graph):
+ """Convert the graph to an onnx model and export it to the location."""
+ import datetime
+ import os
+
+ from onnx import save, helper
+
+ if not os.path.exists(location):
+ os.makedirs(location)
+ time_stamp = datetime.datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
+ model = helper.make_model(graph)
+ save(model, os.path.join(location, "tvm_exported_model_{}.onnx".format(time_stamp)))
+
+
def from_onnx(
- model, shape=None, dtype="float32", opset=None, freeze_params=True, convert_config=None
+ model,
+ shape=None,
+ dtype="float32",
+ opset=None,
+ freeze_params=True,
+ convert_config=None,
+ export_node_renamed_model_path=None,
):
"""Convert a ONNX model into an equivalent Relay Function.
@@ -6553,6 +6659,12 @@ def from_onnx(
True to convert qualified onnx `matmul` to `nn.batch_matmul` strict to NT format
(transpose_a=False, transpose_b=True).
+ export_node_renamed_model_path : str, optional
+ Export the node renamed onnx model to the path.
+ Some models do not contain names in their nodes. During the conversion, if names of nodes
+ are empty, new names will be assigned based on their op types. The exported model can be the
+ reference to spans.
+
Returns
-------
mod : tvm.IRModule
@@ -6577,7 +6689,7 @@ def from_onnx(
warnings.warn(str(e))
except ImportError:
pass
- g = GraphProto(shape, dtype, freeze_params)
+ g = GraphProto(shape, dtype, freeze_params, op_type_dict={})
graph = model.graph
try:
@@ -6607,6 +6719,9 @@ def from_onnx(
with g:
mod, params = g.from_onnx(graph, opset)
+ if export_node_renamed_model_path:
+ export_model(export_node_renamed_model_path, graph)
+
if freeze_params:
mod = relay.transform.DynamicToStatic()(mod)
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 09206b341d..c016078f8f 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -34,7 +34,10 @@ import tvm
import tvm.testing
import tvm.topi.testing
from tvm import relay
-from tvm.contrib import graph_executor
+from tvm.contrib import graph_executor, utils
+from tvm.relay.frontend.common import infer_type
+from tvm.relay.build_module import bind_params_by_name
+from relay.utils.tag_span import _create_span, _set_span, _verify_structural_equal_with_span
import onnx
import onnxruntime.backend
@@ -81,18 +84,31 @@ def get_tvm_output_with_vm(
opset=None,
freeze_params=False,
convert_config=None,
+ validate_structural_equal=True,
):
"""Generic function to execute and get tvm output with vm executor"""
if not isinstance(input_data, list):
input_data = [input_data]
_, shape_dict = get_input_data_shape_dict(graph_def, input_data)
- mod, params = relay.frontend.from_onnx(
- graph_def,
- shape_dict,
- opset=opset,
- freeze_params=freeze_params,
- convert_config=convert_config,
- )
+
+ with tvm.testing.disable_span_filling():
+ mod, params = relay.frontend.from_onnx(
+ graph_def,
+ shape_dict,
+ opset=opset,
+ freeze_params=freeze_params,
+ convert_config=convert_config,
+ )
+ if validate_structural_equal:
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ = relay.frontend.from_onnx(
+ graph_def,
+ shape_dict,
+ opset=opset,
+ freeze_params=freeze_params,
+ convert_config=convert_config,
+ )
+ assert tvm.ir.structural_equal(mod, mod_with_span)
result = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()(
*input_data, **params
@@ -6667,7 +6683,13 @@ def test_random_uniform(target, dev):
outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, shape)],
)
model = helper.make_model(graph, producer_name="random_uniform_test")
- return get_tvm_output_with_vm(model, [], target=target, dev=dev)
+ return get_tvm_output_with_vm(
+ model,
+ [],
+ target=target,
+ dev=dev,
+ validate_structural_equal=(seed is not None),
+ )
# Check that function runs and produces proper shape.
vals = get_random_uniform([10], dtype="float32")
@@ -6733,7 +6755,13 @@ def test_random_uniform_like(target, dev):
outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, shape)],
)
model = helper.make_model(graph, producer_name="random_uniform_like_test")
- return get_tvm_output_with_vm(model, [input_], target=target, dev=dev)
+ return get_tvm_output_with_vm(
+ model,
+ [input_],
+ target=target,
+ dev=dev,
+ validate_structural_equal=(seed is not None),
+ )
# Check that function runs and produces proper shape and dtype.
shape = [10]
@@ -6797,7 +6825,13 @@ def test_random_normal(target, dev):
outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, shape)],
)
model = helper.make_model(graph, producer_name="random_normal_test")
- return get_tvm_output_with_vm(model, [], target=target, dev=dev)
+ return get_tvm_output_with_vm(
+ model,
+ [],
+ target=target,
+ dev=dev,
+ validate_structural_equal=(seed is not None),
+ )
# Test N-D tensor generation.
vals = get_random_normal([1, 3, 100, 100], dtype="float32")
@@ -6837,7 +6871,13 @@ def test_random_normal_like(target, dev):
outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, shape)],
)
model = helper.make_model(graph, producer_name="random_normal_like_test")
- return get_tvm_output_with_vm(model, [input_], target=target, dev=dev)
+ return get_tvm_output_with_vm(
+ model,
+ [input_],
+ target=target,
+ dev=dev,
+ validate_structural_equal=(seed is not None),
+ )
# Test N-D tensor generation.
shape = [1, 3, 100, 100]
@@ -6875,7 +6915,13 @@ def test_multinomial(target, dev):
outputs=[helper.make_tensor_value_info("out", OUT_DTYPE, shape)],
)
model = helper.make_model(graph, producer_name="multinomial_test")
- return get_tvm_output_with_vm(model, [input], target=target, dev=dev)
+ return get_tvm_output_with_vm(
+ model,
+ [input],
+ target=target,
+ dev=dev,
+ validate_structural_equal=(seed is not None),
+ )
# Test N-D tensor generation.
shape = [3]
@@ -7348,5 +7394,309 @@ def test_sequence(target, dev):
verify_sequence_ops((3, 3, 3, 3), 4, axis=2, new_axis=1)
+def test_exporting_node_renamed_model():
+ """test exproting model when export_node_renamed_model is set"""
+
+ a_name, a_shape = "a", (4, 3)
+ b_name, b_shape = "b", (3, 4)
+ out_name, out_shape = "out", [a_shape[0], b_shape[1]]
+ temp_dir = utils.tempdir().path
+
+ # model definition
+ mul_node = helper.make_node("MatMul", [a_name, b_name], [out_name])
+ graph = helper.make_graph(
+ [mul_node],
+ "matmul_test",
+ inputs=[
+ helper.make_tensor_value_info(a_name, TensorProto.FLOAT, a_shape),
+ helper.make_tensor_value_info(b_name, TensorProto.FLOAT, b_shape),
+ ],
+ outputs=[helper.make_tensor_value_info(out_name, TensorProto.FLOAT, out_shape)],
+ )
+ model = helper.make_model(graph, producer_name="matmul_test")
+
+ # get frontend model
+ shape_dict = {a_name: a_shape, b_name: b_shape}
+ _, _ = relay.frontend.from_onnx(model, shape_dict, export_node_renamed_model_path=temp_dir)
+
+ exported_model_name = os.listdir(temp_dir)[0]
+ assert "tvm_exported_model_" in exported_model_name
+
+ exported_model = onnx.load(os.path.join(temp_dir, exported_model_name))
+ assert exported_model.graph.node[0].name == "MatMul_0"
+
+
+class TestSetSpan:
+ """test structural equal between translated / hand-crafted relay IR with span tagged."""
+
+ def _verify(self, res_fptr, golden_fptr):
+ with tvm.testing.enable_span_filling():
+ with_span = res_fptr()
+ with tvm.testing.disable_span_filling():
+ without_span = res_fptr()
+ assert tvm.ir.structural_equal(with_span, without_span)
+ _verify_structural_equal_with_span(with_span, golden_fptr())
+
+ def test_conv2d_bias_add_span(self):
+ padding = [0, 0, 0, 0]
+ k_shape = [7, 7]
+ y_shape, y_name = [1, 6, 10, 10], "y"
+ x_shape, x_name = [1, 3, 10, 10], "x"
+ b_shape, b_name = [6], "b"
+ b_val = np.random.random(b_shape).astype(np.float32)
+ w_shape, w_name = [6, 3, 7, 7], "w"
+ w_val = np.random.random(w_shape).astype(np.float32)
+ group, strides, dilations = 1, [1, 1], [1, 1]
+ conv_name = "conv2d"
+
+ def _res():
+ # model definition
+ node = helper.make_node(
+ "Conv",
+ inputs=[x_name, w_name, b_name],
+ outputs=[y_name],
+ kernel_shape=k_shape,
+ strides=strides,
+ dilations=dilations,
+ group=group,
+ pads=padding,
+ name=conv_name,
+ )
+ graph = helper.make_graph(
+ [node],
+ "conv_test",
+ inputs=[helper.make_tensor_value_info(x_name, TensorProto.FLOAT, x_shape)],
+ outputs=[helper.make_tensor_value_info(y_name, TensorProto.FLOAT, y_shape)],
+ initializer=[
+ helper.make_tensor(
+ w_name,
+ TensorProto.FLOAT,
+ dims=w_shape,
+ vals=w_val.flatten(),
+ ),
+ helper.make_tensor(
+ b_name,
+ TensorProto.FLOAT,
+ dims=b_shape,
+ vals=b_val.flatten(),
+ ),
+ ],
+ )
+ model = helper.make_model(graph, producer_name="conv_test")
+
+ # get frontend model
+ shape_dict = {x_name: x_shape}
+ mod, _ = relay.frontend.from_onnx(model, shape_dict)
+ return mod["main"]
+
+ def _golden():
+ conv_si = conv_name
+ x = relay.var(
+ x_name,
+ shape=tuple(x_shape),
+ span=_create_span(f"{conv_si}.{x_name}"),
+ )
+ conv_weight = relay.const(
+ w_val,
+ span=_create_span(f"{conv_si}.{w_name}"),
+ )
+ conv_bias = relay.const(
+ b_val,
+ span=_create_span(f"{conv_si}.{b_name}"),
+ )
+ conv_out = _set_span(
+ relay.nn.conv2d(
+ x,
+ conv_weight,
+ padding=[0] * 4,
+ channels=y_shape[1],
+ kernel_size=k_shape,
+ ),
+ conv_si,
+ )
+ bias_out = _set_span(relay.nn.bias_add(conv_out, conv_bias), conv_si)
+ return infer_type(relay.Function([x], bias_out))
+
+ self._verify(_res, _golden)
+
+ def test_batchnorm_span(self):
+ input_name, in_shape = "x", [1, 16, 10, 10]
+ bn_name = "bn"
+ output_name = "y"
+ scale_name = "scale"
+ bias_name = "b"
+ mean_name = "mean"
+ var_name = "var"
+
+ def _res():
+ # model definition
+ batchnorm = onnx.helper.make_node(
+ "BatchNormalization",
+ inputs=[input_name, scale_name, bias_name, mean_name, var_name],
+ outputs=[output_name],
+ name=bn_name,
+ )
+ graph = helper.make_graph(
+ [batchnorm],
+ "batchnorm_test",
+ inputs=[
+ helper.make_tensor_value_info(input_name, TensorProto.FLOAT, in_shape),
+ helper.make_tensor_value_info(scale_name, TensorProto.FLOAT, [in_shape[1]]),
+ helper.make_tensor_value_info(bias_name, TensorProto.FLOAT, [in_shape[1]]),
+ helper.make_tensor_value_info(mean_name, TensorProto.FLOAT, [in_shape[1]]),
+ helper.make_tensor_value_info(var_name, TensorProto.FLOAT, [in_shape[1]]),
+ ],
+ outputs=[helper.make_tensor_value_info(output_name, TensorProto.FLOAT, in_shape)],
+ )
+ model = helper.make_model(graph, producer_name="batchnorm_test")
+
+ # get frontend model
+ shape_dict = {input_name: in_shape}
+ mod, _ = relay.frontend.from_onnx(model, shape_dict)
+ return mod["main"]
+
+ def _golden():
+ bn_si = bn_name
+ x = relay.var(
+ input_name,
+ shape=tuple(in_shape),
+ span=_create_span(f"{bn_si}.{input_name}"),
+ )
+ bn_scale = relay.var(
+ scale_name,
+ shape=(in_shape[1],),
+ span=_create_span(f"{bn_si}.{scale_name}"),
+ )
+ bn_bias = relay.var(
+ bias_name,
+ shape=(in_shape[1],),
+ span=_create_span(f"{bn_si}.{bias_name}"),
+ )
+ bn_rm = relay.var(
+ mean_name,
+ shape=(in_shape[1],),
+ span=_create_span(f"{bn_si}.{mean_name}"),
+ )
+ bn_rv = relay.var(
+ var_name,
+ shape=(in_shape[1],),
+ span=_create_span(f"{bn_si}.{var_name}"),
+ )
+ bn_out = _set_span(
+ relay.nn.batch_norm(x, bn_scale, bn_bias, bn_rm, bn_rv),
+ bn_si,
+ )
+ bn_tuple_get_item = _set_span(relay.TupleGetItem(bn_out.tuple_value, 0), bn_si)
+ return infer_type(
+ relay.Function([x, bn_scale, bn_bias, bn_rm, bn_rv], bn_tuple_get_item)
+ )
+
+ self._verify(_res, _golden)
+
+ def test_reshape_span(self):
+ input_shape = [2, 1, 10, 1, 10]
+ new_shape = [2, 1, 10, 10]
+ input_name = "in"
+ output_name = "out"
+ ref_name = "ref_in"
+ const_name = "const"
+ reshape_name = "reshape"
+
+ def _res():
+ # model definition
+ ref_array = np.array(new_shape)
+ ref_node = helper.make_node(
+ "Constant",
+ inputs=[],
+ outputs=[ref_name],
+ value=helper.make_tensor(
+ name="const_tensor",
+ data_type=TensorProto.INT32,
+ dims=ref_array.shape,
+ vals=ref_array.flatten().astype(int),
+ ),
+ name=const_name,
+ )
+ reshape_node = helper.make_node(
+ "Reshape",
+ [input_name, ref_name],
+ [output_name],
+ name=reshape_name,
+ )
+ graph = helper.make_graph(
+ [ref_node, reshape_node],
+ "reshape_test",
+ inputs=[helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape)],
+ outputs=[helper.make_tensor_value_info(output_name, TensorProto.FLOAT, new_shape)],
+ )
+ model = helper.make_model(graph, producer_name="reshape_test")
+
+ # get frontend model
+ shape_dict = {input_name: input_shape}
+ mod, _ = relay.frontend.from_onnx(model, shape_dict)
+ return mod["main"]
+
+ def _golden():
+ reshape_si = reshape_name
+ x = relay.var(
+ input_name,
+ shape=tuple(input_shape),
+ span=_create_span(f"{reshape_si}.{input_name}"),
+ )
+ reshape_out = _set_span(
+ relay.reshape(x, newshape=new_shape),
+ reshape_si,
+ )
+ return infer_type(relay.Function([x], reshape_out))
+
+ self._verify(_res, _golden)
+
+ def test_matmul_span(self):
+ a_name, a_shape = "a", (4, 3)
+ b_name, b_shape = "b", (3, 4)
+ out_name, out_shape = "out", [a_shape[0], b_shape[1]]
+ matmul_name = "matmul"
+
+ def _res():
+ # model definition
+ mul_node = helper.make_node("MatMul", [a_name, b_name], [out_name], name=matmul_name)
+ graph = helper.make_graph(
+ [mul_node],
+ "matmul_test",
+ inputs=[
+ helper.make_tensor_value_info(a_name, TensorProto.FLOAT, a_shape),
+ helper.make_tensor_value_info(b_name, TensorProto.FLOAT, b_shape),
+ ],
+ outputs=[helper.make_tensor_value_info(out_name, TensorProto.FLOAT, out_shape)],
+ )
+ model = helper.make_model(graph, producer_name="matmul_test")
+
+ # get frontend model
+ shape_dict = {a_name: a_shape, b_name: b_shape}
+ mod, _ = relay.frontend.from_onnx(model, shape_dict)
+ return mod["main"]
+
+ def _golden():
+ matmul_si = matmul_name
+ a = relay.var(
+ a_name,
+ shape=tuple(a_shape),
+ span=_create_span(f"{matmul_si}.{a_name}"),
+ )
+ b = relay.var(
+ b_name,
+ shape=tuple(b_shape),
+ span=_create_span(f"{matmul_si}.{b_name}"),
+ )
+ b_t = _set_span(relay.transpose(b, axes=[1, 0]), matmul_si)
+ matmul_out = _set_span(
+ relay.nn.dense(a, b_t, out_dtype="float32"),
+ matmul_si,
+ )
+ return infer_type(relay.Function([a, b], matmul_out))
+
+ self._verify(_res, _golden)
+
+
if __name__ == "__main__":
tvm.testing.main()