You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/12/23 12:00:21 UTC

[GitHub] [tvm] KJlaccHoeUM9l opened a new pull request, #13654: [ONNX] Add converter for QAttention from Microsoft onnxruntime contrib opset

KJlaccHoeUM9l opened a new pull request, #13654:
URL: https://github.com/apache/tvm/pull/13654

   This PR adds support for [QAttention](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QAttention) - quantized version of [Attention](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention) from Microsoft onnxruntime contrib opset.
   An explanation and illustration of how this layer works can be found, for example, in @lena-voita [NLP course](https://lena-voita.github.io/nlp_course/seq2seq_and_attention.html).


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] KJlaccHoeUM9l commented on a diff in pull request #13654: [ONNX] Add converter for QAttention from Microsoft onnxruntime contrib opset

Posted by GitBox <gi...@apache.org>.
KJlaccHoeUM9l commented on code in PR #13654:
URL: https://github.com/apache/tvm/pull/13654#discussion_r1058146999


##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -1379,6 +1379,298 @@ def massage(tensor):
         return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
 
 
+class QAttention(OnnxOpConverter):
+    """Operator converter for QAttention from Microsoft onnxruntime contrib opset.
+
+    This is the self-attention mechanism used in transformer models.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # ************************* Read attrs *************************
+        num_heads = attr["num_heads"]
+        unidirectional = attr["unidirectional"]
+
+        # ************************* Read inputs *************************
+        # (batch, seq, in_hidden)
+        input_emb = inputs[0]
+
+        # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
+        weight = inputs[1]
+
+        # (3 * out_hidden,)
+        bias = inputs[2]
+
+        # Scalar, which means a per-tensor/layer quantization
+        input_scale = inputs[3]
+
+        # Scalar or a 1D tensor, which means a per-tensor/per-column quantization.
+        # Its size should be 3 * out_hidden if it is per-column quantization
+        weight_scale = inputs[4]
+
+        # TODO(agladyshev):
+        #  ORT documentation says that shape is (batch,),
+        #  but in ORT source code we have following comment:
+        #       1. (batch_size)
+        #       2. (2 * batch_size)
+        #       3. (batch_size, 1)
+        #       4. (1, 1)
+        #       5. (batch_size, past_sequence_length + sequence_length)
+        #  In practice, for GPT-2 there shape is (batch, past_seq_length + seq_length).
+        #  Currently only (batch, past_seq_length + seq_length) shape is supported.
+        mask_index = inputs[5]
+
+        # Scalar, which means a per-tensor/layer quantization

Review Comment:
   done



##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -1379,6 +1379,298 @@ def massage(tensor):
         return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
 
 
+class QAttention(OnnxOpConverter):
+    """Operator converter for QAttention from Microsoft onnxruntime contrib opset.
+
+    This is the self-attention mechanism used in transformer models.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # ************************* Read attrs *************************
+        num_heads = attr["num_heads"]
+        unidirectional = attr["unidirectional"]
+
+        # ************************* Read inputs *************************
+        # (batch, seq, in_hidden)
+        input_emb = inputs[0]
+
+        # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
+        weight = inputs[1]
+
+        # (3 * out_hidden,)
+        bias = inputs[2]
+
+        # Scalar, which means a per-tensor/layer quantization
+        input_scale = inputs[3]
+
+        # Scalar or a 1D tensor, which means a per-tensor/per-column quantization.
+        # Its size should be 3 * out_hidden if it is per-column quantization
+        weight_scale = inputs[4]
+
+        # TODO(agladyshev):
+        #  ORT documentation says that shape is (batch,),
+        #  but in ORT source code we have following comment:
+        #       1. (batch_size)
+        #       2. (2 * batch_size)
+        #       3. (batch_size, 1)
+        #       4. (1, 1)
+        #       5. (batch_size, past_sequence_length + sequence_length)
+        #  In practice, for GPT-2 there shape is (batch, past_seq_length + seq_length).
+        #  Currently only (batch, past_seq_length + seq_length) shape is supported.
+        mask_index = inputs[5]
+
+        # Scalar, which means a per-tensor/layer quantization
+        input_zero_point = inputs[6]
+
+        # Scalar or a 1D tensor, which means a per-tensor/per-column quantization.
+        # Its size should be 3 * out_hidden if it is per-column quantization
+        weight_zero_point = inputs[7]
+
+        # (2, batch, num_heads, past_seq, head_size)
+        past = inputs[8]
+
+        # ************************* Parse inputs *************************
+        t1 = ["int8", "uint8"]
+        t2 = ["int8", "uint8"]
+        t3 = ["float32", "float16"]
+        t4 = ["int32"]
+
+        # input
+        assert infer_type(input_emb).checked_type.dtype in t1
+        assert (
+            len(infer_shape(input_emb)) == 3
+        ), "Input should be 3D tensor with shape (batch_size, sequence_length, input_hidden_size)"
+        (batch_size, seq_len, input_hidden) = infer_shape(input_emb)
+        assert input_hidden > 0, (
+            "The weight tensor has (input_hidden_size, 3 * output_hidden_size) shape, so it doesn't"
+            f" make sense to have ({input_hidden}, 3 * output_hidden_size) weight tensor."
+        )
+        assert seq_len > 0, (
+            "The output tensor has (batch_size, sequence_length, hidden_size) shape,"
+            f" so it doesn't make sense to have (batch_size, {seq_len}, hidden_size) output."
+        )
+
+        # weight
+        assert infer_type(weight).checked_type.dtype in t2
+        assert len(infer_shape(weight)) == 2, (
+            "Weight should be 2D input tensor with shape (input_hidden_size, 3 * hidden_size), "
+            "hidden_size = num_heads * head_size"
+        )
+        (input_hidden_weight, out_hidden_x3) = infer_shape(weight)
+        assert input_hidden == input_hidden_weight
+        assert out_hidden_x3 % 3 == 0, "output hidden shape should be divisible by 3: W_Q, W_K, W_V"
+        out_hidden = out_hidden_x3 // 3
+        assert (
+            out_hidden % num_heads == 0
+        ), "output hidden size should be divisible by number of attention heads"
+        head_size = out_hidden // num_heads
+
+        # bias
+        assert infer_type(bias).checked_type.dtype in t3
+        assert (
+            len(infer_shape(bias)) == 1
+        ), "Bias should be 1D input tensor with shape (3 * hidden_size)"
+        (out_hidden_x3_bias,) = infer_shape(bias)
+        assert out_hidden_x3 == out_hidden_x3_bias
+
+        # input_scale
+        assert infer_type(input_scale).checked_type.dtype in t3
+        input_scale = get_scalar(
+            input_scale, params, dtype=infer_type(input_scale).checked_type.dtype
+        )
+
+        # weight_scale
+        assert infer_type(weight_scale).checked_type.dtype in t3
+        # TODO(agladyshev): now QNN Batch Matmul only supports scalar types for scale and zero_point
+        weight_scale = get_scalar(
+            weight_scale, params, dtype=infer_type(weight_scale).checked_type.dtype
+        )
+
+        # mask_index
+        assert (
+            mask_index is not None
+        ), "Attention import currently only supports required mask_index"
+        assert infer_type(mask_index).checked_type.dtype in t4
+        mask_index_shape = infer_shape(mask_index)
+        assert (
+            len(mask_index_shape) == 2
+            and mask_index_shape[0] == batch_size
+            and mask_index_shape[1] >= seq_len
+        ), "currently only support (batch_size, sequence_length) mask index"
+
+        # TODO(agladyshev): int32 required for qnn.batch_matmul (QnnBatchMatmulRel)
+        zero_point_zero = _expr.const(0, "int32")
+
+        # input_zero_point
+        if input_zero_point is None:
+            input_zero_point = zero_point_zero
+        else:
+            assert infer_type(input_zero_point).checked_type.dtype in t1
+            # TODO(agladyshev): int32 required for qnn.batch_matmul (QnnBatchMatmulRel)
+            input_zero_point = get_scalar(input_zero_point, params, dtype="int32")
+
+        # weight_zero_point
+        if weight_zero_point is None:
+            weight_zero_point = zero_point_zero
+        else:
+            assert infer_type(weight_zero_point).checked_type.dtype in t2
+            # TODO(agladyshev): int32 required for qnn.batch_matmul (QnnBatchMatmulRel)
+            weight_zero_point = get_scalar(weight_zero_point, params, dtype="int32")
+
+        # past (2, batch_size, num_heads, past_sequence_length, head_size)
+        past_seq_len = 0
+        if past is not None:
+            assert infer_type(past).checked_type.dtype in t3
+            past_shape = infer_shape(past)
+            assert len(past_shape) == 5, "past should be 5D tensor"
+            assert (
+                past_shape[0] == 2
+                and past_shape[1] == batch_size
+                and past_shape[2] == num_heads
+                and past_shape[3] + seq_len == mask_index_shape[1]
+                and past_shape[4] == head_size
+            )
+            past_seq_len = past_shape[3]
+
+        # ************************* Create Relay *************************
+        # Add batch dimension for QNN Batch Matmul
+        weight = _op.expand_dims(weight, 0, num_newaxis=1)
+        weight = _op.concatenate([weight] * batch_size, axis=0)
+
+        # Split weight and biases and do the Matmul
+        w_Q, w_K, w_V = _op.split(weight, 3, axis=-1)
+        b_Q, b_K, b_V = _op.split(bias, 3, axis=-1)
+
+        def qmatmul_dequantize_bias(
+            lhs, rhs, lhs_scale, rhs_scale, lhs_zero_point, rhs_zero_point, bias
+        ):
+            rhs_transposed = _op.transpose(rhs, axes=[0, 2, 1])  # QNN Batch Matmul do: X * Y^T
+            result = _qnn.op.batch_matmul(
+                lhs, rhs_transposed, lhs_zero_point, rhs_zero_point, lhs_scale, rhs_scale
+            )
+            result = _qnn.op.dequantize(
+                result,
+                _op.multiply(lhs_scale, rhs_scale),
+                zero_point_zero,
+                axis=-1,  # TODO(agladyshev): what is 'axis' parameter for?

Review Comment:
   done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on pull request #13654: [ONNX] Add converter for QAttention from Microsoft onnxruntime contrib opset

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on PR #13654:
URL: https://github.com/apache/tvm/pull/13654#issuecomment-1366442483

   Apologies, I've been quite sick. I'll try to look at this Thursday.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo merged pull request #13654: [ONNX] Add converter for QAttention from Microsoft onnxruntime contrib opset

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo merged PR #13654:
URL: https://github.com/apache/tvm/pull/13654


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] KJlaccHoeUM9l commented on pull request #13654: [ONNX] Add converter for QAttention from Microsoft onnxruntime contrib opset

