You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/20 00:59:03 UTC

[GitHub] marcoabreu closed pull request #12605: add fluent methods for softmin

marcoabreu closed pull request #12605: add fluent methods for softmin
URL: https://github.com/apache/incubator-mxnet/pull/12605
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index d6d619f30ca..93f2bc49e9b 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1678,6 +1678,14 @@ def log_softmax(self, *args, **kwargs):
         """
         return op.log_softmax(self, *args, **kwargs)
 
+    def softmin(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`softmin`.
+
+        The arguments are the same as for :py:func:`softmin`, with
+        this array as data.
+        """
+        return op.softmin(self, *args, **kwargs)
+
     def squeeze(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`squeeze`.
 
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 4864ce99163..554539b424a 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -2423,6 +2423,14 @@ def log_softmax(self, *args, **kwargs):
         """
         return op.log_softmax(self, *args, **kwargs)
 
+    def softmin(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`softmin`.
+
+        The arguments are the same as for :py:func:`softmin`, with
+        this array as data.
+        """
+        return op.softmin(self, *args, **kwargs)
+
     def squeeze(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`squeeze`.
 
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index c48801ec1ce..7a5c7ca4f1b 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -558,7 +558,7 @@ def test_broadcast_like_axis():
             [(1, 7, 9, 1, 1), (9, 1), (-2, -1), (-2, -1), (1, 7, 9, 9, 1)],
             [(2, 1), (1, 7, 9, 1, 1), (1,), (-3,), (2, 9)]
         ]
-        
+
         for test_data in testcases:
             lhs = mx.nd.random.uniform(shape=test_data[0])
             rhs = mx.nd.random.uniform(shape=test_data[1])
@@ -1039,7 +1039,7 @@ def test_ndarray_fluent():
                     'degrees', 'radians', 'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
                     'exp', 'expm1', 'log', 'log10', 'log2', 'log1p', 'sqrt', 'rsqrt', 'square',
                     'reshape_like', 'cbrt', 'rcbrt', 'relu', 'sigmoid', 'softmax', 'log_softmax',
-                    'reciprocal'])
+                    'softmin', 'reciprocal'])
     def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
         with mx.name.NameManager():
             data = mx.nd.random_uniform(shape=shape, ctx=default_context())
@@ -1058,7 +1058,7 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
 
     for func in ['arccosh', 'arcsin', 'arccos', 'arctan', 'tan', 'sinh', 'cosh', 'tanh',
                  'arcsinh', 'arctanh', 'log', 'log10', 'log2', 'log1p', 'sqrt', 'rsqrt',
-                 'cbrt', 'rcbrt', 'relu', 'sigmoid', 'softmax', 'log_softmax']:
+                 'cbrt', 'rcbrt', 'relu', 'sigmoid', 'softmax', 'log_softmax', 'softmin']:
         check_fluent_regular(func, {}, equal_nan=True)
 
     for func in ['expand_dims', 'flip', 'sort', 'topk', 'argsort', 'argmax', 'argmin']:
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index d022c68237a..c5c1b018b08 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -177,7 +177,7 @@ def test_symbol_fluent():
                     'degrees', 'radians', 'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
                     'exp', 'expm1', 'log', 'log10', 'log2', 'log1p', 'sqrt', 'rsqrt',
                     'square', 'reciprocal' 'reshape_like', 'cbrt', 'rcbrt', 'relu', 'sigmoid',
-                    'softmax', 'log_softmax', 'rint', 'ceil', 'floor', 'trunc', 'fix'])
+                    'softmax', 'log_softmax', 'softmin', 'rint', 'ceil', 'floor', 'trunc', 'fix'])
 
     def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
         with mx.name.NameManager():
@@ -196,7 +196,7 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
 
     for func in ['arccosh', 'arcsin', 'arccos', 'arctan', 'tan', 'sinh', 'cosh', 'tanh',
                  'arcsinh', 'arctanh', 'log', 'log10', 'log2', 'log1p', 'sqrt', 'rsqrt',
-                 'cbrt', 'rcbrt', 'relu', 'sigmoid', 'softmax', 'log_softmax']:
+                 'cbrt', 'rcbrt', 'relu', 'sigmoid', 'softmax', 'log_softmax', 'softmin']:
         check_fluent_regular(func, {}, equal_nan=True)
 
     for func in ['expand_dims', 'flip', 'sort', 'topk', 'argsort', 'argmax', 'argmin']:


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services