You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/24 18:57:23 UTC

[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #9102: [Frontend][PaddlePaddle] Add 100+ operators supporting

AndrewZhaoLuo commented on a change in pull request #9102:
URL: https://github.com/apache/tvm/pull/9102#discussion_r715840346



##########
File path: python/tvm/relay/frontend/paddlepaddle.py
##########
@@ -64,26 +129,193 @@ def _get_pad_size(in_size, dilated_kernel_size, stride_size):
     return [pad_before, pad_after]
 
 
+def _dtype_shape_promotion(inputs):
+    """promote data type and shape for list of tensors."""
+
+    dtype_order = ["bool", "int8", "int16", "int32", "int64", "float32", "float64"]
+
+    ranks = [len(infer_shape(x)) for x in inputs]
+    if set(ranks) == set([1, 0]):
+        for i, r in enumerate(ranks):
+            if r == 0:
+                inputs[i] = _op.expand_dims(inputs[i], axis=0)
+
+    dtypes = set(dtype_order.index(infer_type(x).checked_type.dtype) for x in inputs)
+    if len(dtypes) == 1:
+        return inputs
+    max_dtype = dtype_order[max(dtypes)]
+    for i, input_op in enumerate(inputs):
+        if infer_type(input_op).checked_type.dtype != max_dtype:
+            inputs[i] = input_op.astype(max_dtype)
+    return inputs
+
+
+def shape_of(x, dtype="int32"):
+    """Get shape of a tensor"""
+
+    ttype = infer_type(x).checked_type
+    if not _ty.is_dynamic(ttype):
+        shape = list(ttype.shape)
+        return _expr.const(np.array(shape), dtype)
+    return _op.shape_of(x, dtype)
+
+
+def _infer_value(x, params):
+    """Try running infer_value, and if successful, return the inferred value.
+    Otherwise, return input"""
+
+    try:
+        value = infer_value(x, params)
+        return value.numpy().tolist()
+    except Exception:  # pylint: disable=broad-except
+        return x
+
+
+def _convert_dtype_value(val):
+    """converts a Paddle type id to a string."""
+
+    convert_dtype_map = {
+        21: "int8",
+        20: "uint8",
+        6: "float64",
+        5: "float32",
+        4: "float16",
+        3: "int64",
+        2: "int32",
+        1: "int16",
+        0: "bool",
+    }
+    if val not in convert_dtype_map:
+        msg = "Paddle data type value %d is not handled yet." % (val)
+        raise NotImplementedError(msg)
+    return convert_dtype_map[val]
+
+
+def convert_unary_op(g, op, block):
+    """Operator converter for all the activation."""
+
+    op_map = {
+        "isinf_v2": _op.isinf,
+        "isfinite_v2": _op.isfinite,
+        "isnan_v2": _op.isnan,
+    }
+    if op.type in op_map:
+        unary_func = op_map[op.type]
+    else:
+        unary_func = get_relay_op(op.type)
+    out = unary_func(g.get_node(op.input("X")[0]))
+    g.add_node(op.output("Out")[0], out)
+
+
+def convert_addmm(g, op, block):
+    """Operator converter for addmm."""
+
+    input_x = g.get_node(op.input("Input")[0])
+    x = g.get_node(op.input("X")[0])
+    y = g.get_node(op.input("Y")[0])
+
+    alpha = op.attr("Alpha")
+    beta = op.attr("Beta")
+    dtype = block.var(op.output("Out")[0]).dtype
+    dtype = str(dtype).strip().split(".")[1]
+
+    if not isinstance(alpha, _expr.Expr) and alpha != 1:
+        alpha = _expr.const(alpha, dtype)
+        x *= alpha
+
+    if not isinstance(beta, _expr.Expr) and beta != 1:
+        beta = _expr.const(beta, dtype)
+        input_x *= beta
+
+    transposed_y = _op.transpose(y, axes=[1, 0])
+    dense_out = _op.nn.dense(x, transposed_y)
+    out = dense_out + input_x
+    g.add_node(op.output("Out")[0], out)
+
+
+def convert_addn(g, op, block):
+    """Operator converter for sum(add_n)."""
+
+    inputs = op.input("X")
+    out = g.get_node(inputs[0])
+    for i in range(1, len(inputs)):
+        out += g.get_node(inputs[i])
+    g.add_node(op.output("Out")[0], out)
+
+
 def convert_arg_max(g, op, block):
     """Operator converter for arg_max."""
 
     axis = op.attr("axis")
     keepdims = op.attr("keepdims")
     flatten = op.attr("flatten")
+    dtype = op.attr("dtype")
+    dtype = _convert_dtype_value(dtype)
 
     x = g.get_node(op.input("X")[0])
     if axis is None or flatten:
         x = _op.reshape(x, [-1])
         out = _op.argmax(x, axis=None, keepdims=True)
     else:
         out = _op.argmax(x, axis=axis, keepdims=keepdims)
+    if dtype != infer_type(out).checked_type.dtype:
+        out = _op.cast(out, dtype)
+    g.add_node(op.output("Out")[0], out)
+
+
+def convert_arg_min(g, op, block):

Review comment:
       nit: argmin




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org