You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/04/13 17:25:20 UTC

[tvm] branch main updated: [ONNX] Add imports for BERT contrib operators (#10949)

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 11b8cd3ca1 [ONNX] Add imports for BERT contrib operators (#10949)
11b8cd3ca1 is described below

commit 11b8cd3ca167efeea6d96b34a988d874b038f1c9
Author: Altan Haan <31...@users.noreply.github.com>
AuthorDate: Wed Apr 13 10:25:14 2022 -0700

    [ONNX] Add imports for BERT contrib operators (#10949)
    
    * EmbedLayerNormalization, Attention
    
    * fix Attention
    
    * SkipLayerNormalization
    
    * fix dtype bug in Gelu
    
    Co-authored-by: An Wang <an...@gmail.com>
    
    * missing parameterize_targets
    
    * lint
    
    * lint
    
    * comments
    
    * fix small thing
    
    * factor out layer norm computation
    
    * layernorm func
    
    * add optional args to test
    
    * upgrade onnxrt version
    
    * no upgrade onnx
    
    * fix tests
    
    * int32
    
    * fix tests
    
    Co-authored-by: An Wang <an...@gmail.com>
---
 python/tvm/relay/frontend/onnx.py          | 224 ++++++++++++++++++++++++++++-
 tests/python/frontend/onnx/test_forward.py | 219 ++++++++++++++++++++++++++++
 2 files changed, 440 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 168362e229..31b7c21e42 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -329,6 +329,22 @@ def matmul_out_dtype(inputs, out_dtype):
     return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype)
 
 
+def layer_norm(x, eps, gamma, beta):
+    """Common function to handle layer norm"""
+    eps_dtype = infer_type(x).checked_type.dtype
+
+    u, s = _op.mean_variance(x, axis=-1, keepdims=True)
+    output = _op.divide(
+        _op.subtract(x, u),
+        _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
+    )
+    output = _op.multiply(output, gamma)
+    if beta is not None:
+        output = _op.add(output, beta)
+
+    return output
+
+
 class OnnxOpConverter(object):
     """A helper class for holding onnx op converters."""
 
@@ -807,9 +823,10 @@ class Gelu(OnnxOpConverter):
         x = inputs[0]
 
         # Declare consts
-        half = _expr.const(0.5)
-        one = _expr.const(1.0)
-        sqrt2 = _expr.const(math.sqrt(2))
+        const_dtype = infer_type(x).checked_type.dtype
+        half = _expr.const(0.5, dtype=const_dtype)
+        one = _expr.const(1.0, dtype=const_dtype)
+        sqrt2 = _expr.const(math.sqrt(2), dtype=const_dtype)
 
         # Compute gelu
         term1 = _op.multiply(half, x)
@@ -836,6 +853,201 @@ class BiasGelu(OnnxOpConverter):
         return Gelu._impl_v1([inp], attr, params)
 
 
+class EmbedLayerNormalization(OnnxOpConverter):
+    """Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset.
+
+    This layer embeds the input tokens, sums them, and applies layer normalization.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        input_ids = inputs[0]
+        segment_ids = inputs[1]
+        word_emb = inputs[2]
+        pos_emb = inputs[3]
+        segment_emb = inputs[4]
+        gamma = inputs[5]
+        beta = inputs[6]
+
+        mask = inputs[7]
+        pos_ids = inputs[8]
+
+        eps = attr.get("epsilon", 1e-12)
+
+        (batch_size, seq_len) = infer_shape(input_ids)
+
+        if segment_ids:
+            assert segment_emb
+
+        if pos_ids is None:
+            pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int32")
+
+        word_vec = _op.take(word_emb, input_ids, axis=0)
+        segment_vec = _op.take(segment_emb, segment_ids, axis=0)
+        pos_vec = _op.take(pos_emb, pos_ids, axis=0)
+
+        vec_sum = _op.add(word_vec, pos_vec)
+        if segment_ids:
+            vec_sum = _op.add(vec_sum, segment_vec)
+
+        ln = layer_norm(vec_sum, eps, gamma, beta)
+
+        mask_index = _op.const(np.zeros((batch_size,), dtype="int32"))
+        if mask:
+            # calculate number of words per sentence
+            mask_index = _op.sum(mask, axis=1)
+
+        # TODO(@anwang2009): onnxruntime v1.10.0 requires a third output of vec_sum
+        return _expr.TupleWrapper(_expr.Tuple([ln, mask_index]), 2)
+
+
+class SkipLayerNormalization(OnnxOpConverter):
+    """Operator converter for SkipLayerNormalization from Microsoft onnxruntime contrib opset.
+
+    This layer sums the two input tensors (along with optional bias), and applies layer
+    normalization.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        data = inputs[0]
+        skip = inputs[1]
+        gamma = inputs[2]
+        beta = inputs[3]
+        bias = inputs[4]
+
+        assert (
+            beta is not None and bias is not None
+        ), "SkipLayerNormalization import currently only supports required beta and bias"
+
+        eps = attr.get("epsilon", 1e-12)
+
+        x = _op.add(data, skip)
+        if bias is not None:
+            x = _op.add(x, bias)
+
+        output = layer_norm(x, eps, gamma, beta)
+
+        # onnxruntime doesn't compute the other outputs, despite the documentation
+        placeholder = _op.const(0, dtype="float32")
+
+        return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3)
+
+
+class Attention(OnnxOpConverter):
+    """Operator converter for Attention from Microsoft onnxruntime contrib opset.
+
+    This is the self-attention mechanism used in transformer models.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        num_heads = attr["num_heads"]
+        assert (
+            "qkv_hidden_sizes" not in attr
+        ), "different hidden sizes for Q, K, V are not currently supported"
+        assert "unidirectional" not in attr, "unidirectional attention not current supported"
+
+        # (batch, seq, in_hidden)
+        input_emb = inputs[0]
+
+        # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
+        weight = inputs[1]
+
+        # (3 * out_hidden,)
+        bias = inputs[2]
+
+        # 1. (    batch,              1,        max_seq, max_seq)
+        # 2. (    batch, past_seq + seq,)
+        # 3. (    batch,            seq, past_seq + seq,)
+        # 4. (    batch,)
+        # 5. (2 * batch,)
+        # For now, we only support case 2.
+        mask_index = inputs[3]
+
+        # (2, batch, num_heads, past_seq, head_size)
+        past = inputs[4]
+
+        # (batch, num_heads, seq, seq)
+        extra_add = inputs[5]
+
+        (batch_size, seq_len, _) = infer_shape(input_emb)
+        (out_hidden_x3,) = infer_shape(bias)
+        assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3"
+        out_hidden = out_hidden_x3 // 3
+        assert (
+            out_hidden % num_heads == 0
+        ), "output hidden size should be divisible by number of attention heads"
+        head_size = out_hidden // num_heads
+
+        assert (
+            mask_index is not None
+        ), "Attention import currently only supports required mask_index"
+        mask_index_shape = infer_shape(mask_index)
+        assert (
+            len(mask_index_shape) == 2
+            and mask_index_shape[0] == batch_size
+            and mask_index_shape[1] == seq_len
+        ), "currently only support (batch_size, sequence_length) mask index"
+
+        assert past is None, "past K, V state is not currently supported"
+        assert extra_add is None, "extra add to QxK not currently supported"
+
+        # split weight and biases and do the matmuls
+        w_Q, w_K, w_V = _op.split(weight, 3, axis=1)
+        b_Q, b_K, b_V = _op.split(bias, 3, axis=0)
+        # need to merge batch dimensions since TVM matmul is 2D
+        input_emb = _op.reverse_reshape(input_emb, (-1, 0))
+        Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q)
+        K = _op.add(_op.nn.matmul(input_emb, w_K), b_K)
+        V = _op.add(_op.nn.matmul(input_emb, w_V), b_V)
+
+        # massage tensors in preparation for batched matmul
+        def massage(tensor):
+            tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size))
+
+            # (batch_size, num_heads, seq_len, head_size)
+            tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])
+
+            # (batch_size * num_heads, seq_len, head_size)
+            return _op.reverse_reshape(tensor, (-1, 0, 0))
+
+        Q = massage(Q)
+        K = massage(K)
+        V = massage(V)
+
+        K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size))
+        V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size))
+        present = _op.stack([K_present, V_present], axis=0)
+
+        att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True)
+        score_dtype = infer_type(att_scores).checked_type.dtype
+        att_scores = _op.divide(
+            att_scores,
+            _op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype),
+        )
+        att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len))
+
+        # build the attention mask
+        att_mask = _op.cast(mask_index, score_dtype)
+        att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
+        att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
+        att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))
+
+        # apply the mask
+        att_scores = _op.add(att_scores, att_mask)
+        att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len))
+
+        att_probs = _op.nn.softmax(att_scores, axis=-1)
+
+        output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False)
+        output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
+        output = _op.transpose(output, axes=[0, 2, 1, 3])
+        output = _op.reshape(output, (0, 0, out_hidden))
+
+        return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
+
+
 class Gemm(OnnxOpConverter):
     """Operator converter for Gemm."""
 
@@ -4808,6 +5020,12 @@ def _get_convert_map(opset):
         "Elu": Elu.get_converter(opset),
         "Gelu": Gelu.get_converter(opset),
         "BiasGelu": BiasGelu.get_converter(opset),
+        # TODO: We need a better way to handle different domains, in case
+        # of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention
+        # are in the `com.microsoft` domain.
+        "EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset),
+        "SkipLayerNormalization": SkipLayerNormalization.get_converter(opset),
+        "Attention": Attention.get_converter(opset),
         "Exp": Renamer("exp"),
         "Greater": Renamer("greater"),
         "GreaterOrEqual": Renamer("greater_equal"),
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 12e02d5f29..5cc57c87e8 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -39,6 +39,10 @@ def get_input_data_shape_dict(graph_def, input_data):
         shape_dict = {}
         for i, _ in enumerate(input_data):
             input_names[i] = graph_def.graph.input[i].name
+            if input_data[i] is None or input_data[i].shape == ():
+                # Skip adding input shape data when the input data is None;
+                # This is to enable optional arguments for onnx operators.
+                continue
             shape_dict[input_names[i]] = input_data[i].shape
     else:
         input_names = graph_def.graph.input[0].name
@@ -5422,6 +5426,221 @@ def test_biasgelu(target, dev):
     verify_biasgelu(x, bias)
 
 
+@tvm.testing.parametrize_targets
+def test_embedlayernormalization(target, dev):
+    def verify_embedlayernormalization(
+        input_ids,
+        segment_ids,
+        word_embedding,
+        position_embedding,
+        segment_embedding,
+        gamma,
+        beta,
+    ):
+        node = onnx.helper.make_node(
+            "EmbedLayerNormalization",
+            inputs=[
+                "input_ids",
+                "" if segment_ids is None else "segment_ids",
+                "word_embedding",
+                "position_embedding",
+                "" if segment_embedding is None else "segment_embedding",
+                "gamma",
+                "beta",
+            ],
+            outputs=["output", "mask_index"],
+            domain="com.microsoft",
+        )
+
+        node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4))
+
+        segment_ids_shape = [] if segment_ids is None else segment_ids.shape
+        segment_embedding_shape = [] if segment_embedding is None else segment_embedding.shape
+
+        graph = helper.make_graph(
+            [node],
+            "embedlayernormalization_test",
+            inputs=[
+                helper.make_tensor_value_info(
+                    "input_ids", TensorProto.INT32, list(input_ids.shape)
+                ),
+                helper.make_tensor_value_info("segment_ids", TensorProto.INT32, segment_ids_shape),
+                helper.make_tensor_value_info(
+                    "word_embedding", TensorProto.FLOAT, list(word_embedding.shape)
+                ),
+                helper.make_tensor_value_info(
+                    "position_embedding", TensorProto.FLOAT, list(position_embedding.shape)
+                ),
+                helper.make_tensor_value_info(
+                    "segment_embedding", TensorProto.FLOAT, segment_embedding_shape
+                ),
+                helper.make_tensor_value_info("gamma", TensorProto.FLOAT, list(gamma.shape)),
+                helper.make_tensor_value_info("beta", TensorProto.FLOAT, list(beta.shape)),
+            ],
+            outputs=[
+                helper.make_tensor_value_info(
+                    "output", TensorProto.FLOAT, list((batch_size, sequence_length, hidden_size))
+                ),
+                helper.make_tensor_value_info("mask_index", TensorProto.INT32, [batch_size]),
+            ],
+        )
+
+        model = helper.make_model(graph, producer_name="embedlayernormalization_test")
+
+        # TODO(@anwang2009): onnxruntime v1.9.0 requires empty list for optional argument,
+        # but v1.10.0+ requires None instead.
+        verify_with_ort_with_inputs(
+            model,
+            [
+                input_ids,
+                np.empty(0, dtype="int32") if segment_ids is None else segment_ids,
+                word_embedding,
+                position_embedding,
+                np.empty(0, dtype="float32") if segment_embedding is None else segment_embedding,
+                gamma,
+                beta,
+            ],
+            [
+                (batch_size, sequence_length, hidden_size),
+                batch_size,
+            ],
+            target=target,
+            dev=dev,
+            rtol=1e-4,
+            atol=1e-4,
+        )
+
+    hidden_size = 384
+    batch_size = 4
+    sequence_length = 4
+    vocab_size = 5
+
+    input_ids = np.full((batch_size, sequence_length), 3).astype("int32")
+    segment_ids = np.zeros((batch_size, sequence_length)).astype("int32")
+    word_embedding = np.full((vocab_size, hidden_size), 1).astype("float32")
+    position_embedding = np.full((sequence_length, hidden_size), 2).astype("float32")
+    segment_embedding = np.full((vocab_size, hidden_size), 3).astype("float32")
+
+    gamma = np.random.uniform(0.5, 0.7, hidden_size).astype("float32")
+    beta = np.random.randn(hidden_size).astype("float32") * 0.1
+
+    verify_embedlayernormalization(
+        input_ids, segment_ids, word_embedding, position_embedding, segment_embedding, gamma, beta
+    )
+
+    # Test with undefined segment embedding
+    verify_embedlayernormalization(
+        input_ids, None, word_embedding, position_embedding, None, gamma, beta
+    )
+
+
+@tvm.testing.parametrize_targets
+def test_attention(target, dev):
+    def verify_attention(input, weight, bias, mask_index, num_heads):
+        node = onnx.helper.make_node(
+            "Attention",
+            inputs=["input", "weight", "bias", "mask_index"],
+            outputs=["output", "present"],
+            domain="com.microsoft",
+            num_heads=num_heads,
+        )
+
+        present_output_shape = (2, batch_size, num_heads, sequence_length, head_size)
+
+        graph = helper.make_graph(
+            [node],
+            "attention_test",
+            inputs=[
+                helper.make_tensor_value_info("input", TensorProto.FLOAT, list(input.shape)),
+                helper.make_tensor_value_info("weight", TensorProto.FLOAT, list(weight.shape)),
+                helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)),
+                helper.make_tensor_value_info(
+                    "mask_index", TensorProto.INT32, list(mask_index.shape)
+                ),
+            ],
+            outputs=[
+                helper.make_tensor_value_info("output", TensorProto.FLOAT, list(input.shape)),
+                helper.make_tensor_value_info(
+                    "present", TensorProto.FLOAT, list(present_output_shape)
+                ),
+            ],
+        )
+
+        model = helper.make_model(graph, producer_name="attention_test")
+
+        # "present" output should be nullptr when the "past" input isn't included,
+        # but ort requires an output shape to be specified?
+        verify_with_ort_with_inputs(
+            model,
+            [input, weight, bias, mask_index],
+            [input.shape, present_output_shape],
+            target=target,
+            dev=dev,
+            rtol=1e-4,
+            atol=1e-4,
+        )
+
+    hidden_size = 384
+    batch_size = 4
+    sequence_length = 4
+    num_heads = 12
+    head_size = 32
+
+    dtype = "float32"
+    input = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype)
+    weight = np.random.normal(size=(hidden_size, 3 * hidden_size)).astype(dtype) * 0.1
+    bias = np.random.randn(3 * hidden_size).astype(dtype)
+    mask_index = np.full((batch_size, sequence_length), 1).astype("int32")
+
+    verify_attention(input, weight, bias, mask_index, num_heads)
+
+
+@tvm.testing.parametrize_targets
+def test_skiplayernormalization(target, dev):
+    def verify_skiplayernormalization(input, skip, gamma, beta, bias):
+        node = onnx.helper.make_node(
+            "SkipLayerNormalization",
+            inputs=["input", "skip", "gamma", "beta", "bias"],
+            outputs=["output"],
+            domain="com.microsoft",
+        )
+
+        node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4))
+
+        graph = helper.make_graph(
+            [node],
+            "skiplayernormalization_test",
+            inputs=[
+                helper.make_tensor_value_info("input", TensorProto.FLOAT, list(input.shape)),
+                helper.make_tensor_value_info("skip", TensorProto.FLOAT, list(skip.shape)),
+                helper.make_tensor_value_info("gamma", TensorProto.FLOAT, list(gamma.shape)),
+                helper.make_tensor_value_info("beta", TensorProto.FLOAT, list(beta.shape)),
+                helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)),
+            ],
+            outputs=[
+                helper.make_tensor_value_info("output", TensorProto.FLOAT, list(input.shape)),
+            ],
+        )
+
+        model = helper.make_model(graph, producer_name="skiplayernormalization_test")
+        verify_with_ort_with_inputs(
+            model, [input, skip, gamma, beta, bias], [input.shape], target=target, dev=dev
+        )
+
+    hidden_size = 384
+    batch_size = 4
+    sequence_length = 4
+
+    dtype = "float32"
+    input = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype)
+    skip = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype)
+    gamma = np.random.uniform(0.5, 0.7, hidden_size).astype(dtype)
+    beta = np.random.randn(hidden_size).astype(dtype) * 0.1
+    bias = np.random.randn(hidden_size).astype(dtype)
+
+    verify_skiplayernormalization(input, skip, gamma, beta, bias)
+
+
 @tvm.testing.known_failing_targets("cuda")
 @tvm.testing.parametrize_targets
 def test_qlinearconv(target, dev):