Posted by GitBox <gi...@apache.org>.
KJlaccHoeUM9l commented on PR #13654:
URL: https://github.com/apache/tvm/pull/13654#issuecomment-1365292220

   Hello @vvchernov, @echuraev, @AndrewZhaoLuo!
   Could you review this PR?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #13654: [ONNX] Add converter for QAttention from Microsoft onnxruntime contrib opset

Posted by GitBox <gi...@apache.org>.
tvm-bot commented on PR #13654:
URL: https://github.com/apache/tvm/pull/13654#issuecomment-1363893513

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * cc @ehsanmok <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] echuraev commented on a diff in pull request #13654: [ONNX] Add converter for QAttention from Microsoft onnxruntime contrib opset

Posted by GitBox <gi...@apache.org>.
echuraev commented on code in PR #13654:
URL: https://github.com/apache/tvm/pull/13654#discussion_r1058087546


##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -1379,6 +1379,298 @@ def massage(tensor):
         return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
 
 
+class QAttention(OnnxOpConverter):
+    """Operator converter for QAttention from Microsoft onnxruntime contrib opset.
+
+    This is the self-attention mechanism used in transformer models.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # ************************* Read attrs *************************
+        num_heads = attr["num_heads"]
+        unidirectional = attr["unidirectional"]
+
+        # ************************* Read inputs *************************
+        # (batch, seq, in_hidden)
+        input_emb = inputs[0]
+
+        # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
+        weight = inputs[1]
+
+        # (3 * out_hidden,)
+        bias = inputs[2]
+
+        # Scalar, which means a per-tensor/layer quantization
+        input_scale = inputs[3]
+
+        # Scalar or a 1D tensor, which means a per-tensor/per-column quantization.
+        # Its size should be 3 * out_hidden if it is per-column quantization
+        weight_scale = inputs[4]
+
+        # TODO(agladyshev):
+        #  ORT documentation says that shape is (batch,),
+        #  but in ORT source code we have following comment:
+        #       1. (batch_size)
+        #       2. (2 * batch_size)
+        #       3. (batch_size, 1)
+        #       4. (1, 1)
+        #       5. (batch_size, past_sequence_length + sequence_length)
+        #  In practice, for GPT-2 there shape is (batch, past_seq_length + seq_length).
+        #  Currently only (batch, past_seq_length + seq_length) shape is supported.
+        mask_index = inputs[5]
+
+        # Scalar, which means a per-tensor/layer quantization

Review Comment:
   nit: You have absolutely the same comment for `input[3]` and `input[7]`



##########
python/tvm/relay/frontend/onnx.py:
##########
@@ -1379,6 +1379,298 @@ def massage(tensor):
         return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)
 
 
+class QAttention(OnnxOpConverter):
+    """Operator converter for QAttention from Microsoft onnxruntime contrib opset.
+
+    This is the self-attention mechanism used in transformer models.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # ************************* Read attrs *************************
+        num_heads = attr["num_heads"]
+        unidirectional = attr["unidirectional"]
+
+        # ************************* Read inputs *************************
+        # (batch, seq, in_hidden)
+        input_emb = inputs[0]
+
+        # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
+        weight = inputs[1]
+
+        # (3 * out_hidden,)
+        bias = inputs[2]
+
+        # Scalar, which means a per-tensor/layer quantization
+        input_scale = inputs[3]
+
+        # Scalar or a 1D tensor, which means a per-tensor/per-column quantization.
+        # Its size should be 3 * out_hidden if it is per-column quantization
+        weight_scale = inputs[4]
+
+        # TODO(agladyshev):
+        #  ORT documentation says that shape is (batch,),
+        #  but in ORT source code we have following comment:
+        #       1. (batch_size)
+        #       2. (2 * batch_size)
+        #       3. (batch_size, 1)
+        #       4. (1, 1)
+        #       5. (batch_size, past_sequence_length + sequence_length)
+        #  In practice, for GPT-2 there shape is (batch, past_seq_length + seq_length).
+        #  Currently only (batch, past_seq_length + seq_length) shape is supported.
+        mask_index = inputs[5]
+
+        # Scalar, which means a per-tensor/layer quantization
+        input_zero_point = inputs[6]
+
+        # Scalar or a 1D tensor, which means a per-tensor/per-column quantization.
+        # Its size should be 3 * out_hidden if it is per-column quantization
+        weight_zero_point = inputs[7]
+
+        # (2, batch, num_heads, past_seq, head_size)
+        past = inputs[8]
+
+        # ************************* Parse inputs *************************
+        t1 = ["int8", "uint8"]
+        t2 = ["int8", "uint8"]
+        t3 = ["float32", "float16"]
+        t4 = ["int32"]
+
+        # input
+        assert infer_type(input_emb).checked_type.dtype in t1
+        assert (
+            len(infer_shape(input_emb)) == 3
+        ), "Input should be 3D tensor with shape (batch_size, sequence_length, input_hidden_size)"
+        (batch_size, seq_len, input_hidden) = infer_shape(input_emb)
+        assert input_hidden > 0, (
+            "The weight tensor has (input_hidden_size, 3 * output_hidden_size) shape, so it doesn't"
+            f" make sense to have ({input_hidden}, 3 * output_hidden_size) weight tensor."
+        )
+        assert seq_len > 0, (
+            "The output tensor has (batch_size, sequence_length, hidden_size) shape,"
+            f" so it doesn't make sense to have (batch_size, {seq_len}, hidden_size) output."
+        )
+
+        # weight
+        assert infer_type(weight).checked_type.dtype in t2
+        assert len(infer_shape(weight)) == 2, (
+            "Weight should be 2D input tensor with shape (input_hidden_size, 3 * hidden_size), "
+            "hidden_size = num_heads * head_size"
+        )
+        (input_hidden_weight, out_hidden_x3) = infer_shape(weight)
+        assert input_hidden == input_hidden_weight
+        assert out_hidden_x3 % 3 == 0, "output hidden shape should be divisible by 3: W_Q, W_K, W_V"
+        out_hidden = out_hidden_x3 // 3
+        assert (
+            out_hidden % num_heads == 0
+        ), "output hidden size should be divisible by number of attention heads"
+        head_size = out_hidden // num_heads
+
+        # bias
+        assert infer_type(bias).checked_type.dtype in t3
+        assert (
+            len(infer_shape(bias)) == 1
+        ), "Bias should be 1D input tensor with shape (3 * hidden_size)"
+        (out_hidden_x3_bias,) = infer_shape(bias)
+        assert out_hidden_x3 == out_hidden_x3_bias
+
+        # input_scale
+        assert infer_type(input_scale).checked_type.dtype in t3
+        input_scale = get_scalar(
+            input_scale, params, dtype=infer_type(input_scale).checked_type.dtype
+        )
+
+        # weight_scale
+        assert infer_type(weight_scale).checked_type.dtype in t3
+        # TODO(agladyshev): now QNN Batch Matmul only supports scalar types for scale and zero_point
+        weight_scale = get_scalar(
+            weight_scale, params, dtype=infer_type(weight_scale).checked_type.dtype
+        )
+
+        # mask_index
+        assert (
+            mask_index is not None
+        ), "Attention import currently only supports required mask_index"
+        assert infer_type(mask_index).checked_type.dtype in t4
+        mask_index_shape = infer_shape(mask_index)
+        assert (
+            len(mask_index_shape) == 2
+            and mask_index_shape[0] == batch_size
+            and mask_index_shape[1] >= seq_len
+        ), "currently only support (batch_size, sequence_length) mask index"
+
+        # TODO(agladyshev): int32 required for qnn.batch_matmul (QnnBatchMatmulRel)
+        zero_point_zero = _expr.const(0, "int32")
+
+        # input_zero_point
+        if input_zero_point is None:
+            input_zero_point = zero_point_zero
+        else:
+            assert infer_type(input_zero_point).checked_type.dtype in t1
+            # TODO(agladyshev): int32 required for qnn.batch_matmul (QnnBatchMatmulRel)
+            input_zero_point = get_scalar(input_zero_point, params, dtype="int32")
+
+        # weight_zero_point
+        if weight_zero_point is None:
+            weight_zero_point = zero_point_zero
+        else:
+            assert infer_type(weight_zero_point).checked_type.dtype in t2
+            # TODO(agladyshev): int32 required for qnn.batch_matmul (QnnBatchMatmulRel)
+            weight_zero_point = get_scalar(weight_zero_point, params, dtype="int32")
+
+        # past (2, batch_size, num_heads, past_sequence_length, head_size)
+        past_seq_len = 0
+        if past is not None:
+            assert infer_type(past).checked_type.dtype in t3
+            past_shape = infer_shape(past)
+            assert len(past_shape) == 5, "past should be 5D tensor"
+            assert (
+                past_shape[0] == 2
+                and past_shape[1] == batch_size
+                and past_shape[2] == num_heads
+                and past_shape[3] + seq_len == mask_index_shape[1]
+                and past_shape[4] == head_size
+            )
+            past_seq_len = past_shape[3]
+
+        # ************************* Create Relay *************************
+        # Add batch dimension for QNN Batch Matmul
+        weight = _op.expand_dims(weight, 0, num_newaxis=1)
+        weight = _op.concatenate([weight] * batch_size, axis=0)
+
+        # Split weight and biases and do the Matmul
+        w_Q, w_K, w_V = _op.split(weight, 3, axis=-1)
+        b_Q, b_K, b_V = _op.split(bias, 3, axis=-1)
+
+        def qmatmul_dequantize_bias(
+            lhs, rhs, lhs_scale, rhs_scale, lhs_zero_point, rhs_zero_point, bias
+        ):
+            rhs_transposed = _op.transpose(rhs, axes=[0, 2, 1])  # QNN Batch Matmul do: X * Y^T
+            result = _qnn.op.batch_matmul(
+                lhs, rhs_transposed, lhs_zero_point, rhs_zero_point, lhs_scale, rhs_scale
+            )
+            result = _qnn.op.dequantize(
+                result,
+                _op.multiply(lhs_scale, rhs_scale),
+                zero_point_zero,
+                axis=-1,  # TODO(agladyshev): what is 'axis' parameter for?

Review Comment:
   Do you still need this todo comment?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org