You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@singa.apache.org by ch...@apache.org on 2020/04/09 12:27:48 UTC
[singa] branch master updated: add mean square error back and fix
test case
This is an automated email from the ASF dual-hosted git repository.
chrishkchris pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/singa.git
The following commit(s) were added to refs/heads/master by this push:
new b5e77bb add mean square error back and fix test case
new 2483326 Merge pull request #672 from joddiy/fix
b5e77bb is described below
commit b5e77bb3673cfebdddeb3b165509bde4f5a8e591
Author: joddiy <jo...@qq.com>
AuthorDate: Thu Apr 9 19:26:42 2020 +0800
add mean square error back and fix test case
---
examples/onnx/arcface.py | 1 +
examples/onnx/fer_emotion.py | 3 +-
examples/onnx/mnist.py | 8 +-
examples/onnx/mobilenet.py | 1 +
examples/onnx/resnet18.py | 3 +-
examples/onnx/tiny_yolov2.py | 1 -
examples/onnx/utils.py | 1 +
examples/onnx/vgg16.py | 3 +-
python/singa/autograd.py | 95 ++++++++++------
test/python/test_onnx_backend.py | 226 ++++++++++++++++++++++-----------------
test/python/test_operation.py | 43 ++++++--
11 files changed, 236 insertions(+), 149 deletions(-)
diff --git a/examples/onnx/arcface.py b/examples/onnx/arcface.py
index 933e1b1..e1cfa18 100644
--- a/examples/onnx/arcface.py
+++ b/examples/onnx/arcface.py
@@ -32,6 +32,7 @@ from utils import download_model, update_batch_size, check_exist_or_download
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
+
def preprocess(img):
w, h = img.size
img = img.crop((0, (h - w) // 2, w, h - (h - w) // 2))
diff --git a/examples/onnx/fer_emotion.py b/examples/onnx/fer_emotion.py
index adbadaf..46c0142 100644
--- a/examples/onnx/fer_emotion.py
+++ b/examples/onnx/fer_emotion.py
@@ -30,6 +30,7 @@ from utils import download_model, update_batch_size, check_exist_or_download
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
+
def preprocess(img):
input_shape = (1, 1, 64, 64)
img = img.resize((64, 64), Image.ANTIALIAS)
@@ -84,7 +85,7 @@ if __name__ == "__main__":
autograd.training = False
model = Infer(sg_ir)
- # verifty the test
+ # verifty the test
# from utils import load_dataset
# inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'emotion_ferplus', 'test_data_set_0'))
# x_batch = tensor.Tensor(device=dev, data=inputs[0])
diff --git a/examples/onnx/mnist.py b/examples/onnx/mnist.py
index d4f7d2f..cf36727 100644
--- a/examples/onnx/mnist.py
+++ b/examples/onnx/mnist.py
@@ -21,7 +21,6 @@ import gzip
import numpy as np
import codecs
-
from singa import device
from singa import tensor
from singa import opt
@@ -33,6 +32,7 @@ from utils import check_exist_or_download
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
+
def load_dataset():
train_x_url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
train_y_url = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
@@ -143,7 +143,7 @@ def train(model,
if b % 1e2 == 0:
logging.info("acc %6.2f loss, %6.2f" %
- (accuracy_rate, tensor.to_numpy(loss)[0]))
+ (accuracy_rate, tensor.to_numpy(loss)[0]))
logging.info("training completed")
return x_batch, output_batch
@@ -196,7 +196,7 @@ def re_train(sg_ir,
if b % 1e2 == 0:
logging.info("acc %6.2f loss, %6.2f" %
- (accuracy_rate, tensor.to_numpy(loss)[0]))
+ (accuracy_rate, tensor.to_numpy(loss)[0]))
logging.info("re-training completed")
return new_model
@@ -251,7 +251,7 @@ def transfer_learning(sg_ir,
if b % 1e2 == 0:
logging.info("acc %6.2f loss, %6.2f" %
- (accuracy_rate, tensor.to_numpy(loss)[0]))
+ (accuracy_rate, tensor.to_numpy(loss)[0]))
logging.info("transfer-learning completed")
return trans_model
diff --git a/examples/onnx/mobilenet.py b/examples/onnx/mobilenet.py
index e9fd90c..75758f1 100644
--- a/examples/onnx/mobilenet.py
+++ b/examples/onnx/mobilenet.py
@@ -30,6 +30,7 @@ from utils import download_model, update_batch_size, check_exist_or_download
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
+
def preprocess(img):
img = img.resize((256, 256))
img = img.crop((16, 16, 240, 240))
diff --git a/examples/onnx/resnet18.py b/examples/onnx/resnet18.py
index c0ef13a..b3381c0 100644
--- a/examples/onnx/resnet18.py
+++ b/examples/onnx/resnet18.py
@@ -30,6 +30,7 @@ from utils import download_model, update_batch_size, check_exist_or_download
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
+
def preprocess(img):
img = img.resize((256, 256))
img = img.crop((16, 16, 240, 240))
@@ -89,7 +90,7 @@ if __name__ == "__main__":
autograd.training = False
model = Infer(sg_ir)
- # verifty the test
+ # verifty the test
# from utils import load_dataset
# inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'resnet18v1', 'test_data_set_0'))
# x_batch = tensor.Tensor(device=dev, data=inputs[0])
diff --git a/examples/onnx/tiny_yolov2.py b/examples/onnx/tiny_yolov2.py
index 8aff769..e883117 100644
--- a/examples/onnx/tiny_yolov2.py
+++ b/examples/onnx/tiny_yolov2.py
@@ -20,7 +20,6 @@ import os
import numpy as np
from PIL import Image, ImageDraw
-
from singa import device
from singa import tensor
from singa import autograd
diff --git a/examples/onnx/utils.py b/examples/onnx/utils.py
index aff4492..71d1ef4 100644
--- a/examples/onnx/utils.py
+++ b/examples/onnx/utils.py
@@ -25,6 +25,7 @@ from onnx import numpy_helper
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
+
def download_model(url):
download_dir = '/tmp/'
with tarfile.open(check_exist_or_download(url), 'r') as t:
diff --git a/examples/onnx/vgg16.py b/examples/onnx/vgg16.py
index d97b025..b26ea94 100644
--- a/examples/onnx/vgg16.py
+++ b/examples/onnx/vgg16.py
@@ -30,6 +30,7 @@ from utils import download_model, update_batch_size, check_exist_or_download
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
+
def preprocess(img):
img = img.resize((256, 256))
img = img.crop((16, 16, 240, 240))
@@ -88,7 +89,7 @@ if __name__ == "__main__":
autograd.training = False
model = Infer(sg_ir)
- # verifty the test
+ # verifty the test
# from utils import load_dataset
# inputs, ref_outputs = load_dataset(os.path.join('/tmp', 'vgg16', 'test_data_set_0'))
# x_batch = tensor.Tensor(device=dev, data=inputs[0])
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index 3eb7c2b..d26e794 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -586,6 +586,7 @@ def identity(x):
"""
return Identity()(x)[0]
+
class Matmul(Operation):
"""
Init matrix multiplication operator.
@@ -711,6 +712,7 @@ def add_bias(x, b, axis=0):
"""
return AddBias(axis)(x, b)[0]
+
class Reshape(Operation):
"""
Reshape the input tensor similar to np.reshape.
@@ -887,6 +889,7 @@ class Elu(Operation):
`f(x) = alpha * (exp(x) - 1.)` for x < 0, `f(x) = x` for x >= 0., is applied to
the tensor elementwise.
"""
+
def __init__(self, alpha=1.):
"""
Args:
@@ -948,6 +951,7 @@ class Equal(Operation):
Returns the tensor resulted from performing the equal logical operation
elementwise on the input tensors x and y.
"""
+
def __init__(self):
super(Equal, self).__init__()
@@ -1209,6 +1213,29 @@ def softmax_cross_entropy(x, t):
return SoftMaxCrossEntropy(t)(x)[0]
+class MeanSquareError(Operation):
+
+ def __init__(self):
+ super(MeanSquareError, self).__init__()
+
+ def forward(self, x, t):
+ self.err = singa.__sub__(x, t)
+ sqr = singa.Square(self.err)
+ loss = singa.SumAll(sqr)
+ loss /= (x.shape()[0] * 2)
+ return loss
+
+ def backward(self, dy=1.0):
+ dx = self.err
+ dx *= float(1 / self.err.shape()[0])
+ dx *= dy
+ return dx, None
+
+
+def mse_loss(x, t):
+ return MeanSquareError()(x, t)[0]
+
+
def ctensor2numpy(x):
"""
To be used in SoftMax Operation.
@@ -1252,9 +1279,8 @@ class Flatten(Operation):
shape, axis = self.shape, self.axis
# the axis must be within this range (0, r-1)
assert axis <= len(
- shape
- ) - 1 or axis >= 0, "the axis must be within (0, %d-1)" % len(
- shape)
+ shape) - 1 or axis >= 0, "the axis must be within (0, %d-1)" % len(
+ shape)
# calculate the new shape
new_shape = (1, int(np.prod(shape))) if axis == 0 else (
int(np.prod(shape[0:axis]).astype(int)),
@@ -2360,7 +2386,7 @@ class Tanh(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.__mul__(self.cache[0], self.cache[0])
dx = singa.MultFloat(dx, -1.0)
dx = singa.AddFloat(dx, 1.0)
@@ -2404,7 +2430,7 @@ class Cos(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.Sin(self.input)
dx = singa.MultFloat(dx, -1.0)
dx *= dy
@@ -2418,7 +2444,7 @@ def cos(x):
x (Tensor): Input tensor
Returns:
Tensor, the output
- """
+ """
return Cos()(x)[0]
@@ -2448,7 +2474,7 @@ class Cosh(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.Sinh(self.input)
dx *= dy
return dx
@@ -2491,7 +2517,7 @@ class Acos(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.Square(self.input)
dx = singa.MultFloat(dx, -1.0)
dx = singa.AddFloat(dx, 1.0)
@@ -2538,7 +2564,7 @@ class Acosh(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.SubFloat(self.input, 1.0)
dx = singa.Sqrt(dx)
temp = singa.AddFloat(self.input, 1.0)
@@ -2585,7 +2611,7 @@ class Sin(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.Cos(self.input)
dx *= dy
return dx
@@ -2627,7 +2653,7 @@ class Sinh(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.Cosh(self.input)
dx *= dy
return dx
@@ -2669,7 +2695,7 @@ class Asin(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.Square(self.input)
dx = singa.MultFloat(dx, -1.0)
dx = singa.AddFloat(dx, 1.0)
@@ -2685,7 +2711,7 @@ def asin(x):
x (Tensor): Input tensor
Returns:
Tensor, the output
- """
+ """
return Asin()(x)[0]
@@ -2715,7 +2741,7 @@ class Asinh(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.Square(self.input)
dx = singa.AddFloat(dx, 1.0)
dx = singa.PowFloat(dx, -0.5)
@@ -2759,7 +2785,7 @@ class Tan(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.Cos(self.input)
dx = singa.Square(dx)
dx = singa.PowFloat(dx, -1.0)
@@ -2803,7 +2829,7 @@ class Atan(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.Square(self.input)
dx = singa.AddFloat(dx, 1.0)
dx = singa.PowFloat(dx, -1.0)
@@ -2818,7 +2844,7 @@ def atan(x):
x (Tensor): Input tensor
Returns:
Tensor, the output
- """
+ """
return Atan()(x)[0]
@@ -2847,7 +2873,7 @@ class Atanh(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.Square(self.input)
dx = singa.MultFloat(dx, -1.0)
dx = singa.AddFloat(dx, 1.0)
@@ -2863,7 +2889,7 @@ def atanh(x):
x (Tensor): Input tensor
Returns:
Tensor, the output
- """
+ """
return Atanh()(x)[0]
@@ -2893,7 +2919,7 @@ class Sigmoid(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.MultFloat(self.cache[0], -1.0)
dx = singa.AddFloat(dx, 1.0)
dx = singa.__mul__(self.cache[0], dx)
@@ -2908,7 +2934,7 @@ def sigmoid(x):
x (Tensor): Input tensor
Returns:
Tensor, the output
- """
+ """
return Sigmoid()(x)[0]
@@ -2916,7 +2942,7 @@ class Mul(Operation):
"""
Performs element-wise binary multiplication (with Numpy-style broadcasting
support).
- """
+ """
def __init__(self):
super(Mul, self).__init__()
@@ -3007,7 +3033,7 @@ class Unsqueeze(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
return singa.Reshape(dy, self.cache)
@@ -3019,7 +3045,7 @@ def unsqueeze(x, axis=-1):
axis (list of int): the dimensions to be inserted.
Returns:
Tensor, the output
- """
+ """
return Unsqueeze(axis)(x)[0]
@@ -3322,7 +3348,7 @@ class Abs(Operation):
dy (CTensor): the gradient tensor from upper operations
Returns:
CTensor, the gradient over input
- """
+ """
dx = singa.Sign(self.input)
dx *= dy
return dx
@@ -3712,7 +3738,7 @@ class Min(Operation):
*x (a list of CTensor): List of tensors for max.
Returns:
CTensor, the output
- """
+ """
assert (len(x) > 0)
self.l = len(x)
if len(x) == 1:
@@ -4039,7 +4065,7 @@ class Max(Operation):
Returns:
CTensor, the output
tuple of CTensor, mask tensor
- """
+ """
m = singa.__sub__(a, b)
mask0 = singa.GEFloat(m, 0)
mask1 = singa.LTFloat(m, 0)
@@ -4052,7 +4078,7 @@ class Max(Operation):
*x (a list of CTensor): List of tensors for max.
Returns:
CTensor, the output
- """
+ """
assert (len(x) > 0)
self.l = len(x)
if len(x) == 1:
@@ -4951,9 +4977,10 @@ class Split(Operation):
the output CTensor.
"""
x_shape = list(x.shape())
- self.axis = self.axis % len(x_shape)
+ self.axis = self.axis % len(x_shape)
if self.parts is None:
- self.parts = [x_shape[self.axis]//self.num_output] * self.num_output
+ self.parts = [x_shape[self.axis] // self.num_output
+ ] * self.num_output
xs = []
_s = 0
for _l in self.parts:
@@ -5267,7 +5294,6 @@ class Cast(Operation):
x = x.AsType(self.to)
return x
-
def backward(self, dy):
"""
backward of Cast
@@ -5337,15 +5363,16 @@ class OneHot(Operation):
self.axis += (rank + 1)
ls = values.shape[0:self.axis]
rs = values.shape[self.axis:rank]
- targets = np.reshape(depth_range, (1,) * len(ls) + depth_range.shape + (1,) * len(rs))
+ targets = np.reshape(depth_range, (1,) * len(ls) + depth_range.shape +
+ (1,) * len(rs))
values = np.reshape(np.mod(values, self.depth), ls + (1,) + rs)
np_tensor = np.asarray(targets == values, dtype=np.float32)
- np_tensor = np_tensor * (self.values[1] - self.values[0]) + self.values[0]
+ np_tensor = np_tensor * (self.values[1] -
+ self.values[0]) + self.values[0]
tmp_tensor = tensor.from_numpy(np_tensor)
tmp_tensor.to_device(indices.device())
return tmp_tensor.data
-
def backward(self, dy):
"""
backward of OneHot
diff --git a/test/python/test_onnx_backend.py b/test/python/test_onnx_backend.py
index be9777a..2144d1d 100644
--- a/test/python/test_onnx_backend.py
+++ b/test/python/test_onnx_backend.py
@@ -46,6 +46,7 @@ def expect(node,
name,
opset_version=_default_opset_version,
decimal=5):
+
def _helper(dev):
onnx_node = sonnx.OnnxNode(node)
input_tensors = {}
@@ -61,12 +62,14 @@ def expect(node,
outputs_dict = sonnx.run_node(onnx_node, input_tensors, opset_version)
for out1, out2 in zip(outputs, outputs_dict.values()):
np.testing.assert_array_almost_equal(out1,
- tensor.to_numpy(out2),
- decimal=decimal)
+ tensor.to_numpy(out2),
+ decimal=decimal)
+
_helper(cpu_dev)
if (singa.USE_CUDA):
_helper(gpu_dev)
+
class TestPythonOnnxBackend(unittest.TestCase):
"""
This class aims to test the backend functionality of sonnx,
@@ -2192,21 +2195,21 @@ class TestPythonOnnxBackend(unittest.TestCase):
name='test_gemm_all_attributes')
def test_constantOfShape_float_ones(self):
- x = np.array([4, 3, 2]).astype(np.int64)
- tensor_value = onnx.helper.make_tensor("value", onnx.TensorProto.FLOAT,
- [1], [1])
- node = onnx.helper.make_node(
- 'ConstantOfShape',
- inputs=['x'],
- outputs=['y'],
- value=tensor_value,
- )
+ x = np.array([4, 3, 2]).astype(np.int64)
+ tensor_value = onnx.helper.make_tensor("value", onnx.TensorProto.FLOAT,
+ [1], [1])
+ node = onnx.helper.make_node(
+ 'ConstantOfShape',
+ inputs=['x'],
+ outputs=['y'],
+ value=tensor_value,
+ )
- y = np.ones(x, dtype=np.float32)
- expect(node,
- inputs=[x],
- outputs=[y],
- name='test_constantofshape_float_ones')
+ y = np.ones(x, dtype=np.float32)
+ expect(node,
+ inputs=[x],
+ outputs=[y],
+ name='test_constantofshape_float_ones')
def test_constantOfShape_int32_zeros(self):
x = np.array([10, 6]).astype(np.int64)
@@ -2620,7 +2623,9 @@ class TestPythonOnnxBackend(unittest.TestCase):
axes = np.array([0, 1], dtype=np.int64)
steps = np.array([1, 1], dtype=np.int64)
- expect(node, inputs=[x, starts, ends, axes, steps], outputs=[y],
+ expect(node,
+ inputs=[x, starts, ends, axes, steps],
+ outputs=[y],
name='test_slice')
def test_slice_neg(self):
@@ -2637,7 +2642,9 @@ class TestPythonOnnxBackend(unittest.TestCase):
steps = np.array([1], dtype=np.int64)
y = x[:, 0:-1]
- expect(node, inputs=[x, starts, ends, axes, steps], outputs=[y],
+ expect(node,
+ inputs=[x, starts, ends, axes, steps],
+ outputs=[y],
name='test_slice_neg')
# not support empty tensor
@@ -2672,7 +2679,9 @@ class TestPythonOnnxBackend(unittest.TestCase):
steps = np.array([1], dtype=np.int64)
y = x[:, 1:1000]
- expect(node, inputs=[x, starts, ends, axes, steps], outputs=[y],
+ expect(node,
+ inputs=[x, starts, ends, axes, steps],
+ outputs=[y],
name='test_slice_end_out_of_bounds')
def test_slice_default_axes(self):
@@ -2687,7 +2696,9 @@ class TestPythonOnnxBackend(unittest.TestCase):
ends = np.array([20, 10, 4], dtype=np.int64)
y = x[:, :, 3:4]
- expect(node, inputs=[x, starts, ends], outputs=[y],
+ expect(node,
+ inputs=[x, starts, ends],
+ outputs=[y],
name='test_slice_default_axes')
def test_slice_default_steps(self):
@@ -2703,7 +2714,9 @@ class TestPythonOnnxBackend(unittest.TestCase):
axes = np.array([0, 1, 2], dtype=np.int64)
y = x[:, :, 3:4]
- expect(node, inputs=[x, starts, ends, axes], outputs=[y],
+ expect(node,
+ inputs=[x, starts, ends, axes],
+ outputs=[y],
name='test_slice_default_steps')
def test_slice_neg_steps(self):
@@ -2720,7 +2733,9 @@ class TestPythonOnnxBackend(unittest.TestCase):
steps = np.array([-1, -3, -2])
y = x[20:0:-1, 10:0:-3, 4:1:-2]
- expect(node, inputs=[x, starts, ends, axes, steps], outputs=[y],
+ expect(node,
+ inputs=[x, starts, ends, axes, steps],
+ outputs=[y],
name='test_slice_neg_steps')
def test_slice_negative_axes(self):
@@ -2736,7 +2751,9 @@ class TestPythonOnnxBackend(unittest.TestCase):
axes = np.array([0, -2, -1], dtype=np.int64)
y = x[:, :, 3:4]
- expect(node, inputs=[x, starts, ends, axes], outputs=[y],
+ expect(node,
+ inputs=[x, starts, ends, axes],
+ outputs=[y],
name='test_slice_negative_axes')
def test_split_1d(self):
@@ -2746,51 +2763,67 @@ class TestPythonOnnxBackend(unittest.TestCase):
'Split',
inputs=['input'],
outputs=['output_1', 'output_2', 'output_3'],
- axis=0
- )
+ axis=0)
- expected_outputs = [np.array([1., 2.]).astype(np.float32), np.array([3., 4.]).astype(np.float32), np.array([5., 6.]).astype(np.float32)]
- expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_equal_parts_1d')
+ expected_outputs = [
+ np.array([1., 2.]).astype(np.float32),
+ np.array([3., 4.]).astype(np.float32),
+ np.array([5., 6.]).astype(np.float32)
+ ]
+ expect(node,
+ inputs=[input],
+ outputs=[y for y in expected_outputs],
+ name='test_split_equal_parts_1d')
- node = onnx.helper.make_node(
- 'Split',
- inputs=['input'],
- outputs=['output_1', 'output_2'],
- axis=0,
- split=[2, 4]
- )
+ node = onnx.helper.make_node('Split',
+ inputs=['input'],
+ outputs=['output_1', 'output_2'],
+ axis=0,
+ split=[2, 4])
- expected_outputs = [np.array([1., 2.]).astype(np.float32), np.array([3., 4., 5., 6.]).astype(np.float32)]
- expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_variable_parts_1d')
+ expected_outputs = [
+ np.array([1., 2.]).astype(np.float32),
+ np.array([3., 4., 5., 6.]).astype(np.float32)
+ ]
+ expect(node,
+ inputs=[input],
+ outputs=[y for y in expected_outputs],
+ name='test_split_variable_parts_1d')
def test_split_2d(self):
- input = np.array([[1., 2., 3., 4., 5., 6.],
- [7., 8., 9., 10., 11., 12.]]).astype(np.float32)
+ input = np.array([[1., 2., 3., 4., 5., 6.], [7., 8., 9., 10., 11.,
+ 12.]]).astype(np.float32)
- node = onnx.helper.make_node(
- 'Split',
- inputs=['input'],
- outputs=['output_1', 'output_2'],
- axis=1
- )
+ node = onnx.helper.make_node('Split',
+ inputs=['input'],
+ outputs=['output_1', 'output_2'],
+ axis=1)
- expected_outputs = [np.array([[1., 2., 3.], [7., 8., 9.]]).astype(np.float32),
- np.array([[4., 5., 6.], [10., 11., 12.]]).astype(np.float32)]
+ expected_outputs = [
+ np.array([[1., 2., 3.], [7., 8., 9.]]).astype(np.float32),
+ np.array([[4., 5., 6.], [10., 11., 12.]]).astype(np.float32)
+ ]
- expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_equal_parts_2d')
+ expect(node,
+ inputs=[input],
+ outputs=[y for y in expected_outputs],
+ name='test_split_equal_parts_2d')
- node = onnx.helper.make_node(
- 'Split',
- inputs=['input'],
- outputs=['output_1', 'output_2'],
- axis=1,
- split=[2, 4]
- )
+ node = onnx.helper.make_node('Split',
+ inputs=['input'],
+ outputs=['output_1', 'output_2'],
+ axis=1,
+ split=[2, 4])
- expected_outputs = [np.array([[1., 2.], [7., 8.]]).astype(np.float32),
- np.array([[3., 4., 5., 6.], [9., 10., 11., 12.]]).astype(np.float32)]
+ expected_outputs = [
+ np.array([[1., 2.], [7., 8.]]).astype(np.float32),
+ np.array([[3., 4., 5., 6.], [9., 10., 11., 12.]]).astype(np.float32)
+ ]
- expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_variable_parts_2d')
+ expect(node,
+ inputs=[input],
+ outputs=[y for y in expected_outputs],
+ name='test_split_variable_parts_2d')
def test_split_default_values(self):
input = np.array([1., 2., 3., 4., 5., 6.]).astype(np.float32)
@@ -2799,21 +2832,31 @@ class TestPythonOnnxBackend(unittest.TestCase):
node = onnx.helper.make_node(
'Split',
inputs=['input'],
- outputs=['output_1', 'output_2', 'output_3']
- )
+ outputs=['output_1', 'output_2', 'output_3'])
- expected_outputs = [np.array([1., 2.]).astype(np.float32), np.array([3., 4.]).astype(np.float32), np.array([5., 6.]).astype(np.float32)]
- expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_equal_parts_default_axis')
+ expected_outputs = [
+ np.array([1., 2.]).astype(np.float32),
+ np.array([3., 4.]).astype(np.float32),
+ np.array([5., 6.]).astype(np.float32)
+ ]
+ expect(node,
+ inputs=[input],
+ outputs=[y for y in expected_outputs],
+ name='test_split_equal_parts_default_axis')
- node = onnx.helper.make_node(
- 'Split',
- inputs=['input'],
- outputs=['output_1', 'output_2'],
- split=[2, 4]
- )
+ node = onnx.helper.make_node('Split',
+ inputs=['input'],
+ outputs=['output_1', 'output_2'],
+ split=[2, 4])
- expected_outputs = [np.array([1., 2.]).astype(np.float32), np.array([3., 4., 5., 6.]).astype(np.float32)]
- expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_variable_parts_default_axis')
+ expected_outputs = [
+ np.array([1., 2.]).astype(np.float32),
+ np.array([3., 4., 5., 6.]).astype(np.float32)
+ ]
+ expect(node,
+ inputs=[input],
+ outputs=[y for y in expected_outputs],
+ name='test_split_variable_parts_default_axis')
# not support empty tensor
# def test_split_zero_size_splits(self):
@@ -2841,7 +2884,9 @@ class TestPythonOnnxBackend(unittest.TestCase):
indices = np.array([0, 1, 3])
y = np.take(data, indices, axis=0)
- expect(node, inputs=[data, indices.astype(np.int64)], outputs=[y],
+ expect(node,
+ inputs=[data, indices.astype(np.int64)],
+ outputs=[y],
name='test_gather_0')
def test_gather_1(self):
@@ -2855,7 +2900,9 @@ class TestPythonOnnxBackend(unittest.TestCase):
indices = np.array([0, 1, 3])
y = np.take(data, indices, axis=1)
- expect(node, inputs=[data, indices.astype(np.int64)], outputs=[y],
+ expect(node,
+ inputs=[data, indices.astype(np.int64)],
+ outputs=[y],
name='test_gather_1')
def test_gather_negative_indices(self):
@@ -2869,47 +2916,32 @@ class TestPythonOnnxBackend(unittest.TestCase):
indices = np.array([0, -9, -10])
y = np.take(data, indices, axis=0)
- expect(node, inputs=[data, indices.astype(np.int64)], outputs=[y],
+ expect(node,
+ inputs=[data, indices.astype(np.int64)],
+ outputs=[y],
name='test_gather_negative_indices')
def test_tile(self):
- node = onnx.helper.make_node(
- 'Tile',
- inputs=['x', 'y'],
- outputs=['z']
- )
+ node = onnx.helper.make_node('Tile', inputs=['x', 'y'], outputs=['z'])
x = np.random.rand(2, 3, 4, 5).astype(np.float32)
- repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
+ repeats = np.random.randint(low=1, high=10,
+ size=(np.ndim(x),)).astype(np.int64)
z = np.tile(x, repeats)
- expect(node,
- inputs=[x, repeats],
- outputs=[z],
- name='test_tile')
+ expect(node, inputs=[x, repeats], outputs=[z], name='test_tile')
def test_tile_precomputed(self):
- node = onnx.helper.make_node(
- 'Tile',
- inputs=['x', 'y'],
- outputs=['z']
- )
+ node = onnx.helper.make_node('Tile', inputs=['x', 'y'], outputs=['z'])
- x = np.array([
- [0, 1],
- [2, 3]
- ], dtype=np.float32)
+ x = np.array([[0, 1], [2, 3]], dtype=np.float32)
repeats = np.array([2, 2], dtype=np.int64)
- z = np.array([
- [0, 1, 0, 1],
- [2, 3, 2, 3],
- [0, 1, 0, 1],
- [2, 3, 2, 3]
- ], dtype=np.float32)
+ z = np.array([[0, 1, 0, 1], [2, 3, 2, 3], [0, 1, 0, 1], [2, 3, 2, 3]],
+ dtype=np.float32)
expect(node,
inputs=[x, repeats],
diff --git a/test/python/test_operation.py b/test/python/test_operation.py
index 44a903e..557fd49 100755
--- a/test/python/test_operation.py
+++ b/test/python/test_operation.py
@@ -518,6 +518,30 @@ class TestPythonOperation(unittest.TestCase):
def test_numerical_gradients_check_for_lstm_gpu(self):
self._numerical_gradients_check_for_lstm_helper(gpu_dev)
+ def _MeanSquareError_helper(self, dev):
+ X = np.array([4.3, 5.4, 3.3, 3.6, 5.7,
+ 6.0]).reshape(3, 2).astype(np.float32)
+ T = np.array([4.4, 5.3, 3.2, 3.7, 5.4,
+ 6.3]).reshape(3, 2).astype(np.float32)
+ x = tensor.from_numpy(X)
+ t = tensor.from_numpy(T)
+ x.to_device(dev)
+ t.to_device(dev)
+
+ loss = autograd.mse_loss(x, t)
+ dx = loss.creator.backward()[0]
+
+ loss_np = tensor.to_numpy(loss)[0]
+ self.assertAlmostEqual(loss_np, 0.0366666, places=4)
+ self.check_shape(dx.shape(), (3, 2))
+
+ def test_MeanSquareError_cpu(self):
+ self._MeanSquareError_helper(cpu_dev)
+
+ @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled')
+ def test_MeanSquareError_gpu(self):
+ self._MeanSquareError_helper(gpu_dev)
+
def _Abs_helper(self, dev):
X = np.array([0.8, -1.2, 3.3, -3.6, -0.5,
0.5]).reshape(3, 2).astype(np.float32)
@@ -2364,20 +2388,19 @@ class TestPythonOperation(unittest.TestCase):
dx0, dx1 = result.creator.backward(dy.data)
# use realtive and total error instead of demical number
np.testing.assert_allclose(tensor.to_numpy(result),
- y,
- rtol=1e-4,
- atol=1e-4)
+ y,
+ rtol=1e-4,
+ atol=1e-4)
np.testing.assert_allclose(tensor.to_numpy(
tensor.from_raw_tensor(dx0)),
- grad0,
- rtol=1e-4,
- atol=1e-4)
+ grad0,
+ rtol=1e-4,
+ atol=1e-4)
np.testing.assert_allclose(tensor.to_numpy(
tensor.from_raw_tensor(dx1)),
- grad1,
- rtol=1e-4,
- atol=1e-4)
-
+ grad1,
+ rtol=1e-4,
+ atol=1e-4)
def test_div_broadcast_cpu(self):
self._div_broadcast_helper(cpu_dev)