You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/15 05:14:41 UTC

[tvm] branch main updated: [RELAY][FRONTEND] Initial OneFlow frontend support. (#8790)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 29774bddd8 [RELAY][FRONTEND] Initial OneFlow frontend support.  (#8790)
29774bddd8 is described below

commit 29774bddd8a61643b7836e212198a209619735f4
Author: JiaKui Hu <hj...@163.com>
AuthorDate: Fri Apr 15 13:14:35 2022 +0800

    [RELAY][FRONTEND] Initial OneFlow frontend support.  (#8790)
    
    * add relay.f.frontend.fm_oneflow support cnns
    
    * support cuda
    
    * fix mobilenetv2 and reviews
    
    * fix: model without meta info
    
    * support eager and yolo, add test
    
    * fix: license
    
    * add: tutorials
    
    * fix: support new graph
    
    * fix some comments
    
    * refine
    
    * fix concat op convert bug
    
    * refine
    
    * refine
    
    * change cuda to cpu
    
    * fix bug
    
    * fix ci error in tvm
    
    * fix pylint check
    
    * delete useless file
    
    * add skimage package in docker
    
    * fix ci error
    
    * fix bug
    
    * add oneflow fronted test in ci
    
    * merge conflict
    
    * fix tutorial
    
    * try to find error in ci
    
    * revert
    
    * merge conflict
    
    * black oneflow
    
    * Delete from_oneflow.py
    
    Co-authored-by: Xiaoyu Zhang <35...@users.noreply.github.com>
    Co-authored-by: BBuf <11...@qq.com>
---
 python/tvm/relay/frontend/__init__.py         |    1 +
 python/tvm/relay/frontend/oneflow.py          | 1821 +++++++++++++++++++++++++
 tests/python/frontend/oneflow/test_forward.py |  723 ++++++++++
 tests/scripts/task_python_frontend.sh         |    3 +
 4 files changed, 2548 insertions(+)

diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py
index aa49b63203..fbbd4f9921 100644
--- a/python/tvm/relay/frontend/__init__.py
+++ b/python/tvm/relay/frontend/__init__.py
@@ -23,6 +23,7 @@ for Relay.
 from .mxnet import from_mxnet
 from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var
 from .keras import from_keras
+from .oneflow import from_oneflow
 from .onnx import from_onnx
 from .tflite import from_tflite
 from .coreml import from_coreml
diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py
new file mode 100644
index 0000000000..c15b7b3c24
--- /dev/null
+++ b/python/tvm/relay/frontend/oneflow.py
@@ -0,0 +1,1821 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
+# pylint: disable=import-outside-toplevel
+"""OneFlow: OneFlow is a performance-centered and open-source deep learning framework."""
+
+import os
+import re
+import copy
+import warnings
+
+import numpy as np
+import tvm
+from tvm.ir import IRModule
+from tvm.topi.utils import get_const_tuple
+
+from .. import analysis
+from .. import expr as _expr
+from .. import function as _function
+from .. import op as _op
+from .. import ty as _ty
+from .common import (
+    AttrCvt,
+    Renamer,
+    fold_constant,
+    get_relay_op,
+    infer_channels,
+    infer_shape,
+    infer_type,
+    new_var,
+)
+
+__all__ = ["from_oneflow"]
+
+FLOW_2_STR_DTYPE = {
+    2: "float32",
+    3: "float64",
+    6: "int64",
+    5: "int32",
+    4: "int8",
+    7: "uint8",
+    9: "float16",
+}
+
+
+def is_input_op(node):
+    """Return true when the node is the input of the graph."""
+    return node.WhichOneof("op_type") == "input_conf"
+
+
+def is_user_op(node):
+    """Return true when the node is the intermediate variables of graph."""
+    return node.WhichOneof("op_type") == "user_conf"
+
+
+def is_output_op(node):
+    """Return true when the node is the output of the graph."""
+    return node.WhichOneof("op_type") == "output_conf"
+
+
+def is_param_op(node):
+    """Return true when the node is the intermediate variables of model(saved)."""
+    return node.WhichOneof("op_type") == "variable_conf"
+
+
+def get_node_info(node):
+    """
+    Get basic information about nodes: shape, data_type
+    """
+    # list->tuple
+    shape = tuple(node.input_conf.blob_conf.shape.dim)
+    # get data type
+    dtype = node.input_conf.blob_conf.data_type
+    if dtype in list(FLOW_2_NP_DTYPE.keys()):
+        data_type = FLOW_2_NP_DTYPE[dtype]
+    else:
+        raise IndexError("Please check the data type of your node: %s" % node.name)
+
+    return shape, data_type
+
+
+def _dtype_shape_promotion(inputs):
+    """Promote data type and shape for list of tensors."""
+
+    dtype_order = ["bool", "int8", "int16", "int32", "int64", "float32", "float64"]
+
+    ranks = [len(infer_shape(x)) for x in inputs]
+    if set(ranks) == set([1, 0]):
+        for i, r in enumerate(ranks):
+            if r == 0:
+                inputs[i] = _op.expand_dims(inputs[i], axis=0)
+
+    dtypes = set(dtype_order.index(infer_type(x).checked_type.dtype) for x in inputs)
+    if len(dtypes) == 1:
+        return inputs
+    max_dtype = dtype_order[max(dtypes)]
+    for i, input_op in enumerate(inputs):
+        if infer_type(input_op).checked_type.dtype != max_dtype:
+            inputs[i] = input_op.astype(max_dtype)
+    return inputs
+
+
+def parse_attr(attr):
+    """Parse attribute of user op in oneflow."""
+    attrs = {}
+    for a in attr:
+        attr_str = str(attr[a])
+
+        if attr_str[0:7] == "at_list":
+            attr_str_ = attr_str.split(" ")[0]
+
+            if attr_str_ == "at_list_float":
+                attrs[a] = tuple(attr[a].at_list_float.val)
+            elif attr_str_ == "at_list_int32":
+                attrs[a] = tuple(attr[a].at_list_int32.val)
+            elif attr_str_ == "at_list_int64":
+                attrs[a] = tuple(attr[a].at_list_int64.val)
+
+        elif attr_str.split(":")[0] == "at_string":
+            attrs[a] = attr[a].at_string
+
+        elif attr_str.split(" ")[0] == "at_shape":
+            attrs[a] = tuple(list(attr[a].at_shape.dim))
+
+        else:
+            attr_str_ = attr_str.split(":")[0]
+            if attr_str_ == "at_bool":
+                attrs[a] = attr[a].at_bool
+            elif attr_str_ == "at_double":
+                attrs[a] = attr[a].at_double
+            elif attr_str_ == "at_float":
+                attrs[a] = attr[a].at_float
+            elif attr_str_ == "at_int32":
+                attrs[a] = attr[a].at_int32
+            elif attr_str_ == "at_int64":
+                attrs[a] = attr[a].at_int64
+
+    return attrs
+
+
+def shape_of(x, dtype="int64"):
+    ttype = infer_type(x).checked_type
+    if not _ty.is_dynamic(ttype):
+        shape = list(ttype.shape)
+        return _expr.const(shape, dtype)
+
+    return _op.shape_of(x, dtype)
+
+
+def dimension_constraint():
+    def _dim_check(attrs):
+        if len(attrs["kernel_size"]) in [1, 2, 3]:
+            return True
+        return False
+
+    return _dim_check, "Only 1d, 2d and 3d kernel supported."
+
+
+class OneFlowOpConverter(object):
+    """A helper class for holding oneflow op converters."""
+
+    @classmethod
+    def get_converter(cls):
+        """
+        Get converter matches given opset.
+        Parameters
+        ----------
+        None
+
+        Returns
+        -------
+        converter, which should be `_impl_vx`.
+        """
+        version = 1
+        if hasattr(cls, "_impl_v{}".format(version)):
+            return getattr(cls, "_impl_v{}".format(version))
+        raise NotImplementedError("version {} of {} not implemented".format(version, cls.__name__))
+
+
+class Pool(OneFlowOpConverter):
+    """A helper class for pool op converters."""
+
+    name = ""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        data = inputs[0]
+        attrs.pop("data_format")
+        out = AttrCvt(
+            op_name=cls.name,
+            transforms={
+                "kernel_size": "pool_size",
+                "stride": "strides",
+                "dilations": ("dilation", 1),
+            },
+            ignores=["return_indices", "divisor_override"],
+            custom_check=dimension_constraint(),
+        )([data], attrs, params)
+
+        return out
+
+
+class AdaptiveAvgPool2d(OneFlowOpConverter):
+    """Operator converter for AdaptiveAvgPool2d"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        return _op.nn.adaptive_avg_pool2d(inputs[0], output_size=attrs["output_size"])
+
+
+class AdaptiveMaxPool2d(OneFlowOpConverter):
+    """Operator converter for AdaptiveMaxPool2d"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        return _op.nn.adaptive_max_pool2d(inputs[0], output_size=attrs["output_size"])
+
+
+class GlobalAveragePool(OneFlowOpConverter):
+    """Operator converter for GlobalAveragePool"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        rank = len(infer_shape(inputs[0]))
+        if rank == 3:
+            return _op.nn.global_avg_pool1d(inputs[0])
+        if rank == 4:
+            return _op.nn.global_avg_pool2d(inputs[0])
+        if rank == 5:
+            return _op.nn.global_avg_pool3d(inputs[0])
+        raise NotImplementedError(
+            "Global average pooling is only implemented for 1D, 2D, and 3D kernels, got %dD."
+            % (rank - 2),
+        )
+
+
+class GlobalMaxPool(OneFlowOpConverter):
+    """Operator converter for GlobalMaxPool"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        rank = len(infer_shape(inputs[0]))
+        if rank == 3:
+            return _op.nn.global_max_pool1d(inputs[0])
+        if rank == 4:
+            return _op.nn.global_max_pool2d(inputs[0])
+        if rank == 5:
+            return _op.nn.global_max_pool3d(inputs[0])
+        raise NotImplementedError(
+            "Global max pooling is only implemented for 1D, 2D, and 3D kernels, got %dD."
+            % (rank - 2),
+        )
+
+
+class Conv(OneFlowOpConverter):
+    """A helper class for conv op converters."""
+
+    name = ""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        # The kernel is imported from model_dir_path, without the ".weight" logo, etc.
+        # The data is obtained through the graph, its op contains "-input_"
+        in_names = ["-input_"]
+        kernel_names = [".weight"]
+        for i in inputs:
+            IN_NAMES = any(x in str(i) for x in in_names)
+            KERNEL_NAMES = any(x in str(i) for x in kernel_names)
+            if IN_NAMES:
+                data = i
+            elif KERNEL_NAMES:
+                kernel = i
+            else:
+                data = i
+
+        # Use shape of input to determine convolution type.
+        kernel_type = infer_type(kernel)
+        kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)]
+
+        if "kernel_size" not in attrs:
+            attrs["kernel_size"] = kernel_shapes[0][2:]
+        if "dilation_rate" in attrs:
+            attrs["dilation"] = list(attrs["dilation_rate"])
+            attrs.pop("dilation_rate")
+
+        pad_v = attrs.get("padding_before", [0, 0])
+        attrs["padding"] = [pad_v[0], pad_v[1], pad_v[0], pad_v[1]]
+
+        group_conv1d = False
+        if cls.name == "conv1d" and attrs.get("groups") != 1:
+            group_conv1d = True
+            # Expand input from NCW to NCHW
+            data = _op.expand_dims(data, axis=2)
+            # Expand kernel from OIW to OIHW
+            kernel = _op.expand_dims(kernel, axis=2)
+            # Add new value to kernel_shape, strices, dilation, pads, if needed
+            attrs["kernel_size"] = [1] + list(attrs["kernel_size"])
+            if "strides" in attrs:
+                attrs["strides"] = [1] + list(attrs["strides"])
+            if "dilations" in attrs:
+                attrs["dilation"] = [1] + list(attrs["dilations"])
+
+        out = AttrCvt(
+            op_name=cls.name,
+            transforms={
+                "group": ("groups", 1),
+            },
+            ignores=["data_format", "filters", "padding_after", "padding_before"],
+            custom_check=dimension_constraint(),
+        )([data, kernel], attrs, params)
+
+        # If this was a group_conv1d, squish output back to NCW.
+        if group_conv1d:
+            out = _op.squeeze(out, axis=[2])
+
+        return out
+
+
+class ConvTranspose(OneFlowOpConverter):
+    """Operator converter for ConvTranspose."""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        in_names = ["-input_"]
+        kernel_names = [".weight"]
+        for i in inputs:
+            IN_NAMES = any(x in str(i) for x in in_names)
+            KERNEL_NAMES = any(x in str(i) for x in kernel_names)
+            if IN_NAMES:
+                data = i
+            elif KERNEL_NAMES:
+                kernel = i
+            else:
+                data = i
+
+        # get number of channels
+        attrs["channels"] = attrs.get("filters", 1)
+        attrs["groups"] = attrs.get("group", 1)
+
+        kernel_type = infer_type(kernel)
+        kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)]
+
+        if "kernel_size" not in attrs:
+            attrs["kernel_size"] = kernel_shapes[0][2:]
+
+        if "dilation_rate" in attrs:
+            attrs["dilation"] = list(attrs["dilation_rate"])
+            attrs.pop("dilation_rate")
+
+        pad_v = attrs.get("padding_before", [0, 0])
+        attrs["padding"] = [pad_v[0], pad_v[1], pad_v[0], pad_v[1]]
+
+        out = AttrCvt(
+            op_name=cls.name,
+            transforms={
+                "group": ("groups", 1),
+            },
+            disables=["filters", "data_format", "padding_before"],
+            custom_check=dimension_constraint(),
+        )([data, kernel], attrs, params)
+
+        return out
+
+
+class Upsample(OneFlowOpConverter):
+    """A helper class for upsample op converters"""
+
+    name = ""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        data = inputs[0]
+        input_shape = infer_shape(data)
+        dims = len(input_shape)
+
+        width_scale = attrs.get("width_scale", 1.0)
+        height_scale = attrs.get("height_scale", 1.0)
+        align_corners = attrs.get("align_corners", False)
+
+        if "nearest" in cls.name:
+            method = "nearest_neighbor"
+        elif "trilinear" in cls.name:
+            method = "trilinear"
+        elif "bilinear" in cls.name:
+            method = "bilinear"
+
+        # in 3d case, we use the purely static op
+        if dims == 5:
+            if isinstance(scales, _expr.Expr):
+                scale_h = _op.take(scales, _op.const(3))
+                scale_w = _op.take(scales, _op.const(4))
+                scale_d = _op.take(scales, _op.const(1))
+            else:
+                assert len(scales) == 5
+                scale_h = scales[-2]
+                scale_w = scales[-1]
+                scale_d = scales[-3]
+
+            layout = "NCDHW"
+            out = _op.nn.upsampling3d(
+                data,
+                scale_d,
+                scale_h,
+                scale_w,
+                layout=layout,
+                method=method,
+                coordinate_transformation_mode="asymmetric",
+            )
+        # in 2d case, use dynamic op
+        else:
+            if isinstance(height_scale, _expr.Expr):
+                height_scale = _op.take(height_scale, _op.const(3))
+                width_scale = _op.take(width_scale, _op.const(4))
+            layout = "NCHW"
+
+            out = _op.nn.upsampling(
+                inputs[0],
+                height_scale,
+                width_scale,
+                layout=layout,
+                method=method,
+                align_corners=align_corners,
+            )
+        return out
+
+
+class UpsampleNearest(Upsample):
+    """Operator converter for Upsample Nearest"""
+
+    name = "upsample_nearest"
+
+
+class UpsampleBiLinear(Upsample):
+    """Operator converter for Upsample Bilinear"""
+
+    name = "upsample_bilinear"
+
+
+class Conv2d(Conv):
+    """Operator converter for Conv2d"""
+
+    name = "conv2d"
+
+
+class ConvTranspose2d(ConvTranspose):
+    """Operator converter for ConvTranspose2d"""
+
+    name = "conv2d_transpose"
+
+
+class BatchNorm(OneFlowOpConverter):
+    """Operator converter for BatchNorm"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        # sort the inputs
+        sorted_inputs = copy.deepcopy(inputs)
+        for i in inputs:
+            IN_NAMES = "-input_" in str(i)
+            if IN_NAMES:
+                sorted_inputs[0] = i
+            elif "weight" in str(i) and not IN_NAMES:
+                sorted_inputs[1] = i
+            elif "bias" in str(i) and not IN_NAMES:
+                sorted_inputs[2] = i
+            elif "mean" in str(i) and not IN_NAMES:
+                sorted_inputs[3] = i
+            elif "var" in str(i) and not IN_NAMES:
+                sorted_inputs[4] = i
+
+        if "data_format" in attrs:
+            if attrs["data_format"] == "channel_first":
+                attrs["axis"] = 1
+
+        out = AttrCvt(op_name="batch_norm", ignores=["training"], disables=["momentum"])(
+            sorted_inputs, attrs, params
+        )
+        return out[0]
+
+
+class Flatten(OneFlowOpConverter):
+    """Operator converter for Flatten"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        axis = attrs.get("axis", 1)
+        ishape = _op.shape_of(inputs[0])
+        ndim = infer_shape(ishape)[0]
+        if axis < 0:
+            axis = axis + ndim
+
+        if axis == 1:
+            out = _op.nn.batch_flatten(inputs[0])
+        else:
+            pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True)
+            post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True)
+            newshape = _op.concatenate([pre_shape, post_shape], axis=0)
+            out = _op.reshape(inputs[0], newshape)
+        return out
+
+
+class MatMul(OneFlowOpConverter):
+    """Operator converter for MatMul"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        assert len(inputs) == 2, "Gemm op take 2 inputs, {} given".format(len(inputs))
+        # Similar to 'class Conv'
+        true_names = ["weight"]
+        false_names = ["-input_"]
+        for i in inputs:
+            T_NAMES = any(x in str(i) for x in true_names)
+            F_NAMES = any(x in str(i) for x in false_names)
+            if T_NAMES and not F_NAMES:
+                matmul_b = i
+            else:
+                matmul_a = i
+
+        dtype = infer_type(matmul_a).checked_type.dtype
+
+        # Y = alpha * A * B
+        alpha = float(attrs.get("alpha", 1.0))
+        transA = bool(attrs.get("transpose_a", False))
+        transB = bool(attrs.get("transpose_b", False))
+
+        # get number of channels
+        channels = infer_channels(matmul_b, not transB)
+        if transA:
+            matmul_a = _op.transpose(matmul_a, axes=(1, 0))
+        if not transB:
+            matmul_b = _op.transpose(matmul_b, axes=(1, 0))
+        matmul_a = _op.nn.batch_flatten(matmul_a)
+        if alpha != 1.0:
+            matmul_a *= _expr.const(alpha, dtype=dtype)
+
+        return _op.nn.dense(matmul_a, matmul_b, units=channels)
+
+
+class Reduce(OneFlowOpConverter):
+    """Operator converter for reduce ops"""
+
+    name = ""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        attr = {"axis": attrs.get("axis", 0), "keepdims": attrs.get("keepdims", True)}
+        return AttrCvt(cls.name)(inputs, attr)
+
+
+class ReduceMax(Reduce):
+    """Operator converter for ReduceMax"""
+
+    name = "max"
+
+
+class ReduceMin(Reduce):
+    """Operator converter for ReduceMin"""
+
+    name = "min"
+
+
+class ReduceSum(Reduce):
+    """Operator converter for ReduceSum"""
+
+    name = "sum"
+
+
+class ReduceMean(Reduce):
+    """Operator converter for ReduceMean"""
+
+    name = "mean"
+
+
+class Square(OneFlowOpConverter):
+    """Operator converter for square"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        assert len(inputs) == 1, "Square op {} take 1 inputs, {} given".format(
+            cls.name, len(inputs)
+        )
+        return _op.multiply(inputs[0], inputs[0])
+
+
+class Add(OneFlowOpConverter):
+    """Operator converter for Add"""
+
+    name = "add"
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs))
+        axis = int(attrs.get("axis", 0))
+
+        true_names = ["weight", "bias"]
+        false_names = ["-input_"]
+
+        for i in inputs:
+            T_NAMES = any(x in str(i) for x in true_names)
+            F_NAMES = any(x in str(i) for x in false_names)
+            if T_NAMES and not F_NAMES:
+                add_b = i
+            else:
+                add_a = i
+
+        # fix the shape
+        add_shape = infer_shape(add_a)
+        if len(add_shape) > 2:
+            add_b = _op.expand_dims(add_b, axis=axis, num_newaxis=len(add_shape) - 2)
+        add_b_shape = list(infer_shape(add_b))
+        add_b_shape.insert(0, add_shape[0])
+
+        add_b = _op.reshape(add_b, tuple(add_b_shape))
+        out = get_relay_op(cls.name)(add_a, add_b)
+
+        return out
+
+
+class Expand(OneFlowOpConverter):
+    """Operator converter for Expand"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        input_shape = infer_shape(inputs[0])
+        assert input_shape == attrs["in_shape"], "shape wrong"
+
+        new_shape = attrs["out_shape"]
+        out = _op.broadcast_to(inputs[0], shape=new_shape)
+
+        return out
+
+
+class ExpandDim(OneFlowOpConverter):
+    """Operator converter for ExpandDim"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+
+        return _op.expand_dims(inputs[0], axis=attrs.get("axis", 0))
+
+
+class BroadcastMath(OneFlowOpConverter):
+    """Operator converter for broadcast math ops"""
+
+    name = ""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(cls.name, len(inputs))
+        beta_names = ["weight", "bias", "mean", "var", "Constant"]
+
+        for i in inputs:
+            T_NAMES = any([x in str(i) for x in beta_names])
+            if T_NAMES and "-input_" not in str(i):
+                input_b = i
+            else:
+                input_a = i
+
+        if cls.name == "divide":
+            length = []
+            for i in inputs:
+                length.append(len(str(i)))
+            for i in inputs:
+                if len(str(i)) == max(length):
+                    input_a = i
+                else:
+                    input_b = i
+        if cls.name == "subtract":
+            length = []
+            for i in inputs:
+                length.append(len(str(i)))
+            for i in inputs:
+                if len(str(i)) == max(length):
+                    input_b = i
+                else:
+                    input_a = i
+        try:
+            return get_relay_op(cls.name)(input_a, input_b)
+        except UnboundLocalError:
+            return get_relay_op(cls.name)(*inputs)
+
+
+class BroadcastMul(BroadcastMath):
+    """Operator converter for Mul broadcast"""
+
+    name = "multiply"
+
+
+class BroadcastAdd(BroadcastMath):
+    """Operator converter for Add broadcast"""
+
+    name = "add"
+
+
+class BroadcastSub(BroadcastMath):
+    """Operator converter for Sub broadcast"""
+
+    name = "subtract"
+
+
+class BroadcastDiv(BroadcastMath):
+    """Operator converter for Div broadcast"""
+
+    name = "divide"
+
+
+class Greater(OneFlowOpConverter):
+    """Operator converter for greater"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        return _op.greater(inputs[0], inputs[1])
+
+
+class Log1p(OneFlowOpConverter):
+    """Operator converter for Log1p"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        return _op.log(inputs[0] + _expr.const(1.0))
+
+
+class Expm1(OneFlowOpConverter):
+    """Operator converter for Expm1"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        return _op.exp(inputs[0]) - _expr.const(1.0)
+
+
+class Unary(OneFlowOpConverter):
+    """A helper class for unary op converters"""
+
+    name = ""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        assert len(inputs) == 1, "Unary math op {} takes 1 input, {} given".format(
+            cls.name, len(inputs)
+        )
+        return get_relay_op(cls.name)(*inputs)
+
+
+class Absolute(Unary):
+    """Operator converter for Absolute."""
+
+    name = "abs"
+
+
+class AddN(OneFlowOpConverter):
+    """Operator converter for Add_n"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given."
+
+        res = inputs[0]
+        for each in inputs[1:]:
+            res = _op.add(res, each)
+        return res
+
+
+class ScalarAdd(OneFlowOpConverter):
+    """Operator convert for Add_scalar"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs))
+
+        if attrs.get("has_int_operand", True):
+            res = inputs[0] + _expr.const(attrs["int_operand"])
+        elif attrs.get("has_float_operand", True):
+            res = inputs[0] + _expr.const(attrs["float_operand"])
+        else:
+            raise AttributeError(
+                "please check if has_int_operand or has_float_operand in your attrs"
+            )
+
+        return res
+
+
+class ScalarMul(OneFlowOpConverter):
+    """Operator convert for Mul_scalar"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        assert len(inputs) == 1, "add_scalar take == 1 inputs, but {} given.".format(len(inputs))
+
+        if attrs.get("has_int_operand", True):
+            res = inputs[0] * _expr.const(attrs["int_operand"], dtype="float32")
+        elif attrs.get("has_float_operand", True):
+            res = inputs[0] * _expr.const(attrs["float_operand"])
+        else:
+            raise AttributeError(
+                "please check if has_int_operand or has_float_operand in your attrs"
+            )
+
+        return res
+
+
+class ScalarPow(OneFlowOpConverter):
+    """Operator convert for Pow_scalar"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        exponent = attrs.get("exponent", 1.0)
+        exponent = _expr.const(exponent, dtype="float32")
+        return _op.power(inputs[0], exponent)
+
+
+class MaxPool2d(Pool):
+    """Operator converter for MaxPool"""
+
+    name = "max_pool2d"
+
+
+class AveragePool2d(Pool):
+    """Operator converter for AveragePool."""
+
+    name = "avg_pool2d"
+
+
+class Affine(OneFlowOpConverter):
+    """Operator converter for Affine transformation."""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        alpha = _expr.const(attrs.get("alpha", 1.0))
+        beta = _expr.const(attrs.get("beta", 0.0))
+        return (alpha * inputs[0]) + beta
+
+
+class Reshape(OneFlowOpConverter):
+    """Operator converter for Reshape."""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        return _op.reshape(inputs[0], attrs["shape"])
+
+
+class Softmax(OneFlowOpConverter):
+    """Operator converter for Softmax."""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        axis = attrs.get("axis", 1)
+        ndim = len(infer_shape(inputs[0]))
+        if axis < 0:
+            axis += ndim
+        axes = list(range(axis, ndim))
+        x = inputs[0]
+        m = _op.max(x, axes, keepdims=True)
+        e = _op.exp(x - m)
+        return e / _op.sum(e, axes, keepdims=True)
+
+
+class LogSoftmax(OneFlowOpConverter):
+    """Operator converter for LogSoftmax."""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        axis = attrs.get("axis", 1)
+        ndim = len(infer_shape(inputs[0]))
+        if axis < 0:
+            axis += ndim
+        axes = list(range(axis, ndim))
+        x = inputs[0]
+        m = _op.max(x, axes, keepdims=True)
+        e = _op.exp(x - m)
+        s = _op.sum(e, axes, keepdims=True)
+        return x - m - _op.log(s)
+
+
+class Dropout(OneFlowOpConverter):
+    """Operator converter for Dropout."""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        out = AttrCvt("dropout", {"ratio": "rate"}, ignores=["is_test"])
+        return out
+
+
+class ThresholdedRelu(OneFlowOpConverter):
+    """Operator converter for ThresholdedRelu."""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        alpha = float(attrs.get("alpha", 1.0))
+        alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
+        mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
+        return inputs[0] * mask
+
+
+class Elu(OneFlowOpConverter):
+    """Operator converter for Elu"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        alpha = float(attrs.get("alpha", 1.0))
+        return _expr.const(-alpha) * _op.nn.relu(
+            _expr.const(1.0) - _op.exp(inputs[0])
+        ) + _op.nn.relu(inputs[0])
+
+
+class PReLU(OneFlowOpConverter):
+    """Operator converter for PReLU"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        assert len(inputs) == 2, "PReLU need 2 inputs, but {} given".format(len(inputs))
+        for i in inputs:
+            if "-input_" in str(i):
+                prelu_a = i
+            else:
+                prelu_b = i
+
+        input_shape = shape_of(prelu_a)
+        alpha = _op.broadcast_to_like(prelu_b, prelu_a)
+        alpha = _op.reshape(alpha, [-1])
+
+        output = _op.nn.prelu(_op.reshape(prelu_a, [-1]), alpha, axis=0)
+        out = _op.reshape(output, input_shape)
+        return out
+
+
+class Selu(OneFlowOpConverter):
+    """Operator converter for Selu"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        alpha = float(attrs.get("alpha", 1.67326319217681884765625))
+        gamma = float(attrs.get("gamma", 1.05070102214813232421875))
+        return _expr.const(gamma) * (
+            _expr.const(-alpha) * _op.nn.relu(_expr.const(1.0) - _op.exp(inputs[0]))
+            + _op.nn.relu(inputs[0])
+        )
+
+
+class Silu(OneFlowOpConverter):
+    """Operator converter for Silu"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        a = inputs[0]
+        b = _op.sigmoid(inputs[0])
+        return _op.multiply(a, b)
+
+
+class Gelu(OneFlowOpConverter):
+    """Operator converter for Gelu"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        data = inputs[0]
+        return data * (
+            _expr.const(0.5) + _op.erf(data * _expr.const(0.5**0.5)) * _expr.const(0.5)
+        )
+
+
+class HardTanh(OneFlowOpConverter):
+    """Operator converter for HardTanh"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        tanh_min = attrs.get("min_val", 0.0)
+        tanh_max = attrs.get("max_val", 0.0)
+        return _op.tensor.clip(inputs[0], tanh_min, tanh_max)
+
+
+class Softplus(OneFlowOpConverter):
+    """Operator converter for Softplus"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        data = inputs[0]
+        data_dtype = infer_type(data).checked_type.dtype
+        data = _op.exp(data) + _expr.const(1, dtype=data_dtype)
+        return _op.log(data)
+
+
+class Softsign(OneFlowOpConverter):
+    """Operator converter for Softsign"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        return inputs[0] / (_expr.const(1.0) + Absolute.get_converter()(inputs, attrs, params))
+
+
+class Concat(OneFlowOpConverter):
+    """Operator converter for Concat"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        attrs.pop("max_dim_size")
+        inputs = _dtype_shape_promotion(inputs)
+        return _op.concatenate(inputs, axis=attrs["axis"])
+
+
+class Clip(OneFlowOpConverter):
+    """Operator converter for Clip"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        attr = {}
+        dtype = infer_type(inputs[0])
+
+        if "float" in str(dtype):
+            attr["a_min"] = attrs["floating_min"]
+            attr["a_max"] = attrs["floating_max"]
+        elif "int" in str(dtype):
+            attr["a_min"] = attrs["integral_min"]
+            attr["a_max"] = attrs["integral_max"]
+        else:
+            attr["a_min"] = -np.inf
+            attr["a_max"] = np.inf
+
+        out = AttrCvt("clip")(inputs, attr, params)
+        return out
+
+
+class Slice(OneFlowOpConverter):
+    """Operator converter for Slice"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        starts = list(attrs["start"])
+        ends = list(attrs["stop"])
+        steps = list(attrs["step"])
+        return _op.strided_slice(inputs[0], starts, ends, steps)
+
+
+class Split(OneFlowOpConverter):
+    """Operator converter for Split"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        splits = attrs.get("split", None)
+        if splits is not None:
+            indices = []
+            attrs["indices_or_sections"] = []
+            index = 0
+            for i in splits[:-1]:
+                index += i
+                indices.append(index)
+        output = _op.split(inputs[0], indices, attrs.get("axis", 0))
+        # If the output of split is a single value, unpack if from the TupleWrapper
+        if len(output) == 1:
+            output = output[0]
+        return output
+
+
+class Scatter(OneFlowOpConverter):
+    """Operator converter for Scatter"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        axis = attrs.get("axis", 0)
+        return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
+
+
+class Unsqueeze(OneFlowOpConverter):
+    """Operator converter for Unsqueeze"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        axes = sorted(attrs["axes"])
+        for axis in axes:
+            inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1)
+        return inputs[0]
+
+
+class Sign(OneFlowOpConverter):
+    """Operator converter for Sign"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        return _op.sign(inputs[0])
+
+
+class Reciprocal(OneFlowOpConverter):
+    """Operator converter for Reciprocal"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        dtype = infer_type(inputs[0]).checked_type.dtype
+        return _expr.const(1.0, dtype=dtype) / inputs[0]
+
+
+class Erf(OneFlowOpConverter):
+    """Operator converter for Erf"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        return _op.erf(inputs[0])
+
+
+class Erfc(OneFlowOpConverter):
+    """Operator converter for Erfs"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        return _expr.const(1.0) - _op.erf(inputs[0])
+
+
+class HardSigmoid(OneFlowOpConverter):
+    """Operator converter for HardSigmoid"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        alpha = attrs.get("alpha", 0.2)
+        beta = attrs.get("beta", 0.5)
+        transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta)
+        attr = {"a_min": 0, "a_max": 1}
+        return AttrCvt("clip")([transformX], attr)
+
+
+class OneHot(OneFlowOpConverter):
+    """Operator converter for OneHot"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        # Extract relay one_hot inputs.
+        indices, depth, values = inputs
+        ndim = len(infer_shape(indices))
+        # Split onnx on off values into two separate expressions.
+        off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1))
+        # Extract the datatype of the output from on_value.
+        dtype = infer_type(on_value).checked_type.dtype
+        ind_dtype = infer_type(indices).checked_type.dtype
+        # Normalize the indices to a positive range
+        indices = _op.where(
+            indices < _op.const(0, ind_dtype), indices + _op.cast(depth, ind_dtype), indices
+        )
+        # set default value when axis is not set in the model
+        axis = attrs.get("axis", -1)
+        if axis < 0:
+            axis += ndim + 1
+
+        return _op.one_hot(indices, on_value, off_value, depth, axis, dtype=dtype)
+
+
+class Where(OneFlowOpConverter):
+    """Operator converter for Where"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        condition_rank = len(infer_shape(inputs[0]))
+        x_rank = len(infer_shape(inputs[1]))
+        y_rank = len(infer_shape(inputs[2]))
+        ranks = [condition_rank, x_rank, y_rank]
+
+        # If one rank is longer than others, then we can broadcast
+        # to that shape.
+        max_rank = max(ranks)
+        max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank]
+        broadcast_shape = shape_of(inputs[max_rank_idxs[0]])
+        # If two or more inputs have the same rank, compute the broadcast
+        # shape by taking the maximum value of each dimensions.
+        if len(max_rank_idxs) > 1:
+            for idx in max_rank_idxs:
+                broadcast_shape = _op.maximum(broadcast_shape, shape_of(inputs[idx]))
+
+        broadcast_shape = fold_constant(broadcast_shape)
+
+        condition = _op.broadcast_to(inputs[0], broadcast_shape)
+        x = _op.broadcast_to(inputs[1], broadcast_shape)
+        y = _op.broadcast_to(inputs[2], broadcast_shape)
+        return _op.where(condition, x, y)
+
+
+class Constant(OneFlowOpConverter):
+    """Operator converter for Constant"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        is_float = attrs.get("is_floating_value", True)
+        shape = attrs.get("shape", (1,))
+        if is_float:
+            dtype = "float32"
+            value = attrs.pop("floating_value")
+        else:
+            dtype = "int8"
+            value = attrs.pop("integer_value")
+        np_array = np.zeros(shape)
+        np_array.fill(value)
+        value = _expr.const(np_array, dtype)
+        return value
+
+
+class Range(OneFlowOpConverter):
+    """Operator converter for Range"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        if len(inputs) != 0:
+            raise ValueError("Expect no inputs but get {}".format(len(inputs)))
+        start = attrs.get("start", 0.0)
+        limit = attrs.get("limit", 1.0)
+        delta = attrs.get("delta", 1.0)
+        return _op.arange(
+            _expr.const(start, dtype="float32"),
+            _expr.const(limit, dtype="float32"),
+            _expr.const(delta, dtype="float32"),
+        )
+
+
+class Cast(OneFlowOpConverter):
+    """Operator converter for Cast"""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attrs, params):
+        attrs["dtype"] = infer_type(inputs[0]).checked_type.dtype
+        return AttrCvt(op_name="cast")(inputs, attrs)
+
+
+def get_convert_map():
+    # supported oneflow2relay op
+    return {
+        # defs/math
+        "bias_add": Add.get_converter(),
+        "scalar_add": ScalarAdd.get_converter(),
+        "scalar_mul": ScalarMul.get_converter(),
+        "scalar_pow": ScalarPow.get_converter(),
+        "reduce_sum": ReduceSum.get_converter(),
+        "reduce_max": ReduceMax.get_converter(),
+        "reduce_min": ReduceMin.get_converter(),
+        "reduce_mean": ReduceMean.get_converter(),
+        "broadcast_add": BroadcastAdd.get_converter(),
+        "broadcast_mul": BroadcastMul.get_converter(),
+        "broadcast_sub": BroadcastSub.get_converter(),
+        "broadcast_div": BroadcastDiv.get_converter(),
+        "broadcast_greater": Greater.get_converter(),
+        "log": Renamer("log"),
+        "log1p": Log1p.get_converter(),
+        "acos": Renamer("acos"),
+        "acosh": Renamer("acosh"),
+        "asin": Renamer("asin"),
+        "asinh": Renamer("asinh"),
+        "atan": Renamer("atan"),
+        "atanh": Renamer("atanh"),
+        "cos": Renamer("cos"),
+        "cosh": Renamer("cosh"),
+        "sin": Renamer("sin"),
+        "sinh": Renamer("sinh"),
+        "tan": Renamer("tan"),
+        "tanh": Renamer("tanh"),
+        "pow": Renamer("power"),
+        "exp": Renamer("exp"),
+        "expm1": Expm1.get_converter(),
+        "floor": Renamer("floor"),
+        "ceil": Renamer("ceil"),
+        "round": Renamer("round"),
+        "add_n": AddN.get_converter(),
+        "sqrt": Renamer("sqrt"),
+        "rsqrt": Renamer("rsqrt"),
+        "square": Square.get_converter(),
+        "sign": Sign.get_converter(),
+        "erf": Erf.get_converter(),
+        "erfc": Erfc.get_converter(),
+        "reciprocal_no_nan": Reciprocal.get_converter(),
+        # defs/activation
+        "softmax": Softmax.get_converter(),
+        "softsign": Softsign.get_converter(),
+        "hardtanh": HardTanh.get_converter(),
+        "relu": Renamer("relu"),
+        "leaky_relu": Renamer("leaky_relu"),
+        "prelu": PReLU.get_converter(),
+        "selu": Selu.get_converter(),
+        "silu": Silu.get_converter(),
+        "gelu": Gelu.get_converter(),
+        # defs/nn
+        "conv2d": Conv2d.get_converter(),
+        "deconv2d": ConvTranspose2d.get_converter(),
+        "maxpool_2d": MaxPool2d.get_converter(),
+        "avgpool_2d": AveragePool2d.get_converter(),
+        "adaptive_avg_pool2d": AdaptiveAvgPool2d.get_converter(),
+        "adaptive_max_pool2d": AdaptiveMaxPool2d.get_converter(),
+        "dropout": Dropout.get_converter(),
+        "normalization": BatchNorm.get_converter(),
+        "upsample_nearest_2d": UpsampleNearest.get_converter(),
+        "upsample_bilinear_2d": UpsampleBiLinear.get_converter(),
+        # defs/tensor
+        "matmul": MatMul.get_converter(),
+        "concat": Concat.get_converter(),
+        "clip_by_scalar": Clip.get_converter(),
+        "slice": Slice.get_converter(),
+        "expand": Expand.get_converter(),
+        "transpose": AttrCvt("transpose", {"perm": "axes"}),
+        "expand_dims": ExpandDim.get_converter(),
+        "range": Range.get_converter(),
+        "cast": Cast.get_converter(),
+        # defs/others
+        "reshape": Reshape.get_converter(),
+        "constant": Constant.get_converter(),
+        # "where": Where.get_converter(),
+        "flatten": Flatten.get_converter(),
+        "sigmoid": Renamer("sigmoid"),
+        "sigmoid_v2": Renamer("sigmoid"),
+        "hardsigmoid": HardSigmoid.get_converter(),
+        "squeeze": AttrCvt("squeeze", {"axes": "axis"}),
+        "unsqueeze": Unsqueeze.get_converter(),
+    }
+
+
+class oneflow_input(object):
+    """
+    Dual purpose list or dictionary access object
+    """
+
+    def __init__(self):
+        self.input_keys = []
+        self.input_dict = {}
+        self.n = 0
+
+    def __getitem__(self, item):
+        if isinstance(item, int):
+            if item > (len(self.input_keys) - 1):
+                return None
+            return self.input_dict[self.input_keys[item]]
+        if isinstance(item, str):
+            if item not in self.input_keys:
+                return None
+            return self.input_dict[item]
+        if isinstance(item, slice):
+            keys = self.input_keys[item]
+            return [self.input_dict[key] for key in keys]
+
+        raise ValueError("Only integer, string, and slice accesses allowed.")
+
+    def __setitem__(self, item, value):
+        if isinstance(item, int):
+            self.input_dict[self.input_keys[item]] = value
+        elif isinstance(item, str):
+            self.input_keys.append(item)
+            self.input_dict[item] = value
+        else:
+            raise ValueError("Only integer and string indexed writes allowed.")
+
+    def keys(self):
+        return self.input_keys
+
+    def __len__(self):
+        return len(self.input_keys)
+
+    def __iter__(self):
+        self.n = 0
+        return self
+
+    def __next__(self):
+        if self.n < len(self.input_keys):
+            output = self.input_dict[self.input_keys[self.n]]
+            self.n += 1
+            return output
+
+        raise StopIteration
+
+
+def deal_with_input_convert(
+    node_input, node_input_shape, node_input_dtype, node_path, _nodes, _input_path_2_name
+):
+    """deal with input convert in oneflow."""
+    if node_input not in _nodes:
+        if (
+            node_path not in _input_path_2_name
+            or "-input_" in node_input
+            or "FreeEagerTensor" in node_input
+        ):
+            _nodes[node_input] = new_var(
+                node_input,
+                shape=node_input_shape,
+                dtype=node_input_dtype,
+            )
+        else:
+            names = _input_path_2_name[node_path]
+            node_replace = None
+            for k in names:
+                if k in _nodes:
+                    node_replace = k
+            if node_replace is not None:
+                op_replace = copy.deepcopy(_nodes[node_replace])
+                _nodes[node_input] = op_replace
+            else:
+                print("{} will not be in _nodes".format(node_input))
+
+
+def deal_parameter_convert(
+    node_input_paths, model_dir_path, _input_path_2_name, _model_array, _params, _nodes
+):
+    """deal with parameter(weight) convert in oneflow."""
+    for node_input_path in node_input_paths:
+        node_path = os.path.join(model_dir_path, node_input_path.replace("m.", ""))
+        node_input_name = node_input_path.split("/")[0]
+        _input_path_2_name[node_path] = node_input_name
+        for param_name in _model_array:
+            node_p = _model_array[param_name]
+            if node_path == node_p["path"]:
+                node_array = node_p["params"]
+                _params[node_input_name] = node_array
+                _nodes[node_input_name] = new_var(
+                    node_input_name, shape=node_array.shape, dtype=str(node_array.dtype)
+                )
+                break
+
+
+class OneflowGraph(object):
+    """
+    A helper class for handling Relay expression
+
+    Parameters
+    ----------
+    shape : dict of str to tuple, optional
+        The input shape to the graph
+    dtype : dict of str to str
+        The input types to the graph
+
+    node name:
+    1. param: m.layer4.1.bn1.weight / ...
+    2. buffer: m.layer4.1.bn1.running_mean / ...
+    3. node inputs: m.layer4.1.bn1-input_0
+    4. node outputs: m.layer4.1.bn1-output_0
+    """
+
+    def __init__(self, shape, dtype, nodes, model_dir_path):
+        self._nodes = {}
+        self._params = {}
+        self._inputs = {}
+        self._num_input = 0
+        self._num_param = 0
+        self._input_names = []
+        self._model_array = {}
+        self._input_path_2_name = {}
+        self._output_path_2_name = {}
+        self._init_variable_node = []
+        self._shape = shape
+        self._dtype = dtype
+        self._identity_list = []
+        self._sort_inputs = {}
+
+        import oneflow
+
+        model = oneflow.load(model_dir_path)
+        # model_array: keys: layer_name,values: dict('path', 'params')
+        for layer_name in model:
+            layer = model[layer_name]
+            layer_node = {}
+            layer_node["path"] = os.path.join(model_dir_path, layer_name, "out")  # get path
+            if "System-Train" in layer_name:
+                continue
+            node_name = "m." + layer_name
+            shape = self._shape[node_name]
+            dtype = self._dtype[node_name]
+            array = layer.detach().cpu().numpy()
+            layer_node["params"] = array.reshape(shape)
+            self._model_array[layer_name] = layer_node
+
+        for node_name in nodes:
+            node = nodes[node_name]
+            if is_user_op(node):
+                for input_name in node.user_conf.input:
+                    node_input_paths = getattr(node.user_conf.input[input_name], "s")
+                    deal_parameter_convert(
+                        node_input_paths,
+                        model_dir_path,
+                        self._input_path_2_name,
+                        self._model_array,
+                        self._params,
+                        self._nodes,
+                    )
+                for output_name in node.user_conf.output:
+                    node_output_paths = getattr(node.user_conf.output[output_name], "s")
+                    for node_output_path in node_output_paths:
+                        node_path = os.path.join(model_dir_path, node_output_path.replace("m.", ""))
+                        node_output_name = node_output_path.split("/")[0]
+                        self._output_path_2_name[node_path] = node_output_name
+            elif is_output_op(node):
+                node_output_path = getattr(node.output_conf, "in")
+                output_path = os.path.join(
+                    model_dir_path, getattr(node.output_conf, "in").replace("m.", "")
+                )
+                self._output_path_2_name[output_path] = node_name
+            elif is_param_op(node):
+                if "FreeEagerTensor" in node.name:
+                    shape = tuple(node.variable_conf.shape.dim)
+                    dtype = FLOW_2_STR_DTYPE[node.variable_conf.data_type]
+                    self._shape[node.name] = shape
+                    self._dtype[node.name] = dtype
+                    self._init_variable_node.append(node.name)
+        if self._init_variable_node != []:
+            print("{} should be defined by user".format(self._init_variable_node))
+
+    def _parse_input(self, node, model_dir_path):
+        for input_name in node.user_conf.input:
+            node_input_paths = getattr(node.user_conf.input[input_name], "s")
+            for i in node_input_paths:
+                node_input = i.split("/")[0]
+                node_input_shape = self._shape[node_input]
+                node_input_dtype = self._dtype[node_input]
+                node_path = os.path.join(model_dir_path, i.replace("m.", ""))
+                deal_with_input_convert(
+                    node_input,
+                    node_input_shape,
+                    node_input_dtype,
+                    node_path,
+                    self._nodes,
+                    self._input_path_2_name,
+                )
+
+    def _parse_output(self, op_name, outputs, cnt_init=0):
+        """
+        o: m.classifier.1-output_xxx
+        new_o: m.classifier.1-conv2d_0
+        "_"+new_o is in self._shape
+        """
+        for o in outputs:
+            if "-output_" not in o:
+                new_o = o.replace("-" + op_name, "-output")
+                new_o = new_o.replace("_" + new_o.split("_")[-1], "_0")
+                self._shape[o] = self._shape["_" + new_o]
+                self._dtype[o] = self._dtype["_" + new_o]
+            elif len(outputs) > 1:
+                outputs.remove(o)
+        if op_name.lower() == "dropout":
+            if len(outputs) == 1:
+                return outputs
+            outputs = outputs[:-1]
+        elif op_name.lower() == "constant":
+            outputs = [self._init_variable_node[cnt_init]]
+
+        if len(outputs) > 1:
+            outputs = list(set(outputs))
+
+        return outputs
+
+    def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=None):
+        """
+        Parameters
+        ----------
+        nodes : dict, keys: node.name, value: node
+            contain the graph
+        model_dir_path: str
+            The path of parameter
+        freeze_params: bool
+            If freeze_params is True,
+            the computational graph input is the input of the first layer of the network,
+            which cannot be specified by the user, e.g.
+            Default input is: %v_ResNetGraph_0-input_0: Tensor[(1, 3, 224, 224), float32]
+            User-defined input is: %_0-input_0: Tensor[(1, 3, 640, 480), float32]
+            If freeze_params is on, then conv1-in will be the graph input, not Input_0
+        user_input: dict
+            User-defined input information for the graph
+            {
+                node1_name:
+                {
+                    'name':  node1_name,   # str, like "%v_ResNetGraph_0-input_0"
+                    'shape': node1_shape,  # tuple
+                    'dtype': node1_dtype   # str, like "float32"
+                }
+                ...
+            }
+        We recommend that users specify the input by specifying the job function,
+        rather than by this function
+
+        Returns
+        -------
+        mod : tvm.IRModule
+            The returned relay module
+        params : dict
+            A dict of name: tvm.nd.array pairs, used as pretrained weights
+        """
+        # step 1: get the graph input
+        if not freeze_params:
+            for node_init_name in user_input:
+                if "-input_" not in node_init_name:
+                    raise KeyError(
+                        "user_input['name'] should contain '-input_' "
+                        + "to let program know that this is input node"
+                    )
+                self._nodes[node_init_name] = new_var(
+                    node_init_name,
+                    shape=user_input[node_init_name]["shape"],
+                    dtype=user_input[node_init_name]["dtype"],
+                )
+                self._inputs[node_init_name] = self._nodes[node_init_name]
+
+        # step 2: find out if unsupported ops are used
+        convert_map = get_convert_map()
+        unsupported_ops = set()
+        for node_name in nodes:
+            node = nodes[node_name]
+            if is_user_op(node):
+                # op names, not the layer names
+                op_name = node.user_conf.op_type_name
+                if (
+                    op_name not in convert_map
+                    and "constant" not in op_name
+                    and op_name not in self._identity_list
+                ):
+                    unsupported_ops.add(op_name)
+        # find out the unsupported op
+        if unsupported_ops:
+            msg = "The following operators are not supported for frontend OneFlow: "
+            msg += ", ".join(unsupported_ops)
+            raise tvm.error.OpNotImplemented(msg)
+
+        # step 3: convert op
+        for node_name in nodes:
+            node = nodes[node_name]
+            if is_user_op(node):
+                # If there is a user-defined node, skip the following steps
+                if node_name in self._inputs:
+                    continue
+
+                op_name = node.user_conf.op_type_name
+                op_attr = parse_attr(node.user_conf.attr)
+
+                self._parse_input(node, model_dir_path=model_dir_path)
+
+                node_inputs = oneflow_input()
+                for input_name in node.user_conf.input:
+                    node_input_paths = getattr(node.user_conf.input[input_name], "s")
+                    for i in node_input_paths:
+                        node_input = i.split("/")[0]
+                        node_inputs[node_input] = self._nodes[node_input]
+
+                node_outputs = []
+                for output_name in node.user_conf.output:
+                    node_output_paths = getattr(node.user_conf.output[output_name], "s")
+                    for i in node_output_paths:
+                        node_output_path = os.path.join(model_dir_path, i.replace("m.", ""))
+                        if node_output_path in self._input_path_2_name:
+                            node_outputs.append(self._input_path_2_name[node_output_path])
+                        elif node_output_path in self._output_path_2_name:
+                            node_outputs.append(self._output_path_2_name[node_output_path])
+                node_outputs = self._parse_output(op_name, node_outputs)
+
+                # convert
+                op = self._convert_operator(op_name, node_inputs, op_attr)
+
+                if not isinstance(op, _expr.TupleWrapper):
+                    outputs_num = 1
+                else:
+                    outputs_num = len(op)
+
+                assert (
+                    len(node_outputs) == outputs_num
+                ), "Number of output mismatch {} vs {} in {}.".format(
+                    len(node_outputs), outputs_num, op_name
+                )
+
+                if outputs_num == 1:
+                    op = fold_constant(op)
+                else:
+                    op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))
+
+                op_temp = []
+                op_temp.append(op)
+                for i, _ in enumerate(node_outputs):
+                    if isinstance(node_outputs[i], list):
+                        for k in node_outputs[i]:
+                            self._nodes[k] = op_temp[i]
+                    else:
+                        self._nodes[node_outputs[i]] = op_temp[i]
+
+        # step 4: get the outputs
+        outputs = []
+        for node_name in nodes:
+            node = nodes[node_name]
+            if is_output_op(node):
+                node_name_v2 = getattr(node.output_conf, "in").split("/")[0]
+                if node_name in self._nodes:
+                    outputs.append(self._nodes[node_name])
+                elif node_name_v2 in self._nodes:
+                    outputs.append(self._nodes[node_name_v2])
+        outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
+
+        # step 5: get the relay IR
+        free_vars = analysis.free_vars(outputs)
+
+        nodes = {v: k for k, v in self._nodes.items()}
+        free_vars = [nodes[var] for var in free_vars]
+
+        # step 6: make sure the '-input_0' is the first in self._inputs
+        for free_var in free_vars:
+            if free_var not in self._inputs:
+                self._inputs[free_var] = self._nodes[free_var]
+
+        input_names = list(self._inputs.keys())
+        for i, _ in enumerate(input_names):
+            if i != 0 and "-input_0" in input_names[i]:
+                str_buffer = copy.deepcopy(input_names[i])
+                del input_names[i]
+                input_names.insert(0, str_buffer)
+                break
+
+        for input_name in input_names:
+            if input_name in self._inputs:
+                self._sort_inputs[input_name] = self._inputs[input_name]
+            else:
+                raise IndexError("{} is not in self._inputs".format(input_name))
+
+        # step 7: create a function from our output expression and all input variables.
+        func = _function.Function([v for _, v in self._sort_inputs.items()], outputs)
+
+        return IRModule.from_expr(func), self._params
+
+    def _convert_operator(self, op_name, node_inputs, op_attr):
+        """
+        Parameters
+        ----------
+        op_name : str
+            Operator name, such as conv2d、relu
+        node_inputs : list of tvm.relay.function.Function
+            List of inputs.
+        op_attr : dict
+            Dict of operator attributes
+
+        Returns
+        -------
+        sym : tvm.relay.function.Function
+            Converted relay function
+        """
+        convert_map = get_convert_map()
+        if op_name in self._identity_list:
+            sym = get_relay_op(op_name)(*node_inputs, **op_attr)
+        elif op_name in convert_map:
+            sym = convert_map[op_name](node_inputs, op_attr, self._params)
+        else:
+            raise NotImplementedError("Operator {} not implemented.".format(op_name))
+
+        return sym
+
+
+def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None):
+    """
+    see OneflowGraph.from_oneflow
+    """
+    try:
+        import oneflow as flow
+    except ImportError:
+        raise ImportError("please check that OneFlow is installed")
+
+    if not freeze_params and user_input is None:
+        raise ValueError("if you want to specify graph input, please give the 'user_input'")
+    if freeze_params and user_input is not None:
+        warnings.warn("'user_input' will not work, please check the 'freeze_params'")
+
+    # get info of nodes
+    shape = {}
+    dtype = {}
+    graph_str = repr(graph)
+    size_where = 2
+    if "cuda" in graph_str:
+        size_where = 3
+
+    p_size = re.compile(r"size=\(.*?\)", re.S)
+    p_type = re.compile(r"dtype=.*?\)", re.S)
+    types = ["INPUT", "PARAMETER", "BUFFER", "OUTPUT"]
+    for t in types:
+        data = re.finditer(t + ":.*", graph_str)
+        for i in data:
+            attrs = i.group().split(":")
+            size_str = re.findall(p_size, attrs[size_where])
+            type_str = re.findall(p_type, attrs[size_where])
+            assert size_str != [], "size should not be None, please check your repr(graph)"
+
+            size_attr = size_str[0].replace("size=", "")
+            if size_attr[-2] == ",":
+                size_attr = size_attr.replace(",", "")
+            data_size = tuple(map(int, size_attr[1:-1].split(", ")))
+            node_name = attrs[1]
+            shape[node_name] = data_size
+            dtype[node_name] = "float32"
+
+            if type_str != []:
+                type_attr = type_str[0].replace("dtype=", "").replace(")", "")
+                if type_attr[-1] == ",":
+                    type_attr = type_attr.replace(",", "")
+                dtype[node_name] = type_attr.replace("oneflow.", "")
+
+    # get graph proto, if you don't _compile the graph, the _graph_proto will be None
+    graph_input = re.search(r"INPUT:.*", graph_str).group().split(":")
+    shape_input = tuple(
+        map(
+            int,
+            re.findall(p_size, graph_input[size_where])[0].replace("size=", "")[1:-1].split(", "),
+        )
+    )
+    if not graph._is_compiled:
+        graph._compile(flow.rand(shape_input))
+    graph_proto = graph._graph_proto
+
+    # get all nodes
+    nodes = {}
+    for op in graph_proto.net.op:
+        nodes[op.name] = op
+
+    g = OneflowGraph(shape, dtype, nodes, model_dir_path)
+
+    # Use the graph proto as a scope so that ops can access other nodes if needed.
+    mod, params = g.from_oneflow(
+        nodes=nodes,
+        model_dir_path=model_dir_path,
+        freeze_params=freeze_params,
+        user_input=user_input,
+    )
+
+    return mod, params
diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py
new file mode 100644
index 0000000000..d144cdad2b
--- /dev/null
+++ b/tests/python/frontend/oneflow/test_forward.py
@@ -0,0 +1,723 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-self, invalid-name
+# pylint: disable=arguments-differ, unused-argument, unused-import
+"""Unit tests for various models and operators"""
+import os
+import sys
+
+import numpy as np
+import pytest
+import tvm
+import tvm.testing
+import tvm.topi.testing
+from tvm import relay
+from tvm.contrib import graph_executor
+
+import oneflow as flow
+
+MODEL_HOME = "test_model"
+
+
+def mkdir(path):
+    # init
+    path = path.strip()
+    path = path.rstrip("\\")
+
+    if not os.path.exists(path):
+        os.makedirs(path)
+    else:
+        print("{} is already here".format(path))
+
+
+def rmdir(path):
+    for root, dirs, files in os.walk(path, topdown=False):
+        for name in files:
+            os.remove(os.path.join(root, name))
+        for name in dirs:
+            os.rmdir(os.path.join(root, name))
+    os.removedirs(path)
+
+
+def assert_shape(out1, out2):
+    if out1.shape != out2.shape:
+        msg = "Output shapes {} and {} don't match"
+        raise AssertionError(msg.format(out1.shape, out2.shape))
+
+
+class OneFlowGraph(flow.nn.Graph):
+    def __init__(self, module):
+        super().__init__()
+        self.m = module
+
+    def build(self, x):
+        out = self.m(x)
+        return out
+
+
+class OneFlowGraph_v2(flow.nn.Graph):
+    def __init__(self, module):
+        super().__init__()
+        self.m = module
+
+    def build(self, x1, x2, x3):
+        out = self.m(x1, x2, x3)
+        return out
+
+
+def get_oneflow_output(model, inputs):
+    flow_output = model(inputs)
+    return flow_output.numpy()
+
+
+def get_oneflow_concat_output(model, input1, input2, input3):
+    flow_output = model(input1, input2, input3).numpy()
+    return flow_output
+
+
+def get_tvm_output(graph, model_path, inputs: flow.tensor, target="llvm", dtype="float32"):
+    inputs_numpy = inputs.numpy()
+    if target == "llvm":
+        device = tvm.cpu(0)
+    elif target == "cuda":
+        device = tvm.cuda(0)
+
+    mod, params = relay.frontend.from_oneflow(graph, model_path)
+    with tvm.transform.PassContext(opt_level=10):
+        intrp = relay.build_module.create_executor("graph", mod, device, target)
+    tvm_output = intrp.evaluate()(tvm.nd.array(inputs_numpy.astype(dtype)), **params).numpy()
+    return tvm_output
+
+
+def get_tvm_concat_output(
+    graph,
+    model_path,
+    input1: flow.tensor,
+    input2: flow.tensor,
+    input3: flow.tensor,
+    target="llvm",
+    dtype="float32",
+):
+    input1_numpy = input1.numpy()
+    input2_numpy = input2.numpy()
+    input3_numpy = input3.numpy()
+    if target == "llvm":
+        device = tvm.cpu(0)
+    elif target == "cuda":
+        device = tvm.cuda(0)
+
+    mod, params = relay.frontend.from_oneflow(graph, model_path)
+    with tvm.transform.PassContext(opt_level=10):
+        intrp = relay.build_module.create_executor("graph", mod, device, target)
+    tvm_output = intrp.evaluate()(
+        tvm.nd.array(input1_numpy.astype(dtype)),
+        tvm.nd.array(input2_numpy.astype(dtype)),
+        tvm.nd.array(input3_numpy.astype(dtype)),
+        **params,
+    ).numpy()
+    return tvm_output
+
+
+def verify_conv(
+    model,
+    name="",
+    rtol=1e-5,
+    atol=1e-5,
+    inputs=flow.tensor(
+        np.random.rand(1, 3, 224, 224),
+        dtype=flow.float32,
+    ),
+    device="llvm",
+):
+    if device == "cuda":
+        model.to(device)
+        inputs = inputs.to(device)
+
+    graph = OneFlowGraph(model)
+    graph._compile(inputs)
+
+    mkdir(MODEL_HOME)
+    flow.save(model.state_dict(), MODEL_HOME)
+
+    out_flow = get_oneflow_output(graph, inputs)
+    out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device)
+    rmdir(MODEL_HOME)
+
+    assert_shape(out_flow, out_tvm)
+    tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
+
+
+def verify_pool(
+    model,
+    name="",
+    rtol=1e-5,
+    atol=1e-5,
+    inputs=flow.tensor(
+        np.random.rand(1, 3, 224, 224),
+        dtype=flow.float32,
+    ),
+    device="llvm",
+):
+    if device == "cuda":
+        model.to(device)
+        inputs = inputs.to(device)
+
+    graph = OneFlowGraph(model)
+    graph._compile(inputs)
+
+    mkdir(MODEL_HOME)
+    flow.save(model.state_dict(), MODEL_HOME)
+
+    out_flow = get_oneflow_output(graph, inputs)
+    out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device)
+    rmdir(MODEL_HOME)
+
+    assert_shape(out_flow, out_tvm)
+    tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
+
+
+def verify_normalization(
+    model,
+    name="",
+    rtol=1e-5,
+    atol=1e-5,
+    inputs=flow.tensor(
+        np.random.rand(1, 3, 224, 224),
+        dtype=flow.float32,
+    ),
+    device="llvm",
+):
+    if device == "cuda":
+        model.to(device)
+        inputs = inputs.to(device)
+
+    graph = OneFlowGraph(model)
+    graph._compile(inputs)
+
+    # write params
+    mkdir(MODEL_HOME)
+    flow.save(model.state_dict(), MODEL_HOME)
+
+    out_flow = get_oneflow_output(graph, inputs)
+    out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device)
+    rmdir(MODEL_HOME)
+
+    assert_shape(out_flow, out_tvm)
+    tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
+
+
+def verify_upsample(
+    model,
+    name="",
+    rtol=1e-5,
+    atol=1e-5,
+    inputs=flow.tensor(
+        np.random.rand(1, 3, 50, 50),
+        dtype=flow.float32,
+    ),
+    device="llvm",
+):
+    if device == "cuda":
+        model.to(device)
+        inputs = inputs.to(device)
+
+    graph = OneFlowGraph(model)
+    graph._compile(inputs)
+
+    mkdir(MODEL_HOME)
+    flow.save(model.state_dict(), MODEL_HOME)
+
+    out_flow = get_oneflow_output(graph, inputs)
+    out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device)
+    rmdir(MODEL_HOME)
+
+    assert_shape(out_flow, out_tvm)
+    tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
+
+
+def verify_convtran(
+    model,
+    name="",
+    rtol=1e-5,
+    atol=1e-5,
+    inputs=flow.tensor(
+        np.random.rand(1, 3, 50, 50),
+        dtype=flow.float32,
+    ),
+    device="llvm",
+):
+    if device == "cuda":
+        model.to(device)
+        inputs = inputs.to(device)
+
+    graph = OneFlowGraph(model)
+    graph._compile(inputs)
+
+    mkdir(MODEL_HOME)
+    flow.save(model.state_dict(), MODEL_HOME)
+
+    out_flow = get_oneflow_output(graph, inputs)
+    out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device)
+    rmdir(MODEL_HOME)
+
+    assert_shape(out_flow, out_tvm)
+    tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
+
+
+def verify_activation(
+    model,
+    name="",
+    rtol=1e-5,
+    atol=1e-5,
+    inputs=flow.tensor(
+        np.random.rand(10, 10),
+        dtype=flow.float32,
+    ),
+    device="llvm",
+):
+    if device == "cuda":
+        model.to(device)
+        inputs = inputs.to(device)
+
+    graph = OneFlowGraph(model)
+    graph._compile(inputs)
+
+    mkdir(MODEL_HOME)
+    flow.save(model.state_dict(), MODEL_HOME)
+
+    out_flow = get_oneflow_output(graph, inputs)
+    out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device)
+    rmdir(MODEL_HOME)
+
+    assert_shape(out_flow, out_tvm)
+    tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
+
+
+def verify_math(
+    model,
+    name="",
+    rtol=1e-5,
+    atol=1e-5,
+    inputs=flow.tensor(
+        np.random.rand(100, 1),
+        dtype=flow.float32,
+    ),
+    device="llvm",
+):
+    if device == "cuda":
+        model.to(device)
+        inputs = inputs.to(device)
+
+    graph = OneFlowGraph(model)
+    graph._compile(inputs)
+
+    mkdir(MODEL_HOME)
+    flow.save(model.state_dict(), MODEL_HOME)
+
+    out_flow = get_oneflow_output(graph, inputs)
+    out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device)
+    rmdir(MODEL_HOME)
+
+    assert_shape(out_flow, out_tvm)
+    tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
+
+
+def verify_concat(
+    model,
+    name="",
+    rtol=1e-5,
+    atol=1e-5,
+    inputs1=flow.tensor(np.random.randn(2, 5, 5, 4), dtype=flow.float32),
+    inputs2=flow.tensor(np.random.randn(2, 5, 5, 2), dtype=flow.float32),
+    inputs3=flow.tensor(np.random.randn(2, 5, 5, 3), dtype=flow.float32),
+    device="llvm",
+):
+    if device == "cuda":
+        model.to(device)
+        inputs1 = inputs1.to(device)
+        inputs2 = inputs2.to(device)
+        inputs3 = inputs3.to(device)
+
+    graph = OneFlowGraph_v2(model)
+    graph._compile(inputs1, inputs2, inputs3)
+
+    mkdir(MODEL_HOME)
+    flow.save(model.state_dict(), MODEL_HOME)
+
+    out_flow = get_oneflow_concat_output(graph, inputs1, inputs2, inputs3)
+    out_tvm = get_tvm_concat_output(graph, MODEL_HOME, inputs1, inputs2, inputs3, target=device)
+    rmdir(MODEL_HOME)
+
+    assert_shape(out_flow, out_tvm)
+    tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol)
+
+
+# defs/nn
+@tvm.testing.uses_gpu
+def test_conv2d():
+    class Conv2dModel(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = flow.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
+
+        def forward(self, x):
+            x = self.conv(x)
+            return x
+
+    if os.path.exists(MODEL_HOME):
+        rmdir(MODEL_HOME)
+
+    model = Conv2dModel()
+    model.eval()
+
+    for device in ["llvm"]:
+        verify_conv(model, device=device)
+
+
+@tvm.testing.uses_gpu
+def test_pool2d():
+    class MaxPool2dModel(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = flow.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+        def forward(self, x):
+            x = self.pool(x)
+            return x
+
+    class AvgPool2dModel(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = flow.nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+
+        def forward(self, x):
+            x = self.pool(x)
+            return x
+
+    class AdaptiveAvgPool2dModel(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = flow.nn.AdaptiveAvgPool2d((None, 7))
+
+        def forward(self, x):
+            x = self.pool(x)
+            return x
+
+    if os.path.exists(MODEL_HOME):
+        rmdir(MODEL_HOME)
+
+    model1 = MaxPool2dModel().eval()
+    model2 = AvgPool2dModel().eval()
+    model3 = AdaptiveAvgPool2dModel().eval()
+
+    for device in ["llvm"]:
+        verify_pool(model1, device=device)
+        verify_pool(model2, device=device)
+        verify_pool(model3, device=device)
+
+
+@tvm.testing.uses_gpu
+def test_normalization():
+    class BatchNorm2dModel(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.normalization = flow.nn.BatchNorm2d(3)
+
+        def forward(self, x):
+            x = self.normalization(x)
+            return x
+
+    if os.path.exists(MODEL_HOME):
+        rmdir(MODEL_HOME)
+
+    model = BatchNorm2dModel().eval()
+
+    for device in ["llvm"]:
+        verify_normalization(model, device=device)
+
+
+@tvm.testing.uses_gpu
+def test_upsample():
+    class UpsampleModel(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.upsample = flow.nn.Upsample(scale_factor=2.0, mode="nearest")
+
+        def forward(self, x):
+            x = self.upsample(x)
+            return x
+
+    class UpsampleBiliModel(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.upsample = flow.nn.UpsamplingBilinear2d(scale_factor=2.0)
+
+        def forward(self, x):
+            x = self.upsample(x)
+            return x
+
+    if os.path.exists(MODEL_HOME):
+        rmdir(MODEL_HOME)
+
+    model1 = UpsampleModel().eval()
+    model2 = UpsampleBiliModel().eval()
+
+    for device in ["llvm"]:
+        verify_upsample(model1, device=device)
+        verify_upsample(model2, device=device)
+
+
+@tvm.testing.uses_gpu
+def test_convtran():
+    class ConvTranModel(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.convtran = flow.nn.ConvTranspose2d(3, 4, (3, 5), stride=(2, 1), padding=(4, 2))
+
+        def forward(self, x):
+            x = self.convtran(x)
+            return x
+
+    if os.path.exists(MODEL_HOME):
+        rmdir(MODEL_HOME)
+
+    model = ConvTranModel().eval()
+
+    for device in ["llvm"]:
+        verify_convtran(model, device=device)
+
+
+@tvm.testing.uses_gpu
+def test_activation():
+    class Softmax(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.active = flow.nn.Softmax()
+
+        def forward(self, x):
+            x = self.active(x)
+            return x
+
+    class Softplus(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.active = flow.nn.Softplus()
+
+        def forward(self, x):
+            x = self.active(x)
+            return x
+
+    class Softsign(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.active = flow.nn.Softsign()
+
+        def forward(self, x):
+            x = self.active(x)
+            return x
+
+    class Tanh(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.active = flow.nn.Tanh()
+
+        def forward(self, x):
+            x = self.active(x)
+            return x
+
+    class ReLU(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.active = flow.nn.ReLU()
+
+        def forward(self, x):
+            x = self.active(x)
+            return x
+
+    class ReLU6(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.active = flow.nn.ReLU6()
+
+        def forward(self, x):
+            x = self.active(x)
+            return x
+
+    class PReLU(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.active = flow.nn.PReLU()
+
+        def forward(self, x):
+            x = self.active(x)
+            return x
+
+    class SELU(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.active = flow.nn.SELU()
+
+        def forward(self, x):
+            x = self.active(x)
+            return x
+
+    class SiLU(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.active = flow.nn.SiLU()
+
+        def forward(self, x):
+            x = self.active(x)
+            return x
+
+    class LeakyReLU(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.active = flow.nn.LeakyReLU(0.1)
+
+        def forward(self, x):
+            x = self.active(x)
+            return x
+
+    class GELU(flow.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.active = flow.nn.GELU()
+
+        def forward(self, x):
+            x = self.active(x)
+            return x
+
+    if os.path.exists(MODEL_HOME):
+        rmdir(MODEL_HOME)
+
+    model1 = Softmax().eval()
+    model2 = Softplus().eval()
+    model3 = Softsign().eval()
+    model4 = Tanh().eval()
+    model5 = ReLU().eval()
+    model6 = ReLU6().eval()
+    model7 = PReLU().eval()
+    model8 = SELU().eval()
+    model9 = SiLU().eval()
+    model10 = LeakyReLU().eval()
+    model11 = GELU().eval()
+
+    for device in ["llvm"]:
+        verify_activation(model1, device=device)
+        # verify_activation(model2, device=device) # NO PASS
+        verify_activation(model3, device=device)
+        verify_activation(model4, device=device)
+        verify_activation(model5, device=device)
+        verify_activation(model6, device=device)
+        verify_activation(model7, device=device)
+        verify_activation(model8, device=device)
+        verify_activation(model9, device=device)
+        verify_activation(model10, device=device)
+        verify_activation(model11, device=device)
+
+
+@tvm.testing.uses_gpu
+def test_math():
+    class Sigmoid(flow.nn.Module):
+        def forward(self, x):
+            return flow.sigmoid(x)
+
+    class Sign(flow.nn.Module):
+        def forward(self, x):
+            return flow.sign(x)
+
+    class Reciprocal(flow.nn.Module):
+        def forward(self, x):
+            return flow.reciprocal(x)
+
+    class Pow(flow.nn.Module):
+        def forward(self, x):
+            return flow.pow(x, 2.0)
+
+    class Log(flow.nn.Module):
+        def forward(self, x):
+            return flow.log(x)
+
+    class Log2(flow.nn.Module):
+        def forward(self, x):
+            return flow.log1p(x)
+
+    class Exp(flow.nn.Module):
+        def forward(self, x):
+            return flow.exp(x)
+
+    class Exp2(flow.nn.Module):
+        def forward(self, x):
+            return flow.expm1(x)
+
+    model1 = Sigmoid().eval()
+    model2 = Sign().eval()
+    model3 = Log().eval()
+    model4 = Log2().eval()
+    model5 = Exp().eval()
+    model6 = Exp2().eval()
+
+    for device in ["llvm"]:
+        verify_math(model1, device=device)
+        verify_math(model2, device=device)
+        verify_math(model3, device=device)
+        verify_math(model4, device=device)
+        verify_math(model5, device=device)
+        verify_math(model6, device=device)
+
+
+@tvm.testing.uses_gpu
+def test_slice():
+    class Slice(flow.nn.Module):
+        def forward(self, x):
+            tup_list = [[None, None, None], [0, 5, 2], [0, 6, 3]]
+            out = flow.slice(x, slice_tup_list=tup_list)
+            return out
+
+    model = Slice().eval()
+
+    for device in ["llvm"]:
+        verify_math(
+            model, device=device, inputs=flow.tensor(np.random.randn(3, 6, 9).astype(np.float32))
+        )
+
+
+@tvm.testing.uses_gpu
+def test_concat():
+    class Concat(flow.nn.Module):
+        def forward(self, x1, x2, x3):
+            out = flow.cat([x1, x2, x3], dim=-1)
+            return out
+
+    model = Concat().eval()
+
+    for device in ["llvm"]:
+        verify_concat(model, device=device)
+
+
+if __name__ == "__main__":
+    test_conv2d()
+    test_pool2d()
+    test_normalization()
+    test_upsample()
+    test_convtran()
+    test_activation()
+    test_math()
+    test_slice()
+    test_concat()
+    rmdir("log")
diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh
index bbcba37c6d..2c7e34fac5 100755
--- a/tests/scripts/task_python_frontend.sh
+++ b/tests/scripts/task_python_frontend.sh
@@ -58,3 +58,6 @@ run_pytest cython python-frontend-paddlepaddle tests/python/frontend/paddlepaddl
 
 echo "Running relay CoreML frontend test..."
 run_pytest cython python-frontend-coreml tests/python/frontend/coreml
+
+echo "Running relay OneFlow frontend test..."
+run_pytest cython python-frontend-oneflow tests/python/frontend/oneflow