You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/12/15 04:59:29 UTC

[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19653: [wip] onnx support more ops

Zha0q1 commented on a change in pull request #19653:
URL: https://github.com/apache/incubator-mxnet/pull/19653#discussion_r543047281



##########
File path: python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
##########
@@ -2280,6 +2289,200 @@ def convert_layer_norm(node, **kwargs):
     return nodes
 
 
+def make_tensor(shape_list, shape_name, initializer, dtype='int64'):
+    shape_np = np.array(shape_list, dtype=dtype)
+    data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[shape_np.dtype]
+    dims = np.shape(shape_np)
+    tensor_node = onnx.helper.make_tensor_value_info(shape_name, data_type, dims)
+    initializer.append(
+        onnx.helper.make_tensor(
+            name=shape_name,
+            data_type=data_type,
+            dims=dims,
+            vals=shape_list,
+            raw=False,
+        )
+    )
+
+
+@mx_op.register("_contrib_interleaved_matmul_selfatt_qk")
+def convert_matmul_selfatt_qk(node, **kwargs):
+    from onnx.helper import make_node
+    from onnx import TensorProto
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    heads = int(attrs.get('heads'))
+    
+    # a, b, c, d, e are seq_len, batch_size, num_heads, 3, head_dim respectively
+    make_tensor([heads], name+"_const_heads", kwargs["initializer"])
+    make_tensor([0], name+"_0", kwargs["initializer"])
+    make_tensor([1], name+"_1", kwargs["initializer"])
+    make_tensor([1], name+"_1_f", kwargs["initializer"], dtype='float32')
+    make_tensor([2], name+"_2", kwargs["initializer"])
+    make_tensor([3], name+"_3", kwargs["initializer"])
+    make_tensor([4], name+"_4", kwargs["initializer"])
+    make_tensor([5], name+"_5", kwargs["initializer"])
+    make_tensor([heads], name+"_c", kwargs["initializer"])
+    make_tensor([3], name+"_d", kwargs["initializer"])
+ 
+    nodes = [
+            make_node('Shape', [input_nodes[0]], [name+"_data_shape"]),
+            make_node('Slice', [name+'_data_shape', name+'_0', name+'_1'], [name+"_a"]),
+            make_node('Slice', [name+'_data_shape', name+'_1', name+'_2'], [name+"_b"]),
+            make_node('Slice', [name+'_data_shape', name+'_2', name+'_3'], [name+"_cde"]),
+            make_node('Div', [name+'_cde', name+'_c'], [name+'_de']),
+            make_node('Div', [name+'_de', name+'_d'], [name+'_e']),
+            make_node('Cast', [name+'_e'], [name+'_e_f'], to=int(TensorProto.FLOAT)),
+            make_node('Sqrt', [name+'_e_f'], [name+'_sqrt_e']),
+            make_node('Div', [name+'_1_f', name+'_sqrt_e'], [name+'_1_over_sqrt_e']),
+            make_node('Mul', [name+'_b', name+'_c'], [name+'_bc']),
+
+            make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_d', name+'_e'], \
+                      [name+'_shape0'], axis=0),
+            make_node("Concat", [name+'_0', name+'_0', name+'_0', name+'_0', name+'_0'], \
+                      [name+'_slice_start0'], axis=0),
+            make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_1', name+'_e'], \
+                      [name+'_slice_end0'], axis=0),
+            make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_e'], \
+                      [name+'_shape1'], axis=0),
+            make_node("Concat", [name+'_bc', name+'_a', name+'_e'], \
+                      [name+'_shape2'], axis=0),
+            make_node("Concat", [name+'_0', name+'_0', name+'_0', name+'_1', name+'_0'], \
+                      [name+'_slice_start1'], axis=0),
+            make_node("Concat", [name+'_a', name+'_b', name+'_c', name+'_2', name+'_e'], \
+                      [name+'_slice_end1'], axis=0),
+
+            make_node('Reshape', [input_nodes[0], name+'_shape0'], [name+'_reshape0_out']),
+            make_node('Slice', [name+'_reshape0_out', name+'_slice_start0', name+'_slice_end0'], \
+                      [name+'_slice0_out']),
+            make_node('Reshape', [name+'_slice0_out', name+'_shape1'], [name+'_reshape1_out']),
+            make_node('Transpose', [name+'_reshape1_out'], [name+'_transpose0_out'], \
+                      perm=(1, 2, 0, 3)),
+            make_node('Reshape', [name+'_transpose0_out', name+'_shape2'], [name+'_reshape2_out']),
+            make_node('Mul', [name+'_reshape2_out', name+'_1_over_sqrt_e'], [name+'_mul0_out']),
+            make_node('Slice', [name+'_reshape0_out', name+'_slice_start1', name+'_slice_end1'], \
+                      [name+'_slice1_out']),
+            make_node('Reshape', [name+'_slice1_out', name+'_shape1'], [name+'_reshape3_out']),
+            make_node('Transpose', [name+'_reshape3_out'], [name+'_transpose1_out'], \
+                      perm=(1, 2, 0, 3)),
+            make_node('Reshape', [name+'_transpose1_out', name+'_shape2'], [name+'_reshape4_out']),
+            make_node('Transpose', [name+'_reshape4_out'], [name+'_transpose2_out'], \
+                      perm=(0, 2, 1)),
+            make_node('MatMul', [name+'_mul0_out', name+'_transpose2_out'], [name], name=name)
+        ]
+
+    return nodes
+
+
+@mx_op.register("broadcast_axis")
+def convert_broadcast_axis(node, **kwargs):
+    from onnx.helper import make_node
+    from onnx import TensorProto
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    data_shape_list = list(kwargs['in_shape'][0])
+    axis = convert_string_to_list(attrs.get('axis', '()'))
+    size = convert_string_to_list(attrs.get('size', '()'))
+    assert len(axis) == len(size)
+
+    make_tensor([0], name+'_0', kwargs["initializer"])
+    make_tensor([1], name+'_1', kwargs["initializer"])
+    make_tensor([], name+'_void', kwargs["initializer"])
+    create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
+    create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
+
+    shape_name = name+'_shape_0'
+    nodes = [
+            make_node('Shape', [input_nodes[0]], [shape_name]),
+            make_node('Shape', [shape_name], [name+'_in_dim']),
+            make_node('Reshape', [name+'_in_dim', name+'_void'], [name+'_in_dim_s']),
+            make_node('Range', [name+'_0_s', name+'_in_dim_s', name+'_1_s'], [name+'_range']),
+        ]
+
+    for i, axis in enumerate(axis):
+        if axis not in (0, 1):
+            make_tensor([axis], name+'_'+str(axis), kwargs["initializer"])
+        make_tensor([size[i]-1], name+'_size_'+str(i), kwargs["initializer"])
+        _ = [
+             # this is a "one-hot" tensor
+             make_node('Equal', [name+'_range', name+'_'+str(axis)], [name+'_equal_'+str(i)]),
+             make_node('Cast', [name+'_equal_'+str(i)], [name+'_cast_'+str(i)], to=int(TensorProto.INT64)),
+             make_node('Mul', [name+'_size_'+str(i), name+'_cast_'+str(i)], [name+'_mul_'+str(i)]),
+             make_node('Add', [name+'_mul_'+str(i), name+'_1'], [name+'_add_'+str(i)]),
+             make_node('Mul', [name+'_add_'+str(i), shape_name], [name+'_shape_'+str(i+1)])
+            ]
+        shape_name = name+'_shape_'+str(i+1)
+        nodes += _
+
+    nodes += [make_node('Expand', [input_nodes[0], shape_name], [name], name=name)]
+
+    return nodes
+
+@mx_op.register("_contrib_interleaved_matmul_selfatt_valatt")
+def convert_interleaved_matmul_selfatt_valatt(node, **kwargs):
+    return []
+
+
+@mx_op.register("SequenceMask")
+def convert_sequencemask(node, **kwargs):
+    from onnx.helper import make_node
+    from onnx import TensorProto
+
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    use_sequence_length = attrs.get('use_sequence_length', 'False')
+    mask_val = float(attrs.get('value', '0'))
+    axis = int(attrs.get('axis', '0'))
+
+    if(use_sequence_length == 'False'):
+        return [make_node('Identity', [input_nodes[0]], [name], name=name)]
+
+    make_tensor([], name+'_void', kwargs["initializer"])
+    make_tensor([0], name+'_0', kwargs["initializer"])
+    make_tensor([1], name+'_1', kwargs["initializer"])
+    make_tensor([2], name+'_2', kwargs["initializer"])
+    create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
+    create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
+    create_const_scalar_node(name+'_2_s', np.int64(2), kwargs)
+    make_tensor([mask_val], name+'_mask_val', kwargs["initializer"], dtype='float32')
+    #create_const_scalar_node(name+'_mask_val', np.float32(mask_val), kwargs),
+
+    nodes = [
+        make_node('Shape', [input_nodes[0]], [name+'_in_shape']),
+        make_node('Slice', [name+'_in_shape', name+'_0', name+'_1'], [name+'_slice_0']),
+        make_node('Slice', [name+'_in_shape', name+'_1', name+'_2'], [name+'_slice_1']),
+        make_node('Concat', [name+'_slice_0', name+'_1'], [name+'_shape_0'], axis=0),
+        make_node('Shape', [name+'_in_shape'], [name+'_in_dim']),
+        make_node('Reshape', [name+'_in_dim', name+'_void'], [name+'_in_dim_s']),
+        make_node('Range', [name+'_0_s', name+'_in_dim_s', name+'_1_s'], [name+'_range_0']),
+        make_node('Less', [name+'_range_0', name+'_2'], [name+'_less_0']),
+        make_node('Where', [name+'_less_0', name+'_in_shape', name+'_1'], [name+'_shape_1'])
+        ]
+
+    if(axis == 0):
+        nodes += [
+            make_node('Reshape', [name+'_slice_0', name+'_void'], [name+'_max_len'], name = '111'),

Review comment:
       Ohhh sorry those were leftovers when I was debugging. Will remove in next push




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org