You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/04/13 17:25:20 UTC
[tvm] branch main updated: [ONNX] Add imports for BERT contrib operators (#10949)
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 11b8cd3ca1 [ONNX] Add imports for BERT contrib operators (#10949)
11b8cd3ca1 is described below
commit 11b8cd3ca167efeea6d96b34a988d874b038f1c9
Author: Altan Haan <31...@users.noreply.github.com>
AuthorDate: Wed Apr 13 10:25:14 2022 -0700
[ONNX] Add imports for BERT contrib operators (#10949)
* EmbedLayerNormalization, Attention
* fix Attention
* SkipLayerNormalization
* fix dtype bug in Gelu
Co-authored-by: An Wang <an...@gmail.com>
* missing parameterize_targets
* lint
* lint
* comments
* fix small thing
* factor out layer norm computation
* layernorm func
* add optional args to test
* upgrade onnxrt version
* no upgrade onnx
* fix tests
* int32
* fix tests
Co-authored-by: An Wang <an...@gmail.com>
---
python/tvm/relay/frontend/onnx.py | 224 ++++++++++++++++++++++++++++-
tests/python/frontend/onnx/test_forward.py | 219 ++++++++++++++++++++++++++++
2 files changed, 440 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 168362e229..31b7c21e42 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -329,6 +329,22 @@ def matmul_out_dtype(inputs, out_dtype):
return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype)
+def layer_norm(x, eps, gamma, beta):
+ """Common function to handle layer norm"""
+ eps_dtype = infer_type(x).checked_type.dtype
+
+ u, s = _op.mean_variance(x, axis=-1, keepdims=True)
+ output = _op.divide(
+ _op.subtract(x, u),
+ _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
+ )
+ output = _op.multiply(output, gamma)
+ if beta is not None:
+ output = _op.add(output, beta)
+
+ return output
+
+
class OnnxOpConverter(object):
"""A helper class for holding onnx op converters."""
@@ -807,9 +823,10 @@ class Gelu(OnnxOpConverter):
x = inputs[0]
# Declare consts
- half = _expr.const(0.5)
- one = _expr.const(1.0)
- sqrt2 = _expr.const(math.sqrt(2))
+ const_dtype = infer_type(x).checked_type.dtype
+ half = _expr.const(0.5, dtype=const_dtype)
+ one = _expr.const(1.0, dtype=const_dtype)
+ sqrt2 = _expr.const(math.sqrt(2), dtype=const_dtype)
# Compute gelu
term1 = _op.multiply(half, x)
@@ -836,6 +853,201 @@ class BiasGelu(OnnxOpConverter):
return Gelu._impl_v1([inp], attr, params)
+class EmbedLayerNormalization(OnnxOpConverter):
+ """Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset.
+
+ This layer embeds the input tokens, sums them, and applies layer normalization.
+ """
+
+ @classmethod
+ def _impl_v1(cls, inputs, attr, params):
+ input_ids = inputs[0]
+ segment_ids = inputs[1]
+ word_emb = inputs[2]
+ pos_emb = inputs[3]
+ segment_emb = inputs[4]
+ gamma = inputs[5]
+ beta = inputs[6]
+
+ mask = inputs[7]
+ pos_ids = inputs[8]
+
+ eps = attr.get("epsilon", 1e-12)
+
+ (batch_size, seq_len) = infer_shape(input_ids)
+
+ if segment_ids:
+ assert segment_emb
+
+ if pos_ids is None:
+ pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int32")
+
+ word_vec = _op.take(word_emb, input_ids, axis=0)
+ segment_vec = _op.take(segment_emb, segment_ids, axis=0)
+ pos_vec = _op.take(pos_emb, pos_ids, axis=0)
+
+ vec_sum = _op.add(word_vec, pos_vec)
+ if segment_ids:
+ vec_sum = _op.add(vec_sum, segment_vec)
+
+ ln = layer_norm(vec_sum, eps, gamma, beta)
+
+ mask_index = _op.const(np.zeros((batch_size,), dtype="int32"))
+ if mask:
+ # calculate number of words per sentence
+ mask_index = _op.sum(mask, axis=1)
+
+ # TODO(@anwang2009): onnxruntime v1.10.0 requires a third output of vec_sum
+ return _expr.TupleWrapper(_expr.Tuple([ln, mask_index]), 2)
+
+
+class SkipLayerNormalization(OnnxOpConverter):
+ """Operator converter for SkipLayerNormalization from Microsoft onnxruntime contrib opset.
+
+ This layer sums the two input tensors (along with optional bias), and applies layer
+ normalization.
+ """
+
+ @classmethod
+ def _impl_v1(cls, inputs, attr, params):
+ data = inputs[0]
+ skip = inputs[1]
+ gamma = inputs[2]
+ beta = inputs[3]
+ bias = inputs[4]
+
+ assert (
+ beta is not None and bias is not None
+ ), "SkipLayerNormalization import currently only supports required beta and bias"
+
+ eps = attr.get("epsilon", 1e-12)
+
+ x = _op.add(data, skip)
+ if bias is not None:
+ x = _op.add(x, bias)
+
+ output = layer_norm(x, eps, gamma, beta)
+
+ # onnxruntime doesn't compute the other outputs, despite the documentation
+ placeholder = _op.const(0, dtype="float32")
+
+ return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3)
+
+
+class Attention(OnnxOpConverter):
+ """Operator converter for Attention from Microsoft onnxruntime contrib opset.
+
+ This is the self-attention mechanism used in transformer models.
+ """
+
+ @classmethod
+ def _impl_v1(cls, inputs, attr, params):
+ num_heads = attr["num_heads"]
+ assert (
+ "qkv_hidden_sizes" not in attr
+ ), "different hidden sizes for Q, K, V are not currently supported"
+ assert "unidirectional" not in attr, "unidirectional attention not current supported"
+
+ # (batch, seq, in_hidden)
+ input_emb = inputs[0]
+
+ # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
+ weight = inputs[1]
+
+ # (3 * out_hidden,)
+ bias = inputs[2]
+
+ # 1. ( batch, 1, max_seq, max_seq)
+ # 2. ( batch, past_seq + seq,)
+ # 3. ( batch, seq, past_seq + seq,)
+ # 4. ( batch,)
+ # 5. (2 * batch,)
+ # For now, we only support case 2.
+ mask_index = inputs[3]
+
+ # (2, batch, num_heads, past_seq, head_size)
+ past = inputs[4]
+
+ # (batch, num_heads, seq, seq)
+ extra_add = inputs[5]
+
+ (batch_size, seq_len, _) = infer_shape(input_emb)
+ (out_hidden_x3,) = infer_shape(bias)
+ assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3"
+ out_hidden = out_hidden_x3 // 3
+ assert (
+ out_hidden % num_heads == 0
+ ), "output hidden size should be divisible by number of attention heads"
+ head_size = out_hidden // num_heads
+
+ assert (
+ mask_index is not None
+ ), "Attention import currently only supports required mask_index"
+ mask_index_shape = infer_shape(mask_index)
+ assert (
+ len(mask_index_shape) == 2
+ and mask_index_shape[0] == batch_size
+ and mask_index_shape[1] == seq_len
+ ), "currently only support (batch_size, sequence_length) mask index"
+
+ assert past is None, "past K, V state is not currently supported"
+ assert extra_add is None, "extra add to QxK not currently supported"
+
+ # split weight and biases and do the matmuls
+ w_Q, w_K, w_V = _op.split(weight, 3, axis=1)
+ b_Q, b_K, b_V = _op.split(bias, 3, axis=0)
+ # need to merge batch dimensions since TVM matmul is 2D
+ input_emb = _op.reverse_reshape(input_emb, (-1, 0))
+ Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q)
+ K = _op.add(_op.nn.matmul(input_emb, w_K), b_K)
+ V = _op.add(_op.nn.matmul(input_emb, w_V), b_V)
+
+ # massage tensors in preparation for batched matmul
+ def massage(tensor):
+ tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size))
+
+ # (batch_size, num_heads, seq_len, head_size)
+ tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])
+
+ # (batch_size * num_heads, seq_len, head_size)
+ return _op.reverse_reshape(tensor, (-1, 0, 0))
+
+ Q = massage(Q)
+ K = massage(K)
+ V = massage(V)
+
+ K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size))
+ V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size))
+ present = _op.stack([K_present, V_present], axis=0)
+
+ att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True)
+ score_dtype = infer_type(att_scores).checked_type.dtype
+ att_scores = _op.divide(
+ att_scores,
+ _op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype),
+ )
+ att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len))
+
+ # build the attention mask
+ att_mask = _op.cast(mask_index, score_dtype)
+ att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
+ att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
+ att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))
+
+ # apply the mask
+ att_scores = _op.add(att_scores, att_mask)
+ att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len))
+
+ att_probs = _op.nn.softmax(att_scores, axis=-1)
+
+ output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False)
+ output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
+ output = _op.transpose(output, axes=[0, 2, 1, 3])
+ output = _op.reshape(output, (0, 0, out_hidden))
+
+ return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
+
+
class Gemm(OnnxOpConverter):
"""Operator converter for Gemm."""
@@ -4808,6 +5020,12 @@ def _get_convert_map(opset):
"Elu": Elu.get_converter(opset),
"Gelu": Gelu.get_converter(opset),
"BiasGelu": BiasGelu.get_converter(opset),
+ # TODO: We need a better way to handle different domains, in case
+ # of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention
+ # are in the `com.microsoft` domain.
+ "EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset),
+ "SkipLayerNormalization": SkipLayerNormalization.get_converter(opset),
+ "Attention": Attention.get_converter(opset),
"Exp": Renamer("exp"),
"Greater": Renamer("greater"),
"GreaterOrEqual": Renamer("greater_equal"),
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 12e02d5f29..5cc57c87e8 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -39,6 +39,10 @@ def get_input_data_shape_dict(graph_def, input_data):
shape_dict = {}
for i, _ in enumerate(input_data):
input_names[i] = graph_def.graph.input[i].name
+ if input_data[i] is None or input_data[i].shape == ():
+ # Skip adding input shape data when the input data is None;
+ # This is to enable optional arguments for onnx operators.
+ continue
shape_dict[input_names[i]] = input_data[i].shape
else:
input_names = graph_def.graph.input[0].name
@@ -5422,6 +5426,221 @@ def test_biasgelu(target, dev):
verify_biasgelu(x, bias)
+@tvm.testing.parametrize_targets
+def test_embedlayernormalization(target, dev):
+ def verify_embedlayernormalization(
+ input_ids,
+ segment_ids,
+ word_embedding,
+ position_embedding,
+ segment_embedding,
+ gamma,
+ beta,
+ ):
+ node = onnx.helper.make_node(
+ "EmbedLayerNormalization",
+ inputs=[
+ "input_ids",
+ "" if segment_ids is None else "segment_ids",
+ "word_embedding",
+ "position_embedding",
+ "" if segment_embedding is None else "segment_embedding",
+ "gamma",
+ "beta",
+ ],
+ outputs=["output", "mask_index"],
+ domain="com.microsoft",
+ )
+
+ node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4))
+
+ segment_ids_shape = [] if segment_ids is None else segment_ids.shape
+ segment_embedding_shape = [] if segment_embedding is None else segment_embedding.shape
+
+ graph = helper.make_graph(
+ [node],
+ "embedlayernormalization_test",
+ inputs=[
+ helper.make_tensor_value_info(
+ "input_ids", TensorProto.INT32, list(input_ids.shape)
+ ),
+ helper.make_tensor_value_info("segment_ids", TensorProto.INT32, segment_ids_shape),
+ helper.make_tensor_value_info(
+ "word_embedding", TensorProto.FLOAT, list(word_embedding.shape)
+ ),
+ helper.make_tensor_value_info(
+ "position_embedding", TensorProto.FLOAT, list(position_embedding.shape)
+ ),
+ helper.make_tensor_value_info(
+ "segment_embedding", TensorProto.FLOAT, segment_embedding_shape
+ ),
+ helper.make_tensor_value_info("gamma", TensorProto.FLOAT, list(gamma.shape)),
+ helper.make_tensor_value_info("beta", TensorProto.FLOAT, list(beta.shape)),
+ ],
+ outputs=[
+ helper.make_tensor_value_info(
+ "output", TensorProto.FLOAT, list((batch_size, sequence_length, hidden_size))
+ ),
+ helper.make_tensor_value_info("mask_index", TensorProto.INT32, [batch_size]),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="embedlayernormalization_test")
+
+ # TODO(@anwang2009): onnxruntime v1.9.0 requires empty list for optional argument,
+ # but v1.10.0+ requires None instead.
+ verify_with_ort_with_inputs(
+ model,
+ [
+ input_ids,
+ np.empty(0, dtype="int32") if segment_ids is None else segment_ids,
+ word_embedding,
+ position_embedding,
+ np.empty(0, dtype="float32") if segment_embedding is None else segment_embedding,
+ gamma,
+ beta,
+ ],
+ [
+ (batch_size, sequence_length, hidden_size),
+ batch_size,
+ ],
+ target=target,
+ dev=dev,
+ rtol=1e-4,
+ atol=1e-4,
+ )
+
+ hidden_size = 384
+ batch_size = 4
+ sequence_length = 4
+ vocab_size = 5
+
+ input_ids = np.full((batch_size, sequence_length), 3).astype("int32")
+ segment_ids = np.zeros((batch_size, sequence_length)).astype("int32")
+ word_embedding = np.full((vocab_size, hidden_size), 1).astype("float32")
+ position_embedding = np.full((sequence_length, hidden_size), 2).astype("float32")
+ segment_embedding = np.full((vocab_size, hidden_size), 3).astype("float32")
+
+ gamma = np.random.uniform(0.5, 0.7, hidden_size).astype("float32")
+ beta = np.random.randn(hidden_size).astype("float32") * 0.1
+
+ verify_embedlayernormalization(
+ input_ids, segment_ids, word_embedding, position_embedding, segment_embedding, gamma, beta
+ )
+
+ # Test with undefined segment embedding
+ verify_embedlayernormalization(
+ input_ids, None, word_embedding, position_embedding, None, gamma, beta
+ )
+
+
+@tvm.testing.parametrize_targets
+def test_attention(target, dev):
+ def verify_attention(input, weight, bias, mask_index, num_heads):
+ node = onnx.helper.make_node(
+ "Attention",
+ inputs=["input", "weight", "bias", "mask_index"],
+ outputs=["output", "present"],
+ domain="com.microsoft",
+ num_heads=num_heads,
+ )
+
+ present_output_shape = (2, batch_size, num_heads, sequence_length, head_size)
+
+ graph = helper.make_graph(
+ [node],
+ "attention_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT, list(input.shape)),
+ helper.make_tensor_value_info("weight", TensorProto.FLOAT, list(weight.shape)),
+ helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)),
+ helper.make_tensor_value_info(
+ "mask_index", TensorProto.INT32, list(mask_index.shape)
+ ),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("output", TensorProto.FLOAT, list(input.shape)),
+ helper.make_tensor_value_info(
+ "present", TensorProto.FLOAT, list(present_output_shape)
+ ),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="attention_test")
+
+ # "present" output should be nullptr when the "past" input isn't included,
+ # but ort requires an output shape to be specified?
+ verify_with_ort_with_inputs(
+ model,
+ [input, weight, bias, mask_index],
+ [input.shape, present_output_shape],
+ target=target,
+ dev=dev,
+ rtol=1e-4,
+ atol=1e-4,
+ )
+
+ hidden_size = 384
+ batch_size = 4
+ sequence_length = 4
+ num_heads = 12
+ head_size = 32
+
+ dtype = "float32"
+ input = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype)
+ weight = np.random.normal(size=(hidden_size, 3 * hidden_size)).astype(dtype) * 0.1
+ bias = np.random.randn(3 * hidden_size).astype(dtype)
+ mask_index = np.full((batch_size, sequence_length), 1).astype("int32")
+
+ verify_attention(input, weight, bias, mask_index, num_heads)
+
+
+@tvm.testing.parametrize_targets
+def test_skiplayernormalization(target, dev):
+ def verify_skiplayernormalization(input, skip, gamma, beta, bias):
+ node = onnx.helper.make_node(
+ "SkipLayerNormalization",
+ inputs=["input", "skip", "gamma", "beta", "bias"],
+ outputs=["output"],
+ domain="com.microsoft",
+ )
+
+ node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4))
+
+ graph = helper.make_graph(
+ [node],
+ "skiplayernormalization_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT, list(input.shape)),
+ helper.make_tensor_value_info("skip", TensorProto.FLOAT, list(skip.shape)),
+ helper.make_tensor_value_info("gamma", TensorProto.FLOAT, list(gamma.shape)),
+ helper.make_tensor_value_info("beta", TensorProto.FLOAT, list(beta.shape)),
+ helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("output", TensorProto.FLOAT, list(input.shape)),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="skiplayernormalization_test")
+ verify_with_ort_with_inputs(
+ model, [input, skip, gamma, beta, bias], [input.shape], target=target, dev=dev
+ )
+
+ hidden_size = 384
+ batch_size = 4
+ sequence_length = 4
+
+ dtype = "float32"
+ input = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype)
+ skip = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype)
+ gamma = np.random.uniform(0.5, 0.7, hidden_size).astype(dtype)
+ beta = np.random.randn(hidden_size).astype(dtype) * 0.1
+ bias = np.random.randn(hidden_size).astype(dtype)
+
+ verify_skiplayernormalization(input, skip, gamma, beta, bias)
+
+
@tvm.testing.known_failing_targets("cuda")
@tvm.testing.parametrize_targets
def test_qlinearconv(target, dev):