You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/10 21:46:11 UTC
[incubator-tvm] branch master updated: [RELAY][FRONTEND][CAFFE2]
add Mul and ConvTranspose operator (#5302)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 575d536 [RELAY][FRONTEND][CAFFE2] add Mul and ConvTranspose operator (#5302)
575d536 is described below
commit 575d53698e41aef85360d3445f478df2a8f2a9a2
Author: Huacong Yang <wi...@rock-chips.com>
AuthorDate: Sat Apr 11 05:46:03 2020 +0800
[RELAY][FRONTEND][CAFFE2] add Mul and ConvTranspose operator (#5302)
---
python/tvm/relay/frontend/caffe2.py | 35 +++++++++++++++++++++++++++++++++++
1 file changed, 35 insertions(+)
diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py
index f4fcd92..8a5803f 100644
--- a/python/tvm/relay/frontend/caffe2.py
+++ b/python/tvm/relay/frontend/caffe2.py
@@ -172,6 +172,12 @@ class Add(Elemwise):
name = 'add'
+class Mul(Elemwise):
+ """ Operator converter for Mul.
+ """
+ name = 'multiply'
+
+
class Pool(Caffe2OpConverter):
""" A helper class for pool op converters.
"""
@@ -233,6 +239,33 @@ class Conv(Caffe2OpConverter):
return out
+class ConvTranspose(Caffe2OpConverter):
+ """ Operator converter for ConvTranspose.
+ """
+
+ @classmethod
+ def _impl(cls, inputs, args, params):
+ # get number of channels
+ channels = infer_channels(inputs[1], True)
+ args['channels'] = channels
+ _clean_up_pool_args(args)
+ out = AttrCvt(
+ op_name=dimension_picker('conv', '_transpose'),
+ transforms={
+ 'kernel_shape': 'kernel_size',
+ 'pads': ('padding', (0, 0), revert_caffe2_pad),
+ 'dilations': ('dilation', (1, 1)),
+ 'order': ('data_layout', ("NCHW"), lambda x: x if isinstance(x, str) else x.decode('UTF-8')),
+ },
+ excludes=[],
+ ignores=_caffe2_internal_args,
+ custom_check=dimension_constraint())(inputs[:2], args, params)
+ use_bias = len(inputs) == 3
+ if use_bias:
+ out = _op.nn.bias_add(out, inputs[2])
+ return out
+
+
class Concat(Caffe2OpConverter):
""" Operator converter for Concat.
"""
@@ -353,12 +386,14 @@ def _get_convert_map():
# caffe2 common operators
'Add': Add.get_converter(),
'Sum': Sum.get_converter(),
+ 'Mul': Mul.get_converter(),
'Softmax': Softmax.get_converter(),
# nn
'AveragePool': AveragePool.get_converter(),
'MaxPool': MaxPool.get_converter(),
'Conv': Conv.get_converter(),
+ 'ConvTranspose': ConvTranspose.get_converter(),
'Concat': Concat.get_converter(),
'FC': FC.get_converter(),
'SpatialBN': SpatialBN.get_converter(),