You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2021/03/24 02:40:50 UTC
[tvm] branch main updated: [ONNX] Onnx node tests (#7720)
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8131364 [ONNX] Onnx node tests (#7720)
8131364 is described below
commit 813136401a11a49d6c15e6013c34dd822a5c4ff6
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Tue Mar 23 20:40:32 2021 -0600
[ONNX] Onnx node tests (#7720)
* WIP
* some fixes
* more fixes
* fix some conv_transpose tests
* fix out of bounds slice
* fix flatten import
* fix logsoftmax and softmax tests
* fix Error in Upsample
* fix onehot
* normalize errors
* fix gather with negative indices
* parameterize test
* skip unsupported tests
* clean up
* fix rebase
* fix lint
* add an error message when we find an un-identified tensor
---
python/tvm/relay/frontend/onnx.py | 133 +++++++++++++++++------
python/tvm/relay/op/transform.py | 7 +-
tests/python/frontend/onnx/test_forward.py | 163 +++++++++++++++++++++++++++++
3 files changed, 269 insertions(+), 34 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index fab4ae8..d9fc2ff 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -103,10 +103,11 @@ def get_numpy(tensor_proto):
def get_type(elem_type):
"""Converts onnx integer datatype to numpy datatype"""
try:
- from onnx import TensorProto
+ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
except ImportError as e:
raise ImportError("Unable to import onnx which is required {}".format(e))
- return TensorProto.DataType.Name(elem_type).lower()
+
+ return str(TENSOR_TYPE_TO_NP_TYPE[elem_type])
def get_info(info_proto):
@@ -157,7 +158,7 @@ def revert_caffe2_pad(pads):
return pads
-def get_pad_pair(input1d, kernel1d, stride1d):
+def get_pad_pair(input1d, kernel1d, stride1d, mode):
"""infer pad size"""
if input1d % stride1d == 0:
pad = max(kernel1d - stride1d, 0)
@@ -165,6 +166,8 @@ def get_pad_pair(input1d, kernel1d, stride1d):
pad = max(kernel1d - (input1d % stride1d), 0)
pad_before = pad // 2
pad_after = pad - pad_before
+ if "LOWER" in mode:
+ return [pad_after, pad_before]
return [pad_before, pad_after]
@@ -280,9 +283,9 @@ class Pool(OnnxOpConverter):
pad_tuple = []
for axis in range(len(input_shape) - 2):
axis_shape = input_shape[2 + axis]
- stride = attr["strides"][axis]
+ stride = attr.get("strides", [1] * ndim)[axis]
kernel = attr["kernel_shape"][axis]
- pad = get_pad_pair(axis_shape, kernel, stride)
+ pad = get_pad_pair(axis_shape, kernel, stride, attr["auto_pad"])
pad_tuple.append(pad)
pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
attr["pads"] = pad_tuple
@@ -444,9 +447,15 @@ class ConvTranspose(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
- channels = infer_channels(inputs[1], True)
+ out_type = infer_type(inputs[1])
+ out_shapes = [get_const_tuple(out_type.checked_type.shape)]
+ channels = out_shapes[0][1]
attr["channels"] = channels
groups = attr.get("group", 1)
+
+ if "kernel_shape" not in attr:
+ attr["kernel_shape"] = out_shapes[0][2:]
+
attr["groups"] = groups
# infer pads for auto_pad
data = inputs[0]
@@ -528,13 +537,11 @@ class Gemm(OnnxOpConverter):
if not transB:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
inputs[0] = _op.nn.batch_flatten(inputs[0])
-
if alpha != 1.0:
inputs[0] *= _expr.const(alpha)
out = _op.nn.dense(inputs[0], inputs[1], units=channels)
-
if len(inputs) == 3:
- return _op.nn.bias_add(out, _expr.const(beta) * inputs[2])
+ out = out + _expr.const(beta) * inputs[2]
return out
@@ -618,7 +625,7 @@ class Mod(OnnxOpConverter):
# Note: attr['fmod'] determines whether the operator should behave like np.fmod or np.mod.
# attr['fmod'] == 0 will behave as np.mod and attr['fmod'] == 1 will force fmod treatment.
# The relay equivalent of np.fmod is relay.mod and np.mod is relay.floor_mod
- if attr["fmod"] == 0:
+ if attr.get("fmod", 0) == 0:
op_name = "floor_mod"
else:
op_name = "mod"
@@ -849,12 +856,18 @@ class Flatten(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axis", 1)
+ ishape = _op.shape_of(inputs[0])
+ ndim = infer_shape(ishape)[0]
+ if axis < 0:
+ axis = axis + ndim
+
if axis == 1:
out = _op.nn.batch_flatten(inputs[0])
else:
- newshape = [0] * (axis + 1)
- newshape[axis] = -1
- out = _op.reshape(inputs[0], list(newshape))
+ pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True)
+ post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True)
+ newshape = _op.concatenate([pre_shape, post_shape], axis=0)
+ out = _op.reshape(inputs[0], newshape)
return out
@@ -1036,7 +1049,7 @@ class Upsample(OnnxOpConverter):
# in 3d case, we use the purely static op
if dims == 5:
- if isinstance(scales, _expr.Call):
+ if isinstance(scales, _expr.Expr):
scale_h = _op.take(scales, _op.const(3))
scale_w = _op.take(scales, _op.const(4))
scale_d = _op.take(scales, _op.const(1))
@@ -1052,7 +1065,7 @@ class Upsample(OnnxOpConverter):
)
# in 2d case, use dynamic op
else:
- if isinstance(scales, _expr.Call):
+ if isinstance(scales, _expr.Expr):
scale_h = _op.take(scales, _op.const(3))
scale_w = _op.take(scales, _op.const(4))
else:
@@ -1247,7 +1260,13 @@ class Gather(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axis", 0)
- return AttrCvt("take", extras={"axis": axis})(inputs, {})
+ data = inputs[0]
+ indices = inputs[1]
+ ind_dtype = infer_type(indices).checked_type.dtype
+ # Normalize the indices to a positive range
+ s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis))
+ indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices)
+ return _op.take(data, indices, axis)
class GatherElements(OnnxOpConverter):
@@ -1258,6 +1277,10 @@ class GatherElements(OnnxOpConverter):
data = inputs[0]
indices = inputs[1]
axis = attr.get("axis", 0)
+ ind_dtype = infer_type(indices).checked_type.dtype
+ # Normalize the indices to a positive range
+ s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis))
+ indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices)
return _op.gather(data, axis, indices)
@@ -1318,8 +1341,8 @@ class Maximum(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
- if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
- raise ValueError("Expect minimum 2 inputs")
+ if len(inputs) == 1:
+ return inputs[0]
_max = inputs[0]
for i in range(1, len(inputs)):
_max = AttrCvt("maximum")([_max, inputs[i]], {})
@@ -1331,8 +1354,8 @@ class Minimum(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
- if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
- raise ValueError("Expect minimum 2 inputs")
+ if len(inputs) == 1:
+ return inputs[0]
_min = inputs[0]
for i in range(1, len(inputs)):
_min = AttrCvt("minimum")([_min, inputs[i]], {})
@@ -1344,8 +1367,8 @@ class Mean(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
- if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
- raise ValueError("Expect minimum 2 inputs")
+ if len(inputs) == 1:
+ return inputs[0]
# avoid overflow
concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
return _op.mean(concat, axis=0, keepdims=False)
@@ -1485,6 +1508,8 @@ class ArgMax(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
+ if "select_last_index" in attr:
+ raise NotImplementedError("select_last_index not supported in ArgMax")
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", True)
attr = {"axis": axis, "keepdims": keepdims}
@@ -1496,6 +1521,8 @@ class ArgMin(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
+ if "select_last_index" in attr:
+ raise NotImplementedError("select_last_index not supported in ArgMin")
axis = attr.get("axis", 0)
keepdims = attr.get("keepdims", True)
attr = {"axis": axis, "keepdims": keepdims}
@@ -1510,7 +1537,35 @@ class Softmax(OnnxOpConverter):
# set default value when axis is not set in the model
if "axis" not in attr:
attr["axis"] = 1
- return AttrCvt("softmax", transforms={"axis": ("axis", 1)})(inputs, attr, params)
+ axis = attr["axis"]
+ ndim = len(infer_shape(inputs[0]))
+ if axis < 0:
+ axis += ndim
+ axes = list(range(axis, ndim))
+ x = inputs[0]
+ m = _op.max(x, axes, keepdims=True)
+ e = _op.exp(x - m)
+ return e / _op.sum(e, axes, keepdims=True)
+
+
+class LogSoftmax(OnnxOpConverter):
+ """Operator converter for Softmax."""
+
+ @classmethod
+ def _impl_v1(cls, inputs, attr, params):
+ # set default value when axis is not set in the model
+ if "axis" not in attr:
+ attr["axis"] = 1
+ axis = attr["axis"]
+ ndim = len(infer_shape(inputs[0]))
+ if axis < 0:
+ axis += ndim
+ axes = list(range(axis, ndim))
+ x = inputs[0]
+ m = _op.max(x, axes, keepdims=True)
+ e = _op.exp(x - m)
+ s = _op.sum(e, axes, keepdims=True)
+ return x - m - _op.log(s)
class OneHot(OnnxOpConverter):
@@ -1520,14 +1575,24 @@ class OneHot(OnnxOpConverter):
def _impl_v9(cls, inputs, attr, params):
# Extract relay one_hot inputs.
indices, depth, values = inputs
+ ndim = len(infer_shape(indices))
# Split onnx on off values into two separate expressions.
off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1))
# Extract the datatype of the output from on_value.
dtype = infer_type(on_value).checked_type.dtype
+ ind_dtype = infer_type(indices).checked_type.dtype
+ # Normalize the indices to a positive range
+ indices = _op.where(
+ indices < _op.const(0, ind_dtype), indices + _op.cast(depth, ind_dtype), indices
+ )
# set default value when axis is not set in the model
if "axis" not in attr:
attr["axis"] = -1
- return _op.one_hot(indices, on_value, off_value, depth, int(attr["axis"]), dtype=dtype)
+ axis = attr["axis"]
+ if axis < 0:
+ axis += ndim + 1
+
+ return _op.one_hot(indices, on_value, off_value, depth, axis, dtype=dtype)
class ConstantOfShape(OnnxOpConverter):
@@ -1552,7 +1617,7 @@ class Constant(OnnxOpConverter):
@classmethod
def _impl_v9(cls, inputs, attr, params):
if "value" not in attr:
- raise "No Value in Constant"
+ raise tvm.errors.OpAttributeRequired("no value in Constant")
np_value = get_numpy(attr.pop("value"))
dtype = np_value.dtype.name
value = _expr.const(np_value, dtype)
@@ -2042,7 +2107,7 @@ class TopK(OnnxOpConverter):
largest = attr.get("largest", 1)
if largest == 0:
- raise ValueError("TVM only supports finding TopK largest elements")
+ raise NotImplementedError("TVM only supports finding TopK largest elements")
return _op.topk(inputs[0], inputs[1], axis=axis, dtype="int64")
@@ -2087,7 +2152,7 @@ class RoiAlign(OnnxOpConverter):
batch_indices = inputs[2]
mode = attr.get("mode", b"avg")
if mode not in (b"avg", b"max"):
- raise ValueError("RoiAlign in Relay only uses avg and max modes")
+ raise NotImplementedError("RoiAlign in Relay only uses avg and max modes")
output_height = attr.get("output_height", 1)
output_width = attr.get("output_width", 1)
@@ -2128,7 +2193,8 @@ class Clip(OnnxOpConverter):
result = inputs[0]
for i, op in enumerate([_op.tensor.maximum, _op.tensor.minimum]):
if i < len(inputs) - 1:
- result = op(result, inputs[i + 1])
+ if inputs[i + 1] is not None:
+ result = op(result, inputs[i + 1])
return result
@@ -2393,9 +2459,10 @@ class NonMaxSuppression(OnnxOpConverter):
dtype = infer_type(boxes).checked_type.dtype
if "center_point_box" in attr:
- assert (
- attr["center_point_box"] == 0
- ), "Only support center_point_box = 0 in onnx importer right now"
+ if attr["center_point_box"] != 0:
+ raise NotImplementedError(
+ "Only support center_point_box = 0 in ONNX NonMaxSuprresion"
+ )
if iou_threshold is None:
iou_threshold = _expr.const(0.0, dtype="float32")
@@ -2718,7 +2785,7 @@ def _get_convert_map(opset):
"Softplus": Softplus.get_converter(opset),
# softmax default axis is different in onnx
"Softmax": Softmax.get_converter(opset),
- "LogSoftmax": AttrCvt("log_softmax", {"axis": ("axis", 1)}),
+ "LogSoftmax": LogSoftmax.get_converter(opset),
"OneHot": OneHot.get_converter(opset),
# 'Hardmax'
"Softsign": Softsign.get_converter(opset),
@@ -2958,6 +3025,8 @@ class GraphProto:
for i in node.input:
if i != "":
inputs[i] = self._nodes[self._renames.get(i, i)]
+ else:
+ inputs[i] = None
i_name = self._parse_value_proto(node)
node_output = self._fix_outputs(op_name, node.output)
attr["tvm_custom"] = {}
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 4129b61..df0ae76 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -905,10 +905,13 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
- normalized_begin = _make.where(
+ begin = _make.where(
begin < cast_like(const(0), begin), begin + cast_like(shape_of(data), begin), begin
)
- return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode)
+ begin = _make.where(
+ begin >= cast_like(shape_of(data), begin), cast_like(shape_of(data), begin), begin
+ )
+ return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)
return _make.strided_slice(data, begin, end, strides, slice_mode)
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 5a6216a..ec89a3d 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4090,6 +4090,169 @@ def test_cumsum():
verify_cumsum(data, 1, 1, 1, type="int32")
+from onnx import numpy_helper
+
+f = onnx.__file__
+import glob
+
+onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/"))
+
+unsupported_onnx_tests = [
+ "test_basic_convinteger/",
+ "test_bitshift_left_uint16/",
+ "test_bitshift_left_uint32/",
+ "test_bitshift_left_uint64/",
+ "test_bitshift_left_uint8/",
+ "test_bitshift_right_uint16/",
+ "test_bitshift_right_uint32/",
+ "test_bitshift_right_uint64/",
+ "test_bitshift_right_uint8/",
+ "test_cast_DOUBLE_to_FLOAT16/",
+ "test_cast_FLOAT16_to_DOUBLE/",
+ "test_cast_FLOAT16_to_FLOAT/",
+ "test_cast_FLOAT_to_FLOAT16/",
+ "test_cast_FLOAT_to_STRING/",
+ "test_cast_STRING_to_FLOAT/",
+ "test_compress_0/",
+ "test_compress_1/",
+ "test_compress_default_axis/",
+ "test_compress_negative_axis/",
+ "test_convinteger_with_padding/",
+ "test_convtranspose_dilations/",
+ "test_convtranspose_output_shape/",
+ "test_cumsum_1d/",
+ "test_cumsum_1d_exclusive/",
+ "test_cumsum_1d_reverse/",
+ "test_cumsum_1d_reverse_exclusive/",
+ "test_cumsum_2d_axis_0/",
+ "test_cumsum_2d_axis_1/",
+ "test_cumsum_2d_negative_axis/",
+ "test_dequantizelinear/",
+ "test_det_2d/",
+ "test_det_nd/",
+ "test_dynamicquantizelinear/",
+ "test_dynamicquantizelinear_expanded/",
+ "test_dynamicquantizelinear_max_adjusted/",
+ "test_dynamicquantizelinear_max_adjusted_expanded/",
+ "test_dynamicquantizelinear_min_adjusted/",
+ "test_dynamicquantizelinear_min_adjusted_expanded/",
+ "test_eyelike_populate_off_main_diagonal/",
+ "test_eyelike_with_dtype/",
+ "test_eyelike_without_dtype/",
+ "test_hardmax_axis_0/",
+ "test_hardmax_axis_1/",
+ "test_hardmax_axis_2/",
+ "test_hardmax_default_axis/",
+ "test_hardmax_example/",
+ "test_hardmax_negative_axis/",
+ "test_hardmax_one_hot/",
+ "test_isinf_negative/",
+ "test_isinf_positive/",
+ "test_lstm_defaults/",
+ "test_lstm_with_initial_bias/",
+ "test_lstm_with_peepholes/",
+ "test_matmulinteger/",
+ "test_maxpool_2d_dilations/",
+ "test_maxpool_2d_same_lower/",
+ "test_maxpool_2d_same_upper/",
+ "test_maxpool_with_argmax_2d_precomputed_pads/",
+ "test_maxpool_with_argmax_2d_precomputed_strides/",
+ "test_maxunpool_export_with_output_shape/",
+ "test_mvn/",
+ "test_nonmaxsuppression_center_point_box_format/",
+ "test_qlinearconv/",
+ "test_qlinearmatmul_2D/",
+ "test_qlinearmatmul_3D/",
+ "test_quantizelinear/",
+ "test_range_float_type_positive_delta_expanded/",
+ "test_range_int32_type_negative_delta_expanded/",
+ "test_resize_downsample_scales_cubic/",
+ "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside/",
+ "test_resize_downsample_scales_cubic_align_corners/",
+ "test_resize_downsample_scales_linear/",
+ "test_resize_downsample_scales_nearest/",
+ "test_resize_downsample_sizes_cubic/",
+ "test_resize_downsample_sizes_linear_pytorch_half_pixel/",
+ "test_resize_downsample_sizes_nearest/",
+ "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn/",
+ "test_resize_tf_crop_and_resize/",
+ "test_resize_upsample_scales_cubic/",
+ "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside/",
+ "test_resize_upsample_scales_cubic_align_corners/",
+ "test_resize_upsample_scales_cubic_asymmetric/",
+ "test_resize_upsample_scales_linear/",
+ "test_resize_upsample_sizes_cubic/",
+ "test_resize_upsample_sizes_nearest_ceil_half_pixel/",
+ "test_resize_upsample_sizes_nearest_floor_align_corners/",
+ "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/",
+ "test_reversesequence_batch/",
+ "test_reversesequence_time/",
+ "test_rnn_seq_length/",
+ "test_roialign/",
+ "test_round/",
+ "test_scan9_sum/",
+ "test_scan_sum/",
+ "test_scatternd/",
+ "test_selu_default/",
+ "test_shrink_hard/",
+ "test_shrink_soft/",
+ "test_simple_rnn_defaults/",
+ "test_simple_rnn_with_initial_bias/",
+ "test_slice_neg_steps/",
+ "test_slice_start_out_of_bounds/",
+ "test_strnormalizer_export_monday_casesensintive_lower/",
+ "test_strnormalizer_export_monday_casesensintive_nochangecase/",
+ "test_strnormalizer_export_monday_casesensintive_upper/",
+ "test_strnormalizer_export_monday_empty_output/",
+ "test_strnormalizer_export_monday_insensintive_upper_twodim/",
+ "test_strnormalizer_nostopwords_nochangecase/",
+ "test_tfidfvectorizer_tf_batch_onlybigrams_skip0/",
+ "test_tfidfvectorizer_tf_batch_onlybigrams_skip5/",
+ "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5/",
+ "test_tfidfvectorizer_tf_only_bigrams_skip0/",
+ "test_tfidfvectorizer_tf_onlybigrams_levelempty/",
+ "test_tfidfvectorizer_tf_onlybigrams_skip5/",
+ "test_tfidfvectorizer_tf_uniandbigrams_skip5/",
+ "test_top_k_smallest/",
+ "test_unique_not_sorted_without_axis/",
+ "test_unique_sorted_with_axis/",
+ "test_unique_sorted_with_axis_3d/",
+ "test_unique_sorted_with_negative_axis/",
+ "test_unique_sorted_without_axis/",
+ "test_unsqueeze_unsorted_axes/",
+ "test_upsample_nearest/",
+]
+
+
+@pytest.mark.parametrize("test", onnx_test_folders)
+def test_onnx_nodes(test):
+ for failure in unsupported_onnx_tests:
+ if failure in test:
+ pytest.skip()
+ break
+ onnx_model = onnx.load(test + "/model.onnx")
+ inputs = []
+ outputs = []
+ for dataset in glob.glob(test + "/*/"):
+ tensors = sorted(glob.glob(dataset + "/*.pb"))
+ for tensor in tensors:
+ new_tensor = onnx.TensorProto()
+ with open(tensor, "rb") as f:
+ new_tensor.ParseFromString(f.read())
+ if "input" in tensor.split("/")[-1]:
+ inputs.append(numpy_helper.to_array(new_tensor))
+ elif "output" in tensor.split("/")[-1]:
+ outputs.append(numpy_helper.to_array(new_tensor))
+ else:
+ raise ImportError(str(tensor) + " not labeled as an import or an output")
+ tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0))
+ if len(outputs) == 1:
+ tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5)
+ else:
+ for output, val in zip(outputs, tvm_val):
+ tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5)
+
+
def test_wrong_input():
node = helper.make_node(
"Softplus",