You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/08 22:19:23 UTC

[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #10949: [ONNX] Add imports for BERT contrib operators

AndrewZhaoLuo commented on code in PR #10949:
URL: https://github.com/apache/tvm/pull/10949#discussion_r846517205


##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -836,6 +837,192 @@ def _impl_v1(cls, inputs, attr, params):
         return Gelu._impl_v1([inp], attr, params)
 
 
+class EmbedLayerNormalization(OnnxOpConverter):
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        input_ids = inputs[0]
+        segment_ids = inputs[1]
+        word_emb = inputs[2]
+        pos_emb = inputs[3]
+        segment_emb = inputs[4]
+        gamma = inputs[5]
+        beta = inputs[6]
+
+        mask = inputs[7]

Review Comment:
   Mask and position_ids are optional, I believe you need to check the length. On the docs, it says the inputs can be a list of length 2-9 (I assume segment_emb is fine even if it's optional)



##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -836,6 +837,192 @@ def _impl_v1(cls, inputs, attr, params):
         return Gelu._impl_v1([inp], attr, params)
 
 
+class EmbedLayerNormalization(OnnxOpConverter):
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        input_ids = inputs[0]
+        segment_ids = inputs[1]
+        word_emb = inputs[2]
+        pos_emb = inputs[3]
+        segment_emb = inputs[4]
+        gamma = inputs[5]
+        beta = inputs[6]
+
+        mask = inputs[7]
+        pos_ids = inputs[8]
+
+        eps = attr["epsilon"] if "epsilon" in attr else 1e-12
+
+        (batch_size, seq_len) = infer_shape(input_ids)
+
+        if segment_ids:
+            assert segment_emb
+
+        if pos_ids is None:
+            pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int64")
+
+        word_vec = _op.take(word_emb, input_ids, axis=0)
+        segment_vec = _op.take(segment_emb, segment_ids, axis=0)
+        pos_vec = _op.take(pos_emb, pos_ids, axis=0)
+
+        vec_sum = _op.add(word_vec, pos_vec)
+        if segment_ids:
+            vec_sum = _op.add(vec_sum, segment_vec)
+
+        eps_dtype = infer_type(word_emb).checked_type.dtype
+
+        u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True)
+        ln = _op.divide(
+            _op.subtract(vec_sum, u),
+            _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
+        )
+        ln = _op.multiply(ln, gamma) + beta
+
+        mask_index = _op.const(np.zeros((batch_size,), dtype="int64"))
+        if mask:
+            # calculate number of words per sentence
+            mask_index = _op.sum(mask, axis=1)
+
+        return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3)
+
+
+class SkipLayerNormalization(OnnxOpConverter):
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        data = inputs[0]
+        skip = inputs[1]
+        gamma = inputs[2]
+        beta = inputs[3]

Review Comment:
   inputs are list 3-5 so probably need to check lengths



##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -836,6 +837,192 @@ def _impl_v1(cls, inputs, attr, params):
         return Gelu._impl_v1([inp], attr, params)
 
 
+class EmbedLayerNormalization(OnnxOpConverter):
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        input_ids = inputs[0]
+        segment_ids = inputs[1]
+        word_emb = inputs[2]
+        pos_emb = inputs[3]
+        segment_emb = inputs[4]
+        gamma = inputs[5]
+        beta = inputs[6]
+
+        mask = inputs[7]
+        pos_ids = inputs[8]
+
+        eps = attr["epsilon"] if "epsilon" in attr else 1e-12
+
+        (batch_size, seq_len) = infer_shape(input_ids)
+
+        if segment_ids:
+            assert segment_emb
+
+        if pos_ids is None:
+            pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int64")
+
+        word_vec = _op.take(word_emb, input_ids, axis=0)
+        segment_vec = _op.take(segment_emb, segment_ids, axis=0)
+        pos_vec = _op.take(pos_emb, pos_ids, axis=0)
+
+        vec_sum = _op.add(word_vec, pos_vec)
+        if segment_ids:
+            vec_sum = _op.add(vec_sum, segment_vec)
+
+        eps_dtype = infer_type(word_emb).checked_type.dtype
+
+        u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True)
+        ln = _op.divide(
+            _op.subtract(vec_sum, u),
+            _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
+        )
+        ln = _op.multiply(ln, gamma) + beta
+
+        mask_index = _op.const(np.zeros((batch_size,), dtype="int64"))
+        if mask:
+            # calculate number of words per sentence
+            mask_index = _op.sum(mask, axis=1)
+
+        return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3)
+
+
+class SkipLayerNormalization(OnnxOpConverter):
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        data = inputs[0]
+        skip = inputs[1]
+        gamma = inputs[2]
+        beta = inputs[3]
+        bias = inputs[4]
+
+        eps = attr["epsilon"] if "epsilon" in attr else 1e-12
+
+        x = _op.add(data, skip)
+        if bias is not None:
+            x = _op.add(x, bias)
+
+        eps_dtype = infer_type(x).checked_type.dtype
+
+        u, s = _op.mean_variance(x, axis=-1, keepdims=True)
+        output = _op.divide(
+            _op.subtract(x, u),
+            _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
+        )
+        output = _op.multiply(output, gamma)
+        if beta:
+            output = _op.add(output, beta)
+
+        placeholder = _op.const(0, dtype="float32")
+
+        return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3)
+
+
+class Attention(OnnxOpConverter):
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        num_heads = attr["num_heads"]
+        assert (
+            "qkv_hidden_sizes" not in attr
+        ), "different hidden sizes for Q, K, V are not currently supported"
+        assert "unidirectional" not in attr, "unidirectional attention not current supported"
+
+        # (batch, seq, in_hidden)
+        input_emb = inputs[0]
+
+        # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
+        weight = inputs[1]
+
+        # (3 * out_hidden,)
+        bias = inputs[2]
+
+        # 1. (    batch,              1,        max_seq, max_seq)
+        # 2. (    batch, past_seq + seq,)
+        # 3. (    batch,            seq, past_seq + seq,)
+        # 4. (    batch,)
+        # 5. (2 * batch,)
+        # For now, we only support case 2.
+        mask_index = inputs[3]

Review Comment:
   Same thing with lengths



##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -4737,6 +4924,9 @@ def _get_convert_map(opset):
         "Elu": Elu.get_converter(opset),
         "Gelu": Gelu.get_converter(opset),
         "BiasGelu": BiasGelu.get_converter(opset),
+        "EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset),

Review Comment:
   If these are all contrib operators, we need a better way to handle different `domains` (see https://github.com/onnx/onnx/blob/main/docs/IR.md#nodes). In this case they are under `com.microsoft` I believe. 
   
   Right now everything is assumed under default namespace.
   
   Just note a TODO here with the namespace they belong under and I will open up an issue.



##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -836,6 +837,192 @@ def _impl_v1(cls, inputs, attr, params):
         return Gelu._impl_v1([inp], attr, params)
 
 
+class EmbedLayerNormalization(OnnxOpConverter):
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        input_ids = inputs[0]
+        segment_ids = inputs[1]
+        word_emb = inputs[2]
+        pos_emb = inputs[3]
+        segment_emb = inputs[4]
+        gamma = inputs[5]
+        beta = inputs[6]
+
+        mask = inputs[7]
+        pos_ids = inputs[8]
+
+        eps = attr["epsilon"] if "epsilon" in attr else 1e-12

Review Comment:
   where'd you get these default values? 
   
   Also suggest attr.get('epsilon', 1e-12)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org